diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 00000000..5a64081f --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,29 @@ +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def openai_llm_mock(request): + content = request.param + # Note the path here is adjusted to where OpenAI is used, not where it's defined + with patch("vision_agent.llm.llm.OpenAI") as mock: + # Setup a mock response structure that matches what your code expects + mock_instance = mock.return_value + mock_instance.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content=content))] + ) + yield mock_instance + + +@pytest.fixture +def openai_lmm_mock(request): + content = request.param + # Note the path here is adjusted to where OpenAI is used, not where it's defined + with patch("vision_agent.lmm.lmm.OpenAI") as mock: + # Setup a mock response structure that matches what your code expects + mock_instance = mock.return_value + mock_instance.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content=content))] + ) + yield mock_instance diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 00000000..74453a4b --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,59 @@ +import pytest + +from vision_agent.llm.llm import OpenAILLM +from vision_agent.tools import CLIP +from vision_agent.tools.tools import GroundingDINO + +from .fixtures import openai_llm_mock # noqa: F401 + + +@pytest.mark.parametrize( + "openai_llm_mock", ["mocked response"], indirect=["openai_llm_mock"] +) +def test_generate_with_mock(openai_llm_mock): # noqa: F811 + llm = OpenAILLM() + response = llm.generate("test prompt") + assert response == "mocked response" + openai_llm_mock.chat.completions.create.assert_called_once_with( + model="gpt-4-turbo-preview", + messages=[{"role": "user", "content": "test prompt"}], + ) + + +@pytest.mark.parametrize( + "openai_llm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_llm_mock"], +) +def test_generate_classifier(openai_llm_mock): # noqa: F811 + llm = OpenAILLM() + prompt = "Can you generate a cat classifier?" + classifier = llm.generate_classifier(prompt) + assert isinstance(classifier, CLIP) + assert classifier.prompt == "cat" + + +@pytest.mark.parametrize( + "openai_llm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_llm_mock"], +) +def test_generate_detector(openai_llm_mock): # noqa: F811 + llm = OpenAILLM() + prompt = "Can you generate a cat detector?" + detector = llm.generate_detector(prompt) + assert isinstance(detector, GroundingDINO) + assert detector.prompt == "cat" + + +@pytest.mark.parametrize( + "openai_llm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_llm_mock"], +) +def test_generate_segmentor(openai_llm_mock): # noqa: F811 + llm = OpenAILLM() + prompt = "Can you generate a cat segmentor?" + segmentor = llm.generate_detector(prompt) + assert isinstance(segmentor, GroundingDINO) + assert segmentor.prompt == "cat" diff --git a/tests/test_lmm.py b/tests/test_lmm.py new file mode 100644 index 00000000..97cce581 --- /dev/null +++ b/tests/test_lmm.py @@ -0,0 +1,72 @@ +import tempfile + +import pytest +from PIL import Image + +from vision_agent.lmm.lmm import OpenAILMM +from vision_agent.tools import CLIP, GroundingDINO, GroundingSAM + +from .fixtures import openai_lmm_mock # noqa: F401 + + +def create_temp_image(image_format="jpeg"): + temp_file = tempfile.NamedTemporaryFile(suffix=f".{image_format}", delete=False) + image = Image.new("RGB", (100, 100), color=(255, 0, 0)) + image.save(temp_file, format=image_format) + temp_file.seek(0) + return temp_file.name + + +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_generate_with_mock(openai_lmm_mock): # noqa: F811 + temp_image = create_temp_image() + lmm = OpenAILMM() + response = lmm.generate("test prompt", image=temp_image) + assert response == "mocked response" + assert ( + "image_url" + in openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][1] + ) + + +@pytest.mark.parametrize( + "openai_lmm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_lmm_mock"], +) +def test_generate_classifier(openai_lmm_mock): # noqa: F811 + lmm = OpenAILMM() + prompt = "Can you generate a cat classifier?" + classifier = lmm.generate_classifier(prompt) + assert isinstance(classifier, CLIP) + assert classifier.prompt == "cat" + + +@pytest.mark.parametrize( + "openai_lmm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_lmm_mock"], +) +def test_generate_classifier(openai_lmm_mock): # noqa: F811 + lmm = OpenAILMM() + prompt = "Can you generate a cat classifier?" + detector = lmm.generate_detector(prompt) + assert isinstance(detector, GroundingDINO) + assert detector.prompt == "cat" + + +@pytest.mark.parametrize( + "openai_lmm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_lmm_mock"], +) +def test_generate_classifier(openai_lmm_mock): # noqa: F811 + lmm = OpenAILMM() + prompt = "Can you generate a cat classifier?" + segmentor = lmm.generate_segmentor(prompt) + assert isinstance(segmentor, GroundingSAM) + assert segmentor.prompt == "cat" diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 09b4afc8..af90f1b0 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -1,6 +1,7 @@ import json from abc import ABC, abstractmethod from typing import Mapping, cast +from openai import OpenAI from vision_agent.tools import ( CHOOSE_PARAMS, @@ -22,8 +23,6 @@ class OpenAILLM(LLM): r"""An LLM class for any OpenAI LLM model.""" def __init__(self, model_name: str = "gpt-4-turbo-preview"): - from openai import OpenAI - self.model_name = model_name self.client = OpenAI() diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index e46eeb4e..b6c20e27 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Mapping, Optional, Union, cast import requests +from openai import OpenAI from vision_agent.tools import ( CHOOSE_PARAMS, @@ -59,8 +60,11 @@ def generate( json=data, ) resp_json: Dict[str, Any] = res.json() - if resp_json["statusCode"] != 200: - _LOGGER.error(f"Request failed: {resp_json['data']}") + if ( + "statusCode" in resp_json and resp_json["statusCode"] != 200 + ) or "statusCode" not in resp_json: + _LOGGER.error(f"Request failed: {resp_json}") + raise ValueError(f"Request failed: {resp_json}") return cast(str, resp_json["data"]) @@ -68,8 +72,6 @@ class OpenAILMM(LMM): r"""An LMM class for the OpenAI GPT-4 Vision model.""" def __init__(self, model_name: str = "gpt-4-vision-preview"): - from openai import OpenAI - self.model_name = model_name self.client = OpenAI()