From 56fa5d650e93a9638ef6b4c480b19ee21bd0ca4f Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 16 Apr 2024 17:46:19 -0700 Subject: [PATCH] updated vision agent to reflect on multiple images --- vision_agent/agent/vision_agent.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 4c193aae..a3f09b82 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -3,7 +3,7 @@ import sys import tempfile from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from PIL import Image from tabulate import tabulate @@ -264,7 +264,7 @@ def self_reflect( tools: Dict[int, Any], tool_result: List[Dict], final_answer: str, - image: Optional[Union[str, Path]] = None, + images: Optional[Sequence[Union[str, Path]]] = None, ) -> str: prompt = VISION_AGENT_REFLECTION.format( question=question, @@ -275,10 +275,10 @@ def self_reflect( ) if ( issubclass(type(reflect_model), LMM) - and image is not None - and Path(image).suffix in [".jpg", ".jpeg", ".png"] + and images is not None + and all([Path(image).suffix in [".jpg", ".jpeg", ".png"] for image in images]) ): - return reflect_model(prompt, image=image) # type: ignore + return reflect_model(prompt, images=images) # type: ignore return reflect_model(prompt) @@ -357,7 +357,7 @@ def _handle_viz_tools( return image_to_data -def visualize_result(all_tool_results: List[Dict]) -> List[str]: +def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]: image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: # only handle bbox/mask tools or frame extraction @@ -407,7 +407,7 @@ def __init__( task_model: Optional[Union[LLM, LMM]] = None, answer_model: Optional[Union[LLM, LMM]] = None, reflect_model: Optional[Union[LLM, LMM]] = None, - max_retries: int = 2, + max_retries: int = 3, verbose: bool = False, report_progress_callback: Optional[Callable[[str], None]] = None, ): @@ -519,13 +519,19 @@ def chat_with_workflow( visualized_output = visualize_result(all_tool_results) all_tool_results.append({"visualized_output": visualized_output}) + if len(visualized_output) > 0: + reflection_images = visualized_output + elif image is not None: + reflection_images = [image] + else: + reflection_images = None reflection = self_reflect( self.reflect_model, question, self.tools, all_tool_results, final_answer, - visualized_output[0] if len(visualized_output) > 0 else image, + reflection_images, ) self.log_progress(f"Reflection: {reflection}") parsed_reflection = parse_reflect(reflection)