1
1
import json
2
2
import logging
3
3
from pathlib import Path
4
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
4
+ from typing import Any , Callable , Dict , List , Mapping , Optional , Tuple , Union
5
5
6
6
import pandas as pd
7
7
from rich .console import Console
33
33
_CONSOLE = Console ()
34
34
35
35
36
+ def build_working_memory (working_memory : Mapping [str , List [str ]]) -> Sim :
37
+ data : Mapping [str , List [str ]] = {"desc" : [], "doc" : []}
38
+ for key , value in working_memory .items ():
39
+ data ["desc" ].append (key )
40
+ data ["doc" ].append ("\n " .join (value ))
41
+ df = pd .DataFrame (data ) # type: ignore
42
+ return Sim (df , sim_key = "desc" )
43
+
44
+
36
45
def extract_code (code : str ) -> str :
37
46
if "```python" in code :
38
47
code = code [code .find ("```python" ) + len ("```python" ) :]
@@ -41,12 +50,21 @@ def extract_code(code: str) -> str:
41
50
42
51
43
52
def write_plan (
44
- user_requirements : str , tool_desc : str , model : LLM
45
- ) -> List [Dict [str , Any ]]:
53
+ chat : List [Dict [str , str ]],
54
+ plan : Optional [List [Dict [str , Any ]]],
55
+ tool_desc : str ,
56
+ model : LLM ,
57
+ ) -> Tuple [str , List [Dict [str , Any ]]]:
58
+ # Get last user request
59
+ if chat [- 1 ]["role" ] != "user" :
60
+ raise ValueError ("Last chat message must be from the user." )
61
+ user_requirements = chat [- 1 ]["content" ]
62
+
46
63
context = USER_REQ_CONTEXT .format (user_requirement = user_requirements )
47
- prompt = PLAN .format (context = context , plan = "" , tool_desc = tool_desc )
48
- plan = json .loads (model (prompt ).replace ("```" , "" ).strip ())
49
- return plan ["plan" ] # type: ignore
64
+ prompt = PLAN .format (context = context , plan = str (plan ), tool_desc = tool_desc )
65
+ chat [- 1 ]["content" ] = prompt
66
+ plan = json .loads (model .chat (chat ).replace ("```" , "" ).strip ())
67
+ return plan ["user_req" ], plan ["plan" ] # type: ignore
50
68
51
69
52
70
def write_code (
@@ -123,7 +141,7 @@ def write_and_exec_code(
123
141
user_req : str ,
124
142
subtask : str ,
125
143
orig_code : str ,
126
- code_writer_call : Callable ,
144
+ code_writer_call : Callable [..., str ] ,
127
145
model : LLM ,
128
146
tool_info : str ,
129
147
exec : Execute ,
@@ -191,6 +209,7 @@ def run_plan(
191
209
current_test = ""
192
210
retrieved_ltm = ""
193
211
working_memory : Dict [str , List [str ]] = {}
212
+
194
213
for task in active_plan :
195
214
_LOGGER .info (
196
215
f"""
@@ -209,7 +228,7 @@ def run_plan(
209
228
user_req ,
210
229
task ["instruction" ],
211
230
current_code ,
212
- write_code if task ["type" ] == "code" else write_test ,
231
+ write_code if task ["type" ] == "code" else write_test , # type: ignore
213
232
coder ,
214
233
tool_info ,
215
234
exec ,
@@ -277,43 +296,55 @@ def __init__(
277
296
if "doc" not in long_term_memory .df .columns :
278
297
raise ValueError ("Long term memory must have a 'doc' column." )
279
298
self .long_term_memory = long_term_memory
299
+ self .max_retries = 3
280
300
if self .verbose :
281
301
_LOGGER .setLevel (logging .INFO )
282
302
283
303
def __call__ (
284
304
self ,
285
305
input : Union [List [Dict [str , str ]], str ],
286
306
image : Optional [Union [str , Path ]] = None ,
307
+ plan : Optional [List [Dict [str , Any ]]] = None ,
287
308
) -> str :
288
309
if isinstance (input , str ):
289
310
input = [{"role" : "user" , "content" : input }]
290
- code , _ = self .chat_with_tests (input , image )
291
- return code
311
+ results = self .chat_with_workflow (input , image , plan )
312
+ return results [ " code" ] # type: ignore
292
313
293
- def chat_with_tests (
314
+ def chat_with_workflow (
294
315
self ,
295
316
chat : List [Dict [str , str ]],
296
317
image : Optional [Union [str , Path ]] = None ,
297
- ) -> Tuple [str , str ]:
318
+ plan : Optional [List [Dict [str , Any ]]] = None ,
319
+ ) -> Dict [str , Any ]:
298
320
if len (chat ) == 0 :
299
321
raise ValueError ("Input cannot be empty." )
300
322
301
- user_req = chat [0 ]["content" ]
302
323
if image is not None :
303
- user_req += f" Image name { image } "
324
+ # append file names to all user messages
325
+ for chat_i in chat :
326
+ if chat_i ["role" ] == "user" :
327
+ chat_i ["content" ] += f" Image name { image } "
328
+
329
+ working_code = ""
330
+ if plan is not None :
331
+ # grab the latest working code from a previous plan
332
+ for task in plan :
333
+ if "success" in task and "code" in task and task ["success" ]:
334
+ working_code = task ["code" ]
304
335
305
- plan = write_plan (user_req , TOOL_DESCRIPTIONS , self .planner )
336
+ user_req , plan = write_plan (chat , plan , TOOL_DESCRIPTIONS , self .planner )
306
337
_LOGGER .info (
307
338
f"""Plan:
308
339
{ tabulate (tabular_data = plan , headers = "keys" , tablefmt = "mixed_grid" , maxcolwidths = _MAX_TABULATE_COL_WIDTH )} """
309
340
)
310
341
311
- working_code = ""
312
342
working_test = ""
343
+ working_memory : Dict [str , List [str ]] = {}
313
344
success = False
345
+ retries = 0
314
346
315
- __import__ ("ipdb" ).set_trace ()
316
- while not success :
347
+ while not success and retries < self .max_retries :
317
348
working_code , working_test , plan , working_memory_i = run_plan (
318
349
user_req ,
319
350
plan ,
@@ -325,22 +356,21 @@ def chat_with_tests(
325
356
self .verbose ,
326
357
)
327
358
success = all (task ["success" ] for task in plan )
328
- self . _working_memory .update (working_memory_i )
359
+ working_memory .update (working_memory_i )
329
360
330
361
if not success :
331
- # TODO: ask for feedback and replan
362
+ # return to user and request feedback
332
363
break
333
364
334
- return working_code , working_test
365
+ retries += 1
335
366
336
- @property
337
- def working_memory (self ) -> Sim :
338
- data : Dict [str , List [str ]] = {"desc" : [], "doc" : []}
339
- for key , value in self ._working_memory .items ():
340
- data ["desc" ].append (key )
341
- data ["doc" ].append ("\n " .join (value ))
342
- df = pd .DataFrame (data )
343
- return Sim (df , sim_key = "desc" )
367
+ return {
368
+ "code" : working_code ,
369
+ "test" : working_test ,
370
+ "success" : success ,
371
+ "working_memory" : build_working_memory (working_memory ),
372
+ "plan" : plan ,
373
+ }
344
374
345
375
def log_progress (self , description : str ) -> None :
346
376
pass
0 commit comments