Skip to content

Commit

Permalink
bring back openai
Browse files Browse the repository at this point in the history
  • Loading branch information
gr2m committed Sep 6, 2024
1 parent 6a5563f commit bee45e7
Show file tree
Hide file tree
Showing 9 changed files with 519 additions and 44 deletions.
465 changes: 461 additions & 4 deletions package-lock.json

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"license": "ISC",
"description": "",
"dependencies": {
"@copilot-extensions/preview-sdk": "^3.0.0"
"@copilot-extensions/preview-sdk": "^3.0.0",
"openai": "^4.55.0"
},
"devDependencies": {
"@types/express": "^4.17.21",
Expand Down
13 changes: 6 additions & 7 deletions src/functions.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { PromptFunction, InteropMessage } from "@copilot-extensions/preview-sdk";

import OpenAI from "openai";
import { ModelsAPI } from "./models-api.js";

// defaultModel is the model used for internal calls - for tool calling,
Expand All @@ -9,26 +8,26 @@ export const defaultModel = "gpt-4o-mini";
// RunnerResponse is the response from a function call.
export interface RunnerResponse {
model: string;
messages: InteropMessage[];
messages: OpenAI.ChatCompletionMessageParam[];
}

export abstract class Tool {
modelsAPI: ModelsAPI;
static definition: PromptFunction["function"];
static definition: OpenAI.FunctionDefinition;

constructor(modelsAPI: ModelsAPI) {
this.modelsAPI = modelsAPI;
}

static get tool(): PromptFunction {
static get tool(): OpenAI.Chat.Completions.ChatCompletionTool {
return {
type: "function",
function: this.definition,
};
}

abstract execute(
messages: InteropMessage[],
args: Record<string, unknown>
messages: OpenAI.ChatCompletionMessageParam[],
args: object
): Promise<RunnerResponse>;
}
5 changes: 2 additions & 3 deletions src/functions/describe-model.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import OpenAI from "openai";
import { RunnerResponse, defaultModel, Tool } from "../functions.js";

export class describeModel extends Tool {
Expand All @@ -20,7 +19,7 @@ export class describeModel extends Tool {
};

async execute(
messages: InteropMessage[],
messages: OpenAI.ChatCompletionMessageParam[],
args: { model: string }
): Promise<RunnerResponse> {
const [model, modelSchema] = await Promise.all([
Expand Down
5 changes: 2 additions & 3 deletions src/functions/execute-model.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import OpenAI from "openai";
import { RunnerResponse, Tool } from "../functions.js";

type MessageWithReferences = InteropMessage & {
type MessageWithReferences = OpenAI.ChatCompletionMessageParam & {
copilot_references: Reference[];
};

Expand Down
5 changes: 2 additions & 3 deletions src/functions/list-models.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import OpenAI from "openai";
import { RunnerResponse, defaultModel, Tool } from "../functions.js";

export class listModels extends Tool {
Expand All @@ -16,7 +15,7 @@ export class listModels extends Tool {
};

async execute(
messages: InteropMessage[]
messages: OpenAI.ChatCompletionMessageParam[]
): Promise<RunnerResponse> {
const models = await this.modelsAPI.listModels();

Expand Down
5 changes: 2 additions & 3 deletions src/functions/recommend-model.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import OpenAI from "openai";
import { RunnerResponse, defaultModel, Tool } from "../functions.js";

export class recommendModel extends Tool {
Expand All @@ -16,7 +15,7 @@ export class recommendModel extends Tool {
};

async execute(
messages: InteropMessage[]
messages: OpenAI.ChatCompletionMessageParam[]
): Promise<RunnerResponse> {
const models = await this.modelsAPI.listModels();

Expand Down
52 changes: 32 additions & 20 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { createServer } from "node:http";
import { createServer, type IncomingMessage } from "node:http";

import { prompt, getFunctionCalls, createDoneEvent, verifyAndParseRequest } from "@copilot-extensions/preview-sdk";
import { getFunctionCalls, createDoneEvent, verifyAndParseRequest } from "@copilot-extensions/preview-sdk";
import OpenAI from "openai";

import { describeModel } from "./functions/describe-model.js";
import { executeModel } from "./functions/execute-model.js";
Expand Down Expand Up @@ -52,9 +53,15 @@ const server = createServer(async (request, response) => {
}

// List of functions that are available to be called
const modelsAPI = new ModelsAPI();
const modelsAPI = new ModelsAPI(apiKey);
const functions = [listModels, describeModel, executeModel, recommendModel];

// Use the Copilot API to determine which function to execute
const capiClient = new OpenAI({
baseURL: "https://api.githubcopilot.com",
apiKey,
});

// Prepend a system message that includes the list of models, so that
// tool calls can better select the right model to use.
const models = await modelsAPI.listModels();
Expand Down Expand Up @@ -82,27 +89,29 @@ const server = createServer(async (request, response) => {
].concat(payload.messages);

console.time("tool-call");
const promptResult = await prompt({
const toolCaller = await capiClient.chat.completions.create({
messages: toolCallMessages,
token: apiKey,
stream: false,
model: "gpt-4",
tools: functions.map((f) => f.tool),
})
console.timeEnd("tool-call");

const [functionToCall] = getFunctionCalls(promptResult)
const [functionToCall] = getFunctionCalls(toolCaller)

if (
!functionToCall
) {
console.log("No tool call found");

const { stream } = await prompt.stream({
messages: payload.messages,
token: apiKey,
})
// No tool to call, so just call the model with the original messages
const stream = await capiClient.chat.completions.create({
stream: true,
model: "gpt-4",
});

for await (const chunk of stream) {
response.write(new TextDecoder().decode(chunk));
const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n";
response.write(chunkStr);
}

response.end(createDoneEvent().toString());
Expand Down Expand Up @@ -134,16 +143,19 @@ const server = createServer(async (request, response) => {
console.timeEnd("function-exec");

try {
console.time("streaming");
const { stream } = await prompt.stream({
endpoint: 'https://models.inference.ai.azure.com/chat/completions',
const stream = await modelsAPI.inference.chat.completions.create({
model: functionCallRes.model,
messages: functionCallRes.messages,
token: apiKey,
})
stream: true,
stream_options: {
include_usage: false,
},
});

console.time("streaming");
for await (const chunk of stream) {
response.write(new TextDecoder().decode(chunk));
const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n";
response.write(chunkStr);
}

response.end(createDoneEvent().toString());
Expand All @@ -159,9 +171,9 @@ const port = process.env.PORT || "3000"
server.listen(port);
console.log(`Server running at http://localhost:${port}`);

function getBody(request: any): Promise<string> {
function getBody(request: IncomingMessage): Promise<string> {
return new Promise((resolve) => {
const bodyParts: any[] = [];
const bodyParts: Buffer[] = [];
let body;
request
.on("data", (chunk: Buffer) => {
Expand Down
10 changes: 10 additions & 0 deletions src/models-api.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import OpenAI from "openai";

// Model is the structure of a model in the model catalog.
export interface Model {
id: string;
Expand Down Expand Up @@ -31,8 +33,16 @@ export type ModelSchemaParameter = {
};

export class ModelsAPI {
inference: OpenAI;
private _models: Model[] | null = null;

constructor(apiKey: string) {
this.inference = new OpenAI({
baseURL: "https://models.inference.ai.azure.com",
apiKey,
});
}

async getModel(modelName: string): Promise<Model> {
const modelRes = await fetch(
"https://modelcatalog.azure-api.net/v1/model/" + modelName
Expand Down

0 comments on commit bee45e7

Please sign in to comment.