Skip to content

Commit

Permalink
Wd-determinism (#22)
Browse files Browse the repository at this point in the history
* WIP

* A tentative fix

* remove testing comments

* discard find_trees changes

* format and lint

* revert to slow count_topologies for gctree
  • Loading branch information
willdumm authored Jul 5, 2022
1 parent 22e7e9c commit aa251c5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
61 changes: 47 additions & 14 deletions historydag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,30 @@ def __eq__(self, other: object) -> bool:
else:
raise NotImplementedError

def __le__(self, other: object) -> bool:
if isinstance(other, HistoryDagNode):
return (self.label, self.sorted_partitions()) <= (
other.label,
other.sorted_partitions(),
)
else:
raise NotImplementedError

def __lt__(self, other: object) -> bool:
if isinstance(other, HistoryDagNode):
return (self.label, self.sorted_partitions()) < (
other.label,
other.sorted_partitions(),
)
else:
raise NotImplementedError

def __gt__(self, other: object) -> bool:
return not self.__le__(other)

def __ge__(self, other: object) -> bool:
return not self.__lt__(other)

def node_self(self) -> "HistoryDagNode":
"""Returns a HistoryDagNode object with the same clades and label, but
no descendant edges."""
Expand Down Expand Up @@ -94,6 +118,11 @@ def partitions(self) -> frozenset:
frozenset if this node is a UANode."""
return frozenset(self.clades.keys())

def sorted_partitions(self) -> tuple:
"""Returns the node's child clades as a sorted tuple containing leaf
labels in sorted tuples."""
return tuple(sorted([tuple(sorted(clade)) for clade in self.clades.keys()]))

def children(
self, clade: Set[Label] = None
) -> Generator["HistoryDagNode", None, None]:
Expand Down Expand Up @@ -342,22 +371,21 @@ def __getstate__(self) -> Dict:
* label_list: labels used in nodes, without duplicates. Indices are
mapped to nodes in node_list
* node_list: node tuples containing
(node label index in label_list, frozenset of frozensets of leaf label indices, node.attr).
(node label index in label_list, tuple of frozensets of leaf label indices, node.attr).
* edge_list: a tuple for each edge:
(origin node index, target node index, edge weight, edge probability)"""
label_fields = list(self.dagroot.children())[0].label._fields
label_list: List[Optional[Tuple]] = []
node_list: List[Tuple] = []
edge_list: List[Tuple] = []
label_indices: Dict[Label, int] = {}
node_indices = {id(node): idx for idx, node in enumerate(self.postorder())}
node_indices = {node: idx for idx, node in enumerate(self.postorder())}

def cladesets(node):
clades = {
return tuple(
frozenset({label_indices[label] for label in clade})
for clade in node.clades
}
return frozenset(clades)
)

for node in self.postorder():
if node.label not in label_indices:
Expand All @@ -374,7 +402,7 @@ def cladesets(node):
edge_list.append(
(
node_idx,
node_indices[id(target)],
node_indices[target],
eset.weights[idx],
eset.probs[idx],
)
Expand Down Expand Up @@ -434,9 +462,12 @@ def get_trees(self) -> Generator["HistoryDag", None, None]:
yield HistoryDag(cladetree)

def sample(self) -> "HistoryDag":
r"""Samples a clade tree from the history DAG.
(A clade tree is a sub-history DAG containing the root and all
leaf nodes). Returns a new HistoryDagNode object."""
r"""Samples a history from the history DAG.
(A history is a sub-history DAG containing the root and all
leaf nodes)
For reproducibility, set ``random.seed`` before sampling.
Returns a new HistoryDag object."""
return HistoryDag(self.dagroot._sample())

def unlabel(self) -> "HistoryDag":
Expand Down Expand Up @@ -794,7 +825,7 @@ def is_ambiguous(label):
# Add all edges into and out of node to newnode
for target in node.children():
newnode.add_edge(target)
for parent in node.parents:
for parent in sorted(node.parents):
parent.add_edge(newnode)
# Delete old node
node.remove_node(nodedict=nodedict)
Expand Down Expand Up @@ -1017,7 +1048,7 @@ def to_newicks(self, **kwargs):
newicks = self.weight_count(**utils.make_newickcountfuncs(**kwargs)).elements()
return [newick[1:-1] + ";" for newick in newicks]

def count_topologies_with_newicks(self, collapse_leaves: bool = False) -> int:
def count_topologies(self, collapse_leaves: bool = False) -> int:
"""Counts the number of unique topologies in the history DAG. This is
achieved by counting the number of unique newick strings with only
leaves labeled.
Expand All @@ -1042,11 +1073,14 @@ def count_topologies_with_newicks(self, collapse_leaves: bool = False) -> int:
)
return len(self.weight_count(**kwargs))

def count_topologies(self) -> int:
def count_topologies_fast(self) -> int:
"""Counts the number of unique topologies in the history DAG.
This is achieved by creating a new history DAG in which all
internal nodes have matching labels.
This is only guaranteed to match the output of ``count_topologies_with_newicks``
if the DAG has all allowed edges added.
"""
return self.unlabel().count_trees()

Expand Down Expand Up @@ -1383,7 +1417,7 @@ def convert_to_collapsed(self):
# no need for recursion here, all of its parents had
# edges added to new parent from the same clade.
upclade = parent.under_clade()
for grandparent in parent.parents:
for grandparent in sorted(parent.parents):
grandparent.remove_edge_by_clade_and_id(parent, upclade)
for child2 in parent.children():
child2.parents.remove(parent)
Expand Down Expand Up @@ -1633,7 +1667,6 @@ def _unrooted_from_tree(tree):

dag = _unrooted_from_tree(tree)
dagroot = UANode(EdgeSet([dag], weights=[tree.dist]))
dagroot.add_edge(dag, weight=0)
return HistoryDag(dagroot)


Expand Down
6 changes: 3 additions & 3 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def test_count_topologies():
for tree in dag.get_trees()
}
print(checkset)
assert dag.count_topologies() == len(checkset)
assert dag.count_topologies_fast() == len(checkset)


def test_count_topologies_equals_newicks():
for dag in dags:
assert dag.count_topologies() == dag.count_topologies_with_newicks()
assert dag.count_topologies_fast() == dag.count_topologies()


def test_parsimony():
Expand Down Expand Up @@ -293,7 +293,7 @@ def test_topology_count_collapse():
)
)
)
assert dag.count_topologies_with_newicks(collapse_leaves=True) == 2
assert dag.count_topologies(collapse_leaves=True) == 2


# this tests is each of the trees indexed are valid subtrees
Expand Down

0 comments on commit aa251c5

Please sign in to comment.