Skip to content

Commit

Permalink
V3 migration (#484)
Browse files Browse the repository at this point in the history
* feat: convert edges to ocdbt

* feat: worker function to convert edges to ocdbt

* feat: ocdbt option, consolidate ingest cli

* fix(ingest): move fn to utils

* fix(ingest): move ocdbt setup to a fn

* add tensorstore req, fix build kaniko cache

* feat: copy fake_edges to column family 4

* feat: upgrade atomic chunks

* fix: rename abstract module to parent

* feat: upgrade higher layers, docs

* feat: upgrade cli, move common fns to utils

* add copy_fake_edges in upgrade fn

* handle earliest_timestamp, add test flag to upgrade

* fix: fake_edges serialize np.uint64

* add get_operation method, fix timestamp in repair, check for parent

* check for l2 ids invalidated by edit retries

* remove unnecessary parent assert

* remove unused vars

* ignore invalid ids, assert parent after earliest_ts

* check for ids invalidated by retries in higher layers

* parallelize update_cross_edges

* overwrite graph version, create col family 4

* improve status print formatting

* remove ununsed code, consolidate small common module

* efficient check for chunks not done

* check for empty chunks, use get_parents

* efficient get_edit_ts call by batching all children

* reduce earliest_ts calls

* combine bigtable calls, use numpy unique

* add completion rate command

* fix: ignore children without cross edges

* add span option to rate calculation

* reduce mem usage with global vars

* optimize cross edge reading

* use existing layer var

* limit cx edge reading above given layer

* fix: read for earliest_ts check only if true

* filter cross edges fn with timestamps

* remove git from dockerignore, print stats

* shuffle for better distribution of ids

* fix: use different var name for layer

* increase bigtable read timeout

* add message with assert

* fix: make span option int

* handle skipped connections

* fix: read cross edges at layer >= node_layer

* handle another case of skipped nodes

* check for unique parent count

* update graph_id in meta

* uncomment line

* make repair easier to use

* add sanity check for edits

* add sanity check for each layer

* add layers flag for cx edges

* use better names for functions and vars, update types, fix docs
  • Loading branch information
akhileshh authored May 12, 2024
1 parent bc1faa5 commit aaf528e
Show file tree
Hide file tree
Showing 29 changed files with 1,109 additions and 412 deletions.
2 changes: 2 additions & 0 deletions pychunkedgraph/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def configure_app(app):
with app.app_context():
from ..ingest.rq_cli import init_rq_cmds
from ..ingest.cli import init_ingest_cmds
from ..ingest.cli_upgrade import init_upgrade_cmds

init_rq_cmds(app)
init_ingest_cmds(app)
init_upgrade_cmds(app)
23 changes: 23 additions & 0 deletions pychunkedgraph/debug/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,26 @@ def get_l2children(cg, node: np.uint64) -> np.ndarray:
l2children.append(children[layers == 2])
nodes = children[layers > 2]
return np.concatenate(l2children)


def sanity_check(cg, new_roots, operation_id):
"""
Check for duplicates in hierarchy, useful for debugging.
"""
print(f"{len(new_roots)} new ids from {operation_id}")
l2c_d = {}
for new_root in new_roots:
l2c_d[new_root] = get_l2children(cg, new_root)
success = True
for k, v in l2c_d.items():
success = success and (len(v) == np.unique(v).size)
print(f"{k}: {np.unique(v).size}, {len(v)}")
if not success:
raise RuntimeError("Some ids are not valid.")


def sanity_check_single(cg, node, operation_id):
v = get_l2children(cg, node)
msg = f"invalid node {node}:"
msg += f" found {len(v)} l2 ids, must be {np.unique(v).size}"
assert np.unique(v).size == len(v), f"{msg}, from {operation_id}."
6 changes: 6 additions & 0 deletions pychunkedgraph/graph/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ class Connectivity:
),
)

FakeEdgesCF3 = _Attribute(
key=b"fake_edges",
family_id="3",
serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)),
)

