Skip to content

Commit

Permalink
updated plan structure, fixed bug with testing plan tool output
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 24, 2024
1 parent cea438b commit 5e2689c
Showing 1 changed file with 13 additions and 20 deletions.
33 changes: 13 additions & 20 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def format_memory(memory: List[Dict[str, str]]) -> str:
def format_plans(plans: Dict[str, Any]) -> str:
plan_str = ""
for k, v in plans.items():
plan_str += f"{k}:\n"
plan_str += "-" + "\n-".join([e["instructions"] for e in v])
plan_str += "\n" + f"{k}: {v['thoughts']}\n"
plan_str += " -" + "\n -".join([e for e in v["instructions"]])

return plan_str

Expand Down Expand Up @@ -229,9 +229,7 @@ def pick_plan(
"status": "completed" if tool_output.success else "failed",
}
)
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]
tool_output_str = tool_output.text().strip()

if verbosity == 2:
_print_code("Code and test after attempted fix:", code)
Expand All @@ -255,14 +253,15 @@ def pick_plan(

count = 0
best_plan = None
while best_plan is None or count < max_retries:
while best_plan is None and count < max_retries:
try:
best_plan = extract_json(model(chat, stream=False)) # type: ignore
except JSONDecodeError as _:
_LOGGER.exception("Error while extracting JSON during picking best plan")
pass
count += 1

if count == max_retries:
if best_plan is None:
best_plan = {"best_plan": list(plans.keys())[0]}

if verbosity >= 1:
Expand Down Expand Up @@ -537,7 +536,7 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None:


def retrieve_tools(
plans: Dict[str, List[Dict[str, str]]],
plans: Dict[str, Dict[str, Any]],
tool_recommender: Sim,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
Expand All @@ -554,8 +553,8 @@ def retrieve_tools(
tool_lists: Dict[str, List[Dict[str, str]]] = {}
for k, plan in plans.items():
tool_lists[k] = []
for task in plan:
tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
for task in plan["instructions"]:
tools = tool_recommender.top_k(task, k=2, thresh=0.3)
tool_info.extend([e["doc"] for e in tools])
tool_desc.extend([e["desc"] for e in tools])
tool_lists[k].extend(
Expand Down Expand Up @@ -749,14 +748,7 @@ def chat_with_workflow(
if self.verbosity >= 1:
for p in plans:
# tabulate will fail if the keys are not the same for all elements
p_fixed = [
{
"instructions": (
e["instructions"] if "instructions" in e else ""
)
}
for e in plans[p]
]
p_fixed = [{"instructions": e} for e in plans[p]["instructions"]]
_LOGGER.info(
f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)
Expand Down Expand Up @@ -805,13 +797,14 @@ def chat_with_workflow(
)

if self.verbosity >= 1:
plan_i_fixed = [{"instructions": e} for e in plan_i["instructions"]]
_LOGGER.info(
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
f"Picked best plan:\n{tabulate(tabular_data=plan_i_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

results = write_and_test_code(
chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
plan=f"\n{plan_i['thoughts']}\n-" + "\n-".join([e for e in plan_i["instructions"]]),
tool_info=tool_info,
tool_output=tool_output_str,
tool_utils=T.UTILITIES_DOCSTRING,
Expand Down

0 comments on commit 5e2689c

Please sign in to comment.