From 56bf1d66793a8d3c5d7a68e9c4df5c82248ceacb Mon Sep 17 00:00:00 2001 From: colganwi Date: Tue, 20 Feb 2024 12:56:08 -0500 Subject: [PATCH] added tests --- src/treedata/_core/aligned_mapping.py | 8 +++----- src/treedata/_core/treedata.py | 2 +- src/treedata/_utils.py | 14 ------------- tests/test_base.py | 16 +++++++++++---- tests/test_utils.py | 29 ++------------------------- tests/test_views.py | 23 ++++++++++++++------- 6 files changed, 34 insertions(+), 58 deletions(-) diff --git a/src/treedata/_core/aligned_mapping.py b/src/treedata/_core/aligned_mapping.py index fef6caa..7410309 100755 --- a/src/treedata/_core/aligned_mapping.py +++ b/src/treedata/_core/aligned_mapping.py @@ -53,6 +53,8 @@ def _validate_tree(self, tree: nx.DiGraph, key: str) -> nx.DiGraph: for node in tree.nodes: if tree.in_degree(node) == 0: root_count += 1 + if tree.out_degree(node) == 0: + raise ValueError(f"Value for key {key} must be fully connected") elif tree.in_degree(node) > 1: raise ValueError(f"Value for key {key} must be a tree") if tree.out_degree(node) == 0: @@ -196,8 +198,6 @@ def __setitem__(self, key: str, value: nx.DiGraph): ) with view_update(self.parent, self.attrname, ()) as new_mapping: new_mapping[key] = value - print("here2") - print(key) def __delitem__(self, key: str): if key not in self: @@ -222,7 +222,7 @@ def __len__(self) -> int: @contextmanager def view_update(tdata_view: TreeData, attr_name: str, keys: tuple[str, ...]): - """Context manager for updating a view of an AnnData object. + """Context manager for updating a view of an TreeData object. Contains logic for "actualizing" a view. Yields the object to be modified in-place. @@ -241,8 +241,6 @@ def view_update(tdata_view: TreeData, attr_name: str, keys: tuple[str, ...]): """ new = tdata_view.copy() attr = getattr(new, attr_name) - for key in attr: - print(key) container = reduce(lambda d, k: d[k], keys, attr) yield container tdata_view._init_as_actual(new) diff --git a/src/treedata/_core/treedata.py b/src/treedata/_core/treedata.py index 9cf2623..2fffc0f 100755 --- a/src/treedata/_core/treedata.py +++ b/src/treedata/_core/treedata.py @@ -182,7 +182,7 @@ def _init_as_actual( self._vart = AxisTrees(self, 1, vals=vart) def _init_as_view(self, tdata_ref: TreeData, oidx: Index, vidx: Index): - super()._init_as_view(tdata_ref, oidx, vidx) + super()._init_as_view(tdata_ref, oidx=oidx, vidx=vidx) # view of obst and vart self._obst = tdata_ref.obst._view(self, (oidx,)) diff --git a/src/treedata/_utils.py b/src/treedata/_utils.py index a371271..b21a6c5 100755 --- a/src/treedata/_utils.py +++ b/src/treedata/_utils.py @@ -3,20 +3,6 @@ import networkx as nx -def get_leaves(tree: nx.DiGraph) -> list[str]: - """Get the leaves of a tree.""" - leaves = [n for n in tree.nodes if tree.out_degree(n) == 0] - return leaves - - -def get_root(tree: nx.DiGraph) -> str: - """Get the root of a tree.""" - roots = [n for n in tree.nodes if tree.in_degree(n) == 0] - if len(roots) != 1: - raise ValueError(f"Tree must have exactly one root, found {len(roots)}.") - return roots[0] - - def subset_tree(tree: nx.DiGraph, leaves: list[str], asview: bool) -> nx.DiGraph: """Subset tree.""" keep_nodes = set(leaves) diff --git a/tests/test_base.py b/tests/test_base.py index b6bf5f9..582cc85 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -40,10 +40,17 @@ def test_creation(X, adata, tree): assert tdata.X is adata.X -@pytest.mark.parametrize("dim", ["obs", "var"]) -def test_tree_keys(X, tree, dim): +@pytest.mark.parametrize("axis", [0, 1]) +def test_attributes(X, tree, axis): + dim = ["obs", "var"][axis] tdata = td.TreeData(X, obst={"tree": tree}, vart={"tree": tree}, label=None) - check_graph_equality(getattr(tdata, f"{dim}t")["tree"], tree) + assert getattr(tdata, f"{dim}t").axes == (axis,) + assert getattr(tdata, f"{dim}t").attrname == (f"{dim}t") + assert getattr(tdata, f"{dim}t").dim == dim + assert getattr(tdata, f"{dim}t").parent is tdata + assert list(getattr(tdata, f"{dim}t").dim_names) == ["0", "1", "2"] + assert tdata.allow_overlap is False + assert tdata.label is None @pytest.mark.parametrize("dim", ["obs", "var"]) @@ -131,6 +138,7 @@ def test_bad_tree(X): # Has cycle has_cycle = nx.DiGraph() has_cycle.add_edges_from([("0", "1"), ("1", "0")]) + has_cycle.add_node("2") with pytest.raises(ValueError): _ = td.TreeData(X, obst={"tree": has_cycle}) # Not fully connected @@ -145,7 +153,7 @@ def test_bad_tree(X): _ = td.TreeData(X, obst={"tree": bad_leaves}) # Multiple roots multi_root = nx.DiGraph() - multi_root.add_edges_from([("root", "0"), ("bad", "0")]) + multi_root.add_edges_from([("0", "1"), ("1", "0"), ("2", "3")]) with pytest.raises(ValueError): _ = td.TreeData(X, obst={"tree": multi_root}) diff --git a/tests/test_utils.py b/tests/test_utils.py index e06f619..2f0cbbf 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,43 +1,18 @@ import networkx as nx import pytest -from treedata._utils import get_leaves, get_root, subset_tree +from treedata._utils import subset_tree @pytest.fixture def tree(): tree = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) - root = get_root(tree) + root = [n for n, d in tree.in_degree() if d == 0][0] depths = nx.single_source_shortest_path_length(tree, root) nx.set_node_attributes(tree, values=depths, name="depth") yield tree -def test_get_leaves(): - tree = nx.DiGraph() - tree.add_edges_from([("root", "0"), ("root", "1")]) - assert get_leaves(tree) == ["0", "1"] - - -def test_get_root(): - tree = nx.DiGraph() - tree.add_edges_from([("root", "0"), ("root", "1")]) - assert get_root(tree) == "root" - - -def test_get_root_raises(): - # Has cycle - has_cycle = nx.DiGraph() - has_cycle.add_edges_from([("root", "0"), ("0", "root")]) - with pytest.raises(ValueError): - get_root(has_cycle) - # Multiple roots - multi_root = nx.DiGraph() - multi_root.add_edges_from([("root", "0"), ("bad", "0")]) - with pytest.raises(ValueError): - get_root(multi_root) - - def test_subset_tree(tree): # copy subtree = subset_tree(tree, [7, 8, 9], asview=False) diff --git a/tests/test_views.py b/tests/test_views.py index d599313..2943e9e 100755 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -4,15 +4,13 @@ import pytest import treedata as td -from treedata._utils import get_root @pytest.fixture def tree(): tree = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph) tree = nx.relabel_nodes(tree, {i: str(i) for i in tree.nodes}) - root = get_root(tree) - depths = nx.single_source_shortest_path_length(tree, root) + depths = nx.single_source_shortest_path_length(tree, "0") nx.set_node_attributes(tree, values=depths, name="depth") yield tree @@ -36,6 +34,16 @@ def test_views(tdata): assert tdata_subset.obs["test"].tolist() == list(range(2)) +# this test should pass once anndata bug is fixed +# See https://github.com/scverse/anndata/issues/1382 +@pytest.mark.xfail +def test_views_creation(tdata): + tdata_view = td.TreeData(tdata, asview=True) + assert tdata_view.is_view + with pytest.raises(ValueError): + _ = td.TreeData(np.zeros(shape=(3, 3)), asview=False) + + def test_views_subset_tree(tdata): expected_edges = [ ("0", "1"), @@ -58,6 +66,7 @@ def test_views_subset_tree(tdata): tdata_subset = tdata_subset.copy() edges = list(tdata_subset.obst["tree"].edges) assert edges == expected_edges + assert len(tdata.obst["tree"].edges) == 14 def test_views_mutability(tdata): @@ -76,7 +85,7 @@ def test_views_mutability(tdata): tdata_subset.obst["tree"].remove_node("8") -def test_set(tdata): +def test_views_set(tdata): tdata_subset = tdata[[0, 1, 4], :] # bad assignment bad_tree = nx.DiGraph() @@ -94,7 +103,7 @@ def test_set(tdata): assert list(tdata_subset.obst["new_tree"].edges) == [("0", "8")] -def test_del(tdata): +def test_views_del(tdata): tdata_subset = tdata[[0, 1, 4], :] # bad deletion with pytest.raises(KeyError): @@ -107,12 +116,12 @@ def test_del(tdata): assert list(tdata_subset.obst.keys()) == [] -def test_contains(tdata): +def test_views_contains(tdata): tdata_subset = tdata[[0, 1, 4], :] assert "tree" in tdata_subset.obst assert "bad" not in tdata_subset.obst -def test_len(tdata): +def test_views_len(tdata): tdata_subset = tdata[[0, 1, 4], :] assert len(tdata_subset.obst) == 1