1
- import base64
2
- import io
3
1
import json
4
2
import logging
5
3
import os
6
4
from abc import ABC , abstractmethod
7
5
from pathlib import Path
8
- from typing import Any , Callable , Dict , Iterator , List , Optional , Union , cast
6
+ from typing import Any , Dict , Iterator , List , Optional , Union , cast , Sequence
9
7
10
8
import anthropic
11
9
import requests
12
10
from anthropic .types import ImageBlockParam , MessageParam , TextBlockParam
13
11
from openai import AzureOpenAI , OpenAI
14
- from PIL import Image
15
12
16
- import vision_agent .tools as T
17
- from vision_agent .tools .prompts import CHOOSE_PARAMS , SYSTEM_PROMPT
13
+ from vision_agent .utils .image_utils import encode_media
18
14
19
15
from .types import Message
20
16
21
17
_LOGGER = logging .getLogger (__name__ )
22
18
23
19
24
- def encode_image_bytes (image : bytes ) -> str :
25
- image = Image .open (io .BytesIO (image )).convert ("RGB" ) # type: ignore
26
- buffer = io .BytesIO ()
27
- image .save (buffer , format = "PNG" ) # type: ignore
28
- encoded_image = base64 .b64encode (buffer .getvalue ()).decode ("utf-8" )
29
- return encoded_image
30
-
31
-
32
- def encode_media (media : Union [str , Path ]) -> str :
33
- if type (media ) is str and media .startswith (("http" , "https" )):
34
- # for mp4 video url, we assume there is a same url but ends with png
35
- # vision-agent-ui will upload this png when uploading the video
36
- if media .endswith ((".mp4" , "mov" )) and media .find ("vision-agent-dev.s3" ) != - 1 :
37
- return media [:- 4 ] + ".png"
38
- return media
39
- extension = "png"
40
- extension = Path (media ).suffix
41
- if extension .lower () not in {
42
- ".jpg" ,
43
- ".jpeg" ,
44
- ".png" ,
45
- ".webp" ,
46
- ".bmp" ,
47
- ".mp4" ,
48
- ".mov" ,
49
- }:
50
- raise ValueError (f"Unsupported image extension: { extension } " )
51
-
52
- image_bytes = b""
53
- if extension .lower () in {".mp4" , ".mov" }:
54
- frames = T .extract_frames (media )
55
- image = frames [len (frames ) // 2 ]
56
- buffer = io .BytesIO ()
57
- Image .fromarray (image [0 ]).convert ("RGB" ).save (buffer , format = "PNG" )
58
- image_bytes = buffer .getvalue ()
59
- else :
60
- image_bytes = open (media , "rb" ).read ()
61
- return encode_image_bytes (image_bytes )
62
-
63
-
64
20
class LMM (ABC ):
65
21
@abstractmethod
66
22
def generate (
67
- self , prompt : str , media : Optional [List [Union [str , Path ]]] = None , ** kwargs : Any
23
+ self , prompt : str , media : Optional [Sequence [Union [str , Path ]]] = None , ** kwargs : Any
68
24
) -> Union [str , Iterator [Optional [str ]]]:
69
25
pass
70
26
71
27
@abstractmethod
72
28
def chat (
73
29
self ,
74
- chat : List [Message ],
30
+ chat : Sequence [Message ],
75
31
** kwargs : Any ,
76
32
) -> Union [str , Iterator [Optional [str ]]]:
77
33
pass
78
34
79
35
@abstractmethod
80
36
def __call__ (
81
37
self ,
82
- input : Union [str , List [Message ]],
38
+ input : Union [str , Sequence [Message ]],
83
39
** kwargs : Any ,
84
40
) -> Union [str , Iterator [Optional [str ]]]:
85
41
pass
@@ -111,7 +67,7 @@ def __init__(
111
67
112
68
def __call__ (
113
69
self ,
114
- input : Union [str , List [Message ]],
70
+ input : Union [str , Sequence [Message ]],
115
71
** kwargs : Any ,
116
72
) -> Union [str , Iterator [Optional [str ]]]:
117
73
if isinstance (input , str ):
@@ -120,13 +76,13 @@ def __call__(
120
76
121
77
def chat (
122
78
self ,
123
- chat : List [Message ],
79
+ chat : Sequence [Message ],
124
80
** kwargs : Any ,
125
81
) -> Union [str , Iterator [Optional [str ]]]:
126
82
"""Chat with the LMM model.
127
83
128
84
Parameters:
129
- chat (List [Dict[str, str]]): A list of dictionaries containing the chat
85
+ chat (Squence [Dict[str, str]]): A list of dictionaries containing the chat
130
86
messages. The messages can be in the format:
131
87
[{"role": "user", "content": "Hello!"}, ...]
132
88
or if it contains media, it should be in the format:
@@ -147,6 +103,7 @@ def chat(
147
103
"url" : (
148
104
encoded_media
149
105
if encoded_media .startswith (("http" , "https" ))
106
+ or encoded_media .startswith ("data:image/" )
150
107
else f"data:image/png;base64,{ encoded_media } "
151
108
),
152
109
"detail" : "low" ,
@@ -174,7 +131,7 @@ def f() -> Iterator[Optional[str]]:
174
131
def generate (
175
132
self ,
176
133
prompt : str ,
177
- media : Optional [List [Union [str , Path ]]] = None ,
134
+ media : Optional [Sequence [Union [str , Path ]]] = None ,
178
135
** kwargs : Any ,
179
136
) -> Union [str , Iterator [Optional [str ]]]:
180
137
message : List [Dict [str , Any ]] = [
@@ -192,7 +149,12 @@ def generate(
192
149
{
193
150
"type" : "image_url" ,
194
151
"image_url" : {
195
- "url" : f"data:image/png;base64,{ encoded_media } " ,
152
+ "url" : (
153
+ encoded_media
154
+ if encoded_media .startswith (("http" , "https" ))
155
+ or encoded_media .startswith ("data:image/" )
156
+ else f"data:image/png;base64,{ encoded_media } "
157
+ ),
196
158
"detail" : "low" ,
197
159
},
198
160
},
@@ -214,81 +176,6 @@ def f() -> Iterator[Optional[str]]:
214
176
else :
215
177
return cast (str , response .choices [0 ].message .content )
216
178
217
- def generate_classifier (self , question : str ) -> Callable :
218
- api_doc = T .get_tool_documentation ([T .clip ])
219
- prompt = CHOOSE_PARAMS .format (api_doc = api_doc , question = question )
220
- response = self .client .chat .completions .create (
221
- model = self .model_name ,
222
- messages = [
223
- {"role" : "system" , "content" : SYSTEM_PROMPT },
224
- {"role" : "user" , "content" : prompt },
225
- ],
226
- response_format = {"type" : "json_object" },
227
- )
228
-
229
- try :
230
- params = json .loads (cast (str , response .choices [0 ].message .content ))[
231
- "Parameters"
232
- ]
233
- except json .JSONDecodeError :
234
- _LOGGER .error (
235
- f"Failed to decode response: { response .choices [0 ].message .content } "
236
- )
237
- raise ValueError ("Failed to decode response" )
238
-
239
- return lambda x : T .clip (x , params ["prompt" ])
240
-
241
- def generate_detector (self , question : str ) -> Callable :
242
- api_doc = T .get_tool_documentation ([T .owl_v2 ])
243
- prompt = CHOOSE_PARAMS .format (api_doc = api_doc , question = question )
244
- response = self .client .chat .completions .create (
245
- model = self .model_name ,
246
- messages = [
247
- {"role" : "system" , "content" : SYSTEM_PROMPT },
248
- {"role" : "user" , "content" : prompt },
249
- ],
250
- response_format = {"type" : "json_object" },
251
- )
252
-
253
- try :
254
- params = json .loads (cast (str , response .choices [0 ].message .content ))[
255
- "Parameters"
256
- ]
257
- except json .JSONDecodeError :
258
- _LOGGER .error (
259
- f"Failed to decode response: { response .choices [0 ].message .content } "
260
- )
261
- raise ValueError ("Failed to decode response" )
262
-
263
- return lambda x : T .owl_v2 (params ["prompt" ], x )
264
-
265
- def generate_segmentor (self , question : str ) -> Callable :
266
- api_doc = T .get_tool_documentation ([T .grounding_sam ])
267
- prompt = CHOOSE_PARAMS .format (api_doc = api_doc , question = question )
268
- response = self .client .chat .completions .create (
269
- model = self .model_name ,
270
- messages = [
271
- {"role" : "system" , "content" : SYSTEM_PROMPT },
272
- {"role" : "user" , "content" : prompt },
273
- ],
274
- response_format = {"type" : "json_object" },
275
- )
276
-
277
- try :
278
- params = json .loads (cast (str , response .choices [0 ].message .content ))[
279
- "Parameters"
280
- ]
281
- except json .JSONDecodeError :
282
- _LOGGER .error (
283
- f"Failed to decode response: { response .choices [0 ].message .content } "
284
- )
285
- raise ValueError ("Failed to decode response" )
286
-
287
- return lambda x : T .grounding_sam (params ["prompt" ], x )
288
-
289
- def generate_image_qa_tool (self , question : str ) -> Callable :
290
- return lambda x : T .git_vqa_v2 (question , x )
291
-
292
179
293
180
class AzureOpenAILMM (OpenAILMM ):
294
181
def __init__ (
@@ -362,7 +249,7 @@ def __init__(
362
249
363
250
def __call__ (
364
251
self ,
365
- input : Union [str , List [Message ]],
252
+ input : Union [str , Sequence [Message ]],
366
253
** kwargs : Any ,
367
254
) -> Union [str , Iterator [Optional [str ]]]:
368
255
if isinstance (input , str ):
@@ -371,13 +258,13 @@ def __call__(
371
258
372
259
def chat (
373
260
self ,
374
- chat : List [Message ],
261
+ chat : Sequence [Message ],
375
262
** kwargs : Any ,
376
263
) -> Union [str , Iterator [Optional [str ]]]:
377
264
"""Chat with the LMM model.
378
265
379
266
Parameters:
380
- chat (List [Dict[str, str]]): A list of dictionaries containing the chat
267
+ chat (Sequence [Dict[str, str]]): A list of dictionaries containing the chat
381
268
messages. The messages can be in the format:
382
269
[{"role": "user", "content": "Hello!"}, ...]
383
270
or if it contains media, it should be in the format:
@@ -429,7 +316,7 @@ def f() -> Iterator[Optional[str]]:
429
316
def generate (
430
317
self ,
431
318
prompt : str ,
432
- media : Optional [List [Union [str , Path ]]] = None ,
319
+ media : Optional [Sequence [Union [str , Path ]]] = None ,
433
320
** kwargs : Any ,
434
321
) -> Union [str , Iterator [Optional [str ]]]:
435
322
url = f"{ self .url } /generate"
@@ -493,7 +380,7 @@ def __init__(
493
380
494
381
def __call__ (
495
382
self ,
496
- input : Union [str , List [Dict [str , Any ]]],
383
+ input : Union [str , Sequence [Dict [str , Any ]]],
497
384
** kwargs : Any ,
498
385
) -> Union [str , Iterator [Optional [str ]]]:
499
386
if isinstance (input , str ):
@@ -502,7 +389,7 @@ def __call__(
502
389
503
390
def chat (
504
391
self ,
505
- chat : List [Dict [str , Any ]],
392
+ chat : Sequence [Dict [str , Any ]],
506
393
** kwargs : Any ,
507
394
) -> Union [str , Iterator [Optional [str ]]]:
508
395
messages : List [MessageParam ] = []
@@ -551,7 +438,7 @@ def f() -> Iterator[Optional[str]]:
551
438
def generate (
552
439
self ,
553
440
prompt : str ,
554
- media : Optional [List [Union [str , Path ]]] = None ,
441
+ media : Optional [Sequence [Union [str , Path ]]] = None ,
555
442
** kwargs : Any ,
556
443
) -> Union [str , Iterator [Optional [str ]]]:
557
444
content : List [Union [TextBlockParam , ImageBlockParam ]] = [
0 commit comments