From be9c476f3146cacee43f5258fd816b4625a0f4ab Mon Sep 17 00:00:00 2001 From: Gregor Martynus <39992+gr2m@users.noreply.github.com> Date: Mon, 2 Sep 2024 21:06:14 -0700 Subject: [PATCH] feat: `getFunctionCalls()` -b closes (#50) --- README.md | 17 ++++++++ index.d.ts | 13 +++++- index.test-d.ts | 13 ++++++ lib/prompt.js | 34 +++++++++++++-- test/prompt.test.js | 100 +++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 172 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index df00d0a..c53b721 100644 --- a/README.md +++ b/README.md @@ -409,6 +409,23 @@ await prompt({ }); ``` +### `getFunctionCalls()` + +Convenience metthod if a result from a `prompt()` call includes function calls. + +```js +import { prompt, getFunctionCalls } from "@copilot-extensions/preview-sdk"; + +const result = await prompt(options); +const [functionCall] = getFunctionCalls(result); + +if (functionCall) { + console.log("Received a function call", functionCall); +} else { + console.log("No function call received"); +} +``` + ## Dreamcode While implementing the lower-level functionality, we also dream big: what would our dream SDK for Coplitot extensions look like? Please have a look and share your thoughts and ideas: diff --git a/index.d.ts b/index.d.ts index 84d8edf..7f5a605 100644 --- a/index.d.ts +++ b/index.d.ts @@ -306,6 +306,16 @@ interface PromptInterface { (options: WithRequired): Promise; } +interface GetFunctionCallsInterface { + (payload: PromptResult): { + id: string; + function: { + name: string, + arguments: string, + } + }[] +} + // exported methods export declare const verifyRequest: VerifyRequestInterface; @@ -325,4 +335,5 @@ export declare const verifyAndParseRequest: VerifyAndParseRequestInterface; export declare const getUserMessage: GetUserMessageInterface; export declare const getUserConfirmation: GetUserConfirmationInterface; -export declare const prompt: PromptInterface; \ No newline at end of file +export declare const prompt: PromptInterface; +export declare const getFunctionCalls: GetFunctionCallsInterface; \ No newline at end of file diff --git a/index.test-d.ts b/index.test-d.ts index 0ddf2ef..0c0ad09 100644 --- a/index.test-d.ts +++ b/index.test-d.ts @@ -20,6 +20,8 @@ import { type InteropMessage, CopilotRequestPayload, prompt, + PromptResult, + getFunctionCalls, } from "./index.js"; const token = ""; @@ -335,4 +337,15 @@ export async function promptWithoutMessageButMessages() { { role: "user", content: "What about Spain?" }, ], }); +} + +export async function getFunctionCallsTest(promptResponsePayload: PromptResult) { + const result = getFunctionCalls(promptResponsePayload) + + expectType<{ + id: string, function: { + name: string, + arguments: string, + } + }[]>(result) } \ No newline at end of file diff --git a/lib/prompt.js b/lib/prompt.js index e7c9db0..8fe4901 100644 --- a/lib/prompt.js +++ b/lib/prompt.js @@ -47,10 +47,38 @@ export async function prompt(userPrompt, promptOptions) { } ); - const data = await response.json(); + if (response.ok) { + const data = await response.json(); + return { + requestId: response.headers.get("x-request-id"), + message: data.choices[0].message, + }; + } + + const requestId = response.headers.get("x-request-id"); return { - requestId: response.headers.get("x-request-id"), - message: data.choices[0].message, + requestId: requestId, + message: { + role: "Sssistant", + content: `Sorry, an error occured with the chat completions API. (Status: ${response.status}, request ID: ${requestId})`, + }, }; } + +/** @type {import('..').GetFunctionCallsInterface} */ +export function getFunctionCalls(payload) { + const functionCalls = payload.message.tool_calls; + + if (!functionCalls) return []; + + return functionCalls.map((call) => { + return { + id: call.id, + function: { + name: call.function.name, + arguments: call.function.arguments, + }, + }; + }); +} diff --git a/test/prompt.test.js b/test/prompt.test.js index 2b83e12..26c8fbf 100644 --- a/test/prompt.test.js +++ b/test/prompt.test.js @@ -2,7 +2,7 @@ import { test, suite } from "node:test"; import { MockAgent } from "undici"; -import { prompt } from "../index.js"; +import { prompt, getFunctionCalls } from "../index.js"; suite("prompt", () => { test("smoke", (t) => { @@ -267,4 +267,102 @@ suite("prompt", () => { }, }); }); + + test("Handles error", async (t) => { + const mockAgent = new MockAgent(); + function fetchMock(url, opts) { + opts ||= {}; + opts.dispatcher = mockAgent; + return fetch(url, opts); + } + + mockAgent.disableNetConnect(); + const mockPool = mockAgent.get("https://api.githubcopilot.com"); + mockPool + .intercept({ + method: "post", + path: `/chat/completions`, + body: JSON.stringify({ + messages: [ + { + role: "system", + content: "You are a helpful assistant.", + }, + { + role: "user", + content: "What is the capital of France?", + }, + ], + model: "gpt-4", + }), + }) + .reply(400, "Bad Request", { + headers: { + "content-type": "text/plain", + "x-request-id": "", + }, + }); + + const result = await prompt("What is the capital of France?", { + token: "secret", + model: "gpt-4", + request: { fetch: fetchMock }, + }); + + t.assert.deepEqual(result, { + message: { + content: + "Sorry, an error occured with the chat completions API. (Status: 400, request ID: )", + role: "Sssistant", + }, + requestId: "", + }); + }); + + suite("getFunctionCalls()", () => { + test("includes function calls", async (t) => { + const tool_calls = [ + { + function: { + arguments: '{\n "order_id": "123"\n}', + name: "get_delivery_date", + }, + id: "call_Eko8Jz0mgchNOqiJJrrMr8YW", + type: "function", + }, + ]; + const result = getFunctionCalls({ + requestId: "", + message: { + role: "assistant", + tool_calls, + }, + }); + + t.assert.deepEqual( + result, + tool_calls.map((call) => { + return { + id: call.id, + function: { + name: call.function.name, + arguments: call.function.arguments, + }, + }; + }) + ); + }); + + test("does not include function calls", async (t) => { + const result = getFunctionCalls({ + requestId: "", + message: { + content: "Hello! How can I assist you today?", + role: "assistant", + }, + }); + + t.assert.deepEqual(result, []); + }); + }); });