Skip to content

Commit

Permalink
refactor to make function simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 15, 2024
1 parent 9773ace commit e99f838
Showing 1 changed file with 62 additions and 45 deletions.
107 changes: 62 additions & 45 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def self_reflect(
return reflect_model(prompt)


def parse_reflect(reflect: str) -> Dict[str, Any]:
def parse_reflect(reflect: str) -> Any:
try:
return parse_json(reflect)
except Exception:
Expand All @@ -280,6 +280,64 @@ def parse_reflect(reflect: str) -> Dict[str, Any]:
return {"Finish": finish, "Reflection": reflect}


def _handle_extract_frames(image_to_data: Dict[str, Dict], tool_result: Dict) -> Dict[str, Dict]:
image_to_data = image_to_data.copy()
# handle extract_frames_ case, useful if it extracts frames but doesn't do
# any following processing
for video_file_output in tool_result["call_results"]:
for frame, _ in video_file_output:
image = frame
if image not in image_to_data:
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}
return image_to_data


def _handle_viz_tools(image_to_data: Dict[str, Dict], tool_result: Dict) -> Dict[str, Dict]:
image_to_data = image_to_data.copy()

# handle grounding_sam_ and grounding_dino_
parameters = tool_result["parameters"]
# parameters can either be a dictionary or list, parameters can also be malformed
# becaus the LLM builds them
if isinstance(parameters, dict):
if "image" not in parameters:
return image_to_data
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if len(tool_result["parameters"]) < 1 or (
"image" not in tool_result["parameters"][0]
):
return image_to_data

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict) or "bboxes" not in call_result:
return image_to_data

# if the call was successful, then we can add the image data
image = param["image"]
if image not in image_to_data:
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])

return image_to_data


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
Expand All @@ -292,50 +350,9 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
continue

if tool_result["tool_name"] == "extract_frames_":
for video_file_output in tool_result["call_results"]:
for frame, _ in video_file_output:
image = frame
if image not in image_to_data:
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}
else: # handle grounding_sam_ and grounding_dino_
parameters = tool_result["parameters"]
# parameters can either be a dictionary or list, parameters can also be malformed
# becaus the LLM builds them
if isinstance(parameters, dict):
if "image" not in parameters:
continue
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if len(tool_result["parameters"]) < 1 or (
"image" not in tool_result["parameters"][0]
):
continue

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict) or "bboxes" not in call_result:
continue

# if the call was successful, then we can add the image data
image = param["image"]
if image not in image_to_data:
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])
image_to_data = _handle_extract_frames(image_to_data, tool_result)
else:
image_to_data = _handle_viz_tools(image_to_data, tool_result)

visualized_images = []
for image in image_to_data:
Expand Down

0 comments on commit e99f838

Please sign in to comment.