diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index b06bf67b7..29a44aed0 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -1032,6 +1032,7 @@ def temporal_key(self: TemporalMixinProtocol[K, B], key: Optional[str]) -> None: return if key not in self.adata.obs: raise KeyError(f"Unable to find temporal key in `adata.obs[{key!r}]`.") + self.adata.obs[key] = self.adata.obs[key].astype("category") col = self.adata.obs[key] if not (is_categorical_dtype(col) and is_numeric_dtype(col.cat.categories)): raise TypeError( diff --git a/tests/data/moscot_temporal_tests.h5ad b/tests/data/moscot_temporal_tests.h5ad index 1303ab60f..b40d70ecc 100644 Binary files a/tests/data/moscot_temporal_tests.h5ad and b/tests/data/moscot_temporal_tests.h5ad differ