Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added tests #14

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from unittest.mock import MagicMock, patch

import pytest


@pytest.fixture
def openai_llm_mock(request):
content = request.param
# Note the path here is adjusted to where OpenAI is used, not where it's defined
with patch("vision_agent.llm.llm.OpenAI") as mock:
# Setup a mock response structure that matches what your code expects
mock_instance = mock.return_value
mock_instance.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=content))]
)
yield mock_instance


@pytest.fixture
def openai_lmm_mock(request):
content = request.param
# Note the path here is adjusted to where OpenAI is used, not where it's defined
with patch("vision_agent.lmm.lmm.OpenAI") as mock:
# Setup a mock response structure that matches what your code expects
mock_instance = mock.return_value
mock_instance.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=content))]
)
yield mock_instance
59 changes: 59 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

from vision_agent.llm.llm import OpenAILLM
from vision_agent.tools import CLIP
from vision_agent.tools.tools import GroundingDINO

from .fixtures import openai_llm_mock # noqa: F401


@pytest.mark.parametrize(
"openai_llm_mock", ["mocked response"], indirect=["openai_llm_mock"]
)
def test_generate_with_mock(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
response = llm.generate("test prompt")
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo-preview",
messages=[{"role": "user", "content": "test prompt"}],
)


@pytest.mark.parametrize(
"openai_llm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_llm_mock"],
)
def test_generate_classifier(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
prompt = "Can you generate a cat classifier?"
classifier = llm.generate_classifier(prompt)
assert isinstance(classifier, CLIP)
assert classifier.prompt == "cat"


@pytest.mark.parametrize(
"openai_llm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_llm_mock"],
)
def test_generate_detector(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
prompt = "Can you generate a cat detector?"
detector = llm.generate_detector(prompt)
assert isinstance(detector, GroundingDINO)
assert detector.prompt == "cat"


@pytest.mark.parametrize(
"openai_llm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_llm_mock"],
)
def test_generate_segmentor(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
prompt = "Can you generate a cat segmentor?"
segmentor = llm.generate_detector(prompt)
assert isinstance(segmentor, GroundingDINO)
assert segmentor.prompt == "cat"
72 changes: 72 additions & 0 deletions tests/test_lmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import tempfile

import pytest
from PIL import Image

from vision_agent.lmm.lmm import OpenAILMM
from vision_agent.tools import CLIP, GroundingDINO, GroundingSAM

from .fixtures import openai_lmm_mock # noqa: F401


def create_temp_image(image_format="jpeg"):
temp_file = tempfile.NamedTemporaryFile(suffix=f".{image_format}", delete=False)
image = Image.new("RGB", (100, 100), color=(255, 0, 0))
image.save(temp_file, format=image_format)
temp_file.seek(0)
return temp_file.name


@pytest.mark.parametrize(
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
)
def test_generate_with_mock(openai_lmm_mock): # noqa: F811
temp_image = create_temp_image()
lmm = OpenAILMM()
response = lmm.generate("test prompt", image=temp_image)
assert response == "mocked response"
assert (
"image_url"
in openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
"content"
][1]
)


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_lmm_mock"],
)
def test_generate_classifier(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
classifier = lmm.generate_classifier(prompt)
assert isinstance(classifier, CLIP)
assert classifier.prompt == "cat"


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_lmm_mock"],
)
def test_generate_classifier(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
detector = lmm.generate_detector(prompt)
assert isinstance(detector, GroundingDINO)
assert detector.prompt == "cat"


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_lmm_mock"],
)
def test_generate_classifier(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
segmentor = lmm.generate_segmentor(prompt)
assert isinstance(segmentor, GroundingSAM)
assert segmentor.prompt == "cat"
3 changes: 1 addition & 2 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from abc import ABC, abstractmethod
from typing import Mapping, cast
from openai import OpenAI

from vision_agent.tools import (
CHOOSE_PARAMS,
Expand All @@ -22,8 +23,6 @@ class OpenAILLM(LLM):
r"""An LLM class for any OpenAI LLM model."""

def __init__(self, model_name: str = "gpt-4-turbo-preview"):
from openai import OpenAI

self.model_name = model_name
self.client = OpenAI()

Expand Down
10 changes: 6 additions & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Dict, List, Mapping, Optional, Union, cast

import requests
from openai import OpenAI

from vision_agent.tools import (
CHOOSE_PARAMS,
Expand Down Expand Up @@ -59,17 +60,18 @@ def generate(
json=data,
)
resp_json: Dict[str, Any] = res.json()
if resp_json["statusCode"] != 200:
_LOGGER.error(f"Request failed: {resp_json['data']}")
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
return cast(str, resp_json["data"])


class OpenAILMM(LMM):
r"""An LMM class for the OpenAI GPT-4 Vision model."""

def __init__(self, model_name: str = "gpt-4-vision-preview"):
from openai import OpenAI

self.model_name = model_name
self.client = OpenAI()

Expand Down
Loading