diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 62dae768..34be1fc9 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -314,6 +314,7 @@ def _handle_extract_frames( image_to_data[image] = { "bboxes": [], "masks": [], + "heat_map": [], "labels": [], "scores": [], } @@ -343,10 +344,9 @@ def _handle_viz_tools( # Calls can fail, so we need to check if the call was successful. It can either: # 1. return a str or some error that's not a dictionary # 2. return a dictionary but not have the necessary keys + if not isinstance(call_result, dict) or ( - "bboxes" not in call_result - and "masks" not in call_result - and "heat_map" not in call_result + "bboxes" not in call_result and "heat_map" not in call_result ): return image_to_data diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 4f7b69cd..23dc8506 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -238,7 +238,7 @@ def overlay_heat_map( elif isinstance(image, np.ndarray): image = Image.fromarray(image) - if "masks" not in heat_map: + if "heat_map" not in heat_map: return image.convert("RGB") image = image.convert("L")