Skip to content

Commit

Permalink
fixed tree label bug
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed Aug 21, 2024
1 parent 2a6ad22 commit 5998540
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ and this project adheres to [Semantic Versioning][].

### Added

### Changed

- `obst` and `vart` create local copy of `nx.DiGraphs` that are added (#26)

### Fixed

- Fixed bug which caused key to be listed twice in `tree_label` column after value update in `obst` or `vart` (#26)

## [0.0.2] - 2024-06-18

### Changed
Expand Down
6 changes: 3 additions & 3 deletions src/treedata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _update_tree_labels(self):
if self.parent.allow_overlap:
mapping = {k: ",".join(map(str, v)) for k, v in self._leaf_to_tree.items()}
else:
mapping = {k: v[0] for k, v in self._leaf_to_tree.items()}
mapping = {k: next(iter(v)) for k, v in self._leaf_to_tree.items()}
getattr(self.parent, self.dim)[self.parent._tree_label] = getattr(self.parent, f"{self.dim}_names").map(
mapping
)
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
self._axis = axis
self._data = {}
self._tree_to_leaf = defaultdict(set)
self._leaf_to_tree = defaultdict(list)
self._leaf_to_tree = defaultdict(set)
if vals is not None:
self.update(vals)

Expand All @@ -156,7 +156,7 @@ def __setitem__(self, key: str, value: nx.DiGraph):
value, leaves = self._validate_tree(value, key)

for leaf in leaves:
self._leaf_to_tree[leaf].append(key)
self._leaf_to_tree[leaf].add(key)
self._tree_to_leaf[key] = leaves

if not self.parent.is_view:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def test_tree_label(X, tree, dim):
df = pd.DataFrame({"tree": ["bad", "bad", "bad"]})
with pytest.warns(UserWarning):
tdata = td.TreeData(X, label="tree", obs=df, var=df)
# Test tree label with updata
tdata = td.TreeData(X, obst={"0": tree, "1": tree}, label="tree", vart={"0": tree, "1": tree}, allow_overlap=True)
tdata.obst["0"] = tree
assert getattr(tdata, dim).loc["0", "tree"] == "0,1"


def test_tree_overlap(X, tree):
Expand Down

0 comments on commit 5998540

Please sign in to comment.