diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 3f445d80..6bba2905 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -128,7 +128,11 @@ def write_plans( user_request = chat[-1]["content"] context = USER_REQ.format(user_request=user_request) - prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory) + prompt = PLAN.format( + context=context, + tool_desc=tool_desc, + feedback=working_memory, + ) chat[-1]["content"] = prompt return extract_json(model(chat, stream=False)) # type: ignore @@ -674,6 +678,7 @@ def chat_with_workflow( chat: List[Message], test_multi_plan: bool = True, display_visualization: bool = False, + customized_tool_names: Optional[List[str]] = None, ) -> Dict[str, Any]: """Chat with VisionAgentCoder and return intermediate information regarding the task. @@ -689,6 +694,8 @@ def chat_with_workflow( with the first plan. display_visualization (bool): If True, it opens a new window locally to show the image(s) created by visualization code (if there is any). + customized_tool_names (List[str]): A list of customized tools for agent to pick and use. + If not provided, default to full tool set from vision_agent.tools. Returns: Dict[str, Any]: A dictionary containing the code, test, test result, plan, @@ -742,7 +749,9 @@ def chat_with_workflow( ) plans = write_plans( int_chat, - T.TOOL_DESCRIPTIONS, + T.get_tool_descriptions_by_names( + customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore + ), format_memory(working_memory), self.planner, ) @@ -754,7 +763,6 @@ def chat_with_workflow( _LOGGER.info( f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) - tool_infos = retrieve_tools( plans, self.tool_recommender, diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 1e587ce7..a90b7181 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,15 +1,16 @@ from typing import Callable, List, Optional -from .meta_tools import ( - META_TOOL_DOCSTRING, -) +from .meta_tools import META_TOOL_DOCSTRING from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT +from .tool_utils import get_tool_descriptions_by_names from .tools import ( + FUNCTION_TOOLS, TOOL_DESCRIPTIONS, TOOL_DOCSTRING, TOOLS, TOOLS_DF, TOOLS_INFO, + UTIL_TOOLS, UTILITIES_DOCSTRING, blip_image_caption, clip, diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 4f1b1f34..185563a4 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -142,6 +142,31 @@ def get_tool_descriptions(funcs: List[Callable[..., Any]]) -> str: return descriptions +def get_tool_descriptions_by_names( + tool_name: Optional[List[str]], + funcs: List[Callable[..., Any]], + util_funcs: List[ + Callable[..., Any] + ], # util_funcs will always be added to the list of functions +) -> str: + if tool_name is None: + return get_tool_descriptions(funcs + util_funcs) + + invalid_names = [ + name for name in tool_name if name not in {func.__name__ for func in funcs} + ] + + if invalid_names: + raise ValueError(f"Invalid customized tool names: {', '.join(invalid_names)}") + + filtered_funcs = ( + funcs + if not tool_name + else [func for func in funcs if func.__name__ in tool_name] + ) + return get_tool_descriptions(filtered_funcs + util_funcs) + + def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: data: Dict[str, List[str]] = {"desc": [], "doc": []} diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 2dade7f7..594fcf6d 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1650,7 +1650,7 @@ def florencev2_fine_tuned_object_detection( return return_data -TOOLS = [ +FUNCTION_TOOLS = [ owl_v2, extract_frames, ocr, @@ -1671,6 +1671,9 @@ def florencev2_fine_tuned_object_detection( generate_pose_image, closest_mask_distance, closest_box_distance, +] + +UTIL_TOOLS = [ save_json, load_image, save_image, @@ -1679,6 +1682,9 @@ def florencev2_fine_tuned_object_detection( overlay_segmentation_masks, overlay_heat_map, ] + +TOOLS = FUNCTION_TOOLS + UTIL_TOOLS + TOOLS_DF = get_tools_df(TOOLS) # type: ignore TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore