Skip to content

Commit

Permalink
✨ [Update] BoxMatcher matching criteria
Browse files Browse the repository at this point in the history
Added an additional validity criterium in get_valid_matrix, which masks out anchors from targets, that are too large to predict with the given reg_max and stride values.

Implemented a new function: ensure_one_anchor, which adds a single best suited anchor for valid targets without valid anchors. It is a fallback mechanism, which enables too small or too large targets to be trained to be predicted as well, even if not perfectly.

Fixed the filter_duplicate function to use the topk_masked iou_mat for the selection, which previously sometimes matched invalid targets to anchors with duplicates.

Updated docsstrings across the BoxMatcher functions to match the changes.
  • Loading branch information
Adamusen authored Nov 13, 2024
1 parent 959b9b0 commit 65d34e6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 23 deletions.
2 changes: 1 addition & 1 deletion yolo/tools/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, loss_cfg: LossConfig, vec2box: Vec2Box, class_num: int = 80,
self.dfl = DFLoss(vec2box, reg_max)
self.iou = BoxLoss()

self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box, reg_max)

def separate_anchor(self, anchors):
"""
Expand Down
80 changes: 58 additions & 22 deletions yolo/utils/bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,28 +143,35 @@ def generate_anchors(image_size: List[int], strides: List[int]):


class BoxMatcher:
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
def __init__(self, cfg: MatcherConfig, class_num: int, vec2box, reg_max: int) -> None:
self.class_num = class_num
self.anchors = anchors
self.vec2box = vec2box
self.reg_max = reg_max
for attr_name in cfg:
setattr(self, attr_name, cfg[attr_name])

def get_valid_matrix(self, target_bbox: Tensor):
"""
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor.
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor
and is able to correctly predict it with the available reg_max value.
Args:
target_bbox [batch x targets x 4]: The bounding box of each targets.
target_bbox [batch x targets x 4]: The bounding box of each target.
Returns:
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors.
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps
with the anchors, and the anchor is able to predict the target.
"""
Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3)
anchors = self.anchors[None, None] # add a axis at first, second dimension
x_min, y_min, x_max, y_max = target_bbox[:, :, None].unbind(3)
anchors = self.vec2box.anchor_grid[None, None] # add a axis at first, second dimension
anchors_x, anchors_y = anchors.unbind(dim=3)
target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax)
target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax)
target_on_anchor = target_in_x & target_in_y
return target_on_anchor
x_min_dist, x_max_dist = anchors_x - x_min, x_max - anchors_x
y_min_dist, y_max_dist = anchors_y - y_min, y_max - anchors_y
targets_dist = torch.stack((x_min_dist, y_min_dist, x_max_dist, y_max_dist), dim=-1)
targets_dist /= self.vec2box.scaler[None, None, :, None] # (1, 1, anchors, 1)
min_reg_dist, max_reg_dist = targets_dist.amin(dim=-1), targets_dist.amax(dim=-1)
target_on_anchor = min_reg_dist >= 0
target_in_reg_max = max_reg_dist <= self.reg_max - 1.01
return target_on_anchor & target_in_reg_max

def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -194,36 +201,62 @@ def get_iou_matrix(self, predict_bbox, target_bbox) -> Tensor:
"""
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)

def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
def filter_topk(self, target_matrix: Tensor, grid_mask: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
"""
Filter the top-k suitability of targets for each anchor.
Args:
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
grid_mask [batch x targets x anchors]: The match validity for each target to anchors
topk (int, optional): Number of top scores to retain per anchor.
Returns:
topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
"""
values, indices = target_matrix.topk(topk, dim=-1)
masked_target_matrix = grid_mask * target_matrix
values, indices = masked_target_matrix.topk(topk, dim=-1)
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
topk_targets.scatter_(dim=-1, index=indices, src=values)
topk_masks = topk_targets > 0
return topk_targets, topk_masks
topk_mask = topk_targets > 0
return topk_targets, topk_mask

def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
def ensure_one_anchor(self, target_matrix: Tensor, topk_mask: tensor) -> Tensor:
"""
Filter the maximum suitability target index of each anchor.
Ensures each valid target gets at least one anchor matched based on the unmasked target matrix,
which enables an otherwise invalid match. This enables too small or too large targets to be
learned as well, even if they can't be predicted perfectly.
Args:
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
Returns:
topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
"""
values, indices = target_matrix.max(dim=-1)
best_anchor_mask = F.one_hot(indices, target_matrix.size(-1))
matched_anchor_num = torch.sum(topk_mask, dim=-1)
target_without_anchor = (matched_anchor_num == 0) & (values > 0)
topk_mask = torch.where(target_without_anchor[..., None], best_anchor_mask, topk_mask)
return topk_mask

def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor):
"""
Filter the maximum suitability target index of each anchor based on IoU.
Args:
iou_mat [batch x targets x anchors]: The IoU for each targets-anchors
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
Returns:
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
valid_mask [batch x targets]: Mask indicating the validity of each target
topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
"""
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
masked_iou_mat = topk_mask * iou_mat
max_idx = F.one_hot(masked_iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
topk_mask = torch.where(duplicates, max_idx, topk_mask)
unique_indices = topk_mask.argmax(dim=1)
return unique_indices[..., None], topk_mask.sum(1), topk_mask
Expand Down Expand Up @@ -272,10 +305,13 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens
# get cls matrix (cls prob with each gt class and each predict class)
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)

target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
target_matrix = (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])

# choose topk
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
topk_targets, topk_mask = self.filter_topk(target_matrix, grid_mask, topk=self.topk)

# match best anchor to valid targets without valid anchors
topk_mask = self.ensure_one_anchor(target_matrix, topk_mask)

# delete one anchor pred assign to mutliple gts
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
Expand Down Expand Up @@ -304,7 +340,7 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
self.strides = anchor_cfg.strides
else:
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
self.strides = self.create_auto_anchor(model, image_size)

anchor_grid, scaler = generate_anchors(image_size, self.strides)
Expand Down

0 comments on commit 65d34e6

Please sign in to comment.