Skip to content

Commit

Permalink
Switch to only storing fixed node data in Prior
Browse files Browse the repository at this point in the history
Prior is now a subclass of NodeGridValues, with the ability to store a time (*not* a likelihood) for fixed nodes.
  • Loading branch information
hyanwong committed Feb 29, 2020
1 parent 6c297ef commit c6ae1a6
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 211 deletions.
181 changes: 106 additions & 75 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
import tsinfer

import tsdate
from tsdate.date import (SpansBySamples, PriorParams,
from tsdate.date import (SpansBySamples, PriorParams, LIN, LOG,
ConditionalCoalescentTimes, fill_prior, Likelihoods,
LogLikelihoods, LogLikelihoodsStreaming, InOutAlgorithms,
NodeGridValues, gamma_approx, constrain_ages_topo) # NOQA
Prior, gamma_approx, constrain_ages_topo) # NOQA

from tests import utility_functions

Expand Down Expand Up @@ -192,6 +192,13 @@ def test_dangling_nodes_fail(self):
tables.nodes.flags = flags
self.assertRaises(ValueError, self.verify_weights, tables.tree_sequence())

def test_simple_non_contemporaneous(self):
ts = utility_functions.two_tree_ts_n3_non_contemporaneous()
n = len([s for s in ts.samples() if ts.node(s).time == 0])
span_data = self.verify_weights(ts)
self.assertEqual(span_data.lookup_weight(4, n, 2), 0.2) # 2 contemporanous tips
self.assertEqual(span_data.lookup_weight(4, n, 1), 0.8) # only 1 contemporanous

@unittest.skip("YAN to fix")
def test_truncated_nodes(self):
Ne = 1e2
Expand Down Expand Up @@ -334,9 +341,10 @@ class TestMixturePrior(unittest.TestCase):
def get_mixture_prior_params(self, ts, prior_distr):
span_data = SpansBySamples(ts)
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
priors.add(ts.num_samples, approximate=False)
for total_fixed in span_data.total_fixed_at_0_counts:
priors.add(total_fixed, approximate=False)
mixture_prior = priors.get_mixture_prior_params(span_data)
return(mixture_prior)
return mixture_prior

def test_one_tree_n2(self):
ts = utility_functions.single_tree_ts_n2()
Expand Down Expand Up @@ -417,12 +425,19 @@ def test_two_tree_mutation_ts(self):
self.assertTrue(
np.allclose(mixture_prior[5, self.alpha_beta], [1.6, 1.2]))

def test_simple_non_contemporaneous(self):
ts = utility_functions.two_tree_ts_n3_non_contemporaneous()
mixture_prior = self.get_mixture_prior_params(ts, 'gamma')
self.assertTrue(
np.allclose(mixture_prior[4, self.alpha_beta], [0.11111, 0.55555]))


class TestPriorVals(unittest.TestCase):
def verify_prior_vals(self, ts, prior_distr):
span_data = SpansBySamples(ts)
priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr)
priors.add(ts.num_samples, approximate=False)
for total_fixed in span_data.total_fixed_at_0_counts:
priors.add(total_fixed, approximate=False)
grid = np.linspace(0, 3, 3)
mixture_prior = priors.get_mixture_prior_params(span_data)
nodes_to_date = span_data.nodes_to_date
Expand Down Expand Up @@ -469,6 +484,11 @@ def test_tree_with_unary_nodes(self):
self.assertTrue(np.allclose(prior_vals[4], [0, 1, 0.093389]))
self.assertTrue(np.allclose(prior_vals[3], [0, 1, 0.011109]))

def test_simple_non_contemporaneous(self):
ts = utility_functions.two_tree_ts_n3_non_contemporaneous()
prior_vals = self.verify_prior_vals(ts, 'gamma')
self.assertEqual(prior_vals.fixed_time(2), ts.node(2).time)


class TestLikelihoodClass(unittest.TestCase):
def poisson(self, l, x, normalize=True):
Expand Down Expand Up @@ -669,103 +689,114 @@ def test_logsumexp_streaming(self):
np.log(ll_sum)))


