Skip to content

Commit

Permalink
Merge pull request #131 from jeromekelleher/fix-individual-mappings
Browse files Browse the repository at this point in the history
Fixup individual ID refs for ancestors ts.
  • Loading branch information
jeromekelleher authored Oct 20, 2018
2 parents 34fc801 + 5d7b23b commit 15ee495
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 6 deletions.
60 changes: 59 additions & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,6 @@ def verify(self, sample_data, ancestors_ts):
self.assertTrue(np.all(flags[non_samples] == source_flags[non_samples]))

def test_no_flags_changes(self):

ts = msprime.simulate(10, mutation_rate=2, recombination_rate=2, random_seed=233)
samples = tsinfer.SampleData.from_tree_sequence(ts)
ancestors = tsinfer.generate_ancestors(samples)
Expand All @@ -818,6 +817,65 @@ def test_append_nodes(self):
self.verify(samples, tables.tree_sequence())


class TestAncestorsTreeSequenceIndividuals(unittest.TestCase):
"""
Checks that we can have individuals in the ancestors tree sequence and
that they are correctly preserved in the final TS.
"""
def verify(self, sample_data, ancestors_ts):
ts = tsinfer.match_samples(sample_data, ancestors_ts, simplify=False)
self.assertEqual(
ancestors_ts.num_individuals + sample_data.num_individuals,
ts.num_individuals)
# The ancestors individiduals should come first.
final_individuals = ts.individuals()
for ind in ancestors_ts.individuals():
final_ind = next(final_individuals)
self.assertEqual(final_ind, ind)
# The nodes for this individual should *not* be samples.
for u in final_ind.nodes:
node = ts.node(u)
self.assertFalse(node.is_sample())

for ind1, ind2 in zip(final_individuals, sample_data.individuals()):
self.assertTrue(np.array_equal(ind1.location, ind2.location))
self.assertEqual(json.loads(ind1.metadata.decode()), ind2.metadata)
# The nodes for this individual should *not* be samples.
for u in ind1.nodes:
node = ts.node(u)
self.assertTrue(node.is_sample())

def test_zero_individuals(self):
ts = msprime.simulate(10, mutation_rate=2, recombination_rate=2, random_seed=233)
samples = tsinfer.SampleData.from_tree_sequence(ts)
ancestors = tsinfer.generate_ancestors(samples)
ancestors_ts = tsinfer.match_ancestors(samples, ancestors)
self.verify(samples, ancestors_ts)

def test_diploid_individuals(self):
ts = msprime.simulate(10, mutation_rate=2, recombination_rate=2, random_seed=233)
tables = ts.dump_tables()
for j in range(ts.num_samples // 2):
tables.individuals.add_row(flags=j, location=[j, j], metadata=b"X" * j)
# Add these individuals to the first n nodes.
individual = np.zeros(ts.num_nodes, dtype=np.int32) - 1
x = np.arange(ts.num_samples // 2)
individual[2 * x] = x
individual[2 * x + 1] = x
tables.nodes.set_columns(
flags=tables.nodes.flags,
time=tables.nodes.time,
individual=individual)
ts = tables.tree_sequence()
with tsinfer.SampleData() as samples:
for j in range(ts.num_samples // 2):
samples.add_individual(ploidy=2, location=[100 * j], metadata={"X": j})
for var in ts.variants():
samples.add_site(var.site.position, var.genotypes)
ancestors_ts = eval_util.make_ancestors_ts(samples, ts)
self.verify(samples, ancestors_ts)


class AlgorithmsExactlyEqualMixin(object):
"""
For small example tree sequences, check that the Python and C implementations
Expand Down
17 changes: 17 additions & 0 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,17 @@ class Variant(object):
alleles = attr.ib()


@attr.s
class Individual(object):
"""
An Individual object.
"""
# TODO document properly.
id = attr.ib()
location = attr.ib()
metadata = attr.ib()


class SampleData(DataContainer):
"""
SampleData(sequence_length=0, path=None, num_flush_threads=0, \
Expand Down Expand Up @@ -1244,6 +1255,12 @@ def haplotypes(self, samples=None, inference_sites=None):
yield index, a
j += 1

def individuals(self):
# TODO document
iterator = zip(self.individuals_location[:], self.individuals_metadata[:])
for j, (location, metadata) in enumerate(iterator):
yield Individual(j, location=location, metadata=metadata)


@attr.s
class Ancestor(object):
Expand Down
9 changes: 4 additions & 5 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,16 +777,15 @@ def get_samples_tree_sequence(self):
inference_sites = self.sample_data.sites_inference[:]
position = self.sample_data.sites_position[:]
tables = self.ancestors_ts.dump_tables()
num_ancestral_individuals = len(tables.individuals)

# Currently there's no information about populations etc stored in the
# ancestors ts.
for metadata in self.sample_data.populations_metadata[:]:
tables.populations.add_row(self.encode_metadata(metadata))
for location, metadata in zip(
self.sample_data.individuals_location[:],
self.sample_data.individuals_metadata[:]):
for ind in self.sample_data.individuals():
tables.individuals.add_row(
location=location, metadata=self.encode_metadata(metadata))
location=ind.location, metadata=self.encode_metadata(ind.metadata))

logger.debug("Adding tree sequence nodes")
flags, time = tsb.dump_nodes()
Expand All @@ -813,7 +812,7 @@ def get_samples_tree_sequence(self):
flags=flags[sample_id],
time=time[sample_id],
population=population,
individual=individual,
individual=num_ancestral_individuals + individual,
metadata=self.encode_metadata(metadata))
# Add in the remaining non-sample nodes.
for u in range(self.sample_ids[-1] + 1, tsb.num_nodes):
Expand Down

0 comments on commit 15ee495

Please sign in to comment.