Skip to content

Commit

Permalink
PR requests
Browse files Browse the repository at this point in the history
  • Loading branch information
moutasemalakkad committed Jul 18, 2024
1 parent 9f69bcf commit 7cc15f4
Showing 1 changed file with 23 additions and 34 deletions.
57 changes: 23 additions & 34 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import anthropic
import requests
from anthropic.types import ImageBlockParam, MessageParam, TextBlock
from anthropic.types import ImageBlockParam, MessageParam, TextBlockParam
from openai import AzureOpenAI, OpenAI
from PIL import Image

Expand Down Expand Up @@ -380,9 +380,10 @@ def generate(


class ClaudeSonnetLMM(LMM):
r"""An LMM class for Anthropic's Claude Sonnet model."""
def __init__(
self,
api_key: str,
api_key: Optional[str] = None,
model_name: str = "claude-3-sonnet-20240229",
max_tokens: int = 4096,
temperature: float = 0.7,
Expand All @@ -396,7 +397,7 @@ def __init__(

def __call__(
self,
input: Union[str, List[Message]],
input: Union[str, List[Dict[str, Any]]],
) -> str:
if isinstance(input, str):
return self.generate(input)
Expand All @@ -408,25 +409,22 @@ def chat(
) -> str:
messages: List[MessageParam] = []
for msg in chat:
content: List[Union[TextBlock, ImageBlockParam]] = []
if isinstance(msg["content"], str):
content.append(TextBlock(type="text", text=msg["content"]))
elif isinstance(msg["content"], list):
for item in msg["content"]:
if isinstance(item, str):
content.append(TextBlock(type="text", text=item))
elif isinstance(item, (str, Path)):
encoded_media = self.encode_image(item)
content.append(
ImageBlockParam(
type="image",
source={
"type": "base64",
"media_type": "image/png",
"data": encoded_media,
},
)
content: List[Union[TextBlockParam, ImageBlockParam]] = [
TextBlockParam(type="text", text=msg["content"])
]
if "media" in msg:
for media_path in msg["media"]:
encoded_media = encode_media(media_path)
content.append(
ImageBlockParam(
type="image",
source={
"type": "base64",
"media_type": "image/png",
"data": encoded_media,
},
)
)
messages.append({"role": msg["role"], "content": content})

response = self.client.messages.create(
Expand All @@ -443,12 +441,12 @@ def generate(
prompt: str,
media: Optional[List[Union[str, Path]]] = None,
) -> str:
content: List[Union[TextBlock, ImageBlockParam]] = [
TextBlock(type="text", text=prompt)
content: List[Union[TextBlockParam, ImageBlockParam]] = [
TextBlockParam(type="text", text=prompt)
]
if media:
for m in media:
encoded_media = self.encode_image(m)
encoded_media = encode_media(m)
content.append(
ImageBlockParam(
type="image",
Expand All @@ -466,13 +464,4 @@ def generate(
messages=[{"role": "user", "content": content}],
**self.kwargs,
)
return cast(str, response.content[0].text)

def encode_image(self, image_path: Union[str, Path]) -> str:
with open(image_path, "rb") as image_file:
image = Image.open(io.BytesIO(image_file.read())).convert("RGB")
buffer = io.BytesIO()
image.save(buffer, format="PNG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")

return encoded_image
return cast(str, response.content[0].text)

0 comments on commit 7cc15f4

Please sign in to comment.