Skip to content

Commit

Permalink
add num_ensemble_members property to BaseDatastore
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Dec 20, 2024
1 parent b690563 commit a37dc3c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
13 changes: 13 additions & 0 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,19 @@ def num_grid_points(self) -> int:
"""
pass

@property
@abc.abstractmethod
def num_ensemble_members(self) -> int:
"""Return the number of ensemble members in the dataset.
Returns
-------
int
The number of ensemble members in the dataset.
"""
pass

@cached_property
@abc.abstractmethod
def state_feature_weights_values(self) -> List[float]:
Expand Down
6 changes: 4 additions & 2 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def __init__(
self._root_path = self._config_path.parent
self._config = NpyDatastoreConfig.from_yaml_file(self._config_path)

self._num_ensemble_members = self.config.dataset.num_ensemble_members
self._num_timesteps = self.config.dataset.num_timesteps
self._step_length = self.config.dataset.step_length
self._remove_state_features_with_index = (
Expand Down Expand Up @@ -199,6 +198,9 @@ def config(self) -> NpyDatastoreConfig:
"""
return self._config

def num_ensemble_members(self) -> int:
return self.config.dataset.num_ensemble_members

def get_dataarray(self, category: str, split: str) -> DataArray:
"""
Get the data array for the given category and split of data. If the
Expand Down Expand Up @@ -230,7 +232,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
if category == "state":
das = []
# for the state category, we need to load all ensemble members
for member in range(self._num_ensemble_members):
for member in range(self.num_ensemble_members):
da_member = self._get_single_timeseries_dataarray(
features=self.get_vars_names(category="state"),
split=split,
Expand Down
18 changes: 9 additions & 9 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,21 +217,21 @@ def __init__(
self.da_boundary_std = self.ds_boundary_stats.forcing_std

def __len__(self):

if self.datastore.is_ensemble:
warnings.warn(
"only using first ensemble member, so dataset size is "
" effectively reduced by the number of ensemble members "
f"({self.datastore.num_ensemble_members})",
UserWarning,
)

if self.datastore.is_forecast:
# for now we simply create a single sample for each analysis time
# and then take the first (2 + ar_steps) forecast times. In
# addition we only use the first ensemble member (if ensemble data
# has been provided).
# This means that for each analysis time we get a single sample

if self.datastore.is_ensemble:
warnings.warn(
"only using first ensemble member, so dataset size is "
" effectively reduced by the number of ensemble members "
f"({self.datastore._num_ensemble_members})",
UserWarning,
)

# check that there are enough forecast steps available to create
# samples given the number of autoregressive steps requested
n_forecast_steps = self.da_state.elapsed_forecast_duration.size
Expand Down

0 comments on commit a37dc3c

Please sign in to comment.