Skip to content

Commit

Permalink
Make it easier to visualize planned tasks in logs (#31)
Browse files Browse the repository at this point in the history
* Make it easier to visualize planned tasks in logs

* Add type stubs

---------

Co-authored-by: Yazhou Cao <[email protected]>
  • Loading branch information
humpydonkey and AsiaCao authored Mar 27, 2024
1 parent ca80cf6 commit e99a9f6
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 5 deletions.
73 changes: 70 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ openai = "1.*"
typing_extensions = "4.*"
moviepy = "1.*"
opencv-python-headless = "4.*"
tabulate = "^0.9.0"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand All @@ -47,6 +48,7 @@ setuptools = "^68.0.0"
mkdocs = "^1.5.3"
mkdocstrings = {extras = ["python"], version = "^0.23.0"}
mkdocs-material = "^9.4.2"
types-tabulate = "^0.9.0.20240106"

[tool.pytest.ini_options]
log_cli = true
Expand Down
13 changes: 11 additions & 2 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from tabulate import tabulate

from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.tools import TOOLS
Expand Down Expand Up @@ -268,6 +270,11 @@ def retrieval(
{"task": question, "tool_name": tool_name, "parameters": parameters}
]

_LOGGER.info(
f"""Going to run the following {len(tool_results)} tool(s) in sequence:
{tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results: List[Any] = []
if isinstance(result["parameters"], Dict):
Expand Down Expand Up @@ -298,8 +305,6 @@ def create_tasks(
{k: v["description"] for k, v in tools.items()},
reflections,
)

_LOGGER.info(f"Tasks: {tasks}")
if tasks is not None:
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(task_model, question, task_list)
Expand All @@ -309,6 +314,10 @@ def create_tasks(
_LOGGER.error(f"Failed topological_sort on: {task_list}")
else:
task_list = []
_LOGGER.info(
f"""Planned tasks:
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
)
return task_list


Expand Down

0 comments on commit e99a9f6

Please sign in to comment.