Skip to content

Commit

Permalink
ensure dimension order from BaseRegularGridDatastore.stack_grid_coords
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Nov 12, 2024
1 parent e23e110 commit 2cf64ac
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
19 changes: 17 additions & 2 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def get_xy_extent(self, category: str) -> List[float]:
The extent of the x, y coordinates.
"""
xy = self.get_xy(category, stacked=False)
xy = self.get_xy(category)
extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
return [float(v) for v in extent]

Expand Down Expand Up @@ -463,7 +463,22 @@ def stack_grid_coords(
if "grid_index" in da_or_ds.dims:
return da_or_ds

return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
# find the feature dimension, which has named with the format
# `{category}_feature`
potential_feature_dims = [
d for d in da_or_ds_stacked.dims if d.endswith("_feature")
]
if not len(potential_feature_dims) == 1:
raise ValueError(
"Expected exactly one feature dimension in the stacked data, "
f"got {potential_feature_dims}"
)
feature_dim = potential_feature_dims[0]

# ensure that grid_index is the first dimension, and the feature
# dimension is the second
return da_or_ds_stacked.transpose("grid_index", feature_dim, ...)

@property
@functools.lru_cache
Expand Down
10 changes: 4 additions & 6 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,22 +310,20 @@ def get_grid_shape_state(datastore_name):


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_stacking_grid_coords(datastore_name):
@pytest.mark.parametrize("category", ["state", "forcing", "static"])
def test_stacking_grid_coords(datastore_name, category):
"""Check that the `datastore.stack_grid_coords` method is implemented."""
datastore = init_datastore_example(datastore_name)

if not isinstance(datastore, BaseRegularGridDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")

da_static = datastore.get_dataarray("static", split=None)
da_static = datastore.get_dataarray(category=category, split=None)

da_static_unstacked = datastore.unstack_grid_coords(da_static).load()
da_static_test = datastore.stack_grid_coords(da_static_unstacked)

# XXX: for the moment unstacking doesn't guarantee the order of the
# dimensions maybe we should enforce this?
da_static_test = da_static_test.transpose(*da_static.dims)

assert da_static.dims == da_static_test.dims
xr.testing.assert_equal(da_static, da_static_test)


Expand Down

0 comments on commit 2cf64ac

Please sign in to comment.