Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prompt): prompt.stream(), options.endpoint #57

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ await prompt({
});
```

#### `prompt.stream(message, options)`

Works the same way as `prompt()`, but resolves with a `stream` key instead of a `message` key.

```js
import { prompt } from "@copilot-extensions/preview-sdk";

const { requestId, stream } = prompt.stream("What is the capital of France?", {
token: process.env.TOKEN,
});

for await (const chunk of stream) {
console.log(new TextDecoder().decode(chunk));
}
```

### `getFunctionCalls()`

Convenience metthod if a result from a `prompt()` call includes function calls.
Expand Down
18 changes: 16 additions & 2 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ export interface OpenAICompatibilityPayload {
export interface CopilotMessage {
role: string;
content: string;
copilot_references: MessageCopilotReference[];
copilot_references?: MessageCopilotReference[];
copilot_confirmations?: MessageCopilotConfirmation[];
tool_calls?: {
function: {
Expand Down Expand Up @@ -300,8 +300,9 @@ export interface PromptFunction {
}

export type PromptOptions = {
model?: ModelName;
token: string;
endpoint?: string;
model?: ModelName;
tools?: PromptFunction[];
messages?: InteropMessage[];
request?: {
Expand All @@ -314,12 +315,25 @@ export type PromptResult = {
message: CopilotMessage;
};

export type PromptStreamResult = {
requestId: string;
stream: ReadableStream<Uint8Array>;
};

// https://stackoverflow.com/a/69328045
type WithRequired<T, K extends keyof T> = T & { [P in K]-?: T[P] };

interface PromptInterface {
(userPrompt: string, options: PromptOptions): Promise<PromptResult>;
(options: WithRequired<PromptOptions, "messages">): Promise<PromptResult>;
stream: PromptStreamInterface;
}

interface PromptStreamInterface {
(userPrompt: string, options: PromptOptions): Promise<PromptStreamResult>;
(
options: WithRequired<PromptOptions, "messages">,
): Promise<PromptStreamResult>;
}

interface GetFunctionCallsInterface {
Expand Down
21 changes: 18 additions & 3 deletions index.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ export function getUserConfirmationTest(payload: CopilotRequestPayload) {

export async function promptTest() {
const result = await prompt("What is the capital of France?", {
model: "gpt-4",
token: "secret",
});

Expand All @@ -311,7 +310,6 @@ export async function promptTest() {

// with custom fetch
await prompt("What is the capital of France?", {
model: "gpt-4",
token: "secret",
request: {
fetch: () => {},
Expand All @@ -327,7 +325,6 @@ export async function promptTest() {

export async function promptWithToolsTest() {
await prompt("What is the capital of France?", {
model: "gpt-4",
token: "secret",
tools: [
{
Expand Down Expand Up @@ -366,6 +363,24 @@ export async function promptWithoutMessageButMessages() {
});
}

export async function otherPromptOptionsTest() {
const result = await prompt("What is the capital of France?", {
token: "secret",
model: "gpt-4",
endpoint: "https://api.githubcopilot.com",
});
}

export async function promptStreamTest() {
const result = await prompt.stream("What is the capital of France?", {
model: "gpt-4",
token: "secret",
});

expectType<string>(result.requestId);
expectType<ReadableStream<Uint8Array>>(result.stream);
}

export async function getFunctionCallsTest(
promptResponsePayload: PromptResult,
) {
Expand Down
114 changes: 81 additions & 33 deletions lib/prompt.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
// @ts-check

/** @type {import('..').PromptInterface} */
export async function prompt(userPrompt, promptOptions) {
const options = typeof userPrompt === "string" ? promptOptions : userPrompt;

const promptFetch = options.request?.fetch || fetch;
const modelName = options.model || "gpt-4";
function parsePromptArguments(userPrompt, promptOptions) {
const { request: requestOptions, ...options } =
typeof userPrompt === "string" ? promptOptions : userPrompt;

const promptFetch = requestOptions?.fetch || fetch;
const model = options.model || "gpt-4";
const endpoint =
options.endpoint || "https://api.githubcopilot.com/chat/completions";

const systemMessage = options.tools
? "You are a helpful assistant. Use the supplied tools to assist the user."
: "You are a helpful assistant.";
const toolsChoice = options.tools ? "auto" : undefined;

const messages = [
{
Expand All @@ -29,44 +34,87 @@ export async function prompt(userPrompt, promptOptions) {
});
}

const response = await promptFetch(
"https://api.githubcopilot.com/chat/completions",
{
method: "POST",
headers: {
accept: "application/json",
"content-type": "application/json; charset=UTF-8",
"user-agent": "copilot-extensions/preview-sdk.js",
authorization: `Bearer ${options.token}`,
},
body: JSON.stringify({
messages: messages,
model: modelName,
toolChoice: options.tools ? "auto" : undefined,
tools: options.tools,
}),
}
);
return [promptFetch, { ...options, messages, model, endpoint, toolsChoice }];
}

if (response.ok) {
const data = await response.json();
async function sendPromptRequest(promptFetch, options) {
const { endpoint, token, ...payload } = options;
const method = "POST";
const headers = {
accept: "application/json",
"content-type": "application/json; charset=UTF-8",
"user-agent": "copilot-extensions/preview-sdk.js",
authorization: `Bearer ${token}`,
};

return {
requestId: response.headers.get("x-request-id"),
message: data.choices[0].message,
};
const response = await promptFetch(endpoint, {
method,
headers,
body: JSON.stringify(payload),
});

if (response.ok) {
return response;
}

const body = await response.text();
console.log({ body });

throw Object.assign(
new Error(
`[@copilot-extensions/preview-sdk] An error occured with the chat completions API`,
),
{
name: "PromptError",
request: {
method: "POST",
url: endpoint,
headers: {
...headers,
authorization: `Bearer [REDACTED]`,
},
body: payload,
},
response: {
status: response.status,
headers: [...response.headers],
body: body,
},
},
);
}
export async function prompt(userPrompt, promptOptions) {
const [promptFetch, options] = parsePromptArguments(
userPrompt,
promptOptions,
);
const response = await sendPromptRequest(promptFetch, options);
const requestId = response.headers.get("x-request-id");

const data = await response.json();

return {
requestId: requestId,
message: {
role: "Sssistant",
content: `Sorry, an error occured with the chat completions API. (Status: ${response.status}, request ID: ${requestId})`,
},
requestId,
message: data.choices[0].message,
};
}

prompt.stream = async function promptStream(userPrompt, promptOptions) {
const [promptFetch, options] = parsePromptArguments(
userPrompt,
promptOptions,
);
const response = await sendPromptRequest(promptFetch, {
...options,
stream: true,
});

return {
requestId: response.headers.get("x-request-id"),
stream: response.body,
};
};

/** @type {import('..').GetFunctionCallsInterface} */
export function getFunctionCalls(payload) {
const functionCalls = payload.message.tool_calls;
Expand Down
Loading