Skip to content

Commit

Permalink
Merge pull request #1249 from mikel-brostrom/centroid-asso-support
Browse files Browse the repository at this point in the history
Centroid association support for OCSORT and DeepOCSORT
  • Loading branch information
mikel-brostrom authored Jan 12, 2024
2 parents 6e17dac + 52f2424 commit 807be8a
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 60 deletions.
2 changes: 1 addition & 1 deletion boxmot/tracker_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def create_tracker(tracker_type, tracker_config, reid_weights, device, half, per
det_thresh=cfg.det_thresh,
max_age=cfg.max_age,
min_hits=cfg.min_hits,
iou_threshold=cfg.iou_thresh,
asso_threshold=cfg.iou_thresh,
delta_t=cfg.delta_t,
asso_func=cfg.asso_func,
inertia=cfg.inertia,
Expand Down
3 changes: 3 additions & 0 deletions boxmot/trackers/deepocsort/deep_ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,13 @@ def update(self, dets, img):
matched, unmatched_dets, unmatched_trks = associate(
dets[:, 0:5],
trks,
self.asso_func,
self.iou_threshold,
velocities,
k_observations,
self.inertia,
img.shape[1], # w
img.shape[0], # h
stage1_emb_cost,
self.w_association_emb,
self.aw_off,
Expand Down
24 changes: 13 additions & 11 deletions boxmot/trackers/ocsort/ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from boxmot.motion.kalman_filters.ocsort_kf import KalmanFilter
from boxmot.utils.association import associate, linear_assignment
from boxmot.utils.iou import get_asso_func
from boxmot.utils.iou import run_asso_func


def k_previous_obs(observations, cur_age, k):
Expand Down Expand Up @@ -193,7 +194,7 @@ def __init__(
det_thresh=0.2,
max_age=30,
min_hits=3,
iou_threshold=0.3,
asso_threshold=0.3,
delta_t=3,
asso_func="iou",
inertia=0.2,
Expand All @@ -204,7 +205,7 @@ def __init__(
"""
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.asso_threshold = asso_threshold
self.trackers = []
self.frame_count = 0
self.det_thresh = det_thresh
Expand All @@ -214,7 +215,7 @@ def __init__(
self.use_byte = use_byte
KalmanBoxTracker.count = 0

def update(self, dets, _):
def update(self, dets, img):
"""
Params:
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
Expand All @@ -235,6 +236,7 @@ def update(self, dets, _):
), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6"

self.frame_count += 1
h, w = img.shape[0:2]

dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)])
confs = dets[:, 4]
Expand Down Expand Up @@ -279,7 +281,7 @@ def update(self, dets, _):
First round of association
"""
matched, unmatched_dets, unmatched_trks = associate(
dets[:, 0:5], trks, self.iou_threshold, velocities, k_observations, self.inertia
dets[:, 0:5], trks, self.asso_func, self.asso_threshold, velocities, k_observations, self.inertia, w, h
)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :5], dets[m[0], 5], dets[m[0], 6])
Expand All @@ -294,17 +296,17 @@ def update(self, dets, _):
dets_second, u_trks
) # iou between low score detections and unmatched tracks
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
if iou_left.max() > self.asso_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
matched_indices = linear_assignment(-iou_left)
to_remove_trk_indices = []
for m in matched_indices:
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
if iou_left[m[0], m[1]] < self.asso_threshold:
continue
self.trackers[trk_ind].update(
dets_second[det_ind, :5], dets_second[det_ind, 5], dets_second[det_ind, 6]
Expand All @@ -317,11 +319,11 @@ def update(self, dets, _):
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
left_dets = dets[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
iou_left = self.asso_func(left_dets, left_trks)
iou_left = run_asso_func(self.asso_func, left_dets, left_trks, w, h)
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
if iou_left.max() > self.asso_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
Expand All @@ -330,7 +332,7 @@ def update(self, dets, _):
to_remove_trk_indices = []
for m in rematched_indices:
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
if iou_left[m[0], m[1]] < self.asso_threshold:
continue
self.trackers[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5], dets[det_ind, 6])
to_remove_det_indices.append(det_ind)
Expand Down
2 changes: 1 addition & 1 deletion boxmot/trackers/strongsort/sort/linear_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from scipy.optimize import linear_sum_assignment

from ....utils.matching import chi2inv95
from boxmot.utils.matching import chi2inv95

INFTY_COST = 1e5

Expand Down
9 changes: 7 additions & 2 deletions boxmot/utils/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from boxmot.utils.iou import iou_batch
from boxmot.utils.iou import iou_batch, centroid_batch, run_asso_func


def speed_direction_batch(dets, tracks):
Expand Down Expand Up @@ -111,14 +111,18 @@ def compute_aw_max_metric(emb_cost, w_association_emb, bottom=0.5):
def associate(
detections,
trackers,
asso_func,
iou_threshold,
velocities,
previous_obs,
vdc_weight,
w,
h,
emb_cost=None,
w_assoc_emb=None,
aw_off=None,
aw_param=None,

):
if len(trackers) == 0:
return (
Expand All @@ -139,7 +143,8 @@ def associate(
valid_mask = np.ones(previous_obs.shape[0])
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0

iou_matrix = iou_batch(detections, trackers)
iou_matrix = run_asso_func(asso_func, detections, trackers, w, h)
#iou_matrix = iou_batch(detections, trackers)
scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
Expand Down
34 changes: 29 additions & 5 deletions boxmot/utils/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np


def iou_batch(bboxes1, bboxes2):
def iou_batch(bboxes1, bboxes2) -> np.ndarray:
"""
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
"""
Expand All @@ -25,7 +25,7 @@ def iou_batch(bboxes1, bboxes2):
return o


def giou_batch(bboxes1, bboxes2):
def giou_batch(bboxes1, bboxes2) -> np.ndarray:
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
Expand Down Expand Up @@ -62,7 +62,7 @@ def giou_batch(bboxes1, bboxes2):
return giou


def diou_batch(bboxes1, bboxes2):
def diou_batch(bboxes1, bboxes2) -> np.ndarray:
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
Expand Down Expand Up @@ -105,7 +105,7 @@ def diou_batch(bboxes1, bboxes2):
return (diou + 1) / 2.0 # resize from (-1,1) to (0,1)


def ciou_batch(bboxes1, bboxes2):
def ciou_batch(bboxes1, bboxes2) -> np.ndarray:
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
Expand Down Expand Up @@ -161,7 +161,7 @@ def ciou_batch(bboxes1, bboxes2):
return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1)


def centroid_batch(bboxes1, bboxes2, w, h):
def centroid_batch(bboxes1, bboxes2, w, h) -> np.ndarray:
"""
Computes the normalized centroid distance between two sets of bounding boxes.
Bounding boxes are in the format [x1, y1, x2, y2].
Expand All @@ -188,6 +188,30 @@ def centroid_batch(bboxes1, bboxes2, w, h):
return 1 - normalized_distances


