Skip to content

Commit

Permalink
clades dtype argument
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed Oct 24, 2024
1 parent 667e95a commit 465ef18
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.DS_Store
*~
buck-out/
.ipynb_checkpoints/

# Compiled files
.venv/
Expand Down
13 changes: 10 additions & 3 deletions src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ def _nodes_at_depth(tree, parent, nodes, depth, depth_key):
return nodes


def _clade_name_generator():
def _clade_name_generator(dtype=int):
"""Generates clade names."""
valid_dtypes = {"str": str, "int": int, "float": float, str: str, int: int, float: float}
if dtype not in valid_dtypes:
raise ValueError("dtype must be one of str, int, or float")
converter = valid_dtypes[dtype]
i = 0
while True:
yield str(i)
yield converter(i)
i += 1


Expand Down Expand Up @@ -65,6 +69,7 @@ def clades(
clades: str | Sequence[str] = None,
key_added: str = "clade",
update: bool = False,
dtype: type | str = str,
tree: str | Sequence[str] | None = None,
copy: bool = False,
) -> None | Mapping:
Expand All @@ -84,6 +89,8 @@ def clades(
Key to store clades in.
update
If True, updates existing clades instead of overwriting.
dtype
Data type of clade names. One of `str`, `int`, or `float`.
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
Expand All @@ -107,7 +114,7 @@ def clades(
if clades and len(trees) > 1:
raise ValueError("Multiple trees are present. Must specify a single tree if clades are given.")
# Identify clades
name_generator = _clade_name_generator()
name_generator = _clade_name_generator(dtype=dtype)
lcas = []
for key, tree in trees.items():
tree_lcas = _clades(tree, depth, depth_key, clades, key_added, name_generator, update)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ def test_clades_multiple_trees():
assert pd.isna(tdata.obs.loc["B", "test"])


def test_clades_dtype(tdata):
clades(tdata, depth=0, dtype=int)
assert tdata.obs["clade"].dtype == int
assert tdata.obst["tree"].nodes["A"]["clade"] == 0
clades(tdata, depth=0, dtype="int")
assert tdata.obs["clade"].dtype == int
assert tdata.obst["tree"].nodes["A"]["clade"] == 0
clades(tdata, depth=1, dtype=float)
assert tdata.obs["clade"].dtype == float
assert tdata.obst["tree"].nodes["C"]["clade"] == 1.0


def test_clades_invalid(tdata):
with pytest.raises(ValueError):
clades(td.TreeData(), clades={"A": 0}, depth=0)
Expand Down

0 comments on commit 465ef18

Please sign in to comment.