diff --git a/README.md b/README.md index e9a6863..df00d0a 100644 --- a/README.md +++ b/README.md @@ -321,7 +321,34 @@ const { message } = await prompt("What is the capital of France?", { console.log(message.content); ``` -⚠️ Not all of the arguments below are implemented yet. +In order to pass a history of messages, pass them as `options.messages`: + +```js +const { message } = await prompt("What about Spain?", { + model: "gpt-4", + token: process.env.TOKEN, + messages: [ + { role: "user", content: "What is the capital of France?" }, + { role: "assistant", content: "The capital of France is Paris." }, + ], +}); +``` + +Alternatively, skip the `message` argument and pass all messages as `options.messages`: + +```js +const { message } = await prompt({ + model: "gpt-4", + token: process.env.TOKEN, + messages: [ + { role: "user", content: "What is the capital of France?" }, + { role: "assistant", content: "The capital of France is Paris." }, + { role: "user", content: "What about Spain?" }, + ], +}); +``` + +⚠️ Not all of the arguments below are implemented yet. See [#5](https://github.com/copilot-extensions/preview-sdk.js/issues/5) sub issues for progress. ```js await prompt({ diff --git a/index.d.ts b/index.d.ts index f41811e..84d8edf 100644 --- a/index.d.ts +++ b/index.d.ts @@ -75,9 +75,7 @@ type ResponseEvent = type CopilotAckResponseEventData = { choices: [{ - delta: { - content: "", role: "assistant" - } + delta: InteropMessage<"assistant"> }] } @@ -92,9 +90,7 @@ type CopilotDoneResponseEventData = { type CopilotTextResponseEventData = { choices: [{ - delta: { - content: string, role: "assistant" - } + delta: InteropMessage<"assistant"> }] } type CopilotConfirmationResponseEventData = { @@ -134,7 +130,7 @@ interface CopilotReference { export interface CopilotRequestPayload { copilot_thread_id: string - messages: Message[] + messages: CopilotMessage[] stop: any top_p: number temperature: number @@ -146,14 +142,10 @@ export interface CopilotRequestPayload { } export interface OpenAICompatibilityPayload { - messages: { - role: string - name?: string - content: string - }[] + messages: InteropMessage[] } -export interface Message { +export interface CopilotMessage { role: string content: string copilot_references: MessageCopilotReference[] @@ -167,6 +159,14 @@ export interface Message { "type": "function" }[] name?: string + [key: string]: unknown +} + +export interface InteropMessage { + role: TRole + content: string + name?: string + [key: string]: unknown } export interface MessageCopilotReference { @@ -254,10 +254,23 @@ export interface GetUserConfirmationInterface { // prompt -/** model names supported by Copilot API */ +/** + * model names supported by Copilot API + * + * Based on https://api.githubcopilot.com/models from 2024-09-02 + */ export type ModelName = - | "gpt-4" | "gpt-3.5-turbo" + | "gpt-3.5-turbo-0613" + | "gpt-4" + | "gpt-4-0613" + | "gpt-4-o-preview" + | "gpt-4o" + | "gpt-4o-2024-05-13" + | "text-embedding-3-small" + | "text-embedding-3-small-inference" + | "text-embedding-ada-002" + | "text-embedding-ada-002-index" export interface PromptFunction { type: "function" @@ -274,6 +287,7 @@ export type PromptOptions = { model: ModelName token: string tools?: PromptFunction[] + messages?: InteropMessage[] request?: { fetch?: Function } @@ -281,11 +295,15 @@ export type PromptOptions = { export type PromptResult = { requestId: string - message: Message + message: CopilotMessage } +// https://stackoverflow.com/a/69328045 +type WithRequired = T & { [P in K]-?: T[P] } + interface PromptInterface { (userPrompt: string, options: PromptOptions): Promise; + (options: WithRequired): Promise; } // exported methods diff --git a/index.test-d.ts b/index.test-d.ts index 267d21c..0ddf2ef 100644 --- a/index.test-d.ts +++ b/index.test-d.ts @@ -17,6 +17,7 @@ import { getUserMessage, getUserConfirmation, type VerificationPublicKey, + type InteropMessage, CopilotRequestPayload, prompt, } from "./index.js"; @@ -79,11 +80,10 @@ export function createAckEventTest() { expectType<() => string>(event.toString); expectType(event.toString()); + expectType<{ choices: [{ - delta: { - content: "", role: "assistant" - } + delta: InteropMessage<"assistant"> }] }>(event.data); @@ -98,9 +98,7 @@ export function createTextEventTest() { expectType<{ choices: [{ - delta: { - content: string, role: "assistant" - } + delta: InteropMessage<"assistant"> }] }>(event.data); @@ -243,6 +241,7 @@ export function transformPayloadForOpenAICompatibilityTest(payload: CopilotReque content: string; role: string; name?: string + [key: string]: unknown }[] } >(result); @@ -307,12 +306,33 @@ export async function promptWithToolsTest() { function: { name: "", description: "", - parameters: { - - }, + parameters: {}, strict: true, } } ] }) +} + +export async function promptWithMessageAndMessages() { + await prompt("What about Spain?", { + model: "gpt-4", + token: 'secret', + messages: [ + { role: "user", content: "What is the capital of France?" }, + { role: "assistant", content: "The capital of France is Paris." }, + ], + }); +} + +export async function promptWithoutMessageButMessages() { + await prompt({ + model: "gpt-4", + token: 'secret', + messages: [ + { role: "user", content: "What is the capital of France?" }, + { role: "assistant", content: "The capital of France is Paris." }, + { role: "user", content: "What about Spain?" }, + ], + }); } \ No newline at end of file diff --git a/lib/prompt.js b/lib/prompt.js index 77bbab9..e7c9db0 100644 --- a/lib/prompt.js +++ b/lib/prompt.js @@ -2,12 +2,32 @@ /** @type {import('..').PromptInterface} */ export async function prompt(userPrompt, promptOptions) { - const promptFetch = promptOptions.request?.fetch || fetch; + const options = typeof userPrompt === "string" ? promptOptions : userPrompt; - const systemMessage = promptOptions.tools + const promptFetch = options.request?.fetch || fetch; + + const systemMessage = options.tools ? "You are a helpful assistant. Use the supplied tools to assist the user." : "You are a helpful assistant."; + const messages = [ + { + role: "system", + content: systemMessage, + }, + ]; + + if (options.messages) { + messages.push(...options.messages); + } + + if (typeof userPrompt === "string") { + messages.push({ + role: "user", + content: userPrompt, + }); + } + const response = await promptFetch( "https://api.githubcopilot.com/chat/completions", { @@ -16,22 +36,13 @@ export async function prompt(userPrompt, promptOptions) { accept: "application/json", "content-type": "application/json; charset=UTF-8", "user-agent": "copilot-extensions/preview-sdk.js", - authorization: `Bearer ${promptOptions.token}`, + authorization: `Bearer ${options.token}`, }, body: JSON.stringify({ - messages: [ - { - role: "system", - content: systemMessage, - }, - { - role: "user", - content: userPrompt, - }, - ], - model: promptOptions.model, - toolChoice: promptOptions.tools ? "auto" : undefined, - tools: promptOptions.tools, + messages: messages, + model: options.model, + toolChoice: options.tools ? "auto" : undefined, + tools: options.tools, }), } ); diff --git a/test/prompt.test.js b/test/prompt.test.js index fb8d723..2b83e12 100644 --- a/test/prompt.test.js +++ b/test/prompt.test.js @@ -1,145 +1,270 @@ -import { test } from "node:test"; +import { test, suite } from "node:test"; import { MockAgent } from "undici"; import { prompt } from "../index.js"; -test("smoke", (t) => { - t.assert.equal(typeof prompt, "function"); -}); +suite("prompt", () => { + test("smoke", (t) => { + t.assert.equal(typeof prompt, "function"); + }); -test("minimal usage", async (t) => { - const mockAgent = new MockAgent(); - function fetchMock(url, opts) { - opts ||= {}; - opts.dispatcher = mockAgent; - return fetch(url, opts); - } - - mockAgent.disableNetConnect(); - const mockPool = mockAgent.get("https://api.githubcopilot.com"); - mockPool - .intercept({ - method: "post", - path: `/chat/completions`, - body: JSON.stringify({ - messages: [ - { - role: "system", - content: "You are a helpful assistant.", - }, - { - role: "user", - content: "What is the capital of France?", - }, - ], - model: "gpt-4", - }), - }) - .reply( - 200, - { - choices: [ - { - message: { - content: "", + test("minimal usage", async (t) => { + const mockAgent = new MockAgent(); + function fetchMock(url, opts) { + opts ||= {}; + opts.dispatcher = mockAgent; + return fetch(url, opts); + } + + mockAgent.disableNetConnect(); + const mockPool = mockAgent.get("https://api.githubcopilot.com"); + mockPool + .intercept({ + method: "post", + path: `/chat/completions`, + body: JSON.stringify({ + messages: [ + { + role: "system", + content: "You are a helpful assistant.", }, + { + role: "user", + content: "What is the capital of France?", + }, + ], + model: "gpt-4", + }), + }) + .reply( + 200, + { + choices: [ + { + message: { + content: "", + }, + }, + ], + }, + { + headers: { + "content-type": "application/json", + "x-request-id": "", }, - ], + } + ); + + const result = await prompt("What is the capital of France?", { + token: "secret", + model: "gpt-4", + request: { fetch: fetchMock }, + }); + + t.assert.deepEqual(result, { + requestId: "", + message: { + content: "", }, - { - headers: { - "content-type": "application/json", - "x-request-id": "", + }); + }); + + test("options.messages", async (t) => { + const mockAgent = new MockAgent(); + function fetchMock(url, opts) { + opts ||= {}; + opts.dispatcher = mockAgent; + return fetch(url, opts); + } + + mockAgent.disableNetConnect(); + const mockPool = mockAgent.get("https://api.githubcopilot.com"); + mockPool + .intercept({ + method: "post", + path: `/chat/completions`, + body: JSON.stringify({ + messages: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "What is the capital of France?" }, + { role: "assistant", content: "The capital of France is Paris." }, + { role: "user", content: "What about Spain?" }, + ], + model: "gpt-4", + }), + }) + .reply( + 200, + { + choices: [ + { + message: { + content: "", + }, + }, + ], }, - } - ); + { + headers: { + "content-type": "application/json", + "x-request-id": "", + }, + } + ); - const result = await prompt("What is the capital of France?", { - token: "secret", - model: "gpt-4", - request: { fetch: fetchMock }, - }); + const result = await prompt("What about Spain?", { + model: "gpt-4", + token: "secret", + messages: [ + { role: "user", content: "What is the capital of France?" }, + { role: "assistant", content: "The capital of France is Paris." }, + ], + request: { fetch: fetchMock }, + }); - t.assert.deepEqual(result, { - requestId: "", - message: { - content: "", - }, + t.assert.deepEqual(result, { + requestId: "", + message: { + content: "", + }, + }); }); -}); -test("function calling", async (t) => { - const mockAgent = new MockAgent(); - function fetchMock(url, opts) { - opts ||= {}; - opts.dispatcher = mockAgent; - return fetch(url, opts); - } - - mockAgent.disableNetConnect(); - const mockPool = mockAgent.get("https://api.githubcopilot.com"); - mockPool - .intercept({ - method: "post", - path: `/chat/completions`, - body: JSON.stringify({ - messages: [ - { - role: "system", - content: - "You are a helpful assistant. Use the supplied tools to assist the user.", - }, - { role: "user", content: "Call the function" }, - ], - model: "gpt-4", - toolChoice: "auto", - tools: [ - { - type: "function", - function: { name: "the_function", description: "The function" }, - }, - ], - }), - }) - .reply( - 200, - { - choices: [ - { - message: { - content: "", + test("single options argument", async (t) => { + const mockAgent = new MockAgent(); + function fetchMock(url, opts) { + opts ||= {}; + opts.dispatcher = mockAgent; + return fetch(url, opts); + } + + mockAgent.disableNetConnect(); + const mockPool = mockAgent.get("https://api.githubcopilot.com"); + mockPool + .intercept({ + method: "post", + path: `/chat/completions`, + body: JSON.stringify({ + messages: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "What is the capital of France?" }, + { role: "assistant", content: "The capital of France is Paris." }, + { role: "user", content: "What about Spain?" }, + ], + model: "gpt-4", + }), + }) + .reply( + 200, + { + choices: [ + { + message: { + content: "", + }, }, + ], + }, + { + headers: { + "content-type": "application/json", + "x-request-id": "", }, - ], + } + ); + + const result = await prompt({ + model: "gpt-4", + token: "secret", + messages: [ + { role: "user", content: "What is the capital of France?" }, + { role: "assistant", content: "The capital of France is Paris." }, + { role: "user", content: "What about Spain?" }, + ], + request: { fetch: fetchMock }, + }); + + t.assert.deepEqual(result, { + requestId: "", + message: { + content: "", }, - { - headers: { - "content-type": "application/json", - "x-request-id": "", + }); + }); + + test("function calling", async (t) => { + const mockAgent = new MockAgent(); + function fetchMock(url, opts) { + opts ||= {}; + opts.dispatcher = mockAgent; + return fetch(url, opts); + } + + mockAgent.disableNetConnect(); + const mockPool = mockAgent.get("https://api.githubcopilot.com"); + mockPool + .intercept({ + method: "post", + path: `/chat/completions`, + body: JSON.stringify({ + messages: [ + { + role: "system", + content: + "You are a helpful assistant. Use the supplied tools to assist the user.", + }, + { role: "user", content: "Call the function" }, + ], + model: "gpt-4", + toolChoice: "auto", + tools: [ + { + type: "function", + function: { name: "the_function", description: "The function" }, + }, + ], + }), + }) + .reply( + 200, + { + choices: [ + { + message: { + content: "", + }, + }, + ], }, - } - ); - - const result = await prompt("Call the function", { - token: "secret", - model: "gpt-4", - tools: [ - { - type: "function", - function: { - name: "the_function", - description: "The function", + { + headers: { + "content-type": "application/json", + "x-request-id": "", + }, + } + ); + + const result = await prompt("Call the function", { + token: "secret", + model: "gpt-4", + tools: [ + { + type: "function", + function: { + name: "the_function", + description: "The function", + }, }, - }, - ], - request: { fetch: fetchMock }, - }); + ], + request: { fetch: fetchMock }, + }); - t.assert.deepEqual(result, { - requestId: "", - message: { - content: "", - }, + t.assert.deepEqual(result, { + requestId: "", + message: { + content: "", + }, + }); }); });