Skip to content

Commit eb4fce8

Browse files
committed
Use struct for node metadata
1 parent 99cad13 commit eb4fce8

File tree

3 files changed

+58
-20
lines changed

3 files changed

+58
-20
lines changed

tests/test_inference.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,9 +1094,14 @@ def test_from_standard_tree_sequence(self):
10941094
assert i1.flags == i2.flags
10951095
assert tsutil.json_metadata_is_subset(i1.metadata, i2.metadata)
10961096
# Unless inference is perfect, internal nodes may differ, but sample nodes
1097-
# should be identical
1097+
# should be identical. Node metadata is not transferred, however, and a tsinfer-
1098+
# specific node metadata schema is used (where empty is None rather than b"")
1099+
assert (
1100+
ts_inferred.table_metadata_schemas.node
1101+
== tsinfer.formats.node_metadata_schema()
1102+
)
10981103
for n1, n2 in zip(ts.samples(), ts_inferred.samples()):
1099-
assert ts.node(n1) == ts_inferred.node(n2)
1104+
assert ts.node(n1).replace(metadata=None) == ts_inferred.node(n2)
11001105
# Sites can have metadata added by the inference process, but inferred site
11011106
# metadata should always include all the metadata in the original ts
11021107
for s1, s2 in zip(ts.sites(), ts_inferred.sites()):
@@ -1586,12 +1591,13 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None):
15861591
ancestors_time = ancestor_data.ancestors_time[:]
15871592
num_ancestor_nodes = 0
15881593
for n in ancestors_ts.nodes():
1589-
md = json.loads(n.metadata) if n.metadata else {}
1594+
md = n.metadata if n.metadata else {}
15901595
if tsinfer.is_pc_ancestor(n.flags):
1591-
assert not ("ancestor_data_id" in md)
1596+
if "tsinfer" in md:
1597+
assert "ancestor_data_id" not in md["tsinfer"]
15921598
else:
1593-
assert "ancestor_data_id" in md
1594-
assert ancestors_time[md["ancestor_data_id"]] == n.time
1599+
assert "tsinfer" in md and "ancestor_data_id" in md["tsinfer"]
1600+
assert ancestors_time[md["tsinfer"]["ancestor_data_id"]] == n.time
15951601
num_ancestor_nodes += 1
15961602
assert num_ancestor_nodes == ancestor_data.num_ancestors
15971603

@@ -3114,8 +3120,7 @@ def verify_augmented_ancestors(
31143120
node = t2.nodes[m + j]
31153121
assert node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR
31163122
assert node.time == 1
3117-
metadata = json.loads(node.metadata.decode())
3118-
assert node_id == metadata["sample_data_id"]
3123+
assert node_id == node.metadata["tsinfer"]["sample_data_id"]
31193124

31203125
t2.nodes.truncate(len(t1.nodes))
31213126
# Adding and subtracting 1 can lead to small diffs, so we compare
@@ -3265,8 +3270,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
32653270
num_sample_ancestors = 0
32663271
for node in final_ts.nodes():
32673272
if node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR:
3268-
metadata = json.loads(node.metadata.decode())
3269-
assert metadata["sample_data_id"] in subset
3273+
assert node.metadata["tsinfer"]["sample_data_id"] in subset
32703274
num_sample_ancestors += 1
32713275
assert expected_sample_ancestors == num_sample_ancestors
32723276
tsinfer.verify(samples, final_ts.simplify())

tsinfer/formats.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,42 @@ def permissive_json_schema():
7474
}
7575

7676

77+
def node_metadata_schema():
78+
# This is fixed by tsinfer: users cannot add to the node metadata
79+
return tskit.MetadataSchema(
80+
{
81+
"codec": "struct",
82+
"type": ["object", "null"],
83+
"properties": {
84+
"tsinfer": {
85+
"description": "Information about node identity "
86+
"from the tsinfer inference process",
87+
"type": "object",
88+
"properties": {
89+
"ancestor_data_id": {
90+
"description": "The corresponding ancestor ID "
91+
"in the ancestors file created by the inference process, "
92+
"or -1 if not applicable",
93+
"type": "number",
94+
"binaryFormat": "i",
95+
"default": -1,
96+
},
97+
"sample_data_id": {
98+
"description": "The corresponding sample ID "
99+
"in the sample data file used for inference, "
100+
"or -1 if not applicable",
101+
"type": "number",
102+
"binaryFormat": "i",
103+
"default": -1,
104+
},
105+
},
106+
},
107+
},
108+
"additionalProperties": False,
109+
}
110+
)
111+
112+
77113
def np_obj_equal(np_obj_array1, np_obj_array2):
78114
"""
79115
A replacement for np.array_equal to test equality of numpy arrays that

tsinfer/inference.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,12 +1411,7 @@ def get_ancestors_tree_sequence(self):
14111411
pc_ancestors = is_pc_ancestor(flags)
14121412
tables.nodes.set_columns(flags=flags, time=times)
14131413

1414-
# # FIXME we should do this as a struct codec?
1415-
# dict_schema = permissive_json_schema()
1416-
# dict_schema = add_to_schema(dict_schema, "ancestor_data_id",
1417-
# {"type": "integer"})
1418-
# schema = tskit.MetadataSchema(dict_schema)
1419-
# tables.nodes.schema = schema
1414+
tables.nodes.metadata_schema = formats.node_metadata_schema()
14201415

14211416
# Add metadata for any non-PC node, pointing to the original ancestor
14221417
metadata = []
@@ -1425,7 +1420,11 @@ def get_ancestors_tree_sequence(self):
14251420
if is_pc:
14261421
metadata.append(b"")
14271422
else:
1428-
metadata.append(_encode_raw_metadata({"ancestor_data_id": ancestor}))
1423+
metadata.append(
1424+
tables.nodes.metadata_schema.validate_and_encode_row(
1425+
{"tsinfer": {"ancestor_data_id": ancestor}}
1426+
)
1427+
)
14291428
ancestor += 1
14301429
tables.nodes.packset_metadata(metadata)
14311430
left, right, parent, child = tsb.dump_edges()
@@ -1471,6 +1470,7 @@ def store_output(self):
14711470
tables = tskit.TableCollection(
14721471
sequence_length=self.ancestor_data.sequence_length
14731472
)
1473+
tables.nodes.metadata_schema = formats.node_metadata_schema()
14741474
ts = tables.tree_sequence()
14751475
return ts
14761476

@@ -1830,9 +1830,7 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
18301830
tables.nodes.add_row(
18311831
flags=constants.NODE_IS_SAMPLE_ANCESTOR,
18321832
time=times[j],
1833-
metadata=_encode_raw_metadata(
1834-
{"sample_data_id": int(sample_indexes[s])}
1835-
),
1833+
metadata={"tsinfer": {"sample_data_id": int(sample_indexes[s])}},
18361834
)
18371835
s += 1
18381836
else:

0 commit comments

Comments
 (0)