Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
MingruiZhang committed Aug 26, 2024
1 parent 31af305 commit 8b344aa
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 16 deletions.
18 changes: 16 additions & 2 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ def write_plans(

user_request = chat[-1]["content"]
context = USER_REQ.format(user_request=user_request)
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
prompt = PLAN.format(
context=context,
tool_desc=tool_desc,
feedback=working_memory,
)
print(prompt)
chat[-1]["content"] = prompt
return extract_json(model(chat, stream=False)) # type: ignore

Expand Down Expand Up @@ -610,6 +615,8 @@ def __init__(
also None, the local python runtime environment will be used.
"""

print("testing1")

self.planner = (
OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner
)
Expand Down Expand Up @@ -661,6 +668,7 @@ def chat_with_workflow(
chat: List[Message],
test_multi_plan: bool = True,
display_visualization: bool = False,
customized_tool_names: List[str] = [],
) -> Dict[str, Any]:
"""Chat with VisionAgentCoder and return intermediate information regarding the
task.
Expand All @@ -676,12 +684,16 @@ def chat_with_workflow(
with the first plan.
display_visualization (bool): If True, it opens a new window locally to
show the image(s) created by visualization code (if there is any).
customized_tool_names (List[str]): A list of customized tools for agent to pick and use.
If not provided, default to full tool set from vision_agent.tools.
Returns:
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
and working memory of the agent.
"""

print("chat with workflow - start")

if not chat:
raise ValueError("Chat cannot be empty.")

Expand Down Expand Up @@ -729,7 +741,9 @@ def chat_with_workflow(
)
plans = write_plans(
int_chat,
T.TOOL_DESCRIPTIONS,
T.get_tool_descriptions_by_names(
customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS
),
format_memory(working_memory),
self.planner,
)
Expand Down
5 changes: 4 additions & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from .meta_tools import META_TOOL_DOCSTRING, florencev2_fine_tuning
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import (
FUNCTION_TOOLS,
TOOL_DESCRIPTIONS,
TOOL_DOCSTRING,
TOOLS,
TOOLS_DF,
UTIL_TOOLS,
UTILITIES_DOCSTRING,
blip_image_caption,
clip,
Expand All @@ -18,10 +20,11 @@
extract_frames,
florencev2_image_caption,
florencev2_object_detection,
florencev2_roberta_vqa,
florencev2_ocr,
florencev2_roberta_vqa,
generate_pose_image,
generate_soft_edge_image,
get_tool_descriptions_by_names,
get_tool_documentation,
git_vqa_v2,
grounding_dino,
Expand Down
23 changes: 23 additions & 0 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,29 @@ def get_tool_descriptions(funcs: List[Callable[..., Any]]) -> str:
return descriptions


def get_tool_descriptions_by_names(
tool_name: List[str],
funcs: List[Callable[..., Any]],
util_funcs: List[
Callable[..., Any]
], # util_funcs will always be added to the list of functions
) -> str:

invalid_names = [
name for name in tool_name if name not in {func.__name__ for func in funcs}
]

if invalid_names:
raise ValueError(f"Invalid customized tool names: {', '.join(invalid_names)}")

filtered_funcs = (
funcs
if not tool_name
else [func for func in funcs if func.__name__ in tool_name]
)
return get_tool_descriptions(filtered_funcs + util_funcs)


def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
data: Dict[str, List[str]] = {"desc": [], "doc": []}

Expand Down
33 changes: 20 additions & 13 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,34 @@
import json
import logging
import tempfile
from pathlib import Path
from importlib import resources
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import cv2
import requests
import numpy as np
from pytube import YouTube # type: ignore
import requests
from moviepy.editor import ImageSequenceClip
from PIL import Image, ImageDraw, ImageFont
from pillow_heif import register_heif_opener # type: ignore
from pytube import YouTube # type: ignore

from vision_agent.tools.tool_utils import (
send_inference_request,
get_tool_descriptions,
get_tool_descriptions_by_names,
get_tool_documentation,
get_tools_df,
send_inference_request,
)
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.execute import FileSerializer, MimeType
from vision_agent.utils.image_utils import (
b64_to_pil,
convert_quad_box_to_bbox,
convert_to_b64,
denormalize_bbox,
get_image_size,
normalize_bbox,
convert_quad_box_to_bbox,
rle_decode,
)

Expand Down Expand Up @@ -1285,7 +1286,17 @@ def overlay_heat_map(
return np.array(combined)


TOOLS = [
UTIL_TOOLS = [
save_json,
load_image,
save_image,
save_video,
overlay_bounding_boxes,
overlay_segmentation_masks,
overlay_heat_map,
]

FUNCTION_TOOLS = [
owl_v2,
grounding_sam,
extract_frames,
Expand All @@ -1305,15 +1316,11 @@ def overlay_heat_map(
generate_pose_image,
closest_mask_distance,
closest_box_distance,
save_json,
load_image,
save_image,
save_video,
overlay_bounding_boxes,
overlay_segmentation_masks,
overlay_heat_map,
template_match,
]

TOOLS = FUNCTION_TOOLS + UTIL_TOOLS

TOOLS_DF = get_tools_df(TOOLS) # type: ignore
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
Expand Down

0 comments on commit 8b344aa

Please sign in to comment.