diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 545911d3..3b0b0a4f 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -176,6 +176,7 @@ def pick_plan( model: LMM, code_interpreter: CodeInterpreter, verbosity: int = 0, + max_retries: int = 3, ) -> Tuple[str, str]: chat = copy.deepcopy(chat) if chat[-1]["role"] != "user": @@ -192,13 +193,13 @@ def pick_plan( if len(tool_output.logs.stdout) > 0: tool_output_str = tool_output.logs.stdout[0] - if verbosity >= 1: + if verbosity == 2: _print_code("Initial code and tests:", code) _LOGGER.info(f"Initial code execution result:\n{tool_output.text()}") # retry if the tool output is empty or code fails - count = 1 - while (not tool_output.success or tool_output_str == "") and count < 3: + count = 0 + while (not tool_output.success or tool_output_str == "") and count < max_retries: prompt = TEST_PLANS.format( docstring=tool_info, plans=plan_str, @@ -214,12 +215,15 @@ def pick_plan( if len(tool_output.logs.stdout) > 0: tool_output_str = tool_output.logs.stdout[0] - if verbosity == 1: + if verbosity == 2: _print_code("Code and test after attempted fix:", code) _LOGGER.info(f"Code execution result after attempte {count}") count += 1 + if verbosity >= 1: + _print_code("Final code:", code) + user_req = chat[-1]["content"] context = USER_REQ.format(user_request=user_req) # because the tool picker model gets the image as well, we have to be careful with @@ -408,7 +412,7 @@ def debug_code( FIX_BUG.format( code=code, tests=test, - result="\n".join(result.text().splitlines()[-50:]), + result="\n".join(result.text().splitlines()[-100:]), feedback=format_memory(working_memory + new_working_memory), ) ) @@ -673,92 +677,85 @@ def chat_with_workflow( working_memory: List[Dict[str, str]] = [] results = {"code": "", "test": "", "plan": []} plan = [] - success = False - retries = 0 - - while not success and retries < self.max_retries: - self.log_progress( - { - "type": "plans", - "status": "started", - } - ) - plans = write_plans( - int_chat, - T.TOOL_DESCRIPTIONS, - format_memory(working_memory), - self.planner, - ) - if self.verbosity >= 1: - for p in plans: - _LOGGER.info( - f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" - ) - - tool_infos = retrieve_tools( - plans, - self.tool_recommender, - self.log_progress, - self.verbosity, - ) - best_plan, tool_output_str = pick_plan( - int_chat, - plans, - tool_infos["all"], - self.coder, - code_interpreter, - verbosity=self.verbosity, - ) + self.log_progress( + { + "type": "plans", + "status": "started", + } + ) + plans = write_plans( + int_chat, + T.TOOL_DESCRIPTIONS, + format_memory(working_memory), + self.planner, + ) - if best_plan in plans and best_plan in tool_infos: - plan_i = plans[best_plan] - tool_info = tool_infos[best_plan] - else: - if self.verbosity >= 1: - _LOGGER.warning( - f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info." - ) - k = list(plans.keys())[0] - plan_i = plans[k] - tool_info = tool_infos[k] - - self.log_progress( - { - "type": "plans", - "status": "completed", - "payload": plan_i, - } - ) - if self.verbosity >= 1: + if self.verbosity >= 1: + for p in plans: _LOGGER.info( - f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) - results = write_and_test_code( - chat=[ - {"role": c["role"], "content": c["content"]} for c in int_chat - ], - plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]), - tool_info=tool_info, - tool_output=tool_output_str, - tool_utils=T.UTILITIES_DOCSTRING, - working_memory=working_memory, - coder=self.coder, - tester=self.tester, - debugger=self.debugger, - code_interpreter=code_interpreter, - log_progress=self.log_progress, - verbosity=self.verbosity, - media=media_list, + tool_infos = retrieve_tools( + plans, + self.tool_recommender, + self.log_progress, + self.verbosity, + ) + best_plan, tool_output_str = pick_plan( + int_chat, + plans, + tool_infos["all"], + self.coder, + code_interpreter, + verbosity=self.verbosity, + ) + + if best_plan in plans and best_plan in tool_infos: + plan_i = plans[best_plan] + tool_info = tool_infos[best_plan] + else: + if self.verbosity >= 1: + _LOGGER.warning( + f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info." + ) + k = list(plans.keys())[0] + plan_i = plans[k] + tool_info = tool_infos[k] + + self.log_progress( + { + "type": "plans", + "status": "completed", + "payload": plan_i, + } + ) + if self.verbosity >= 1: + _LOGGER.info( + f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) - success = cast(bool, results["success"]) - code = cast(str, results["code"]) - test = cast(str, results["test"]) - working_memory.extend(results["working_memory"]) # type: ignore - plan.append({"code": code, "test": test, "plan": plan_i}) - retries += 1 + results = write_and_test_code( + chat=[{"role": c["role"], "content": c["content"]} for c in int_chat], + plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]), + tool_info=tool_info, + tool_output=tool_output_str, + tool_utils=T.UTILITIES_DOCSTRING, + working_memory=working_memory, + coder=self.coder, + tester=self.tester, + debugger=self.debugger, + code_interpreter=code_interpreter, + log_progress=self.log_progress, + verbosity=self.verbosity, + media=media_list, + ) + success = cast(bool, results["success"]) + code = cast(str, results["code"]) + test = cast(str, results["test"]) + working_memory.extend(results["working_memory"]) # type: ignore + plan.append({"code": code, "test": test, "plan": plan_i}) execution_result = cast(Execution, results["test_result"]) self.log_progress(