From 0eba8ab6797f90593e2e9825eecf3dccf8865cfd Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 27 Aug 2024 14:55:46 -0700 Subject: [PATCH] added tests for ollama --- tests/unit/fixtures.py | 24 ++++++++++++++++++++++++ tests/unit/test_lmm.py | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index ccad51e8..a56ebac6 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -31,3 +31,27 @@ def generator(): mock_instance = mock.return_value mock_instance.chat.completions.create.return_value = mock_generate() yield mock_instance + + +@pytest.fixture +def generate_ollama_lmm_mock(request): + content = request.param + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"response": content} + with patch("vision_agent.lmm.lmm.requests.post") as mock: + mock.return_value = mock_resp + yield mock + + +@pytest.fixture +def chat_ollama_lmm_mock(request): + content = request.param + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"message": {"content": content}} + with patch("vision_agent.lmm.lmm.requests.post") as mock: + mock.return_value = mock_resp + yield mock diff --git a/tests/unit/test_lmm.py b/tests/unit/test_lmm.py index 9cb43650..c954b173 100644 --- a/tests/unit/test_lmm.py +++ b/tests/unit/test_lmm.py @@ -1,3 +1,4 @@ +import json import tempfile from unittest.mock import patch @@ -5,9 +6,13 @@ import pytest from PIL import Image -from vision_agent.lmm.lmm import OpenAILMM +from vision_agent.lmm.lmm import OllamaLMM, OpenAILMM -from .fixtures import openai_lmm_mock # noqa: F401 +from .fixtures import ( # noqa: F401 + chat_ollama_lmm_mock, + generate_ollama_lmm_mock, + openai_lmm_mock, +) def create_temp_image(image_format="jpeg"): @@ -135,6 +140,31 @@ def test_call_with_mock_stream(openai_lmm_mock): # noqa: F811 ) +@pytest.mark.parametrize( + "generate_ollama_lmm_mock", + ["mocked response"], + indirect=["generate_ollama_lmm_mock"], +) +def test_generate_ollama_mock(generate_ollama_lmm_mock): # noqa: F811 + temp_image = create_temp_image() + lmm = OllamaLMM() + response = lmm.generate("test prompt", media=[temp_image]) + assert response == "mocked response" + call_args = json.loads(generate_ollama_lmm_mock.call_args.kwargs["data"]) + assert call_args["prompt"] == "test prompt" + + +@pytest.mark.parametrize( + "chat_ollama_lmm_mock", ["mocked response"], indirect=["chat_ollama_lmm_mock"] +) +def test_chat_ollama_mock(chat_ollama_lmm_mock): # noqa: F811 + lmm = OllamaLMM() + response = lmm.chat([{"role": "user", "content": "test prompt"}]) + assert response == "mocked response" + call_args = json.loads(chat_ollama_lmm_mock.call_args.kwargs["data"]) + assert call_args["messages"][0]["content"] == "test prompt" + + @pytest.mark.parametrize( "openai_lmm_mock", ['{"Parameters": {"prompt": "cat"}}'],