Skip to content

Commit

Permalink
feat: customized_tool_names param to VisionAgentCoder (#210)
Browse files Browse the repository at this point in the history
Allow user to input customized_tool_names to VisionAgentCoder so user and limit and customize what tool to use.

The implementation is non-destructive, it only changes input of write_plans (so planner only uses provided customized_tool_names + a list of util tools (such as save_image). the rest is still up to VisionAgentCoder.

This means the tool user pick is not mandatory, VisionAgentCoder might still ignore the customized_tool_names if it doesn't think it solves the issue, even though it is the only choice it has
  • Loading branch information
MingruiZhang authored Aug 27, 2024
1 parent 7e149d1 commit b302e94
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 7 deletions.
14 changes: 11 additions & 3 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def write_plans(

user_request = chat[-1]["content"]
context = USER_REQ.format(user_request=user_request)
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
prompt = PLAN.format(
context=context,
tool_desc=tool_desc,
feedback=working_memory,
)
chat[-1]["content"] = prompt
return extract_json(model(chat, stream=False)) # type: ignore

Expand Down Expand Up @@ -674,6 +678,7 @@ def chat_with_workflow(
chat: List[Message],
test_multi_plan: bool = True,
display_visualization: bool = False,
customized_tool_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Chat with VisionAgentCoder and return intermediate information regarding the
task.
Expand All @@ -689,6 +694,8 @@ def chat_with_workflow(
with the first plan.
display_visualization (bool): If True, it opens a new window locally to
show the image(s) created by visualization code (if there is any).
customized_tool_names (List[str]): A list of customized tools for agent to pick and use.
If not provided, default to full tool set from vision_agent.tools.
Returns:
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
Expand Down Expand Up @@ -742,7 +749,9 @@ def chat_with_workflow(
)
plans = write_plans(
int_chat,
T.TOOL_DESCRIPTIONS,
T.get_tool_descriptions_by_names(
customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
),
format_memory(working_memory),
self.planner,
)
Expand All @@ -754,7 +763,6 @@ def chat_with_workflow(
_LOGGER.info(
f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_infos = retrieve_tools(
plans,
self.tool_recommender,
Expand Down
7 changes: 4 additions & 3 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Callable, List, Optional

from .meta_tools import (
META_TOOL_DOCSTRING,
)
from .meta_tools import META_TOOL_DOCSTRING
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tool_utils import get_tool_descriptions_by_names
from .tools import (
FUNCTION_TOOLS,
TOOL_DESCRIPTIONS,
TOOL_DOCSTRING,
TOOLS,
TOOLS_DF,
TOOLS_INFO,
UTIL_TOOLS,
UTILITIES_DOCSTRING,
blip_image_caption,
clip,
Expand Down
25 changes: 25 additions & 0 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,31 @@ def get_tool_descriptions(funcs: List[Callable[..., Any]]) -> str:
return descriptions


def get_tool_descriptions_by_names(
tool_name: Optional[List[str]],
funcs: List[Callable[..., Any]],
util_funcs: List[
Callable[..., Any]
], # util_funcs will always be added to the list of functions
) -> str:
if tool_name is None:
return get_tool_descriptions(funcs + util_funcs)

invalid_names = [
name for name in tool_name if name not in {func.__name__ for func in funcs}
]

if invalid_names:
raise ValueError(f"Invalid customized tool names: {', '.join(invalid_names)}")

filtered_funcs = (
funcs
if not tool_name
else [func for func in funcs if func.__name__ in tool_name]
)
return get_tool_descriptions(filtered_funcs + util_funcs)


def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
data: Dict[str, List[str]] = {"desc": [], "doc": []}

Expand Down
8 changes: 7 additions & 1 deletion vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,7 @@ def florencev2_fine_tuned_object_detection(
return return_data


TOOLS = [
FUNCTION_TOOLS = [
owl_v2,
extract_frames,
ocr,
Expand All @@ -1671,6 +1671,9 @@ def florencev2_fine_tuned_object_detection(
generate_pose_image,
closest_mask_distance,
closest_box_distance,
]

UTIL_TOOLS = [
save_json,
load_image,
save_image,
Expand All @@ -1679,6 +1682,9 @@ def florencev2_fine_tuned_object_detection(
overlay_segmentation_masks,
overlay_heat_map,
]

TOOLS = FUNCTION_TOOLS + UTIL_TOOLS

TOOLS_DF = get_tools_df(TOOLS) # type: ignore
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
Expand Down

0 comments on commit b302e94

Please sign in to comment.