diff --git a/src/functions/execute-model.ts b/src/functions/execute-model.ts index 5dc3e1b..cc52dff 100644 --- a/src/functions/execute-model.ts +++ b/src/functions/execute-model.ts @@ -26,8 +26,11 @@ Example Queries (IMPORTANT: Phrasing doesn't have to match): properties: { model: { type: "string", - description: - "The name of the model to execute. It is ONLY the name of the model, not the publisher or registry. For example: `gpt-4o`, or `cohere-command-r-plus`.", + description: [ + "The name of the model to execute. It is ONLY the name of the model, not the publisher or registry.", + "For example: `gpt-4o`, or `cohere-command-r-plus`.", + "The list of models is available in the context window of the chat, in the `<-- LIST OF MODELS -->` section.", + ].join("\n"), }, instruction: { type: "string", diff --git a/src/functions/list-models.ts b/src/functions/list-models.ts index ffd940f..e1552e3 100644 --- a/src/functions/list-models.ts +++ b/src/functions/list-models.ts @@ -23,26 +23,17 @@ export class listModels extends Tool { "The user is asking for a list of available models.", "Respond with a concise and readable list of the models, with a short description for each one.", "Use markdown formatting to make each description more readable.", - "Begin each model's description with a header consisting of the model's registry and name", - "The header must be formatted as `/`.", + "Begin each model's description with a header consisting of the model's name", "That list of models is as follows:", + JSON.stringify( + models.map((model) => ({ + name: model.friendly_name, + publisher: model.publisher, + description: model.summary, + })) + ), ]; - for (const model of models) { - systemMessage.push( - [ - `\t- Model Name: ${model.name}`, - `\t\tModel Version: ${model.model_version}`, - `\t\tPublisher: ${model.publisher}`, - `\t\tModel Family: ${model.model_family}`, - `\t\tModel Registry: ${model.model_registry}`, - `\t\tLicense: ${model.license}`, - `\t\tTask: ${model.task}`, - `\t\tSummary: ${model.summary}`, - ].join("\n") - ); - } - return { model: defaultModel, messages: [ diff --git a/src/index.ts b/src/index.ts index 2901384..84339dc 100644 --- a/src/index.ts +++ b/src/index.ts @@ -37,14 +37,17 @@ app.post("/", verifySignatureMiddleware, express.json(), async (req, res) => { "You are an extension of GitHub Copilot, built to interact with GitHub Models.", "GitHub Models is a language model playground, where you can experiment with different models and see how they respond to your prompts.", "Here is a list of some of the models available to the user:", + "<-- LIST OF MODELS -->", JSON.stringify( models.map((model) => ({ + friendly_name: model.friendly_name, name: model.name, publisher: model.publisher, registry: model.model_registry, description: model.summary, })) ), + "<-- END OF LIST OF MODELS -->", ].join("\n"), }, ...req.body.messages, @@ -107,22 +110,28 @@ app.post("/", verifySignatureMiddleware, express.json(), async (req, res) => { } console.timeEnd("function-exec"); - console.time("stream"); - const stream = await modelsAPI.inference.chat.completions.create({ - model: functionCallRes.model, - messages: functionCallRes.messages, - stream: true, - }); - console.timeEnd("stream"); + try { + const stream = await modelsAPI.inference.chat.completions.create({ + model: functionCallRes.model, + messages: functionCallRes.messages, + stream: true, + stream_options: { + include_usage: false, + }, + }); - console.time("streaming"); - for await (const chunk of stream) { - const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n"; - res.write(chunkStr); + console.time("streaming"); + for await (const chunk of stream) { + const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n"; + res.write(chunkStr); + } + res.write("data: [DONE]\n\n"); + console.timeEnd("streaming"); + res.end(); + } catch (err) { + console.error(err); + res.status(500).end(); } - res.write("data: [DONE]\n\n"); - console.timeEnd("streaming"); - res.end(); }); // Health check