Skip to content

Commit

Permalink
make imports easier, pass more code info
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 22, 2024
1 parent 957ed56 commit 152ac13
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class DefaultImports:
"""Container for default imports used in the code execution."""

common_imports = [
"import os",
"import numpy as np",
"from vision_agent.tools import *",
"from typing import *",
"from pillow_heif import register_heif_opener",
"register_heif_opener()",
Expand Down Expand Up @@ -174,7 +177,10 @@ def pick_plan(

# retry if the tool output is empty or code fails
count = 0
while (not tool_output.success or tool_output_str == "") and count < max_retries:
while (
not tool_output.success
or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
) and count < max_retries:
prompt = TEST_PLANS.format(
docstring=tool_info,
plans=plan_str,
Expand Down Expand Up @@ -213,6 +219,7 @@ def pick_plan(
if verbosity == 2:
_print_code("Code and test after attempted fix:", code)
_LOGGER.info(f"Code execution result after attempt {count + 1}")
_LOGGER.info(f"{tool_output_str}")

count += 1

Expand Down Expand Up @@ -247,8 +254,12 @@ def pick_plan(
or "best_plan" not in plan_thoughts
or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
):
_LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}")
plan_thoughts = {"best_plan": list(plans.keys())[0]}

if "thoughts" not in plan_thoughts:
plan_thoughts["thoughts"] = ""

if verbosity >= 1:
_LOGGER.info(f"Best plan:\n{plan_thoughts}")
log_progress(
Expand All @@ -259,7 +270,7 @@ def pick_plan(
"payload": plans[plan_thoughts["best_plan"]],
}
)
return plan_thoughts, tool_output_str
return plan_thoughts, "```python\n" + code + "\n```\n" + tool_output_str


def write_code(
Expand Down Expand Up @@ -844,7 +855,8 @@ def chat_with_workflow(
"code": DefaultImports.prepend_imports(code),
"test": test,
"test_result": execution_result,
"plan": plan,
"plans": plans,
"plan_thoughts": plan_thoughts_str,
"working_memory": working_memory,
}

Expand Down

0 comments on commit 152ac13

Please sign in to comment.