diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 3a370c5e..0453624b 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -127,7 +127,12 @@ def write_plans( user_request = chat[-1]["content"] context = USER_REQ.format(user_request=user_request) - prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory) + prompt = PLAN.format( + context=context, + tool_desc=tool_desc, + feedback=working_memory, + ) + print(prompt) chat[-1]["content"] = prompt return extract_json(model(chat, stream=False)) # type: ignore @@ -610,6 +615,8 @@ def __init__( also None, the local python runtime environment will be used. """ + print("testing1") + self.planner = ( OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner ) @@ -661,6 +668,7 @@ def chat_with_workflow( chat: List[Message], test_multi_plan: bool = True, display_visualization: bool = False, + customized_tool_names: List[str] = [], ) -> Dict[str, Any]: """Chat with VisionAgentCoder and return intermediate information regarding the task. @@ -676,12 +684,16 @@ def chat_with_workflow( with the first plan. display_visualization (bool): If True, it opens a new window locally to show the image(s) created by visualization code (if there is any). + customized_tool_names (List[str]): A list of customized tools for agent to pick and use. + If not provided, default to full tool set from vision_agent.tools. Returns: Dict[str, Any]: A dictionary containing the code, test, test result, plan, and working memory of the agent. """ + print("chat with workflow - start") + if not chat: raise ValueError("Chat cannot be empty.") @@ -729,7 +741,9 @@ def chat_with_workflow( ) plans = write_plans( int_chat, - T.TOOL_DESCRIPTIONS, + T.get_tool_descriptions_by_names( + customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS + ), format_memory(working_memory), self.planner, ) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index f9879626..c415fd60 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -3,10 +3,12 @@ from .meta_tools import META_TOOL_DOCSTRING, florencev2_fine_tuning from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT from .tools import ( + FUNCTION_TOOLS, TOOL_DESCRIPTIONS, TOOL_DOCSTRING, TOOLS, TOOLS_DF, + UTIL_TOOLS, UTILITIES_DOCSTRING, blip_image_caption, clip, @@ -18,10 +20,11 @@ extract_frames, florencev2_image_caption, florencev2_object_detection, - florencev2_roberta_vqa, florencev2_ocr, + florencev2_roberta_vqa, generate_pose_image, generate_soft_edge_image, + get_tool_descriptions_by_names, get_tool_documentation, git_vqa_v2, grounding_dino, diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 0ff56177..1a1146d3 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -127,6 +127,29 @@ def get_tool_descriptions(funcs: List[Callable[..., Any]]) -> str: return descriptions +def get_tool_descriptions_by_names( + tool_name: List[str], + funcs: List[Callable[..., Any]], + util_funcs: List[ + Callable[..., Any] + ], # util_funcs will always be added to the list of functions +) -> str: + + invalid_names = [ + name for name in tool_name if name not in {func.__name__ for func in funcs} + ] + + if invalid_names: + raise ValueError(f"Invalid customized tool names: {', '.join(invalid_names)}") + + filtered_funcs = ( + funcs + if not tool_name + else [func for func in funcs if func.__name__ in tool_name] + ) + return get_tool_descriptions(filtered_funcs + util_funcs) + + def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: data: Dict[str, List[str]] = {"desc": [], "doc": []} diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 0254a455..212b15fa 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2,33 +2,34 @@ import json import logging import tempfile -from pathlib import Path from importlib import resources +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union, cast import cv2 -import requests import numpy as np -from pytube import YouTube # type: ignore +import requests from moviepy.editor import ImageSequenceClip from PIL import Image, ImageDraw, ImageFont from pillow_heif import register_heif_opener # type: ignore +from pytube import YouTube # type: ignore from vision_agent.tools.tool_utils import ( - send_inference_request, get_tool_descriptions, + get_tool_descriptions_by_names, get_tool_documentation, get_tools_df, + send_inference_request, ) from vision_agent.utils import extract_frames_from_video from vision_agent.utils.execute import FileSerializer, MimeType from vision_agent.utils.image_utils import ( b64_to_pil, + convert_quad_box_to_bbox, convert_to_b64, denormalize_bbox, get_image_size, normalize_bbox, - convert_quad_box_to_bbox, rle_decode, ) @@ -1285,7 +1286,17 @@ def overlay_heat_map( return np.array(combined) -TOOLS = [ +UTIL_TOOLS = [ + save_json, + load_image, + save_image, + save_video, + overlay_bounding_boxes, + overlay_segmentation_masks, + overlay_heat_map, +] + +FUNCTION_TOOLS = [ owl_v2, grounding_sam, extract_frames, @@ -1305,15 +1316,11 @@ def overlay_heat_map( generate_pose_image, closest_mask_distance, closest_box_distance, - save_json, - load_image, - save_image, - save_video, - overlay_bounding_boxes, - overlay_segmentation_masks, - overlay_heat_map, template_match, ] + +TOOLS = FUNCTION_TOOLS + UTIL_TOOLS + TOOLS_DF = get_tools_df(TOOLS) # type: ignore TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore