1
1
import { z } from "zod" ;
2
2
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream" ;
3
3
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream" ;
4
+ import type { CompletionCreateParamsStreaming } from "openai/resources/completions" ;
5
+ import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions" ;
4
6
import { buildPrompt } from "$lib/buildPrompt" ;
5
7
import { env } from "$env/dynamic/private" ;
6
8
import type { Endpoint } from "../endpoints" ;
@@ -16,12 +18,13 @@ export const endpointOAIParametersSchema = z.object({
16
18
. default ( "chat_completions" ) ,
17
19
defaultHeaders : z . record ( z . string ( ) ) . optional ( ) ,
18
20
defaultQuery : z . record ( z . string ( ) ) . optional ( ) ,
21
+ extraBody : z . record ( z . any ( ) ) . optional ( ) ,
19
22
} ) ;
20
23
21
24
export async function endpointOai (
22
25
input : z . input < typeof endpointOAIParametersSchema >
23
26
) : Promise < Endpoint > {
24
- const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
27
+ const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, extraBody } =
25
28
endpointOAIParametersSchema . parse ( input ) ;
26
29
let OpenAI ;
27
30
try {
@@ -47,19 +50,22 @@ export async function endpointOai(
47
50
} ) ;
48
51
49
52
const parameters = { ...model . parameters , ...generateSettings } ;
53
+ const body : CompletionCreateParamsStreaming = {
54
+ model : model . id ?? model . name ,
55
+ prompt,
56
+ stream : true ,
57
+ max_tokens : parameters ?. max_new_tokens ,
58
+ stop : parameters ?. stop ,
59
+ temperature : parameters ?. temperature ,
60
+ top_p : parameters ?. top_p ,
61
+ frequency_penalty : parameters ?. repetition_penalty ,
62
+ } ;
50
63
51
- return openAICompletionToTextGenerationStream (
52
- await openai . completions . create ( {
53
- model : model . id ?? model . name ,
54
- prompt,
55
- stream : true ,
56
- max_tokens : parameters ?. max_new_tokens ,
57
- stop : parameters ?. stop ,
58
- temperature : parameters ?. temperature ,
59
- top_p : parameters ?. top_p ,
60
- frequency_penalty : parameters ?. repetition_penalty ,
61
- } )
62
- ) ;
64
+ const openAICompletion = await openai . completions . create ( body , {
65
+ body : { ...body , ...extraBody } ,
66
+ } ) ;
67
+
68
+ return openAICompletionToTextGenerationStream ( openAICompletion ) ;
63
69
} ;
64
70
} else if ( completion === "chat_completions" ) {
65
71
return async ( { messages, preprompt, generateSettings } ) => {
@@ -77,19 +83,22 @@ export async function endpointOai(
77
83
}
78
84
79
85
const parameters = { ...model . parameters , ...generateSettings } ;
86
+ const body : ChatCompletionCreateParamsStreaming = {
87
+ model : model . id ?? model . name ,
88
+ messages : messagesOpenAI ,
89
+ stream : true ,
90
+ max_tokens : parameters ?. max_new_tokens ,
91
+ stop : parameters ?. stop ,
92
+ temperature : parameters ?. temperature ,
93
+ top_p : parameters ?. top_p ,
94
+ frequency_penalty : parameters ?. repetition_penalty ,
95
+ } ;
96
+
97
+ const openChatAICompletion = await openai . chat . completions . create ( body , {
98
+ body : { ...body , ...extraBody } ,
99
+ } ) ;
80
100
81
- return openAIChatToTextGenerationStream (
82
- await openai . chat . completions . create ( {
83
- model : model . id ?? model . name ,
84
- messages : messagesOpenAI ,
85
- stream : true ,
86
- max_tokens : parameters ?. max_new_tokens ,
87
- stop : parameters ?. stop ,
88
- temperature : parameters ?. temperature ,
89
- top_p : parameters ?. top_p ,
90
- frequency_penalty : parameters ?. repetition_penalty ,
91
- } )
92
- ) ;
101
+ return openAIChatToTextGenerationStream ( openChatAICompletion ) ;
93
102
} ;
94
103
} else {
95
104
throw new Error ( "Invalid completion type" ) ;
0 commit comments