From ee7f3b573360352b10f579c1320990bec3ee5299 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Wed, 16 Oct 2024 09:37:09 +0100 Subject: [PATCH] Implement tree.ancestors Fixes #2706 --- docs/python-api.md | 1 + python/CHANGELOG.rst | 1 + python/tests/test_balance_metrics.py | 11 +---------- python/tests/test_highlevel.py | 17 +++++++++++++++++ python/tskit/trees.py | 10 ++++++++++ 5 files changed, 30 insertions(+), 10 deletions(-) diff --git a/docs/python-api.md b/docs/python-api.md index 1713c1344d..12f40ada6b 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -551,6 +551,7 @@ Iterator access .. autosummary:: Tree.nodes + Tree.ancestors Array access .. autosummary:: diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index a166f46646..ff20dd06b7 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -89,6 +89,7 @@ - Add comma separation to all display numbers. (:user:`benjeffery`, :issue:`3017`, :pr:`3018`) +- Added ``Tree.ancestors(u)`` method. (:user:`hyanwong`, :issue:`2706`, :pr:`3021`) - Add ``resources`` section to provenance schema. (:user:`benjeffery`, :pr:`3016`) diff --git a/python/tests/test_balance_metrics.py b/python/tests/test_balance_metrics.py index dc77f95e6b..eed20477b2 100644 --- a/python/tests/test_balance_metrics.py +++ b/python/tests/test_balance_metrics.py @@ -35,15 +35,6 @@ # we can remove this. -def node_path(tree, u): - path = [] - u = tree.parent(u) - while u != tskit.NULL: - path.append(u) - u = tree.parent(u) - return path - - def sackin_index_definition(tree): return sum(tree.depth(u) for u in tree.leaves()) @@ -79,7 +70,7 @@ def b2_index_definition(tree, base=10): if tree.num_roots != 1: raise ValueError("B2 index is only defined for trees with one root") proba = [ - np.prod([1 / tree.num_children(u) for u in node_path(tree, leaf)]) + np.prod([1 / tree.num_children(u) for u in tree.ancestors(leaf)]) for leaf in tree.leaves() ] return -sum(p * math.log(p, base) for p in proba) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index e560fbfe40..5929daa6ae 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -3822,6 +3822,23 @@ def test_num_children(self): for u in tree.nodes(): assert tree.num_children(u) == len(tree.children(u)) + def test_ancestors(self): + tree = tskit.Tree.generate_balanced(10, arity=3) + ancestors_arrays = {u: [] for u in np.arange(tree.tree_sequence.num_nodes)} + ancestors_arrays[-1] = [] + for u in tree.nodes(order="preorder"): + parent = tree.parent(u) + if parent != tskit.NULL: + ancestors_arrays[u] = [parent] + ancestors_arrays[tree.parent(u)] + for u in tree.nodes(): + assert list(tree.ancestors(u)) == ancestors_arrays[u] + + def test_ancestors_empty(self): + ts = tskit.Tree.generate_comb(10).tree_sequence + tree = ts.delete_intervals([[0, 1]]).first() + for u in ts.samples(): + assert len(list(tree.ancestors(u))) == 0 + @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_virtual_root_semantics(self, ts): for tree in ts.trees(): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 79ca182d67..ce046c9e87 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1091,6 +1091,16 @@ def parent_array(self): """ return self._parent_array + def ancestors(self, u): + """ + Returns an iterator over the ancestors of node ``u`` in this tree + (i.e. the chain of parents from ``u`` to the root). + """ + u = self.parent(u) + while u != -1: + yield u + u = self.parent(u) + # Quintuply linked tree structure. def left_child(self, u):