diff --git a/README.md b/README.md index 611b5ca..cf51484 100644 --- a/README.md +++ b/README.md @@ -422,6 +422,22 @@ await prompt({ }); ``` +#### `prompt.stream(message, options)` + +Works the same way as `prompt()`, but resolves with a `stream` key instead of a `message` key. + +```js +import { prompt } from "@copilot-extensions/preview-sdk"; + +const { requestId, stream } = prompt.stream("What is the capital of France?", { + token: process.env.TOKEN, +}); + +for await (const chunk of stream) { + console.log(new TextDecoder().decode(chunk)); +} +``` + ### `getFunctionCalls()` Convenience metthod if a result from a `prompt()` call includes function calls. diff --git a/index.d.ts b/index.d.ts index 96c3483..22b8870 100644 --- a/index.d.ts +++ b/index.d.ts @@ -175,7 +175,7 @@ export interface OpenAICompatibilityPayload { export interface CopilotMessage { role: string; content: string; - copilot_references: MessageCopilotReference[]; + copilot_references?: MessageCopilotReference[]; copilot_confirmations?: MessageCopilotConfirmation[]; tool_calls?: { function: { @@ -300,8 +300,9 @@ export interface PromptFunction { } export type PromptOptions = { - model?: ModelName; token: string; + endpoint?: string; + model?: ModelName; tools?: PromptFunction[]; messages?: InteropMessage[]; request?: { @@ -314,12 +315,25 @@ export type PromptResult = { message: CopilotMessage; }; +export type PromptStreamResult = { + requestId: string; + stream: ReadableStream; +}; + // https://stackoverflow.com/a/69328045 type WithRequired = T & { [P in K]-?: T[P] }; interface PromptInterface { (userPrompt: string, options: PromptOptions): Promise; (options: WithRequired): Promise; + stream: PromptStreamInterface; +} + +interface PromptStreamInterface { + (userPrompt: string, options: PromptOptions): Promise; + ( + options: WithRequired, + ): Promise; } interface GetFunctionCallsInterface { diff --git a/index.test-d.ts b/index.test-d.ts index 61897ca..844b63a 100644 --- a/index.test-d.ts +++ b/index.test-d.ts @@ -302,7 +302,6 @@ export function getUserConfirmationTest(payload: CopilotRequestPayload) { export async function promptTest() { const result = await prompt("What is the capital of France?", { - model: "gpt-4", token: "secret", }); @@ -311,7 +310,6 @@ export async function promptTest() { // with custom fetch await prompt("What is the capital of France?", { - model: "gpt-4", token: "secret", request: { fetch: () => {}, @@ -327,7 +325,6 @@ export async function promptTest() { export async function promptWithToolsTest() { await prompt("What is the capital of France?", { - model: "gpt-4", token: "secret", tools: [ { @@ -366,6 +363,24 @@ export async function promptWithoutMessageButMessages() { }); } +export async function otherPromptOptionsTest() { + const result = await prompt("What is the capital of France?", { + token: "secret", + model: "gpt-4", + endpoint: "https://api.githubcopilot.com", + }); +} + +export async function promptStreamTest() { + const result = await prompt.stream("What is the capital of France?", { + model: "gpt-4", + token: "secret", + }); + + expectType(result.requestId); + expectType>(result.stream); +} + export async function getFunctionCallsTest( promptResponsePayload: PromptResult, ) { diff --git a/lib/prompt.js b/lib/prompt.js index 7d90ed7..6dd4ff7 100644 --- a/lib/prompt.js +++ b/lib/prompt.js @@ -1,15 +1,20 @@ // @ts-check /** @type {import('..').PromptInterface} */ -export async function prompt(userPrompt, promptOptions) { - const options = typeof userPrompt === "string" ? promptOptions : userPrompt; - const promptFetch = options.request?.fetch || fetch; - const modelName = options.model || "gpt-4"; +function parsePromptArguments(userPrompt, promptOptions) { + const { request: requestOptions, ...options } = + typeof userPrompt === "string" ? promptOptions : userPrompt; + + const promptFetch = requestOptions?.fetch || fetch; + const model = options.model || "gpt-4"; + const endpoint = + options.endpoint || "https://api.githubcopilot.com/chat/completions"; const systemMessage = options.tools ? "You are a helpful assistant. Use the supplied tools to assist the user." : "You are a helpful assistant."; + const toolsChoice = options.tools ? "auto" : undefined; const messages = [ { @@ -29,44 +34,87 @@ export async function prompt(userPrompt, promptOptions) { }); } - const response = await promptFetch( - "https://api.githubcopilot.com/chat/completions", - { - method: "POST", - headers: { - accept: "application/json", - "content-type": "application/json; charset=UTF-8", - "user-agent": "copilot-extensions/preview-sdk.js", - authorization: `Bearer ${options.token}`, - }, - body: JSON.stringify({ - messages: messages, - model: modelName, - toolChoice: options.tools ? "auto" : undefined, - tools: options.tools, - }), - } - ); + return [promptFetch, { ...options, messages, model, endpoint, toolsChoice }]; +} - if (response.ok) { - const data = await response.json(); +async function sendPromptRequest(promptFetch, options) { + const { endpoint, token, ...payload } = options; + const method = "POST"; + const headers = { + accept: "application/json", + "content-type": "application/json; charset=UTF-8", + "user-agent": "copilot-extensions/preview-sdk.js", + authorization: `Bearer ${token}`, + }; - return { - requestId: response.headers.get("x-request-id"), - message: data.choices[0].message, - }; + const response = await promptFetch(endpoint, { + method, + headers, + body: JSON.stringify(payload), + }); + + if (response.ok) { + return response; } + const body = await response.text(); + console.log({ body }); + + throw Object.assign( + new Error( + `[@copilot-extensions/preview-sdk] An error occured with the chat completions API`, + ), + { + name: "PromptError", + request: { + method: "POST", + url: endpoint, + headers: { + ...headers, + authorization: `Bearer [REDACTED]`, + }, + body: payload, + }, + response: { + status: response.status, + headers: [...response.headers], + body: body, + }, + }, + ); +} +export async function prompt(userPrompt, promptOptions) { + const [promptFetch, options] = parsePromptArguments( + userPrompt, + promptOptions, + ); + const response = await sendPromptRequest(promptFetch, options); const requestId = response.headers.get("x-request-id"); + + const data = await response.json(); + return { - requestId: requestId, - message: { - role: "Sssistant", - content: `Sorry, an error occured with the chat completions API. (Status: ${response.status}, request ID: ${requestId})`, - }, + requestId, + message: data.choices[0].message, }; } +prompt.stream = async function promptStream(userPrompt, promptOptions) { + const [promptFetch, options] = parsePromptArguments( + userPrompt, + promptOptions, + ); + const response = await sendPromptRequest(promptFetch, { + ...options, + stream: true, + }); + + return { + requestId: response.headers.get("x-request-id"), + stream: response.body, + }; +}; + /** @type {import('..').GetFunctionCallsInterface} */ export function getFunctionCalls(payload) { const functionCalls = payload.message.tool_calls; diff --git a/test/prompt.test.js b/test/prompt.test.js index 576893c..7820464 100644 --- a/test/prompt.test.js +++ b/test/prompt.test.js @@ -1,4 +1,5 @@ import { test, suite } from "node:test"; +import assert from "node:assert"; import { MockAgent } from "undici"; @@ -84,13 +85,13 @@ suite("prompt", () => { method: "post", path: `/chat/completions`, body: JSON.stringify({ + model: "", messages: [ { role: "system", content: "You are a helpful assistant." }, { role: "user", content: "What is the capital of France?" }, { role: "assistant", content: "The capital of France is Paris." }, { role: "user", content: "What about Spain?" }, ], - model: "", }), }) .reply( @@ -190,6 +191,67 @@ suite("prompt", () => { }); }); + test("options.endpoint", async (t) => { + const mockAgent = new MockAgent(); + function fetchMock(url, opts) { + opts ||= {}; + opts.dispatcher = mockAgent; + return fetch(url, opts); + } + + mockAgent.disableNetConnect(); + const mockPool = mockAgent.get("https://my-copilot-endpoint.test"); + mockPool + .intercept({ + method: "post", + path: `/chat/completions`, + body: JSON.stringify({ + messages: [ + { + role: "system", + content: "You are a helpful assistant.", + }, + { + role: "user", + content: "What is the capital of France?", + }, + ], + model: "gpt-4", + }), + }) + .reply( + 200, + { + choices: [ + { + message: { + content: "", + }, + }, + ], + }, + { + headers: { + "content-type": "application/json", + "x-request-id": "", + }, + }, + ); + + const result = await prompt("What is the capital of France?", { + token: "secret", + endpoint: "https://my-copilot-endpoint.test/chat/completions", + request: { fetch: fetchMock }, + }); + + t.assert.deepEqual(result, { + requestId: "", + message: { + content: "", + }, + }); + }); + test("single options argument", async (t) => { const mockAgent = new MockAgent(); function fetchMock(url, opts) { @@ -266,6 +328,12 @@ suite("prompt", () => { method: "post", path: `/chat/completions`, body: JSON.stringify({ + tools: [ + { + type: "function", + function: { name: "the_function", description: "The function" }, + }, + ], messages: [ { role: "system", @@ -275,13 +343,7 @@ suite("prompt", () => { { role: "user", content: "Call the function" }, ], model: "gpt-4", - toolChoice: "auto", - tools: [ - { - type: "function", - function: { name: "the_function", description: "The function" }, - }, - ], + toolsChoice: "auto", }), }) .reply( @@ -360,19 +422,51 @@ suite("prompt", () => { }, }); - const result = await prompt("What is the capital of France?", { - token: "secret", - request: { fetch: fetchMock }, - }); - - t.assert.deepEqual(result, { - message: { - content: - "Sorry, an error occured with the chat completions API. (Status: 400, request ID: )", - role: "Sssistant", + await assert.rejects( + async () => { + await prompt("What is the capital of France?", { + token: "secret", + request: { fetch: fetchMock }, + }); }, - requestId: "", - }); + { + name: "PromptError", + message: + "[@copilot-extensions/preview-sdk] An error occured with the chat completions API", + request: { + method: "POST", + url: "https://api.githubcopilot.com/chat/completions", + headers: { + "content-type": "application/json; charset=UTF-8", + "user-agent": "copilot-extensions/preview-sdk.js", + accept: "application/json", + authorization: "Bearer [REDACTED]", + }, + body: { + messages: [ + { + content: "You are a helpful assistant.", + role: "system", + }, + { + content: "What is the capital of France?", + role: "user", + }, + ], + model: "gpt-4", + toolsChoice: undefined, + }, + }, + response: { + status: 400, + headers: [ + ["content-type", "text/plain"], + ["x-request-id", ""], + ], + body: "Bad Request", + }, + }, + ); }); suite("getFunctionCalls()", () => {