Skip to content

Commit

Permalink
updated vision agent to reflect on multiple images
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 17, 2024
1 parent ee80ba2 commit 56fa5d6
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from PIL import Image
from tabulate import tabulate
Expand Down Expand Up @@ -264,7 +264,7 @@ def self_reflect(
tools: Dict[int, Any],
tool_result: List[Dict],
final_answer: str,
image: Optional[Union[str, Path]] = None,
images: Optional[Sequence[Union[str, Path]]] = None,
) -> str:
prompt = VISION_AGENT_REFLECTION.format(
question=question,
Expand All @@ -275,10 +275,10 @@ def self_reflect(
)
if (
issubclass(type(reflect_model), LMM)
and image is not None
and Path(image).suffix in [".jpg", ".jpeg", ".png"]
and images is not None
and all([Path(image).suffix in [".jpg", ".jpeg", ".png"] for image in images])
):
return reflect_model(prompt, image=image) # type: ignore
return reflect_model(prompt, images=images) # type: ignore
return reflect_model(prompt)


Expand Down Expand Up @@ -357,7 +357,7 @@ def _handle_viz_tools(
return image_to_data


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
# only handle bbox/mask tools or frame extraction
Expand Down Expand Up @@ -407,7 +407,7 @@ def __init__(
task_model: Optional[Union[LLM, LMM]] = None,
answer_model: Optional[Union[LLM, LMM]] = None,
reflect_model: Optional[Union[LLM, LMM]] = None,
max_retries: int = 2,
max_retries: int = 3,
verbose: bool = False,
report_progress_callback: Optional[Callable[[str], None]] = None,
):
Expand Down Expand Up @@ -519,13 +519,19 @@ def chat_with_workflow(

visualized_output = visualize_result(all_tool_results)
all_tool_results.append({"visualized_output": visualized_output})
if len(visualized_output) > 0:
reflection_images = visualized_output
elif image is not None:
reflection_images = [image]
else:
reflection_images = None
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
visualized_output[0] if len(visualized_output) > 0 else image,
reflection_images,
)
self.log_progress(f"Reflection: {reflection}")
parsed_reflection = parse_reflect(reflection)
Expand Down

0 comments on commit 56fa5d6

Please sign in to comment.