diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 3b0b0a4f..110a9a17 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -622,6 +622,7 @@ def __call__( def chat_with_workflow( self, chat: List[Message], + test_multi_plan: bool = True, display_visualization: bool = False, ) -> Dict[str, Any]: """Chat with Vision Agent and return intermediate information regarding the task. @@ -691,7 +692,7 @@ def chat_with_workflow( self.planner, ) - if self.verbosity >= 1: + if self.verbosity >= 1 and test_multi_plan: for p in plans: _LOGGER.info( f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" @@ -703,14 +704,19 @@ def chat_with_workflow( self.log_progress, self.verbosity, ) - best_plan, tool_output_str = pick_plan( - int_chat, - plans, - tool_infos["all"], - self.coder, - code_interpreter, - verbosity=self.verbosity, - ) + + if test_multi_plan: + best_plan, tool_output_str = pick_plan( + int_chat, + plans, + tool_infos["all"], + self.coder, + code_interpreter, + verbosity=self.verbosity, + ) + else: + best_plan = list(plans.keys())[0] + tool_output_str = "" if best_plan in plans and best_plan in tool_infos: plan_i = plans[best_plan]