From bfca55bd17bed5bc9de7a5abbb7412d91eec7993 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 12 Mar 2024 13:20:26 -0700 Subject: [PATCH 1/3] added tests --- tests/__init__.py | 0 tests/fixtures.py | 29 +++++++++++++++++ tests/test_llm.py | 59 +++++++++++++++++++++++++++++++++ tests/test_lmm.py | 72 +++++++++++++++++++++++++++++++++++++++++ vision_agent/llm/llm.py | 3 +- vision_agent/lmm/lmm.py | 10 +++--- 6 files changed, 167 insertions(+), 6 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/fixtures.py create mode 100644 tests/test_llm.py create mode 100644 tests/test_lmm.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 00000000..5a64081f --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,29 @@ +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def openai_llm_mock(request): + content = request.param + # Note the path here is adjusted to where OpenAI is used, not where it's defined + with patch("vision_agent.llm.llm.OpenAI") as mock: + # Setup a mock response structure that matches what your code expects + mock_instance = mock.return_value + mock_instance.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content=content))] + ) + yield mock_instance + + +@pytest.fixture +def openai_lmm_mock(request): + content = request.param + # Note the path here is adjusted to where OpenAI is used, not where it's defined + with patch("vision_agent.lmm.lmm.OpenAI") as mock: + # Setup a mock response structure that matches what your code expects + mock_instance = mock.return_value + mock_instance.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content=content))] + ) + yield mock_instance diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 00000000..18a96fc4 --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,59 @@ +import pytest + +from vision_agent.llm.llm import OpenAILLM +from vision_agent.tools import CLIP +from vision_agent.tools.tools import GroundingDINO + +from .fixtures import openai_llm_mock + + +@pytest.mark.parametrize( + "openai_llm_mock", ["mocked response"], indirect=["openai_llm_mock"] +) +def test_generate_with_mock(openai_llm_mock): + llm = OpenAILLM() + response = llm.generate("test prompt") + assert response == "mocked response" + openai_llm_mock.chat.completions.create.assert_called_once_with( + model="gpt-4-turbo-preview", + messages=[{"role": "user", "content": "test prompt"}], + ) + + +@pytest.mark.parametrize( + "openai_llm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_llm_mock"], +) +def test_generate_classifier(openai_llm_mock): + llm = OpenAILLM() + prompt = "Can you generate a cat classifier?" + classifier = llm.generate_classifier(prompt) + assert isinstance(classifier, CLIP) + assert classifier.prompt == "cat" + + +@pytest.mark.parametrize( + "openai_llm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_llm_mock"], +) +def test_generate_detector(openai_llm_mock): + llm = OpenAILLM() + prompt = "Can you generate a cat detector?" + detector = llm.generate_detector(prompt) + assert isinstance(detector, GroundingDINO) + assert detector.prompt == "cat" + + +@pytest.mark.parametrize( + "openai_llm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_llm_mock"], +) +def test_generate_segmentor(openai_llm_mock): + llm = OpenAILLM() + prompt = "Can you generate a cat segmentor?" + segmentor = llm.generate_detector(prompt) + assert isinstance(segmentor, GroundingDINO) + assert segmentor.prompt == "cat" diff --git a/tests/test_lmm.py b/tests/test_lmm.py new file mode 100644 index 00000000..af356de6 --- /dev/null +++ b/tests/test_lmm.py @@ -0,0 +1,72 @@ +import tempfile + +import pytest +from PIL import Image + +from vision_agent.lmm.lmm import OpenAILMM +from vision_agent.tools import CLIP, GroundingDINO, GroundingSAM + +from .fixtures import openai_lmm_mock + + +def create_temp_image(image_format="jpeg"): + temp_file = tempfile.NamedTemporaryFile(suffix=f".{image_format}", delete=False) + image = Image.new("RGB", (100, 100), color=(255, 0, 0)) + image.save(temp_file, format=image_format) + temp_file.seek(0) + return temp_file.name + + +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_generate_with_mock(openai_lmm_mock): + temp_image = create_temp_image() + lmm = OpenAILMM() + response = lmm.generate("test prompt", image=temp_image) + assert response == "mocked response" + assert ( + "image_url" + in openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][1] + ) + + +@pytest.mark.parametrize( + "openai_lmm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_lmm_mock"], +) +def test_generate_classifier(openai_lmm_mock): + lmm = OpenAILMM() + prompt = "Can you generate a cat classifier?" + classifier = lmm.generate_classifier(prompt) + assert isinstance(classifier, CLIP) + assert classifier.prompt == "cat" + + +@pytest.mark.parametrize( + "openai_lmm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_lmm_mock"], +) +def test_generate_classifier(openai_lmm_mock): + lmm = OpenAILMM() + prompt = "Can you generate a cat classifier?" + detector = lmm.generate_detector(prompt) + assert isinstance(detector, GroundingDINO) + assert detector.prompt == "cat" + + +@pytest.mark.parametrize( + "openai_lmm_mock", + ['{"Parameters": {"prompt": "cat"}}'], + indirect=["openai_lmm_mock"], +) +def test_generate_classifier(openai_lmm_mock): + lmm = OpenAILMM() + prompt = "Can you generate a cat classifier?" + segmentor = lmm.generate_segmentor(prompt) + assert isinstance(segmentor, GroundingSAM) + assert segmentor.prompt == "cat" diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 09b4afc8..af90f1b0 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -1,6 +1,7 @@ import json from abc import ABC, abstractmethod from typing import Mapping, cast +from openai import OpenAI from vision_agent.tools import ( CHOOSE_PARAMS, @@ -22,8 +23,6 @@ class OpenAILLM(LLM): r"""An LLM class for any OpenAI LLM model.""" def __init__(self, model_name: str = "gpt-4-turbo-preview"): - from openai import OpenAI - self.model_name = model_name self.client = OpenAI() diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index e46eeb4e..b6c20e27 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Mapping, Optional, Union, cast import requests +from openai import OpenAI from vision_agent.tools import ( CHOOSE_PARAMS, @@ -59,8 +60,11 @@ def generate( json=data, ) resp_json: Dict[str, Any] = res.json() - if resp_json["statusCode"] != 200: - _LOGGER.error(f"Request failed: {resp_json['data']}") + if ( + "statusCode" in resp_json and resp_json["statusCode"] != 200 + ) or "statusCode" not in resp_json: + _LOGGER.error(f"Request failed: {resp_json}") + raise ValueError(f"Request failed: {resp_json}") return cast(str, resp_json["data"]) @@ -68,8 +72,6 @@ class OpenAILMM(LMM): r"""An LMM class for the OpenAI GPT-4 Vision model.""" def __init__(self, model_name: str = "gpt-4-vision-preview"): - from openai import OpenAI - self.model_name = model_name self.client = OpenAI() From c7412c8379969f6bbd6273739eedfd85dbe9cf7b Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 12 Mar 2024 13:29:40 -0700 Subject: [PATCH 2/3] updated pyproject --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a823de15..1e33b7ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" +[tool.flake8] +exclude = "tests/*" + [tool.black] exclude = '.vscode|.eggs|venv' line-length = 88 # suggested by black official site From 6540122844b5b35976664cbb02df4aff85ca627b Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 12 Mar 2024 14:07:29 -0700 Subject: [PATCH 3/3] fixed flake8 --- pyproject.toml | 3 --- tests/test_llm.py | 10 +++++----- tests/test_lmm.py | 10 +++++----- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e33b7ba..a823de15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,9 +52,6 @@ log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" -[tool.flake8] -exclude = "tests/*" - [tool.black] exclude = '.vscode|.eggs|venv' line-length = 88 # suggested by black official site diff --git a/tests/test_llm.py b/tests/test_llm.py index 18a96fc4..74453a4b 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -4,13 +4,13 @@ from vision_agent.tools import CLIP from vision_agent.tools.tools import GroundingDINO -from .fixtures import openai_llm_mock +from .fixtures import openai_llm_mock # noqa: F401 @pytest.mark.parametrize( "openai_llm_mock", ["mocked response"], indirect=["openai_llm_mock"] ) -def test_generate_with_mock(openai_llm_mock): +def test_generate_with_mock(openai_llm_mock): # noqa: F811 llm = OpenAILLM() response = llm.generate("test prompt") assert response == "mocked response" @@ -25,7 +25,7 @@ def test_generate_with_mock(openai_llm_mock): ['{"Parameters": {"prompt": "cat"}}'], indirect=["openai_llm_mock"], ) -def test_generate_classifier(openai_llm_mock): +def test_generate_classifier(openai_llm_mock): # noqa: F811 llm = OpenAILLM() prompt = "Can you generate a cat classifier?" classifier = llm.generate_classifier(prompt) @@ -38,7 +38,7 @@ def test_generate_classifier(openai_llm_mock): ['{"Parameters": {"prompt": "cat"}}'], indirect=["openai_llm_mock"], ) -def test_generate_detector(openai_llm_mock): +def test_generate_detector(openai_llm_mock): # noqa: F811 llm = OpenAILLM() prompt = "Can you generate a cat detector?" detector = llm.generate_detector(prompt) @@ -51,7 +51,7 @@ def test_generate_detector(openai_llm_mock): ['{"Parameters": {"prompt": "cat"}}'], indirect=["openai_llm_mock"], ) -def test_generate_segmentor(openai_llm_mock): +def test_generate_segmentor(openai_llm_mock): # noqa: F811 llm = OpenAILLM() prompt = "Can you generate a cat segmentor?" segmentor = llm.generate_detector(prompt) diff --git a/tests/test_lmm.py b/tests/test_lmm.py index af356de6..97cce581 100644 --- a/tests/test_lmm.py +++ b/tests/test_lmm.py @@ -6,7 +6,7 @@ from vision_agent.lmm.lmm import OpenAILMM from vision_agent.tools import CLIP, GroundingDINO, GroundingSAM -from .fixtures import openai_lmm_mock +from .fixtures import openai_lmm_mock # noqa: F401 def create_temp_image(image_format="jpeg"): @@ -20,7 +20,7 @@ def create_temp_image(image_format="jpeg"): @pytest.mark.parametrize( "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] ) -def test_generate_with_mock(openai_lmm_mock): +def test_generate_with_mock(openai_lmm_mock): # noqa: F811 temp_image = create_temp_image() lmm = OpenAILMM() response = lmm.generate("test prompt", image=temp_image) @@ -38,7 +38,7 @@ def test_generate_with_mock(openai_lmm_mock): ['{"Parameters": {"prompt": "cat"}}'], indirect=["openai_lmm_mock"], ) -def test_generate_classifier(openai_lmm_mock): +def test_generate_classifier(openai_lmm_mock): # noqa: F811 lmm = OpenAILMM() prompt = "Can you generate a cat classifier?" classifier = lmm.generate_classifier(prompt) @@ -51,7 +51,7 @@ def test_generate_classifier(openai_lmm_mock): ['{"Parameters": {"prompt": "cat"}}'], indirect=["openai_lmm_mock"], ) -def test_generate_classifier(openai_lmm_mock): +def test_generate_classifier(openai_lmm_mock): # noqa: F811 lmm = OpenAILMM() prompt = "Can you generate a cat classifier?" detector = lmm.generate_detector(prompt) @@ -64,7 +64,7 @@ def test_generate_classifier(openai_lmm_mock): ['{"Parameters": {"prompt": "cat"}}'], indirect=["openai_lmm_mock"], ) -def test_generate_classifier(openai_lmm_mock): +def test_generate_classifier(openai_lmm_mock): # noqa: F811 lmm = OpenAILMM() prompt = "Can you generate a cat classifier?" segmentor = lmm.generate_segmentor(prompt)