Skip to content

Commit

Permalink
feat: add dict to map func name to func docstring (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuyiqunLu authored Aug 26, 2024
1 parent 31af305 commit 66ed891
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
6 changes: 4 additions & 2 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TOOL_DOCSTRING,
TOOLS,
TOOLS_DF,
TOOLS_INFO,
UTILITIES_DOCSTRING,
blip_image_caption,
clip,
Expand Down Expand Up @@ -52,15 +53,16 @@ def register_tool(imports: Optional[List] = None) -> Callable:
def decorator(tool: Callable) -> Callable:
import inspect

from .tools import get_tool_descriptions, get_tools_df
from .tools import get_tool_descriptions, get_tools_df, get_tools_info

global TOOLS, TOOLS_DF, TOOL_DESCRIPTIONS, TOOL_DOCSTRING
global TOOLS, TOOLS_DF, TOOL_DESCRIPTIONS, TOOL_DOCSTRING, TOOLS_INFO

if tool not in TOOLS:
TOOLS.append(tool)
TOOLS_DF = get_tools_df(TOOLS) # type: ignore
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
TOOLS_INFO = get_tools_info(TOOLS) # type: ignore

globals()[tool.__name__] = tool
if imports is not None:
Expand Down
13 changes: 13 additions & 0 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,16 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
data["doc"].append(doc)

return pd.DataFrame(data) # type: ignore


def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
data: Dict[str, str] = {}

for func in funcs:
desc = func.__doc__
if desc is None:
desc = ""

data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"

return data
2 changes: 2 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_tool_descriptions,
get_tool_documentation,
get_tools_df,
get_tools_info,
)
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.execute import FileSerializer, MimeType
Expand Down Expand Up @@ -1317,6 +1318,7 @@ def overlay_heat_map(
TOOLS_DF = get_tools_df(TOOLS) # type: ignore
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
TOOLS_INFO = get_tools_info(TOOLS) # type: ignore
UTILITIES_DOCSTRING = get_tool_documentation(
[
save_json,
Expand Down

0 comments on commit 66ed891

Please sign in to comment.