Skip to content

Commit f895484

Browse files
committed
Better code reuse and tests for LiteLLM params
1 parent 5a8f388 commit f895484

File tree

3 files changed

+99
-34
lines changed

3 files changed

+99
-34
lines changed

.cursorrules

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
- call me "boss"
12
- Always assume pydantic 2 (not pydantic 1)
2-
- Always use pytest for tests
33
- The project supports Python 3.10 and above
4+
- When writing tests:
5+
1) Always use pytest for tests in python code
6+
2) assume an appriopirate test file already exists, find it, and suggest tests get appended to that file. If no such file exists, ask me before assuming a new test file is the correct route.
7+
3) Test brevity is important. Use approaches for re-use and brevity including using fixtures for repeated code, and using pytest parameterize for similar tests
8+
49

libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,11 @@ async def _run(self, input: Dict | str) -> RunOutput:
7474
raise ValueError("cot_prompt is required for cot_two_call strategy")
7575
messages.append({"role": "system", "content": cot_prompt})
7676

77-
# First call for chain of thought
78-
cot_response = await litellm.acompletion(
79-
model=self.litellm_model_id(),
80-
messages=messages,
81-
api_base=self._api_base,
82-
headers=self._headers,
83-
# TODO P1 - remove type ignore
84-
**self._additional_body_options, # type: ignore
77+
# First call for chain of thought - No logprobs as only needed for final answer
78+
completion_kwargs = await self.build_completion_kwargs(
79+
provider, messages, None
8580
)
81+
cot_response = await litellm.acompletion(**completion_kwargs)
8682
if (
8783
not isinstance(cot_response, ModelResponse)
8884
or not cot_response.choices
@@ -103,32 +99,10 @@ async def _run(self, input: Dict | str) -> RunOutput:
10399
]
104100
)
105101

106-
# Build custom request params based on model provider
107-
extra_body = self.build_extra_body(provider)
108-
109-
# Main completion call
110-
response_format_options = await self.response_format_options()
111-
112-
# Merge all parameters into a single kwargs dict for litellm
113-
# TODO P0 - make this shared
114-
completion_kwargs = {
115-
"model": self.litellm_model_id(),
116-
"messages": messages,
117-
"api_base": self._api_base,
118-
"headers": self._headers,
119-
**extra_body,
120-
**self._additional_body_options,
121-
}
122-
123-
# Add logprobs if requested
124-
if self.base_adapter_config.top_logprobs is not None:
125-
completion_kwargs["logprobs"] = True
126-
completion_kwargs["top_logprobs"] = self.base_adapter_config.top_logprobs
127-
128-
# Add response format options
129-
completion_kwargs.update(response_format_options)
130-
131102
# Make the API call using litellm
103+
completion_kwargs = await self.build_completion_kwargs(
104+
provider, messages, self.base_adapter_config.top_logprobs
105+
)
132106
response = await litellm.acompletion(**completion_kwargs)
133107

134108
if not isinstance(response, ModelResponse):
@@ -379,3 +353,31 @@ def litellm_model_id(self) -> str:
379353

380354
self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
381355
return self._litellm_model_id
356+
357+
async def build_completion_kwargs(
358+
self,
359+
provider: KilnModelProvider,
360+
messages: list[dict[str, Any]],
361+
top_logprobs: int | None,
362+
) -> dict[str, Any]:
363+
extra_body = self.build_extra_body(provider)
364+
365+
# Merge all parameters into a single kwargs dict for litellm
366+
completion_kwargs = {
367+
"model": self.litellm_model_id(),
368+
"messages": messages,
369+
"api_base": self._api_base,
370+
"headers": self._headers,
371+
**extra_body,
372+
**self._additional_body_options,
373+
}
374+
375+
# Response format: json_schema, json_instructions, json_mode, function_calling, etc
376+
response_format_options = await self.response_format_options()
377+
completion_kwargs.update(response_format_options)
378+
379+
if top_logprobs is not None:
380+
completion_kwargs["logprobs"] = True
381+
completion_kwargs["top_logprobs"] = top_logprobs
382+
383+
return completion_kwargs

libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,61 @@ def test_litellm_model_id_unknown_provider(config, mock_task):
337337

338338
with pytest.raises(Exception, match="Test error"):
339339
adapter.litellm_model_id()
340+
341+
342+
@pytest.mark.asyncio
343+
@pytest.mark.parametrize(
344+
"top_logprobs,response_format,extra_body",
345+
[
346+
(None, {}, {}), # Basic case
347+
(5, {}, {}), # With logprobs
348+
(
349+
None,
350+
{"response_format": {"type": "json_object"}},
351+
{},
352+
), # With response format
353+
(
354+
3,
355+
{"tools": [{"type": "function"}]},
356+
{"reasoning_effort": 0.8},
357+
), # Combined options
358+
],
359+
)
360+
async def test_build_completion_kwargs(
361+
config, mock_task, top_logprobs, response_format, extra_body
362+
):
363+
"""Test build_completion_kwargs with various configurations"""
364+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
365+
mock_provider = Mock()
366+
messages = [{"role": "user", "content": "Hello"}]
367+
368+
with (
369+
patch.object(adapter, "model_provider", return_value=mock_provider),
370+
patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
371+
patch.object(adapter, "build_extra_body", return_value=extra_body),
372+
patch.object(adapter, "response_format_options", return_value=response_format),
373+
):
374+
kwargs = await adapter.build_completion_kwargs(
375+
mock_provider, messages, top_logprobs
376+
)
377+
378+
# Verify core functionality
379+
assert kwargs["model"] == "openai/test-model"
380+
assert kwargs["messages"] == messages
381+
assert kwargs["api_base"] == config.base_url
382+
383+
# Verify optional parameters
384+
if top_logprobs is not None:
385+
assert kwargs["logprobs"] is True
386+
assert kwargs["top_logprobs"] == top_logprobs
387+
else:
388+
assert "logprobs" not in kwargs
389+
assert "top_logprobs" not in kwargs
390+
391+
# Verify response format is included
392+
for key, value in response_format.items():
393+
assert kwargs[key] == value
394+
395+
# Verify extra body is included
396+
for key, value in extra_body.items():
397+
assert kwargs[key] == value

0 commit comments

Comments
 (0)