@@ -74,15 +74,11 @@ async def _run(self, input: Dict | str) -> RunOutput:
74
74
raise ValueError ("cot_prompt is required for cot_two_call strategy" )
75
75
messages .append ({"role" : "system" , "content" : cot_prompt })
76
76
77
- # First call for chain of thought
78
- cot_response = await litellm .acompletion (
79
- model = self .litellm_model_id (),
80
- messages = messages ,
81
- api_base = self ._api_base ,
82
- headers = self ._headers ,
83
- # TODO P1 - remove type ignore
84
- ** self ._additional_body_options , # type: ignore
77
+ # First call for chain of thought - No logprobs as only needed for final answer
78
+ completion_kwargs = await self .build_completion_kwargs (
79
+ provider , messages , None
85
80
)
81
+ cot_response = await litellm .acompletion (** completion_kwargs )
86
82
if (
87
83
not isinstance (cot_response , ModelResponse )
88
84
or not cot_response .choices
@@ -103,32 +99,10 @@ async def _run(self, input: Dict | str) -> RunOutput:
103
99
]
104
100
)
105
101
106
- # Build custom request params based on model provider
107
- extra_body = self .build_extra_body (provider )
108
-
109
- # Main completion call
110
- response_format_options = await self .response_format_options ()
111
-
112
- # Merge all parameters into a single kwargs dict for litellm
113
- # TODO P0 - make this shared
114
- completion_kwargs = {
115
- "model" : self .litellm_model_id (),
116
- "messages" : messages ,
117
- "api_base" : self ._api_base ,
118
- "headers" : self ._headers ,
119
- ** extra_body ,
120
- ** self ._additional_body_options ,
121
- }
122
-
123
- # Add logprobs if requested
124
- if self .base_adapter_config .top_logprobs is not None :
125
- completion_kwargs ["logprobs" ] = True
126
- completion_kwargs ["top_logprobs" ] = self .base_adapter_config .top_logprobs
127
-
128
- # Add response format options
129
- completion_kwargs .update (response_format_options )
130
-
131
102
# Make the API call using litellm
103
+ completion_kwargs = await self .build_completion_kwargs (
104
+ provider , messages , self .base_adapter_config .top_logprobs
105
+ )
132
106
response = await litellm .acompletion (** completion_kwargs )
133
107
134
108
if not isinstance (response , ModelResponse ):
@@ -379,3 +353,31 @@ def litellm_model_id(self) -> str:
379
353
380
354
self ._litellm_model_id = litellm_provider_name + "/" + provider .model_id
381
355
return self ._litellm_model_id
356
+
357
+ async def build_completion_kwargs (
358
+ self ,
359
+ provider : KilnModelProvider ,
360
+ messages : list [dict [str , Any ]],
361
+ top_logprobs : int | None ,
362
+ ) -> dict [str , Any ]:
363
+ extra_body = self .build_extra_body (provider )
364
+
365
+ # Merge all parameters into a single kwargs dict for litellm
366
+ completion_kwargs = {
367
+ "model" : self .litellm_model_id (),
368
+ "messages" : messages ,
369
+ "api_base" : self ._api_base ,
370
+ "headers" : self ._headers ,
371
+ ** extra_body ,
372
+ ** self ._additional_body_options ,
373
+ }
374
+
375
+ # Response format: json_schema, json_instructions, json_mode, function_calling, etc
376
+ response_format_options = await self .response_format_options ()
377
+ completion_kwargs .update (response_format_options )
378
+
379
+ if top_logprobs is not None :
380
+ completion_kwargs ["logprobs" ] = True
381
+ completion_kwargs ["top_logprobs" ] = top_logprobs
382
+
383
+ return completion_kwargs
0 commit comments