def run_asso_func(func, *args):
"""
Wrapper function that checks the inputs to the association functions
and then call either one of the iou association functions or centroid.
Parameters:
func: The batch function to call (either *iou*_batch or centroid_batch).
*args: Variable length argument list, containing either bounding boxes and optionally size parameters.
"""
if func not in [iou_batch, giou_batch, diou_batch, ciou_batch, centroid_batch]:
raise ValueError("Invalid function specified. Must be either '(g,d,c, )iou_batch' or 'centroid_batch'.")

if func in (iou_batch, giou_batch, diou_batch, ciou_batch):
if len(args) != 4 or not all(isinstance(arg, (list, np.ndarray)) for arg in args[0:2]):
raise ValueError("Invalid arguments for iou_batch. Expected two bounding boxes.")
return func(*args[0:2])
elif func is centroid_batch:
if len(args) != 4 or not all(isinstance(arg, (list, np.ndarray)) for arg in args[:2]) or not all(isinstance(arg, (int)) for arg in args[2:]):
raise ValueError("Invalid arguments for centroid_batch. Expected two bounding boxes and two size parameters.")
return func(*args)
else:
raise ValueError("No such association method")


def get_asso_func(asso_mode):
ASSO_FUNCS = {
"iou": iou_batch,
Expand Down
48 changes: 8 additions & 40 deletions boxmot/utils/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import scipy
import torch
from scipy.spatial.distance import cdist
from boxmot.utils.iou import iou_batch

"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
Expand Down Expand Up @@ -107,7 +108,13 @@ def iou_distance(atracks, btracks):
else:
atlbrs = [track.xyxy for track in atracks]
btlbrs = [track.xyxy for track in btracks]
_ious = ious(atlbrs, btlbrs)

ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if ious.size == 0:
return ious
_ious = iou_batch(atlbrs, btlbrs)
print(_ious)

cost_matrix = 1 - _ious

return cost_matrix
Expand Down Expand Up @@ -215,45 +222,6 @@ def fuse_score(cost_matrix, detections):
return fuse_cost


def bbox_ious(boxes, query_boxes):
"""
Parameters
----------
boxes: (N, 4) ndarray of float
query_boxes: (K, 4) ndarray of float
Returns
-------
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
N = boxes.shape[0]
K = query_boxes.shape[0]
overlaps = np.zeros((N, K), dtype=np.float32)

for k in range(K):
box_area = (query_boxes[k, 2] - query_boxes[k, 0] + 1) * (
query_boxes[k, 3] - query_boxes[k, 1] + 1
)
for n in range(N):
iw = (
min(boxes[n, 2], query_boxes[k, 2]) -
max(boxes[n, 0], query_boxes[k, 0]) + 1
)
if iw > 0:
ih = (
min(boxes[n, 3], query_boxes[k, 3]) -
max(boxes[n, 1], query_boxes[k, 1]) + 1
)
if ih > 0:
ua = float(
(boxes[n, 2] - boxes[n, 0] + 1) *
(boxes[n, 3] - boxes[n, 1] + 1) +
box_area -
iw * ih
)
overlaps[n, k] = iw * ih / ua
return overlaps


def _pdist(a, b):
"""Compute pair-wise squared distance between points in `a` and `b`.
Parameters
Expand Down

0 comments on commit 807be8a

Please sign in to comment.