Skip to content

Commit

Permalink
store trees as json
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed Jun 13, 2024
1 parent 0931bd8 commit 5c790bf
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 33 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ urls.Source = "https://github.com/YosefLab/treedata"
urls.Home-page = "https://github.com/YosefLab/treedata"
dependencies = [
"anndata",
"copy",
"h5py",
"numpy",
"pandas",
"pathlib",
"pyarrow",
"networkx",
"session-info",
Expand Down
20 changes: 14 additions & 6 deletions src/treedata/_core/read.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import json
from collections.abc import MutableMapping, Sequence
from pathlib import Path
from typing import (
Literal,
)

import anndata as ad
import h5py
import zarr
from scipy import sparse

Expand All @@ -15,16 +17,14 @@
from treedata._utils import dict_to_digraph


def _tdata_from_adata(tdata) -> TreeData:
def _tdata_from_adata(tdata, treedata_attrs=None) -> TreeData:
"""Create a TreeData object parsing attribute from AnnData uns field."""
tdata.__class__ = TreeData
if "treedata_attrs" in tdata.uns.keys():
treedata_attrs = tdata.uns["treedata_attrs"]
if treedata_attrs is not None:
tdata._tree_label = treedata_attrs["label"] if "label" in treedata_attrs.keys() else None
tdata._allow_overlap = bool(treedata_attrs["allow_overlap"])
tdata._obst = AxisTrees(tdata, 0, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["obst"].items()})
tdata._vart = AxisTrees(tdata, 1, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["vart"].items()})
del tdata.uns["treedata_attrs"]
else:
tdata._tree_label = None
tdata._allow_overlap = False
Expand Down Expand Up @@ -71,7 +71,12 @@ def read_h5ad(
as_sparse_fmt=as_sparse_fmt,
chunk_size=chunk_size,
)
return _tdata_from_adata(adata)
with h5py.File(filename, "r") as f:
if "raw.treedata_attrs" in f:
treedata_attrs = json.loads(f["raw.treedata_attrs"][()])
else:
treedata_attrs = None
return _tdata_from_adata(adata, treedata_attrs)


