@@ -172,19 +172,25 @@ def write_plans(
172
172
def pick_plan (
173
173
chat : List [Message ],
174
174
plans : Dict [str , Any ],
175
- tool_info : str ,
175
+ tool_infos : Dict [ str , str ] ,
176
176
model : LMM ,
177
177
code_interpreter : CodeInterpreter ,
178
+ test_multi_plan : bool ,
178
179
verbosity : int = 0 ,
179
180
max_retries : int = 3 ,
180
- ) -> Tuple [str , str ]:
181
+ ) -> Tuple [Any , str , str ]:
182
+ if not test_multi_plan :
183
+ k = list (plans .keys ())[0 ]
184
+ return plans [k ], tool_infos [k ], ""
185
+
186
+ all_tool_info = tool_infos ["all" ]
181
187
chat = copy .deepcopy (chat )
182
188
if chat [- 1 ]["role" ] != "user" :
183
189
raise ValueError ("Last chat message must be from the user." )
184
190
185
191
plan_str = format_plans (plans )
186
192
prompt = TEST_PLANS .format (
187
- docstring = tool_info , plans = plan_str , previous_attempts = ""
193
+ docstring = all_tool_info , plans = plan_str , previous_attempts = ""
188
194
)
189
195
190
196
code = extract_code (model (prompt ))
@@ -201,7 +207,7 @@ def pick_plan(
201
207
count = 0
202
208
while (not tool_output .success or tool_output_str == "" ) and count < max_retries :
203
209
prompt = TEST_PLANS .format (
204
- docstring = tool_info ,
210
+ docstring = all_tool_info ,
205
211
plans = plan_str ,
206
212
previous_attempts = PREVIOUS_FAILED .format (
207
213
code = code , error = tool_output .text ()
@@ -237,7 +243,17 @@ def pick_plan(
237
243
best_plan = extract_json (model (chat ))
238
244
if verbosity >= 1 :
239
245
_LOGGER .info (f"Best plan:\n { best_plan } " )
240
- return best_plan ["best_plan" ], tool_output_str
246
+
247
+ plan = best_plan ["best_plan" ]
248
+ if plan in plans and plan in tool_infos :
249
+ return plans [plan ], tool_infos [plan ], tool_output_str
250
+ else :
251
+ if verbosity >= 1 :
252
+ _LOGGER .warning (
253
+ f"Best plan { plan } not found in plans or tool_infos. Using the first plan and tool info."
254
+ )
255
+ k = list (plans .keys ())[0 ]
256
+ return plans [k ], tool_infos [k ], tool_output_str
241
257
242
258
243
259
@traceable
@@ -524,6 +540,13 @@ def retrieve_tools(
524
540
)
525
541
all_tools = "\n \n " .join (set (tool_info ))
526
542
tool_lists_unique ["all" ] = all_tools
543
+ log_progress (
544
+ {
545
+ "type" : "tools" ,
546
+ "status" : "completed" ,
547
+ "payload" : tool_lists [list (plans .keys ())[0 ]],
548
+ }
549
+ )
527
550
return tool_lists_unique
528
551
529
552
@@ -692,6 +715,14 @@ def chat_with_workflow(
692
715
self .planner ,
693
716
)
694
717
718
+ self .log_progress (
719
+ {
720
+ "type" : "plans" ,
721
+ "status" : "completed" ,
722
+ "payload" : plans [list (plans .keys ())[0 ]],
723
+ }
724
+ )
725
+
695
726
if self .verbosity >= 1 and test_multi_plan :
696
727
for p in plans :
697
728
_LOGGER .info (
@@ -705,47 +736,25 @@ def chat_with_workflow(
705
736
self .verbosity ,
706
737
)
707
738
708
- if test_multi_plan :
709
- best_plan , tool_output_str = pick_plan (
710
- int_chat ,
711
- plans ,
712
- tool_infos ["all" ],
713
- self .coder ,
714
- code_interpreter ,
715
- verbosity = self .verbosity ,
716
- )
717
- else :
718
- best_plan = list (plans .keys ())[0 ]
719
- tool_output_str = ""
720
-
721
- if best_plan in plans and best_plan in tool_infos :
722
- plan_i = plans [best_plan ]
723
- tool_info = tool_infos [best_plan ]
724
- else :
725
- if self .verbosity >= 1 :
726
- _LOGGER .warning (
727
- f"Best plan { best_plan } not found in plans or tool_infos. Using the first plan and tool info."
728
- )
729
- k = list (plans .keys ())[0 ]
730
- plan_i = plans [k ]
731
- tool_info = tool_infos [k ]
732
-
733
- self .log_progress (
734
- {
735
- "type" : "plans" ,
736
- "status" : "completed" ,
737
- "payload" : plan_i ,
738
- }
739
+ best_plan , best_tool_info , tool_output_str = pick_plan (
740
+ int_chat ,
741
+ plans ,
742
+ tool_infos ,
743
+ self .coder ,
744
+ code_interpreter ,
745
+ test_multi_plan ,
746
+ verbosity = self .verbosity ,
739
747
)
748
+
740
749
if self .verbosity >= 1 :
741
750
_LOGGER .info (
742
- f"Picked best plan:\n { tabulate (tabular_data = plan_i , headers = 'keys' , tablefmt = 'mixed_grid' , maxcolwidths = _MAX_TABULATE_COL_WIDTH )} "
751
+ f"Picked best plan:\n { tabulate (tabular_data = best_plan , headers = 'keys' , tablefmt = 'mixed_grid' , maxcolwidths = _MAX_TABULATE_COL_WIDTH )} "
743
752
)
744
753
745
754
results = write_and_test_code (
746
755
chat = [{"role" : c ["role" ], "content" : c ["content" ]} for c in int_chat ],
747
- plan = "\n -" + "\n -" .join ([e ["instructions" ] for e in plan_i ]),
748
- tool_info = tool_info ,
756
+ plan = "\n -" + "\n -" .join ([e ["instructions" ] for e in best_plan ]),
757
+ tool_info = best_tool_info ,
749
758
tool_output = tool_output_str ,
750
759
tool_utils = T .UTILITIES_DOCSTRING ,
751
760
working_memory = working_memory ,
@@ -761,7 +770,7 @@ def chat_with_workflow(
761
770
code = cast (str , results ["code" ])
762
771
test = cast (str , results ["test" ])
763
772
working_memory .extend (results ["working_memory" ]) # type: ignore
764
- plan .append ({"code" : code , "test" : test , "plan" : plan_i })
773
+ plan .append ({"code" : code , "test" : test , "plan" : best_plan })
765
774
766
775
execution_result = cast (Execution , results ["test_result" ])
767
776
self .log_progress (
0 commit comments