Skip to content

Commit 4bd5cb3

Browse files
committed
Add tests
1 parent 13ebbee commit 4bd5cb3

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
omit =
33
**/test_*.py
44
libs/core/kiln_ai/adapters/ml_model_list.py
5+
conftest.py

libs/core/kiln_ai/adapters/test_langchain_adapter.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1+
import os
12
from unittest.mock import AsyncMock, MagicMock, patch
23

4+
import pytest
5+
from langchain_aws import ChatBedrockConverse
36
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7+
from langchain_fireworks import ChatFireworks
48
from langchain_groq import ChatGroq
9+
from langchain_ollama import ChatOllama
10+
from langchain_openai import ChatOpenAI
511

612
from kiln_ai.adapters.langchain_adapters import (
713
LangchainAdapter,
814
get_structured_output_options,
15+
langchain_model_from_provider,
916
)
17+
from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
1018
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
1119
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
1220

@@ -150,3 +158,178 @@ async def test_get_structured_output_options():
150158
):
151159
options = await get_structured_output_options("model_name", "provider")
152160
assert options == {}
161+
162+
163+
@pytest.mark.asyncio
164+
async def test_langchain_model_from_provider_openai():
165+
provider = KilnModelProvider(
166+
name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
167+
)
168+
169+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
170+
mock_config.return_value.open_ai_api_key = "test_key"
171+
model = await langchain_model_from_provider(provider, "gpt-4")
172+
assert isinstance(model, ChatOpenAI)
173+
assert model.model_name == "gpt-4"
174+
175+
176+
@pytest.mark.asyncio
177+
async def test_langchain_model_from_provider_groq():
178+
provider = KilnModelProvider(
179+
name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
180+
)
181+
182+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
183+
mock_config.return_value.groq_api_key = "test_key"
184+
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
185+
assert isinstance(model, ChatGroq)
186+
assert model.model_name == "mixtral-8x7b"
187+
188+
189+
@pytest.mark.asyncio
190+
async def test_langchain_model_from_provider_bedrock():
191+
provider = KilnModelProvider(
192+
name=ModelProviderName.amazon_bedrock,
193+
provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
194+
)
195+
196+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
197+
mock_config.return_value.bedrock_access_key = "test_access"
198+
mock_config.return_value.bedrock_secret_key = "test_secret"
199+
model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
200+
assert isinstance(model, ChatBedrockConverse)
201+
assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
202+
assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
203+
204+
205+
@pytest.mark.asyncio
206+
async def test_langchain_model_from_provider_fireworks():
207+
provider = KilnModelProvider(
208+
name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
209+
)
210+
211+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
212+
mock_config.return_value.fireworks_api_key = "test_key"
213+
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
214+
assert isinstance(model, ChatFireworks)
215+
216+
217+
@pytest.mark.asyncio
218+
async def test_langchain_model_from_provider_ollama():
219+
provider = KilnModelProvider(
220+
name=ModelProviderName.ollama,
221+
provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
222+
)
223+
224+
mock_connection = MagicMock()
225+
with (
226+
patch(
227+
"kiln_ai.adapters.langchain_adapters.get_ollama_connection",
228+
return_value=AsyncMock(return_value=mock_connection),
229+
),
230+
patch(
231+
"kiln_ai.adapters.langchain_adapters.ollama_model_installed",
232+
return_value=True,
233+
),
234+
patch(
235+
"kiln_ai.adapters.langchain_adapters.ollama_base_url",
236+
return_value="http://localhost:11434",
237+
),
238+
):
239+
model = await langchain_model_from_provider(provider, "llama2")
240+
assert isinstance(model, ChatOllama)
241+
assert model.model == "llama2"
242+
243+
244+
@pytest.mark.asyncio
245+
async def test_langchain_model_from_provider_invalid():
246+
provider = KilnModelProvider.model_construct(
247+
name="invalid_provider", provider_options={}
248+
)
249+
250+
with pytest.raises(ValueError, match="Invalid model or provider"):
251+
await langchain_model_from_provider(provider, "test_model")
252+
253+
254+
@pytest.mark.asyncio
255+
async def test_langchain_adapter_model_caching(tmp_path):
256+
task = build_test_task(tmp_path)
257+
custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
258+
259+
adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
260+
261+
# First call should return the cached model
262+
model1 = await adapter.model()
263+
assert model1 is custom_model
264+
265+
# Second call should return the same cached instance
266+
model2 = await adapter.model()
267+
assert model2 is model1
268+
269+
270+
@pytest.mark.asyncio
271+
async def test_langchain_adapter_model_structured_output(tmp_path):
272+
task = build_test_task(tmp_path)
273+
task.output_json_schema = """
274+
{
275+
"type": "object",
276+
"properties": {
277+
"count": {"type": "integer"}
278+
}
279+
}
280+
"""
281+
282+
mock_model = MagicMock()
283+
mock_model.with_structured_output = MagicMock(return_value="structured_model")
284+
285+
adapter = LangchainAdapter(
286+
kiln_task=task, model_name="test_model", provider="test_provider"
287+
)
288+
289+
with (
290+
patch(
291+
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
292+
AsyncMock(return_value=mock_model),
293+
),
294+
patch(
295+
"kiln_ai.adapters.langchain_adapters.get_structured_output_options",
296+
AsyncMock(return_value={"option1": "value1"}),
297+
),
298+
):
299+
model = await adapter.model()
300+
301+
# Verify the model was configured with structured output
302+
mock_model.with_structured_output.assert_called_once_with(
303+
{
304+
"type": "object",
305+
"properties": {"count": {"type": "integer"}},
306+
"title": "task_response",
307+
"description": "A response from the task",
308+
},
309+
include_raw=True,
310+
option1="value1",
311+
)
312+
assert model == "structured_model"
313+
314+
315+
@pytest.mark.asyncio
316+
async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
317+
task = build_test_task(tmp_path)
318+
task.output_json_schema = (
319+
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
320+
)
321+
322+
mock_model = MagicMock()
323+
# Remove with_structured_output method
324+
del mock_model.with_structured_output
325+
326+
adapter = LangchainAdapter(
327+
kiln_task=task, model_name="test_model", provider="test_provider"
328+
)
329+
330+
with patch(
331+
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
332+
AsyncMock(return_value=mock_model),
333+
):
334+
with pytest.raises(ValueError, match="does not support structured output"):
335+
await adapter.model()

0 commit comments

Comments
 (0)