1- import base64
2- import io
31import json
42import logging
53import os
64from abc import ABC , abstractmethod
75from pathlib import Path
8- from typing import Any , Callable , Dict , Iterator , List , Optional , Union , cast
6+ from typing import Any , Dict , Iterator , List , Optional , Union , cast , Sequence
97
108import anthropic
119import requests
1210from anthropic .types import ImageBlockParam , MessageParam , TextBlockParam
1311from openai import AzureOpenAI , OpenAI
14- from PIL import Image
1512
16- import vision_agent .tools as T
17- from vision_agent .tools .prompts import CHOOSE_PARAMS , SYSTEM_PROMPT
13+ from vision_agent .utils .image_utils import encode_media
1814
1915from .types import Message
2016
2117_LOGGER = logging .getLogger (__name__ )
2218
2319
24- def encode_image_bytes (image : bytes ) -> str :
25- image = Image .open (io .BytesIO (image )).convert ("RGB" ) # type: ignore
26- buffer = io .BytesIO ()
27- image .save (buffer , format = "PNG" ) # type: ignore
28- encoded_image = base64 .b64encode (buffer .getvalue ()).decode ("utf-8" )
29- return encoded_image
30-
31-
32- def encode_media (media : Union [str , Path ]) -> str :
33- if type (media ) is str and media .startswith (("http" , "https" )):
34- # for mp4 video url, we assume there is a same url but ends with png
35- # vision-agent-ui will upload this png when uploading the video
36- if media .endswith ((".mp4" , "mov" )) and media .find ("vision-agent-dev.s3" ) != - 1 :
37- return media [:- 4 ] + ".png"
38- return media
39- extension = "png"
40- extension = Path (media ).suffix
41- if extension .lower () not in {
42- ".jpg" ,
43- ".jpeg" ,
44- ".png" ,
45- ".webp" ,
46- ".bmp" ,
47- ".mp4" ,
48- ".mov" ,
49- }:
50- raise ValueError (f"Unsupported image extension: { extension } " )
51-
52- image_bytes = b""
53- if extension .lower () in {".mp4" , ".mov" }:
54- frames = T .extract_frames (media )
55- image = frames [len (frames ) // 2 ]
56- buffer = io .BytesIO ()
57- Image .fromarray (image [0 ]).convert ("RGB" ).save (buffer , format = "PNG" )
58- image_bytes = buffer .getvalue ()
59- else :
60- image_bytes = open (media , "rb" ).read ()
61- return encode_image_bytes (image_bytes )
62-
63-
6420class LMM (ABC ):
6521 @abstractmethod
6622 def generate (
67- self , prompt : str , media : Optional [List [Union [str , Path ]]] = None , ** kwargs : Any
23+ self , prompt : str , media : Optional [Sequence [Union [str , Path ]]] = None , ** kwargs : Any
6824 ) -> Union [str , Iterator [Optional [str ]]]:
6925 pass
7026
7127 @abstractmethod
7228 def chat (
7329 self ,
74- chat : List [Message ],
30+ chat : Sequence [Message ],
7531 ** kwargs : Any ,
7632 ) -> Union [str , Iterator [Optional [str ]]]:
7733 pass
7834
7935 @abstractmethod
8036 def __call__ (
8137 self ,
82- input : Union [str , List [Message ]],
38+ input : Union [str , Sequence [Message ]],
8339 ** kwargs : Any ,
8440 ) -> Union [str , Iterator [Optional [str ]]]:
8541 pass
@@ -111,7 +67,7 @@ def __init__(
11167
11268 def __call__ (
11369 self ,
114- input : Union [str , List [Message ]],
70+ input : Union [str , Sequence [Message ]],
11571 ** kwargs : Any ,
11672 ) -> Union [str , Iterator [Optional [str ]]]:
11773 if isinstance (input , str ):
@@ -120,13 +76,13 @@ def __call__(
12076
12177 def chat (
12278 self ,
123- chat : List [Message ],
79+ chat : Sequence [Message ],
12480 ** kwargs : Any ,
12581 ) -> Union [str , Iterator [Optional [str ]]]:
12682 """Chat with the LMM model.
12783
12884 Parameters:
129- chat (List [Dict[str, str]]): A list of dictionaries containing the chat
85+ chat (Squence [Dict[str, str]]): A list of dictionaries containing the chat
13086 messages. The messages can be in the format:
13187 [{"role": "user", "content": "Hello!"}, ...]
13288 or if it contains media, it should be in the format:
@@ -147,6 +103,7 @@ def chat(
147103 "url" : (
148104 encoded_media
149105 if encoded_media .startswith (("http" , "https" ))
106+ or encoded_media .startswith ("data:image/" )
150107 else f"data:image/png;base64,{ encoded_media } "
151108 ),
152109 "detail" : "low" ,
@@ -174,7 +131,7 @@ def f() -> Iterator[Optional[str]]:
174131 def generate (
175132 self ,
176133 prompt : str ,
177- media : Optional [List [Union [str , Path ]]] = None ,
134+ media : Optional [Sequence [Union [str , Path ]]] = None ,
178135 ** kwargs : Any ,
179136 ) -> Union [str , Iterator [Optional [str ]]]:
180137 message : List [Dict [str , Any ]] = [
@@ -192,7 +149,12 @@ def generate(
192149 {
193150 "type" : "image_url" ,
194151 "image_url" : {
195- "url" : f"data:image/png;base64,{ encoded_media } " ,
152+ "url" : (
153+ encoded_media
154+ if encoded_media .startswith (("http" , "https" ))
155+ or encoded_media .startswith ("data:image/" )
156+ else f"data:image/png;base64,{ encoded_media } "
157+ ),
196158 "detail" : "low" ,
197159 },
198160 },
@@ -214,81 +176,6 @@ def f() -> Iterator[Optional[str]]:
214176 else :
215177 return cast (str , response .choices [0 ].message .content )
216178
217- def generate_classifier (self , question : str ) -> Callable :
218- api_doc = T .get_tool_documentation ([T .clip ])
219- prompt = CHOOSE_PARAMS .format (api_doc = api_doc , question = question )
220- response = self .client .chat .completions .create (
221- model = self .model_name ,
222- messages = [
223- {"role" : "system" , "content" : SYSTEM_PROMPT },
224- {"role" : "user" , "content" : prompt },
225- ],
226- response_format = {"type" : "json_object" },
227- )
228-
229- try :
230- params = json .loads (cast (str , response .choices [0 ].message .content ))[
231- "Parameters"
232- ]
233- except json .JSONDecodeError :
234- _LOGGER .error (
235- f"Failed to decode response: { response .choices [0 ].message .content } "
236- )
237- raise ValueError ("Failed to decode response" )
238-
239- return lambda x : T .clip (x , params ["prompt" ])
240-
241- def generate_detector (self , question : str ) -> Callable :
242- api_doc = T .get_tool_documentation ([T .owl_v2 ])
243- prompt = CHOOSE_PARAMS .format (api_doc = api_doc , question = question )
244- response = self .client .chat .completions .create (
245- model = self .model_name ,
246- messages = [
247- {"role" : "system" , "content" : SYSTEM_PROMPT },
248- {"role" : "user" , "content" : prompt },
249- ],
250- response_format = {"type" : "json_object" },
251- )
252-
253- try :
254- params = json .loads (cast (str , response .choices [0 ].message .content ))[
255- "Parameters"
256- ]
257- except json .JSONDecodeError :
258- _LOGGER .error (
259- f"Failed to decode response: { response .choices [0 ].message .content } "
260- )
261- raise ValueError ("Failed to decode response" )
262-
263- return lambda x : T .owl_v2 (params ["prompt" ], x )
264-
265- def generate_segmentor (self , question : str ) -> Callable :
266- api_doc = T .get_tool_documentation ([T .grounding_sam ])
267- prompt = CHOOSE_PARAMS .format (api_doc = api_doc , question = question )
268- response = self .client .chat .completions .create (
269- model = self .model_name ,
270- messages = [
271- {"role" : "system" , "content" : SYSTEM_PROMPT },
272- {"role" : "user" , "content" : prompt },
273- ],
274- response_format = {"type" : "json_object" },
275- )
276-
277- try :
278- params = json .loads (cast (str , response .choices [0 ].message .content ))[
279- "Parameters"
280- ]
281- except json .JSONDecodeError :
282- _LOGGER .error (
283- f"Failed to decode response: { response .choices [0 ].message .content } "
284- )
285- raise ValueError ("Failed to decode response" )
286-
287- return lambda x : T .grounding_sam (params ["prompt" ], x )
288-
289- def generate_image_qa_tool (self , question : str ) -> Callable :
290- return lambda x : T .git_vqa_v2 (question , x )
291-
292179
293180class AzureOpenAILMM (OpenAILMM ):
294181 def __init__ (
@@ -362,7 +249,7 @@ def __init__(
362249
363250 def __call__ (
364251 self ,
365- input : Union [str , List [Message ]],
252+ input : Union [str , Sequence [Message ]],
366253 ** kwargs : Any ,
367254 ) -> Union [str , Iterator [Optional [str ]]]:
368255 if isinstance (input , str ):
@@ -371,13 +258,13 @@ def __call__(
371258
372259 def chat (
373260 self ,
374- chat : List [Message ],
261+ chat : Sequence [Message ],
375262 ** kwargs : Any ,
376263 ) -> Union [str , Iterator [Optional [str ]]]:
377264 """Chat with the LMM model.
378265
379266 Parameters:
380- chat (List [Dict[str, str]]): A list of dictionaries containing the chat
267+ chat (Sequence [Dict[str, str]]): A list of dictionaries containing the chat
381268 messages. The messages can be in the format:
382269 [{"role": "user", "content": "Hello!"}, ...]
383270 or if it contains media, it should be in the format:
@@ -429,7 +316,7 @@ def f() -> Iterator[Optional[str]]:
429316 def generate (
430317 self ,
431318 prompt : str ,
432- media : Optional [List [Union [str , Path ]]] = None ,
319+ media : Optional [Sequence [Union [str , Path ]]] = None ,
433320 ** kwargs : Any ,
434321 ) -> Union [str , Iterator [Optional [str ]]]:
435322 url = f"{ self .url } /generate"
@@ -493,7 +380,7 @@ def __init__(
493380
494381 def __call__ (
495382 self ,
496- input : Union [str , List [Dict [str , Any ]]],
383+ input : Union [str , Sequence [Dict [str , Any ]]],
497384 ** kwargs : Any ,
498385 ) -> Union [str , Iterator [Optional [str ]]]:
499386 if isinstance (input , str ):
@@ -502,7 +389,7 @@ def __call__(
502389
503390 def chat (
504391 self ,
505- chat : List [Dict [str , Any ]],
392+ chat : Sequence [Dict [str , Any ]],
506393 ** kwargs : Any ,
507394 ) -> Union [str , Iterator [Optional [str ]]]:
508395 messages : List [MessageParam ] = []
@@ -551,7 +438,7 @@ def f() -> Iterator[Optional[str]]:
551438 def generate (
552439 self ,
553440 prompt : str ,
554- media : Optional [List [Union [str , Path ]]] = None ,
441+ media : Optional [Sequence [Union [str , Path ]]] = None ,
555442 ** kwargs : Any ,
556443 ) -> Union [str , Iterator [Optional [str ]]]:
557444 content : List [Union [TextBlockParam , ImageBlockParam ]] = [
0 commit comments