11import json
22import logging
33from pathlib import Path
4- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
4+ from typing import Any , Callable , Dict , List , Mapping , Optional , Tuple , Union
55
66import pandas as pd
77from rich .console import Console
3333_CONSOLE = Console ()
3434
3535
36+ def build_working_memory (working_memory : Mapping [str , List [str ]]) -> Sim :
37+ data : Mapping [str , List [str ]] = {"desc" : [], "doc" : []}
38+ for key , value in working_memory .items ():
39+ data ["desc" ].append (key )
40+ data ["doc" ].append ("\n " .join (value ))
41+ df = pd .DataFrame (data ) # type: ignore
42+ return Sim (df , sim_key = "desc" )
43+
44+
3645def extract_code (code : str ) -> str :
3746 if "```python" in code :
3847 code = code [code .find ("```python" ) + len ("```python" ) :]
@@ -41,12 +50,21 @@ def extract_code(code: str) -> str:
4150
4251
4352def write_plan (
44- user_requirements : str , tool_desc : str , model : LLM
45- ) -> List [Dict [str , Any ]]:
53+ chat : List [Dict [str , str ]],
54+ plan : Optional [List [Dict [str , Any ]]],
55+ tool_desc : str ,
56+ model : LLM ,
57+ ) -> Tuple [str , List [Dict [str , Any ]]]:
58+ # Get last user request
59+ if chat [- 1 ]["role" ] != "user" :
60+ raise ValueError ("Last chat message must be from the user." )
61+ user_requirements = chat [- 1 ]["content" ]
62+
4663 context = USER_REQ_CONTEXT .format (user_requirement = user_requirements )
47- prompt = PLAN .format (context = context , plan = "" , tool_desc = tool_desc )
48- plan = json .loads (model (prompt ).replace ("```" , "" ).strip ())
49- return plan ["plan" ] # type: ignore
64+ prompt = PLAN .format (context = context , plan = str (plan ), tool_desc = tool_desc )
65+ chat [- 1 ]["content" ] = prompt
66+ plan = json .loads (model .chat (chat ).replace ("```" , "" ).strip ())
67+ return plan ["user_req" ], plan ["plan" ] # type: ignore
5068
5169
5270def write_code (
@@ -123,7 +141,7 @@ def write_and_exec_code(
123141 user_req : str ,
124142 subtask : str ,
125143 orig_code : str ,
126- code_writer_call : Callable ,
144+ code_writer_call : Callable [..., str ] ,
127145 model : LLM ,
128146 tool_info : str ,
129147 exec : Execute ,
@@ -191,6 +209,7 @@ def run_plan(
191209 current_test = ""
192210 retrieved_ltm = ""
193211 working_memory : Dict [str , List [str ]] = {}
212+
194213 for task in active_plan :
195214 _LOGGER .info (
196215 f"""
@@ -209,7 +228,7 @@ def run_plan(
209228 user_req ,
210229 task ["instruction" ],
211230 current_code ,
212- write_code if task ["type" ] == "code" else write_test ,
231+ write_code if task ["type" ] == "code" else write_test , # type: ignore
213232 coder ,
214233 tool_info ,
215234 exec ,
@@ -277,43 +296,55 @@ def __init__(
277296 if "doc" not in long_term_memory .df .columns :
278297 raise ValueError ("Long term memory must have a 'doc' column." )
279298 self .long_term_memory = long_term_memory
299+ self .max_retries = 3
280300 if self .verbose :
281301 _LOGGER .setLevel (logging .INFO )
282302
283303 def __call__ (
284304 self ,
285305 input : Union [List [Dict [str , str ]], str ],
286306 image : Optional [Union [str , Path ]] = None ,
307+ plan : Optional [List [Dict [str , Any ]]] = None ,
287308 ) -> str :
288309 if isinstance (input , str ):
289310 input = [{"role" : "user" , "content" : input }]
290- code , _ = self .chat_with_tests (input , image )
291- return code
311+ results = self .chat_with_workflow (input , image , plan )
312+ return results [ " code" ] # type: ignore
292313
293- def chat_with_tests (
314+ def chat_with_workflow (
294315 self ,
295316 chat : List [Dict [str , str ]],
296317 image : Optional [Union [str , Path ]] = None ,
297- ) -> Tuple [str , str ]:
318+ plan : Optional [List [Dict [str , Any ]]] = None ,
319+ ) -> Dict [str , Any ]:
298320 if len (chat ) == 0 :
299321 raise ValueError ("Input cannot be empty." )
300322
301- user_req = chat [0 ]["content" ]
302323 if image is not None :
303- user_req += f" Image name { image } "
324+ # append file names to all user messages
325+ for chat_i in chat :
326+ if chat_i ["role" ] == "user" :
327+ chat_i ["content" ] += f" Image name { image } "
328+
329+ working_code = ""
330+ if plan is not None :
331+ # grab the latest working code from a previous plan
332+ for task in plan :
333+ if "success" in task and "code" in task and task ["success" ]:
334+ working_code = task ["code" ]
304335
305- plan = write_plan (user_req , TOOL_DESCRIPTIONS , self .planner )
336+ user_req , plan = write_plan (chat , plan , TOOL_DESCRIPTIONS , self .planner )
306337 _LOGGER .info (
307338 f"""Plan:
308339{ tabulate (tabular_data = plan , headers = "keys" , tablefmt = "mixed_grid" , maxcolwidths = _MAX_TABULATE_COL_WIDTH )} """
309340 )
310341
311- working_code = ""
312342 working_test = ""
343+ working_memory : Dict [str , List [str ]] = {}
313344 success = False
345+ retries = 0
314346
315- __import__ ("ipdb" ).set_trace ()
316- while not success :
347+ while not success and retries < self .max_retries :
317348 working_code , working_test , plan , working_memory_i = run_plan (
318349 user_req ,
319350 plan ,
@@ -325,22 +356,21 @@ def chat_with_tests(
325356 self .verbose ,
326357 )
327358 success = all (task ["success" ] for task in plan )
328- self . _working_memory .update (working_memory_i )
359+ working_memory .update (working_memory_i )
329360
330361 if not success :
331- # TODO: ask for feedback and replan
362+ # return to user and request feedback
332363 break
333364
334- return working_code , working_test
365+ retries += 1
335366
336- @property
337- def working_memory (self ) -> Sim :
338- data : Dict [str , List [str ]] = {"desc" : [], "doc" : []}
339- for key , value in self ._working_memory .items ():
340- data ["desc" ].append (key )
341- data ["doc" ].append ("\n " .join (value ))
342- df = pd .DataFrame (data )
343- return Sim (df , sim_key = "desc" )
367+ return {
368+ "code" : working_code ,
369+ "test" : working_test ,
370+ "success" : success ,
371+ "working_memory" : build_working_memory (working_memory ),
372+ "plan" : plan ,
373+ }
344374
345375 def log_progress (self , description : str ) -> None :
346376 pass
0 commit comments