Skip to content

Commit

Permalink
Merge pull request #127 from jeromekelleher/exclude-centromere
Browse files Browse the repository at this point in the history
Added method to snip out the centromeres of a tree sequence.
  • Loading branch information
jeromekelleher authored Oct 9, 2018
2 parents 79521c3 + b19cc18 commit 343e3c1
Show file tree
Hide file tree
Showing 2 changed files with 308 additions and 147 deletions.
216 changes: 181 additions & 35 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import msprime
import numpy as np
import numpy.ma as ma

import tsinfer

Expand Down Expand Up @@ -854,7 +853,8 @@ def simple_node_span(self, ts):
for tree in ts.trees():
length = tree.interval[1] - tree.interval[0]
for u in tree.nodes():
S[u] += length
if tree.num_samples(u) > 0:
S[u] += length
return S

def verify(self, ts):
Expand Down Expand Up @@ -922,50 +922,54 @@ def simple_mean_sample_ancestry(self, ts, sample_sets):
Straightforward implementation of mean sample ancestry by iterating
over the trees and nodes in each tree.
"""
S = tsinfer.node_span(ts)
A = np.zeros((len(sample_sets), ts.num_nodes))
for set_index in range(len(sample_sets)):
# Set everything to -1 to detect the trees in which the node is not
# present. We use a numpy mask to exclude these below.
A_pop = np.zeros((ts.num_nodes, ts.num_trees)) - 1
for tree in ts.trees(tracked_samples=sample_sets[set_index]):
left, right = tree.interval
for node in tree.nodes():
num_samples = tree.num_samples(node)
if num_samples > 0:
f = tree.num_tracked_samples(node) / num_samples
# Each fraction is weighted by the distance along this tree.
w = (right - left)
A_pop[node][tree.index] = f * w
x = ma.array(A_pop, mask=A_pop < 0)
# The final value for each node is the mean ancestry fraction for this
# population over the trees that it was defined in, divided by the span
# of that node.
A[set_index] = np.sum(x, axis=1) / S
S = np.zeros(ts.num_nodes)
tree_iters = [ts.trees(tracked_samples=sample_set) for sample_set in sample_sets]
for _ in range(ts.num_trees):
trees = [next(tree_iter) for tree_iter in tree_iters]
left, right = trees[0].interval
length = right - left
for node in trees[0].nodes():
total_samples = sum(tree.num_tracked_samples(node) for tree in trees)
if total_samples > 0:
for j, tree in enumerate(trees):
f = tree.num_tracked_samples(node) / total_samples
A[j, node] += f * length
S[node] += length

# The final value for each node is the mean ancestry fraction for this
# population over the trees that it was defined in, divided by the span
# of that node.
index = S != 0
A[:, index] /= S[index]
return A

def verify(self, ts, sample_sets):
A1 = self.simple_mean_sample_ancestry(ts, sample_sets)
A2 = tsinfer.mean_sample_ancestry(ts, sample_sets)
self.assertEqual(A1.shape, A2.shape)
# for tree in ts.trees():
# print(tree.interval)
# print(tree.draw(format="unicode"))
# print()
# for node in ts.nodes():
# if not np.allclose(A1[:, node.id], A2[:, node.id]):
# print("*", end="")
# print("{}\t{:.5f}\t{:.5f}\t|{:.5f}\t{:.5f}".format(
# node.id, A1[0, node.id], A1[1, node.id],
# A2[0, node.id], A2[1, node.id]))
# print(node.id, np.sum(A2[:, node.id]), A2[:, node.id], sep="\t")
if set(itertools.chain(*sample_sets)) == set(ts.samples()):
self.assertTrue(np.allclose(np.sum(A1, axis=0), 1))
else:
S = np.sum(A1, axis=0)
self.assertTrue(np.allclose(S[S != 0], 1))
self.assertTrue(np.allclose(A1, A2))
return A1

def two_populations_high_migration_example(self):
def two_populations_high_migration_example(self, mutation_rate=10):
ts = msprime.simulate(
population_configurations=[
msprime.PopulationConfiguration(3),
msprime.PopulationConfiguration(3)],
msprime.PopulationConfiguration(8),
msprime.PopulationConfiguration(8)],
migration_matrix=[[0, 1], [1, 0]],
recombination_rate=1,
mutation_rate=10,
recombination_rate=3,
mutation_rate=mutation_rate,
random_seed=5)
self.assertGreater(ts.num_trees, 1)
return ts
Expand All @@ -984,14 +988,51 @@ def test_two_populations_high_migration(self):
total = np.sum(A, axis=0)
self.assertTrue(np.allclose(total[total != 0], 1))

@unittest.skip("TODO: we should probably be taking the fraction *within*")
def test_two_populations_high_migration_no_centromere(self):
ts = self.two_populations_high_migration_example(mutation_rate=0)
ts = tsinfer.snip_centromere(ts, 0.4, 0.6)
# simplify the output to get rid of unreferenced nodes.
ts = ts.simplify()
A = self.verify(ts, [ts.samples(0), ts.samples(1)])
total = np.sum(A, axis=0)
self.assertTrue(np.allclose(total[total != 0], 1))

def test_span_zero_nodes(self):
ts = msprime.simulate(10, random_seed=1)
tables = ts.dump_tables()
# Add in a few unreferenced nodes.
u1 = tables.nodes.add_row(flags=0, time=1234)
u2 = tables.nodes.add_row(flags=1, time=1234)
ts = tables.tree_sequence()
sample_sets = [[j] for j in range(10)]
A1 = self.simple_mean_sample_ancestry(ts, sample_sets)
A2 = tsinfer.mean_sample_ancestry(ts, sample_sets)
S = np.sum(A1, axis=0)
self.assertTrue(np.allclose(A1, A2))
self.assertTrue(np.allclose(S[:u1], 1))
self.assertTrue(np.all(A1[:, u1] == 0))
self.assertTrue(np.all(A1[:, u2] == 0))

def test_two_populations_incomplete_samples(self):
ts = self.two_populations_high_migration_example()
samples = ts.samples()
A = self.verify(ts, [samples[:2], samples[:-2]])
A = self.verify(ts, [samples[:2], samples[-2:]])
total = np.sum(A, axis=0)
self.assertTrue(np.allclose(total[total != 0], 1))

def test_single_tree_incomplete_samples(self):
ts = msprime.simulate(10, random_seed=1)
A = self.verify(ts, [[0, 1], [2, 3]])
total = np.sum(A, axis=0)
self.assertTrue(np.allclose(total[total != 0], 1))

def test_two_populations_overlapping_samples(self):
ts = self.two_populations_high_migration_example()
with self.assertRaises(ValueError):
tsinfer.mean_sample_ancestry(ts, [[1], [1]])
with self.assertRaises(ValueError):
tsinfer.mean_sample_ancestry(ts, [[1, 1], [2]])

def test_two_populations_high_migration_inferred(self):
ts = self.two_populations_high_migration_example()
samples = tsinfer.SampleData.from_tree_sequence(ts)
Expand All @@ -1014,10 +1055,10 @@ def test_random_data_inferred(self):
self.verify(inferred_ts, [samples[: n // 2], samples[n // 2:]])

def test_random_data_inferred_no_simplify(self):
samples = self.get_random_data_example(num_sites=20, num_samples=10)
samples = self.get_random_data_example(num_sites=20, num_samples=3)
inferred_ts = tsinfer.infer(samples, simplify=False)
samples = inferred_ts.samples()
self.verify(inferred_ts, [samples[:5], samples[5:]])
self.verify(inferred_ts, [samples[:1], samples[1:]])

def test_many_groups(self):
ts = msprime.simulate(32, recombination_rate=10, random_seed=4)
Expand All @@ -1029,3 +1070,108 @@ def test_many_groups(self):
for j in range(ts.num_samples // group_size)]
self.verify(ts, sample_sets)
group_size *= 2


class TestSnipCentromere(unittest.TestCase):
"""
Tests that we remove the centromere successfully from tree sequences.
"""
def snip_centromere(self, ts, left, right):
"""
Simple implementation of snipping out centromere.
"""
assert 0 < left < right < ts.sequence_length
tables = ts.dump_tables()
tables.edges.clear()
for edge in ts.edges():
if right <= edge.left or left >= edge.right:
tables.edges.add_row(edge.left, edge.right, edge.parent, edge.child)
else:
if edge.left < left:
tables.edges.add_row(edge.left, left, edge.parent, edge.child)
if right < edge.right:
tables.edges.add_row(right, edge.right, edge.parent, edge.child)
tables.sort()
return tables.tree_sequence()

def verify(self, ts, left, right):
ts1 = self.snip_centromere(ts, left, right)
ts2 = tsinfer.snip_centromere(ts, left, right)
t1 = ts1.dump_tables()
t2 = ts2.dump_tables()
t1.provenances.clear()
t2.provenances.clear()
self.assertEqual(t1, t2)
tree_found = False
for tree in ts1.trees():
if tree.interval == (left, right):
tree_found = True
for node in ts1.nodes():
self.assertEqual(tree.parent(node.id), msprime.NULL_NODE)
break
self.assertTrue(tree_found)
return ts1

def test_single_tree(self):
ts1 = msprime.simulate(10, random_seed=1)
ts2 = self.verify(ts1, 0.5, 0.6)
self.assertEqual(ts2.num_trees, 3)

def test_many_trees(self):
ts1 = msprime.simulate(10, length=10, recombination_rate=1, random_seed=1)
self.assertGreater(ts1.num_trees, 2)
self.verify(ts1, 5, 6)

def get_random_data_example(self, position, num_samples, seed=100):
np.random.seed(seed)
G = np.random.randint(2, size=(position.shape[0], num_samples)).astype(np.uint8)
with tsinfer.SampleData() as sample_data:
for j, x in enumerate(position):
sample_data.add_site(x, G[j])
return sample_data

def test_random_data_inferred_no_simplify(self):
samples = self.get_random_data_example(
10 * np.arange(10), num_samples=10, seed=2)
inferred_ts = tsinfer.infer(samples, simplify=False)
ts = self.verify(inferred_ts, 55, 57)
self.assertTrue(np.array_equal(
ts.genotype_matrix(), inferred_ts.genotype_matrix()))

def test_random_data_inferred_simplify(self):
samples = self.get_random_data_example(5 * np.arange(10), num_samples=10, seed=2)
inferred_ts = tsinfer.infer(samples, simplify=True)
ts = self.verify(inferred_ts, 12, 15)
self.assertTrue(np.array_equal(
ts.genotype_matrix(), inferred_ts.genotype_matrix()))

def test_coordinate_errors(self):
ts = msprime.simulate(2, length=10, recombination_rate=1, random_seed=1)
self.assertRaises(ValueError, tsinfer.snip_centromere, ts, -1, 5)
self.assertRaises(ValueError, tsinfer.snip_centromere, ts, 0, 5)
self.assertRaises(ValueError, tsinfer.snip_centromere, ts, 1, 10)
self.assertRaises(ValueError, tsinfer.snip_centromere, ts, 1, 11)
self.assertRaises(ValueError, tsinfer.snip_centromere, ts, 6, 5)
self.assertRaises(ValueError, tsinfer.snip_centromere, ts, 5, 5)

def test_position_errors(self):
ts = msprime.simulate(
2, length=10, recombination_rate=1, random_seed=1, mutation_rate=2)
X = ts.tables.sites.position
self.assertGreater(X.shape[0], 3)
# Left cannot be on a site position.
self.assertRaises(ValueError, tsinfer.snip_centromere, ts, X[0], X[0] + 0.001)
# Cannot go either side of a position
self.assertRaises(
ValueError, tsinfer.snip_centromere, ts, X[0] - 0.001, X[0] + 0.001)
# Cannot cover multiple positions
self.assertRaises(
ValueError, tsinfer.snip_centromere, ts, X[0] - 0.001, X[2] + 0.001)

def test_right_on_position(self):
ts1 = msprime.simulate(
2, length=10, recombination_rate=1, random_seed=1, mutation_rate=2)
X = ts1.tables.sites.position
self.assertGreater(X.shape[0], 1)
ts2 = self.verify(ts1, X[0] - 0.001, X[0])
self.assertTrue(np.array_equal(ts1.genotype_matrix(), ts2.genotype_matrix()))
Loading

0 comments on commit 343e3c1

Please sign in to comment.