From a37dc3ceddfdb9421767528a03e08147bbe4a185 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 14:03:22 +0100 Subject: [PATCH] add num_ensemble_members property to BaseDatastore --- neural_lam/datastore/base.py | 13 +++++++++++++ neural_lam/datastore/npyfilesmeps/store.py | 6 ++++-- neural_lam/weather_dataset.py | 18 +++++++++--------- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index b9de2da5..84600b50 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -298,6 +298,19 @@ def num_grid_points(self) -> int: """ pass + @property + @abc.abstractmethod + def num_ensemble_members(self) -> int: + """Return the number of ensemble members in the dataset. + + Returns + ------- + int + The number of ensemble members in the dataset. + + """ + pass + @cached_property @abc.abstractmethod def state_feature_weights_values(self) -> List[float]: diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 24349e7e..b91f7291 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -166,7 +166,6 @@ def __init__( self._root_path = self._config_path.parent self._config = NpyDatastoreConfig.from_yaml_file(self._config_path) - self._num_ensemble_members = self.config.dataset.num_ensemble_members self._num_timesteps = self.config.dataset.num_timesteps self._step_length = self.config.dataset.step_length self._remove_state_features_with_index = ( @@ -199,6 +198,9 @@ def config(self) -> NpyDatastoreConfig: """ return self._config + def num_ensemble_members(self) -> int: + return self.config.dataset.num_ensemble_members + def get_dataarray(self, category: str, split: str) -> DataArray: """ Get the data array for the given category and split of data. If the @@ -230,7 +232,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray: if category == "state": das = [] # for the state category, we need to load all ensemble members - for member in range(self._num_ensemble_members): + for member in range(self.num_ensemble_members): da_member = self._get_single_timeseries_dataarray( features=self.get_vars_names(category="state"), split=split, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index c6b142ec..d2fb5921 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -217,21 +217,21 @@ def __init__( self.da_boundary_std = self.ds_boundary_stats.forcing_std def __len__(self): + + if self.datastore.is_ensemble: + warnings.warn( + "only using first ensemble member, so dataset size is " + " effectively reduced by the number of ensemble members " + f"({self.datastore.num_ensemble_members})", + UserWarning, + ) + if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time # and then take the first (2 + ar_steps) forecast times. In # addition we only use the first ensemble member (if ensemble data # has been provided). # This means that for each analysis time we get a single sample - - if self.datastore.is_ensemble: - warnings.warn( - "only using first ensemble member, so dataset size is " - " effectively reduced by the number of ensemble members " - f"({self.datastore._num_ensemble_members})", - UserWarning, - ) - # check that there are enough forecast steps available to create # samples given the number of autoregressive steps requested n_forecast_steps = self.da_state.elapsed_forecast_duration.size