Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: getFunctionCalls() #50

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,23 @@ await prompt({
});
```

### `getFunctionCalls()`

Convenience metthod if a result from a `prompt()` call includes function calls.

```js
import { prompt, getFunctionCalls } from "@copilot-extensions/preview-sdk";

const result = await prompt(options);
const [functionCall] = getFunctionCalls(result);

if (functionCall) {
console.log("Received a function call", functionCall);
} else {
console.log("No function call received");
}
```

## Dreamcode

While implementing the lower-level functionality, we also dream big: what would our dream SDK for Coplitot extensions look like? Please have a look and share your thoughts and ideas:
Expand Down
13 changes: 12 additions & 1 deletion index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,16 @@ interface PromptInterface {
(options: WithRequired<PromptOptions, "messages">): Promise<PromptResult>;
}

interface GetFunctionCallsInterface {
(payload: PromptResult): {
id: string;
function: {
name: string,
arguments: string,
}
}[]
}

// exported methods

export declare const verifyRequest: VerifyRequestInterface;
Expand All @@ -325,4 +335,5 @@ export declare const verifyAndParseRequest: VerifyAndParseRequestInterface;
export declare const getUserMessage: GetUserMessageInterface;
export declare const getUserConfirmation: GetUserConfirmationInterface;

export declare const prompt: PromptInterface;
export declare const prompt: PromptInterface;
export declare const getFunctionCalls: GetFunctionCallsInterface;
13 changes: 13 additions & 0 deletions index.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import {
type InteropMessage,
CopilotRequestPayload,
prompt,
PromptResult,
getFunctionCalls,
} from "./index.js";

const token = "";
Expand Down Expand Up @@ -335,4 +337,15 @@ export async function promptWithoutMessageButMessages() {
{ role: "user", content: "What about Spain?" },
],
});
}

export async function getFunctionCallsTest(promptResponsePayload: PromptResult) {
const result = getFunctionCalls(promptResponsePayload)

expectType<{
id: string, function: {
name: string,
arguments: string,
}
}[]>(result)
}
34 changes: 31 additions & 3 deletions lib/prompt.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,38 @@ export async function prompt(userPrompt, promptOptions) {
}
);

const data = await response.json();
if (response.ok) {
const data = await response.json();

return {
requestId: response.headers.get("x-request-id"),
message: data.choices[0].message,
};
}

const requestId = response.headers.get("x-request-id");
return {
requestId: response.headers.get("x-request-id"),
message: data.choices[0].message,
requestId: requestId,
message: {
role: "Sssistant",
content: `Sorry, an error occured with the chat completions API. (Status: ${response.status}, request ID: ${requestId})`,
},
};
}

/** @type {import('..').GetFunctionCallsInterface} */
export function getFunctionCalls(payload) {
const functionCalls = payload.message.tool_calls;

if (!functionCalls) return [];

return functionCalls.map((call) => {
return {
id: call.id,
function: {
name: call.function.name,
arguments: call.function.arguments,
},
};
});
}
100 changes: 99 additions & 1 deletion test/prompt.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { test, suite } from "node:test";

import { MockAgent } from "undici";

import { prompt } from "../index.js";
import { prompt, getFunctionCalls } from "../index.js";

suite("prompt", () => {
test("smoke", (t) => {
Expand Down Expand Up @@ -267,4 +267,102 @@ suite("prompt", () => {
},
});
});

test("Handles error", async (t) => {
const mockAgent = new MockAgent();
function fetchMock(url, opts) {
opts ||= {};
opts.dispatcher = mockAgent;
return fetch(url, opts);
}

mockAgent.disableNetConnect();
const mockPool = mockAgent.get("https://api.githubcopilot.com");
mockPool
.intercept({
method: "post",
path: `/chat/completions`,
body: JSON.stringify({
messages: [
{
role: "system",
content: "You are a helpful assistant.",
},
{
role: "user",
content: "What is the capital of France?",
},
],
model: "gpt-4",
}),
})
.reply(400, "Bad Request", {
headers: {
"content-type": "text/plain",
"x-request-id": "<request-id>",
},
});

const result = await prompt("What is the capital of France?", {
token: "secret",
model: "gpt-4",
request: { fetch: fetchMock },
});

t.assert.deepEqual(result, {
message: {
content:
"Sorry, an error occured with the chat completions API. (Status: 400, request ID: <request-id>)",
role: "Sssistant",
},
requestId: "<request-id>",
});
});

suite("getFunctionCalls()", () => {
test("includes function calls", async (t) => {
const tool_calls = [
{
function: {
arguments: '{\n "order_id": "123"\n}',
name: "get_delivery_date",
},
id: "call_Eko8Jz0mgchNOqiJJrrMr8YW",
type: "function",
},
];
const result = getFunctionCalls({
requestId: "<request-id>",
message: {
role: "assistant",
tool_calls,
},
});

t.assert.deepEqual(
result,
tool_calls.map((call) => {
return {
id: call.id,
function: {
name: call.function.name,
arguments: call.function.arguments,
},
};
})
);
});

test("does not include function calls", async (t) => {
const result = getFunctionCalls({
requestId: "<request-id>",
message: {
content: "Hello! How can I assist you today?",
role: "assistant",
},
});

t.assert.deepEqual(result, []);
});
});
});