def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData:
Expand All @@ -83,4 +88,7 @@ def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData:
The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class.
"""
adata = ad.read_zarr(store)
return _tdata_from_adata(adata)
if "treedata_attrs" in adata.uns.keys():
treedata_attrs = adata.uns["treedata_attrs"]
del adata.uns["treedata_attrs"]
return _tdata_from_adata(adata, treedata_attrs)
84 changes: 62 additions & 22 deletions src/treedata/_core/treedata.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from __future__ import annotations

import warnings
import json
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from copy import deepcopy
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Literal,
)

import anndata as ad
import h5py
import networkx as nx
import numpy as np
import pandas as pd
from anndata._core.index import Index, Index1D
from anndata._core.index import Index, Index1D, _subset
from anndata._io.h5ad import write_h5ad
from scipy import sparse

from treedata._utils import digraph_to_dict
Expand Down Expand Up @@ -282,31 +286,57 @@ def _treedata_attrs(self) -> dict:
"allow_overlap": self.allow_overlap,
}

def _mutated_copy(self, **kwargs):
"""Creating TreeData with attributes optionally specified via kwargs."""
if self.isbacked:
if "X" not in kwargs or (self.raw is not None and "raw" not in kwargs):
raise NotImplementedError(
"This function does not currently handle backed objects "
"internally, this should be dealt with before."
)
new = {}
new["label"] = self.label
new["allow_overlap"] = self.allow_overlap

for key in ["obs", "var", "obsm", "varm", "obsp", "varp", "obst", "vart", "layers"]:
if key in kwargs:
new[key] = kwargs[key]
else:
new[key] = getattr(self, key).copy()
if "X" in kwargs:
new["X"] = kwargs["X"]
elif self._has_X():
new["X"] = self.X.copy()
if "uns" in kwargs:
new["uns"] = kwargs["uns"]
else:
new["uns"] = deepcopy(self._uns)
if "raw" in kwargs:
new["raw"] = kwargs["raw"]
elif self.raw is not None:
new["raw"] = self.raw.copy()

return TreeData(**new)

def copy(self, filename: PathLike | None = None) -> TreeData:
"""Full copy, optionally on disk"""
adata = super().copy(filename=filename)
"""Full copy, optionally on disk."""
if not self.isbacked:
treedata_copy = TreeData(
adata,
obst=self.obst.copy(),
vart=self.vart.copy(),
label=self.label,
allow_overlap=self.allow_overlap,
)
if self.is_view and self._has_X():
return self._mutated_copy(X=_subset(self._adata_ref.X, (self._oidx, self._vidx)).copy())
else:
return self._mutated_copy()
else:
from .read import read_h5ad

if filename is None:
raise ValueError(
"To copy an TreeData object in backed mode, "
"To copy an AnnData object in backed mode, "
"pass a filename: `.copy(filename='myfilename.h5ad')`. "
"To load the object into memory, use `.to_memory()`."
)
mode = self.file._filemode
adata.uns["treedata_attrs"] = self._treedata_attrs()
adata.write_h5ad(filename)
treedata_copy = read_h5ad(filename, backed=mode)
return treedata_copy
self.write_h5ad(filename)
return read_h5ad(filename, backed=mode)

def transpose(self) -> TreeData:
"""Transpose whole object
Expand Down Expand Up @@ -347,13 +377,23 @@ def write_h5ad(
Sparse arrays in TreeData object to write as dense. Currently only
supports `X` and `raw/X`.
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.uns["treedata_attrs"] = self._treedata_attrs()
super().write_h5ad(
filename=filename, compression=compression, compression_opts=compression_opts, as_dense=as_dense
if filename is None and not self.isbacked:
raise ValueError("Provide a filename!")
if filename is None:
filename = self.filename

write_h5ad(
Path(filename),
self,
compression=compression,
compression_opts=compression_opts,
as_dense=as_dense,
)
self.uns.pop("treedata_attrs")

with h5py.File(filename, "a") as f:
if "raw.treedata_attrs" in f:
del f["raw.treedata_attrs"]
f.create_dataset("raw.treedata_attrs", data=json.dumps(self._treedata_attrs()))

write = write_h5ad # a shortcut and backwards compat

Expand Down
12 changes: 7 additions & 5 deletions tests/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_read_anndata(tdata, tmp_path):
def test_h5ad_backing(tdata, tree, tmp_path):
tdata_copy = tdata.copy()
assert not tdata.isbacked
backing_h5ad = tmp_path / "test.h5ad"
backing_h5ad = tmp_path / "test_backed.h5ad"
tdata.filename = backing_h5ad
# backing mode
tdata.write()
Expand All @@ -93,13 +93,10 @@ def test_h5ad_backing(tdata, tree, tmp_path):
with pytest.warns(UserWarning):
with pytest.raises(ValueError):
tdata_subset.obs["foo"] = range(3)
# with pytest.warns(UserWarning):
# with pytest.raises(ValueError):
# tdata_subset.obst["foo"] = tree
assert subset_hash == joblib.hash(tdata_subset)
assert tdata_subset.is_view
# copy
tdata_subset = tdata_subset.copy(tmp_path / "test.subset.h5ad")
tdata_subset = tdata_subset.copy(tmp_path / "test_subset.h5ad")
assert not tdata_subset.is_view
tdata_subset.obs["foo"] = range(3)
assert not tdata_subset.is_view
Expand All @@ -110,4 +107,9 @@ def test_h5ad_backing(tdata, tree, tmp_path):
tdata_subset = tdata_subset.to_memory()
assert not tdata_subset.is_view
assert not tdata_subset.isbacked
print(tdata_subset)
check_graph_equality(tdata_subset.obst["tree"], tdata.obst["tree"])


if __name__ == "__main__":
pytest.main(["-v", __file__])
21 changes: 21 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,23 @@ def test_views_subset_trees():
assert list(tdata_subset.obst["tree2"].edges) == [("root", "2")]


def test_views_copy():
# subset tdata with multiple trees
tree1 = nx.DiGraph([("root", "0"), ("root", "1")])
tree2 = nx.DiGraph([("root", "2"), ("root", "3")])
tdata = td.TreeData(X=np.zeros((8, 8)), allow_overlap=False, obst={"tree1": tree1, "tree2": tree2})
tdata_subset = tdata.copy()[["0", "1"], :].copy()
assert list(tdata_subset.obst["tree1"].edges) == [("root", "0"), ("root", "1")]
assert list(tdata_subset.obst["tree2"].edges) == []
print(tdata_subset)
tdata_subset = tdata.copy()[["2", "3"], :].copy()
assert list(tdata_subset.obst["tree1"].edges) == []
assert list(tdata_subset.obst["tree2"].edges) == [("root", "2"), ("root", "3")]
tdata_subset = tdata.copy()[["0", "1", "2"], :].copy()
assert list(tdata_subset.obst["tree1"].edges) == [("root", "0"), ("root", "1")]
assert list(tdata_subset.obst["tree2"].edges) == [("root", "2")]


def test_views_mutability(tdata):
# can mutate attributes of graph
nx.set_node_attributes(tdata.obst["tree"], False, "in_subset")
Expand Down Expand Up @@ -140,3 +157,7 @@ def test_views_contains(tdata):
def test_views_len(tdata):
tdata_subset = tdata[[0, 1, 4], :]
assert len(tdata_subset.obst) == 1


if __name__ == "__main__":
pytest.main(["-v", __file__])

0 comments on commit 5c790bf

Please sign in to comment.