Skip to content

Commit

Permalink
Binary tree resolutions and fast sampling (#84)
Browse files Browse the repository at this point in the history
* add binary support math and tests

* format

* add tests

* now works on non-multifurcating nodes

* add count binary resolved clades and test

* allow excluding leaves in binary resolutions

* faster sampling infrastructure WIP

* expand use of PreorderTreeBuilder class for sampling and indexing

* format and lint

* fix comb function in python 3.7
  • Loading branch information
willdumm authored Jan 30, 2024
1 parent ea3a068 commit a0514e1
Show file tree
Hide file tree
Showing 9 changed files with 618 additions and 44 deletions.
7 changes: 5 additions & 2 deletions historydag/beast_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Provides utilities for parsing BEAST outputs and storing sampled trees in HistoryDag objects.
"""Provides utilities for parsing BEAST outputs and storing sampled trees in
HistoryDag objects.
This module uses ``dendropy`` to parse the newick strings found in BEAST output files, since
``ete3`` is incompatible with newick strings containing commas other than those which separate
nodes. The ``historydag`` package does not require ``dendropy``, so to use this module, you must
manually ensure that ``dendropy`` is installed in your environment."""
manually ensure that ``dendropy`` is installed in your environment.
"""

import historydag as hdag
from warnings import warn
Expand Down
14 changes: 9 additions & 5 deletions historydag/compact_genome.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""This module provides a CompactGenome class, intended as a convenient and compact representation of
a nucleotide sequence as a collection of mutations relative to a reference sequence. This object also
provides methods to conveniently mutate CompactGenome objects according to a list of mutations, produce
mutations defining the difference between two CompactGenome objects, and efficiently access the base
at a site (or the entire sequence, as a string) implied by a CompactGenome.
"""This module provides a CompactGenome class, intended as a convenient and
compact representation of a nucleotide sequence as a collection of mutations
relative to a reference sequence.
This object also provides methods to conveniently mutate CompactGenome
objects according to a list of mutations, produce mutations defining the
difference between two CompactGenome objects, and efficiently access the
base at a site (or the entire sequence, as a string) implied by a
CompactGenome.
"""

from frozendict import frozendict
Expand Down
293 changes: 268 additions & 25 deletions historydag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,106 @@ def convert(dag, newclass):
return newclass.from_history_dag(dag)


# Preorder tree creation class

TreeBuilderNode = Any


class PreorderTreeBuilder:
"""Any class implementing a PreorderTreeBuilder interface can be used as a
tree sample constructor in :meth:`HistoryDag.fast_sample`. Subclasses
implementing this interface may implement an arbitrary constructor
interface, as the user will be responsible for creating instances to be
used for sampling. In addition, subclasses must implement the following
methods:
Methods:
add_node: This method must accept a :class:HistoryDagNode object ``dag_node`` and, optionally
a TreeBuilderNode instance ``parent``, representing the parent node of the node to be added,
and returns a TreeBuilderNode instance representing the added node in the sampled
tree. TreeBuilderNode can be any type which is convenient for the internal
implementation of the PreorderTreeBuilder subclass. This method can expect to be
called on nodes in a sampled tree in a pre-ordering. A parent node will always be
provided unless `dag_node` is the root node.
get_finished_tree: This method takes no arguments and returns the data defining the
sampled tree, after any necessary clean-up or final tree construction steps. Its
return value is the return value of :meth:`HistoryDag.fast_sample`.
"""

pass


class EteTreeBuilder(PreorderTreeBuilder):
def __init__(
self,
name_func: Callable[[HistoryDagNode], str] = lambda n: "unnamed",
features: Optional[List[str]] = [],
feature_funcs: Mapping[str, Callable[[HistoryDagNode], str]] = {},
):
self.treeroot = None
self.name_func = name_func
self.feature_funcs = feature_funcs

for feature in features:

def feature_func(node):
return getattr(node.label, feature)

self.feature_funcs[feature] = feature_func

self.feature_funcs = tuple(self.feature_funcs.items())

def add_node(
self,
dag_node: HistoryDagNode,
parent: ete3.TreeNode = None,
) -> ete3.TreeNode:
# Skip the UA Node
if isinstance(dag_node, UANode):
return None
newnode = ete3.TreeNode()
newnode.name = self.name_func(dag_node)
for feature, feature_func in self.feature_funcs:
newnode.add_feature(feature, feature_func(dag_node))
if parent is None:
assert self.treeroot is None
self.treeroot = newnode
else:
parent.add_child(child=newnode)
return newnode

def get_finished_tree(self):
return self.treeroot


class PreorderHistoryBuilder(PreorderTreeBuilder):
def __init__(
self,
dag_type,
):
self.root_node = None
self.edges = []
self.dag_type = dag_type

def add_node(
self,
dag_node: HistoryDagNode,
parent: HistoryDagNode = None,
) -> HistoryDagNode:
new_node = dag_node.empty_copy()
if parent is None:
assert self.root_node is None
self.root_node = new_node
else:
self.edges.append((parent, new_node))
return new_node

def get_finished_tree(self):
for parent, child in reversed(self.edges):
parent.add_edge(child)
return self.dag_type(self.root_node)


class HistoryDag:
r"""An object to represent a collection of internally labeled trees. A
wrapper object to contain exposed HistoryDag methods and point to a
Expand Down Expand Up @@ -286,19 +386,55 @@ def __getitem__(self, key) -> "HistoryDag":
dag.trim_optimal_weight(**key)
return dag
elif isinstance(key, int):
# This call to count_histories is essential because we use
# the cached node counts later. For faster indexing, call this
# method once and call _get_subhistory_by_subid yourself.
length = self.count_histories()
if key < 0:
key = length + key
if not (key >= 0 and key < length):
raise IndexError
self.count_histories()
return self.__class__(self.dagroot._get_subhistory_by_subid(key))
builder = PreorderHistoryBuilder(type(self))
self.dagroot._get_subhistory_by_subid(key, builder)
return builder.get_finished_tree()
else:
raise TypeError(
f"History DAG indices must be integers or utils.HistoryDagFilter"
f" objects, not {type(key)}"
)

def get_histories_by_index(self, key_iterator, tree_builder_func=None):
"""Retrieving a history by index is slow, since each retrieval requires
running the ``trim_optimal_weight`` method on the entire DAG to
populate node counts. This method instead runs that method a single
time and yields a history for each index yielded by ``key_iterator``.
Args:
key_iterator: An iterator on desired history indices. May be consumable, as
it will only be used once.
tree_builder_func: A function accepting an index and returning a
:class:`PreorderTreeBuilder` instance to be used to build the history
with that index. If None (default), then tree-shaped HistoryDag objects
will be yielded using :class:`PreorderHistoryBuilder`.
"""
if tree_builder_func is None:

def tree_builder_func(idx):
return PreorderHistoryBuilder(type(self))

length = self.count_histories()

for key in key_iterator:
if key < 0:
key = length + key
if not (key >= 0 and key < length):
raise IndexError(
f"Invalid index {key} in DAG containing {length} histories"
)
builder = tree_builder_func(key)
self.dagroot._get_subhistory_by_subid(key, builder)
yield builder.get_finished_tree()

def get_label_type(self) -> type:
"""Return the type for labels on this dag's nodes."""
return type(next(self.dagroot.children()).label)
Expand Down Expand Up @@ -680,17 +816,57 @@ def find_node(
except StopIteration:
raise ValueError("No matching node found.")

def fast_sample(
self,
tree_builder: PreorderTreeBuilder = None,
log_probabilities=False,
):
"""This is a non-recursive alternative to :meth:`HistoryDag.sample`,
which is likely to be slower on small DAGs, but may allow significant
optimizations on large DAGs, or in the case that the data format being
sampled is something other than a tree-shaped HistoryDag object.
This method does not provide an edge_selector argument like :meth:`HistoryDag.sample`.
Instead, any masking of edges should be done prior to sampling using the :meth:`HistoryDag.set_sample_mask`
method, or by modifying the arguments to :meth:`HistoryDag.probability_annotate`.
Args:
tree_builder: a PreorderTreeBuilder instance to handle construction of the sampled tree.
log_probabilities: Whether edge probabilities annotated on this DAG (using, for example,
:meth:`HistoryDag.probability_annotate`) are on a log-scale.
"""
if tree_builder is None:
tree_builder = PreorderHistoryBuilder(type(self))

def get_sampled_children(node):
for clade, eset in node.clades.items():
sampled_target, _ = eset.sample(log_probabilities=log_probabilities)
yield sampled_target

node_queue = [(self.dagroot, tree_builder.add_node(self.dagroot))]

while len(node_queue) > 0:
parent, parent_repr = node_queue.pop()
for child in get_sampled_children(parent):
child_repr = tree_builder.add_node(child, parent=parent_repr)
node_queue.append((child, child_repr))

return tree_builder.get_finished_tree()

def sample(
self, edge_selector=lambda e: True, log_probabilities=False
) -> "HistoryDag":
r"""Samples a history from the history DAG. (A history is a sub-history
DAG containing the root and all leaf nodes) For reproducibility, set
``random.seed`` before sampling.
When there is an option, edges pointing to nodes on which `selection_func` is True
When there is an option, edges pointing to nodes on which `edge_selector` is True
will always be chosen.
Returns a new HistoryDag object.
To use the more general sampling pattern which allows an arbitrary PreorderTreeBuilder
object, use :meth:`HistoryDag.fast_sample` instead.
"""
return self.__class__(
self.dagroot._sample(
Expand Down Expand Up @@ -1171,29 +1347,20 @@ def to_ete(
"""
# First build a dictionary of ete3 nodes keyed by HDagNodes.
if features is None:
labelfeatures = list(
list(self.dagroot.children())[0].label._asdict().keys()
)
else:
labelfeatures = features

def etenode(node: HistoryDagNode) -> ete3.TreeNode:
newnode = ete3.TreeNode()
newnode.name = name_func(node)
for feature in labelfeatures:
newnode.add_feature(feature, getattr(node.label, feature))
for feature, func in feature_funcs.items():
newnode.add_feature(feature, func(node))
return newnode

nodedict = {node: etenode(node) for node in self.preorder(skip_ua_node=True)}
features = list(list(self.dagroot.children())[0].label._asdict().keys())

for node in nodedict:
for target in node.children():
nodedict[node].add_child(child=nodedict[target])
tree_builder = EteTreeBuilder(
name_func=name_func, features=features, feature_funcs=feature_funcs
)
nodes_to_process = [(self.dagroot, tree_builder.add_node(self.dagroot))]
while len(nodes_to_process) > 0:
node, node_repr = nodes_to_process.pop()
for child in node.children():
nodes_to_process.append(
(child, tree_builder.add_node(child, parent=node_repr))
)

# Since self is cladetree, dagroot can have only one child
return nodedict[list(self.dagroot.children())[0]]
return tree_builder.get_finished_tree()

def to_graphviz(
self,
Expand Down Expand Up @@ -2949,6 +3116,36 @@ def edge_probabilities(

return {key: aggregate_func(val) for key, val in edge_probabilities.items()}

def set_sample_mask(self, edge_selector, log_probabilities=False):
"""Zero out edge weights for masked edges before calling
:meth:`HistoryDag.fast_sample`. This should be equivalent to passing
the same edge_selector function to :meth:`HistoryDag.sample`.
Args:
edge_selector: A function accepting an edge (a tuple of HistoryDagNode objects) and
returning True of False. An edge marked False will be ineligible for sampling, unless
all other edges in the same edge set are also marked False.
log_probabilities: Since the mask is applied by modifying edge probabilities, one must specify
whether those probabilities are on a log scale.
Take care to verify that you shouldn't instead use :meth:`HistoryDag.probability_annotate` with
a choice of ``edge_weight_func`` that takes into account the masking preferences.
"""

if log_probabilities:
mask_value = float("-inf")
else:
mask_value = 0

for node in self.preorder():
for clade, eset in node.clades.items():
mask = tuple(edge_selector((node, target)) for target in eset.targets)
# If all mask values are false, then skip modifying probs.
if any(mask):
for i, val in enumerate(mask):
if not val:
eset.probs[i] = mask_value

def probability_annotate(
self,
edge_weight_func,
Expand Down Expand Up @@ -3696,7 +3893,7 @@ def history_dag_from_nodes(nodes: Sequence[HistoryDagNode]) -> HistoryDag:
ua_node = UANode(EdgeSet())
if ua_node in nodes:
ua_node = nodes[ua_node].empty_copy()
nodes.pop(ua_node)
nodes.pop(ua_node)
clade_dict = _clade_union_dict(nodes.keys())
edge_dict = {
node: [child for clade in node.clades for child in clade_dict[clade]]
Expand All @@ -3711,3 +3908,49 @@ def history_dag_from_nodes(nodes: Sequence[HistoryDagNode]) -> HistoryDag:
node.add_edge(child)

return HistoryDag(ua_node)


def make_binary_complete_dag(leaf_labels):
"""Produce a history DAG containing all binary topologies on the provided
iterable of leaf labels."""
leaf_labels = list(leaf_labels)
model_label = leaf_labels[0]
if not isinstance(model_label, tuple):
raise ValueError(
"Provided labels must be a historydag Label type (a typing.NamedTuple instance)"
)
field_values = tuple(Ellipsis for _ in model_label)
internal_label = type(model_label)(*field_values)

node_set = {UANode(EdgeSet())}
for clade in utils.powerset(leaf_labels, start_size=1):
# Now need to get all splits of this clade into two child clades
cladesize = len(clade)
for child_mask in utils.powerset(
range(cladesize), start_size=1, end_size=cladesize - 1
):
splitter_mask = [False] * cladesize
for idx in child_mask:
splitter_mask[idx] = True
clade1 = frozenset(
clade[idx] for idx, flag in enumerate(splitter_mask) if flag
)
clade2 = frozenset(
clade[idx] for idx, flag in enumerate(splitter_mask) if not flag
)
node_set.add(
HistoryDagNode(
internal_label,
{clade1: EdgeSet(), clade2: EdgeSet()},
{},
)
)
for leaf_label in leaf_labels:
node_set.add(
HistoryDagNode(
leaf_label,
{},
{},
)
)
return history_dag_from_nodes(node_set)
Loading

0 comments on commit a0514e1

Please sign in to comment.