From faae8e1df4312b679fe506c74aa7e80059315c4f Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Sat, 6 Jan 2024 20:53:33 +0000 Subject: [PATCH] Use direct memory access for checks --- tsdate/core.py | 34 +++++++++++++++------------------- tsdate/prior.py | 8 +++----- tsdate/util.py | 6 +++--- 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index 1dd36fce..b11be21e 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -180,12 +180,12 @@ def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0): if e.child not in self.fixednodes } else: - edges = self.ts.tables.edges fixed_nodes = np.array(list(self.fixednodes)) keys = np.unique( np.core.records.fromarrays( - (self.mut_edges, edges.right - edges.left), names="muts,span" - )[np.logical_not(np.isin(edges.child, fixed_nodes))] + (self.mut_edges, self.ts.edges_right - self.ts.edges_left), + names="muts,span", + )[np.logical_not(np.isin(self.ts.edges_child, fixed_nodes))] ) if unique_method == 1: self.unfixed_likelihood_cache = dict.fromkeys({tuple(t) for t in keys}) @@ -602,8 +602,8 @@ def __init__(self, priors, lik, *, progress=False): self.priors.force_probability_space(lik.probability_space) self.spans = np.bincount( - self.ts.tables.edges.child, - weights=self.ts.tables.edges.right - self.ts.tables.edges.left, + self.ts.edges_child, + weights=self.ts.edges_right - self.ts.edges_left, ) self.spans = np.pad(self.spans, (0, self.ts.num_nodes - len(self.spans))) @@ -653,15 +653,15 @@ def edges_by_child_then_parent_desc(self, grouped=True): """ wtype = np.dtype( [ - ("child_age", self.ts.tables.nodes.time.dtype), - ("child_node", self.ts.tables.edges.child.dtype), - ("parent_age", self.ts.tables.nodes.time.dtype), + ("child_age", self.ts.nodes_time.dtype), + ("child_node", self.ts.edges_child.dtype), + ("parent_age", self.ts.nodes_time.dtype), ] ) w = np.empty(self.ts.num_edges, dtype=wtype) - w["child_age"] = self.ts.tables.nodes.time[self.ts.tables.edges.child] - w["child_node"] = self.ts.tables.edges.child - w["parent_age"] = -self.ts.tables.nodes.time[self.ts.tables.edges.parent] + w["child_age"] = self.ts.nodes_time[self.ts.edges_child] + w["child_node"] = self.ts.edges_child + w["parent_age"] = -self.ts.nodes_time[self.ts.edges_parent] sorted_child_parent = ( self.ts.edge(i) for i in reversed( @@ -740,9 +740,7 @@ def inside_pass(self, *, standardize=True, cache_inside=False, progress=None): if standardize: marginal_lik = self.lik.combine(marginal_lik, denominator[parent]) if cache_inside: - self.g_i = self.lik.ratio( - g_i, denominator[self.ts.tables.edges.child, None] - ) + self.g_i = self.lik.ratio(g_i, denominator[self.ts.edges_child, None]) # Keep the results in this object self.inside = inside self.denominator = denominator @@ -791,7 +789,7 @@ def outside_pass( for child, edges in tqdm( self.edges_by_child_desc(), desc="Outside", - total=len(np.unique(self.ts.tables.edges.child)), + total=len(np.unique(self.ts.edges_child)), disable=not progress, ): if child in self.fixednodes: @@ -859,9 +857,7 @@ def outside_maximization(self, *, eps, progress=None): mut_edges = self.lik.mut_edges mrcas = np.where( - np.isin( - np.arange(self.ts.num_nodes), self.ts.tables.edges.child, invert=True - ) + np.isin(np.arange(self.ts.num_nodes), self.ts.edges_child, invert=True) )[0] for i in mrcas: if i not in self.fixednodes: @@ -870,7 +866,7 @@ def outside_maximization(self, *, eps, progress=None): for child, edges in tqdm( self.edges_by_child_then_parent_desc(), desc="Maximization", - total=len(np.unique(self.ts.tables.edges.child)), + total=len(np.unique(self.ts.edges_child)), disable=not progress, ): if child in self.fixednodes: diff --git a/tsdate/prior.py b/tsdate/prior.py index 72a7087e..e25ab3c9 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -467,7 +467,7 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False): self.ts = tree_sequence self.sample_node_set = set(self.ts.samples()) - if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0): + if np.any(self.ts.nodes_time[self.ts.samples()] != 0): raise ValueError( "The SpansBySamples class needs a tree seq with all samples at time 0" ) @@ -1032,7 +1032,7 @@ def fill_priors( # convert timepoints to generational timescale prior_times = base.NodeGridValues( ts.num_nodes, - datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32), + datable_nodes[np.argsort(ts.nodes_time[datable_nodes])].astype(np.int32), population_size.to_natural_timescale(timepoints), ) @@ -1167,9 +1167,7 @@ def make_parameter_grid(self, population_size, progress=False): prior_pars = base.NodeGridValues( self.tree_sequence.num_nodes, - datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype( - np.int32 - ), + datable_nodes[np.argsort(ts.nodes_time[datable_nodes])].astype(np.int32), np.array([0, np.inf]), ) prior_pars.probability_space = base.GAMMA_PAR diff --git a/tsdate/util.py b/tsdate/util.py index 5b99ae3f..228f8e43 100644 --- a/tsdate/util.py +++ b/tsdate/util.py @@ -39,7 +39,7 @@ def reduce_to_contemporaneous(ts): Simplify the ts to only the contemporaneous samples, and return the new ts + node map """ samples = ts.samples() - contmpr_samples = samples[ts.tables.nodes.time[samples] == 0] + contmpr_samples = samples[ts.nodes_time[samples] == 0] return ts.simplify( contmpr_samples, map_nodes=True, @@ -187,7 +187,7 @@ def nodes_time_unconstrained(tree_sequence): stored in the node metadata). Will produce an error if the tree sequence does not contain this information. """ - nodes_time = tree_sequence.tables.nodes.time.copy() + nodes_time = tree_sequence.nodes_time.copy() metadata = tree_sequence.tables.nodes.metadata metadata_offset = tree_sequence.tables.nodes.metadata_offset for index, met in enumerate(tskit.unpack_bytes(metadata, metadata_offset)): @@ -270,7 +270,7 @@ def sites_time_from_ts( e.args += "Try calling sites_time_from_ts() with unconstrained=False." raise else: - nodes_time = tree_sequence.tables.nodes.time + nodes_time = tree_sequence.nodes_time sites_time = np.full(tree_sequence.num_sites, np.nan) for tree in tree_sequence.trees():