Skip to content

Commit

Permalink
Merge pull request #123 from jeromekelleher/ancestry-calc
Browse files Browse the repository at this point in the history
Initial draft of the mean ancestry calculation.
  • Loading branch information
jeromekelleher authored Sep 26, 2018
2 parents d06bed4 + 812c0bc commit 79521c3
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 30 deletions.
195 changes: 173 additions & 22 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import msprime
import numpy as np
import numpy.ma as ma

import tsinfer

Expand Down Expand Up @@ -843,38 +844,188 @@ def test_inferred_simplify(self):
self.verify(ts)


class TestNodeSpan(unittest.TestCase):
def simple_node_span(self, ts):
"""
Straightforward implementation of node span calculation by iterating over
over the trees and nodes in each tree.
"""
S = np.zeros(ts.num_nodes)
for tree in ts.trees():
length = tree.interval[1] - tree.interval[0]
for u in tree.nodes():
S[u] += length
return S

def verify(self, ts):
S1 = self.simple_node_span(ts)
S2 = tsinfer.node_span(ts)
self.assertEqual(S1.shape, S2.shape)
self.assertTrue(np.allclose(S1, S2))
self.assertTrue(np.all(S1 > 0))
self.assertTrue(np.all(S1 <= ts.sequence_length))
return S1

def test_single_locus(self):
ts = msprime.simulate(10, random_seed=1)
S = self.verify(ts)
self.assertTrue(np.all(S == 1))

def test_single_locus_sequence_length(self):
ts = msprime.simulate(10, length=100, random_seed=1)
S = self.verify(ts)
self.assertTrue(np.all(S == 100))

def test_simple_recombination(self):
ts = msprime.simulate(20, recombination_rate=5, random_seed=2)
self.assertGreater(ts.num_trees, 2)
S = self.verify(ts)
self.assertFalse(np.all(S == 1))

def test_simple_recombination_sequence_length(self):
ts = msprime.simulate(20, recombination_rate=5, length=10, random_seed=3)
self.assertGreater(ts.num_trees, 2)
S = self.verify(ts)
self.assertFalse(np.all(S == 10))

def test_inferred_no_simplify(self):
ts = msprime.simulate(10, recombination_rate=2, mutation_rate=10, random_seed=3)
samples = tsinfer.SampleData.from_tree_sequence(ts)
ts = tsinfer.infer(samples, simplify=False)
self.verify(ts)

def test_inferred(self):
ts = msprime.simulate(10, recombination_rate=2, mutation_rate=10, random_seed=3)
samples = tsinfer.SampleData.from_tree_sequence(ts)
ts = tsinfer.infer(samples)
self.verify(ts)

def test_inferred_random_data(self):
np.random.seed(10)
num_sites = 40
num_samples = 8
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.uint8)
with tsinfer.SampleData() as sample_data:
for j in range(num_sites):
sample_data.add_site(j, G[j])
ts = tsinfer.infer(sample_data)
self.verify(ts)


class TestMeanSampleAncestry(unittest.TestCase):
"""
Tests the mean_sample_ancestry function.
"""
# Commenting out for now.
# def verify(self, ts):
# A = np.zeros((ts.num_populations, ts.num_nodes))
# for pop in range(ts.num_populations):
# A_pop = np.zeros((ts.num_nodes, ts.num_trees))
# L = np.zeros(ts.num_nodes)
# samples = ts.samples(population=pop)
# for tree in ts.trees(tracked_samples=samples):
# left, right = tree.interval
# for node in tree.nodes():
# f = tree.num_tracked_samples(node) / tree.num_samples(node)
# A_pop[node][tree.index] = f
# L[node] = right - left
# print(A_pop)
# print(L)
# print(np.mean(A_pop, axis=1))

# for tree in ts.trees():
# print(tree.interval)
# print(tree.draw(format="unicode"))

def test_two_populations_high_migration(self):
def simple_mean_sample_ancestry(self, ts, sample_sets):
"""
Straightforward implementation of mean sample ancestry by iterating
over the trees and nodes in each tree.
"""
S = tsinfer.node_span(ts)
A = np.zeros((len(sample_sets), ts.num_nodes))
for set_index in range(len(sample_sets)):
# Set everything to -1 to detect the trees in which the node is not
# present. We use a numpy mask to exclude these below.
A_pop = np.zeros((ts.num_nodes, ts.num_trees)) - 1
for tree in ts.trees(tracked_samples=sample_sets[set_index]):
left, right = tree.interval
for node in tree.nodes():
num_samples = tree.num_samples(node)
if num_samples > 0:
f = tree.num_tracked_samples(node) / num_samples
# Each fraction is weighted by the distance along this tree.
w = (right - left)
A_pop[node][tree.index] = f * w
x = ma.array(A_pop, mask=A_pop < 0)
# The final value for each node is the mean ancestry fraction for this
# population over the trees that it was defined in, divided by the span
# of that node.
A[set_index] = np.sum(x, axis=1) / S
return A

