diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index fca883402..307bd0967 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -189,11 +189,13 @@ def handle_find_minimal_covering_nodes(table_id, is_binary=True): ): # Process from higher layers to lower layers if len(node_queue[layer]) == 0: continue - + current_nodes = list(node_queue[layer]) # Call handle_roots to find parents - parents = cg.get_roots(current_nodes, stop_layer=layer + 1, time_stamp=timestamp) + parents = cg.get_roots( + current_nodes, stop_layer=layer + 1, time_stamp=timestamp + ) unique_parents = np.unique(parents) parent_layers = np.array( [cg.get_chunk_layer(parent) for parent in unique_parents] @@ -312,7 +314,11 @@ def str2bool(v): def publish_edit( - table_id: str, user_id: str, result: GraphEditOperation.Result, is_priority=True + table_id: str, + user_id: str, + result: GraphEditOperation.Result, + is_priority: bool = True, + remesh: bool = True, ): import pickle @@ -322,6 +328,7 @@ def publish_edit( "table_id": table_id, "user_id": user_id, "remesh_priority": "true" if is_priority else "false", + "remesh": "true" if remesh else "false", } payload = { "operation_id": int(result.operation_id), @@ -343,6 +350,7 @@ def handle_merge(table_id, allow_same_segment_merge=False): nodes = json.loads(request.data) is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) chebyshev_distance = request.args.get("chebyshev_distance", 3, type=int) current_app.logger.debug(nodes) @@ -391,7 +399,7 @@ def handle_merge(table_id, allow_same_segment_merge=False): current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) if len(ret.new_lvl2_ids) > 0: - publish_edit(table_id, user_id, ret, is_priority=is_priority) + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) return ret @@ -405,6 +413,7 @@ def handle_split(table_id): data = json.loads(request.data) is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) mincut = request.args.get("mincut", True, type=str2bool) current_app.logger.debug(data) @@ -457,7 +466,7 @@ def handle_split(table_id): current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) if len(ret.new_lvl2_ids) > 0: - publish_edit(table_id, user_id, ret, is_priority=is_priority) + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) return ret @@ -470,6 +479,7 @@ def handle_undo(table_id): data = json.loads(request.data) is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) user_id = str(g.auth_user.get("id", current_app.user_id)) current_app.logger.debug(data) @@ -489,7 +499,7 @@ def handle_undo(table_id): current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) if ret.new_lvl2_ids.size > 0: - publish_edit(table_id, user_id, ret, is_priority=is_priority) + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) return ret @@ -502,6 +512,7 @@ def handle_redo(table_id): data = json.loads(request.data) is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) user_id = str(g.auth_user.get("id", current_app.user_id)) current_app.logger.debug(data) @@ -521,7 +532,7 @@ def handle_redo(table_id): current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) if ret.new_lvl2_ids.size > 0: - publish_edit(table_id, user_id, ret, is_priority=is_priority) + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) return ret @@ -536,6 +547,7 @@ def handle_rollback(table_id): target_user_id = request.args["user_id"] is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) skip_operation_ids = np.array( json.loads(request.args.get("skip_operation_ids", "[]")), dtype=np.uint64 ) @@ -562,7 +574,7 @@ def handle_rollback(table_id): raise cg_exceptions.BadRequest(str(e)) if ret.new_lvl2_ids.size > 0: - publish_edit(table_id, user_id, ret, is_priority=is_priority) + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) return user_operations diff --git a/pychunkedgraph/app/segmentation/v1/routes.py b/pychunkedgraph/app/segmentation/v1/routes.py index 5aee72d94..f0be4d6c4 100644 --- a/pychunkedgraph/app/segmentation/v1/routes.py +++ b/pychunkedgraph/app/segmentation/v1/routes.py @@ -15,6 +15,7 @@ ) from pychunkedgraph.app import common as app_common +from pychunkedgraph.app import app_utils from pychunkedgraph.app.app_utils import ( jsonify_with_kwargs, remap_public, @@ -626,3 +627,21 @@ def valid_nodes(table_id): resp = common.valid_nodes(table_id, is_binary=is_binary) return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route("/table//supervoxel_lookup", methods=["POST"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_supervoxel_lookup(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + + nodes = json.loads(request.data) + cg = app_utils.get_cg(table_id) + node_ids = [] + coords = [] + for node in nodes: + node_ids.append(node[0]) + coords.append(np.array(node[1:]) / cg.segmentation_resolution) + + atomic_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids) + return jsonify_with_kwargs(atomic_ids, int64_as_str=int64_as_str) diff --git a/pychunkedgraph/meshing/meshgen.py b/pychunkedgraph/meshing/meshgen.py index a8da89b1f..b537b97ae 100644 --- a/pychunkedgraph/meshing/meshgen.py +++ b/pychunkedgraph/meshing/meshgen.py @@ -5,9 +5,12 @@ import numpy as np import time import collections +from itertools import combinations from functools import lru_cache +from copy import copy import datetime import pytz +import networkx as nx from scipy import ndimage from multiwrapper import multiprocessing_utils as mu @@ -26,8 +29,6 @@ UTC = pytz.UTC -# Change below to true if debugging and want to see results in stdout -PRINT_FOR_DEBUGGING = False # Change below to false if debugging and do not need to write to cloud (warning: do not deploy w/ below set to false) WRITING_TO_CLOUD = True @@ -193,8 +194,6 @@ def _lower_remaps(ks): assert cg.get_chunk_layer(chunk_id) >= 2 assert cg.get_chunk_layer(chunk_id) <= cg.meta.layer_count - print(f"\n{chunk_id} ----------------\n") - lower_remaps = {} if cg.get_chunk_layer(chunk_id) > 2: for lower_chunk_id in cg.get_chunk_child_ids(chunk_id): @@ -322,7 +321,6 @@ def get_lx_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): time_stamp = UTC.localize(time_stamp) stop_layer, neigh_chunk_ids = calculate_stop_layer(cg, chunk_id) - print(f"Stop layer: {stop_layer}") # Find the parent in the lowest common chunk for each l2 id. These parent # ids are referred to as root ids even though they are not necessarily the @@ -337,8 +335,6 @@ def get_lx_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): # This loop is the main bottleneck for neigh_chunk_id in neigh_chunk_ids: - print(f"Neigh: {neigh_chunk_id} --------------") - lx_ids, root_ids, lx_id_remap = get_root_lx_remapping( cg, neigh_chunk_id, stop_layer, time_stamp=time_stamp, n_threads=n_threads ) @@ -480,7 +476,6 @@ def get_lx_overlapping_remappings_for_nodes_and_svs( time_stamp = UTC.localize(time_stamp) stop_layer, _ = calculate_stop_layer(cg, chunk_id) - print(f"Stop layer: {stop_layer}") # Find the parent in the lowest common chunk for each node id and sv id. These parent # ids are referred to as root ids even though they are not necessarily the @@ -851,8 +846,6 @@ def add_nodes_to_l2_chunk_dict(ids): add_nodes_to_l2_chunk_dict(l2_node_ids) for chunk_id, node_ids in l2_chunk_dict.items(): - if PRINT_FOR_DEBUGGING: - print("remeshing", chunk_id, node_ids) try: l2_time_stamp = _get_timestamp_from_node_ids(cg, node_ids) except ValueError: @@ -889,8 +882,6 @@ def add_nodes_to_l2_chunk_dict(ids): cur_chunk_dict = chunk_dicts[layer - 3] for chunk_dict in chunk_dicts: for chunk_id, node_ids in chunk_dict.items(): - if PRINT_FOR_DEBUGGING: - print("remeshing", chunk_id, node_ids) # Stitch the meshes of the parents we found in the previous loop chunk_stitch_remeshing_task( None, @@ -904,6 +895,43 @@ def add_nodes_to_l2_chunk_dict(ids): ) +def _get_independent_node_groups(cg, ids_to_mesh: set): + """ + Iterates over ids to create connected components (ccs) of overlapping ids. + Then creates list of groups with non overlapping ids by iterating over ccs. + """ + nodes = list(ids_to_mesh) + edges = [] + children_d = cg.get_children(node_id_or_ids=nodes) + for n1, n2 in combinations(nodes, 2): + if np.intersect1d(children_d[n1], children_d[n2]).size: + edges.append([n1, n2]) + ids_to_mesh.discard(int(n1)) + ids_to_mesh.discard(int(n2)) + + graph = nx.Graph() + graph.add_edges_from(edges) + ccs = { + next(iter(cc)): collections.deque(cc) for cc in nx.connected_components(graph) + } + independent_groups = [list(ids_to_mesh)] + while True: + if len(ccs) == 0: + break + + group = [] + empty = [] + for k, v in ccs.items(): + group.append(v.pop()) + if len(v) == 0: + empty.append(k) + + for k in empty: + del ccs[k] + independent_groups.append(group) + return independent_groups + + def chunk_initial_mesh_task( cg_name, chunk_id, @@ -919,6 +947,7 @@ def chunk_initial_mesh_task( cg=None, sharded=False, cache=True, + return_meshes=False, ): if cg is None: cg = ChunkedGraph(graph_id=cg_name) @@ -931,6 +960,7 @@ def chunk_initial_mesh_task( assert layer == 2 assert mip >= cg.meta.cv.mip + merged_meshes = {} if sharded: cv = CloudVolume( f"graphene://https://localhost/segmentation/table/dummy", @@ -938,7 +968,6 @@ def chunk_initial_mesh_task( ) sharding_info = cv.mesh.meta.info["sharding"]["2"] sharding_spec = ShardingSpecification.from_dict(sharding_info) - merged_meshes = {} mesh_dst = os.path.join( cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(layer) ) @@ -946,18 +975,24 @@ def chunk_initial_mesh_task( mesh_dst = cv_unsharded_mesh_path result.append((chunk_id, layer, cx, cy, cz)) - print( - "Retrieving remap table for chunk %s -- (%s, %s, %s, %s)" - % (chunk_id, layer, cx, cy, cz) - ) mesher = zmesh.Mesher(cg.meta.cv.mip_resolution(mip)) draco_encoding_settings = get_draco_encoding_settings_for_chunk( cg, chunk_id, mip, high_padding ) + ids_to_mesh = set() if node_id_subset is None: seg = get_remapped_segmentation( cg, chunk_id, mip, overlap_vx=high_padding, time_stamp=time_stamp ) + try: + ts = cg.meta.custom_data["mesh"]["initial_ts"] + mesh_ts = datetime.datetime.fromtimestamp(ts) + except KeyError: + mesh_ts = None + range_read = cg.range_read_chunk( + chunk_id, properties=attributes.Hierarchy.Child, time_stamp=mesh_ts + ) + ids_to_mesh = set([int(x) for x in range_read.keys()]) else: seg = get_remapped_seg_for_lvl2_nodes( cg, @@ -974,9 +1009,7 @@ def chunk_initial_mesh_task( mesher.mesh(seg) del seg cf = CloudFiles(mesh_dst) - if PRINT_FOR_DEBUGGING: - print("cv path", mesh_dst) - print("num ids", len(mesher.ids())) + result.append(len(mesher.ids())) for obj_id in mesher.ids(): mesh = mesher.get(obj_id, reduction_factor=100, max_error=max_err) @@ -998,8 +1031,9 @@ def chunk_initial_mesh_task( else: file_contents = mesh.to_precomputed() compress = True + ids_to_mesh.discard(int(obj_id)) if WRITING_TO_CLOUD: - if sharded: + if sharded or return_meshes: merged_meshes[int(obj_id)] = file_contents else: cf.put( @@ -1008,6 +1042,29 @@ def chunk_initial_mesh_task( compress=compress, cache_control=cache_string, ) + + if return_meshes: + # children = cg.get_children(node_id_subset, flatten=True) + # print(len(node_id_subset), len(children), len(merged_meshes)) + return copy(merged_meshes) + + if len(ids_to_mesh) > 0: + # can't mesh overlapping ids (shared supervoxels) + independent_groups = _get_independent_node_groups(cg, ids_to_mesh) + for group in independent_groups: + meshes_from_edits = chunk_initial_mesh_task( + None, + chunk_id, + mip=mip, + node_id_subset=group, + cg=cg, + cv_unsharded_mesh_path="file://", + max_err=max_err, + sharded=False, + return_meshes=True, + ) + merged_meshes.update(meshes_from_edits) + if sharded and WRITING_TO_CLOUD: shard_binary = sharding_spec.synthesize_shard(merged_meshes) shard_filename = cv.mesh.readers[layer].get_filename(chunk_id) @@ -1018,15 +1075,15 @@ def chunk_initial_mesh_task( compress=False, cache_control=cache_string, ) - if PRINT_FOR_DEBUGGING: - print(", ".join(str(x) for x in result)) return result -def get_multi_child_nodes(cg, chunk_id, node_id_subset=None, chunk_bbox_string=False): +def get_multi_child_nodes( + cg, chunk_id, node_id_subset=None, chunk_bbox_string=False, time_stamp=None +): if node_id_subset is None: range_read = cg.range_read_chunk( - chunk_id, properties=attributes.Hierarchy.Child + chunk_id, properties=attributes.Hierarchy.Child, time_stamp=time_stamp ) else: range_read = cg.client.read_nodes( @@ -1040,7 +1097,8 @@ def get_multi_child_nodes(cg, chunk_id, node_id_subset=None, chunk_bbox_string=F fragment.value for child_fragments_for_node in node_rows for fragment in child_fragments_for_node - ], dtype=object + ], + dtype=object, ) # Filter out node ids that do not have roots (caused by failed ingest tasks) root_ids = cg.get_roots(node_ids, fail_to_zero=True) @@ -1231,9 +1289,6 @@ def chunk_stitch_remeshing_task( manifest_cache = ManifestCache(cg.graph_id, initial=False) manifest_cache.set_fragments(fragments_d) - - if PRINT_FOR_DEBUGGING: - print(", ".join(str(x) for x in result)) return ", ".join(str(x) for x in result) @@ -1246,8 +1301,16 @@ def chunk_initial_sharded_stitching_task( cache_string = "public" if cache else "no-cache" + try: + ts = cg.meta.custom_data["mesh"]["initial_ts"] + mesh_ts = datetime.datetime.fromtimestamp(ts) + except KeyError: + mesh_ts = None + layer = cg.get_chunk_layer(chunk_id) - multi_child_nodes, multi_child_descendants = get_multi_child_nodes(cg, chunk_id) + multi_child_nodes, multi_child_descendants = get_multi_child_nodes( + cg, chunk_id, time_stamp=mesh_ts + ) chunk_to_id_dict = collections.defaultdict(list) for child_node in multi_child_descendants: diff --git a/pychunkedgraph/meshing/meshgen_utils.py b/pychunkedgraph/meshing/meshgen_utils.py index 711c09322..429fb438f 100644 --- a/pychunkedgraph/meshing/meshgen_utils.py +++ b/pychunkedgraph/meshing/meshgen_utils.py @@ -150,6 +150,7 @@ def get_json_info(cg): return loads(info_str) +@lru_cache(maxsize=1) def get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx=1): cv = CloudVolume(cg.meta.cv.cloudpath, mip=mip, fill_missing=True) mip_diff = mip - cg.meta.cv.mip diff --git a/pychunkedgraph/meshing/meshing_batch.py b/pychunkedgraph/meshing/meshing_batch.py index a5acd1f1b..537645039 100644 --- a/pychunkedgraph/meshing/meshing_batch.py +++ b/pychunkedgraph/meshing/meshing_batch.py @@ -1,8 +1,12 @@ +import argparse, os +import numpy as np +from cloudvolume import CloudVolume +from cloudfiles import CloudFiles from taskqueue import TaskQueue, LocalTaskQueue -import argparse + from pychunkedgraph.graph.chunkedgraph import ChunkedGraph # noqa -import numpy as np from pychunkedgraph.meshing.meshing_sqs import MeshTask +from pychunkedgraph.meshing import meshgen_utils # noqa if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -13,11 +17,22 @@ parser.add_argument('--layer', type=int) parser.add_argument('--mip', type=int) parser.add_argument('--skip_cache', action='store_true') + parser.add_argument('--overwrite', type=bool, default=False) args = parser.parse_args() cache = not args.skip_cache cg = ChunkedGraph(graph_id=args.cg_name) + cv = CloudVolume( + f"graphene://https://localhost/segmentation/table/dummy", + info=meshgen_utils.get_json_info(cg), + ) + dst = os.path.join( + cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(args.layer) + ) + cf = CloudFiles(dst) + if len(list(cf.list())) > 0 and not args.overwrite: + raise ValueError(f"Destination {dst} is not empty. Use `--overwrite true` to proceed anyway.") chunks_arr = [] for x in range(args.chunk_start[0],args.chunk_end[0]): diff --git a/workers/mesh_worker.py b/workers/mesh_worker.py index 238bad7a9..b484d355b 100644 --- a/workers/mesh_worker.py +++ b/workers/mesh_worker.py @@ -22,6 +22,10 @@ def callback(payload): op_id = int(data["operation_id"]) l2ids = np.array(data["new_lvl2_ids"], dtype=basetypes.NODE_ID) table_id = payload.attributes["table_id"] + remesh = payload.attributes["remesh"] + + if remesh == "false": + return try: cg = PCG_CACHE[table_id] @@ -56,7 +60,6 @@ def callback(payload): except KeyError: return - logging.log(INFO_HIGH, f"remeshing {l2ids}; graph {table_id} operation {op_id}.") meshgen.remeshing( cg,