11import { z } from "zod" ;
22import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream" ;
33import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream" ;
4+ import type { CompletionCreateParamsStreaming } from "openai/resources/completions" ;
5+ import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions" ;
46import { buildPrompt } from "$lib/buildPrompt" ;
57import { env } from "$env/dynamic/private" ;
68import type { Endpoint } from "../endpoints" ;
@@ -16,12 +18,13 @@ export const endpointOAIParametersSchema = z.object({
1618 . default ( "chat_completions" ) ,
1719 defaultHeaders : z . record ( z . string ( ) ) . optional ( ) ,
1820 defaultQuery : z . record ( z . string ( ) ) . optional ( ) ,
21+ extraBody : z . record ( z . any ( ) ) . optional ( ) ,
1922} ) ;
2023
2124export async function endpointOai (
2225 input : z . input < typeof endpointOAIParametersSchema >
2326) : Promise < Endpoint > {
24- const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
27+ const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, extraBody } =
2528 endpointOAIParametersSchema . parse ( input ) ;
2629 let OpenAI ;
2730 try {
@@ -47,19 +50,22 @@ export async function endpointOai(
4750 } ) ;
4851
4952 const parameters = { ...model . parameters , ...generateSettings } ;
53+ const body : CompletionCreateParamsStreaming = {
54+ model : model . id ?? model . name ,
55+ prompt,
56+ stream : true ,
57+ max_tokens : parameters ?. max_new_tokens ,
58+ stop : parameters ?. stop ,
59+ temperature : parameters ?. temperature ,
60+ top_p : parameters ?. top_p ,
61+ frequency_penalty : parameters ?. repetition_penalty ,
62+ } ;
5063
51- return openAICompletionToTextGenerationStream (
52- await openai . completions . create ( {
53- model : model . id ?? model . name ,
54- prompt,
55- stream : true ,
56- max_tokens : parameters ?. max_new_tokens ,
57- stop : parameters ?. stop ,
58- temperature : parameters ?. temperature ,
59- top_p : parameters ?. top_p ,
60- frequency_penalty : parameters ?. repetition_penalty ,
61- } )
62- ) ;
64+ const openAICompletion = await openai . completions . create ( body , {
65+ body : { ...body , ...extraBody } ,
66+ } ) ;
67+
68+ return openAICompletionToTextGenerationStream ( openAICompletion ) ;
6369 } ;
6470 } else if ( completion === "chat_completions" ) {
6571 return async ( { messages, preprompt, generateSettings } ) => {
@@ -77,19 +83,22 @@ export async function endpointOai(
7783 }
7884
7985 const parameters = { ...model . parameters , ...generateSettings } ;
86+ const body : ChatCompletionCreateParamsStreaming = {
87+ model : model . id ?? model . name ,
88+ messages : messagesOpenAI ,
89+ stream : true ,
90+ max_tokens : parameters ?. max_new_tokens ,
91+ stop : parameters ?. stop ,
92+ temperature : parameters ?. temperature ,
93+ top_p : parameters ?. top_p ,
94+ frequency_penalty : parameters ?. repetition_penalty ,
95+ } ;
96+
97+ const openChatAICompletion = await openai . chat . completions . create ( body , {
98+ body : { ...body , ...extraBody } ,
99+ } ) ;
80100
81- return openAIChatToTextGenerationStream (
82- await openai . chat . completions . create ( {
83- model : model . id ?? model . name ,
84- messages : messagesOpenAI ,
85- stream : true ,
86- max_tokens : parameters ?. max_new_tokens ,
87- stop : parameters ?. stop ,
88- temperature : parameters ?. temperature ,
89- top_p : parameters ?. top_p ,
90- frequency_penalty : parameters ?. repetition_penalty ,
91- } )
92- ) ;
101+ return openAIChatToTextGenerationStream ( openChatAICompletion ) ;
93102 } ;
94103 } else {
95104 throw new Error ( "Invalid completion type" ) ;
0 commit comments