class TestNodeGridValuesClass(unittest.TestCase):
# TODO - needs a few more tests in here
class TestPriorClass(unittest.TestCase):
def test_init(self):
num_nodes = 5
ids = np.array([3, 4])
nodetimes = np.ones(5)
nonfixed_ids = np.array([3, 2])
timepoints = np.array(range(10))
store = NodeGridValues(num_nodes, ids, timepoints, fill_value=6)
self.assertEquals(store.grid_data.shape, (len(ids), len(timepoints)))
self.assertEquals(len(store.fixed_data), (num_nodes-len(ids)))
store = Prior(
timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids, fill_value=6)
self.assertEquals(store.grid_data.shape, (len(nonfixed_ids), len(timepoints)))
self.assertEquals(len(store.fixed_times), (len(nodetimes)-len(nonfixed_ids)))
self.assertTrue(np.all(store.grid_data == 6))
self.assertTrue(np.all(store.fixed_data == 6))
self.assertTrue(np.all(store.fixed_times == 1))
for i in range(len(nodetimes)):
if i in nonfixed_ids:
self.assertTrue(np.all(store[i] == 6))
self.assertRaises(IndexError, store.fixed_time, i)
else:
self.assertEqual(store.fixed_time(i), 1)
with self.assertRaises(IndexError):
_ = store[i]

ids = np.array([3, 4], dtype=np.int32)
store = NodeGridValues(num_nodes, ids, timepoints, fill_value=5)
self.assertEquals(store.grid_data.shape, (len(ids), len(timepoints)))
self.assertEquals(len(store.fixed_data), num_nodes-len(ids))
self.assertTrue(np.all(store.fixed_data == 5))
def test_probability_spaces(self):
nodetimes = np.ones(5)
nonfixed_ids = np.array([3, 4])
timepoints = np.array(range(10))
store = Prior(
timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids, fill_value=0.5)
self.assertTrue(np.all(store.grid_data == 0.5))
store.force_probability_space(LIN)
self.assertTrue(np.all(store.grid_data == 0.5))
store.force_probability_space(LOG)
self.assertTrue(np.allclose(store.grid_data, np.log(0.5)))
store.force_probability_space(LOG)
self.assertTrue(np.allclose(store.grid_data, np.log(0.5)))
store.force_probability_space(LIN)
self.assertTrue(np.all(store.grid_data == 0.5))
self.assertRaises(ValueError, store.force_probability_space, "foobar")

def test_set_and_get(self):
num_nodes = 5
grid_size = 2
nodetimes = np.ones(5)
timepoints = [0, 1.1]
fill = {}
for ids in ([3, 4], []):
for nonfixed_ids in ([3, 4], [0]):
np.random.seed(1)
store = NodeGridValues(
num_nodes, np.array(ids, dtype=np.int32), np.array(range(grid_size)))
for i in range(num_nodes):
fill[i] = np.random.random(grid_size if i in ids else None)
store[i] = fill[i]
for i in range(num_nodes):
store = Prior(timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids)
for i in range(len(nodetimes)):
fill[i] = np.random.random(len(store.timepoints))
if i in nonfixed_ids:
store[i] = fill[i]
else:
with self.assertRaises(IndexError):
store[i] = fill[i]
for i in nonfixed_ids:
self.assertTrue(np.all(fill[i] == store[i]))
self.assertRaises(IndexError, store.__getitem__, num_nodes)

def test_bad_init(self):
ids = [3, 4]
self.assertRaises(ValueError, NodeGridValues, 3, np.array(ids),
np.array([0, 1.2, 2]))
self.assertRaises(AttributeError, NodeGridValues, 5, np.array(ids), -1)
self.assertRaises(ValueError, NodeGridValues, 5, np.array([-1]),
np.array([0, 1.2, 2]))
timepoints = [0, 1.2, 2]
nodetimes = np.ones(5)
nonfixed_ids = [4, 0]
Prior(timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids)
# ids > nodetimes
self.assertRaises(
ValueError, Prior, timepoints, nodetimes=nodetimes, gridnodes=[4, 5])
# duplicate ids
self.assertRaises(
ValueError, Prior, timepoints, nodetimes=nodetimes, gridnodes=[4, 4, 0])
# bad ids
self.assertRaises(
ValueError, Prior, timepoints, nodetimes=nodetimes,
gridnodes=np.array([[1, 4], [2, 0]]))
self.assertRaises(
OverflowError, Prior, timepoints, nodetimes=nodetimes, gridnodes=[-1, 4])
# bad timepoint
self.assertRaises(
ValueError, Prior, [], nodetimes=nodetimes, gridnodes=nonfixed_ids)
# bad nodetimes
self.assertRaises(
ValueError, Prior, timepoints, nodetimes=[], gridnodes=nonfixed_ids)