def verify(self, ts, sample_sets):
A1 = self.simple_mean_sample_ancestry(ts, sample_sets)
A2 = tsinfer.mean_sample_ancestry(ts, sample_sets)
self.assertEqual(A1.shape, A2.shape)
# print()
# for node in ts.nodes():
# if not np.allclose(A1[:, node.id], A2[:, node.id]):
# print("*", end="")
# print("{}\t{:.5f}\t{:.5f}\t|{:.5f}\t{:.5f}".format(
# node.id, A1[0, node.id], A1[1, node.id],
# A2[0, node.id], A2[1, node.id]))
self.assertTrue(np.allclose(A1, A2))
return A1

def two_populations_high_migration_example(self):
ts = msprime.simulate(
population_configurations=[
msprime.PopulationConfiguration(3),
msprime.PopulationConfiguration(3)],
migration_matrix=[[0, 1], [1, 0]],
recombination_rate=1,
mutation_rate=10,
random_seed=5)
self.assertGreater(ts.num_trees, 1)
# self.verify(ts)
return ts

def get_random_data_example(self, num_sites, num_samples, seed=100):
np.random.seed(seed)
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.uint8)
with tsinfer.SampleData() as sample_data:
for j in range(num_sites):
sample_data.add_site(j, G[j])
return sample_data

def test_two_populations_high_migration(self):
ts = self.two_populations_high_migration_example()
A = self.verify(ts, [ts.samples(0), ts.samples(1)])
total = np.sum(A, axis=0)
self.assertTrue(np.allclose(total[total != 0], 1))

@unittest.skip("TODO: we should probably be taking the fraction *within*")
def test_two_populations_incomplete_samples(self):
ts = self.two_populations_high_migration_example()
samples = ts.samples()
A = self.verify(ts, [samples[:2], samples[:-2]])
total = np.sum(A, axis=0)
self.assertTrue(np.allclose(total[total != 0], 1))

def test_two_populations_high_migration_inferred(self):
ts = self.two_populations_high_migration_example()
samples = tsinfer.SampleData.from_tree_sequence(ts)
inferred_ts = tsinfer.infer(samples)
self.assertEqual(inferred_ts.num_populations, ts.num_populations)
self.verify(inferred_ts, [inferred_ts.samples(0), inferred_ts.samples(1)])

def test_two_populations_high_migration_inferred_no_simplify(self):
ts = self.two_populations_high_migration_example()
samples = tsinfer.SampleData.from_tree_sequence(ts)
inferred_ts = tsinfer.infer(samples, simplify=False)
self.assertEqual(inferred_ts.num_populations, ts.num_populations)
self.verify(inferred_ts, [inferred_ts.samples(0), inferred_ts.samples(1)])

def test_random_data_inferred(self):
n = 20
samples = self.get_random_data_example(num_sites=52, num_samples=n)
inferred_ts = tsinfer.infer(samples)
samples = inferred_ts.samples()
self.verify(inferred_ts, [samples[: n // 2], samples[n // 2:]])

def test_random_data_inferred_no_simplify(self):
samples = self.get_random_data_example(num_sites=20, num_samples=10)
inferred_ts = tsinfer.infer(samples, simplify=False)
samples = inferred_ts.samples()
self.verify(inferred_ts, [samples[:5], samples[5:]])

def test_many_groups(self):
ts = msprime.simulate(32, recombination_rate=10, random_seed=4)
samples = ts.samples()
group_size = 1
while group_size <= ts.num_samples:
sample_sets = [
samples[j * group_size: (j + 1) * group_size]
for j in range(ts.num_samples // group_size)]
self.verify(ts, sample_sets)
group_size *= 2
5 changes: 5 additions & 0 deletions tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def get_example_ts(self, sample_size, sequence_length):

def verify_data_round_trip(self, ts, input_file):
self.assertGreater(ts.num_sites, 1)
for pop in ts.populations():
input_file.add_population()
for sample in ts.samples():
node = ts.node(sample)
input_file.add_individual(ploidy=1, population=node.population)
for v in ts.variants():
input_file.add_site(v.site.position, v.genotypes, v.alleles)
input_file.record_provenance("verify_data_round_trip")
Expand Down
6 changes: 5 additions & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,13 +1582,17 @@ def verify(self, samples):
t1.provenances.clear()
t2.provenances.clear()

# Population data isn't carried through in ancestors tree sequences
# for now.
t2.populations.clear()

self.assertEqual(t1.nodes, t2.nodes)
self.assertEqual(t1.edges, t2.edges)
self.assertEqual(t1.sites, t2.sites)
self.assertEqual(t1.mutations, t2.mutations)
self.assertEqual(t1.populations, t2.populations)
self.assertEqual(t1.sites, t2.sites)
self.assertEqual(t1.individuals, t2.individuals)
self.assertEqual(t1.sites, t2.sites)

self.assertEqual(t1, t2)

Expand Down
Loading

0 comments on commit 79521c3

Please sign in to comment.