Skip to content

Commit

Permalink
added tests for ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 27, 2024
1 parent c14f147 commit 0eba8ab
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
24 changes: 24 additions & 0 deletions tests/unit/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,27 @@ def generator():
mock_instance = mock.return_value
mock_instance.chat.completions.create.return_value = mock_generate()
yield mock_instance


@pytest.fixture
def generate_ollama_lmm_mock(request):
content = request.param

mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"response": content}
with patch("vision_agent.lmm.lmm.requests.post") as mock:
mock.return_value = mock_resp
yield mock


@pytest.fixture
def chat_ollama_lmm_mock(request):
content = request.param

mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"message": {"content": content}}
with patch("vision_agent.lmm.lmm.requests.post") as mock:
mock.return_value = mock_resp
yield mock
34 changes: 32 additions & 2 deletions tests/unit/test_lmm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import json
import tempfile
from unittest.mock import patch

import numpy as np
import pytest
from PIL import Image

from vision_agent.lmm.lmm import OpenAILMM
from vision_agent.lmm.lmm import OllamaLMM, OpenAILMM

from .fixtures import openai_lmm_mock # noqa: F401
from .fixtures import ( # noqa: F401
chat_ollama_lmm_mock,
generate_ollama_lmm_mock,
openai_lmm_mock,
)


def create_temp_image(image_format="jpeg"):
Expand Down Expand Up @@ -135,6 +140,31 @@ def test_call_with_mock_stream(openai_lmm_mock): # noqa: F811
)


@pytest.mark.parametrize(
"generate_ollama_lmm_mock",
["mocked response"],
indirect=["generate_ollama_lmm_mock"],
)
def test_generate_ollama_mock(generate_ollama_lmm_mock): # noqa: F811
temp_image = create_temp_image()
lmm = OllamaLMM()
response = lmm.generate("test prompt", media=[temp_image])
assert response == "mocked response"
call_args = json.loads(generate_ollama_lmm_mock.call_args.kwargs["data"])
assert call_args["prompt"] == "test prompt"


@pytest.mark.parametrize(
"chat_ollama_lmm_mock", ["mocked response"], indirect=["chat_ollama_lmm_mock"]
)
def test_chat_ollama_mock(chat_ollama_lmm_mock): # noqa: F811
lmm = OllamaLMM()
response = lmm.chat([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
call_args = json.loads(chat_ollama_lmm_mock.call_args.kwargs["data"])
assert call_args["messages"][0]["content"] == "test prompt"


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
Expand Down

0 comments on commit 0eba8ab

Please sign in to comment.