Skip to content

Create a new merge() function, and use that for concatenate #3183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ which perform the same actions but modify the {class}`TableCollection` in place.
.. autosummary::
TreeSequence.simplify
TreeSequence.subset
TreeSequence.merge
TreeSequence.union
TreeSequence.concatenate
TreeSequence.keep_intervals
Expand Down Expand Up @@ -753,6 +754,7 @@ a functional way, returning a new tree sequence while leaving the original uncha
TableCollection.delete_sites
TableCollection.trim
TableCollection.shift
TableCollection.merge
TableCollection.union
TableCollection.delete_older
```
Expand Down
7 changes: 7 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
- ``TreeSequence.map_to_vcf_model`` now also returns the transformed positions and
contig length. (:user:`benjeffery`, :pr:`XXXX`, :issue:`3173`)

- New ``merge`` functions for tree sequences and table collections, to merge another
into the current one (:user:`hyanwong`, :pr:`3183`, :issue:`3181`)

**Bugfixes**

- Fix bug in ``TreeSequence.pair_coalescence_counts`` when ``span_normalise=True``
and a window breakpoint falls within an internal missing interval.
(:user:`nspope`, :pr:`3176`, :issue:`3175`)

- Change ``TreeSequence.concatenate`` to use ``merge``, as ``union`` does not
port edges, sites, or mutations from the added tree sequences if they are associated
with shared nodes (:user:`hyanwong`, :pr:`3183`, :issue:`3168`, :issue:`3182`)

--------------------
[0.6.4] - 2025-05-21
--------------------
Expand Down
281 changes: 268 additions & 13 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io
import itertools
import json
import platform
import random
import sys
import unittest
Expand All @@ -43,6 +44,9 @@
import tskit.provenance as provenance


IS_WINDOWS = platform.system() == "Windows"


def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=True):
"""
Simple Python implementation of keep_intervals.
Expand Down Expand Up @@ -7141,18 +7145,223 @@ def test_bad_seq_len(self):
ts.shift(1, sequence_length=1)


class TestMerge:
def test_empty(self):
ts = tskit.TableCollection(2).tree_sequence()
merged_ts = ts.merge(ts, node_mapping=[])
assert merged_ts.num_nodes == 0
assert merged_ts.num_edges == 0
assert merged_ts.sequence_length == 2

def test_overlay(self):
ts1 = tskit.Tree.generate_balanced(4, span=2).tree_sequence
tables = tskit.Tree.generate_comb(4, span=2).tree_sequence.dump_tables()
tables.populations.add_row()
tables.nodes[5] = tables.nodes[5].replace(
flags=tskit.NODE_IS_SAMPLE, population=0
)
ts2 = tables.tree_sequence()
ts_merge = ts1.merge(ts2, node_mapping=np.full(ts1.num_nodes, tskit.NULL))
assert ts_merge.sequence_length == ts1.sequence_length
assert ts_merge.num_samples == ts1.num_samples + ts2.num_samples
assert ts_merge.num_nodes == ts1.num_nodes + ts2.num_nodes
assert ts_merge.num_edges == ts1.num_edges + ts2.num_edges
assert ts_merge.num_trees == 1
assert ts_merge.num_populations == 1
assert ts_merge.first().num_roots == 2

def test_split_and_merge(self):
# Cut up a single tree into alternating edges and mutations, then merge
ts = tskit.Tree.generate_comb(4, span=10).tree_sequence
ts = msprime.sim_mutations(ts, rate=0.1, random_seed=1)
mut_counts = np.bincount(ts.mutations_site, minlength=ts.num_sites)
assert min(mut_counts) == 1
assert max(mut_counts) > 1
tables1 = ts.dump_tables()
tables1.mutations.clear()
tables2 = tables1.copy()
i = 0
for s in ts.sites():
for m in s.mutations:
i += 1
if i % 2:
tables1.mutations.append(m.replace(parent=tskit.NULL))
else:
tables2.mutations.append(m.replace(parent=tskit.NULL))
tables1.simplify()
tables2.simplify()
assert tables1.sites.num_rows != ts.num_sites
tables1.edges.clear()
tables2.edges.clear()
for e in ts.edges():
if e.id % 2:
tables1.edges.append(e)
else:
tables2.edges.append(e)
ts1 = tables1.tree_sequence()
ts2 = tables2.tree_sequence()
new_ts = ts1.merge(ts2, node_mapping=np.arange(ts.num_nodes)).simplify()
assert new_ts.equals(ts, ignore_provenance=True)

def test_multi_tree(self):
ts = msprime.sim_ancestry(
2, sequence_length=4, recombination_rate=1, random_seed=1
)
ts = msprime.sim_mutations(ts, rate=1, random_seed=1)
assert ts.num_trees > 3
assert ts.num_mutations > 4
ts1 = ts.keep_intervals([[0, 1.5]], simplify=False)
ts2 = ts.keep_intervals([[1.5, 4]], simplify=False)
new_ts = ts1.merge(
ts2, node_mapping=np.arange(ts.num_nodes), add_populations=False
)
assert new_ts.num_trees == ts.num_trees + 1
new_ts = new_ts.simplify()
new_ts.equals(ts, ignore_provenance=True)

def test_new_individuals(self):
ts1 = msprime.sim_ancestry(2, sequence_length=1, random_seed=1)
ts2 = msprime.sim_ancestry(2, sequence_length=1, random_seed=2)
tables = ts2.dump_tables()
tables.edges.clear()
ts2 = tables.tree_sequence()
node_map = np.full(ts2.num_nodes, tskit.NULL)
node_map[0:2] = [0, 1] # map first two nodes to themselves
ts_merged = ts1.merge(ts2, node_mapping=node_map)
assert ts_merged.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
assert ts1.num_individuals == 2
assert ts_merged.num_individuals == 3

def test_popcheck(self):
tables = tskit.TableCollection(1)
p1 = tables.populations.add_row(b"foo")
p2 = tables.populations.add_row(b"bar")
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2)
ts1 = tables.tree_sequence()
tables.populations[0] = tables.populations[0].replace(metadata=b"baz")
ts2 = tables.tree_sequence()
with pytest.raises(ValueError, match="Non-matching populations"):
ts1.merge(ts2, node_mapping=[0, 1])
ts1.merge(ts2, node_mapping=[0, 1], check_populations=False)
# Check with add_populations=False
ts1.merge(ts2, node_mapping=[-1, 1]) # only merge the last one
with pytest.raises(ValueError, match="Non-matching populations"):
ts1.merge(ts2, node_mapping=[-1, 1], add_populations=False)

with pytest.raises(ValueError, match="Non-matching populations"):
ts1.simplify([0]).merge(ts2, node_mapping=[-1, 1])

def test_isolated_mutations(self):
tables = tskit.TableCollection(1)
u = tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
s = tables.sites.add_row(0.5, "A")
tables.mutations.add_row(s, u, derived_state="T", time=1, metadata=b"xxx")
ts1 = tables.tree_sequence()
tables.mutations[0] = tables.mutations[0].replace(time=0.5, metadata=b"yyy")
ts2 = tables.tree_sequence()
ts_merge = ts1.merge(ts2, node_mapping=[0])
assert ts_merge.num_sites == 1
assert ts_merge.num_mutations == 2
assert ts_merge.mutation(0).time == 1
assert ts_merge.mutation(0).parent == tskit.NULL
assert ts_merge.mutation(0).metadata == b"xxx"
assert ts_merge.mutation(1).time == 0.5
assert ts_merge.mutation(1).parent == 0
assert ts_merge.mutation(1).metadata == b"yyy"

def test_identity(self):
tables = tskit.TableCollection(1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
ts = tables.tree_sequence()
ts_merge = ts.merge(ts, node_mapping=[0])
assert ts.equals(ts_merge, ignore_provenance=True)

@pytest.mark.skipif(IS_WINDOWS, reason="Msprime gives different result on Windows")
def test_migrations(self):
pop_configs = [msprime.PopulationConfiguration(3) for _ in range(2)]
migration_matrix = [[0, 0.001], [0.001, 0]]
ts = msprime.simulate(
population_configurations=pop_configs,
migration_matrix=migration_matrix,
record_migrations=True,
recombination_rate=2,
random_seed=42, # pick a seed that gives min(migrations.left) > 0
end_time=100,
)
# No migration_table.squash() function exists, so we just try to cut on the
# LHS of all the migrations
assert ts.num_migrations > 0
assert ts.migrations_left.min() > 0
cutpoint = ts.migrations_left.min()
ts1 = ts.keep_intervals([[0, cutpoint]], simplify=False)
ts2 = ts.keep_intervals([[cutpoint, ts.sequence_length]], simplify=False)
ts_new = ts1.merge(ts2, node_mapping=np.arange(ts.num_nodes))
tables = ts_new.dump_tables()
tables.edges.squash()
tables.sort()
ts_new = tables.tree_sequence()
ts.tables.assert_equals(ts_new.tables, ignore_provenance=True)

def test_provenance(self):
tables = tskit.TableCollection(1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
ts = tables.tree_sequence()
ts_merge = ts.merge(ts, node_mapping=[0], record_provenance=False)
assert ts_merge.num_provenances == ts.num_provenances
ts_merge = ts.merge(ts, node_mapping=[0])
assert ts_merge.num_provenances == ts.num_provenances + 1
prov = json.loads(ts_merge.provenance(-1).record)
assert prov["parameters"]["command"] == "merge"
assert prov["parameters"]["node_mapping"] == [0]
assert prov["parameters"]["add_populations"] is True
assert prov["parameters"]["check_populations"] is True

def test_bad_sequence_length(self):
ts1 = tskit.TableCollection(1).tree_sequence()
ts2 = tskit.TableCollection(2).tree_sequence()
with pytest.raises(ValueError, match="sequence length"):
ts1.merge(ts2, node_mapping=[])

def test_bad_node_mapping(self):
ts = tskit.Tree.generate_comb(5).tree_sequence
with pytest.raises(ValueError, match="node_mapping"):
ts.merge(ts, node_mapping=[0, 1, 2])

def test_bad_populations(self):
tables = tskit.TableCollection(1)
tables = tskit.TableCollection(1)
p1 = tables.populations.add_row()
p2 = tables.populations.add_row()
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2)
ts2 = tables.tree_sequence()
ts1 = ts2.simplify([0, 1])
assert ts1.num_populations == 1
assert ts2.num_populations == 2
ts2.merge(ts1, [0, -1], check_populations=False, add_populations=False)
with pytest.raises(ValueError, match="population not present"):
ts1.merge(ts2, [0, -1, -1], check_populations=False, add_populations=False)


class TestConcatenate:
def test_simple(self):
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
ts1 = msprime.sim_mutations(ts1, rate=1, random_seed=1)
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
ts2 = msprime.sim_mutations(ts2, rate=1, random_seed=1)
assert ts1.num_samples == ts2.num_samples
assert ts1.num_nodes != ts2.num_nodes
joint_ts = ts1.concatenate(ts2)
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
assert joint_ts.num_samples == ts1.num_samples
assert joint_ts.num_sites == ts1.num_sites + ts2.num_sites
assert joint_ts.num_mutations == ts1.num_mutations + ts2.num_mutations
ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim()
# Have to simplify here, to remove the redundant nodes
ts3.tables.assert_equals(ts1.tables, ignore_provenance=True)
assert ts3.equals(ts1.simplify(), ignore_provenance=True)
ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim()
assert ts4.equals(ts2.simplify(), ignore_provenance=True)
Expand Down Expand Up @@ -7183,6 +7392,13 @@ def test_empty(self):
assert ts.num_nodes == 0
assert ts.sequence_length == 40

def test_check_populations(self):
ts = msprime.sim_ancestry(2)
ts1 = ts.concatenate(ts, ts, check_populations=True)
assert ts1.num_populations == 1
ts2 = ts.concatenate(ts, ts, add_populations=True, check_populations=True)
assert ts2.num_populations == 3

def test_samples_at_end(self):
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
Expand All @@ -7200,22 +7416,58 @@ def test_internal_samples(self):
nodes_flags[:] = tskit.NODE_IS_SAMPLE
nodes_flags[-1] = 0 # Only root is not a sample
tables.nodes.flags = nodes_flags
ts = tables.tree_sequence()
ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.5, random_seed=1)
assert ts.num_mutations > 0
assert ts.num_mutations > ts.num_sites
joint_ts = ts.concatenate(ts)
assert joint_ts.num_samples == ts.num_samples
assert joint_ts.num_nodes == ts.num_nodes + 1
assert joint_ts.num_mutations == ts.num_mutations * 2
assert joint_ts.num_sites == ts.num_sites * 2
assert joint_ts.sequence_length == ts.sequence_length * 2

def test_some_shared_samples(self):
ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence
ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence
shared = np.full(ts2.num_nodes, tskit.NULL)
shared[0] = 1
shared[1] = 0
joint_ts = ts1.concatenate(ts2, node_mappings=[shared])
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
tables = tskit.Tree.generate_comb(5).tree_sequence.dump_tables()
tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE)
ts1 = tables.tree_sequence()
tables = tskit.Tree.generate_balanced(5).tree_sequence.dump_tables()
tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE)
ts2 = tables.tree_sequence()
assert ts1.num_samples == ts2.num_samples
joint_ts = ts1.concatenate(ts2)
assert joint_ts.num_samples == ts1.num_samples
assert joint_ts.num_edges == ts1.num_edges + ts2.num_edges
for tree in joint_ts.trees():
assert tree.num_roots == 1

@pytest.mark.parametrize("simplify", [True, False])
def test_wf_sim(self, simplify):
# Test that we can split & concat a wf_sim ts, which has internal samples
tables = wf.wf_sim(
6,
5,
seed=3,
deep_history=True,
initial_generation_samples=True,
num_loci=10,
)
tables.sort()
tables.simplify()
ts = msprime.mutate(tables.tree_sequence(), rate=0.05, random_seed=234)
assert ts.num_trees > 2
assert len(np.unique(ts.nodes_time[ts.samples()])) > 1
ts1 = ts.keep_intervals([[0, 4.5]], simplify=False).trim()
ts2 = ts.keep_intervals([[4.5, ts.sequence_length]], simplify=False).trim()
if simplify:
ts1 = ts1.simplify(filter_nodes=False)
ts2, node_map = ts2.simplify(map_nodes=True)
node_mapping = np.zeros_like(node_map, shape=ts2.num_nodes)
kept = node_map != tskit.NULL
node_mapping[node_map[kept]] = np.arange(len(node_map))[kept]
else:
node_mapping = np.arange(ts.num_nodes)
ts_new = ts1.concatenate(ts2, node_mappings=[node_mapping]).simplify()
ts_new.tables.assert_equals(ts.tables, ignore_provenance=True)

def test_provenance(self):
ts = tskit.Tree.generate_comb(2).tree_sequence
Expand All @@ -7233,9 +7485,12 @@ def test_unequal_samples(self):
with pytest.raises(ValueError, match="must have the same number of samples"):
ts1.concatenate(ts2)

@pytest.mark.skip(
reason="union bug: https://github.com/tskit-dev/tskit/issues/3168"
)
def test_different_sample_numbers(self):
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
ts2 = tskit.Tree.generate_balanced(4, arity=3, span=3).tree_sequence
with pytest.raises(ValueError, match="must have the same number of samples"):
ts1.concatenate(ts2)

def test_duplicate_ts(self):
ts1 = tskit.Tree.generate_comb(3, span=4).tree_sequence
ts = ts1.keep_intervals([[0, 1]]).trim() # a quarter of the original
Expand Down
Loading
Loading