diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 8ed6b71a..ba35829a 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -5,7 +5,7 @@ import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, cast +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union, cast import anthropic import requests @@ -104,15 +104,17 @@ def __init__( def __call__( self, input: Union[str, List[Message]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): - return self.generate(input) - return self.chat(input) + return self.generate(input, **kwargs) + return self.chat(input, **kwargs) def chat( self, chat: List[Message], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: """Chat with the LMM model. Parameters: @@ -141,17 +143,24 @@ def chat( ) fixed_chat.append(fixed_c) + # prefers kwargs from second dictionary over first + tmp_kwargs = self.kwargs | kwargs response = self.client.chat.completions.create( - model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore + model=self.model_name, messages=fixed_chat, **tmp_kwargs # type: ignore ) - - return cast(str, response.choices[0].message.content) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + for chunk in response: + chunk_message = chunk.choices[0].delta.content + yield chunk_message + else: + return cast(str, response.choices[0].message.content) def generate( self, prompt: str, media: Optional[List[Union[str, Path]]] = None, - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: message: List[Dict[str, Any]] = [ { "role": "user", @@ -173,10 +182,17 @@ def generate( }, ) + # prefers kwargs from second dictionary over first + tmp_kwargs = self.kwargs | kwargs response = self.client.chat.completions.create( - model=self.model_name, messages=message, **self.kwargs # type: ignore + model=self.model_name, messages=message, **tmp_kwargs # type: ignore ) - return cast(str, response.choices[0].message.content) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + for chunk in response: + chunk_message = chunk.choices[0].delta.content + yield chunk_message # type: ignore + else: + return cast(str, response.choices[0].message.content) def generate_classifier(self, question: str) -> Callable: api_doc = T.get_tool_documentation([T.clip]) @@ -309,20 +325,22 @@ def __init__( self.url = base_url self.model_name = model_name self.json_mode = json_mode - self.stream = False + self.kwargs = kwargs def __call__( self, input: Union[str, List[Message]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): - return self.generate(input) - return self.chat(input) + return self.generate(input, **kwargs) + return self.chat(input, **kwargs) def chat( self, chat: List[Message], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: """Chat with the LMM model. Parameters: @@ -341,40 +359,77 @@ def chat( url = f"{self.url}/chat" model = self.model_name messages = fixed_chat - data = {"model": model, "messages": messages, "stream": self.stream} + data = {"model": model, "messages": messages} + + tmp_kwargs = self.kwargs | kwargs + data.update(tmp_kwargs) json_data = json.dumps(data) - response = requests.post(url, data=json_data) - if response.status_code != 200: - raise ValueError(f"Request failed with status code {response.status_code}") - response = response.json() - return response["message"]["content"] # type: ignore + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + with requests.post(url, data=json_data, stream=True) as stream: + if stream.status_code != 200: + raise ValueError( + f"Request failed with status code {stream.status_code}" + ) + + for chunk in stream.iter_content(chunk_size=None): + chunk_data = json.loads(chunk) + if chunk_data["done"]: + yield None + else: + yield chunk_data["message"]["content"] + else: + stream = requests.post(url, data=json_data) + if stream.status_code != 200: + raise ValueError( + f"Request failed with status code {stream.status_code}" + ) + stream = stream.json() + return stream["message"]["content"] # type: ignore def generate( self, prompt: str, media: Optional[List[Union[str, Path]]] = None, - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: url = f"{self.url}/generate" data = { "model": self.model_name, "prompt": prompt, "images": [], - "stream": self.stream, } - json_data = json.dumps(data) if media and len(media) > 0: for m in media: data["images"].append(encode_media(m)) # type: ignore - response = requests.post(url, data=json_data) + tmp_kwargs = self.kwargs | kwargs + data.update(tmp_kwargs) + json_data = json.dumps(data) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + with requests.post(url, data=json_data, stream=True) as stream: + if stream.status_code != 200: + raise ValueError( + f"Request failed with status code {stream.status_code}" + ) + + for chunk in stream.iter_content(chunk_size=None): + chunk_data = json.loads(chunk) + if chunk_data["done"]: + yield None + else: + yield chunk_data["response"] + else: + stream = requests.post(url, data=json_data) - if response.status_code != 200: - raise ValueError(f"Request failed with status code {response.status_code}") + if stream.status_code != 200: + raise ValueError( + f"Request failed with status code {stream.status_code}" + ) - response = response.json() - return response["response"] # type: ignore + stream = stream.json() + return stream["response"] # type: ignore class ClaudeSonnetLMM(LMM): @@ -385,27 +440,28 @@ def __init__( api_key: Optional[str] = None, model_name: str = "claude-3-sonnet-20240229", max_tokens: int = 4096, - temperature: float = 0.7, **kwargs: Any, ): self.client = anthropic.Anthropic(api_key=api_key) self.model_name = model_name - self.max_tokens = max_tokens - self.temperature = temperature + if "max_tokens" not in kwargs: + kwargs["max_tokens"] = max_tokens self.kwargs = kwargs def __call__( self, input: Union[str, List[Dict[str, Any]]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): - return self.generate(input) - return self.chat(input) + return self.generate(input, **kwargs) + return self.chat(input, **kwargs) def chat( self, chat: List[Dict[str, Any]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: messages: List[MessageParam] = [] for msg in chat: content: List[Union[TextBlockParam, ImageBlockParam]] = [ @@ -426,20 +482,28 @@ def chat( ) messages.append({"role": msg["role"], "content": content}) + # prefers kwargs from second dictionary over first + tmp_kwargs = self.kwargs | kwargs response = self.client.messages.create( - model=self.model_name, - max_tokens=self.max_tokens, - temperature=self.temperature, - messages=messages, - **self.kwargs, + model=self.model_name, messages=messages, **tmp_kwargs ) - return cast(str, response.content[0].text) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + for chunk in response: + if chunk.type == "message_start" or chunk.type == "content_block_start": + continue + elif chunk.type == "content_block_delta": + yield chunk.delta.text + elif chunk.type == "message_stop": + yield None + else: + return cast(str, response.content[0].text) def generate( self, prompt: str, media: Optional[List[Union[str, Path]]] = None, - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: content: List[Union[TextBlockParam, ImageBlockParam]] = [ TextBlockParam(type="text", text=prompt) ] @@ -456,11 +520,21 @@ def generate( }, ) ) + + # prefers kwargs from second dictionary over first + tmp_kwargs = self.kwargs | kwargs response = self.client.messages.create( model=self.model_name, - max_tokens=self.max_tokens, - temperature=self.temperature, messages=[{"role": "user", "content": content}], - **self.kwargs, + **tmp_kwargs, ) - return cast(str, response.content[0].text) + if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + for chunk in response: + if chunk.type == "message_start" or chunk.type == "content_block_start": + continue + elif chunk.type == "content_block_delta": + yield chunk.delta.text + elif chunk.type == "message_stop": + yield None + else: + return cast(str, response.content[0].text)