Skip to content

Commit

Permalink
added example template matching use case
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 23, 2024
1 parent 6af8d9e commit c5135c7
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 0 deletions.
Binary file added examples/custom_tools/pid.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/custom_tools/pid_template.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 49 additions & 0 deletions examples/custom_tools/run_custom_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import vision_agent as va
from vision_agent.image_utils import get_image_size, normalize_bbox
from vision_agent.tools import Tool, register_tool

from template_match import template_matching_with_rotation


@register_tool
class TemplateMatch(Tool):
name = "template_match_"
description = "'template_match_' takes a template image and finds all locations where that template appears in the input image."
usage = {
"required_parameters": [
{"name": "target_image", "type": "str"},
{"name": "template_image", "type": "str"},
],
"examples": [
{
"scenario": "Can you detect the location of the template in the target image? Image name: target.png Reference image: template.png",
"parameters": {
"target_image": "target.png",
"template_image": "template.png",
},
},
],
}

def __call__(self, target_image: str, template_image: str) -> dict:
image_size = get_image_size(target_image)
matches = template_matching_with_rotation(target_image, template_image)
matches["bboxes"] = [
normalize_bbox(box, image_size) for box in matches["bboxes"]
]
return matches


if __name__ == "__main__":
agent = va.agent.VisionAgent(verbose=True)
resp, tools = agent.chat_with_workflow(
[
{
"role": "user",
"content": "Can you find the locations of the pid_template.png in pid.png and tell me if any are nearby 'NOTE 5'?",
}
],
image="pid.png",
reference_data={"image": "pid_template.png"},
visualize_output=True,
)
96 changes: 96 additions & 0 deletions examples/custom_tools/template_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import cv2
import numpy as np
import torch
from torchvision.ops import nms


def rotate_image(mat, angle):
"""
Rotates an image (angle in degrees) and expands image to avoid cropping
"""

height, width = mat.shape[:2] # image shape has 3 dimensions
image_center = (
width / 2,
height / 2,
) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape

rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)

# rotation calculates the cos and sin, taking absolutes of those.
abs_cos = abs(rotation_mat[0, 0])
abs_sin = abs(rotation_mat[0, 1])

# find the new width and height bounds
bound_w = int(height * abs_sin + width * abs_cos)
bound_h = int(height * abs_cos + width * abs_sin)

# subtract old image center (bringing image back to origo) and adding the new image center coordinates
rotation_mat[0, 2] += bound_w / 2 - image_center[0]
rotation_mat[1, 2] += bound_h / 2 - image_center[1]

# rotate image with the new bounds and translated rotation matrix
rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))
return rotated_mat


def template_matching_with_rotation(
main_image_path: str,
template_path: str,
max_rotation: int = 360,
step: int = 90,
threshold: float = 0.75,
visualize: bool = False,
) -> dict:
main_image = cv2.imread(main_image_path)
template = cv2.imread(template_path)
template_height, template_width = template.shape[:2]

# Convert images to grayscale
main_image_gray = cv2.cvtColor(main_image, cv2.COLOR_BGR2GRAY)
template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)

boxes = []
scores = []

for angle in range(0, max_rotation, step):
# Rotate the template
rotated_template = rotate_image(template_gray, angle)

# Perform template matching
result = cv2.matchTemplate(
main_image_gray,
rotated_template,
cv2.TM_CCOEFF_NORMED,
)

y_coords, x_coords = np.where(result >= threshold)
for x, y in zip(x_coords, y_coords):
boxes.append(
(x, y, x + rotated_template.shape[1], y + rotated_template.shape[0])
)
scores.append(result[y, x])

indices = (
nms(
torch.tensor(boxes).float(),
torch.tensor(scores).float(),
0.2,
)
.numpy()
.tolist()
)
boxes = [boxes[i] for i in indices]
scores = [scores[i] for i in indices]

if visualize:
# Draw a rectangle around the best match
for box in boxes:
cv2.rectangle(main_image, (box[0], box[1]), (box[2], box[3]), 255, 2)

# Display the result
cv2.imshow("Best Match", main_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

return {"bboxes": boxes, "scores": scores}

0 comments on commit c5135c7

Please sign in to comment.