Skip to content

Commit

Permalink
Merge pull request #13 from YosefLab/leaf-data-patch
Browse files Browse the repository at this point in the history
Leaf data patch
  • Loading branch information
colganwi authored Oct 24, 2024
2 parents 19eebb2 + 362ecbf commit 667e95a
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 14 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ __pycache__/

# Node modules
node_modules/

# Environment
environment.yml
15 changes: 15 additions & 0 deletions src/pycea/pl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,18 @@ def _series_to_rgb_array(series, colors, vmin=None, vmax=None, na_color="#808080
else:
raise ValueError("cmap must be either a dictionary or a ListedColormap.")
return rgb_array


def _check_tree_overlap(tdata, tree_keys):
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
tree_keys = tdata.obst.keys()
if tdata.allow_overlap and len(tree_keys) > 1:
raise ValueError("Must specify a tree when tdata.allow_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")
6 changes: 4 additions & 2 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ._docs import _doc_params, doc_common_plot_args
from ._utils import (
_check_tree_overlap,
_get_categorical_colors,
_get_categorical_markers,
_series_to_rgb_array,
Expand All @@ -34,7 +35,7 @@ def branches(
extend_branches: bool = False,
angled_branches: bool = False,
color: str = "black",
linewidth: int | float | str = .5,
linewidth: int | float | str = 0.5,
depth_key: str = "depth",
tree: str | Sequence[str] | None = None,
cmap: str | mcolors.Colormap = "viridis",
Expand Down Expand Up @@ -79,6 +80,7 @@ def branches(
""" # noqa: D205
# Setup
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
if ax is None:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None)
elif (ax.name == "polar" and not polar) or (ax.name != "polar" and polar):
Expand Down Expand Up @@ -498,7 +500,7 @@ def tree(
angled_branches: bool = False,
depth_key: str = "depth",
branch_color: str = "black",
branch_linewidth: int | float | str = .5,
branch_linewidth: int | float | str = 0.5,
node_color: str = "black",
node_style: str = "o",
node_size: int | float = 10,
Expand Down
18 changes: 17 additions & 1 deletion src/pycea/pp/setup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ def _add_depth(tree, depth_key):
nx.set_node_attributes(tree, depths, depth_key)


def _check_tree_overlap(tdata, tree_keys):
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
tree_keys = tdata.obst.keys()
if tdata.allow_overlap and len(tree_keys) > 1:
raise ValueError("Must specify a tree when tdata.allow_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")


def add_depth(
tdata: td.TreeData, key_added: str = "depth", tree: str | Sequence[str] | None = None, copy: bool = False
) -> None | pd.DataFrame:
Expand Down Expand Up @@ -44,9 +59,10 @@ def add_depth(
- Distance from the root node.
"""
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
_add_depth(tree, key_added)
tdata.obs[key_added] = get_keyed_leaf_data(tdata, key_added)[key_added]
tdata.obs[key_added] = get_keyed_leaf_data(tdata, key_added, tree_keys)[key_added]
if copy:
return get_keyed_node_data(tdata, key_added, tree_keys)
15 changes: 15 additions & 0 deletions src/pycea/tl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,18 @@ def _assert_param_xor(params):
if n_set == 0:
raise ValueError(f"At least one of {param_text} must be set.")
return None


def _check_tree_overlap(tdata, tree_keys):
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
tree_keys = tdata.obst.keys()
if tdata.allow_overlap and len(tree_keys) > 1:
raise ValueError("Must specify a tree when tdata.allow_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")
3 changes: 3 additions & 0 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_root, get_trees

from ._utils import _check_tree_overlap


def _most_common(arr):
"""Finds the most common element in a list."""
Expand Down Expand Up @@ -256,6 +258,7 @@ def ancestral_states(
if len(keys) != len(keys_added):
raise ValueError("Length of keys must match length of keys_added.")
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
data, is_array = get_keyed_obs_data(tdata, keys)
Expand Down
3 changes: 3 additions & 0 deletions src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from pycea.utils import check_tree_has_key, get_keyed_leaf_data, get_root, get_trees

from ._utils import _check_tree_overlap


def _nodes_at_depth(tree, parent, nodes, depth, depth_key):
"""Recursively finds nodes at a given depth."""
Expand Down Expand Up @@ -100,6 +102,7 @@ def clades(
"""
# Setup
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
trees = get_trees(tdata, tree_keys)
if clades and len(trees) > 1:
raise ValueError("Multiple trees are present. Must specify a single tree if clades are given.")
Expand Down
2 changes: 2 additions & 0 deletions src/pycea/tl/tree_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ._metrics import _get_tree_metric, _TreeMetric
from ._utils import (
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_format_keys,
_set_distances_and_connectivities,
Expand Down Expand Up @@ -164,6 +165,7 @@ def tree_distance(
key_added = key_added or "tree"
connect_key = _format_keys(connect_key, "connectivities")
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
trees = get_trees(tdata, tree_keys)
metric_fn = _get_tree_metric(metric)
single_obs = False
Expand Down
2 changes: 2 additions & 0 deletions src/pycea/tl/tree_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._utils import (
_assert_param_xor,
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_set_distances_and_connectivities,
_set_random_state,
Expand Down Expand Up @@ -142,6 +143,7 @@ def tree_neighbors(
_assert_param_xor({"n_neighbors": n_neighbors, "max_dist": max_dist})
_ = _get_tree_metric(metric)
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
if update:
_check_previous_params(tdata, {"metric": metric}, key_added, ["neighbors", "distances"])
# Neighbors of a single leaf
Expand Down
14 changes: 8 additions & 6 deletions src/pycea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def check_tree_has_key(tree: nx.DiGraph, key: str):


def get_keyed_edge_data(
tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None
tdata: td.TreeData, keys: str | Sequence[str], tree: str | Sequence[str] = None
) -> pd.DataFrame:
"""Gets edge data for a given key from a tree or set of trees."""
tree_keys = tree
if isinstance(tree_keys, str):
tree_keys = [tree_keys]
if isinstance(keys, str):
Expand All @@ -65,9 +66,10 @@ def get_keyed_edge_data(


def get_keyed_node_data(
tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None
tdata: td.TreeData, keys: str | Sequence[str], tree: str | Sequence[str] = None
) -> pd.DataFrame:
"""Gets node data for a given key from a tree or set of trees."""
tree_keys = tree
if isinstance(tree_keys, str):
tree_keys = [tree_keys]
if isinstance(keys, str):
Expand All @@ -86,9 +88,10 @@ def get_keyed_node_data(


def get_keyed_leaf_data(
tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None
tdata: td.TreeData, keys: str | Sequence[str], tree: str | Sequence[str] = None
) -> pd.DataFrame:
"""Gets node data for a given key from a tree or set of trees."""
tree_keys = tree
if isinstance(tree_keys, str):
tree_keys = [tree_keys]
if isinstance(keys, str):
Expand Down Expand Up @@ -156,16 +159,15 @@ def get_keyed_obsm_data(tdata: td.TreeData, key: str) -> sp.sparse.csr_matrix:
return X


def get_trees(tdata: td.TreeData, tree_keys: str | Sequence[str] | None) -> Mapping[str, nx.DiGraph]:
def get_trees(tdata: td.TreeData, tree: str | Sequence[str] | None) -> Mapping[str, nx.DiGraph]:
"""Gets tree data for a given key from a tree."""
trees = {}
tree_keys = tree
if tree_keys is None:
tree_keys = tdata.obst.keys()
elif isinstance(tree_keys, str):
tree_keys = [tree_keys]
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
tree_keys = list(tree_keys)
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")
Expand Down
18 changes: 18 additions & 0 deletions tests/test_setup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def tdata():
yield tdata


@pytest.fixture
def tdata_with_overlap():
tree = nx.DiGraph([("root", "A"), ("root", "B"), ("B", "C"), ("B", "D")])
tdata = td.TreeData(
obs=pd.DataFrame(index=["A", "C", "D"]), obst={"tree1": tree, "tree2": tree}, allow_overlap=True
)
yield tdata


def test_add_depth(tdata):
depths = add_depth(tdata, key_added="depth", copy=True)
assert depths.loc[("tree1", "root"), "depth"] == 0
Expand All @@ -25,5 +34,14 @@ def test_add_depth(tdata):
assert tdata.obs.loc["C", "depth"] == 2


def test_add_depth_overlap(tdata_with_overlap):
with pytest.raises(ValueError):
add_depth(tdata_with_overlap, key_added="depth", copy=True)
depths = add_depth(tdata_with_overlap, key_added="depth", tree="tree1", copy=True)
assert depths.loc[("tree1", "C"), "depth"] == 2
depths = add_depth(tdata_with_overlap, key_added="depth", tree="tree2", copy=True)
assert depths.loc[("tree2", "C"), "depth"] == 2


if __name__ == "__main__":
pytest.main(["-v", __file__])
16 changes: 11 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def tree():
def tdata(tree):
tdata = td.TreeData(
obs=pd.DataFrame({"value": ["1", "2"]}, index=["D", "E"]),
obst={"tree": tree},
obst={"tree": tree, "tree2": tree},
obsm={"spatial": pd.DataFrame([[0, 0], [1, 1]], index=["D", "E"])},
allow_overlap=True,
)
yield tdata

Expand Down Expand Up @@ -64,22 +65,27 @@ def test_get_subtree_leaves(tree):


def test_get_keyed_edge_data(tdata):
data = get_keyed_edge_data(tdata, ["weight", "color"])
data = get_keyed_edge_data(tdata, ["weight", "color"], tree="tree")
assert data.columns.tolist() == ["weight", "color"]
assert data.index.names == ["tree", "edge"]
assert data["weight"].to_list() == [5, 3, 4]
data = get_keyed_edge_data(tdata, ["weight", "color"])
assert data.shape[0] == 6
assert data.index.get_level_values("tree").unique().tolist() == ["tree", "tree2"]


def test_get_keyed_node_data(tdata):
data = get_keyed_node_data(tdata, ["value", "color"])
data = get_keyed_node_data(tdata, ["value", "color"], tree="tree")
assert data.columns.tolist() == ["value", "color"]
assert data.index.names == ["tree", "node"]
assert data["value"].to_list() == [1, 2, 2, 4, 4]
data = get_keyed_node_data(tdata, ["value", "color"])
assert data.shape[0] == 10
assert data.index.get_level_values("tree").unique().tolist() == ["tree", "tree2"]


def test_get_keyed_leaf_data(tdata):
data = get_keyed_leaf_data(tdata, ["value", "color"])
print(data)
data = get_keyed_leaf_data(tdata, ["value", "color"], tree="tree")
assert data.columns.tolist() == ["value", "color"]
assert data["value"].tolist() == [4, 4]
assert data["color"].tolist() == ["blue", "blue"]
Expand Down

0 comments on commit 667e95a

Please sign in to comment.