Skip to content

Commit d7ef3b1

Browse files
taeminleensarrazin
andauthored
Extend endpointOai.ts to allow usage of extra sampling parameters (huggingface#1032)
* Extend endpointOai.ts to allow usage of extra sampling parameters when calling vllm as an OpenAI compatible * refactor : prettier endpointOai.ts * Fix: Corrected type imports in endpointOai.ts * Simplifies code a bit and adds `extraBody` to open ai endpooint * Update zod schema to allow any type in extraBody --------- Co-authored-by: Nathan Sarrazin <[email protected]>
1 parent 0fa2340 commit d7ef3b1

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

src/lib/server/endpoints/openai/endpointOai.ts

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import { z } from "zod";
22
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
33
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
4+
import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
5+
import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
46
import { buildPrompt } from "$lib/buildPrompt";
57
import { env } from "$env/dynamic/private";
68
import type { Endpoint } from "../endpoints";
@@ -16,12 +18,13 @@ export const endpointOAIParametersSchema = z.object({
1618
.default("chat_completions"),
1719
defaultHeaders: z.record(z.string()).optional(),
1820
defaultQuery: z.record(z.string()).optional(),
21+
extraBody: z.record(z.any()).optional(),
1922
});
2023

2124
export async function endpointOai(
2225
input: z.input<typeof endpointOAIParametersSchema>
2326
): Promise<Endpoint> {
24-
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
27+
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, extraBody } =
2528
endpointOAIParametersSchema.parse(input);
2629
let OpenAI;
2730
try {
@@ -47,19 +50,22 @@ export async function endpointOai(
4750
});
4851

4952
const parameters = { ...model.parameters, ...generateSettings };
53+
const body: CompletionCreateParamsStreaming = {
54+
model: model.id ?? model.name,
55+
prompt,
56+
stream: true,
57+
max_tokens: parameters?.max_new_tokens,
58+
stop: parameters?.stop,
59+
temperature: parameters?.temperature,
60+
top_p: parameters?.top_p,
61+
frequency_penalty: parameters?.repetition_penalty,
62+
};
5063

51-
return openAICompletionToTextGenerationStream(
52-
await openai.completions.create({
53-
model: model.id ?? model.name,
54-
prompt,
55-
stream: true,
56-
max_tokens: parameters?.max_new_tokens,
57-
stop: parameters?.stop,
58-
temperature: parameters?.temperature,
59-
top_p: parameters?.top_p,
60-
frequency_penalty: parameters?.repetition_penalty,
61-
})
62-
);
64+
const openAICompletion = await openai.completions.create(body, {
65+
body: { ...body, ...extraBody },
66+
});
67+
68+
return openAICompletionToTextGenerationStream(openAICompletion);
6369
};
6470
} else if (completion === "chat_completions") {
6571
return async ({ messages, preprompt, generateSettings }) => {
@@ -77,19 +83,22 @@ export async function endpointOai(
7783
}
7884

7985
const parameters = { ...model.parameters, ...generateSettings };
86+
const body: ChatCompletionCreateParamsStreaming = {
87+
model: model.id ?? model.name,
88+
messages: messagesOpenAI,
89+
stream: true,
90+
max_tokens: parameters?.max_new_tokens,
91+
stop: parameters?.stop,
92+
temperature: parameters?.temperature,
93+
top_p: parameters?.top_p,
94+
frequency_penalty: parameters?.repetition_penalty,
95+
};
96+
97+
const openChatAICompletion = await openai.chat.completions.create(body, {
98+
body: { ...body, ...extraBody },
99+
});
80100

81-
return openAIChatToTextGenerationStream(
82-
await openai.chat.completions.create({
83-
model: model.id ?? model.name,
84-
messages: messagesOpenAI,
85-
stream: true,
86-
max_tokens: parameters?.max_new_tokens,
87-
stop: parameters?.stop,
88-
temperature: parameters?.temperature,
89-
top_p: parameters?.top_p,
90-
frequency_penalty: parameters?.repetition_penalty,
91-
})
92-
);
101+
return openAIChatToTextGenerationStream(openChatAICompletion);
93102
};
94103
} else {
95104
throw new Error("Invalid completion type");

0 commit comments

Comments
 (0)