Skip to content

Commit

Permalink
Calculate likelihoods for non-contempory nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Mar 3, 2020
1 parent ac19376 commit de1ae7f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 23 deletions.
29 changes: 26 additions & 3 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_larger_find_node_tip_weights(self):
self.verify_weights(ts)

def test_dangling_nodes_warn(self):
ts = utility_functions.single_tree_ts_n3_dangling()
ts = utility_functions.single_tree_ts_n2_dangling()
with self.assertLogs(level="WARNING") as log:
self.verify_weights(ts)
self.assertGreater(len(log.output), 0)
Expand Down Expand Up @@ -434,6 +434,17 @@ def test_simple_non_contemporaneous(self):
self.assertTrue(
np.allclose(mixture_prior[4, self.alpha_beta], [0.11111, 0.55555]))

def test_simulated_non_contemporaneous(self):
samples = [
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0)
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123)
self.get_mixture_prior_params(ts, 'lognorm')
self.get_mixture_prior_params(ts, 'gamma')


class TestPriorVals(unittest.TestCase):
def verify_prior_vals(self, ts, prior_distr):
Expand Down Expand Up @@ -490,6 +501,18 @@ def test_simple_non_contemporaneous(self):
prior_vals = self.verify_prior_vals(ts, 'gamma')
self.assertEqual(prior_vals.fixed_time(2), ts.node(2).time)

def test_simulated_non_contemporaneous(self):
samples = [
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0)
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123)
prior_vals = self.verify_prior_vals(ts, 'gamma')
print(prior_vals.timepoints)
raise


class TestLikelihoodClass(unittest.TestCase):
def poisson(self, l, x, normalize=True):
Expand Down Expand Up @@ -789,7 +812,7 @@ def test_nonmatching_prior_vs_lik_timepoints(self):

def test_nonmatching_prior_vs_lik_fixednodes(self):
ts1 = utility_functions.single_tree_ts_n3()
ts2 = utility_functions.single_tree_ts_n3_dangling()
ts2 = utility_functions.single_tree_ts_n2_dangling()
timepoints = np.array([0, 1.2, 2])
prior = tsdate.build_prior_grid(ts1, timepoints)
lls = Likelihoods(ts2, prior.timepoints)
Expand Down Expand Up @@ -901,7 +924,7 @@ def test_two_tree_mutation_ts(self):
self.assertTrue(np.allclose(algo.inside[5], np.array([0, 7.06320034e-11, 1])))

def test_dangling_fails(self):
ts = utility_functions.single_tree_ts_n3_dangling()
ts = utility_functions.single_tree_ts_n2_dangling()
print(ts.draw_text())
print("Samples:", ts.samples())
prior = tsdate.build_prior_grid(ts, timepoints=np.array([0, 1.2, 2]))
Expand Down
11 changes: 9 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_unary_warning(self):
self.assertEqual(len(log.output), 1)
self.assertIn("unary nodes", log.output[0])

def test_fails_with_recombination(self):
def test_fails_with_recombination_clock(self):
ts = utility_functions.two_tree_mutation_ts()
for probability_space in (LOG, LIN):
self.assertRaises(
Expand All @@ -58,6 +58,12 @@ def test_fails_with_recombination(self):
NotImplementedError, tsdate.date, ts, Ne=1, recombination_rate=1,
probability_space=probability_space, mutation_rate=1)

def test_non_contemporaneous(self):
ts = utility_functions.two_tree_ts_n3_non_contemporaneous()
theta = 2
ts = msprime.mutate(ts, rate=theta)
tsdate.date(ts, Ne=1, mutation_rate=theta, probability_space=LIN)

# def test_simple_ts_n2(self):
# ts = utility_functions.single_tree_ts_n2()
# dated_ts = tsdate.date(ts, Ne=10000)
Expand Down Expand Up @@ -209,7 +215,8 @@ def test_non_contemporaneous(self):
msprime.Sample(population=0, time=0),
msprime.Sample(population=0, time=1.0)
]
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2)
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123)
print(ts.draw_text())
self.assertRaises(NotImplementedError, tsdate.date, ts, 1, 2)

