|
| 1 | +import os |
1 | 2 | from unittest.mock import AsyncMock, MagicMock, patch
|
2 | 3 |
|
| 4 | +import pytest |
| 5 | +from langchain_aws import ChatBedrockConverse |
3 | 6 | from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 7 | +from langchain_fireworks import ChatFireworks |
4 | 8 | from langchain_groq import ChatGroq
|
| 9 | +from langchain_ollama import ChatOllama |
| 10 | +from langchain_openai import ChatOpenAI |
5 | 11 |
|
6 | 12 | from kiln_ai.adapters.langchain_adapters import (
|
7 | 13 | LangchainAdapter,
|
8 | 14 | get_structured_output_options,
|
| 15 | + langchain_model_from_provider, |
9 | 16 | )
|
| 17 | +from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName |
10 | 18 | from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
|
11 | 19 | from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
12 | 20 |
|
@@ -150,3 +158,178 @@ async def test_get_structured_output_options():
|
150 | 158 | ):
|
151 | 159 | options = await get_structured_output_options("model_name", "provider")
|
152 | 160 | assert options == {}
|
| 161 | + |
| 162 | + |
| 163 | +@pytest.mark.asyncio |
| 164 | +async def test_langchain_model_from_provider_openai(): |
| 165 | + provider = KilnModelProvider( |
| 166 | + name=ModelProviderName.openai, provider_options={"model": "gpt-4"} |
| 167 | + ) |
| 168 | + |
| 169 | + with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: |
| 170 | + mock_config.return_value.open_ai_api_key = "test_key" |
| 171 | + model = await langchain_model_from_provider(provider, "gpt-4") |
| 172 | + assert isinstance(model, ChatOpenAI) |
| 173 | + assert model.model_name == "gpt-4" |
| 174 | + |
| 175 | + |
| 176 | +@pytest.mark.asyncio |
| 177 | +async def test_langchain_model_from_provider_groq(): |
| 178 | + provider = KilnModelProvider( |
| 179 | + name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"} |
| 180 | + ) |
| 181 | + |
| 182 | + with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: |
| 183 | + mock_config.return_value.groq_api_key = "test_key" |
| 184 | + model = await langchain_model_from_provider(provider, "mixtral-8x7b") |
| 185 | + assert isinstance(model, ChatGroq) |
| 186 | + assert model.model_name == "mixtral-8x7b" |
| 187 | + |
| 188 | + |
| 189 | +@pytest.mark.asyncio |
| 190 | +async def test_langchain_model_from_provider_bedrock(): |
| 191 | + provider = KilnModelProvider( |
| 192 | + name=ModelProviderName.amazon_bedrock, |
| 193 | + provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"}, |
| 194 | + ) |
| 195 | + |
| 196 | + with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: |
| 197 | + mock_config.return_value.bedrock_access_key = "test_access" |
| 198 | + mock_config.return_value.bedrock_secret_key = "test_secret" |
| 199 | + model = await langchain_model_from_provider(provider, "anthropic.claude-v2") |
| 200 | + assert isinstance(model, ChatBedrockConverse) |
| 201 | + assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access" |
| 202 | + assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret" |
| 203 | + |
| 204 | + |
| 205 | +@pytest.mark.asyncio |
| 206 | +async def test_langchain_model_from_provider_fireworks(): |
| 207 | + provider = KilnModelProvider( |
| 208 | + name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"} |
| 209 | + ) |
| 210 | + |
| 211 | + with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: |
| 212 | + mock_config.return_value.fireworks_api_key = "test_key" |
| 213 | + model = await langchain_model_from_provider(provider, "mixtral-8x7b") |
| 214 | + assert isinstance(model, ChatFireworks) |
| 215 | + |
| 216 | + |
| 217 | +@pytest.mark.asyncio |
| 218 | +async def test_langchain_model_from_provider_ollama(): |
| 219 | + provider = KilnModelProvider( |
| 220 | + name=ModelProviderName.ollama, |
| 221 | + provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]}, |
| 222 | + ) |
| 223 | + |
| 224 | + mock_connection = MagicMock() |
| 225 | + with ( |
| 226 | + patch( |
| 227 | + "kiln_ai.adapters.langchain_adapters.get_ollama_connection", |
| 228 | + return_value=AsyncMock(return_value=mock_connection), |
| 229 | + ), |
| 230 | + patch( |
| 231 | + "kiln_ai.adapters.langchain_adapters.ollama_model_installed", |
| 232 | + return_value=True, |
| 233 | + ), |
| 234 | + patch( |
| 235 | + "kiln_ai.adapters.langchain_adapters.ollama_base_url", |
| 236 | + return_value="http://localhost:11434", |
| 237 | + ), |
| 238 | + ): |
| 239 | + model = await langchain_model_from_provider(provider, "llama2") |
| 240 | + assert isinstance(model, ChatOllama) |
| 241 | + assert model.model == "llama2" |
| 242 | + |
| 243 | + |
| 244 | +@pytest.mark.asyncio |
| 245 | +async def test_langchain_model_from_provider_invalid(): |
| 246 | + provider = KilnModelProvider.model_construct( |
| 247 | + name="invalid_provider", provider_options={} |
| 248 | + ) |
| 249 | + |
| 250 | + with pytest.raises(ValueError, match="Invalid model or provider"): |
| 251 | + await langchain_model_from_provider(provider, "test_model") |
| 252 | + |
| 253 | + |
| 254 | +@pytest.mark.asyncio |
| 255 | +async def test_langchain_adapter_model_caching(tmp_path): |
| 256 | + task = build_test_task(tmp_path) |
| 257 | + custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test") |
| 258 | + |
| 259 | + adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model) |
| 260 | + |
| 261 | + # First call should return the cached model |
| 262 | + model1 = await adapter.model() |
| 263 | + assert model1 is custom_model |
| 264 | + |
| 265 | + # Second call should return the same cached instance |
| 266 | + model2 = await adapter.model() |
| 267 | + assert model2 is model1 |
| 268 | + |
| 269 | + |
| 270 | +@pytest.mark.asyncio |
| 271 | +async def test_langchain_adapter_model_structured_output(tmp_path): |
| 272 | + task = build_test_task(tmp_path) |
| 273 | + task.output_json_schema = """ |
| 274 | + { |
| 275 | + "type": "object", |
| 276 | + "properties": { |
| 277 | + "count": {"type": "integer"} |
| 278 | + } |
| 279 | + } |
| 280 | + """ |
| 281 | + |
| 282 | + mock_model = MagicMock() |
| 283 | + mock_model.with_structured_output = MagicMock(return_value="structured_model") |
| 284 | + |
| 285 | + adapter = LangchainAdapter( |
| 286 | + kiln_task=task, model_name="test_model", provider="test_provider" |
| 287 | + ) |
| 288 | + |
| 289 | + with ( |
| 290 | + patch( |
| 291 | + "kiln_ai.adapters.langchain_adapters.langchain_model_from", |
| 292 | + AsyncMock(return_value=mock_model), |
| 293 | + ), |
| 294 | + patch( |
| 295 | + "kiln_ai.adapters.langchain_adapters.get_structured_output_options", |
| 296 | + AsyncMock(return_value={"option1": "value1"}), |
| 297 | + ), |
| 298 | + ): |
| 299 | + model = await adapter.model() |
| 300 | + |
| 301 | + # Verify the model was configured with structured output |
| 302 | + mock_model.with_structured_output.assert_called_once_with( |
| 303 | + { |
| 304 | + "type": "object", |
| 305 | + "properties": {"count": {"type": "integer"}}, |
| 306 | + "title": "task_response", |
| 307 | + "description": "A response from the task", |
| 308 | + }, |
| 309 | + include_raw=True, |
| 310 | + option1="value1", |
| 311 | + ) |
| 312 | + assert model == "structured_model" |
| 313 | + |
| 314 | + |
| 315 | +@pytest.mark.asyncio |
| 316 | +async def test_langchain_adapter_model_no_structured_output_support(tmp_path): |
| 317 | + task = build_test_task(tmp_path) |
| 318 | + task.output_json_schema = ( |
| 319 | + '{"type": "object", "properties": {"count": {"type": "integer"}}}' |
| 320 | + ) |
| 321 | + |
| 322 | + mock_model = MagicMock() |
| 323 | + # Remove with_structured_output method |
| 324 | + del mock_model.with_structured_output |
| 325 | + |
| 326 | + adapter = LangchainAdapter( |
| 327 | + kiln_task=task, model_name="test_model", provider="test_provider" |
| 328 | + ) |
| 329 | + |
| 330 | + with patch( |
| 331 | + "kiln_ai.adapters.langchain_adapters.langchain_model_from", |
| 332 | + AsyncMock(return_value=mock_model), |
| 333 | + ): |
| 334 | + with pytest.raises(ValueError, match="does not support structured output"): |
| 335 | + await adapter.model() |
0 commit comments