Skip to content

Commit

Permalink
visualized output/reflection to handle extract_frames_
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 15, 2024
1 parent fa460b8 commit 4001524
Showing 1 changed file with 52 additions and 36 deletions.
88 changes: 52 additions & 36 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,44 +278,59 @@ def parse_reflect(reflect: str) -> bool:
def visualize_result(all_tool_results: List[Dict]) -> List[str]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
if tool_result["tool_name"] not in ["grounding_sam_", "grounding_dino_"]:
# only handle bbox/mask tools or frame extraction
if tool_result["tool_name"] not in [
"grounding_sam_",
"grounding_dino_",
"extract_frames_",
]:
continue

parameters = tool_result["parameters"]
# parameters can either be a dictionary or list, parameters can also be malformed
# becaus the LLM builds them
if isinstance(parameters, dict):
if "image" not in parameters:
continue
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if len(tool_result["parameters"]) < 1 or (
"image" not in tool_result["parameters"][0]
):
continue

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict):
continue
if "bboxes" not in call_result:
continue

# if the call was successful, then we can add the image data
image = param["image"]
if image not in image_to_data:
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])
if tool_result["tool_name"] == "extract_frames_":
for video_file_output in tool_result["call_results"]:
for frame, _ in video_file_output:
image = frame
if image not in image_to_data:
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}
else: # handle grounding_sam_ and grounding_dino_
parameters = tool_result["parameters"]
# parameters can either be a dictionary or list, parameters can also be malformed
# becaus the LLM builds them
if isinstance(parameters, dict):
if "image" not in parameters:
continue
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if len(tool_result["parameters"]) < 1 or (
"image" not in tool_result["parameters"][0]
):
continue

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict) or "bboxes" not in call_result:
continue

# if the call was successful, then we can add the image data
image = param["image"]
if image not in image_to_data:
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])

visualized_images = []
for image in image_to_data:
Expand Down Expand Up @@ -459,6 +474,7 @@ def chat_with_workflow(
self.answer_model, question, answers, reflections
)

__import__("ipdb").set_trace()
visualized_output = visualize_result(all_tool_results)
all_tool_results.append({"visualized_output": visualized_output})
reflection = self_reflect(
Expand Down

0 comments on commit 4001524

Please sign in to comment.