@unittest.skip("YAN to fix")
Expand Down
34 changes: 16 additions & 18 deletions tsdate/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,10 +1066,10 @@ def fill_prior(node_parameters, timepoints, ts, *, prior_distr, progress=False):
datable_nodes = np.ones(ts.num_nodes, dtype=bool)
datable_nodes[ts.samples()] = False
datable_nodes = np.where(datable_nodes)[0]
prior_times = NodeGridValues(
ts.num_nodes,
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
timepoints)
# Sort by time
datable_nodes = datable_nodes[
np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32)
prior_times = NodeGridValues(timepoints, gridnodes=datable_nodes)

# TO DO - this can probably be done in an single numpy step rather than a for loop
for node in tqdm(datable_nodes, desc="Assign Prior to Each Node",
Expand Down Expand Up @@ -1259,10 +1259,12 @@ def get_mut_lik_fixed_node(self, edge):

mutations_on_edge = self.mut_edges[edge.id]
child_time = self.ts.node(edge.child).time
assert child_time == 0
# Temporary hack - we should really take a more precise likelihood
return self._lik(mutations_on_edge, edge_span(edge), self.timediff, self.theta,
normalize=self.normalize)
timediff = self.timediff - child_time
mask = timediff > 0
lik = np.full(len(timediff), self.null_constant, dtype=FLOAT_DTYPE)
lik[mask] = self._lik(mutations_on_edge, edge_span(edge), timediff[mask],
self.theta, normalize=self.normalize)
return lik

def get_mut_lik_lower_tri(self, edge):
"""
Expand Down Expand Up @@ -1531,8 +1533,8 @@ class InOutAlgorithms:
Contains the inside and outside algorithms
"""
def __init__(self, prior, lik, *, progress=False):
if (lik.fixednodes.intersection(prior.nonfixed_nodes) or
len(lik.fixednodes) + len(prior.nonfixed_nodes) != lik.ts.num_nodes):
if (lik.fixednodes.intersection(prior.gridnodes) or
len(lik.fixednodes) + len(prior.gridnodes) != lik.ts.num_nodes):
raise ValueError(
"The prior and likelihood objects disagree on which nodes are fixed")
if not np.allclose(lik.timepoints, prior.timepoints):
Expand Down Expand Up @@ -1641,8 +1643,8 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
if np.ndim(inside_values) == 0 or np.all(np.isnan(inside_values)):
# Child appears fixed, or we have not visited it. Either our
# edge order is wrong (bug) or we have hit a dangling node
raise ValueError("The input tree sequence includes "
"dangling nodes: please simplify it")
raise ValueError("Node {} appears to be dangling: please "
"simplify the tree sequence".format(edge.child))
daughter_val = self.lik.scale_geometric(
spanfrac, self.lik.make_lower_tri(inside[edge.child]))
edge_lik = self.lik.get_inside(daughter_val, edge)
Expand Down Expand Up @@ -1834,7 +1836,8 @@ def build_prior_grid(tree_sequence, timepoints=20, *, approximate_prior=False,
time slices at which to evaluate node age.
:param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence`, treated as
undated
undated. Currently, only the samples at time 0 are used to create the conditional
coalescent prior.
:param int_or_array_like timepoints: The number of quantiles used to create the
time slices, or manually-specified time slices as a numpy array
:param bool approximate_prior: Whether to use a precalculated approximate prior or
Expand Down Expand Up @@ -1964,11 +1967,6 @@ def get_dates(
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
"""
# Stuff yet to be implemented. These can be deleted once fixed
for sample in tree_sequence.samples():
if tree_sequence.node(sample).time != 0:
raise NotImplementedError(
"Sample {} is not at time 0".format(sample))
fixed_nodes = set(tree_sequence.samples())

# Default to not creating approximate prior unless ts has > 1000 samples
Expand Down

0 comments on commit de1ae7f

Please sign in to comment.