Skip to content

Commit

Permalink
fix(edits): incorrect order of opeartions; documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
akhileshh committed Sep 25, 2023
1 parent ab8a358 commit ae1dd24
Showing 1 changed file with 108 additions and 102 deletions.
210 changes: 108 additions & 102 deletions pychunkedgraph/graph/edits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List
from typing import Tuple
from typing import Iterable
from typing import Set
from collections import defaultdict

import fastremap
Expand All @@ -24,15 +25,13 @@


def _init_old_hierarchy(cg, l2ids: np.ndarray, parent_ts: datetime.datetime = None):
new_old_id_d = defaultdict(set)
old_new_id_d = defaultdict(set)
old_hierarchy_d = {id_: {2: id_} for id_ in l2ids}
for id_ in l2ids:
layer_parent_d = cg.get_all_parents_dict(id_, time_stamp=parent_ts)
old_hierarchy_d[id_].update(layer_parent_d)
for parent in layer_parent_d.values():
old_hierarchy_d[parent] = old_hierarchy_d[id_]
return new_old_id_d, old_new_id_d, old_hierarchy_d
return old_hierarchy_d


def _analyze_affected_edges(
Expand Down Expand Up @@ -176,64 +175,6 @@ def check_fake_edges(
return atomic_edges, rows


def _update_neighbor_cross_edges_single(
cg, new_id: int, cx_edges_d: dict, node_map: dict, *, parent_ts
) -> dict:
node_layer = cg.get_chunk_layer(new_id)
counterparts = []
for layer in range(node_layer, cg.meta.layer_count):
layer_edges = cx_edges_d.get(layer, types.empty_2d)
counterparts.extend(layer_edges[:, 1])

cp_cx_edges_d = cg.get_cross_chunk_edges(
counterparts, time_stamp=parent_ts, raw_only=True
)
updated_counterparts = {}
for counterpart, edges_d in cp_cx_edges_d.items():
val_dict = {}
for layer in range(2, cg.meta.layer_count):
edges = edges_d.get(layer, types.empty_2d)
if edges.size == 0:
continue
assert np.all(edges[:, 0] == counterpart)
edges = fastremap.remap(edges, node_map, preserve_missing_labels=True)
edges_d[layer] = edges
val_dict[attributes.Connectivity.CrossChunkEdge[layer]] = edges
if not val_dict:
continue
cg.cache.cross_chunk_edges_cache[counterpart] = edges_d
updated_counterparts[counterpart] = val_dict
return updated_counterparts


def _update_neighbor_cross_edges(
cg, new_ids: List[int], new_old_id_d: dict, old_new_id_d, *, time_stamp, parent_ts
) -> List:
node_map = {}
for k, v in old_new_id_d.items():
node_map[k] = next(iter(v))

updated_counterparts = {}
newid_cx_edges_d = cg.get_cross_chunk_edges(new_ids, time_stamp=parent_ts)
for new_id in new_ids:
cx_edges_d = newid_cx_edges_d[new_id]
temp_map = {
old_id: new_id for old_id in _get_flipped_ids(new_old_id_d, [new_id])
}
node_map.update(temp_map)
result = _update_neighbor_cross_edges_single(
cg, new_id, cx_edges_d, node_map, parent_ts=parent_ts
)
updated_counterparts.update(result)

updated_entries = []
for node, val_dict in updated_counterparts.items():
rowkey = serialize_uint64(node)
row = cg.client.mutate_row(rowkey, val_dict, time_stamp=time_stamp)
updated_entries.append(row)
return updated_entries


def add_edges(
cg,
*,
Expand All @@ -250,9 +191,10 @@ def add_edges(
if not allow_same_segment_merge:
roots = cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts)
assert np.unique(roots).size == 2, "L2 IDs must belong to different roots."
new_old_id_d, old_new_id_d, old_hierarchy_d = _init_old_hierarchy(
cg, l2ids, parent_ts=parent_ts
)

new_old_id_d = defaultdict(set)
old_new_id_d = defaultdict(set)
old_hierarchy_d = _init_old_hierarchy(cg, l2ids, parent_ts=parent_ts)
atomic_children_d = cg.get_children(l2ids)
cross_edges_d = merge_cross_edge_dicts(
cg.get_cross_chunk_edges(l2ids, time_stamp=parent_ts), l2_cross_edges_d
Expand Down Expand Up @@ -288,14 +230,6 @@ def add_edges(
new_cx_edges_d[layer] = edges
assert np.all(edges[:, 0] == new_id)
cg.cache.cross_chunk_edges_cache[new_id] = new_cx_edges_d
updated_entries = _update_neighbor_cross_edges(
cg,
new_l2_ids,
new_old_id_d,
old_new_id_d,
time_stamp=time_stamp,
parent_ts=parent_ts,
)

create_parents = CreateParentNodes(
cg,
Expand All @@ -313,7 +247,7 @@ def add_edges(
l2c = get_l2children(cg, new_root)
assert len(l2c) == np.unique(l2c).size, f"inconsistent result op {operation_id}"
create_parents.create_new_entries()
return new_roots, new_l2_ids, updated_entries + create_parents.new_entries
return new_roots, new_l2_ids, create_parents.new_entries


def _process_l2_agglomeration(
Expand Down Expand Up @@ -385,9 +319,9 @@ def remove_edges(
roots = cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts)
assert np.unique(roots).size == 1, "L2 IDs must belong to same root."

new_old_id_d, old_new_id_d, old_hierarchy_d = _init_old_hierarchy(
cg, l2ids, parent_ts=parent_ts
)
new_old_id_d = defaultdict(set)
old_new_id_d = defaultdict(set)
old_hierarchy_d = _init_old_hierarchy(cg, l2ids, parent_ts=parent_ts)
chunk_id_map = dict(zip(l2ids.tolist(), cg.get_chunk_ids_from_node_ids(l2ids)))

removed_edges = np.concatenate([atomic_edges, atomic_edges[:, ::-1]], axis=0)
Expand Down Expand Up @@ -424,14 +358,6 @@ def remove_edges(
new_cx_edges_d[layer] = edges
assert np.all(edges[:, 0] == new_id)
cg.cache.cross_chunk_edges_cache[new_id] = new_cx_edges_d
updated_entries = _update_neighbor_cross_edges(
cg,
new_l2_ids,
new_old_id_d,
old_new_id_d,
time_stamp=time_stamp,
parent_ts=parent_ts,
)

create_parents = CreateParentNodes(
cg,
Expand All @@ -448,7 +374,7 @@ def remove_edges(
l2c = get_l2children(cg, new_root)
assert len(l2c) == np.unique(l2c).size, f"inconsistent result op {operation_id}"
create_parents.create_new_entries()
return new_roots, new_l2_ids, updated_entries + create_parents.new_entries
return new_roots, new_l2_ids, create_parents.new_entries


def _get_flipped_ids(id_map, node_ids):
Expand All @@ -463,6 +389,82 @@ def _get_flipped_ids(id_map, node_ids):
return np.concatenate(ids)


def _update_neighbor_cross_edges_single(
cg, new_id: int, cx_edges_d: dict, node_map: dict, *, parent_ts
) -> dict:
"""
For each new_id, get counterparts and update its cross chunk edges.
Some of them maybe updated multiple times so we need to collect them first
and then write to storage to consolidate the mutations.
Returns updated counterparts.
"""
node_layer = cg.get_chunk_layer(new_id)
counterparts = []
for layer in range(node_layer, cg.meta.layer_count):
layer_edges = cx_edges_d.get(layer, types.empty_2d)
counterparts.extend(layer_edges[:, 1])

cp_cx_edges_d = cg.get_cross_chunk_edges(
counterparts, time_stamp=parent_ts, raw_only=True
)
updated_counterparts = {}
for counterpart, edges_d in cp_cx_edges_d.items():
val_dict = {}
for layer in range(2, cg.meta.layer_count):
edges = edges_d.get(layer, types.empty_2d)
if edges.size == 0:
continue
assert np.all(edges[:, 0] == counterpart)
edges = fastremap.remap(edges, node_map, preserve_missing_labels=True)
edges_d[layer] = edges
val_dict[attributes.Connectivity.CrossChunkEdge[layer]] = edges
if not val_dict:
continue
cg.cache.cross_chunk_edges_cache[counterpart] = edges_d
updated_counterparts[counterpart] = val_dict
return updated_counterparts


def _update_neighbor_cross_edges(
cg,
new_ids: List[int],
new_old_id: dict,
old_new_id,
*,
time_stamp,
parent_ts,
) -> List:
"""
For each new_id, get counterparts and update its cross chunk edges.
Some of them maybe updated multiple times so we need to collect them first
and then write to storage to consolidate the mutations.
Returns mutations to updated counterparts/partner nodes.
"""
updated_counterparts = {}
newid_cx_edges_d = cg.get_cross_chunk_edges(new_ids, time_stamp=parent_ts)

node_map = {}
for k, v in old_new_id.items():
if len(v) == 1:
node_map[k] = next(iter(v))

for new_id in new_ids:
cx_edges_d = newid_cx_edges_d[new_id]
m = {old_id: new_id for old_id in _get_flipped_ids(new_old_id, [new_id])}
node_map.update(m)
result = _update_neighbor_cross_edges_single(
cg, new_id, cx_edges_d, node_map, parent_ts=parent_ts
)
updated_counterparts.update(result)

updated_entries = []
for node, val_dict in updated_counterparts.items():
rowkey = serialize_uint64(node)
row = cg.client.mutate_row(rowkey, val_dict, time_stamp=time_stamp)
updated_entries.append(row)
return updated_entries


class CreateParentNodes:
def __init__(
self,
Expand All @@ -471,8 +473,8 @@ def __init__(
new_l2_ids: Iterable,
operation_id: basetypes.OPERATION_ID,
time_stamp: datetime.datetime,
new_old_id_d: Dict[np.uint64, Iterable[np.uint64]] = None,
old_new_id_d: Dict[np.uint64, Iterable[np.uint64]] = None,
new_old_id_d: Dict[np.uint64, Set[np.uint64]] = None,
old_new_id_d: Dict[np.uint64, Set[np.uint64]] = None,
old_hierarchy_d: Dict[np.uint64, Dict[int, np.uint64]] = None,
parent_ts: datetime.datetime = None,
):
Expand Down Expand Up @@ -544,12 +546,15 @@ def _update_cross_edge_cache(self, parent, children):
updates cross chunk edges in cache;
this can only be done after all new components at a layer have IDs
"""
parent_layer = self.cg.get_chunk_layer(parent)
if parent_layer == 2:
# l2 cross edges have already been updated by this point
return
cx_edges_d = self.cg.get_cross_chunk_edges(
children, time_stamp=self._last_successful_ts
)
cx_edges_d = concatenate_cross_edge_dicts(cx_edges_d.values())

parent_layer = self.cg.get_chunk_layer(parent)
edge_nodes = np.unique(np.concatenate([*cx_edges_d.values(), types.empty_2d]))
edge_parents = self.cg.get_roots(
edge_nodes,
Expand Down Expand Up @@ -600,28 +605,15 @@ def _create_new_parents(self, layer: int):
self.cg.get_parent_chunk_id(cc_ids[0], parent_layer),
root_chunk=parent_layer == self.cg.meta.layer_count,
)
new_parent_ids.append(parent_id)
self._new_ids_d[parent_layer].append(parent_id)
self._update_id_lineage(parent_id, cc_ids, layer, parent_layer)
new_parent_ids.append(parent_id)

self.cg.cache.children_cache[parent_id] = cc_ids
cache_utils.update(
self.cg.cache.parents_cache,
cc_ids,
parent_id,
)
for new_id in new_parent_ids:
children = self.cg.get_children(new_id)
self._update_cross_edge_cache(new_id, children)
entries = _update_neighbor_cross_edges(
self.cg,
new_parent_ids,
self._new_old_id_d,
self._old_new_id_d,
time_stamp=self._time_stamp,
parent_ts=self._last_successful_ts,
)
self.new_entries.extend(entries)

def run(self) -> Iterable:
"""
Expand All @@ -637,6 +629,20 @@ def run(self) -> Iterable:
self.cg.graph_id,
self._operation_id,
):
# all new IDs in this layer have been created
# update their cross chunk edges and their neighbors'
for new_id in self._new_ids_d[layer]:
children = self.cg.get_children(new_id)
self._update_cross_edge_cache(new_id, children)
entries = _update_neighbor_cross_edges(
self.cg,
self._new_ids_d[layer],
self._new_old_id_d,
self._old_new_id_d,
time_stamp=self._time_stamp,
parent_ts=self._last_successful_ts,
)
self.new_entries.extend(entries)
self._create_new_parents(layer)
return self._new_ids_d[self.cg.meta.layer_count]

Expand Down

0 comments on commit ae1dd24

Please sign in to comment.