diff --git a/pyproject.toml b/pyproject.toml index 7c22e90..b7d1803 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,8 +20,11 @@ urls.Source = "https://github.com/YosefLab/treedata" urls.Home-page = "https://github.com/YosefLab/treedata" dependencies = [ "anndata", + "copy", + "h5py", "numpy", "pandas", + "pathlib", "pyarrow", "networkx", "session-info", diff --git a/src/treedata/_core/read.py b/src/treedata/_core/read.py index ad756b6..a7b7a79 100755 --- a/src/treedata/_core/read.py +++ b/src/treedata/_core/read.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from collections.abc import MutableMapping, Sequence from pathlib import Path from typing import ( @@ -7,6 +8,7 @@ ) import anndata as ad +import h5py import zarr from scipy import sparse @@ -15,16 +17,14 @@ from treedata._utils import dict_to_digraph -def _tdata_from_adata(tdata) -> TreeData: +def _tdata_from_adata(tdata, treedata_attrs=None) -> TreeData: """Create a TreeData object parsing attribute from AnnData uns field.""" tdata.__class__ = TreeData - if "treedata_attrs" in tdata.uns.keys(): - treedata_attrs = tdata.uns["treedata_attrs"] + if treedata_attrs is not None: tdata._tree_label = treedata_attrs["label"] if "label" in treedata_attrs.keys() else None tdata._allow_overlap = bool(treedata_attrs["allow_overlap"]) tdata._obst = AxisTrees(tdata, 0, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["obst"].items()}) tdata._vart = AxisTrees(tdata, 1, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["vart"].items()}) - del tdata.uns["treedata_attrs"] else: tdata._tree_label = None tdata._allow_overlap = False @@ -71,7 +71,12 @@ def read_h5ad( as_sparse_fmt=as_sparse_fmt, chunk_size=chunk_size, ) - return _tdata_from_adata(adata) + with h5py.File(filename, "r") as f: + if "raw.treedata_attrs" in f: + treedata_attrs = json.loads(f["raw.treedata_attrs"][()]) + else: + treedata_attrs = None + return _tdata_from_adata(adata, treedata_attrs) def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData: @@ -83,4 +88,7 @@ def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData: The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class. """ adata = ad.read_zarr(store) - return _tdata_from_adata(adata) + if "treedata_attrs" in adata.uns.keys(): + treedata_attrs = adata.uns["treedata_attrs"] + del adata.uns["treedata_attrs"] + return _tdata_from_adata(adata, treedata_attrs) diff --git a/src/treedata/_core/treedata.py b/src/treedata/_core/treedata.py index 99955bb..a83b384 100755 --- a/src/treedata/_core/treedata.py +++ b/src/treedata/_core/treedata.py @@ -1,7 +1,9 @@ from __future__ import annotations -import warnings +import json from collections.abc import Iterable, Mapping, MutableMapping, Sequence +from copy import deepcopy +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -9,10 +11,12 @@ ) import anndata as ad +import h5py import networkx as nx import numpy as np import pandas as pd -from anndata._core.index import Index, Index1D +from anndata._core.index import Index, Index1D, _subset +from anndata._io.h5ad import write_h5ad from scipy import sparse from treedata._utils import digraph_to_dict @@ -282,31 +286,57 @@ def _treedata_attrs(self) -> dict: "allow_overlap": self.allow_overlap, } + def _mutated_copy(self, **kwargs): + """Creating TreeData with attributes optionally specified via kwargs.""" + if self.isbacked: + if "X" not in kwargs or (self.raw is not None and "raw" not in kwargs): + raise NotImplementedError( + "This function does not currently handle backed objects " + "internally, this should be dealt with before." + ) + new = {} + new["label"] = self.label + new["allow_overlap"] = self.allow_overlap + + for key in ["obs", "var", "obsm", "varm", "obsp", "varp", "obst", "vart", "layers"]: + if key in kwargs: + new[key] = kwargs[key] + else: + new[key] = getattr(self, key).copy() + if "X" in kwargs: + new["X"] = kwargs["X"] + elif self._has_X(): + new["X"] = self.X.copy() + if "uns" in kwargs: + new["uns"] = kwargs["uns"] + else: + new["uns"] = deepcopy(self._uns) + if "raw" in kwargs: + new["raw"] = kwargs["raw"] + elif self.raw is not None: + new["raw"] = self.raw.copy() + + return TreeData(**new) + def copy(self, filename: PathLike | None = None) -> TreeData: - """Full copy, optionally on disk""" - adata = super().copy(filename=filename) + """Full copy, optionally on disk.""" if not self.isbacked: - treedata_copy = TreeData( - adata, - obst=self.obst.copy(), - vart=self.vart.copy(), - label=self.label, - allow_overlap=self.allow_overlap, - ) + if self.is_view and self._has_X(): + return self._mutated_copy(X=_subset(self._adata_ref.X, (self._oidx, self._vidx)).copy()) + else: + return self._mutated_copy() else: from .read import read_h5ad if filename is None: raise ValueError( - "To copy an TreeData object in backed mode, " + "To copy an AnnData object in backed mode, " "pass a filename: `.copy(filename='myfilename.h5ad')`. " "To load the object into memory, use `.to_memory()`." ) mode = self.file._filemode - adata.uns["treedata_attrs"] = self._treedata_attrs() - adata.write_h5ad(filename) - treedata_copy = read_h5ad(filename, backed=mode) - return treedata_copy + self.write_h5ad(filename) + return read_h5ad(filename, backed=mode) def transpose(self) -> TreeData: """Transpose whole object @@ -347,13 +377,23 @@ def write_h5ad( Sparse arrays in TreeData object to write as dense. Currently only supports `X` and `raw/X`. """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - self.uns["treedata_attrs"] = self._treedata_attrs() - super().write_h5ad( - filename=filename, compression=compression, compression_opts=compression_opts, as_dense=as_dense + if filename is None and not self.isbacked: + raise ValueError("Provide a filename!") + if filename is None: + filename = self.filename + + write_h5ad( + Path(filename), + self, + compression=compression, + compression_opts=compression_opts, + as_dense=as_dense, ) - self.uns.pop("treedata_attrs") + + with h5py.File(filename, "a") as f: + if "raw.treedata_attrs" in f: + del f["raw.treedata_attrs"] + f.create_dataset("raw.treedata_attrs", data=json.dumps(self._treedata_attrs())) write = write_h5ad # a shortcut and backwards compat diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index f53f927..241d54b 100755 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -75,7 +75,7 @@ def test_read_anndata(tdata, tmp_path): def test_h5ad_backing(tdata, tree, tmp_path): tdata_copy = tdata.copy() assert not tdata.isbacked - backing_h5ad = tmp_path / "test.h5ad" + backing_h5ad = tmp_path / "test_backed.h5ad" tdata.filename = backing_h5ad # backing mode tdata.write() @@ -93,13 +93,10 @@ def test_h5ad_backing(tdata, tree, tmp_path): with pytest.warns(UserWarning): with pytest.raises(ValueError): tdata_subset.obs["foo"] = range(3) - # with pytest.warns(UserWarning): - # with pytest.raises(ValueError): - # tdata_subset.obst["foo"] = tree assert subset_hash == joblib.hash(tdata_subset) assert tdata_subset.is_view # copy - tdata_subset = tdata_subset.copy(tmp_path / "test.subset.h5ad") + tdata_subset = tdata_subset.copy(tmp_path / "test_subset.h5ad") assert not tdata_subset.is_view tdata_subset.obs["foo"] = range(3) assert not tdata_subset.is_view @@ -110,4 +107,9 @@ def test_h5ad_backing(tdata, tree, tmp_path): tdata_subset = tdata_subset.to_memory() assert not tdata_subset.is_view assert not tdata_subset.isbacked + print(tdata_subset) check_graph_equality(tdata_subset.obst["tree"], tdata.obst["tree"]) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_views.py b/tests/test_views.py index bdefbcd..dff59e4 100755 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -75,6 +75,23 @@ def test_views_subset_trees(): assert list(tdata_subset.obst["tree2"].edges) == [("root", "2")] +def test_views_copy(): + # subset tdata with multiple trees + tree1 = nx.DiGraph([("root", "0"), ("root", "1")]) + tree2 = nx.DiGraph([("root", "2"), ("root", "3")]) + tdata = td.TreeData(X=np.zeros((8, 8)), allow_overlap=False, obst={"tree1": tree1, "tree2": tree2}) + tdata_subset = tdata.copy()[["0", "1"], :].copy() + assert list(tdata_subset.obst["tree1"].edges) == [("root", "0"), ("root", "1")] + assert list(tdata_subset.obst["tree2"].edges) == [] + print(tdata_subset) + tdata_subset = tdata.copy()[["2", "3"], :].copy() + assert list(tdata_subset.obst["tree1"].edges) == [] + assert list(tdata_subset.obst["tree2"].edges) == [("root", "2"), ("root", "3")] + tdata_subset = tdata.copy()[["0", "1", "2"], :].copy() + assert list(tdata_subset.obst["tree1"].edges) == [("root", "0"), ("root", "1")] + assert list(tdata_subset.obst["tree2"].edges) == [("root", "2")] + + def test_views_mutability(tdata): # can mutate attributes of graph nx.set_node_attributes(tdata.obst["tree"], False, "in_subset") @@ -140,3 +157,7 @@ def test_views_contains(tdata): def test_views_len(tdata): tdata_subset = tdata[[0, 1, 4], :] assert len(tdata_subset.obst) == 1 + + +if __name__ == "__main__": + pytest.main(["-v", __file__])