Skip to content

Commit b302e94

Browse files
authored
feat: customized_tool_names param to VisionAgentCoder (#210)
Allow user to input customized_tool_names to VisionAgentCoder so user and limit and customize what tool to use. The implementation is non-destructive, it only changes input of write_plans (so planner only uses provided customized_tool_names + a list of util tools (such as save_image). the rest is still up to VisionAgentCoder. This means the tool user pick is not mandatory, VisionAgentCoder might still ignore the customized_tool_names if it doesn't think it solves the issue, even though it is the only choice it has
1 parent 7e149d1 commit b302e94

File tree

4 files changed

+47
-7
lines changed

4 files changed

+47
-7
lines changed

vision_agent/agent/vision_agent_coder.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ def write_plans(
128128

129129
user_request = chat[-1]["content"]
130130
context = USER_REQ.format(user_request=user_request)
131-
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
131+
prompt = PLAN.format(
132+
context=context,
133+
tool_desc=tool_desc,
134+
feedback=working_memory,
135+
)
132136
chat[-1]["content"] = prompt
133137
return extract_json(model(chat, stream=False)) # type: ignore
134138

@@ -674,6 +678,7 @@ def chat_with_workflow(
674678
chat: List[Message],
675679
test_multi_plan: bool = True,
676680
display_visualization: bool = False,
681+
customized_tool_names: Optional[List[str]] = None,
677682
) -> Dict[str, Any]:
678683
"""Chat with VisionAgentCoder and return intermediate information regarding the
679684
task.
@@ -689,6 +694,8 @@ def chat_with_workflow(
689694
with the first plan.
690695
display_visualization (bool): If True, it opens a new window locally to
691696
show the image(s) created by visualization code (if there is any).
697+
customized_tool_names (List[str]): A list of customized tools for agent to pick and use.
698+
If not provided, default to full tool set from vision_agent.tools.
692699
693700
Returns:
694701
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
@@ -742,7 +749,9 @@ def chat_with_workflow(
742749
)
743750
plans = write_plans(
744751
int_chat,
745-
T.TOOL_DESCRIPTIONS,
752+
T.get_tool_descriptions_by_names(
753+
customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
754+
),
746755
format_memory(working_memory),
747756
self.planner,
748757
)
@@ -754,7 +763,6 @@ def chat_with_workflow(
754763
_LOGGER.info(
755764
f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
756765
)
757-
758766
tool_infos = retrieve_tools(
759767
plans,
760768
self.tool_recommender,

vision_agent/tools/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from typing import Callable, List, Optional
22

3-
from .meta_tools import (
4-
META_TOOL_DOCSTRING,
5-
)
3+
from .meta_tools import META_TOOL_DOCSTRING
64
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
5+
from .tool_utils import get_tool_descriptions_by_names
76
from .tools import (
7+
FUNCTION_TOOLS,
88
TOOL_DESCRIPTIONS,
99
TOOL_DOCSTRING,
1010
TOOLS,
1111
TOOLS_DF,
1212
TOOLS_INFO,
13+
UTIL_TOOLS,
1314
UTILITIES_DOCSTRING,
1415
blip_image_caption,
1516
clip,

vision_agent/tools/tool_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,31 @@ def get_tool_descriptions(funcs: List[Callable[..., Any]]) -> str:
142142
return descriptions
143143

144144

145+
def get_tool_descriptions_by_names(
146+
tool_name: Optional[List[str]],
147+
funcs: List[Callable[..., Any]],
148+
util_funcs: List[
149+
Callable[..., Any]
150+
], # util_funcs will always be added to the list of functions
151+
) -> str:
152+
if tool_name is None:
153+
return get_tool_descriptions(funcs + util_funcs)
154+
155+
invalid_names = [
156+
name for name in tool_name if name not in {func.__name__ for func in funcs}
157+
]
158+
159+
if invalid_names:
160+
raise ValueError(f"Invalid customized tool names: {', '.join(invalid_names)}")
161+
162+
filtered_funcs = (
163+
funcs
164+
if not tool_name
165+
else [func for func in funcs if func.__name__ in tool_name]
166+
)
167+
return get_tool_descriptions(filtered_funcs + util_funcs)
168+
169+
145170
def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
146171
data: Dict[str, List[str]] = {"desc": [], "doc": []}
147172

vision_agent/tools/tools.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,7 @@ def florencev2_fine_tuned_object_detection(
16501650
return return_data
16511651

16521652

1653-
TOOLS = [
1653+
FUNCTION_TOOLS = [
16541654
owl_v2,
16551655
extract_frames,
16561656
ocr,
@@ -1671,6 +1671,9 @@ def florencev2_fine_tuned_object_detection(
16711671
generate_pose_image,
16721672
closest_mask_distance,
16731673
closest_box_distance,
1674+
]
1675+
1676+
UTIL_TOOLS = [
16741677
save_json,
16751678
load_image,
16761679
save_image,
@@ -1679,6 +1682,9 @@ def florencev2_fine_tuned_object_detection(
16791682
overlay_segmentation_masks,
16801683
overlay_heat_map,
16811684
]
1685+
1686+
TOOLS = FUNCTION_TOOLS + UTIL_TOOLS
1687+
16821688
TOOLS_DF = get_tools_df(TOOLS) # type: ignore
16831689
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
16841690
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore

0 commit comments

Comments
 (0)