Skip to content

Commit

Permalink
get_vars_names and units
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Jul 17, 2024
1 parent 3c864b2 commit 1f54b0e
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 23 deletions.
4 changes: 2 additions & 2 deletions neural_lam/datastore/mllam.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def step_length(self) -> int:
return da_dt.dt.seconds[0] // 3600

def get_vars_units(self, category: str) -> List[str]:
return self._ds[f"{category}_unit"].values.tolist()
return self._ds[f"{category}_feature_units"].values.tolist()

def get_vars_names(self, category: str) -> List[str]:
return self._ds[f"{category}_longname"].values.tolist()
return self._ds[f"{category}_feature"].values.tolist()

def get_num_data_vars(self, category: str) -> int:
return self._ds[f"{category}_feature"].count().item()
Expand Down
16 changes: 13 additions & 3 deletions neural_lam/datastore/npyfiles/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,17 @@ def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
def get_vars_units(self, category: str) -> torch.List[str]:
if category == "state":
return self.config["dataset"]["var_units"]
elif category == "forcing":
return [
"W/m^2",
"kg/m^2",
"1",
"1",
"1",
"1",
]
elif category == "static":
return ["m^2/s^2", "1", "m", "m"]
else:
raise NotImplementedError(f"Category {category} not supported")

Expand All @@ -471,9 +482,8 @@ def get_vars_names(self, category: str) -> torch.List[str]:
else:
raise NotImplementedError(f"Category {category} not supported")

@property
def get_num_data_vars(self) -> int:
return len(self.get_vars_names(category="state"))
def get_num_data_vars(self, category: str) -> int:
return len(self.get_vars_names(category=category))

def get_xy(self, category: str, stacked: bool) -> np.ndarray:
"""Return the x, y coordinates of the dataset.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"plotly>=5.15.0",
"torch>=2.3.0",
"torch-geometric==2.3.1",
"mllam-data-prep @ git+https://github.com/mllam/mllam-data-prep",
]
requires-python = ">=3.9"

Expand Down
31 changes: 16 additions & 15 deletions tests/datastore_configs/mllam/example.danra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@ output:
step: PT3H
chunking:
time: 1
splitting_dim: time
splits:
train:
start: 1990-09-03T00:00
end: 1990-09-06T00:00
compute_statistics:
ops: [mean, std]
dims: [grid_index, time]
validation:
start: 1990-09-06T00:00
end: 1990-09-07T00:00
test:
start: 1990-09-07T00:00
end: 1990-09-09T00:00
splitting:
dim: time
splits:
train:
start: 1990-09-03T00:00
end: 1990-09-06T00:00
compute_statistics:
ops: [mean, std]
dims: [grid_index, time]
validation:
start: 1990-09-06T00:00
end: 1990-09-07T00:00
test:
start: 1990-09-07T00:00
end: 1990-09-09T00:00

inputs:
danra_height_levels:
Expand Down Expand Up @@ -59,7 +60,7 @@ inputs:
dims: [time, x, y]
variables:
# shouldn't really be using sea-surface pressure as "forcing", but don't
# have radiation varibles in danra yet
# have radiation variables in danra yet
- pres_seasurface
dim_mapping:
time:
Expand Down
71 changes: 68 additions & 3 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,30 @@
"""List of methods and attributes that should be implemented in a subclass of
`BaseCartesianDatastore` (these are all decorated with `@abc.abstractmethod`):
- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
- [x] `coords_projection` (property): Projection object for the coordinates.
- [ ] `get_vars_units` (method): Get the units of the variables in the given category.
- [ ] `get_vars_names` (method): Get the names of the variables in the given category.
- [ ] `get_num_data_vars` (method): Get the number of data variables in the
given category.
- [ ] `get_normalization_dataarray` (method): Return the normalization
dataarray for the given category.
- [ ] `get_dataarray` (method): Return the processed data (as a single
`xr.DataArray`) for the given category and test/train/val-split.
- [ ] `boundary_mask` (property): Return the boundary mask for the dataset,
with spatial dimensions stacked.
In addition BaseCartesianDatastore must have the following methods and attributes:
- [ ] `get_xy_extent` (method): Return the extent of the x, y coordinates for a
given category of data.
- [ ] `get_xy` (method): Return the x, y coordinates of the dataset.
- [ ] `coords_projection` (property): Projection object for the coordinates.
- [ ] `grid_shape_state` (property): Shape of the grid for the state variables.
"""

# Third-party
import cartopy.crs as ccrs
import pytest

# First-party
Expand All @@ -22,10 +48,17 @@
)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_datastore(datastore_name):
def _init_datastore(datastore_name):
DatastoreClass = DATASTORES[datastore_name]
datastore = DatastoreClass(**EXAMPLES[datastore_name])
return DatastoreClass(**EXAMPLES[datastore_name])


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_datastore_grid_xy(datastore_name):
"""Use the `datastore.get_xy` method to get the x, y coordinates of the
dataset and check that the shape is correct against the
`datastore.grid_shape_state` property."""
datastore = _init_datastore(datastore_name)

# check the shapes of the xy grid
grid_shape = datastore.grid_shape_state
Expand All @@ -40,3 +73,35 @@ def test_datastore(datastore_name):
assert xy.shape == (2, nx * ny)
else:
assert xy.shape == (2, ny, nx)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_projection(datastore_name):
"""Check that the `datastore.coords_projection` property is implemented."""
datastore = _init_datastore(datastore_name)

assert isinstance(datastore.coords_projection, ccrs.Projection)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_vars(datastore_name):
"""Check that results of.
- `datastore.get_vars_units`
- `datastore.get_vars_names`
- `datastore.get_num_data_vars`
are consistent (as in the number of variables are the same) and that the
return types of each are correct.
"""
datastore = _init_datastore(datastore_name)

for category in ["state", "forcing", "static"]:
units = datastore.get_vars_units(category)
names = datastore.get_vars_names(category)
num_vars = datastore.get_num_data_vars(category)

assert len(units) == len(names) == num_vars
assert isinstance(units, list)
assert isinstance(names, list)
assert isinstance(num_vars, int)

0 comments on commit 1f54b0e

Please sign in to comment.