Skip to content

Improve efficiency of Tracklet Generation #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 128 additions & 18 deletions src/mouse_tracking/utils/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,29 +1070,137 @@ def generate_greedy_tracklets(self, max_cost: float = -np.log(1e-3), rotate_pose
self._tracklet_gen_method = 'greedy'
self._make_tracklets()

def stitch_greedy_tracklets(self, num_tracks: int = None, all_embeds: bool = True, prioritize_long: bool = False):
"""Greedy method that links merges tracklets 1 at a time based on lowest cost.

Args:
num_tracks: number of tracks to produce
all_embeds: bool to include original tracklet centers as merges are made
prioritize_long: bool to adjust cost of linking with length of tracklets
"""
def stitch_greedy_tracklets(
self,
num_tracks: int | None = None,
all_embeds: bool = True,
prioritize_long: bool = False,
):
"""Optimized greedy method that links merges tracklets 1 at a time based on lowest cost.

Args:
num_tracks: number of tracks to produce
all_embeds: bool to include original tracklet centers as merges are made
prioritize_long: bool to adjust cost of linking with length of tracklets

Notes:
Optimized version eliminates O(n³) pandas DataFrame recreation bottleneck.
Uses numpy arrays and incremental cost matrix updates for O(n²) complexity.
"""
if num_tracks is None:
num_tracks = self._avg_observation

# copy original tracklet list, so that we can revert at the end
original_tracklets = self._tracklets

# We can use pandas to do slightly easier searching
current_costs = pd.DataFrame(self._get_transition_costs(all_embeds, True, longer_track_priority=float(prioritize_long)))
while not np.all(np.isinf(current_costs.to_numpy(na_value=np.inf))):
t1, t2 = np.unravel_index(np.argmin(current_costs.to_numpy(na_value=np.inf)), current_costs.shape)
tracklet_1 = current_costs.index[t1]
tracklet_2 = current_costs.columns[t2]
new_tracklet = Tracklet.from_tracklets([self._tracklets[tracklet_1], self._tracklets[tracklet_2]], True)
self._tracklets = [x for i, x in enumerate(self._tracklets) if i not in [tracklet_1, tracklet_2]] + [new_tracklet]
current_costs = pd.DataFrame(self._get_transition_costs(all_embeds, True, longer_track_priority=float(prioritize_long)))
# Early exit if no tracklets or only one tracklet
if len(self._tracklets) <= 1:
self._stitch_translation = {0: 0}
self._tracklets = original_tracklets
self._tracklet_stitch_method = "greedy"
return

# Get initial transition costs as dict and convert to numpy matrix
cost_dict = self._get_transition_costs(
all_embeds, True, longer_track_priority=float(prioritize_long)
)

# Build numpy cost matrix - work with a copy of tracklets for merging
working_tracklets = list(
self._tracklets
) # Copy for modifications during merging
n_tracklets = len(working_tracklets)

# Initialize cost matrix with infinity
cost_matrix = np.full((n_tracklets, n_tracklets), np.inf, dtype=np.float64)

# Fill cost matrix from cost_dict
for i, costs_for_i in cost_dict.items():
for j, cost in costs_for_i.items():
cost_matrix[i, j] = cost
cost_matrix[j, i] = cost # Matrix should be symmetric

# Track which tracklets are still active (not merged)
active_tracklets = set(range(n_tracklets))

# Main stitching loop - continues until no more valid merges
while len(active_tracklets) > 1:
# Find minimum cost among active tracklets
min_cost = np.inf
best_pair = None

for i in active_tracklets:
for j in active_tracklets:
if i < j and cost_matrix[i, j] < min_cost:
min_cost = cost_matrix[i, j]
best_pair = (i, j)

# If no finite cost found, break (no more valid merges)
if best_pair is None or np.isinf(min_cost):
break

tracklet_1_idx, tracklet_2_idx = best_pair

# Create new merged tracklet
new_tracklet = Tracklet.from_tracklets(
[working_tracklets[tracklet_1_idx], working_tracklets[tracklet_2_idx]],
True,
)

# Remove merged tracklets from active set
active_tracklets.remove(tracklet_1_idx)
active_tracklets.remove(tracklet_2_idx)

# Add new tracklet to working list and get its index
working_tracklets.append(new_tracklet)
new_tracklet_idx = len(working_tracklets) - 1
active_tracklets.add(new_tracklet_idx)

# Extend cost matrix for new tracklet if needed
if new_tracklet_idx >= cost_matrix.shape[0]:
# Extend matrix size
old_size = cost_matrix.shape[0]
new_size = max(old_size * 2, new_tracklet_idx + 1)
new_matrix = np.full((new_size, new_size), np.inf, dtype=np.float64)
new_matrix[:old_size, :old_size] = cost_matrix
cost_matrix = new_matrix

# Calculate costs for new tracklet with all remaining active tracklets
for other_idx in active_tracklets:
if other_idx != new_tracklet_idx and other_idx < len(working_tracklets):
# Calculate cost between new tracklet and existing tracklet
match_cost = new_tracklet.compare_to(
working_tracklets[other_idx], other_anchors=all_embeds
)

# Apply priority adjustment if enabled
if match_cost is not None and prioritize_long:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm finding the formatting of this block a bit hard to follow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The crazy Github tab size really isn't helping. Does anyone mind if I ruff format the whole file to replace the tabs with spaces?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.github.com/en/account-and-profile/setting-up-and-managing-your-personal-account-on-github/managing-user-account-settings/managing-your-tab-size-rendering-preference
Default is 8. I'd prefer having the formatting stuff not mask the changes just yet (but should definitely happen in the near future!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will wait.

longer_track_length = 100 # Default from _get_transition_costs
sigmoid_length_new = 1 / (
1 + np.exp(longer_track_length - new_tracklet.n_frames)
)
sigmoid_length_other = 1 / (
1
+ np.exp(
longer_track_length
- working_tracklets[other_idx].n_frames
)
)
match_cost += (
1 - sigmoid_length_new * sigmoid_length_other
) * float(prioritize_long)

# Update cost matrix
if match_cost is not None and not np.isinf(match_cost):
cost_matrix[new_tracklet_idx, other_idx] = match_cost
cost_matrix[other_idx, new_tracklet_idx] = match_cost
else:
cost_matrix[new_tracklet_idx, other_idx] = np.inf
cost_matrix[other_idx, new_tracklet_idx] = np.inf

# Update self._tracklets with the merged result for ID assignment
self._tracklets = [working_tracklets[i] for i in active_tracklets]

# Tracklets are formed. Now we should assign the longest ones IDs.
tracklet_lengths = [len(x.frames) for x in self._tracklets]
Expand All @@ -1102,9 +1210,11 @@ def stitch_greedy_tracklets(self, num_tracks: int = None, all_embeds: bool = Tru
for cur_assignment in assignment_order:
ids_to_assign = self._tracklets[cur_assignment].track_id
for cur_tracklet_id in ids_to_assign:
track_to_longterm_id[int(cur_tracklet_id + 1)] = current_id if current_id > 0 else 0
track_to_longterm_id[int(cur_tracklet_id + 1)] = (
current_id if current_id > 0 else 0
)
current_id -= 1

self._stitch_translation = track_to_longterm_id
self._tracklets = original_tracklets
self._tracklet_stitch_method = 'greedy'
self._tracklet_stitch_method = "greedy"
1 change: 1 addition & 0 deletions tests/utils/matching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for the matching utils module."""
1 change: 1 addition & 0 deletions tests/utils/matching/video_observations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for the VideoObservations class."""
Loading