Skip to content

Commit

Permalink
Merge branch 'main' into import_external_tracks_menu
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Feb 18, 2025
2 parents 53ba4ab + 2c5dac2 commit b4cd72e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
10 changes: 5 additions & 5 deletions src/motile_tracker/data_model/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def __init__(
super().__init__(tracks)
self.nodes = nodes
user_attrs = attributes.copy()
self.times = attributes.get(tracks.time_attr, None)
if tracks.time_attr in attributes:
del user_attrs[tracks.time_attr]
self.times = attributes.get(NodeAttr.TIME.value, None)
if NodeAttr.TIME.value in attributes:
del user_attrs[NodeAttr.TIME.value]
self.positions = attributes.get(tracks.pos_attr, None)
if tracks.pos_attr in attributes:
del user_attrs[tracks.pos_attr]
Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(
super().__init__(tracks)
self.nodes = nodes
self.attributes = {
self.tracks.time_attr: self.tracks.get_times(nodes),
NodeAttr.TIME.value: self.tracks.get_times(nodes),
self.tracks.pos_attr: self.tracks.get_positions(nodes),
NodeAttr.TRACK_ID.value: self.tracks._get_nodes_attr(
nodes, NodeAttr.TRACK_ID.value
Expand Down Expand Up @@ -227,7 +227,7 @@ def __init__(
"""
super().__init__(tracks)
protected_attrs = [
NodeAttr.TIME.value,
tracks.time_attr,
NodeAttr.AREA.value,
NodeAttr.TRACK_ID.value,
]
Expand Down
13 changes: 5 additions & 8 deletions src/motile_tracker/data_model/tracks_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,8 @@ def is_valid(self, edge) -> tuple[bool, TracksAction | None]:
True if the edge is valid, false if invalid"""

# make sure that the node2 is downstream of node1
time1 = self.tracks.graph.nodes[edge[0]][NodeAttr.TIME.value]
time2 = self.tracks.graph.nodes[edge[1]][NodeAttr.TIME.value]
time1 = self.tracks.get_time(edge[0])
time2 = self.tracks.get_time(edge[1])

if time1 > time2:
edge = (edge[1], edge[0])
Expand All @@ -415,10 +415,7 @@ def is_valid(self, edge) -> tuple[bool, TracksAction | None]:
return False, action

# reject if edge is horizontal
elif (
self.tracks.graph.nodes[edge[0]][NodeAttr.TIME.value]
== self.tracks.graph.nodes[edge[1]][NodeAttr.TIME.value]
):
elif self.tracks.get_time(edge[0]) == self.tracks.get_time(edge[1]):
show_warning("Edge is rejected because it is horizontal.")
return False, action

Expand Down Expand Up @@ -464,7 +461,7 @@ def is_valid(self, edge) -> tuple[bool, TracksAction | None]:
nodes = [
n
for n, attr in self.tracks.graph.nodes(data=True)
if attr.get(NodeAttr.TIME.value) == t
if attr.get(self.tracks.time_attr) == t
and attr.get(NodeAttr.TRACK_ID.value) == track_id2
]
if len(nodes) > 0:
Expand Down Expand Up @@ -550,7 +547,7 @@ def update_segmentations(
times = [pix[0][0] for pix in pixels]
attributes = {
NodeAttr.TRACK_ID.value: track_ids,
self.tracks.time_attr: times,
NodeAttr.TIME.value: times,
"node_id": nodes,
}

Expand Down

0 comments on commit b4cd72e

Please sign in to comment.