Skip to content

Commit

Permalink
pass plan thoughts to coder
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 19, 2024
1 parent 70b63fa commit d244cf3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
32 changes: 20 additions & 12 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def pick_plan(
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
max_retries: int = 3,
) -> Tuple[str, str]:
) -> Tuple[Dict[str, str], str]:
log_progress(
{
"type": "log",
Expand Down Expand Up @@ -233,10 +233,10 @@ def pick_plan(
chat[-1]["content"] = prompt

count = 0
best_plan = None
while best_plan is None and count < max_retries:
plan_thoughts = None
while plan_thoughts is None and count < max_retries:
try:
best_plan = extract_json(model(chat, stream=False)) # type: ignore
plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore
except JSONDecodeError as e:
_LOGGER.exception(
f"Error while extracting JSON during picking best plan {str(e)}"
Expand All @@ -245,30 +245,31 @@ def pick_plan(
count += 1

if (
best_plan is None
or "best_plan" not in best_plan
or ("best_plan" in best_plan and best_plan["best_plan"] not in plans)
plan_thoughts is None
or "best_plan" not in plan_thoughts
or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
):
best_plan = {"best_plan": list(plans.keys())[0]}
plan_thoughts = {"best_plan": list(plans.keys())[0]}

if verbosity >= 1:
_LOGGER.info(f"Best plan:\n{best_plan}")
_LOGGER.info(f"Best plan:\n{plan_thoughts}")
log_progress(
{
"type": "log",
"log_content": "Picked best plan",
"status": "completed",
"payload": plans[best_plan["best_plan"]],
"payload": plans[plan_thoughts["best_plan"]],
}
)
return best_plan["best_plan"], tool_output_str
return plan_thoughts, tool_output_str


def write_code(
coder: LMM,
chat: List[Message],
plan: str,
tool_info: str,
plan_thoughts: str,
tool_output: str,
feedback: str,
) -> str:
Expand All @@ -281,6 +282,7 @@ def write_code(
docstring=tool_info,
question=FULL_TASK.format(user_request=user_request, subtasks=plan),
tool_output=tool_output,
plan_thoughts=plan_thoughts,
feedback=feedback,
)
chat[-1]["content"] = prompt
Expand Down Expand Up @@ -316,6 +318,7 @@ def write_and_test_code(
plan: str,
tool_info: str,
tool_output: str,
plan_thoughts: str,
tool_utils: str,
working_memory: List[Dict[str, str]],
coder: LMM,
Expand All @@ -340,6 +343,7 @@ def write_and_test_code(
plan,
tool_info,
tool_output,
plan_thoughts,
format_memory(working_memory),
)
test = write_test(
Expand Down Expand Up @@ -760,7 +764,7 @@ def chat_with_workflow(
)

if test_multi_plan:
best_plan, tool_output_str = pick_plan(
plan_thoughts, tool_output_str = pick_plan(
int_chat,
plans,
tool_infos["all"],
Expand All @@ -770,9 +774,12 @@ def chat_with_workflow(
self.log_progress,
verbosity=self.verbosity,
)
best_plan = plan_thoughts["best_plan"]
plan_thoughts = plan_thoughts["thoughts"]
else:
best_plan = list(plans.keys())[0]
tool_output_str = ""
plan_thoughts = ""

if best_plan in plans and best_plan in tool_infos:
plan_i = plans[best_plan]
Expand Down Expand Up @@ -807,6 +814,7 @@ def chat_with_workflow(
+ "\n-".join([e for e in plan_i["instructions"]]),
tool_info=tool_info,
tool_output=tool_output_str,
plan_thoughts=plan_thoughts,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=working_memory,
coder=self.coder,
Expand Down
13 changes: 9 additions & 4 deletions vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,14 @@
```python
import numpy as np
from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, florence2_sam2_video_tracking
# sample at 1 FPS and use the first 10 frames to reduce processing time
frames = extract_frames("video.mp4", 1)
frames = [f[0] for f in frames][:10]
# import numpy for remove_array auxiliary function
import numpy as np
def remove_arrays(o):
if isinstance(o, list):
return [remove_arrays(e) for e in o]
Expand Down Expand Up @@ -179,7 +180,7 @@ def remove_arrays(o):
3. Output a JSON object with the following format:
{{
"predicted_answer": str # the answer you would expect from the best plan
"thoughts": str # your thought process for choosing the best plan
"thoughts": str # your thought process for choosing the best plan, any adjustments you would make to the plan
"best_plan": str # the best plan you have chosen
}}
"""
Expand All @@ -202,15 +203,19 @@ def remove_arrays(o):
**User Instructions**:
{question}
**Tool Output**:
**Tool Outputs**:
{tool_output}
**Tool Output Thoughts**:
{plan_thoughts}
**Previous Feedback**:
{feedback}
**Instructions**:
1. **Understand and Clarify**: Make sure you understand the task.
2. **Algorithm/Method Selection**: Decide on the most efficient method, use the tool output to guide your decision.
2. **Algorithm/Method Selection**: Decide on the most efficient method, use the tool outputs to guide your decision.
3. **Pseudocode Creation**: Write down the steps you will follow in pseudocode.
4. **Code Generation**: Translate your pseudocode into executable Python code. Ensure you use correct arguments, remember coordinates are always returned normalized from `vision_agent.tools`. All images from `vision_agent.tools` are in RGB format, red is (255, 0, 0) and blue is (0, 0, 255).
"""
Expand Down

0 comments on commit d244cf3

Please sign in to comment.