Skip to content

Commit

Permalink
obs_names and var_names add to fields
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescaDr committed Jun 1, 2024
1 parent f2a58ed commit f3d94b0
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/geome/ann2data/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _convert_to_tensor(self, obj):
if obj.dtype.name == "category":
return torch.from_numpy(pd.get_dummies(obj).to_numpy()).to(torch.float)
if not np.issubdtype(obj.dtype, np.number):
return torch.from_numpy(obj.astype(np.float)).to(torch.float)
return torch.from_numpy(obj.astype(np.float64)).to(torch.float)
if isinstance(obj, np.ndarray):
return torch.from_numpy(obj).to(torch.float)
else:
Expand Down
5 changes: 5 additions & 0 deletions src/geome/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def get_from_loc(adata: AnnData, location: str) -> Any:
"""
if location == "X":
return adata.X
elif location == "obs_names":
return adata.obs_names.to_numpy()
elif location == "var_names":
return adata.var_names.to_numpy()

assert len(location.split("/")) == 2, f"Location must have only one delimiter {location}"
axis, key = location.split("/")

Expand Down
15 changes: 11 additions & 4 deletions tests/ann2data/test_ann2data_by_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,22 @@ def test_sample_case_ann2data_basic():
# make sure that there are two clusters of spatial coordinates
# so that the resulting splits number of edges will be the same
# as the sum of the number of edges in each cluster
func_args = {"radius": 4.0, "coord_type": "generic"}
func_args = {"radius": 4.0, "coord_type": "generic", "library_key": "image_id"}
coordinates[:25, 0] += 100
adata_gt = ad.AnnData(
np.random.rand(50, 2),
obs={"cell_type": ["a"] * 25 + ["b"] * 25, "image_id": list("cd" * 25)},
obs={"cell_type": ["a"] * 20 + ["b"] * 20 + ["c"] * 5 + ["d"] * 5, "image_id": list("xy" * 20) + ["z"] * 10},
obsm={"spatial_init": coordinates},
)
a2d = ann2data.Ann2DataByCategory(
fields={"x": ["X"], "edge_index": ["uns/edge_index"], "edge_weight": ["uns/edge_weight"]},
fields={
"x": ["X"],
"obs_names": ["obs_names"],
"var_names": ["var_names"],
"edge_index": ["uns/edge_index"],
"edge_weight": ["uns/edge_weight"],
"y": ["obs/cell_type"],
},
category="cell_type",
preprocess=transforms.Categorize(keys=["cell_type", "image_id"]),
transform=transforms.AddEdgeIndex(
Expand All @@ -30,7 +37,7 @@ def test_sample_case_ann2data_basic():
),
)
datas = list(a2d(adata_gt.copy()))
assert len(datas) == 2
assert len(datas) == 3
big_adata_tf = transforms.Compose(
[
transforms.Categorize(keys=["cell_type", "image_id"]),
Expand Down

0 comments on commit f3d94b0

Please sign in to comment.