diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 83610a6c..5457bed9 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -295,7 +295,7 @@ def get_xy_extent(self, category: str) -> List[float]: The extent of the x, y coordinates. """ - xy = self.get_xy(category, stacked=False) + xy = self.get_xy(category) extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] return [float(v) for v in extent] @@ -463,7 +463,22 @@ def stack_grid_coords( if "grid_index" in da_or_ds.dims: return da_or_ds - return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) + da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) + # find the feature dimension, which has named with the format + # `{category}_feature` + potential_feature_dims = [ + d for d in da_or_ds_stacked.dims if d.endswith("_feature") + ] + if not len(potential_feature_dims) == 1: + raise ValueError( + "Expected exactly one feature dimension in the stacked data, " + f"got {potential_feature_dims}" + ) + feature_dim = potential_feature_dims[0] + + # ensure that grid_index is the first dimension, and the feature + # dimension is the second + return da_or_ds_stacked.transpose("grid_index", feature_dim, ...) @property @functools.lru_cache diff --git a/tests/test_datastores.py b/tests/test_datastores.py index c0d69ec0..00cd508e 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -310,22 +310,20 @@ def get_grid_shape_state(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_stacking_grid_coords(datastore_name): +@pytest.mark.parametrize("category", ["state", "forcing", "static"]) +def test_stacking_grid_coords(datastore_name, category): """Check that the `datastore.stack_grid_coords` method is implemented.""" datastore = init_datastore_example(datastore_name) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip("Datastore does not implement `BaseCartesianDatastore`") - da_static = datastore.get_dataarray("static", split=None) + da_static = datastore.get_dataarray(category=category, split=None) da_static_unstacked = datastore.unstack_grid_coords(da_static).load() da_static_test = datastore.stack_grid_coords(da_static_unstacked) - # XXX: for the moment unstacking doesn't guarantee the order of the - # dimensions maybe we should enforce this? - da_static_test = da_static_test.transpose(*da_static.dims) - + assert da_static.dims == da_static_test.dims xr.testing.assert_equal(da_static, da_static_test)