Skip to content

Commit

Permalink
added tests (#14)
Browse files Browse the repository at this point in the history
* added tests

* updated pyproject

* fixed flake8
  • Loading branch information
dillonalaird authored Mar 12, 2024
1 parent 91693c6 commit 9364e76
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 6 deletions.
Empty file added tests/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from unittest.mock import MagicMock, patch

import pytest


@pytest.fixture
def openai_llm_mock(request):
content = request.param
# Note the path here is adjusted to where OpenAI is used, not where it's defined
with patch("vision_agent.llm.llm.OpenAI") as mock:
# Setup a mock response structure that matches what your code expects
mock_instance = mock.return_value
mock_instance.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=content))]
)
yield mock_instance


@pytest.fixture
def openai_lmm_mock(request):
content = request.param
# Note the path here is adjusted to where OpenAI is used, not where it's defined
with patch("vision_agent.lmm.lmm.OpenAI") as mock:
# Setup a mock response structure that matches what your code expects
mock_instance = mock.return_value
mock_instance.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=content))]
)
yield mock_instance
59 changes: 59 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

from vision_agent.llm.llm import OpenAILLM
from vision_agent.tools import CLIP
from vision_agent.tools.tools import GroundingDINO

from .fixtures import openai_llm_mock # noqa: F401


@pytest.mark.parametrize(
"openai_llm_mock", ["mocked response"], indirect=["openai_llm_mock"]
)
def test_generate_with_mock(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
response = llm.generate("test prompt")
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo-preview",
messages=[{"role": "user", "content": "test prompt"}],
)


@pytest.mark.parametrize(
"openai_llm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_llm_mock"],
)
def test_generate_classifier(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
prompt = "Can you generate a cat classifier?"
classifier = llm.generate_classifier(prompt)
assert isinstance(classifier, CLIP)
assert classifier.prompt == "cat"


@pytest.mark.parametrize(
"openai_llm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_llm_mock"],
)
def test_generate_detector(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
prompt = "Can you generate a cat detector?"
detector = llm.generate_detector(prompt)
assert isinstance(detector, GroundingDINO)
assert detector.prompt == "cat"


@pytest.mark.parametrize(
"openai_llm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_llm_mock"],
)
def test_generate_segmentor(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
prompt = "Can you generate a cat segmentor?"
segmentor = llm.generate_detector(prompt)
assert isinstance(segmentor, GroundingDINO)
assert segmentor.prompt == "cat"
72 changes: 72 additions & 0 deletions tests/test_lmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import tempfile

import pytest
from PIL import Image

from vision_agent.lmm.lmm import OpenAILMM
from vision_agent.tools import CLIP, GroundingDINO, GroundingSAM

from .fixtures import openai_lmm_mock # noqa: F401


def create_temp_image(image_format="jpeg"):
temp_file = tempfile.NamedTemporaryFile(suffix=f".{image_format}", delete=False)
image = Image.new("RGB", (100, 100), color=(255, 0, 0))
image.save(temp_file, format=image_format)
temp_file.seek(0)
return temp_file.name


@pytest.mark.parametrize(
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
)
def test_generate_with_mock(openai_lmm_mock): # noqa: F811
temp_image = create_temp_image()
lmm = OpenAILMM()
response = lmm.generate("test prompt", image=temp_image)
assert response == "mocked response"
assert (
"image_url"
in openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
"content"
][1]
)


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_lmm_mock"],
)
def test_generate_classifier(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
classifier = lmm.generate_classifier(prompt)
assert isinstance(classifier, CLIP)
assert classifier.prompt == "cat"


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_lmm_mock"],
)
def test_generate_classifier(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
detector = lmm.generate_detector(prompt)
assert isinstance(detector, GroundingDINO)
assert detector.prompt == "cat"


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_lmm_mock"],
)
def test_generate_classifier(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
segmentor = lmm.generate_segmentor(prompt)
assert isinstance(segmentor, GroundingSAM)
assert segmentor.prompt == "cat"
3 changes: 1 addition & 2 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from abc import ABC, abstractmethod
from typing import Mapping, cast
from openai import OpenAI

from vision_agent.tools import (
CHOOSE_PARAMS,
Expand All @@ -22,8 +23,6 @@ class OpenAILLM(LLM):
r"""An LLM class for any OpenAI LLM model."""

def __init__(self, model_name: str = "gpt-4-turbo-preview"):
from openai import OpenAI

self.model_name = model_name
self.client = OpenAI()

Expand Down
10 changes: 6 additions & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Dict, List, Mapping, Optional, Union, cast

import requests
from openai import OpenAI

from vision_agent.tools import (
CHOOSE_PARAMS,
Expand Down Expand Up @@ -59,17 +60,18 @@ def generate(
json=data,
)
resp_json: Dict[str, Any] = res.json()
if resp_json["statusCode"] != 200:
_LOGGER.error(f"Request failed: {resp_json['data']}")
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
return cast(str, resp_json["data"])


class OpenAILMM(LMM):
r"""An LMM class for the OpenAI GPT-4 Vision model."""

def __init__(self, model_name: str = "gpt-4-vision-preview"):
from openai import OpenAI

self.model_name = model_name
self.client = OpenAI()

Expand Down

0 comments on commit 9364e76

Please sign in to comment.