diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 45a09428..cd0729bd 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -9,7 +9,7 @@ import anthropic import requests -from anthropic.types import ImageBlockParam, MessageParam, TextBlock +from anthropic.types import ImageBlockParam, MessageParam, TextBlockParam from openai import AzureOpenAI, OpenAI from PIL import Image @@ -380,9 +380,10 @@ def generate( class ClaudeSonnetLMM(LMM): + r"""An LMM class for Anthropic's Claude Sonnet model.""" def __init__( self, - api_key: str, + api_key: Optional[str] = None, model_name: str = "claude-3-sonnet-20240229", max_tokens: int = 4096, temperature: float = 0.7, @@ -396,7 +397,7 @@ def __init__( def __call__( self, - input: Union[str, List[Message]], + input: Union[str, List[Dict[str, Any]]], ) -> str: if isinstance(input, str): return self.generate(input) @@ -408,25 +409,22 @@ def chat( ) -> str: messages: List[MessageParam] = [] for msg in chat: - content: List[Union[TextBlock, ImageBlockParam]] = [] - if isinstance(msg["content"], str): - content.append(TextBlock(type="text", text=msg["content"])) - elif isinstance(msg["content"], list): - for item in msg["content"]: - if isinstance(item, str): - content.append(TextBlock(type="text", text=item)) - elif isinstance(item, (str, Path)): - encoded_media = self.encode_image(item) - content.append( - ImageBlockParam( - type="image", - source={ - "type": "base64", - "media_type": "image/png", - "data": encoded_media, - }, - ) + content: List[Union[TextBlockParam, ImageBlockParam]] = [ + TextBlockParam(type="text", text=msg["content"]) + ] + if "media" in msg: + for media_path in msg["media"]: + encoded_media = encode_media(media_path) + content.append( + ImageBlockParam( + type="image", + source={ + "type": "base64", + "media_type": "image/png", + "data": encoded_media, + }, ) + ) messages.append({"role": msg["role"], "content": content}) response = self.client.messages.create( @@ -443,12 +441,12 @@ def generate( prompt: str, media: Optional[List[Union[str, Path]]] = None, ) -> str: - content: List[Union[TextBlock, ImageBlockParam]] = [ - TextBlock(type="text", text=prompt) + content: List[Union[TextBlockParam, ImageBlockParam]] = [ + TextBlockParam(type="text", text=prompt) ] if media: for m in media: - encoded_media = self.encode_image(m) + encoded_media = encode_media(m) content.append( ImageBlockParam( type="image", @@ -466,13 +464,4 @@ def generate( messages=[{"role": "user", "content": content}], **self.kwargs, ) - return cast(str, response.content[0].text) - - def encode_image(self, image_path: Union[str, Path]) -> str: - with open(image_path, "rb") as image_file: - image = Image.open(io.BytesIO(image_file.read())).convert("RGB") - buffer = io.BytesIO() - image.save(buffer, format="PNG") - encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") - - return encoded_image + return cast(str, response.content[0].text) \ No newline at end of file