From 02ef28cf8068d253d26895750b503c78de5b4dc2 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Wed, 17 Jul 2024 18:58:42 +0800 Subject: [PATCH] feat: add logs for picking plans --- vision_agent/agent/vision_agent.py | 52 ++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 3b0b0a4f..2806c38c 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -217,7 +217,7 @@ def pick_plan( if verbosity == 2: _print_code("Code and test after attempted fix:", code) - _LOGGER.info(f"Code execution result after attempte {count}") + _LOGGER.info(f"Code execution result after attempt {count}") count += 1 @@ -491,15 +491,8 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None: def retrieve_tools( plans: Dict[str, List[Dict[str, str]]], tool_recommender: Sim, - log_progress: Callable[[Dict[str, Any]], None], verbosity: int = 0, ) -> Dict[str, str]: - log_progress( - { - "type": "tools", - "status": "started", - } - ) tool_info = [] tool_desc = [] tool_lists: Dict[str, List[Dict[str, str]]] = {} @@ -510,7 +503,12 @@ def retrieve_tools( tool_info.extend([e["doc"] for e in tools]) tool_desc.extend([e["desc"] for e in tools]) tool_lists[k].extend( - {"description": e["desc"], "documentation": e["doc"]} for e in tools + { + "plan": task["instructions"] if index == 0 else "", + "tool": e["desc"].strip().split()[0], + "documentation": e["doc"], + } + for index, e in enumerate(tools) ) if verbosity == 2: @@ -524,7 +522,7 @@ def retrieve_tools( ) all_tools = "\n\n".join(set(tool_info)) tool_lists_unique["all"] = all_tools - return tool_lists_unique + return tool_lists_unique, tool_lists class VisionAgent(Agent): @@ -691,16 +689,40 @@ def chat_with_workflow( self.planner, ) + unique_instructions_set = set() + unique_instructions = [] + + for plan in plans.values(): + for item in plan: + instruction = item["instructions"] + if instruction not in unique_instructions_set: + unique_instructions_set.add(instruction) + unique_instructions.append(item) + + self.log_progress( + { + "type": "plans", + "status": "completed", + "payload": unique_instructions, + } + ) + if self.verbosity >= 1: for p in plans: _LOGGER.info( f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) - tool_infos = retrieve_tools( + self.log_progress( + { + "type": "pick_plans", + "status": "started", + } + ) + + tool_infos, tool_lists = retrieve_tools( plans, self.tool_recommender, - self.log_progress, self.verbosity, ) best_plan, tool_output_str = pick_plan( @@ -715,6 +737,7 @@ def chat_with_workflow( if best_plan in plans and best_plan in tool_infos: plan_i = plans[best_plan] tool_info = tool_infos[best_plan] + tool_list = tool_lists[best_plan] else: if self.verbosity >= 1: _LOGGER.warning( @@ -723,12 +746,13 @@ def chat_with_workflow( k = list(plans.keys())[0] plan_i = plans[k] tool_info = tool_infos[k] + tool_list = tool_lists[k] self.log_progress( { - "type": "plans", + "type": "pick_plans", "status": "completed", - "payload": plan_i, + "payload": tool_list, } ) if self.verbosity >= 1: