diff --git a/src/functions.ts b/src/functions.ts index d60c7a4..adee01a 100644 --- a/src/functions.ts +++ b/src/functions.ts @@ -1,4 +1,5 @@ -import OpenAI from "openai"; +import type { PromptFunction, InteropMessage } from "@copilot-extensions/preview-sdk"; + import { ModelsAPI } from "./models-api.js"; // defaultModel is the model used for internal calls - for tool calling, @@ -8,18 +9,18 @@ export const defaultModel = "gpt-4o-mini"; // RunnerResponse is the response from a function call. export interface RunnerResponse { model: string; - messages: OpenAI.ChatCompletionMessageParam[]; + messages: InteropMessage[]; } export abstract class Tool { modelsAPI: ModelsAPI; - static definition: OpenAI.FunctionDefinition; + static definition: PromptFunction["function"]; constructor(modelsAPI: ModelsAPI) { this.modelsAPI = modelsAPI; } - static get tool(): OpenAI.Chat.Completions.ChatCompletionTool { + static get tool(): PromptFunction { return { type: "function", function: this.definition, @@ -27,7 +28,7 @@ export abstract class Tool { } abstract execute( - messages: OpenAI.ChatCompletionMessageParam[], - args: object + messages: InteropMessage[], + args: Record ): Promise; } diff --git a/src/functions/describe-model.ts b/src/functions/describe-model.ts index 0481227..c1026bf 100644 --- a/src/functions/describe-model.ts +++ b/src/functions/describe-model.ts @@ -1,4 +1,5 @@ -import OpenAI from "openai"; +import type { InteropMessage } from "@copilot-extensions/preview-sdk"; + import { RunnerResponse, defaultModel, Tool } from "../functions.js"; export class describeModel extends Tool { @@ -19,7 +20,7 @@ export class describeModel extends Tool { }; async execute( - messages: OpenAI.ChatCompletionMessageParam[], + messages: InteropMessage[], args: { model: string } ): Promise { const [model, modelSchema] = await Promise.all([ diff --git a/src/functions/execute-model.ts b/src/functions/execute-model.ts index 405a618..62a3422 100644 --- a/src/functions/execute-model.ts +++ b/src/functions/execute-model.ts @@ -1,7 +1,8 @@ -import OpenAI from "openai"; +import type { InteropMessage } from "@copilot-extensions/preview-sdk"; + import { RunnerResponse, Tool } from "../functions.js"; -type MessageWithReferences = OpenAI.ChatCompletionMessageParam & { +type MessageWithReferences = InteropMessage & { copilot_references: Reference[]; }; diff --git a/src/functions/list-models.ts b/src/functions/list-models.ts index 3a47715..56d0183 100644 --- a/src/functions/list-models.ts +++ b/src/functions/list-models.ts @@ -1,4 +1,5 @@ -import OpenAI from "openai"; +import type { InteropMessage } from "@copilot-extensions/preview-sdk"; + import { RunnerResponse, defaultModel, Tool } from "../functions.js"; export class listModels extends Tool { @@ -15,7 +16,7 @@ export class listModels extends Tool { }; async execute( - messages: OpenAI.ChatCompletionMessageParam[] + messages: InteropMessage[] ): Promise { const models = await this.modelsAPI.listModels(); diff --git a/src/functions/recommend-model.ts b/src/functions/recommend-model.ts index f49345b..c177d39 100644 --- a/src/functions/recommend-model.ts +++ b/src/functions/recommend-model.ts @@ -1,4 +1,5 @@ -import OpenAI from "openai"; +import type { InteropMessage } from "@copilot-extensions/preview-sdk"; + import { RunnerResponse, defaultModel, Tool } from "../functions.js"; export class recommendModel extends Tool { @@ -15,7 +16,7 @@ export class recommendModel extends Tool { }; async execute( - messages: OpenAI.ChatCompletionMessageParam[] + messages: InteropMessage[] ): Promise { const models = await this.modelsAPI.listModels(); diff --git a/src/index.ts b/src/index.ts index 38854cf..cd1874d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,7 +1,6 @@ -import { createServer, IncomingMessage } from "node:http"; +import { createServer } from "node:http"; -import { verifyAndParseRequest, createAckEvent } from "@copilot-extensions/preview-sdk"; -import OpenAI from "openai"; +import { prompt, getFunctionCalls, createAckEvent, createDoneEvent, verifyAndParseRequest, createTextEvent } from "@copilot-extensions/preview-sdk"; import { describeModel } from "./functions/describe-model.js"; import { executeModel } from "./functions/execute-model.js"; @@ -12,6 +11,7 @@ import { ModelsAPI } from "./models-api.js"; const server = createServer(async (request, response) => { if (request.method === "GET") { + // health check response.statusCode = 200; response.end(`OK`); return; @@ -55,15 +55,9 @@ const server = createServer(async (request, response) => { response.write(createAckEvent().toString()); // List of functions that are available to be called - const modelsAPI = new ModelsAPI(apiKey); + const modelsAPI = new ModelsAPI(); const functions = [listModels, describeModel, executeModel, recommendModel]; - // Use the Copilot API to determine which function to execute - const capiClient = new OpenAI({ - baseURL: "https://api.githubcopilot.com", - apiKey, - }); - // Prepend a system message that includes the list of models, so that // tool calls can better select the right model to use. const models = await modelsAPI.listModels(); @@ -91,49 +85,41 @@ const server = createServer(async (request, response) => { ].concat(payload.messages); console.time("tool-call"); - const toolCaller = await capiClient.chat.completions.create({ - stream: false, - model: "gpt-4", - // @ts-expect-error - TODO @gr2m - type incompatibility between @openai/api and @copilot-extensions/preview-sdk + const promptResult = await prompt({ messages: toolCallMessages, - tool_choice: "auto", + token: apiKey, tools: functions.map((f) => f.tool), - }); + }) console.timeEnd("tool-call"); + const [functionToCall] = getFunctionCalls(promptResult) + if ( - !toolCaller.choices[0] || - !toolCaller.choices[0].message || - !toolCaller.choices[0].message.tool_calls || - !toolCaller.choices[0].message.tool_calls[0].function + !functionToCall ) { console.log("No tool call found"); - // No tool to call, so just call the model with the original messages - const stream = await capiClient.chat.completions.create({ - stream: true, - model: "gpt-4", - // @ts-expect-error - TODO @gr2m - type incompatibility between @openai/api and @copilot-extensions/preview-sdk + + const { stream } = await prompt.stream({ messages: payload.messages, - }); + token: apiKey, + }) for await (const chunk of stream) { - const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n"; - response.write(chunkStr); + response.write(new TextDecoder().decode(chunk)); } - response.write("data: [DONE]\n\n"); - response.end(); + + response.end(createDoneEvent().toString()); return; } - const functionToCall = toolCaller.choices[0].message.tool_calls[0].function; - const args = JSON.parse(functionToCall.arguments); + const args = JSON.parse(functionToCall.function.arguments); console.time("function-exec"); let functionCallRes: RunnerResponse; try { - console.log("Executing function", functionToCall.name); + console.log("Executing function", functionToCall.function.name); const funcClass = functions.find( - (f) => f.definition.name === functionToCall.name + (f) => f.definition.name === functionToCall.function.name ); if (!funcClass) { throw new Error("Unknown function"); @@ -141,7 +127,6 @@ const server = createServer(async (request, response) => { console.log("\t with args", args); const func = new funcClass(modelsAPI); - // @ts-expect-error - TODO @gr2m - type incompatibility between @openai/api and @copilot-extensions/preview-sdk functionCallRes = await func.execute(payload.messages, args); } catch (err) { console.error(err); @@ -152,23 +137,20 @@ const server = createServer(async (request, response) => { console.timeEnd("function-exec"); try { - const stream = await modelsAPI.inference.chat.completions.create({ + console.time("streaming"); + const { stream } = await prompt.stream({ + endpoint: 'https://models.inference.ai.azure.com/chat/completions', model: functionCallRes.model, messages: functionCallRes.messages, - stream: true, - stream_options: { - include_usage: false, - }, - }); + token: apiKey, + }) - console.time("streaming"); for await (const chunk of stream) { - const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n"; - response.write(chunkStr); + response.write(new TextDecoder().decode(chunk)); } - response.write("data: [DONE]\n\n"); + + response.end(createDoneEvent().toString()); console.timeEnd("streaming"); - response.end(); } catch (err) { console.error(err); response.statusCode = 500 @@ -180,12 +162,12 @@ const port = process.env.PORT || "3000" server.listen(port); console.log(`Server running at http://localhost:${port}`); -function getBody(request: IncomingMessage): Promise { +function getBody(request: any): Promise { return new Promise((resolve) => { const bodyParts: any[] = []; let body; request - .on("data", (chunk) => { + .on("data", (chunk: Buffer) => { bodyParts.push(chunk); }) .on("end", () => { diff --git a/src/models-api.ts b/src/models-api.ts index 1db5b68..c53b150 100644 --- a/src/models-api.ts +++ b/src/models-api.ts @@ -1,5 +1,3 @@ -import OpenAI from "openai"; - // Model is the structure of a model in the model catalog. export interface Model { id: string; @@ -33,16 +31,8 @@ export type ModelSchemaParameter = { }; export class ModelsAPI { - inference: OpenAI; private _models: Model[] | null = null; - constructor(apiKey: string) { - this.inference = new OpenAI({ - baseURL: "https://models.inference.ai.azure.com", - apiKey, - }); - } - async getModel(modelName: string): Promise { const modelRes = await fetch( "https://modelcatalog.azure-api.net/v1/model/" + modelName