From 17d6f4539206e45903761ed4c57a909899b20b60 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sun, 11 Aug 2024 10:47:00 -0700 Subject: [PATCH] fixed tests for streaming --- tests/unit/fixtures.py | 16 +++++++++-- tests/unit/test_lmm.py | 62 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index ab8d35e0..ccad51e8 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -13,11 +13,21 @@ def langsmith_wrap_oepnai_mock(request, openai_llm_mock): @pytest.fixture def openai_lmm_mock(request): content = request.param + + def mock_generate(*args, **kwargs): + if kwargs.get("stream", False): + + def generator(): + for chunk in content.split(" ") + [None]: + yield MagicMock(choices=[MagicMock(delta=MagicMock(content=chunk))]) + + return generator() + else: + return MagicMock(choices=[MagicMock(message=MagicMock(content=content))]) + # Note the path here is adjusted to where OpenAI is used, not where it's defined with patch("vision_agent.lmm.lmm.OpenAI") as mock: # Setup a mock response structure that matches what your code expects mock_instance = mock.return_value - mock_instance.chat.completions.create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=content))] - ) + mock_instance.chat.completions.create.return_value = mock_generate() yield mock_instance diff --git a/tests/unit/test_lmm.py b/tests/unit/test_lmm.py index 82871fce..9cb43650 100644 --- a/tests/unit/test_lmm.py +++ b/tests/unit/test_lmm.py @@ -34,6 +34,24 @@ def test_generate_with_mock(openai_lmm_mock): # noqa: F811 ) +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_generate_with_mock_stream(openai_lmm_mock): # noqa: F811 + temp_image = create_temp_image() + lmm = OpenAILMM() + response = lmm.generate("test prompt", media=[temp_image], stream=True) + expected_response = ["mocked", "response", None] + for i, chunk in enumerate(response): + assert chunk == expected_response[i] + assert ( + "image_url" + in openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][1] + ) + + @pytest.mark.parametrize( "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] ) @@ -49,6 +67,23 @@ def test_chat_with_mock(openai_lmm_mock): # noqa: F811 ) +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_chat_with_mock_stream(openai_lmm_mock): # noqa: F811 + lmm = OpenAILMM() + response = lmm.chat([{"role": "user", "content": "test prompt"}], stream=True) + expected_response = ["mocked", "response", None] + for i, chunk in enumerate(response): + assert chunk == expected_response[i] + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + @pytest.mark.parametrize( "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] ) @@ -73,6 +108,33 @@ def test_call_with_mock(openai_lmm_mock): # noqa: F811 ) +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_call_with_mock_stream(openai_lmm_mock): # noqa: F811 + expected_response = ["mocked", "response", None] + lmm = OpenAILMM() + response = lmm("test prompt", stream=True) + for i, chunk in enumerate(response): + assert chunk == expected_response[i] + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + response = lmm([{"role": "user", "content": "test prompt"}], stream=True) + for i, chunk in enumerate(response): + assert chunk == expected_response[i] + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + @pytest.mark.parametrize( "openai_lmm_mock", ['{"Parameters": {"prompt": "cat"}}'],