diff --git a/examples/custom_tools/run_custom_tool.py b/examples/custom_tools/run_custom_tool.py index fb0cc5c0..1e61ab6e 100644 --- a/examples/custom_tools/run_custom_tool.py +++ b/examples/custom_tools/run_custom_tool.py @@ -1,7 +1,7 @@ import numpy as np -from template_match import template_matching_with_rotation import vision_agent as va +from template_match import template_matching_with_rotation from vision_agent.utils.image_utils import get_image_size, normalize_bbox diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 013eb950..e421fe16 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -2,9 +2,11 @@ import json import logging import sys +import tempfile from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, cast +from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast +from PIL import Image from rich.console import Console from rich.style import Style from rich.syntax import Syntax @@ -78,12 +80,35 @@ def extract_json(json_str: str) -> Dict[str, Any]: return json_dict # type: ignore +def extract_image( + media: Optional[Sequence[Union[str, Path]]] +) -> Optional[Sequence[Union[str, Path]]]: + if media is None: + return None + + new_media = [] + for m in media: + m = Path(m) + extension = m.suffix + if extension in [".jpg", ".jpeg", ".png", ".bmp"]: + new_media.append(m) + elif extension in [".mp4", ".mov"]: + frames = T.extract_frames(m) + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + if len(frames) > 0: + Image.fromarray(frames[0][0]).save(tmp.name) + new_media.append(Path(tmp.name)) + if len(new_media) == 0: + return None + return new_media + + def write_plan( chat: List[Dict[str, str]], tool_desc: str, working_memory: str, model: Union[LLM, LMM], - media: Optional[List[Union[str, Path]]] = None, + media: Optional[Sequence[Union[str, Path]]] = None, ) -> List[Dict[str, str]]: chat = copy.deepcopy(chat) if chat[-1]["role"] != "user": @@ -94,6 +119,7 @@ def write_plan( prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory) chat[-1]["content"] = prompt if isinstance(model, OpenAILMM): + media = extract_image(media) return extract_json(model.chat(chat, images=media))["plan"] # type: ignore else: return extract_json(model.chat(chat))["plan"] # type: ignore @@ -103,7 +129,7 @@ def reflect( chat: List[Dict[str, str]], plan: str, code: str, - model: LLM, + model: Union[LLM, LMM], ) -> Dict[str, Union[str, bool]]: chat = copy.deepcopy(chat) if chat[-1]["role"] != "user": @@ -309,7 +335,7 @@ class VisionAgent(Agent): def __init__( self, - planner: Optional[LLM] = None, + planner: Optional[Union[LLM, LMM]] = None, coder: Optional[LLM] = None, tester: Optional[LLM] = None, debugger: Optional[LLM] = None, diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index 1b3d464e..97f73ebc 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -29,14 +29,17 @@ {feedback} **Instructions**: -Based on the context and tools you have available, write a plan of subtasks to achieve the user request utilizing given tools when necessary. Output a list of jsons in the following format: +1. Based on the context and tools you have available, write a plan of subtasks to achieve the user request. +2. Go over the users request step by step and ensure each step is represented as a clear subtask in your plan. + +Output a list of jsons in the following format ```json {{ "plan": [ {{ - "instructions": str # what you should do in this task, one short phrase or sentence + "instructions": str # what you should do in this task associated with a tool }} ] }} diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index be62787b..061da146 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -199,14 +199,15 @@ def extract_frames( def ocr(image: np.ndarray) -> List[Dict[str, Any]]: """'ocr' extracts text from an image. It returns a list of detected text, bounding - boxes, and confidence scores. The results are sorted from top-left to bottom right + boxes with normalized coordinates, and confidence scores. The results are sorted + from top-left to bottom right. Parameters: image (np.ndarray): The image to extract text from. Returns: - List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox, - and confidence score. + List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox + with nornmalized coordinates, and confidence score. Example -------