FakeEdges = _Attribute(
key=b"fake_edges",
family_id="4",
Expand Down
78 changes: 59 additions & 19 deletions pychunkedgraph/graph/chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .meta import ChunkedGraphMeta
from .utils import basetypes
from .utils import id_helpers
from .utils import serializers
from .utils import generic as misc_utils
from .edges import Edges
from .edges import utils as edge_utils
Expand Down Expand Up @@ -76,7 +77,7 @@ def version(self) -> str:
return self.client.read_graph_version()

@property
def client(self) -> base.SimpleClient:
def client(self) -> BigTableClient:
return self._client

@property
Expand Down Expand Up @@ -287,9 +288,11 @@ def _get_children_multiple(
node_ids=node_ids, properties=attributes.Hierarchy.Child
)
return {
x: node_children_d[x][0].value
if x in node_children_d
else types.empty_1d.copy()
x: (
node_children_d[x][0].value
if x in node_children_d
else types.empty_1d.copy()
)
for x in node_ids
}
return self.cache.children_multiple(node_ids)
Expand Down Expand Up @@ -322,6 +325,7 @@ def get_cross_chunk_edges(
node_ids: typing.Iterable,
*,
raw_only=False,
all_layers=True,
time_stamp: typing.Optional[datetime.datetime] = None,
) -> typing.Dict:
"""
Expand All @@ -334,21 +338,24 @@ def get_cross_chunk_edges(
node_ids = np.array(node_ids, dtype=basetypes.NODE_ID)
if node_ids.size == 0:
return result
attrs = [
attributes.Connectivity.CrossChunkEdge[l]
for l in range(2, max(3, self.meta.layer_count))
]
layers = range(2, max(3, self.meta.layer_count))
attrs = [attributes.Connectivity.CrossChunkEdge[l] for l in layers]
node_edges_d_d = self.client.read_nodes(
node_ids=node_ids,
properties=attrs,
end_time=time_stamp,
end_time_inclusive=True,
)
for id_ in node_ids:
layers = self.get_chunk_layers(node_ids)
valid_layer = lambda x, y: x >= y
if not all_layers:
valid_layer = lambda x, y: x == y
for layer, id_ in zip(layers, node_ids):
try:
result[id_] = {
prop.index: val[0].value.copy()
for prop, val in node_edges_d_d[id_].items()
if valid_layer(prop.index, layer)
}
except KeyError:
result[id_] = {}
Expand Down Expand Up @@ -631,9 +638,24 @@ def get_fake_edges(
edges = np.concatenate(
[np.array(e.value, dtype=basetypes.NODE_ID, copy=False) for e in val]
)
result[id_] = Edges(edges[:, 0], edges[:, 1], fake_edges=True)
result[id_] = Edges(edges[:, 0], edges[:, 1])
return result

def copy_fake_edges(self, chunk_id: np.uint64) -> None:
_edges = self.client.read_node(
node_id=chunk_id,
properties=attributes.Connectivity.FakeEdgesCF3,
end_time_inclusive=True,
fake_edges=True,
)
mutations = []
_id = serializers.serialize_uint64(chunk_id, fake_edges=True)
for e in _edges:
val_dict = {attributes.Connectivity.FakeEdges: e.value}
row = self.client.mutate_row(_id, val_dict, time_stamp=e.timestamp)
mutations.append(row)
self.client.write(mutations)

def get_l2_agglomerations(
self, level2_ids: np.ndarray, edges_only: bool = False
) -> typing.Tuple[typing.Dict[int, types.Agglomeration], typing.Tuple[Edges]]:
Expand Down Expand Up @@ -690,13 +712,15 @@ def get_l2_agglomerations(
)
return (
agglomeration_d,
(self.mock_edges,)
if self.mock_edges is not None
else (in_edges, out_edges, cross_edges),
(
(self.mock_edges,)
if self.mock_edges is not None
else (in_edges, out_edges, cross_edges)
),
)

def get_node_timestamps(
self, node_ids: typing.Sequence[np.uint64], return_numpy=True
self, node_ids: typing.Sequence[np.uint64], return_numpy=True, normalize=False
) -> typing.Iterable:
"""
The timestamp of the children column can be assumed
Expand All @@ -710,17 +734,22 @@ def get_node_timestamps(
if return_numpy:
return np.array([], dtype=np.datetime64)
return []
result = []
earliest_ts = self.get_earliest_timestamp()
for n in node_ids:
ts = children[n][0].timestamp
if normalize:
ts = earliest_ts if ts < earliest_ts else ts
result.append(ts)
if return_numpy:
return np.array(
[children[x][0].timestamp for x in node_ids], dtype=np.datetime64
)
return [children[x][0].timestamp for x in node_ids]
return np.array(result, dtype=np.datetime64)
return result

# OPERATIONS
def add_edges(
self,
user_id: str,
atomic_edges: typing.Sequence[np.uint64],
atomic_edges: typing.Sequence[typing.Sequence[np.uint64]],
*,
affinities: typing.Sequence[np.float32] = None,
source_coords: typing.Sequence[int] = None,
Expand Down Expand Up @@ -935,3 +964,14 @@ def get_earliest_timestamp(self):
_, timestamp = self.client.read_log_entry(op_id)
if timestamp is not None:
return timestamp - timedelta(milliseconds=500)

def get_operation_ids(self, node_ids: typing.Sequence):
response = self.client.read_nodes(node_ids=node_ids)
result = {}
for node in node_ids:
try:
operations = response[node][attributes.OperationLogs.OperationID]
result[node] = [(x.value, x.timestamp) for x in operations]
except KeyError:
...
return result
2 changes: 1 addition & 1 deletion pychunkedgraph/graph/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def create_graph(self) -> None:
"""Initialize the graph and store associated meta."""

@abstractmethod
def add_graph_version(self, version):
def add_graph_version(self, version: str, overwrite: bool = False):
"""Add a version to the graph."""

@abstractmethod
Expand Down
37 changes: 23 additions & 14 deletions pychunkedgraph/graph/client/bigtable/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from google.cloud.bigtable.column_family import MaxVersionsGCRule
from google.cloud.bigtable.table import Table
from google.cloud.bigtable.row_set import RowSet
from google.cloud.bigtable.row_data import PartialRowData
from google.cloud.bigtable.row_data import DEFAULT_RETRY_READ_ROWS, PartialRowData
from google.cloud.bigtable.row_filters import RowFilter

from . import utils
Expand Down Expand Up @@ -97,8 +97,9 @@ def create_graph(self, meta: ChunkedGraphMeta, version: str) -> None:
self.add_graph_version(version)
self.update_graph_meta(meta)

def add_graph_version(self, version: str):
assert self.read_graph_version() is None, "Graph has already been versioned."
def add_graph_version(self, version: str, overwrite: bool = False):
if not overwrite:
assert self.read_graph_version() is None, self.read_graph_version()
self._version = version
row = self.mutate_row(
attributes.GraphVersion.key,
Expand Down Expand Up @@ -160,18 +161,25 @@ def read_nodes(
# when all IDs in a block are within a range
node_ids = np.sort(node_ids)
rows = self._read_byte_rows(
start_key=serialize_uint64(start_id, fake_edges=fake_edges)
if start_id is not None
else None,
end_key=serialize_uint64(end_id, fake_edges=fake_edges)
if end_id is not None
else None,
start_key=(
serialize_uint64(start_id, fake_edges=fake_edges)
if start_id is not None
else None
),
end_key=(
serialize_uint64(end_id, fake_edges=fake_edges)
if end_id is not None
else None
),
end_key_inclusive=end_id_inclusive,
row_keys=(
serialize_uint64(node_id, fake_edges=fake_edges) for node_id in node_ids
)
if node_ids is not None
else None,
(
serialize_uint64(node_id, fake_edges=fake_edges)
for node_id in node_ids
)
if node_ids is not None
else None
),
columns=properties,
start_time=start_time,
end_time=end_time,
Expand Down Expand Up @@ -819,7 +827,8 @@ def _execute_read_thread(self, args: typing.Tuple[Table, RowSet, RowFilter]):
# Check for everything falsy, because Bigtable considers even empty
# lists of row_keys as no upper/lower bound!
return {}
range_read = table.read_rows(row_set=row_set, filter_=row_filter)
retry = DEFAULT_RETRY_READ_ROWS.with_timeout(180)
range_read = table.read_rows(row_set=row_set, filter_=row_filter, retry=retry)
res = {v.row_key: utils.partial_row_data_to_column_dict(v) for v in range_read}
return res

Expand Down
Loading

0 comments on commit aaf528e

Please sign in to comment.