Skip to content
Merged
59 changes: 34 additions & 25 deletions src/lib/server/endpoints/openai/endpointOai.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { z } from "zod";
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
import { buildPrompt } from "$lib/buildPrompt";
import { env } from "$env/dynamic/private";
import type { Endpoint } from "../endpoints";
Expand All @@ -16,12 +18,13 @@ export const endpointOAIParametersSchema = z.object({
.default("chat_completions"),
defaultHeaders: z.record(z.string()).optional(),
defaultQuery: z.record(z.string()).optional(),
extraBody: z.record(z.string()).optional(),
});

export async function endpointOai(
input: z.input<typeof endpointOAIParametersSchema>
): Promise<Endpoint> {
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, extraBody } =
endpointOAIParametersSchema.parse(input);
let OpenAI;
try {
Expand All @@ -47,19 +50,22 @@ export async function endpointOai(
});

const parameters = { ...model.parameters, ...generateSettings };
const body: CompletionCreateParamsStreaming = {
model: model.id ?? model.name,
prompt,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
};

return openAICompletionToTextGenerationStream(
await openai.completions.create({
model: model.id ?? model.name,
prompt,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
})
);
const openAICompletion = await openai.completions.create(body, {
body: { ...body, ...extraBody },
});

return openAICompletionToTextGenerationStream(openAICompletion);
};
} else if (completion === "chat_completions") {
return async ({ messages, preprompt, generateSettings }) => {
Expand All @@ -77,19 +83,22 @@ export async function endpointOai(
}

const parameters = { ...model.parameters, ...generateSettings };
const body: ChatCompletionCreateParamsStreaming = {
model: model.id ?? model.name,
messages: messagesOpenAI,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
};

const openChatAICompletion = await openai.chat.completions.create(body, {
body: { ...body, ...extraBody },
});

return openAIChatToTextGenerationStream(
await openai.chat.completions.create({
model: model.id ?? model.name,
messages: messagesOpenAI,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
})
);
return openAIChatToTextGenerationStream(openChatAICompletion);
};
} else {
throw new Error("Invalid completion type");
Expand Down