Skip to content

Commit

Permalink
Merge pull request #2 from copilot-extensions/include-models-in-funcs
Browse files Browse the repository at this point in the history
Include list of models in tool call system prompt
  • Loading branch information
JasonEtco authored Aug 15, 2024
2 parents c0bd65e + fe031c7 commit 9c8400d
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 17 deletions.
10 changes: 4 additions & 6 deletions src/functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ export interface RunnerResponse {
messages: OpenAI.ChatCompletionMessageParam[];
}

export class Tool {
export abstract class Tool {
modelsAPI: ModelsAPI;
static definition: OpenAI.FunctionDefinition;

constructor(modelsAPI: ModelsAPI) {
this.modelsAPI = modelsAPI;
Expand All @@ -24,12 +25,9 @@ export class Tool {
function: this.definition,
};
}
static definition: OpenAI.FunctionDefinition;

async execute(
abstract execute(
messages: OpenAI.ChatCompletionMessageParam[],
args: object
): Promise<RunnerResponse> {
throw new Error("Not implemented");
}
): Promise<RunnerResponse>;
}
2 changes: 1 addition & 1 deletion src/functions/describe-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export class describeModel extends Tool {
model: {
type: "string",
description:
'The model to describe. Looks like "publisher/model-name".',
'The model to describe. Looks like "registry/model-name". For example, `azureml/Phi-3-medium-128k-instruct` or `azure-openai/gpt-4o',
},
},
required: ["model"],
Expand Down
1 change: 1 addition & 0 deletions src/functions/recommend-model.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import OpenAI from "openai";
import { RunnerResponse, defaultModel, Tool } from "../functions";
import { ModelsAPI } from "../models-api";

export class recommendModel extends Tool {
static definition = {
Expand Down
42 changes: 32 additions & 10 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ import { listModels } from "./functions/list-models";
import { RunnerResponse } from "./functions";
import { recommendModel } from "./functions/recommend-model";
import { ModelsAPI } from "./models-api";

// List of functions that are available to be called
const functions = [listModels, describeModel, executeModel, recommendModel];

const app = express();

app.post("/", verifySignatureMiddleware, express.json(), async (req, res) => {
Expand All @@ -21,17 +17,44 @@ app.post("/", verifySignatureMiddleware, express.json(), async (req, res) => {
return;
}

// List of functions that are available to be called
const modelsAPI = new ModelsAPI(apiKey);
const functions = [listModels, describeModel, executeModel, recommendModel];

// Use the Copilot API to determine which function to execute
const capiClient = new OpenAI({
baseURL: "https://api.githubcopilot.com",
apiKey,
});

// Prepend a system message that includes the list of models, so that
// tool calls can better select the right model to use.
const models = await modelsAPI.listModels();
const toolCallMessages = [
{
role: "system",
content: [
"You are an extension of GitHub Copilot, built to interact with GitHub Models.",
"GitHub Models is a language model playground, where you can experiment with different models and see how they respond to your prompts.",
"Here is a list of some of the models available to the user:",
JSON.stringify(
models.map((model) => ({
name: model.name,
publisher: model.publisher,
registry: model.model_registry,
description: model.summary,
}))
),
].join("\n"),
},
...req.body.messages,
].concat(req.body.messages);

console.time("tool-call");
const toolCaller = await capiClient.chat.completions.create({
stream: false,
model: "gpt-4",
messages: req.body.messages,
messages: toolCallMessages,
tool_choice: "auto",
tools: functions.map((f) => f.tool),
});
Expand Down Expand Up @@ -63,20 +86,19 @@ app.post("/", verifySignatureMiddleware, express.json(), async (req, res) => {
const args = JSON.parse(functionToCall.arguments);

console.time("function-exec");
const modelsAPI = new ModelsAPI(apiKey);
let functionCallRes: RunnerResponse;
try {
console.log("Executing function", functionToCall.name);
const klass = functions.find(
const funcClass = functions.find(
(f) => f.definition.name === functionToCall.name
);
if (!klass) {
if (!funcClass) {
throw new Error("Unknown function");
}

console.log("\t with args", args);
const inst = new klass(modelsAPI);
functionCallRes = await inst.execute(req.body.messages, args);
const func = new funcClass(modelsAPI);
functionCallRes = await func.execute(req.body.messages, args);
} catch (err) {
console.error(err);
res.status(500).end();
Expand Down
6 changes: 6 additions & 0 deletions src/models-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export type ModelSchemaParameter = {

export class ModelsAPI {
inference: OpenAI;
private _models: Model[] | null = null;

constructor(apiKey: string) {
this.inference = new OpenAI({
Expand Down Expand Up @@ -67,6 +68,10 @@ export class ModelsAPI {
}

async listModels(): Promise<Model[]> {
if (this._models) {
return this._models;
}

const modelsRes = await fetch(
"https://modelcatalog.azure-api.net/v1/models"
);
Expand All @@ -75,6 +80,7 @@ export class ModelsAPI {
}

const models = (await modelsRes.json()) as Model[];
this._models = models;
return models;
}
}

0 comments on commit 9c8400d

Please sign in to comment.