From 82c9b6f3d594af50e0cb551f817f0cdc493d8862 Mon Sep 17 00:00:00 2001 From: wuyiqunLu <132986242+wuyiqunLu@users.noreply.github.com> Date: Fri, 26 Jul 2024 14:36:40 +0800 Subject: [PATCH] feat: new log with pick plan (#180) * feat: new log with pick plan * fix lint * format * fix lint * address comment --- vision_agent/agent/vision_agent.py | 141 +++++++++++++++++++++-------- 1 file changed, 105 insertions(+), 36 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index d8fc079a..d525962e 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -176,13 +176,29 @@ def pick_plan( model: LMM, code_interpreter: CodeInterpreter, test_multi_plan: bool, + log_progress: Callable[[Dict[str, Any]], None], verbosity: int = 0, max_retries: int = 3, ) -> Tuple[Any, str, str]: if not test_multi_plan: k = list(plans.keys())[0] + log_progress( + { + "type": "logs", + "log_content": "Plans created", + "status": "completed", + "payload": plans[k], + } + ) return plans[k], tool_infos[k], "" + log_progress( + { + "type": "logs", + "log_content": "Generating code to pick best plan", + "status": "started", + } + ) all_tool_info = tool_infos["all"] chat = copy.deepcopy(chat) if chat[-1]["role"] != "user": @@ -194,6 +210,14 @@ def pick_plan( ) code = extract_code(model(prompt)) + log_progress( + { + "type": "logs", + "log_content": "Executing code to test plan", + "code": code, + "status": "running", + } + ) tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code)) tool_output_str = "" if len(tool_output.logs.stdout) > 0: @@ -203,6 +227,18 @@ def pick_plan( _print_code("Initial code and tests:", code) _LOGGER.info(f"Initial code execution result:\n{tool_output.text()}") + log_progress( + { + "type": "logs", + "log_content": ( + "Code execution succeed" + if tool_output.success + else "Code execution failed" + ), + "payload": tool_output.to_json(), + "status": "completed" if tool_output.success else "failed", + } + ) # retry if the tool output is empty or code fails count = 0 while (not tool_output.success or tool_output_str == "") and count < max_retries: @@ -213,10 +249,33 @@ def pick_plan( code=code, error=tool_output.text() ), ) + log_progress( + { + "type": "logs", + "log_content": "Retry running code", + "code": code, + "status": "running", + } + ) code = extract_code(model(prompt)) tool_output = code_interpreter.exec_isolation( DefaultImports.prepend_imports(code) ) + log_progress( + { + "type": "logs", + "log_content": ( + "Code execution succeed" + if tool_output.success + else "Code execution failed" + ), + "code": code, + "payload": { + "result": tool_output.to_json(), + }, + "status": "completed" if tool_output.success else "failed", + } + ) tool_output_str = "" if len(tool_output.logs.stdout) > 0: tool_output_str = tool_output.logs.stdout[0] @@ -246,14 +305,26 @@ def pick_plan( plan = best_plan["best_plan"] if plan in plans and plan in tool_infos: - return plans[plan], tool_infos[plan], tool_output_str + best_plans = plans[plan] + best_tool_infos = tool_infos[plan] else: if verbosity >= 1: _LOGGER.warning( f"Best plan {plan} not found in plans or tool_infos. Using the first plan and tool info." ) k = list(plans.keys())[0] - return plans[k], tool_infos[k], tool_output_str + best_plans = plans[k] + best_tool_infos = tool_infos[k] + + log_progress( + { + "type": "logs", + "log_content": "Picked best plan", + "status": "complete", + "payload": best_plans, + } + ) + return best_plans, best_tool_infos, tool_output_str @traceable @@ -323,7 +394,8 @@ def write_and_test_code( ) -> Dict[str, Any]: log_progress( { - "type": "code", + "type": "log", + "log_content": "Generating code", "status": "started", } ) @@ -341,10 +413,11 @@ def write_and_test_code( log_progress( { - "type": "code", + "type": "log", + "log_content": "Running code", "status": "running", + "code": DefaultImports.prepend_imports(code), "payload": { - "code": DefaultImports.prepend_imports(code), "test": test, }, } @@ -354,10 +427,13 @@ def write_and_test_code( ) log_progress( { - "type": "code", + "type": "log", + "log_content": ( + "Code execution succeed" if result.success else "Code execution failed" + ), "status": "completed" if result.success else "failed", + "code": DefaultImports.prepend_imports(code), "payload": { - "code": DefaultImports.prepend_imports(code), "test": test, "result": result.to_json(), }, @@ -507,15 +583,8 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None: def retrieve_tools( plans: Dict[str, List[Dict[str, str]]], tool_recommender: Sim, - log_progress: Callable[[Dict[str, Any]], None], verbosity: int = 0, -) -> Dict[str, str]: - log_progress( - { - "type": "tools", - "status": "started", - } - ) +) -> Tuple[Dict[str, str], Dict[str, List[Dict[str, str]]]]: tool_info = [] tool_desc = [] tool_lists: Dict[str, List[Dict[str, str]]] = {} @@ -526,7 +595,12 @@ def retrieve_tools( tool_info.extend([e["doc"] for e in tools]) tool_desc.extend([e["desc"] for e in tools]) tool_lists[k].extend( - {"description": e["desc"], "documentation": e["doc"]} for e in tools + { + "plan": task["instructions"] if index == 0 else "", + "tool": e["desc"].strip().split()[0], + "documentation": e["doc"], + } + for index, e in enumerate(tools) ) if verbosity == 2: @@ -540,14 +614,7 @@ def retrieve_tools( ) all_tools = "\n\n".join(set(tool_info)) tool_lists_unique["all"] = all_tools - log_progress( - { - "type": "tools", - "status": "completed", - "payload": tool_lists[list(plans.keys())[0]], - } - ) - return tool_lists_unique + return tool_lists_unique, tool_lists class VisionAgent(Agent): @@ -704,7 +771,8 @@ def chat_with_workflow( self.log_progress( { - "type": "plans", + "type": "logs", + "log_content": "Creating plans", "status": "started", } ) @@ -715,27 +783,28 @@ def chat_with_workflow( self.planner, ) - self.log_progress( - { - "type": "plans", - "status": "completed", - "payload": plans[list(plans.keys())[0]], - } - ) - if self.verbosity >= 1 and test_multi_plan: for p in plans: _LOGGER.info( f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) - tool_infos = retrieve_tools( + tool_infos, tool_lists = retrieve_tools( plans, self.tool_recommender, - self.log_progress, self.verbosity, ) + if test_multi_plan: + self.log_progress( + { + "type": "logs", + "log_content": "Creating plans", + "status": "completed", + "payload": tool_lists, + } + ) + best_plan, best_tool_info, tool_output_str = pick_plan( int_chat, plans, @@ -777,8 +846,8 @@ def chat_with_workflow( { "type": "final_code", "status": "completed" if success else "failed", + "code": DefaultImports.prepend_imports(code), "payload": { - "code": DefaultImports.prepend_imports(code), "test": test, "result": execution_result.to_json(), },