def test_clone(self):
num_nodes = 10
grid_size = 2
ids = [3, 4]
orig = NodeGridValues(num_nodes, np.array(ids), np.array(range(grid_size)))
timepoints = [0, 1]
nodetimes = np.ones(5)
nonfixed_ids = [3, 4]
orig = Prior(timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids)
orig[3] = np.array([1, 2])
orig[4] = np.array([4, 3])
orig[0] = 1.5
orig[9] = 2.5
# test with np.zeros
clone = NodeGridValues.clone_with_new_data(orig, 0)
clone = orig.clone_grid_with_new_data(0)
self.assertEquals(clone.grid_data.shape, orig.grid_data.shape)
self.assertEquals(clone.fixed_data.shape, orig.fixed_data.shape)
self.assertTrue(np.all(clone.grid_data == 0))
self.assertTrue(np.all(clone.fixed_data == 0))
# test with something else
clone = NodeGridValues.clone_with_new_data(orig, 5)
clone = orig.clone_grid_with_new_data(5)
self.assertEquals(clone.grid_data.shape, orig.grid_data.shape)
self.assertEquals(clone.fixed_data.shape, orig.fixed_data.shape)
self.assertTrue(np.all(clone.grid_data == 5))
self.assertTrue(np.all(clone.fixed_data == 5))
# test with different
scalars = np.arange(num_nodes - len(ids))
clone = NodeGridValues.clone_with_new_data(orig, 0, scalars)
self.assertEquals(clone.grid_data.shape, orig.grid_data.shape)
self.assertEquals(clone.fixed_data.shape, orig.fixed_data.shape)
self.assertTrue(np.all(clone.grid_data == 0))
self.assertTrue(np.all(clone.fixed_data == scalars))

clone = NodeGridValues.clone_with_new_data(
orig, np.array([[1, 2], [4, 3]]))
for i in range(num_nodes):
if i in ids:
self.assertTrue(np.all(clone[i] == orig[i]))
else:
self.assertTrue(np.isnan(clone[i]))
clone = NodeGridValues.clone_with_new_data(
orig, np.array([[1, 2], [4, 3]]), 0)
for i in range(num_nodes):
if i in ids:
clone = orig.clone_grid_with_new_data(np.array([[1, 2], [4, 3]]))
for i in range(len(nodetimes)):
if i in nonfixed_ids:
self.assertTrue(np.all(clone[i] == orig[i]))
else:
self.assertEquals(clone[i], 0)
self.assertRaises(IndexError, clone.__getitem__, i)

def test_bad_clone(self):
num_nodes = 10
ids = [3, 4]
orig = NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]))
self.assertRaises(
ValueError,
NodeGridValues.clone_with_new_data,
orig, np.array([[1, 2, 3], [4, 5, 6]]))
nodetimes = np.zeros(10)
ids = np.array([3, 4])
timepoints = np.array([0, 1.2])
orig = Prior(timepoints, nodetimes=nodetimes, gridnodes=ids)
self.assertRaises(
ValueError,
NodeGridValues.clone_with_new_data,
orig, 0, np.array([[1, 2], [4, 5]]))
ValueError, orig.clone_grid_with_new_data, np.array([[1, 2, 3], [4, 5, 6]]))


class TestInsideAlgorithm(unittest.TestCase):
Expand Down
9 changes: 2 additions & 7 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@ class TestPrebuilt(unittest.TestCase):
Tests for tsdate on prebuilt tree sequences
"""
def test_dangling_failure(self):
ts = utility_functions.single_tree_ts_n3()
# Mark node 0 as a non-sample node, which should make it dangling
tables = ts.dump_tables()
flags = tables.nodes.flags
flags[0] = flags[0] & (~tskit.NODE_IS_SAMPLE)
tables.nodes.flags = flags
self.assertRaises(ValueError, tsdate.date, tables.tree_sequence(), Ne=1)
ts = utility_functions.single_tree_ts_n2_dangling()
self.assertRaises(ValueError, tsdate.date, ts, Ne=1)

def test_unary_warning(self):
with self.assertLogs(level="WARNING") as log:
Expand Down
Loading

0 comments on commit c6ae1a6

Please sign in to comment.