Skip to content

Commit

Permalink
Fix minor issues (#169)
Browse files Browse the repository at this point in the history
* update retries

* removed outer loop

* add more bug context

* improved prints for testing code
  • Loading branch information
dillonalaird authored Jul 16, 2024
1 parent 74f7c00 commit add68ce
Showing 1 changed file with 81 additions and 84 deletions.
165 changes: 81 additions & 84 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def pick_plan(
model: LMM,
code_interpreter: CodeInterpreter,
verbosity: int = 0,
max_retries: int = 3,
) -> Tuple[str, str]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
Expand All @@ -192,13 +193,13 @@ def pick_plan(
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]

if verbosity >= 1:
if verbosity == 2:
_print_code("Initial code and tests:", code)
_LOGGER.info(f"Initial code execution result:\n{tool_output.text()}")

# retry if the tool output is empty or code fails
count = 1
while (not tool_output.success or tool_output_str == "") and count < 3:
count = 0
while (not tool_output.success or tool_output_str == "") and count < max_retries:
prompt = TEST_PLANS.format(
docstring=tool_info,
plans=plan_str,
Expand All @@ -214,12 +215,15 @@ def pick_plan(
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]

if verbosity == 1:
if verbosity == 2:
_print_code("Code and test after attempted fix:", code)
_LOGGER.info(f"Code execution result after attempte {count}")

count += 1

if verbosity >= 1:
_print_code("Final code:", code)

user_req = chat[-1]["content"]
context = USER_REQ.format(user_request=user_req)
# because the tool picker model gets the image as well, we have to be careful with
Expand Down Expand Up @@ -408,7 +412,7 @@ def debug_code(
FIX_BUG.format(
code=code,
tests=test,
result="\n".join(result.text().splitlines()[-50:]),
result="\n".join(result.text().splitlines()[-100:]),
feedback=format_memory(working_memory + new_working_memory),
)
)
Expand Down Expand Up @@ -673,92 +677,85 @@ def chat_with_workflow(
working_memory: List[Dict[str, str]] = []
results = {"code": "", "test": "", "plan": []}
plan = []
success = False
retries = 0

while not success and retries < self.max_retries:
self.log_progress(
{
"type": "plans",
"status": "started",
}
)
plans = write_plans(
int_chat,
T.TOOL_DESCRIPTIONS,
format_memory(working_memory),
self.planner,
)

if self.verbosity >= 1:
for p in plans:
_LOGGER.info(
f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_infos = retrieve_tools(
plans,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
best_plan, tool_output_str = pick_plan(
int_chat,
plans,
tool_infos["all"],
self.coder,
code_interpreter,
verbosity=self.verbosity,
)
self.log_progress(
{
"type": "plans",
"status": "started",
}
)
plans = write_plans(
int_chat,
T.TOOL_DESCRIPTIONS,
format_memory(working_memory),
self.planner,
)

if best_plan in plans and best_plan in tool_infos:
plan_i = plans[best_plan]
tool_info = tool_infos[best_plan]
else:
if self.verbosity >= 1:
_LOGGER.warning(
f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
plan_i = plans[k]
tool_info = tool_infos[k]

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plan_i,
}
)
if self.verbosity >= 1:
if self.verbosity >= 1:
for p in plans:
_LOGGER.info(
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

results = write_and_test_code(
chat=[
{"role": c["role"], "content": c["content"]} for c in int_chat
],
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
tool_info=tool_info,
tool_output=tool_output_str,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=working_memory,
coder=self.coder,
tester=self.tester,
debugger=self.debugger,
code_interpreter=code_interpreter,
log_progress=self.log_progress,
verbosity=self.verbosity,
media=media_list,
tool_infos = retrieve_tools(
plans,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
best_plan, tool_output_str = pick_plan(
int_chat,
plans,
tool_infos["all"],
self.coder,
code_interpreter,
verbosity=self.verbosity,
)

if best_plan in plans and best_plan in tool_infos:
plan_i = plans[best_plan]
tool_info = tool_infos[best_plan]
else:
if self.verbosity >= 1:
_LOGGER.warning(
f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
plan_i = plans[k]
tool_info = tool_infos[k]

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plan_i,
}
)
if self.verbosity >= 1:
_LOGGER.info(
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)
success = cast(bool, results["success"])
code = cast(str, results["code"])
test = cast(str, results["test"])
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

retries += 1
results = write_and_test_code(
chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
tool_info=tool_info,
tool_output=tool_output_str,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=working_memory,
coder=self.coder,
tester=self.tester,
debugger=self.debugger,
code_interpreter=code_interpreter,
log_progress=self.log_progress,
verbosity=self.verbosity,
media=media_list,
)
success = cast(bool, results["success"])
code = cast(str, results["code"])
test = cast(str, results["test"])
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

execution_result = cast(Execution, results["test_result"])
self.log_progress(
Expand Down

0 comments on commit add68ce

Please sign in to comment.