From f3d94b037cc1f6c02b8ae5b3c5a9e6f99ddb2321 Mon Sep 17 00:00:00 2001 From: FrancescaDr Date: Sat, 1 Jun 2024 09:19:29 +0200 Subject: [PATCH] obs_names and var_names add to fields --- src/geome/ann2data/basic.py | 2 +- src/geome/utils.py | 5 +++++ tests/ann2data/test_ann2data_by_category.py | 15 +++++++++++---- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/geome/ann2data/basic.py b/src/geome/ann2data/basic.py index 0e169e0..efe86a9 100644 --- a/src/geome/ann2data/basic.py +++ b/src/geome/ann2data/basic.py @@ -87,7 +87,7 @@ def _convert_to_tensor(self, obj): if obj.dtype.name == "category": return torch.from_numpy(pd.get_dummies(obj).to_numpy()).to(torch.float) if not np.issubdtype(obj.dtype, np.number): - return torch.from_numpy(obj.astype(np.float)).to(torch.float) + return torch.from_numpy(obj.astype(np.float64)).to(torch.float) if isinstance(obj, np.ndarray): return torch.from_numpy(obj).to(torch.float) else: diff --git a/src/geome/utils.py b/src/geome/utils.py index fb953ec..d09a624 100644 --- a/src/geome/utils.py +++ b/src/geome/utils.py @@ -21,6 +21,11 @@ def get_from_loc(adata: AnnData, location: str) -> Any: """ if location == "X": return adata.X + elif location == "obs_names": + return adata.obs_names.to_numpy() + elif location == "var_names": + return adata.var_names.to_numpy() + assert len(location.split("/")) == 2, f"Location must have only one delimiter {location}" axis, key = location.split("/") diff --git a/tests/ann2data/test_ann2data_by_category.py b/tests/ann2data/test_ann2data_by_category.py index f06c8fa..1941892 100644 --- a/tests/ann2data/test_ann2data_by_category.py +++ b/tests/ann2data/test_ann2data_by_category.py @@ -10,15 +10,22 @@ def test_sample_case_ann2data_basic(): # make sure that there are two clusters of spatial coordinates # so that the resulting splits number of edges will be the same # as the sum of the number of edges in each cluster - func_args = {"radius": 4.0, "coord_type": "generic"} + func_args = {"radius": 4.0, "coord_type": "generic", "library_key": "image_id"} coordinates[:25, 0] += 100 adata_gt = ad.AnnData( np.random.rand(50, 2), - obs={"cell_type": ["a"] * 25 + ["b"] * 25, "image_id": list("cd" * 25)}, + obs={"cell_type": ["a"] * 20 + ["b"] * 20 + ["c"] * 5 + ["d"] * 5, "image_id": list("xy" * 20) + ["z"] * 10}, obsm={"spatial_init": coordinates}, ) a2d = ann2data.Ann2DataByCategory( - fields={"x": ["X"], "edge_index": ["uns/edge_index"], "edge_weight": ["uns/edge_weight"]}, + fields={ + "x": ["X"], + "obs_names": ["obs_names"], + "var_names": ["var_names"], + "edge_index": ["uns/edge_index"], + "edge_weight": ["uns/edge_weight"], + "y": ["obs/cell_type"], + }, category="cell_type", preprocess=transforms.Categorize(keys=["cell_type", "image_id"]), transform=transforms.AddEdgeIndex( @@ -30,7 +37,7 @@ def test_sample_case_ann2data_basic(): ), ) datas = list(a2d(adata_gt.copy())) - assert len(datas) == 2 + assert len(datas) == 3 big_adata_tf = transforms.Compose( [ transforms.Categorize(keys=["cell_type", "image_id"]),