From 80e04052025234134967110284e3f5ce8da9648a Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Mon, 22 Jul 2024 11:20:57 +0800 Subject: [PATCH 1/8] fix: add back some logging for ui --- vision_agent/agent/vision_agent.py | 39 +++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 110a9a17..f3d3977d 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -491,15 +491,8 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None: def retrieve_tools( plans: Dict[str, List[Dict[str, str]]], tool_recommender: Sim, - log_progress: Callable[[Dict[str, Any]], None], verbosity: int = 0, ) -> Dict[str, str]: - log_progress( - { - "type": "tools", - "status": "started", - } - ) tool_info = [] tool_desc = [] tool_lists: Dict[str, List[Dict[str, str]]] = {} @@ -524,7 +517,7 @@ def retrieve_tools( ) all_tools = "\n\n".join(set(tool_info)) tool_lists_unique["all"] = all_tools - return tool_lists_unique + return tool_lists_unique, tool_lists class VisionAgent(Agent): @@ -692,16 +685,30 @@ def chat_with_workflow( self.planner, ) + if not test_multi_plan: + self.log_progress( + { + "type": "plans", + "status": "completed", + "payload": plans[list(plans.keys())[0]], + } + ) + self.log_progress( + { + "type": "tools", + "status": "started", + } + ) + if self.verbosity >= 1 and test_multi_plan: for p in plans: _LOGGER.info( f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) - tool_infos = retrieve_tools( + tool_infos, tool_lists = retrieve_tools( plans, self.tool_recommender, - self.log_progress, self.verbosity, ) @@ -730,11 +737,19 @@ def chat_with_workflow( plan_i = plans[k] tool_info = tool_infos[k] + if test_multi_plan: + self.log_progress( + { + "type": "plans", + "status": "completed", + "payload": tool_lists[best_plan], + } + ) self.log_progress( { - "type": "plans", + "type": "tools", "status": "completed", - "payload": plan_i, + "payload": tool_lists[best_plan], } ) if self.verbosity >= 1: From 6d1ccba42a561842f0c126f78edd272ca31a0eed Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Mon, 22 Jul 2024 11:52:04 +0800 Subject: [PATCH 2/8] simplifier --- vision_agent/agent/vision_agent.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index f3d3977d..e5c86f99 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -156,6 +156,7 @@ def write_plans( tool_desc: str, working_memory: str, model: LMM, + test_multi_plan: bool, ) -> Dict[str, Any]: chat = copy.deepcopy(chat) if chat[-1]["role"] != "user": @@ -737,14 +738,6 @@ def chat_with_workflow( plan_i = plans[k] tool_info = tool_infos[k] - if test_multi_plan: - self.log_progress( - { - "type": "plans", - "status": "completed", - "payload": tool_lists[best_plan], - } - ) self.log_progress( { "type": "tools", From 487ec4415b8323f1ce38e36e6d47797d7ecc6946 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Mon, 22 Jul 2024 12:49:49 +0800 Subject: [PATCH 3/8] resolve lint error --- vision_agent/agent/vision_agent.py | 96 +++++++++++++++--------------- 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index e5c86f99..8881fb5c 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -156,7 +156,6 @@ def write_plans( tool_desc: str, working_memory: str, model: LMM, - test_multi_plan: bool, ) -> Dict[str, Any]: chat = copy.deepcopy(chat) if chat[-1]["role"] != "user": @@ -173,19 +172,25 @@ def write_plans( def pick_plan( chat: List[Message], plans: Dict[str, Any], - tool_info: str, + tool_infos: Dict[str, str], model: LMM, code_interpreter: CodeInterpreter, + test_multi_plan: bool, verbosity: int = 0, max_retries: int = 3, ) -> Tuple[str, str]: + if not test_multi_plan: + k = list(plans.keys())[0] + return plans[k], tool_infos[k], "" + + all_tool_info = tool_infos["all"] chat = copy.deepcopy(chat) if chat[-1]["role"] != "user": raise ValueError("Last chat message must be from the user.") plan_str = format_plans(plans) prompt = TEST_PLANS.format( - docstring=tool_info, plans=plan_str, previous_attempts="" + docstring=all_tool_info, plans=plan_str, previous_attempts="" ) code = extract_code(model(prompt)) @@ -202,7 +207,7 @@ def pick_plan( count = 0 while (not tool_output.success or tool_output_str == "") and count < max_retries: prompt = TEST_PLANS.format( - docstring=tool_info, + docstring=all_tool_info, plans=plan_str, previous_attempts=PREVIOUS_FAILED.format( code=code, error=tool_output.text() @@ -238,7 +243,17 @@ def pick_plan( best_plan = extract_json(model(chat)) if verbosity >= 1: _LOGGER.info(f"Best plan:\n{best_plan}") - return best_plan["best_plan"], tool_output_str + + plan = best_plan["best_plan"] + if plan in plans and best_plan in tool_infos: + return plans[best_plan], tool_infos[best_plan], tool_output_str + else: + if verbosity >= 1: + _LOGGER.warning( + f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info." + ) + k = list(plans.keys())[0] + return plans[k], tool_infos[k], tool_output_str @traceable @@ -492,8 +507,15 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None: def retrieve_tools( plans: Dict[str, List[Dict[str, str]]], tool_recommender: Sim, + log_progress: Callable[[Dict[str, Any]], None], verbosity: int = 0, ) -> Dict[str, str]: + log_progress( + { + "type": "tools", + "status": "started", + } + ) tool_info = [] tool_desc = [] tool_lists: Dict[str, List[Dict[str, str]]] = {} @@ -518,7 +540,14 @@ def retrieve_tools( ) all_tools = "\n\n".join(set(tool_info)) tool_lists_unique["all"] = all_tools - return tool_lists_unique, tool_lists + log_progress( + { + "type": "tools", + "status": "completed", + "payload": tool_lists[list(plans.keys())[0]], + } + ) + return tool_lists_unique class VisionAgent(Agent): @@ -694,12 +723,6 @@ def chat_with_workflow( "payload": plans[list(plans.keys())[0]], } ) - self.log_progress( - { - "type": "tools", - "status": "started", - } - ) if self.verbosity >= 1 and test_multi_plan: for p in plans: @@ -707,53 +730,32 @@ def chat_with_workflow( f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) - tool_infos, tool_lists = retrieve_tools( + tool_infos = retrieve_tools( plans, self.tool_recommender, + self.log_progress, self.verbosity, ) - if test_multi_plan: - best_plan, tool_output_str = pick_plan( - int_chat, - plans, - tool_infos["all"], - self.coder, - code_interpreter, - verbosity=self.verbosity, - ) - else: - best_plan = list(plans.keys())[0] - tool_output_str = "" - - if best_plan in plans and best_plan in tool_infos: - plan_i = plans[best_plan] - tool_info = tool_infos[best_plan] - else: - if self.verbosity >= 1: - _LOGGER.warning( - f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info." - ) - k = list(plans.keys())[0] - plan_i = plans[k] - tool_info = tool_infos[k] - - self.log_progress( - { - "type": "tools", - "status": "completed", - "payload": tool_lists[best_plan], - } + best_plan, best_tool_info, tool_output_str = pick_plan( + int_chat, + plans, + tool_infos, + self.coder, + code_interpreter, + test_multi_plan, + verbosity=self.verbosity, ) + if self.verbosity >= 1: _LOGGER.info( - f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + f"Picked best plan:\n{tabulate(tabular_data=best_plan, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) results = write_and_test_code( chat=[{"role": c["role"], "content": c["content"]} for c in int_chat], - plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]), - tool_info=tool_info, + plan="\n-" + "\n-".join([e["instructions"] for e in best_plan]), + tool_info=best_tool_info, tool_output=tool_output_str, tool_utils=T.UTILITIES_DOCSTRING, working_memory=working_memory, From 1b0b382152e509466f7b688f5be836e4da492935 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Mon, 22 Jul 2024 12:54:01 +0800 Subject: [PATCH 4/8] fix --- vision_agent/agent/vision_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 8881fb5c..dbd751c6 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -771,7 +771,7 @@ def chat_with_workflow( code = cast(str, results["code"]) test = cast(str, results["test"]) working_memory.extend(results["working_memory"]) # type: ignore - plan.append({"code": code, "test": test, "plan": plan_i}) + plan.append({"code": code, "test": test, "plan": best_plan}) execution_result = cast(Execution, results["test_result"]) self.log_progress( From 80e9c379da43b54da77b7d3faf7511ec92db67e7 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Mon, 22 Jul 2024 12:56:28 +0800 Subject: [PATCH 5/8] fix type error --- vision_agent/agent/vision_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index dbd751c6..1e7a358f 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -178,7 +178,7 @@ def pick_plan( test_multi_plan: bool, verbosity: int = 0, max_retries: int = 3, -) -> Tuple[str, str]: +) -> Tuple[str, str, str]: if not test_multi_plan: k = list(plans.keys())[0] return plans[k], tool_infos[k], "" From 41b136152a3bb2aa16624a9638d56bac4b60c170 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Mon, 22 Jul 2024 12:59:30 +0800 Subject: [PATCH 6/8] fix type error --- vision_agent/agent/vision_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 1e7a358f..c96ab70a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -245,12 +245,12 @@ def pick_plan( _LOGGER.info(f"Best plan:\n{best_plan}") plan = best_plan["best_plan"] - if plan in plans and best_plan in tool_infos: - return plans[best_plan], tool_infos[best_plan], tool_output_str + if plan in plans and plan in tool_infos: + return plans[plan], tool_infos[plan], tool_output_str else: if verbosity >= 1: _LOGGER.warning( - f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info." + f"Best plan {plan} not found in plans or tool_infos. Using the first plan and tool info." ) k = list(plans.keys())[0] return plans[k], tool_infos[k], tool_output_str From 995b53cfb1dcc08f0cd5e19ebcce01ff056a8bf6 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Mon, 22 Jul 2024 13:03:00 +0800 Subject: [PATCH 7/8] fix type error --- vision_agent/agent/vision_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index c96ab70a..0dd4a195 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -178,7 +178,7 @@ def pick_plan( test_multi_plan: bool, verbosity: int = 0, max_retries: int = 3, -) -> Tuple[str, str, str]: +) -> Tuple[Any, str, str]: if not test_multi_plan: k = list(plans.keys())[0] return plans[k], tool_infos[k], "" From 9855e28de3da3bff688459342e1b944fb0dba23d Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Mon, 22 Jul 2024 13:57:48 +0800 Subject: [PATCH 8/8] remove check --- vision_agent/agent/vision_agent.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 0dd4a195..d8fc079a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -715,14 +715,13 @@ def chat_with_workflow( self.planner, ) - if not test_multi_plan: - self.log_progress( - { - "type": "plans", - "status": "completed", - "payload": plans[list(plans.keys())[0]], - } - ) + self.log_progress( + { + "type": "plans", + "status": "completed", + "payload": plans[list(plans.keys())[0]], + } + ) if self.verbosity >= 1 and test_multi_plan: for p in plans: