Skip to content

Commit b4528ab

Browse files
committed
moved image related utils to image_utils
1 parent ad9cf58 commit b4528ab

File tree

2 files changed

+70
-136
lines changed

2 files changed

+70
-136
lines changed

vision_agent/lmm/lmm.py

Lines changed: 23 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,41 @@
1-
import base64
2-
import io
31
import json
42
import logging
53
import os
64
from abc import ABC, abstractmethod
75
from pathlib import Path
8-
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, cast
6+
from typing import Any, Dict, Iterator, List, Optional, Union, cast, Sequence
97

108
import anthropic
119
import requests
1210
from anthropic.types import ImageBlockParam, MessageParam, TextBlockParam
1311
from openai import AzureOpenAI, OpenAI
14-
from PIL import Image
1512

16-
import vision_agent.tools as T
17-
from vision_agent.tools.prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
13+
from vision_agent.utils.image_utils import encode_media
1814

1915
from .types import Message
2016

2117
_LOGGER = logging.getLogger(__name__)
2218

2319

24-
def encode_image_bytes(image: bytes) -> str:
25-
image = Image.open(io.BytesIO(image)).convert("RGB") # type: ignore
26-
buffer = io.BytesIO()
27-
image.save(buffer, format="PNG") # type: ignore
28-
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
29-
return encoded_image
30-
31-
32-
def encode_media(media: Union[str, Path]) -> str:
33-
if type(media) is str and media.startswith(("http", "https")):
34-
# for mp4 video url, we assume there is a same url but ends with png
35-
# vision-agent-ui will upload this png when uploading the video
36-
if media.endswith((".mp4", "mov")) and media.find("vision-agent-dev.s3") != -1:
37-
return media[:-4] + ".png"
38-
return media
39-
extension = "png"
40-
extension = Path(media).suffix
41-
if extension.lower() not in {
42-
".jpg",
43-
".jpeg",
44-
".png",
45-
".webp",
46-
".bmp",
47-
".mp4",
48-
".mov",
49-
}:
50-
raise ValueError(f"Unsupported image extension: {extension}")
51-
52-
image_bytes = b""
53-
if extension.lower() in {".mp4", ".mov"}:
54-
frames = T.extract_frames(media)
55-
image = frames[len(frames) // 2]
56-
buffer = io.BytesIO()
57-
Image.fromarray(image[0]).convert("RGB").save(buffer, format="PNG")
58-
image_bytes = buffer.getvalue()
59-
else:
60-
image_bytes = open(media, "rb").read()
61-
return encode_image_bytes(image_bytes)
62-
63-
6420
class LMM(ABC):
6521
@abstractmethod
6622
def generate(
67-
self, prompt: str, media: Optional[List[Union[str, Path]]] = None, **kwargs: Any
23+
self, prompt: str, media: Optional[Sequence[Union[str, Path]]] = None, **kwargs: Any
6824
) -> Union[str, Iterator[Optional[str]]]:
6925
pass
7026

7127
@abstractmethod
7228
def chat(
7329
self,
74-
chat: List[Message],
30+
chat: Sequence[Message],
7531
**kwargs: Any,
7632
) -> Union[str, Iterator[Optional[str]]]:
7733
pass
7834

7935
@abstractmethod
8036
def __call__(
8137
self,
82-
input: Union[str, List[Message]],
38+
input: Union[str, Sequence[Message]],
8339
**kwargs: Any,
8440
) -> Union[str, Iterator[Optional[str]]]:
8541
pass
@@ -111,7 +67,7 @@ def __init__(
11167

11268
def __call__(
11369
self,
114-
input: Union[str, List[Message]],
70+
input: Union[str, Sequence[Message]],
11571
**kwargs: Any,
11672
) -> Union[str, Iterator[Optional[str]]]:
11773
if isinstance(input, str):
@@ -120,13 +76,13 @@ def __call__(
12076

12177
def chat(
12278
self,
123-
chat: List[Message],
79+
chat: Sequence[Message],
12480
**kwargs: Any,
12581
) -> Union[str, Iterator[Optional[str]]]:
12682
"""Chat with the LMM model.
12783
12884
Parameters:
129-
chat (List[Dict[str, str]]): A list of dictionaries containing the chat
85+
chat (Squence[Dict[str, str]]): A list of dictionaries containing the chat
13086
messages. The messages can be in the format:
13187
[{"role": "user", "content": "Hello!"}, ...]
13288
or if it contains media, it should be in the format:
@@ -147,6 +103,7 @@ def chat(
147103
"url": (
148104
encoded_media
149105
if encoded_media.startswith(("http", "https"))
106+
or encoded_media.startswith("data:image/")
150107
else f"data:image/png;base64,{encoded_media}"
151108
),
152109
"detail": "low",
@@ -174,7 +131,7 @@ def f() -> Iterator[Optional[str]]:
174131
def generate(
175132
self,
176133
prompt: str,
177-
media: Optional[List[Union[str, Path]]] = None,
134+
media: Optional[Sequence[Union[str, Path]]] = None,
178135
**kwargs: Any,
179136
) -> Union[str, Iterator[Optional[str]]]:
180137
message: List[Dict[str, Any]] = [
@@ -192,7 +149,12 @@ def generate(
192149
{
193150
"type": "image_url",
194151
"image_url": {
195-
"url": f"data:image/png;base64,{encoded_media}",
152+
"url": (
153+
encoded_media
154+
if encoded_media.startswith(("http", "https"))
155+
or encoded_media.startswith("data:image/")
156+
else f"data:image/png;base64,{encoded_media}"
157+
),
196158
"detail": "low",
197159
},
198160
},
@@ -214,81 +176,6 @@ def f() -> Iterator[Optional[str]]:
214176
else:
215177
return cast(str, response.choices[0].message.content)
216178

217-
def generate_classifier(self, question: str) -> Callable:
218-
api_doc = T.get_tool_documentation([T.clip])
219-
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
220-
response = self.client.chat.completions.create(
221-
model=self.model_name,
222-
messages=[
223-
{"role": "system", "content": SYSTEM_PROMPT},
224-
{"role": "user", "content": prompt},
225-
],
226-
response_format={"type": "json_object"},
227-
)
228-
229-
try:
230-
params = json.loads(cast(str, response.choices[0].message.content))[
231-
"Parameters"
232-
]
233-
except json.JSONDecodeError:
234-
_LOGGER.error(
235-
f"Failed to decode response: {response.choices[0].message.content}"
236-
)
237-
raise ValueError("Failed to decode response")
238-
239-
return lambda x: T.clip(x, params["prompt"])
240-
241-
def generate_detector(self, question: str) -> Callable:
242-
api_doc = T.get_tool_documentation([T.owl_v2])
243-
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
244-
response = self.client.chat.completions.create(
245-
model=self.model_name,
246-
messages=[
247-
{"role": "system", "content": SYSTEM_PROMPT},
248-
{"role": "user", "content": prompt},
249-
],
250-
response_format={"type": "json_object"},
251-
)
252-
253-
try:
254-
params = json.loads(cast(str, response.choices[0].message.content))[
255-
"Parameters"
256-
]
257-
except json.JSONDecodeError:
258-
_LOGGER.error(
259-
f"Failed to decode response: {response.choices[0].message.content}"
260-
)
261-
raise ValueError("Failed to decode response")
262-
263-
return lambda x: T.owl_v2(params["prompt"], x)
264-
265-
def generate_segmentor(self, question: str) -> Callable:
266-
api_doc = T.get_tool_documentation([T.grounding_sam])
267-
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
268-
response = self.client.chat.completions.create(
269-
model=self.model_name,
270-
messages=[
271-
{"role": "system", "content": SYSTEM_PROMPT},
272-
{"role": "user", "content": prompt},
273-
],
274-
response_format={"type": "json_object"},
275-
)
276-
277-
try:
278-
params = json.loads(cast(str, response.choices[0].message.content))[
279-
"Parameters"
280-
]
281-
except json.JSONDecodeError:
282-
_LOGGER.error(
283-
f"Failed to decode response: {response.choices[0].message.content}"
284-
)
285-
raise ValueError("Failed to decode response")
286-
287-
return lambda x: T.grounding_sam(params["prompt"], x)
288-
289-
def generate_image_qa_tool(self, question: str) -> Callable:
290-
return lambda x: T.git_vqa_v2(question, x)
291-
292179

293180
class AzureOpenAILMM(OpenAILMM):
294181
def __init__(
@@ -362,7 +249,7 @@ def __init__(
362249

363250
def __call__(
364251
self,
365-
input: Union[str, List[Message]],
252+
input: Union[str, Sequence[Message]],
366253
**kwargs: Any,
367254
) -> Union[str, Iterator[Optional[str]]]:
368255
if isinstance(input, str):
@@ -371,13 +258,13 @@ def __call__(
371258

372259
def chat(
373260
self,
374-
chat: List[Message],
261+
chat: Sequence[Message],
375262
**kwargs: Any,
376263
) -> Union[str, Iterator[Optional[str]]]:
377264
"""Chat with the LMM model.
378265
379266
Parameters:
380-
chat (List[Dict[str, str]]): A list of dictionaries containing the chat
267+
chat (Sequence[Dict[str, str]]): A list of dictionaries containing the chat
381268
messages. The messages can be in the format:
382269
[{"role": "user", "content": "Hello!"}, ...]
383270
or if it contains media, it should be in the format:
@@ -429,7 +316,7 @@ def f() -> Iterator[Optional[str]]:
429316
def generate(
430317
self,
431318
prompt: str,
432-
media: Optional[List[Union[str, Path]]] = None,
319+
media: Optional[Sequence[Union[str, Path]]] = None,
433320
**kwargs: Any,
434321
) -> Union[str, Iterator[Optional[str]]]:
435322
url = f"{self.url}/generate"
@@ -493,7 +380,7 @@ def __init__(
493380

494381
def __call__(
495382
self,
496-
input: Union[str, List[Dict[str, Any]]],
383+
input: Union[str, Sequence[Dict[str, Any]]],
497384
**kwargs: Any,
498385
) -> Union[str, Iterator[Optional[str]]]:
499386
if isinstance(input, str):
@@ -502,7 +389,7 @@ def __call__(
502389

503390
def chat(
504391
self,
505-
chat: List[Dict[str, Any]],
392+
chat: Sequence[Dict[str, Any]],
506393
**kwargs: Any,
507394
) -> Union[str, Iterator[Optional[str]]]:
508395
messages: List[MessageParam] = []
@@ -551,7 +438,7 @@ def f() -> Iterator[Optional[str]]:
551438
def generate(
552439
self,
553440
prompt: str,
554-
media: Optional[List[Union[str, Path]]] = None,
441+
media: Optional[Sequence[Union[str, Path]]] = None,
555442
**kwargs: Any,
556443
) -> Union[str, Iterator[Optional[str]]]:
557444
content: List[Union[TextBlockParam, ImageBlockParam]] = [

vision_agent/utils/image_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from PIL import Image, ImageDraw, ImageFont
1414
from PIL.Image import Image as ImageType
1515

16+
from vision_agent.utils import extract_frames_from_video
17+
1618
COLORS = [
1719
(158, 218, 229),
1820
(219, 219, 141),
@@ -172,6 +174,51 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
172174
)
173175

174176

177+
def encode_image_bytes(image: bytes) -> str:
178+
image = Image.open(io.BytesIO(image)).convert("RGB") # type: ignore
179+
buffer = io.BytesIO()
180+
image.save(buffer, format="PNG") # type: ignore
181+
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
182+
return encoded_image
183+
184+
185+
def encode_media(media: Union[str, Path]) -> str:
186+
if isinstance(media, str) and media.startswith(("http", "https")):
187+
# for mp4 video url, we assume there is a same url but ends with png
188+
# vision-agent-ui will upload this png when uploading the video
189+
if media.endswith((".mp4", "mov")) and media.find("vision-agent-dev.s3") != -1:
190+
return media[:-4] + ".png"
191+
return media
192+
193+
# if media is already a base64 encoded image return
194+
if isinstance(media, str) and media.startswith("data:image/"):
195+
return media
196+
197+
extension = "png"
198+
extension = Path(media).suffix
199+
if extension.lower() not in {
200+
".jpg",
201+
".jpeg",
202+
".png",
203+
".webp",
204+
".bmp",
205+
".mp4",
206+
".mov",
207+
}:
208+
raise ValueError(f"Unsupported image extension: {extension}")
209+
210+
image_bytes = b""
211+
if extension.lower() in {".mp4", ".mov"}:
212+
frames = extract_frames_from_video(str(media), fps=1)
213+
image = frames[len(frames) // 2]
214+
buffer = io.BytesIO()
215+
Image.fromarray(image[0]).convert("RGB").save(buffer, format="PNG")
216+
image_bytes = buffer.getvalue()
217+
else:
218+
image_bytes = open(media, "rb").read()
219+
return encode_image_bytes(image_bytes)
220+
221+
175222
def denormalize_bbox(
176223
bbox: List[Union[int, float]], image_size: Tuple[int, ...]
177224
) -> List[float]:

0 commit comments

Comments
 (0)