From 78e3b7207b2b6aae1fa5a18aa328a5dc013e2af9 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 25 Apr 2024 21:00:56 -0700 Subject: [PATCH] Add mask key for visualization (#66) * add mask to necessary keys * fix heat map append --- vision_agent/agent/vision_agent.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 5afd3f5c..d01569fc 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -346,7 +346,9 @@ def _handle_viz_tools( # 2. return a dictionary but not have the necessary keys if not isinstance(call_result, dict) or ( - "bboxes" not in call_result and "heat_map" not in call_result + "bboxes" not in call_result + and "mask" not in call_result + and "heat_map" not in call_result ): return image_to_data @@ -366,7 +368,8 @@ def _handle_viz_tools( image_to_data[image]["scores"].extend(call_result.get("scores", [])) image_to_data[image]["masks"].extend(call_result.get("masks", [])) # only single heatmap is returned - image_to_data[image]["heat_map"].append(call_result.get("heat_map", [])) + if "heat_map" in call_result: + image_to_data[image]["heat_map"].append(call_result["heat_map"]) if "mask_shape" in call_result: image_to_data[image]["mask_shape"] = call_result["mask_shape"]