Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: customized_tool_names param to VisionAgentCoder #210

Merged
merged 9 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def write_plans(

user_request = chat[-1]["content"]
context = USER_REQ.format(user_request=user_request)
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
prompt = PLAN.format(
context=context,
tool_desc=tool_desc,
feedback=working_memory,
)
chat[-1]["content"] = prompt
return extract_json(model(chat, stream=False)) # type: ignore

Expand Down Expand Up @@ -674,6 +678,7 @@ def chat_with_workflow(
chat: List[Message],
test_multi_plan: bool = True,
display_visualization: bool = False,
customized_tool_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Chat with VisionAgentCoder and return intermediate information regarding the
task.
Expand All @@ -689,6 +694,8 @@ def chat_with_workflow(
with the first plan.
display_visualization (bool): If True, it opens a new window locally to
show the image(s) created by visualization code (if there is any).
customized_tool_names (List[str]): A list of customized tools for agent to pick and use.
If not provided, default to full tool set from vision_agent.tools.

Returns:
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
Expand Down Expand Up @@ -742,7 +749,9 @@ def chat_with_workflow(
)
plans = write_plans(
int_chat,
T.TOOL_DESCRIPTIONS,
T.get_tool_descriptions_by_names(
customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
),
format_memory(working_memory),
self.planner,
)
Expand All @@ -754,7 +763,6 @@ def chat_with_workflow(
_LOGGER.info(
f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_infos = retrieve_tools(
plans,
self.tool_recommender,
Expand Down
7 changes: 4 additions & 3 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Callable, List, Optional

from .meta_tools import (
META_TOOL_DOCSTRING,
)
from .meta_tools import META_TOOL_DOCSTRING
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tool_utils import get_tool_descriptions_by_names
from .tools import (
FUNCTION_TOOLS,
TOOL_DESCRIPTIONS,
TOOL_DOCSTRING,
TOOLS,
TOOLS_DF,
TOOLS_INFO,
UTIL_TOOLS,
UTILITIES_DOCSTRING,
blip_image_caption,
clip,
Expand Down
25 changes: 25 additions & 0 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,31 @@ def get_tool_descriptions(funcs: List[Callable[..., Any]]) -> str:
return descriptions


def get_tool_descriptions_by_names(
tool_name: Optional[List[str]],
funcs: List[Callable[..., Any]],
util_funcs: List[
Callable[..., Any]
], # util_funcs will always be added to the list of functions
) -> str:
if tool_name is None:
return get_tool_descriptions(funcs + util_funcs)

invalid_names = [
name for name in tool_name if name not in {func.__name__ for func in funcs}
]

if invalid_names:
raise ValueError(f"Invalid customized tool names: {', '.join(invalid_names)}")

filtered_funcs = (
funcs
if not tool_name
else [func for func in funcs if func.__name__ in tool_name]
)
return get_tool_descriptions(filtered_funcs + util_funcs)


def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
data: Dict[str, List[str]] = {"desc": [], "doc": []}

Expand Down
8 changes: 7 additions & 1 deletion vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,7 @@ def florencev2_fine_tuned_object_detection(
return return_data


TOOLS = [
FUNCTION_TOOLS = [
owl_v2,
extract_frames,
ocr,
Expand All @@ -1671,6 +1671,9 @@ def florencev2_fine_tuned_object_detection(
generate_pose_image,
closest_mask_distance,
closest_box_distance,
]

UTIL_TOOLS = [
save_json,
load_image,
save_image,
Expand All @@ -1679,6 +1682,9 @@ def florencev2_fine_tuned_object_detection(
overlay_segmentation_masks,
overlay_heat_map,
]

TOOLS = FUNCTION_TOOLS + UTIL_TOOLS

TOOLS_DF = get_tools_df(TOOLS) # type: ignore
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
Expand Down
Loading