Skip to content

Commit

Permalink
Merge pull request #356 from hyanwong/API-extras
Browse files Browse the repository at this point in the history
Use direct memory access for checks
  • Loading branch information
hyanwong committed Jan 6, 2024
2 parents 9b6b275 + faae8e1 commit 6da3a58
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 27 deletions.
34 changes: 15 additions & 19 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0):
if e.child not in self.fixednodes
}
else:
edges = self.ts.tables.edges
fixed_nodes = np.array(list(self.fixednodes))
keys = np.unique(
np.core.records.fromarrays(
(self.mut_edges, edges.right - edges.left), names="muts,span"
)[np.logical_not(np.isin(edges.child, fixed_nodes))]
(self.mut_edges, self.ts.edges_right - self.ts.edges_left),
names="muts,span",
)[np.logical_not(np.isin(self.ts.edges_child, fixed_nodes))]
)
if unique_method == 1:
self.unfixed_likelihood_cache = dict.fromkeys({tuple(t) for t in keys})
Expand Down Expand Up @@ -603,8 +603,8 @@ def __init__(self, priors, lik, *, progress=False):
self.priors.force_probability_space(lik.probability_space)

self.spans = np.bincount(
self.ts.tables.edges.child,
weights=self.ts.tables.edges.right - self.ts.tables.edges.left,
self.ts.edges_child,
weights=self.ts.edges_right - self.ts.edges_left,
)
self.spans = np.pad(self.spans, (0, self.ts.num_nodes - len(self.spans)))

Expand Down Expand Up @@ -654,15 +654,15 @@ def edges_by_child_then_parent_desc(self, grouped=True):
"""
wtype = np.dtype(
[
("child_age", self.ts.tables.nodes.time.dtype),
("child_node", self.ts.tables.edges.child.dtype),
("parent_age", self.ts.tables.nodes.time.dtype),
("child_age", self.ts.nodes_time.dtype),
("child_node", self.ts.edges_child.dtype),
("parent_age", self.ts.nodes_time.dtype),
]
)
w = np.empty(self.ts.num_edges, dtype=wtype)
w["child_age"] = self.ts.tables.nodes.time[self.ts.tables.edges.child]
w["child_node"] = self.ts.tables.edges.child
w["parent_age"] = -self.ts.tables.nodes.time[self.ts.tables.edges.parent]
w["child_age"] = self.ts.nodes_time[self.ts.edges_child]
w["child_node"] = self.ts.edges_child
w["parent_age"] = -self.ts.nodes_time[self.ts.edges_parent]
sorted_child_parent = (
self.ts.edge(i)
for i in reversed(
Expand Down Expand Up @@ -741,9 +741,7 @@ def inside_pass(self, *, standardize=True, cache_inside=False, progress=None):
if standardize:
marginal_lik = self.lik.combine(marginal_lik, denominator[parent])
if cache_inside:
self.g_i = self.lik.ratio(
g_i, denominator[self.ts.tables.edges.child, None]
)
self.g_i = self.lik.ratio(g_i, denominator[self.ts.edges_child, None])
# Keep the results in this object
self.inside = inside
self.denominator = denominator
Expand Down Expand Up @@ -792,7 +790,7 @@ def outside_pass(
for child, edges in tqdm(
self.edges_by_child_desc(),
desc="Outside",
total=len(np.unique(self.ts.tables.edges.child)),
total=len(np.unique(self.ts.edges_child)),
disable=not progress,
):
if child in self.fixednodes:
Expand Down Expand Up @@ -860,9 +858,7 @@ def outside_maximization(self, *, eps, progress=None):

mut_edges = self.lik.mut_edges
mrcas = np.where(
np.isin(
np.arange(self.ts.num_nodes), self.ts.tables.edges.child, invert=True
)
np.isin(np.arange(self.ts.num_nodes), self.ts.edges_child, invert=True)
)[0]
for i in mrcas:
if i not in self.fixednodes:
Expand All @@ -871,7 +867,7 @@ def outside_maximization(self, *, eps, progress=None):
for child, edges in tqdm(
self.edges_by_child_then_parent_desc(),
desc="Maximization",
total=len(np.unique(self.ts.tables.edges.child)),
total=len(np.unique(self.ts.edges_child)),
disable=not progress,
):
if child in self.fixednodes:
Expand Down
8 changes: 3 additions & 5 deletions tsdate/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False):

self.ts = tree_sequence
self.sample_node_set = set(self.ts.samples())
if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
if np.any(self.ts.nodes_time[self.ts.samples()] != 0):
raise ValueError(
"The SpansBySamples class needs a tree seq with all samples at time 0"
)
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def fill_priors(
# convert timepoints to generational timescale
prior_times = base.NodeGridValues(
ts.num_nodes,
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
datable_nodes[np.argsort(ts.nodes_time[datable_nodes])].astype(np.int32),
population_size.to_natural_timescale(timepoints),
)

Expand Down Expand Up @@ -1169,9 +1169,7 @@ def make_parameter_grid(self, population_size, progress=False):

prior_pars = base.NodeGridValues(
self.tree_sequence.num_nodes,
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(
np.int32
),
datable_nodes[np.argsort(ts.nodes_time[datable_nodes])].astype(np.int32),
np.array([0, np.inf]),
)
prior_pars.probability_space = base.GAMMA_PAR
Expand Down
6 changes: 3 additions & 3 deletions tsdate/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def reduce_to_contemporaneous(ts):
Simplify the ts to only the contemporaneous samples, and return the new ts + node map
"""
samples = ts.samples()
contmpr_samples = samples[ts.tables.nodes.time[samples] == 0]
contmpr_samples = samples[ts.nodes_time[samples] == 0]
return ts.simplify(
contmpr_samples,
map_nodes=True,
Expand Down Expand Up @@ -187,7 +187,7 @@ def nodes_time_unconstrained(tree_sequence):
stored in the node metadata). Will produce an error if the tree sequence does
not contain this information.
"""
nodes_time = tree_sequence.tables.nodes.time.copy()
nodes_time = tree_sequence.nodes_time.copy()
metadata = tree_sequence.tables.nodes.metadata
metadata_offset = tree_sequence.tables.nodes.metadata_offset
for index, met in enumerate(tskit.unpack_bytes(metadata, metadata_offset)):
Expand Down Expand Up @@ -270,7 +270,7 @@ def sites_time_from_ts(
e.args += "Try calling sites_time_from_ts() with unconstrained=False."
raise
else:
nodes_time = tree_sequence.tables.nodes.time
nodes_time = tree_sequence.nodes_time
sites_time = np.full(tree_sequence.num_sites, np.nan)

for tree in tree_sequence.trees():
Expand Down

0 comments on commit 6da3a58

Please sign in to comment.