From 14f47e0f5794bf6ff57347738fd9b345399946e7 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 13 Aug 2024 13:51:44 -0700 Subject: [PATCH] Add Streaming for LMMs (#191) * added streaming * fixed type errors * fixed linting * fixed generator func type * black formatting * fixed tests for streaming --- tests/unit/fixtures.py | 16 +- tests/unit/test_lmm.py | 62 +++++++ vision_agent/agent/vision_agent.py | 2 +- vision_agent/agent/vision_agent_coder.py | 17 +- vision_agent/lmm/lmm.py | 208 +++++++++++++++++------ 5 files changed, 242 insertions(+), 63 deletions(-) diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index ab8d35e0..ccad51e8 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -13,11 +13,21 @@ def langsmith_wrap_oepnai_mock(request, openai_llm_mock): @pytest.fixture def openai_lmm_mock(request): content = request.param + + def mock_generate(*args, **kwargs): + if kwargs.get("stream", False): + + def generator(): + for chunk in content.split(" ") + [None]: + yield MagicMock(choices=[MagicMock(delta=MagicMock(content=chunk))]) + + return generator() + else: + return MagicMock(choices=[MagicMock(message=MagicMock(content=content))]) + # Note the path here is adjusted to where OpenAI is used, not where it's defined with patch("vision_agent.lmm.lmm.OpenAI") as mock: # Setup a mock response structure that matches what your code expects mock_instance = mock.return_value - mock_instance.chat.completions.create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=content))] - ) + mock_instance.chat.completions.create.return_value = mock_generate() yield mock_instance diff --git a/tests/unit/test_lmm.py b/tests/unit/test_lmm.py index 82871fce..9cb43650 100644 --- a/tests/unit/test_lmm.py +++ b/tests/unit/test_lmm.py @@ -34,6 +34,24 @@ def test_generate_with_mock(openai_lmm_mock): # noqa: F811 ) +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_generate_with_mock_stream(openai_lmm_mock): # noqa: F811 + temp_image = create_temp_image() + lmm = OpenAILMM() + response = lmm.generate("test prompt", media=[temp_image], stream=True) + expected_response = ["mocked", "response", None] + for i, chunk in enumerate(response): + assert chunk == expected_response[i] + assert ( + "image_url" + in openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][1] + ) + + @pytest.mark.parametrize( "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] ) @@ -49,6 +67,23 @@ def test_chat_with_mock(openai_lmm_mock): # noqa: F811 ) +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_chat_with_mock_stream(openai_lmm_mock): # noqa: F811 + lmm = OpenAILMM() + response = lmm.chat([{"role": "user", "content": "test prompt"}], stream=True) + expected_response = ["mocked", "response", None] + for i, chunk in enumerate(response): + assert chunk == expected_response[i] + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + @pytest.mark.parametrize( "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] ) @@ -73,6 +108,33 @@ def test_call_with_mock(openai_lmm_mock): # noqa: F811 ) +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_call_with_mock_stream(openai_lmm_mock): # noqa: F811 + expected_response = ["mocked", "response", None] + lmm = OpenAILMM() + response = lmm("test prompt", stream=True) + for i, chunk in enumerate(response): + assert chunk == expected_response[i] + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + response = lmm([{"role": "user", "content": "test prompt"}], stream=True) + for i, chunk in enumerate(response): + assert chunk == expected_response[i] + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + @pytest.mark.parametrize( "openai_lmm_mock", ['{"Parameters": {"prompt": "cat"}}'], diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 9090b706..a41fd09f 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -63,7 +63,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]: dir=WORKSPACE, conversation=conversation, ) - return extract_json(orch([{"role": "user", "content": prompt}])) + return extract_json(orch([{"role": "user", "content": prompt}], stream=False)) # type: ignore def run_code_action(code: str, code_interpreter: CodeInterpreter) -> str: diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 352697d7..3a370c5e 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -129,7 +129,7 @@ def write_plans( context = USER_REQ.format(user_request=user_request) prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory) chat[-1]["content"] = prompt - return extract_json(model.chat(chat)) + return extract_json(model(chat, stream=False)) # type: ignore def pick_plan( @@ -160,7 +160,7 @@ def pick_plan( docstring=tool_info, plans=plan_str, previous_attempts="", media=media ) - code = extract_code(model(prompt)) + code = extract_code(model(prompt, stream=False)) # type: ignore log_progress( { "type": "log", @@ -211,7 +211,7 @@ def pick_plan( "code": DefaultImports.prepend_imports(code), } ) - code = extract_code(model(prompt)) + code = extract_code(model(prompt, stream=False)) # type: ignore tool_output = code_interpreter.exec_isolation( DefaultImports.prepend_imports(code) ) @@ -251,7 +251,7 @@ def pick_plan( tool_output=tool_output_str[:20_000], ) chat[-1]["content"] = prompt - best_plan = extract_json(model(chat)) + best_plan = extract_json(model(chat, stream=False)) # type: ignore if verbosity >= 1: _LOGGER.info(f"Best plan:\n{best_plan}") @@ -286,7 +286,7 @@ def write_code( feedback=feedback, ) chat[-1]["content"] = prompt - return extract_code(coder(chat)) + return extract_code(coder(chat, stream=False)) # type: ignore def write_test( @@ -310,7 +310,7 @@ def write_test( media=media, ) chat[-1]["content"] = prompt - return extract_code(tester(chat)) + return extract_code(tester(chat, stream=False)) # type: ignore def write_and_test_code( @@ -439,13 +439,14 @@ def debug_code( while not success and count < 3: try: fixed_code_and_test = extract_json( - debugger( + debugger( # type: ignore FIX_BUG.format( code=code, tests=test, result="\n".join(result.text().splitlines()[-50:]), feedback=format_memory(working_memory + new_working_memory), - ) + ), + stream=False, ) ) success = True diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 8ed6b71a..9a8c5bf1 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -5,7 +5,7 @@ import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, cast +from typing import Any, Callable, Dict, Iterator, List, Optional, Union, cast import anthropic import requests @@ -58,22 +58,24 @@ def encode_media(media: Union[str, Path]) -> str: class LMM(ABC): @abstractmethod def generate( - self, prompt: str, media: Optional[List[Union[str, Path]]] = None - ) -> str: + self, prompt: str, media: Optional[List[Union[str, Path]]] = None, **kwargs: Any + ) -> Union[str, Iterator[Optional[str]]]: pass @abstractmethod def chat( self, chat: List[Message], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: pass @abstractmethod def __call__( self, input: Union[str, List[Message]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: pass @@ -104,15 +106,17 @@ def __init__( def __call__( self, input: Union[str, List[Message]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): - return self.generate(input) - return self.chat(input) + return self.generate(input, **kwargs) + return self.chat(input, **kwargs) def chat( self, chat: List[Message], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: """Chat with the LMM model. Parameters: @@ -141,17 +145,28 @@ def chat( ) fixed_chat.append(fixed_c) + # prefers kwargs from second dictionary over first + tmp_kwargs = self.kwargs | kwargs response = self.client.chat.completions.create( - model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore + model=self.model_name, messages=fixed_chat, **tmp_kwargs # type: ignore ) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + + def f() -> Iterator[Optional[str]]: + for chunk in response: + chunk_message = chunk.choices[0].delta.content # type: ignore + yield chunk_message - return cast(str, response.choices[0].message.content) + return f() + else: + return cast(str, response.choices[0].message.content) def generate( self, prompt: str, media: Optional[List[Union[str, Path]]] = None, - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: message: List[Dict[str, Any]] = [ { "role": "user", @@ -173,10 +188,21 @@ def generate( }, ) + # prefers kwargs from second dictionary over first + tmp_kwargs = self.kwargs | kwargs response = self.client.chat.completions.create( - model=self.model_name, messages=message, **self.kwargs # type: ignore + model=self.model_name, messages=message, **tmp_kwargs # type: ignore ) - return cast(str, response.choices[0].message.content) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + + def f() -> Iterator[Optional[str]]: + for chunk in response: + chunk_message = chunk.choices[0].delta.content # type: ignore + yield chunk_message + + return f() + else: + return cast(str, response.choices[0].message.content) def generate_classifier(self, question: str) -> Callable: api_doc = T.get_tool_documentation([T.clip]) @@ -309,20 +335,22 @@ def __init__( self.url = base_url self.model_name = model_name self.json_mode = json_mode - self.stream = False + self.kwargs = kwargs def __call__( self, input: Union[str, List[Message]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): - return self.generate(input) - return self.chat(input) + return self.generate(input, **kwargs) + return self.chat(input, **kwargs) def chat( self, chat: List[Message], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: """Chat with the LMM model. Parameters: @@ -341,40 +369,85 @@ def chat( url = f"{self.url}/chat" model = self.model_name messages = fixed_chat - data = {"model": model, "messages": messages, "stream": self.stream} + data = {"model": model, "messages": messages} + + tmp_kwargs = self.kwargs | kwargs + data.update(tmp_kwargs) json_data = json.dumps(data) - response = requests.post(url, data=json_data) - if response.status_code != 200: - raise ValueError(f"Request failed with status code {response.status_code}") - response = response.json() - return response["message"]["content"] # type: ignore + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + + def f() -> Iterator[Optional[str]]: + with requests.post(url, data=json_data, stream=True) as stream: + if stream.status_code != 200: + raise ValueError( + f"Request failed with status code {stream.status_code}" + ) + + for chunk in stream.iter_content(chunk_size=None): + chunk_data = json.loads(chunk) + if chunk_data["done"]: + yield None + else: + yield chunk_data["message"]["content"] + + return f() + else: + stream = requests.post(url, data=json_data) + if stream.status_code != 200: + raise ValueError( + f"Request failed with status code {stream.status_code}" + ) + stream = stream.json() + return stream["message"]["content"] # type: ignore def generate( self, prompt: str, media: Optional[List[Union[str, Path]]] = None, - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: url = f"{self.url}/generate" data = { "model": self.model_name, "prompt": prompt, "images": [], - "stream": self.stream, } - json_data = json.dumps(data) if media and len(media) > 0: for m in media: data["images"].append(encode_media(m)) # type: ignore - response = requests.post(url, data=json_data) + tmp_kwargs = self.kwargs | kwargs + data.update(tmp_kwargs) + json_data = json.dumps(data) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + + def f() -> Iterator[Optional[str]]: + with requests.post(url, data=json_data, stream=True) as stream: + if stream.status_code != 200: + raise ValueError( + f"Request failed with status code {stream.status_code}" + ) + + for chunk in stream.iter_content(chunk_size=None): + chunk_data = json.loads(chunk) + if chunk_data["done"]: + yield None + else: + yield chunk_data["response"] - if response.status_code != 200: - raise ValueError(f"Request failed with status code {response.status_code}") + return f() + else: + stream = requests.post(url, data=json_data) + + if stream.status_code != 200: + raise ValueError( + f"Request failed with status code {stream.status_code}" + ) - response = response.json() - return response["response"] # type: ignore + stream = stream.json() + return stream["response"] # type: ignore class ClaudeSonnetLMM(LMM): @@ -385,27 +458,28 @@ def __init__( api_key: Optional[str] = None, model_name: str = "claude-3-sonnet-20240229", max_tokens: int = 4096, - temperature: float = 0.7, **kwargs: Any, ): self.client = anthropic.Anthropic(api_key=api_key) self.model_name = model_name - self.max_tokens = max_tokens - self.temperature = temperature + if "max_tokens" not in kwargs: + kwargs["max_tokens"] = max_tokens self.kwargs = kwargs def __call__( self, input: Union[str, List[Dict[str, Any]]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): - return self.generate(input) - return self.chat(input) + return self.generate(input, **kwargs) + return self.chat(input, **kwargs) def chat( self, chat: List[Dict[str, Any]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: messages: List[MessageParam] = [] for msg in chat: content: List[Union[TextBlockParam, ImageBlockParam]] = [ @@ -426,20 +500,35 @@ def chat( ) messages.append({"role": msg["role"], "content": content}) + # prefers kwargs from second dictionary over first + tmp_kwargs = self.kwargs | kwargs response = self.client.messages.create( - model=self.model_name, - max_tokens=self.max_tokens, - temperature=self.temperature, - messages=messages, - **self.kwargs, + model=self.model_name, messages=messages, **tmp_kwargs ) - return cast(str, response.content[0].text) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + + def f() -> Iterator[Optional[str]]: + for chunk in response: + if ( + chunk.type == "message_start" + or chunk.type == "content_block_start" + ): + continue + elif chunk.type == "content_block_delta": + yield chunk.delta.text + elif chunk.type == "message_stop": + yield None + + return f() + else: + return cast(str, response.content[0].text) def generate( self, prompt: str, media: Optional[List[Union[str, Path]]] = None, - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: content: List[Union[TextBlockParam, ImageBlockParam]] = [ TextBlockParam(type="text", text=prompt) ] @@ -456,11 +545,28 @@ def generate( }, ) ) + + # prefers kwargs from second dictionary over first + tmp_kwargs = self.kwargs | kwargs response = self.client.messages.create( model=self.model_name, - max_tokens=self.max_tokens, - temperature=self.temperature, messages=[{"role": "user", "content": content}], - **self.kwargs, + **tmp_kwargs, ) - return cast(str, response.content[0].text) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + + def f() -> Iterator[Optional[str]]: + for chunk in response: + if ( + chunk.type == "message_start" + or chunk.type == "content_block_start" + ): + continue + elif chunk.type == "content_block_delta": + yield chunk.delta.text + elif chunk.type == "message_stop": + yield None + + return f() + else: + return cast(str, response.content[0].text)