diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 8e02db0a..87b600e1 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -268,7 +268,7 @@ def self_reflect( return reflect_model(prompt) -def parse_reflect(reflect: str) -> Dict[str, Any]: +def parse_reflect(reflect: str) -> Any: try: return parse_json(reflect) except Exception: @@ -280,6 +280,64 @@ def parse_reflect(reflect: str) -> Dict[str, Any]: return {"Finish": finish, "Reflection": reflect} +def _handle_extract_frames(image_to_data: Dict[str, Dict], tool_result: Dict) -> Dict[str, Dict]: + image_to_data = image_to_data.copy() + # handle extract_frames_ case, useful if it extracts frames but doesn't do + # any following processing + for video_file_output in tool_result["call_results"]: + for frame, _ in video_file_output: + image = frame + if image not in image_to_data: + image_to_data[image] = { + "bboxes": [], + "masks": [], + "labels": [], + "scores": [], + } + return image_to_data + + +def _handle_viz_tools(image_to_data: Dict[str, Dict], tool_result: Dict) -> Dict[str, Dict]: + image_to_data = image_to_data.copy() + + # handle grounding_sam_ and grounding_dino_ + parameters = tool_result["parameters"] + # parameters can either be a dictionary or list, parameters can also be malformed + # becaus the LLM builds them + if isinstance(parameters, dict): + if "image" not in parameters: + return image_to_data + parameters = [parameters] + elif isinstance(tool_result["parameters"], list): + if len(tool_result["parameters"]) < 1 or ( + "image" not in tool_result["parameters"][0] + ): + return image_to_data + + for param, call_result in zip(parameters, tool_result["call_results"]): + # calls can fail, so we need to check if the call was successful + if not isinstance(call_result, dict) or "bboxes" not in call_result: + return image_to_data + + # if the call was successful, then we can add the image data + image = param["image"] + if image not in image_to_data: + image_to_data[image] = { + "bboxes": [], + "masks": [], + "labels": [], + "scores": [], + } + + image_to_data[image]["bboxes"].extend(call_result["bboxes"]) + image_to_data[image]["labels"].extend(call_result["labels"]) + image_to_data[image]["scores"].extend(call_result["scores"]) + if "masks" in call_result: + image_to_data[image]["masks"].extend(call_result["masks"]) + + return image_to_data + + def visualize_result(all_tool_results: List[Dict]) -> List[str]: image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: @@ -292,50 +350,9 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]: continue if tool_result["tool_name"] == "extract_frames_": - for video_file_output in tool_result["call_results"]: - for frame, _ in video_file_output: - image = frame - if image not in image_to_data: - image_to_data[image] = { - "bboxes": [], - "masks": [], - "labels": [], - "scores": [], - } - else: # handle grounding_sam_ and grounding_dino_ - parameters = tool_result["parameters"] - # parameters can either be a dictionary or list, parameters can also be malformed - # becaus the LLM builds them - if isinstance(parameters, dict): - if "image" not in parameters: - continue - parameters = [parameters] - elif isinstance(tool_result["parameters"], list): - if len(tool_result["parameters"]) < 1 or ( - "image" not in tool_result["parameters"][0] - ): - continue - - for param, call_result in zip(parameters, tool_result["call_results"]): - # calls can fail, so we need to check if the call was successful - if not isinstance(call_result, dict) or "bboxes" not in call_result: - continue - - # if the call was successful, then we can add the image data - image = param["image"] - if image not in image_to_data: - image_to_data[image] = { - "bboxes": [], - "masks": [], - "labels": [], - "scores": [], - } - - image_to_data[image]["bboxes"].extend(call_result["bboxes"]) - image_to_data[image]["labels"].extend(call_result["labels"]) - image_to_data[image]["scores"].extend(call_result["scores"]) - if "masks" in call_result: - image_to_data[image]["masks"].extend(call_result["masks"]) + image_to_data = _handle_extract_frames(image_to_data, tool_result) + else: + image_to_data = _handle_viz_tools(image_to_data, tool_result) visualized_images = [] for image in image_to_data: