From 5df1bff46f22f818a01389f2d8bf5148d822bde9 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 19:46:37 +0100 Subject: [PATCH 001/103] add datastore_boundary to neural_lam --- neural_lam/train_model.py | 22 ++++++++++++++++++++++ neural_lam/weather_dataset.py | 10 ++++++++++ 2 files changed, 32 insertions(+) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c89..37bf6db7 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -34,6 +34,11 @@ def main(input_args=None): type=str, help="Path to the configuration for neural-lam", ) + parser.add_argument( + "--config_path_boundary", + type=str, + help="Path to the configuration for boundary conditions", + ) parser.add_argument( "--model", type=str, @@ -212,6 +217,9 @@ def main(input_args=None): assert ( args.config_path is not None ), "Specify your config with --config_path" + assert ( + args.config_path_boundary is not None + ), "Specify your config with --config_path_boundary" assert args.model in MODELS, f"Unknown model: {args.model}" assert args.eval in ( None, @@ -227,10 +235,24 @@ def main(input_args=None): # Load neural-lam configuration and datastore to use config, datastore = load_config_and_datastore(config_path=args.config_path) + config_boundary, datastore_boundary = load_config_and_datastore( + config_path=args.config_path_boundary + ) + + # TODO this should not be required, make more flexible + assert ( + datastore.num_past_forcing_steps + == datastore_boundary.num_past_forcing_steps + ), "Mismatch in num_past_forcing_steps" + assert ( + datastore.num_future_forcing_steps + == datastore_boundary.num_future_forcing_steps + ), "Mismatch in num_future_forcing_steps" # Create datamodule data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=args.ar_steps_train, ar_steps_eval=args.ar_steps_eval, standardize=True, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 532e3c90..51256e41 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -22,6 +22,8 @@ class WeatherDataset(torch.utils.data.Dataset): ---------- datastore : BaseDatastore The datastore to load the data from (e.g. mdp). + datastore_boundary : BaseDatastore + The boundary datastore to load the data from (e.g. mdp). split : str, optional The data split to use ("train", "val" or "test"). Default is "train". ar_steps : int, optional @@ -43,6 +45,7 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, split="train", ar_steps=3, num_past_forcing_steps=1, @@ -54,6 +57,7 @@ def __init__( self.split = split self.ar_steps = ar_steps self.datastore = datastore + self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps @@ -605,6 +609,7 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, ar_steps_train=3, ar_steps_eval=25, standardize=True, @@ -615,6 +620,7 @@ def __init__( ): super().__init__() self._datastore = datastore + self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps self.ar_steps_train = ar_steps_train @@ -626,6 +632,7 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: + # BUG: There also seem to be issues with "spawn", to be investigated # default to spawn for now, as the default on linux "fork" hangs # when using dask (which the npyfilesmeps datastore uses) self.multiprocessing_context = "spawn" @@ -636,6 +643,7 @@ def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="train", ar_steps=self.ar_steps_train, standardize=self.standardize, @@ -644,6 +652,7 @@ def setup(self, stage=None): ) self.val_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="val", ar_steps=self.ar_steps_eval, standardize=self.standardize, @@ -654,6 +663,7 @@ def setup(self, stage=None): if stage == "test" or stage is None: self.test_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="test", ar_steps=self.ar_steps_eval, standardize=self.standardize, From 46590efc277cb809d788ce5af44133f8b95eb279 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 20:15:41 +0100 Subject: [PATCH 002/103] complete integration of boundary in weatherDataset --- neural_lam/weather_dataset.py | 55 ++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 51256e41..10b74086 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -67,6 +67,9 @@ def __init__( self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + self.da_boundary = self.datastore_boundary.get_dataarray( + category="boundary", split=self.split + ) # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -118,6 +121,15 @@ def __init__( self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + if self.da_boundary is not None: + self.ds_boundary_stats = ( + self.datastore_boundary.get_standardization_dataarray( + category="boundary" + ) + ) + self.da_boundary_mean = self.ds_boundary_stats.boundary_mean + self.da_boundary_std = self.ds_boundary_stats.boundary_std + def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time @@ -352,6 +364,8 @@ def _build_item_dataarrays(self, idx): The dataarray for the target states. da_forcing_windowed : xr.DataArray The dataarray for the forcing data, windowed for the sample. + da_boundary_windowed : xr.DataArray + The dataarray for the boundary data, windowed for the sample. da_target_times : xr.DataArray The dataarray for the target times. """ @@ -381,6 +395,11 @@ def _build_item_dataarrays(self, idx): else: da_forcing = None + if self.da_boundary is not None: + da_boundary = self.da_boundary + else: + da_boundary = None + # handle time sampling in a way that is compatible with both analysis # and forecast data da_state = self._slice_state_time( @@ -390,11 +409,17 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed = self._slice_forcing_time( da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps ) + if da_boundary is not None: + da_boundary_windowed = self._slice_forcing_time( + da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps + ) # load the data into memory da_state.load() if da_forcing is not None: da_forcing_windowed.load() + if da_boundary is not None: + da_boundary_windowed.load() da_init_states = da_state.isel(time=slice(0, 2)) da_target_states = da_state.isel(time=slice(2, None)) @@ -417,6 +442,11 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed - self.da_forcing_mean ) / self.da_forcing_std + if da_boundary is not None: + da_boundary_windowed = ( + da_boundary_windowed - self.da_boundary_mean + ) / self.da_boundary_std + if da_forcing is not None: # stack the `forcing_feature` and `window_sample` dimensions into a # single `forcing_feature` dimension @@ -436,11 +466,31 @@ def _build_item_dataarrays(self, idx): "forcing_feature": [], }, ) + if da_boundary is not None: + # stack the `forcing_feature` and `window_sample` dimensions into a + # single `forcing_feature` dimension + da_boundary_windowed = da_boundary_windowed.stack( + boundary_feature_windowed=("boundary_feature", "window") + ) + else: + # create an empty forcing tensor with the right shape + da_boundary_windowed = xr.DataArray( + data=np.empty( + (self.ar_steps, da_state.grid_index.size, 0), + ), + dims=("time", "grid_index", "boundary_feature"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + "boundary_feature": [], + }, + ) return ( da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) @@ -475,6 +525,7 @@ def __getitem__(self, idx): da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) = self._build_item_dataarrays(idx=idx) @@ -491,13 +542,15 @@ def __getitem__(self, idx): ) forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype) + boundary = torch.tensor(da_boundary_windowed.values, dtype=tensor_dtype) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) # forcing: (ar_steps, N_grid, d_windowed_forcing) + # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) - return init_states, target_states, forcing, target_times + return init_states, target_states, forcing, boundary, target_times def __iter__(self): """ From b990f4941bd7167160a2f265b1e9fe17026ed31e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 20:15:55 +0100 Subject: [PATCH 003/103] Add test to check timestep length and spacing --- neural_lam/weather_dataset.py | 76 +++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 10b74086..97d9f9c3 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -101,6 +101,82 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # Check time coverage for forcing and boundary data + if self.da_forcing is not None or self.da_boundary is not None: + state_times = self.da_state.time + state_time_min = state_times.min().values + state_time_max = state_times.max().values + + def get_time_step(times): + """Calculate the time step from the data""" + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + if self.da_forcing is not None: + forcing_times = self.da_forcing.time + forcing_time_step = get_time_step(forcing_times.values) + forcing_time_min = forcing_times.min().values + forcing_time_max = forcing_times.max().values + + # Calculate required bounds for forcing using its time step + forcing_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * forcing_time_step + ) + forcing_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * forcing_time_step + ) + + if forcing_time_min > forcing_required_time_min: + raise ValueError( + f"Forcing data starts too late." + f"Required start: {forcing_required_time_min}, " + f"but forcing starts at {forcing_time_min}." + ) + + if forcing_time_max < forcing_required_time_max: + raise ValueError( + f"Forcing data ends too early." + f"Required end: {forcing_required_time_max}," + f"but forcing ends at {forcing_time_max}." + ) + + if self.da_boundary is not None: + boundary_times = self.da_boundary.time + boundary_time_step = get_time_step(boundary_times.values) + boundary_time_min = boundary_times.min().values + boundary_time_max = boundary_times.max().values + + # Calculate required bounds for boundary using its time step + boundary_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * boundary_time_step + ) + boundary_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * boundary_time_step + ) + + if boundary_time_min > boundary_required_time_min: + raise ValueError( + f"Boundary data starts too late." + f"Required start: {boundary_required_time_min}, " + f"but boundary starts at {boundary_time_min}." + ) + + if boundary_time_max < boundary_required_time_max: + raise ValueError( + f"Boundary data ends too early." + f"Required end: {boundary_required_time_max}, " + f"but boundary ends at {boundary_time_max}." + ) + # Set up for standardization # TODO: This will become part of ar_model.py soon! self.standardize = standardize From 3fd1d6be82d0174b106922a7ff9c74255bac5a35 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 21:43:57 +0100 Subject: [PATCH 004/103] setting default mdp boundary to 0 gridcells --- neural_lam/datastore/mdp.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 10593a82..8c67fe58 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -26,7 +26,7 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): + def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at `config_path`. A boundary mask is created with `n_boundary_points` @@ -335,19 +335,22 @@ def boundary_mask(self) -> xr.DataArray: boundary point and 0 is not. """ - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) + if self._n_boundary_points > 0: + ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) + da_state_variable = ( + ds_unstacked["state"].isel(time=0).isel(state_feature=0) + ) + da_domain_allzero = xr.zeros_like(da_state_variable) + ds_unstacked["boundary_mask"] = da_domain_allzero.isel( + x=slice(self._n_boundary_points, -self._n_boundary_points), + y=slice(self._n_boundary_points, -self._n_boundary_points), + ) + ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( + 1 + ).astype(int) + return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) + else: + return None @property def coords_projection(self) -> ccrs.Projection: From 1f2499c3b3fb8493b89d2be97ff301181c756f72 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 21:44:54 +0100 Subject: [PATCH 005/103] implement time-based slicing combine two slicing fcts into one --- neural_lam/weather_dataset.py | 300 ++++++++++++++++++---------------- 1 file changed, 161 insertions(+), 139 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 97d9f9c3..5d35a4b7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -67,8 +67,9 @@ def __init__( self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + # XXX For now boundary data is always considered forcing data self.da_boundary = self.datastore_boundary.get_dataarray( - category="boundary", split=self.split + category="forcing", split=self.split ) # check that with the provided data-arrays and ar_steps that we have a @@ -200,7 +201,7 @@ def get_time_step(times): if self.da_boundary is not None: self.ds_boundary_stats = ( self.datastore_boundary.get_standardization_dataarray( - category="boundary" + category="forcing" ) ) self.da_boundary_mean = self.ds_boundary_stats.boundary_mean @@ -252,175 +253,156 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_state_time(self, da_state, idx, n_steps: int): + def _slice_time(self, da_state, da_forcing, idx, n_steps: int): """ - Produce a time slice of the given dataarray `da_state` (state) starting - at `idx` and with `n_steps` steps. An `offset`is calculated based on the - `num_past_forcing_steps` class attribute. `Offset` is used to offset the - start of the sample, to assert that enough previous time steps are - available for the 2 initial states and any corresponding forcings - (calculated in `_slice_forcing_time`). + Produce time slices of the given dataarrays `da_state` (state) and + `da_forcing` (forcing). For the state data, slicing is done as before + based on `idx`. For the forcing data, nearest neighbor matching is + performed based on the state times. Additionally, the time difference + between the matched forcing times and state times (in multiples of state + time steps) is added to the forcing dataarray. Parameters ---------- da_state : xr.DataArray - The dataarray to slice. This is expected to have a `time` dimension - if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. + The state dataarray to slice. + da_forcing : xr.DataArray + The forcing dataarray to slice. idx : int - The index of the time step to start the sample from. + The index of the time step to start the sample from in the state + data. n_steps : int The number of time steps to include in the sample. Returns ------- - da_sliced : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', + da_state_sliced : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). + da_forcing_matched : xr.DataArray + The forcing dataarray matched to state times with an added + coordinate 'time_diff', representing the time difference to state + times in multiples of state time steps. """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). + # Number of initial steps required (e.g., for initializing models) init_steps = 2 - # slice the dataarray to include the required number of time steps + + # Slice the state data as before if self.datastore.is_forecast: + # Calculate start and end indices for slicing start_idx = max(0, self.num_past_forcing_steps - init_steps) end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps - # this implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select a analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast, always starting at forecast time 2. - da_sliced = da_state.isel( + + # Slice the state data over the elapsed forecast duration + da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) - # create a new time dimension so that the produced sample has a - # `time` dimension, similarly to the analysis only data - da_sliced["time"] = ( - da_sliced.analysis_time + da_sliced.elapsed_forecast_duration + + # Create a new 'time' dimension + da_state_sliced["time"] = ( + da_state_sliced.analysis_time + + da_state_sliced.elapsed_forecast_duration ) - da_sliced = da_sliced.swap_dims( + da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + else: - # For analysis data we slice the time dimension directly. The offset - # is only relevant for the very first (and last) samples in the - # dataset. + # For analysis data, slice the time dimension directly start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) end_idx = ( idx + max(init_steps, self.num_past_forcing_steps) + n_steps ) - da_sliced = da_state.isel(time=slice(start_idx, end_idx)) - return da_sliced + da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - def _slice_forcing_time(self, da_forcing, idx, n_steps: int): - """ - Produce a time slice of the given dataarray `da_forcing` (forcing) - starting at `idx` and with `n_steps` steps. An `offset` is calculated - based on the `num_past_forcing_steps` class attribute. It is used to - offset the start of the sample, to ensure that enough previous time - steps are available for the forcing data. The forcing data is windowed - around the current autoregressive time step to include the past and - future forcings. - - Parameters - ---------- - da_forcing : xr.DataArray - The forcing dataarray to slice. This is expected to have a `time` - dimension if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. - idx : int - The index of the time step to start the sample from. - n_steps : int - The number of time steps to include in the sample. - - Returns - ------- - da_concat : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', - 'window', 'forcing_feature'). - """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). The forcing data is windowed around the - # current autregressive time step. The two `init_steps` can also be used - # as past forcings. - init_steps = 2 - da_list = [] + # Get the state times for matching + state_times = da_state_sliced["time"] + # Match forcing data to state times based on nearest neighbor if self.datastore.is_forecast: - # This implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select an analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast. - # Add a 'time' dimension using the actual forecast times - offset = max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps - - current_time = ( - da_forcing.analysis_time[idx] - + da_forcing.elapsed_forecast_duration[offset + step] - ) - - da_sliced = da_forcing.isel( - analysis_time=idx, - elapsed_forecast_duration=slice(start_idx, end_idx + 1), - ) - - da_sliced = da_sliced.rename( - {"elapsed_forecast_duration": "window"} - ) + # Calculate all possible forcing times + forcing_times = ( + da_forcing.analysis_time + da_forcing.elapsed_forecast_duration + ) + forcing_times_flat = forcing_times.stack( + forecast_time=("analysis_time", "elapsed_forecast_duration") + ) - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) - ) + # Compute time differences + time_deltas = ( + forcing_times_flat.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) + time_diffs = np.abs(time_deltas) + idx_min = time_diffs.argmin(axis=0) + + # Retrieve corresponding indices for analysis_time and + # elapsed_forecast_duration + forecast_time_index = forcing_times_flat["forecast_time"][idx_min] + analysis_time_indices = forecast_time_index["analysis_time"] + elapsed_forecast_duration_indices = forecast_time_index[ + "elapsed_forecast_duration" + ] + + # Slice the forcing data using matched indices + da_forcing_matched = da_forcing.isel( + analysis_time=("time", analysis_time_indices), + elapsed_forecast_duration=( + "time", + elapsed_forecast_duration_indices, + ), + ) - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + # Assign matched state times to the forcing data + da_forcing_matched["time"] = state_times + da_forcing_matched = da_forcing_matched.swap_dims( + {"elapsed_forecast_duration": "time"} + ) - da_list.append(da_sliced) + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] + time_diff_steps = ( + time_deltas[idx_min, np.arange(len(state_times))] + / state_time_step + ) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + # Add time difference as a new coordinate + da_forcing_matched = da_forcing_matched.assign_coords( + time_diff=("time", time_diff_steps) + ) else: - # For analysis data, we slice the time dimension directly. The - # offset is only relevant for the very first (and last) samples in - # the dataset. - offset = idx + max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps - - # Slice the data over the desired time window - da_sliced = da_forcing.isel(time=slice(start_idx, end_idx + 1)) - - da_sliced = da_sliced.rename({"time": "window"}) - - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) - ) + # For analysis data, match directly using the 'time' coordinate + forcing_times = da_forcing["time"] - # Add a 'time' dimension to keep track of steps using actual - # time coordinates - current_time = da_forcing.time[offset + step] - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + # Compute time differences + time_deltas = ( + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) + time_diffs = np.abs(time_deltas) + idx_min = time_diffs.argmin(axis=0) - da_list.append(da_sliced) + # Slice the forcing data using matched indices + da_forcing_matched = da_forcing.isel(time=idx_min) + da_forcing_matched = da_forcing_matched.assign_coords( + time=state_times + ) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] + time_diff_steps = ( + time_deltas[idx_min, np.arange(len(state_times))] + / state_time_step + ) - return da_concat + # Add time difference as a new coordinate + da_forcing_matched = da_forcing_matched.assign_coords( + time_diff=("time", time_diff_steps) + ) + + return da_state_sliced, da_forcing_matched def _build_item_dataarrays(self, idx): """ @@ -442,6 +424,7 @@ def _build_item_dataarrays(self, idx): The dataarray for the forcing data, windowed for the sample. da_boundary_windowed : xr.DataArray The dataarray for the boundary data, windowed for the sample. + Boundary data is always considered forcing data. da_target_times : xr.DataArray The dataarray for the target times. """ @@ -478,15 +461,15 @@ def _build_item_dataarrays(self, idx): # handle time sampling in a way that is compatible with both analysis # and forecast data - da_state = self._slice_state_time( + da_state = self._slice_time( da_state=da_state, idx=idx, n_steps=self.ar_steps ) if da_forcing is not None: - da_forcing_windowed = self._slice_forcing_time( + da_forcing_windowed = self._slice_time( da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps ) if da_boundary is not None: - da_boundary_windowed = self._slice_forcing_time( + da_boundary_windowed = self._slice_time( da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps ) @@ -524,13 +507,32 @@ def _build_item_dataarrays(self, idx): ) / self.da_boundary_std if da_forcing is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension + # Expand 'time_diff' to align with 'forcing_feature' and 'window' + # dimensions 'time_diff' has dimension ('time'), expand to ('time', + # 'forcing_feature', 'window') + time_diff_expanded = da_forcing_windowed["time_diff"].expand_dims( + forcing_feature=da_forcing_windowed["forcing_feature"], + window=da_forcing_windowed["window"], + ) + + # Stack 'forcing_feature' and 'window' into a single + # 'forcing_feature_windowed' dimension da_forcing_windowed = da_forcing_windowed.stack( forcing_feature_windowed=("forcing_feature", "window") ) + time_diff_expanded = time_diff_expanded.stack( + forcing_feature_windowed=("forcing_feature", "window") + ) + + # Assign 'time_diff' as a coordinate to 'forcing_feature_windowed' + da_forcing_windowed = da_forcing_windowed.assign_coords( + time_diff=( + "forcing_feature_windowed", + time_diff_expanded.values, + ) + ) else: - # create an empty forcing tensor with the right shape + # Create an empty forcing tensor with the right shape da_forcing_windowed = xr.DataArray( data=np.empty( (self.ar_steps, da_state.grid_index.size, 0), @@ -542,14 +544,34 @@ def _build_item_dataarrays(self, idx): "forcing_feature": [], }, ) + if da_boundary is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension + # If 'da_boundary_windowed' also has 'time_diff', process similarly + # Expand 'time_diff' to align with 'boundary_feature' and 'window' + # dimensions + time_diff_expanded = da_boundary_windowed["time_diff"].expand_dims( + boundary_feature=da_boundary_windowed["boundary_feature"], + window=da_boundary_windowed["window"], + ) + + # Stack 'boundary_feature' and 'window' into a single + # 'boundary_feature_windowed' dimension da_boundary_windowed = da_boundary_windowed.stack( boundary_feature_windowed=("boundary_feature", "window") ) + time_diff_expanded = time_diff_expanded.stack( + boundary_feature_windowed=("boundary_feature", "window") + ) + + # Assign 'time_diff' as a coordinate to 'boundary_feature_windowed' + da_boundary_windowed = da_boundary_windowed.assign_coords( + time_diff=( + "boundary_feature_windowed", + time_diff_expanded.values, + ) + ) else: - # create an empty forcing tensor with the right shape + # Create an empty boundary tensor with the right shape da_boundary_windowed = xr.DataArray( data=np.empty( (self.ar_steps, da_state.grid_index.size, 0), From 1af1481e6884f89ccf39befa37e0d61ed16bbcc3 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 06:26:54 +0100 Subject: [PATCH 006/103] remove all interior_mask and boundary_mask --- neural_lam/datastore/base.py | 17 ------- neural_lam/datastore/mdp.py | 34 -------------- neural_lam/datastore/npyfilesmeps/store.py | 28 ------------ neural_lam/models/ar_model.py | 53 ++++------------------ neural_lam/vis.py | 16 ------- tests/dummy_datastore.py | 22 --------- tests/test_datastores.py | 21 --------- 7 files changed, 10 insertions(+), 181 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 0317c2e5..5aeedb2e 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -228,23 +228,6 @@ def get_dataarray( """ pass - @cached_property - @abc.abstractmethod - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - pass - @abc.abstractmethod def get_xy(self, category: str) -> np.ndarray: """ diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 8c67fe58..5365c723 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -318,40 +318,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: ds_stats = self._ds[stats_variables.keys()].rename(stats_variables) return ds_stats - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Produce a 0/1 mask for the boundary points of the dataset, these will - sit at the edges of the domain (in x/y extent) and will be used to mask - out the boundary points from the loss function and to overwrite the - boundary points from the prediction. For now this is created when the - mask is requested, but in the future this could be saved to the zarr - file. - - Returns - ------- - xr.DataArray - A 0/1 mask for the boundary points of the dataset, where 1 is a - boundary point and 0 is not. - - """ - if self._n_boundary_points > 0: - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) - else: - return None - @property def coords_projection(self) -> ccrs.Projection: """ diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 42e80706..146b0627 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -668,34 +668,6 @@ def grid_shape_state(self) -> CartesianGridShape: ny, nx = self.config.grid_shape_state return CartesianGridShape(x=nx, y=ny) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """The boundary mask for the dataset. This is a binary mask that is 1 - where the grid cell is on the boundary of the domain, and 0 otherwise. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions `[grid_index]`. - - """ - xy = self.get_xy(category="state", stacked=False) - xs = xy[:, :, 0] - ys = xy[:, :, 1] - # Check if x-coordinates are constant along columns - assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant" - # Check if y-coordinates are constant along rows - assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant" - # Extract unique x and y coordinates - x = xs[:, 0] # Unique x-coordinates (changes along the first axis) - y = ys[0, :] # Unique y-coordinates (changes along the second axis) - values = np.load(self.root_path / "static" / "border_mask.npy") - da_mask = xr.DataArray( - values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask" - ) - da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int) - return da_mask_stacked_xy - def get_standardization_dataarray(self, category: str) -> xr.Dataset: """Return the standardization dataarray for the given category. This should contain a `{category}_mean` and `{category}_std` variable for diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index bc4c6719..4ab73cc7 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -42,7 +42,6 @@ def __init__( da_state_stats = datastore.get_standardization_dataarray( category="state" ) - da_boundary_mask = datastore.boundary_mask num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps @@ -115,18 +114,6 @@ def __init__( # Instantiate loss function self.loss = metrics.get_metric(args.loss) - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.float32 - ).unsqueeze( - 1 - ) # add feature dim - - self.register_buffer("boundary_mask", boundary_mask, persistent=False) - # Pre-compute interior mask for use in loss function - self.register_buffer( - "interior_mask", 1.0 - self.boundary_mask, persistent=False - ) # (num_grid_nodes, 1), 1 for non-border - self.val_metrics = { "mse": [], } @@ -153,13 +140,6 @@ def configure_optimizers(self): ) return opt - @property - def interior_mask_bool(self): - """ - Get the interior mask as a boolean (N,) mask. - """ - return self.interior_mask[:, 0].to(torch.bool) - @staticmethod def expand_to_batch(x, batch_size): """ @@ -191,7 +171,6 @@ def unroll_prediction(self, init_states, forcing_features, true_states): for i in range(pred_steps): forcing = forcing_features[:, i] - border_state = true_states[:, i] pred_state, pred_std = self.predict_step( prev_state, prev_prev_state, forcing @@ -199,19 +178,13 @@ def unroll_prediction(self, init_states, forcing_features, true_states): # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, # d_f) or None - # Overwrite border with true state - new_state = ( - self.boundary_mask * border_state - + self.interior_mask * pred_state - ) - - prediction_list.append(new_state) + prediction_list.append(pred_state) if self.output_std: pred_std_list.append(pred_std) # Update conditioning states prev_prev_state = prev_state - prev_state = new_state + prev_state = pred_state prediction = torch.stack( prediction_list, dim=1 @@ -249,12 +222,14 @@ def training_step(self, batch): """ prediction, target, pred_std, _ = self.common_step(batch) - # Compute loss + # Compute loss - mean over unrolled times and batch batch_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ) - ) # mean over unrolled times and batch + ) log_dict = {"train_loss": batch_loss} self.log_dict( @@ -287,9 +262,7 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, _ = self.common_step(batch) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) @@ -314,7 +287,6 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.val_metrics["mse"].append(entry_mses) @@ -341,9 +313,7 @@ def test_step(self, batch, batch_idx): # pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1,) mean_loss = torch.mean(time_step_loss) @@ -372,16 +342,13 @@ def test_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.test_metrics[metric_name].append(batch_metric_vals) if self.output_std: # Store output std. per variable, spatially averaged - mean_pred_std = torch.mean( - pred_std[..., self.interior_mask_bool, :], dim=-2 - ) # (B, pred_steps, d_f) + mean_pred_std = torch.mean(pred_std, dim=-2) # (B, pred_steps, d_f) self.test_metrics["output_std"].append(mean_pred_std) # Save per-sample spatial loss for specific times diff --git a/neural_lam/vis.py b/neural_lam/vis.py index b9d18b39..31de8f32 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -86,13 +86,6 @@ def plot_prediction( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region - fig, axes = plt.subplots( 1, 2, @@ -112,7 +105,6 @@ def plot_prediction( data_grid, origin="lower", extent=extent, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="plasma", @@ -147,13 +139,6 @@ def plot_spatial_error( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region - fig, ax = plt.subplots( figsize=(5, 4.8), subplot_kw={"projection": datastore.coords_projection}, @@ -170,7 +155,6 @@ def plot_spatial_error( error_grid, origin="lower", extent=extent, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="OrRd", diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 9075d404..d62c7356 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -148,12 +148,6 @@ def __init__( times = [self.T0 + dt * i for i in range(n_timesteps)] self.ds.coords["time"] = times - # Add boundary mask - self.ds["boundary_mask"] = xr.DataArray( - np.random.choice([0, 1], size=(n_points_1d, n_points_1d)), - dims=["x", "y"], - ) - # Stack the spatial dimensions into grid_index self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS) @@ -342,22 +336,6 @@ def get_dataarray( dim_order = self.expected_dim_order(category=category) return self.ds[category].transpose(*dim_order) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - return self.ds["boundary_mask"] - def get_xy(self, category: str, stacked: bool) -> ndarray: """Return the x, y coordinates of the dataset. diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 4a4b1100..a91f6245 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -18,8 +18,6 @@ dataarray for the given category. - `get_dataarray` (method): Return the processed data (as a single `xr.DataArray`) for the given category and test/train/val-split. -- `boundary_mask` (property): Return the boundary mask for the dataset, - with spatial dimensions stacked. - `config` (property): Return the configuration of the datastore. In addition BaseRegularGridDatastore must have the following methods and @@ -213,25 +211,6 @@ def test_get_dataarray(datastore_name): assert n_features["train"] == n_features["val"] == n_features["test"] -@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_boundary_mask(datastore_name): - """Check that the `datastore.boundary_mask` property is implemented and - that the returned object is an xarray DataArray with the correct shape.""" - datastore = init_datastore_example(datastore_name) - da_mask = datastore.boundary_mask - - assert isinstance(da_mask, xr.DataArray) - assert set(da_mask.dims) == {"grid_index"} - assert da_mask.dtype == "int" - assert set(da_mask.values) == {0, 1} - assert da_mask.sum() > 0 - assert da_mask.sum() < da_mask.size - - if isinstance(datastore, BaseRegularGridDatastore): - grid_shape = datastore.grid_shape_state - assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y - - @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_xy_extent(datastore_name): """Check that the `datastore.get_xy_extent` method is implemented and that From d545cb7576de020b7d721c08741e784bc2b69c24 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:55:56 +0100 Subject: [PATCH 007/103] added gcsfs dependency for era5 weatherbench download --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f0bc0851..5bbe4d92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "torch-geometric==2.3.1", "parse>=1.20.2", "dataclass-wizard>=0.22.3", - "mllam-data-prep>=0.5.0", + "gcsfs>=2021.10.0", + "mllam-data-prep @ git+https://github.com/leifdenby/mllam-data-prep@temp/for-neural-lam-datastores", ] requires-python = ">=3.9" From 5c1a7d7cf9a4befb874ce847424787e818cced75 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:57:57 +0100 Subject: [PATCH 008/103] added new era5 datastore config for boundary --- tests/conftest.py | 19 +++- .../mdp/era5_1000hPa_winds/.gitignore | 2 + .../mdp/era5_1000hPa_winds/config.yaml | 3 + .../era5_1000hPa_winds/era5.datastore.yaml | 90 +++++++++++++++++++ 4 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml diff --git a/tests/conftest.py b/tests/conftest.py index 6f579621..be5cf3e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,15 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) +DATASTORES_BOUNDARY_EXAMPLES = dict( + mdp=( + DATASTORE_EXAMPLES_ROOT_PATH + / "mdp" + / "era5_1000hPa_winds" + / "era5.datastore.yaml" + ) +) + DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore @@ -102,5 +111,13 @@ def init_datastore_example(datastore_kind): datastore_kind=datastore_kind, config_path=DATASTORES_EXAMPLES[datastore_kind], ) - return datastore + + +def init_datastore_boundary_example(datastore_kind): + datastore_boundary = init_datastore( + datastore_kind=datastore_kind, + config_path=DATASTORES_BOUNDARY_EXAMPLES[datastore_kind], + ) + + return datastore_boundary diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore new file mode 100644 index 00000000..f2828f46 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore @@ -0,0 +1,2 @@ +*.zarr/ +graph/ diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml new file mode 100644 index 00000000..5d1e05f2 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml @@ -0,0 +1,3 @@ +datastore: + kind: mdp + config_path: era5.datastore.yaml diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml new file mode 100644 index 00000000..36b39501 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml @@ -0,0 +1,90 @@ +#TODO: What do these versions mean? Should they be updated? +schema_version: v0.2.0+dev +dataset_version: v1.0.0 + +output: + variables: + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-02T00:00 + end: 1990-09-10T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-02T00:00 + end: 1990-09-07T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-05T00:00 + end: 1990-09-08T00:00 + test: + start: 1990-09-06T00:00 + end: 1990-09-10T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + v_component_of_wind: + level: + values: [1000, ] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_surface_net_short_wave_radiation_flux + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 From 30e4f05e1c9cc726180868450286d9cf8279ce07 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:58:36 +0100 Subject: [PATCH 009/103] removed left-over boundary-mask references --- neural_lam/datastore/mdp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 5365c723..fd9acb4e 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -26,11 +26,10 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): + def __init__(self, config_path, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at - `config_path`. A boundary mask is created with `n_boundary_points` - boundary points. If `reuse_existing` is True, the dataset is loaded + `config_path`. If `reuse_existing` is True, the dataset is loaded from a zarr file if it exists (unless the config has been modified since the zarr was created), otherwise it is created from the configuration file. @@ -41,8 +40,6 @@ def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): The path to the configuration file, this will be fed to the `mllam_data_prep.Config.from_yaml_file` method to then call `mllam_data_prep.create_dataset` to create the dataset. - n_boundary_points : int - The number of boundary points to use in the boundary mask. reuse_existing : bool Whether to reuse an existing dataset zarr file if it exists and its creation date is newer than the configuration file. @@ -69,7 +66,6 @@ def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): if self._ds is None: self._ds = mdp.create_dataset(config=self._config) self._ds.to_zarr(fp_ds) - self._n_boundary_points = n_boundary_points print("The loaded datastore contains the following features:") for category in ["state", "forcing", "static"]: From 6a8c593f422c2844545feb2cc7e57de520dc1062 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:59:12 +0100 Subject: [PATCH 010/103] make check for existing category in datastore more flexible (for boundary) --- neural_lam/datastore/mdp.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index fd9acb4e..67aaa9d0 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -153,8 +153,8 @@ def get_vars_units(self, category: str) -> List[str]: The units of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_units"].values.tolist() @@ -172,8 +172,8 @@ def get_vars_names(self, category: str) -> List[str]: The names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature"].values.tolist() @@ -192,8 +192,8 @@ def get_vars_long_names(self, category: str) -> List[str]: The long names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_long_name"].values.tolist() @@ -248,9 +248,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: The xarray DataArray object with processed dataset. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") - return None + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return [] da_category = self._ds[category] From 17c920d36848d61153fd53781d8ec3ac90e5de56 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 20 Nov 2024 16:00:15 +0100 Subject: [PATCH 011/103] implement xarray based (mostly) time slicing and windowing --- neural_lam/weather_dataset.py | 255 +++++++++++++++------------------- 1 file changed, 111 insertions(+), 144 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5d35a4b7..c8806d1c 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -64,10 +64,16 @@ def __init__( self.da_state = self.datastore.get_dataarray( category="state", split=self.split ) + if self.da_state is None: + raise ValueError( + "A non-empty state dataarray must be provided. " + "The datastore.get_dataarray() returned None or empty array " + "for category='state'" + ) self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) - # XXX For now boundary data is always considered forcing data + # XXX For now boundary data is always considered mdp-forcing data self.da_boundary = self.datastore_boundary.get_dataarray( category="forcing", split=self.split ) @@ -102,53 +108,36 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + def get_time_step(times): + """Calculate the time step from the data""" + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + # Check time step consistency in state data + _ = get_time_step(self.da_state.time.values) + # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: state_times = self.da_state.time state_time_min = state_times.min().values state_time_max = state_times.max().values - def get_time_step(times): - """Calculate the time step from the data""" - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] - if self.da_forcing is not None: + # Forcing data is part of the same datastore as state data + # During creation the time dimension of the forcing data + # is matched to the state data forcing_times = self.da_forcing.time - forcing_time_step = get_time_step(forcing_times.values) - forcing_time_min = forcing_times.min().values - forcing_time_max = forcing_times.max().values - - # Calculate required bounds for forcing using its time step - forcing_required_time_min = ( - state_time_min - - self.num_past_forcing_steps * forcing_time_step - ) - forcing_required_time_max = ( - state_time_max - + self.num_future_forcing_steps * forcing_time_step - ) - - if forcing_time_min > forcing_required_time_min: - raise ValueError( - f"Forcing data starts too late." - f"Required start: {forcing_required_time_min}, " - f"but forcing starts at {forcing_time_min}." - ) - - if forcing_time_max < forcing_required_time_max: - raise ValueError( - f"Forcing data ends too early." - f"Required end: {forcing_required_time_max}," - f"but forcing ends at {forcing_time_max}." - ) + _ = get_time_step(forcing_times.values) if self.da_boundary is not None: + # Boundary data is part of a separate datastore + # The boundary data is allowed to have a different time_step + # Check that the boundary data covers the required time range boundary_times = self.da_boundary.time boundary_time_step = get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values @@ -204,8 +193,8 @@ def get_time_step(times): category="forcing" ) ) - self.da_boundary_mean = self.ds_boundary_stats.boundary_mean - self.da_boundary_std = self.ds_boundary_stats.boundary_std + self.da_boundary_mean = self.ds_boundary_stats.forcing_mean + self.da_boundary_std = self.ds_boundary_stats.forcing_std def __len__(self): if self.datastore.is_forecast: @@ -253,7 +242,7 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_time(self, da_state, da_forcing, idx, n_steps: int): + def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): """ Produce time slices of the given dataarrays `da_state` (state) and `da_forcing` (forcing). For the state data, slicing is done as before @@ -316,8 +305,13 @@ def _slice_time(self, da_state, da_forcing, idx, n_steps: int): ) da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) + if da_forcing is None: + return da_state_sliced, None + # Get the state times for matching state_times = da_state_sliced["time"] + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] # Match forcing data to state times based on nearest neighbor if self.datastore.is_forecast: @@ -371,39 +365,80 @@ def _slice_time(self, da_state, da_forcing, idx, n_steps: int): da_forcing_matched = da_forcing_matched.assign_coords( time_diff=("time", time_diff_steps) ) - else: # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing["time"] # Compute time differences time_deltas = ( - forcing_times.values[:, np.newaxis] - - state_times.values[np.newaxis, :] + state_times.values[np.newaxis, :] + - forcing_times.values[:, np.newaxis] + ) + idx_min = np.abs(time_deltas).argmin(axis=0) + + time_diff_steps = xr.DataArray( + np.stack( + [ + np.diagonal(time_deltas, offset=offset)[ + -len(state_times) + init_steps : + ] + / state_time_step + for offset in range( + -self.num_past_forcing_steps, + self.num_future_forcing_steps + 1, + ) + ], + axis=1, + ), + dims=["time", "window"], + coords={ + "time": state_times.isel(time=slice(init_steps, None)), + "window": np.arange( + -self.num_past_forcing_steps, + self.num_future_forcing_steps + 1, + ), + }, + name="time_diff_steps", ) - time_diffs = np.abs(time_deltas) - idx_min = time_diffs.argmin(axis=0) - # Slice the forcing data using matched indices - da_forcing_matched = da_forcing.isel(time=idx_min) - da_forcing_matched = da_forcing_matched.assign_coords( - time=state_times + # Create window dimension using rolling + window_size = ( + self.num_past_forcing_steps + self.num_future_forcing_steps + 1 ) - - # Calculate time differences in multiples of state time steps - state_time_step = state_times.values[1] - state_times.values[0] - time_diff_steps = ( - time_deltas[idx_min, np.arange(len(state_times))] - / state_time_step + da_forcing_windowed = da_forcing.rolling( + time=window_size, center=True + ).construct(window_dim="window") + da_forcing_matched = da_forcing_windowed.isel( + time=idx_min[init_steps:] ) # Add time difference as a new coordinate da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=("time", time_diff_steps) + time_diff=time_diff_steps ) return da_state_sliced, da_forcing_matched + def _process_windowed_data(self, da_windowed, da_state, da_target_times): + """Helper function to process windowed data after standardization.""" + stacked_dim = "forcing_feature_windowed" + if da_windowed is not None: + # Stack the 'feature' and 'window' dimensions + da_windowed = da_windowed.stack( + {stacked_dim: ("forcing_feature", "window")} + ) + else: + # Create empty DataArray with the correct dimensions and coordinates + return xr.DataArray( + data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), + dims=("time", "grid_index", f"{stacked_dim}"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + f"{stacked_dim}": [], + }, + ) + def _build_item_dataarrays(self, idx): """ Create the dataarrays for the initial states, target states and forcing @@ -459,18 +494,21 @@ def _build_item_dataarrays(self, idx): else: da_boundary = None - # handle time sampling in a way that is compatible with both analysis - # and forecast data - da_state = self._slice_time( - da_state=da_state, idx=idx, n_steps=self.ar_steps + # if da_forcing is None, the function will return None for + # da_forcing_windowed + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, ) - if da_forcing is not None: - da_forcing_windowed = self._slice_time( - da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps - ) + if da_boundary is not None: - da_boundary_windowed = self._slice_time( - da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps + _, da_boundary_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_boundary, ) # load the data into memory @@ -506,83 +544,12 @@ def _build_item_dataarrays(self, idx): da_boundary_windowed - self.da_boundary_mean ) / self.da_boundary_std - if da_forcing is not None: - # Expand 'time_diff' to align with 'forcing_feature' and 'window' - # dimensions 'time_diff' has dimension ('time'), expand to ('time', - # 'forcing_feature', 'window') - time_diff_expanded = da_forcing_windowed["time_diff"].expand_dims( - forcing_feature=da_forcing_windowed["forcing_feature"], - window=da_forcing_windowed["window"], - ) - - # Stack 'forcing_feature' and 'window' into a single - # 'forcing_feature_windowed' dimension - da_forcing_windowed = da_forcing_windowed.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - time_diff_expanded = time_diff_expanded.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - - # Assign 'time_diff' as a coordinate to 'forcing_feature_windowed' - da_forcing_windowed = da_forcing_windowed.assign_coords( - time_diff=( - "forcing_feature_windowed", - time_diff_expanded.values, - ) - ) - else: - # Create an empty forcing tensor with the right shape - da_forcing_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "forcing_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "forcing_feature": [], - }, - ) - - if da_boundary is not None: - # If 'da_boundary_windowed' also has 'time_diff', process similarly - # Expand 'time_diff' to align with 'boundary_feature' and 'window' - # dimensions - time_diff_expanded = da_boundary_windowed["time_diff"].expand_dims( - boundary_feature=da_boundary_windowed["boundary_feature"], - window=da_boundary_windowed["window"], - ) - - # Stack 'boundary_feature' and 'window' into a single - # 'boundary_feature_windowed' dimension - da_boundary_windowed = da_boundary_windowed.stack( - boundary_feature_windowed=("boundary_feature", "window") - ) - time_diff_expanded = time_diff_expanded.stack( - boundary_feature_windowed=("boundary_feature", "window") - ) - - # Assign 'time_diff' as a coordinate to 'boundary_feature_windowed' - da_boundary_windowed = da_boundary_windowed.assign_coords( - time_diff=( - "boundary_feature_windowed", - time_diff_expanded.values, - ) - ) - else: - # Create an empty boundary tensor with the right shape - da_boundary_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "boundary_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "boundary_feature": [], - }, - ) + da_forcing_windowed = self._process_windowed_data( + da_forcing_windowed, da_state, da_target_times + ) + da_boundary_windowed = self._process_windowed_data( + da_boundary_windowed, da_state, da_target_times + ) return ( da_init_states, From 79199956225277cb88b255a514be1a72634926c5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 07:09:52 +0100 Subject: [PATCH 012/103] cleanup analysis based time-slicing --- neural_lam/weather_dataset.py | 85 +++++++++++++++++------------------ 1 file changed, 42 insertions(+), 43 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index c8806d1c..bbfb5705 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -245,11 +245,12 @@ def __len__(self): def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing` (forcing). For the state data, slicing is done as before - based on `idx`. For the forcing data, nearest neighbor matching is - performed based on the state times. Additionally, the time difference - between the matched forcing times and state times (in multiples of state - time steps) is added to the forcing dataarray. + `da_forcing` (forcing). For the state data, slicing is done based on + `idx`. For the forcing data, nearest neighbor matching is performed + based on the state times. Additionally, the time difference between the + matched forcing times and state times (in multiples of state time steps) + is added to the forcing dataarray. This will be used as an additional + feature in the model (temporal embedding). Parameters ---------- @@ -269,9 +270,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). da_forcing_matched : xr.DataArray - The forcing dataarray matched to state times with an added - coordinate 'time_diff', representing the time difference to state - times in multiples of state time steps. + The sliced state dataarray with dims ('time', 'grid_index', + 'forcing_feature_windowed'). """ # Number of initial steps required (e.g., for initializing models) init_steps = 2 @@ -308,9 +308,9 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): if da_forcing is None: return da_state_sliced, None - # Get the state times for matching + # Get the state times and its temporal resolution for matching with + # forcing data state_times = da_state_sliced["time"] - # Calculate time differences in multiples of state time steps state_time_step = state_times.values[1] - state_times.values[0] # Match forcing data to state times based on nearest neighbor @@ -369,39 +369,29 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing["time"] - # Compute time differences + # Compute time differences between forcing and state times + # (in multiples of state time steps) + # Retrieve the indices of the closest times in the forcing data time_deltas = ( - state_times.values[np.newaxis, :] - - forcing_times.values[:, np.newaxis] - ) + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) - time_diff_steps = xr.DataArray( - np.stack( - [ - np.diagonal(time_deltas, offset=offset)[ - -len(state_times) + init_steps : - ] - / state_time_step - for offset in range( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ) - ], - axis=1, - ), - dims=["time", "window"], - coords={ - "time": state_times.isel(time=slice(init_steps, None)), - "window": np.arange( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ), - }, - name="time_diff_steps", + time_diff_steps = np.stack( + [ + time_deltas[ + idx_i + - self.num_past_forcing_steps : idx_i + + self.num_future_forcing_steps + + 1, + init_steps + step_i, + ] + for (step_i, idx_i) in enumerate(idx_min[init_steps:]) + ], ) - # Create window dimension using rolling + # Create window dimension for forcing data to stack later window_size = ( self.num_past_forcing_steps + self.num_future_forcing_steps + 1 ) @@ -412,9 +402,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): time=idx_min[init_steps:] ) - # Add time difference as a new coordinate - da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=time_diff_steps + # Add time difference as a new coordinate to concatenate to the + # forcing features later + da_forcing_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, ) return da_state_sliced, da_forcing_matched @@ -423,13 +415,19 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data after standardization.""" stacked_dim = "forcing_feature_windowed" if da_windowed is not None: - # Stack the 'feature' and 'window' dimensions + # Stack the 'feature' and 'window' dimensions and add the + # time step differences to the existing features as a temporal + # embedding da_windowed = da_windowed.stack( {stacked_dim: ("forcing_feature", "window")} ) + da_windowed = xr.concat( + [da_windowed, da_windowed.time_diff_steps], + dim="forcing_feature_windowed", + ) else: # Create empty DataArray with the correct dimensions and coordinates - return xr.DataArray( + da_windowed = xr.DataArray( data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), dims=("time", "grid_index", f"{stacked_dim}"), coords={ @@ -438,6 +436,7 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): f"{stacked_dim}": [], }, ) + return da_windowed def _build_item_dataarrays(self, idx): """ From 9bafceec0480ead53e4cdd32b24be669c195316c Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:59:42 +0100 Subject: [PATCH 013/103] implement datastore_boundary in existing tests --- tests/test_datasets.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 419aece0..67eac70e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -14,12 +14,19 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataset -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) from tests.dummy_datastore import DummyDatastore @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_shapes(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_shapes(datastore_name, datastore_boundary_name): """Check that the `datastore.get_dataarray` method is implemented. Validate the shapes of the tensors match between the different @@ -31,6 +38,9 @@ def test_dataset_item_shapes(datastore_name): """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_gridpoints = datastore.num_grid_points N_pred_steps = 4 @@ -38,6 +48,7 @@ def test_dataset_item_shapes(datastore_name): num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -48,7 +59,7 @@ def test_dataset_item_shapes(datastore_name): # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - init_states, target_states, forcing, target_times = item + init_states, target_states, forcing, boundary, target_times = item # initial states assert init_states.ndim == 3 @@ -81,14 +92,23 @@ def test_dataset_item_shapes(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_create_dataarray_from_tensor(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_create_dataarray_from_tensor( + datastore_name, datastore_boundary_name +): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -158,13 +178,19 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): @pytest.mark.parametrize("split", ["train", "val", "test"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_single_batch(datastore_name, split): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_single_batch(datastore_name, datastore_boundary_name, split): """Check that the `datastore.get_dataarray` method is implemented. And that it returns an xarray DataArray with the correct dimensions. """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) device_name = ( torch.device("cuda") if torch.cuda.is_available() else "cpu" @@ -210,7 +236,9 @@ def _create_graph(): ) ) - dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2) + dataset = WeatherDataset( + datastore=datastore, datastore_boundary=datastore_boundary, split=split + ) model = GraphLAM(args=args, datastore=datastore, config=config) # noqa From ce06bbc24dc4765944c0b937ace0dc4d0f11f364 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 16:39:27 +0100 Subject: [PATCH 014/103] allow for grid shape retrieval from forcing data --- neural_lam/datastore/mdp.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 67aaa9d0..57a3249f 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -377,8 +377,17 @@ def grid_shape_state(self): The shape of the cartesian grid for the state variables. """ - ds_state = self.unstack_grid_coords(self._ds["state"]) - da_x, da_y = ds_state.x, ds_state.y + # Boundary data often has no state features + if "state" not in self._ds: + warnings.warn( + "no state data found in datastore" + "returning grid shape from forcing data" + ) + ds_forcing = self.unstack_grid_coords(self._ds["forcing"]) + da_x, da_y = ds_forcing.x, ds_forcing.y + else: + ds_state = self.unstack_grid_coords(self._ds["state"]) + da_x, da_y = ds_state.x, ds_state.y assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) From 884b5c623117cb18c405ac869caaff028625e5fb Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 16:40:47 +0100 Subject: [PATCH 015/103] rearrange time slicing, boundary first --- neural_lam/weather_dataset.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index bbfb5705..32add37a 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -495,13 +495,6 @@ def _build_item_dataarrays(self, idx): # if da_forcing is None, the function will return None for # da_forcing_windowed - da_state, da_forcing_windowed = self._slice_time( - da_state=da_state, - idx=idx, - n_steps=self.ar_steps, - da_forcing=da_forcing, - ) - if da_boundary is not None: _, da_boundary_windowed = self._slice_time( da_state=da_state, @@ -509,6 +502,12 @@ def _build_item_dataarrays(self, idx): n_steps=self.ar_steps, da_forcing=da_boundary, ) + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, + ) # load the data into memory da_state.load() From 5904cbe9da67d3e98eaab0cebd501a2ad0ded7f3 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Mon, 25 Nov 2024 16:42:21 +0100 Subject: [PATCH 016/103] identified issue, cleanup next --- neural_lam/datastore/base.py | 9 ++++- neural_lam/datastore/mdp.py | 5 ++- neural_lam/models/ar_model.py | 46 ++++++++++++++++++++-- neural_lam/train_model.py | 2 +- neural_lam/vis.py | 73 +++++++++++++++++++++++++---------- 5 files changed, 107 insertions(+), 28 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 0317c2e5..b0055e39 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -295,8 +295,13 @@ def get_xy_extent(self, category: str) -> List[float]: The extent of the x, y coordinates. """ - xy = self.get_xy(category, stacked=False) - extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] + xy = self.get_xy(category, stacked=True) + extent = [ + xy[:, 0].min(), + xy[:, 0].max(), + xy[:, 1].min(), + xy[:, 1].max(), + ] return [float(v) for v in extent] @property diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 10593a82..0d1aac7b 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -1,4 +1,5 @@ # Standard library +import copy import warnings from functools import cached_property from pathlib import Path @@ -394,7 +395,9 @@ def coords_projection(self) -> ccrs.Projection: class_name = projection_info["class_name"] ProjectionClass = getattr(ccrs, class_name) - kwargs = projection_info["kwargs"] + # need to copy otherwise we modify the dict stored in the dataclass + # in-place + kwargs = copy.deepcopy(projection_info["kwargs"]) globe_kwargs = kwargs.pop("globe", {}) if len(globe_kwargs) > 0: diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index bc4c6719..b55143f0 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -7,12 +7,14 @@ import pytorch_lightning as pl import torch import wandb +from loguru import logger # Local from .. import metrics, vis from ..config import NeuralLAMConfig from ..datastore import BaseDatastore from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset class ARModel(pl.LightningModule): @@ -147,6 +149,14 @@ def __init__( # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] + def _create_dataarray_from_tensor(self, tensor, time, split, category): + weather_dataset = WeatherDataset(datastore=self._datastore, split=split) + time = np.array(time, dtype="datetime64[ns]") + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor, time=time, category=category + ) + return da + def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) @@ -406,10 +416,13 @@ def test_step(self, batch, batch_idx): ) self.plot_examples( - batch, n_additional_examples, prediction=prediction + batch, + n_additional_examples, + prediction=prediction, + split="test", ) - def plot_examples(self, batch, n_examples, prediction=None): + def plot_examples(self, batch, n_examples, split, prediction=None): """ Plot the first n_examples forecasts from batch @@ -422,18 +435,34 @@ def plot_examples(self, batch, n_examples, prediction=None): prediction, target, _, _ = self.common_step(batch) target = batch[1] + time = batch[3] # Rescale to original data scale prediction_rescaled = prediction * self.state_std + self.state_mean target_rescaled = target * self.state_std + self.state_mean # Iterate over the examples - for pred_slice, target_slice in zip( - prediction_rescaled[:n_examples], target_rescaled[:n_examples] + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], ): # Each slice is (pred_steps, num_grid_nodes, d_f) self.plotted_examples += 1 # Increment already here + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + var_vmin = ( torch.minimum( pred_slice.flatten(0, 1).min(dim=0)[0], @@ -465,6 +494,10 @@ def plot_examples(self, batch, n_examples, prediction=None): title=f"{var_name} ({var_unit}), " f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i + ).squeeze(), + da_target=da_target.isel(state_feature=var_i).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( @@ -476,6 +509,11 @@ def plot_examples(self, batch, n_examples, prediction=None): ] example_i = self.plotted_examples + for i, fig in enumerate(var_figs): + fn = f"example_{i}_{example_i}_t{t_i}.png" + fig.savefig(fn) + logger.info(f"Saved example plot to {fn}") + wandb.log( { f"{var_name}_example_{example_i}": wandb.Image(fig) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c89..9d1d5039 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -23,7 +23,7 @@ } -@logger.catch +@logger.catch(reraise=True) def main(input_args=None): """Main function for training and evaluating models.""" parser = ArgumentParser( diff --git a/neural_lam/vis.py b/neural_lam/vis.py index b9d18b39..357a8977 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -68,6 +68,8 @@ def plot_prediction( pred, target, datastore: BaseRegularGridDatastore, + da_prediction=None, + da_target=None, title=None, vrange=None, ): @@ -88,10 +90,8 @@ def plot_prediction( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + mask_values = np.invert(da_mask.values.astype(bool)).astype(float) + pixel_alpha = mask_values.clip(0.7, 1) # Faded border region fig, axes = plt.subplots( 1, @@ -100,29 +100,62 @@ def plot_prediction( subplot_kw={"projection": datastore.coords_projection}, ) + use_xarray = True + # Plot pred and target - for ax, data in zip(axes, (target, pred)): + + if not use_xarray: + for ax, data in zip(axes, (target, pred)): + ax.coastlines() # Add coastline outlines + data_grid = ( + data.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() + .numpy() + ) + im = ax.imshow( + data_grid, + origin="lower", + extent=extent, + alpha=pixel_alpha, + vmin=vmin, + vmax=vmax, + cmap="plasma", + ) + + cbar = fig.colorbar(im, aspect=30) + cbar.ax.tick_params(labelsize=10) + + x = da_target.x.values + y = da_target.y.values + extent = [x.min(), x.max(), y.min(), y.max()] + for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines - data_grid = ( - data.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() - .numpy() - ) - im = ax.imshow( - data_grid, + im = da.plot.imshow( + ax=ax, origin="lower", + x="x", extent=extent, - alpha=pixel_alpha, + alpha=pixel_alpha.T, vmin=vmin, vmax=vmax, cmap="plasma", + transform=datastore.coords_projection, ) + # da.plot.pcolormesh( + # ax=ax, + # x="x", + # vmin=vmin, + # vmax=vmax, + # transform=datastore.coords_projection, + # cmap="plasma", + # ) + # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) if title: fig.suptitle(title, size=20) @@ -150,9 +183,7 @@ def plot_spatial_error( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region fig, ax = plt.subplots( figsize=(5, 4.8), @@ -161,8 +192,10 @@ def plot_spatial_error( ax.coastlines() # Add coastline outlines error_grid = ( - error.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() + error.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() .numpy() ) From efe03027842a22139d6554d68ffee7b6ebe0ad73 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 13:46:05 +0100 Subject: [PATCH 017/103] use xarray plot only --- neural_lam/models/ar_model.py | 47 +++++++++++++++++++++++++++-------- neural_lam/vis.py | 43 +++----------------------------- 2 files changed, 39 insertions(+), 51 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index b55143f0..0af25367 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,5 +1,6 @@ # Standard library import os +from typing import List, Union # Third-party import matplotlib.pyplot as plt @@ -7,7 +8,7 @@ import pytorch_lightning as pl import torch import wandb -from loguru import logger +import xarray as xr # Local from .. import metrics, vis @@ -149,7 +150,35 @@ def __init__( # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] - def _create_dataarray_from_tensor(self, tensor, time, split, category): + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: Union[int, List[int]], + split: str, + category: str, + ) -> xr.DataArray: + """ + Create an `xr.DataArray` from a tensor, with the correct dimensions and + coordinates to match the datastore used by the model. This function in + in effect is the inverse of what is returned by + `WeatherDataset.__getitem__`. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to convert to a `xr.DataArray` with dimensions [time, + grid_index, feature] + time : Union[int,List[int]] + The time index or indices for the data, given as integers or a list + of integers representing epoch time in nanoseconds. + split : str + The split of the data, either 'train', 'val', or 'test' + category : str + The category of the data, either 'state' or 'forcing' + """ + # TODO: creating an instance of WeatherDataset here on every call is + # not how this should be done but whether WeatherDataset should be + # provided to ARModel or where to put plotting still needs discussion weather_dataset = WeatherDataset(datastore=self._datastore, split=split) time = np.array(time, dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( @@ -482,14 +511,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None): var_vranges = list(zip(var_vmin, var_vmax)) # Iterate over prediction horizon time steps - for t_i, (pred_t, target_t) in enumerate( - zip(pred_slice, target_slice), start=1 - ): + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): # Create one figure per variable at this time step var_figs = [ vis.plot_prediction( - pred=pred_t[:, var_i], - target=target_t[:, var_i], datastore=self._datastore, title=f"{var_name} ({var_unit}), " f"t={t_i} ({self._datastore.step_length * t_i} h)", @@ -509,10 +534,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None): ] example_i = self.plotted_examples - for i, fig in enumerate(var_figs): - fn = f"example_{i}_{example_i}_t{t_i}.png" - fig.savefig(fn) - logger.info(f"Saved example plot to {fn}") + # for i, fig in enumerate(var_figs): + # fn = f"example_{i}_{example_i}_t{t_i}.png" + # fig.savefig(fn) + # logger.info(f"Saved example plot to {fn}") wandb.log( { diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 357a8977..47c68e4f 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -65,8 +65,6 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, - target, datastore: BaseRegularGridDatastore, da_prediction=None, da_target=None, @@ -81,8 +79,8 @@ def plot_prediction( """ # Get common scale for values if vrange is None: - vmin = min(vals.min().cpu().item() for vals in (pred, target)) - vmax = max(vals.max().cpu().item() for vals in (pred, target)) + vmin = min(da_prediction.min(), da_target.min()) + vmax = max(da_prediction.max(), da_target.max()) else: vmin, vmax = vrange @@ -100,39 +98,13 @@ def plot_prediction( subplot_kw={"projection": datastore.coords_projection}, ) - use_xarray = True - # Plot pred and target - - if not use_xarray: - for ax, data in zip(axes, (target, pred)): - ax.coastlines() # Add coastline outlines - data_grid = ( - data.reshape( - [datastore.grid_shape_state.x, datastore.grid_shape_state.y] - ) - .T.cpu() - .numpy() - ) - im = ax.imshow( - data_grid, - origin="lower", - extent=extent, - alpha=pixel_alpha, - vmin=vmin, - vmax=vmax, - cmap="plasma", - ) - - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) - x = da_target.x.values y = da_target.y.values extent = [x.min(), x.max(), y.min(), y.max()] for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines - im = da.plot.imshow( + da.plot.imshow( ax=ax, origin="lower", x="x", @@ -144,15 +116,6 @@ def plot_prediction( transform=datastore.coords_projection, ) - # da.plot.pcolormesh( - # ax=ax, - # x="x", - # vmin=vmin, - # vmax=vmax, - # transform=datastore.coords_projection, - # cmap="plasma", - # ) - # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) From a489c2ed974397ea230d2e61b842d8d9384867dc Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 14:07:06 +0100 Subject: [PATCH 018/103] don't reraise --- neural_lam/train_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 9d1d5039..74146c89 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -23,7 +23,7 @@ } -@logger.catch(reraise=True) +@logger.catch def main(input_args=None): """Main function for training and evaluating models.""" parser = ArgumentParser( From 242d08bcb5374cdd90aecfd49f501ed233f1ce0c Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 14:50:03 +0100 Subject: [PATCH 019/103] remove debug plot --- neural_lam/models/ar_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 0af25367..c875688b 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -534,10 +534,6 @@ def plot_examples(self, batch, n_examples, split, prediction=None): ] example_i = self.plotted_examples - # for i, fig in enumerate(var_figs): - # fn = f"example_{i}_{example_i}_t{t_i}.png" - # fig.savefig(fn) - # logger.info(f"Saved example plot to {fn}") wandb.log( { From c1f706c29542d770ed49e910f8b9bd5caff1fdec Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 16:04:24 +0100 Subject: [PATCH 020/103] remove extent calc used in diagnosing issue --- neural_lam/vis.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 47c68e4f..c814aacf 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -99,9 +99,6 @@ def plot_prediction( ) # Plot pred and target - x = da_target.x.values - y = da_target.y.values - extent = [x.min(), x.max(), y.min(), y.max()] for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines da.plot.imshow( From cf8e3e4c1be93a6ec074368aaf6f91c8042b5278 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 29 Nov 2024 14:51:36 +0100 Subject: [PATCH 021/103] add type annotation --- neural_lam/vis.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index c814aacf..d6b57f88 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -2,6 +2,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import xarray as xr # Local from . import utils @@ -66,8 +67,8 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( datastore: BaseRegularGridDatastore, - da_prediction=None, - da_target=None, + da_prediction: xr.DataArray = None, + da_target: xr.DataArray = None, title=None, vrange=None, ): From 85160cecf13ecfc9fc6a589ac1a9e3542da45e23 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 29 Nov 2024 15:03:06 +0100 Subject: [PATCH 022/103] ensure tensor copy to cpu mem before data-array creation --- neural_lam/models/ar_model.py | 10 ++++++---- neural_lam/weather_dataset.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index c875688b..0d8e6e3c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -167,10 +167,12 @@ def _create_dataarray_from_tensor( ---------- tensor : torch.Tensor The tensor to convert to a `xr.DataArray` with dimensions [time, - grid_index, feature] + grid_index, feature]. The tensor will be copied to the CPU if it is + not already there. time : Union[int,List[int]] The time index or indices for the data, given as integers or a list - of integers representing epoch time in nanoseconds. + of integers representing epoch time in nanoseconds. The ints will be + copied to the CPU memory if they are not already there. split : str The split of the data, either 'train', 'val', or 'test' category : str @@ -180,9 +182,9 @@ def _create_dataarray_from_tensor( # not how this should be done but whether WeatherDataset should be # provided to ARModel or where to put plotting still needs discussion weather_dataset = WeatherDataset(datastore=self._datastore, split=split) - time = np.array(time, dtype="datetime64[ns]") + time = np.array(time.cpu(), dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor, time=time, category=category + tensor=tensor.cpu().numpy(), time=time, category=category ) return da diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 532e3c90..b5f85580 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -529,7 +529,8 @@ def create_dataarray_from_tensor( tensor : torch.Tensor The tensor to construct the DataArray from, this assumed to have the same dimension ordering as returned by the __getitem__ method - (i.e. time, grid_index, {category}_feature). + (i.e. time, grid_index, {category}_feature). The tensor will be + copied to the CPU before constructing the DataArray. time : datetime.datetime or list[datetime.datetime] The time or times of the tensor. category : str @@ -581,7 +582,7 @@ def _is_listlike(obj): coords["time"] = time da = xr.DataArray( - tensor.numpy(), + tensor.cpu().numpy(), dims=dims, coords=coords, ) From 52c452879f56c7f982cfd5d55a5259f37cb6b030 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 29 Nov 2024 15:05:36 +0100 Subject: [PATCH 023/103] apply time-indexing to support ar_steps_val > 1 --- neural_lam/models/ar_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 0d8e6e3c..44baf9c2 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -522,9 +522,11 @@ def plot_examples(self, batch, n_examples, split, prediction=None): f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, da_prediction=da_prediction.isel( - state_feature=var_i + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 ).squeeze(), - da_target=da_target.isel(state_feature=var_i).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( From b96d8ebc0c5c22f980e22384efafcd08db20577f Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:42:05 +0100 Subject: [PATCH 024/103] renaming test datastores --- tests/datastore_examples/.gitignore | 3 +- .../.gitignore | 0 .../era5_1000hPa_danra_100m_winds/config.yaml | 12 +++ .../danra.datastore.yaml | 99 +++++++++++++++++++ .../era5.datastore.yaml | 23 ++--- .../mdp/era5_1000hPa_winds/config.yaml | 3 - 6 files changed, 122 insertions(+), 18 deletions(-) rename tests/datastore_examples/mdp/{era5_1000hPa_winds => era5_1000hPa_danra_100m_winds}/.gitignore (100%) create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml rename tests/datastore_examples/mdp/{era5_1000hPa_winds => era5_1000hPa_danra_100m_winds}/era5.datastore.yaml (80%) delete mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore index e84e6493..4fbd2326 100644 --- a/tests/datastore_examples/.gitignore +++ b/tests/datastore_examples/.gitignore @@ -1,2 +1,3 @@ npyfilesmeps/*.zip -npyfilesmeps/meps_example_reduced/ +npyfilesmeps/meps_example_reduced +npyfilesmeps/era5_1000hPa_temp_meps_example_reduced diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore similarity index 100% rename from tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore rename to tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml new file mode 100644 index 00000000..a158bee3 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml @@ -0,0 +1,12 @@ +datastore: + kind: mdp + config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml new file mode 100644 index 00000000..3edf1267 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml @@ -0,0 +1,99 @@ +schema_version: v0.5.0 +dataset_version: v0.1.0 + +output: + variables: + static: [grid_index, static_feature] + state: [time, grid_index, state_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-03T00:00 + end: 1990-09-09T00:00 + step: PT3H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-03T00:00 + end: 1990-09-06T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-06T00:00 + end: 1990-09-07T00:00 + test: + start: 1990-09-07T00:00 + end: 1990-09-09T00:00 + +inputs: + danra_height_levels: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + target_output_variable: state + + danra_surface: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + # use surface incoming shortwave radiation as forcing + - swavr0m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: forcing + + danra_lsm: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr + dims: [x, y] + variables: + - lsm + dim_mapping: + grid_index: + method: stack + dims: [x, y] + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: static + +extra: + projection: + class_name: LambertConformal + kwargs: + central_longitude: 25.0 + central_latitude: 56.7 + standard_parallels: [56.7, 56.7] + globe: + semimajor_axis: 6367470.0 + semiminor_axis: 6367470.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml similarity index 80% rename from tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml rename to tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index 36b39501..c97da4bc 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -1,5 +1,4 @@ -#TODO: What do these versions mean? Should they be updated? -schema_version: v0.2.0+dev +schema_version: v0.5.0 dataset_version: v1.0.0 output: @@ -7,8 +6,8 @@ output: forcing: [time, grid_index, forcing_feature] coord_ranges: time: - start: 1990-09-02T00:00 - end: 1990-09-10T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 step: PT6H chunking: time: 1 @@ -16,17 +15,17 @@ output: dim: time splits: train: - start: 1990-09-02T00:00 - end: 1990-09-07T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 compute_statistics: ops: [mean, std, diff_mean, diff_std] dims: [grid_index, time] val: - start: 1990-09-05T00:00 - end: 1990-09-08T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 test: - start: 1990-09-06T00:00 - end: 1990-09-10T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: @@ -37,10 +36,6 @@ inputs: level: values: [1000,] units: hPa - v_component_of_wind: - level: - values: [1000, ] - units: hPa dim_mapping: time: method: rename diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml deleted file mode 100644 index 5d1e05f2..00000000 --- a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml +++ /dev/null @@ -1,3 +0,0 @@ -datastore: - kind: mdp - config_path: era5.datastore.yaml From 72da25fd15d46a4497728935e9767c34330f1ccc Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:44:15 +0100 Subject: [PATCH 025/103] adding num_past/future_boundary_step args --- neural_lam/train_model.py | 37 +++++++++++++++------------------ tests/test_datasets.py | 43 +++++++++++++++++++++++++++++++++------ tests/test_training.py | 24 ++++++++++++++++++++-- 3 files changed, 75 insertions(+), 29 deletions(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 37bf6db7..2a61e86c 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -34,11 +34,6 @@ def main(input_args=None): type=str, help="Path to the configuration for neural-lam", ) - parser.add_argument( - "--config_path_boundary", - type=str, - help="Path to the configuration for boundary conditions", - ) parser.add_argument( "--model", type=str, @@ -208,6 +203,18 @@ def main(input_args=None): default=1, help="Number of future time steps to use as input for forcing data", ) + parser.add_argument( + "--num_past_boundary_steps", + type=int, + default=1, + help="Number of past time steps to use as input for boundary data", + ) + parser.add_argument( + "--num_future_boundary_steps", + type=int, + default=1, + help="Number of future time steps to use as input for boundary data", + ) args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() @@ -217,9 +224,6 @@ def main(input_args=None): assert ( args.config_path is not None ), "Specify your config with --config_path" - assert ( - args.config_path_boundary is not None - ), "Specify your config with --config_path_boundary" assert args.model in MODELS, f"Unknown model: {args.model}" assert args.eval in ( None, @@ -234,21 +238,10 @@ def main(input_args=None): seed.seed_everything(args.seed) # Load neural-lam configuration and datastore to use - config, datastore = load_config_and_datastore(config_path=args.config_path) - config_boundary, datastore_boundary = load_config_and_datastore( - config_path=args.config_path_boundary + config, datastore, datastore_boundary = load_config_and_datastore( + config_path=args.config_path ) - # TODO this should not be required, make more flexible - assert ( - datastore.num_past_forcing_steps - == datastore_boundary.num_past_forcing_steps - ), "Mismatch in num_past_forcing_steps" - assert ( - datastore.num_future_forcing_steps - == datastore_boundary.num_future_forcing_steps - ), "Mismatch in num_future_forcing_steps" - # Create datamodule data_module = WeatherDataModule( datastore=datastore, @@ -258,6 +251,8 @@ def main(input_args=None): standardize=True, num_past_forcing_steps=args.num_past_forcing_steps, num_future_forcing_steps=args.num_future_forcing_steps, + num_past_boundary_steps=args.num_past_boundary_steps, + num_future_boundary_steps=args.num_future_boundary_steps, batch_size=args.batch_size, num_workers=args.num_workers, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 67eac70e..5fbe4a5d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -42,10 +42,13 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): datastore_boundary_name ) N_gridpoints = datastore.num_grid_points + N_gridpoints_boundary = datastore_boundary.num_grid_points N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, datastore_boundary=datastore_boundary, @@ -53,6 +56,8 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) item = dataset[0] @@ -77,8 +82,23 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): assert forcing.ndim == 3 assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints - assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * ( - num_past_forcing_steps + num_future_forcing_steps + 1 + # each stacked forcing feature has one corresponding temporal embedding + assert ( + forcing.shape[2] + == datastore.get_num_data_vars("forcing") + * (num_past_forcing_steps + num_future_forcing_steps + 1) + * 2 + ) + + # boundary + assert boundary.ndim == 3 + assert boundary.shape[0] == N_pred_steps + assert boundary.shape[1] == N_gridpoints_boundary + assert ( + boundary.shape[2] + == datastore_boundary.get_num_data_vars("forcing") + * (num_past_boundary_steps + num_future_boundary_steps + 1) + * 2 ) # batch times @@ -88,6 +108,7 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): # try to get the last item of the dataset to ensure slicing and stacking # operations are working as expected and are consistent with the dataset # length + dataset[len(dataset) - 1] @@ -106,6 +127,9 @@ def test_dataset_item_create_dataarray_from_tensor( N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 + dataset = WeatherDataset( datastore=datastore, datastore_boundary=datastore_boundary, @@ -113,16 +137,22 @@ def test_dataset_item_create_dataarray_from_tensor( ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) idx = 0 # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - _, target_states, _, target_times_arr = dataset[idx] - _, da_target_true, _, da_target_times_true = dataset._build_item_dataarrays( - idx=idx - ) + _, target_states, _, _, target_times_arr = dataset[idx] + ( + _, + da_target_true, + _, + _, + da_target_times_true, + ) = dataset._build_item_dataarrays(idx=idx) target_times = np.array(target_times_arr, dtype="datetime64[ns]") np.testing.assert_equal(target_times, da_target_times_true.values) @@ -272,6 +302,7 @@ def test_dataset_length(dataset_config): dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=dataset_config["ar_steps"], num_past_forcing_steps=dataset_config["past"], diff --git a/tests/test_training.py b/tests/test_training.py index 1ed1847d..28566a4b 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -14,18 +14,33 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataModule -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_training(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( f"Skipping test for {datastore_name} as it is not a regular " "grid datastore." ) + if not isinstance(datastore_boundary, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_boundary_name} as it is not a regular " + "grid datastore." + ) if torch.cuda.is_available(): device_name = "cuda" @@ -59,6 +74,7 @@ def test_training(datastore_name): data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=3, ar_steps_eval=5, standardize=True, @@ -66,6 +82,8 @@ def test_training(datastore_name): num_workers=1, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, ) class ModelArgs: @@ -85,6 +103,8 @@ class ModelArgs: metrics_watch = [] num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 model_args = ModelArgs() From 244f1ccb77e9d12852e3a59feddff5034f54ef95 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:44:51 +0100 Subject: [PATCH 026/103] using combined config file --- neural_lam/config.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/neural_lam/config.py b/neural_lam/config.py index d3e09697..914ebb38 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -168,4 +168,15 @@ def load_config_and_datastore( datastore_kind=config.datastore.kind, config_path=datastore_config_path ) - return config, datastore + if config.datastore_boundary is not None: + datastore_boundary_config_path = ( + Path(config_path).parent / config.datastore_boundary.config_path + ) + datastore_boundary = init_datastore( + datastore_kind=config.datastore_boundary.kind, + config_path=datastore_boundary_config_path, + ) + else: + datastore_boundary = None + + return config, datastore, datastore_boundary From a9cc36e23de294f21fce15f903a4ba7d0a8496a6 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:45:12 +0100 Subject: [PATCH 027/103] proper handling of state/forcing/boundary in dataset --- neural_lam/weather_dataset.py | 304 +++++++++++++++++++--------------- 1 file changed, 167 insertions(+), 137 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 32add37a..b717c40a 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -38,6 +38,16 @@ class WeatherDataset(torch.utils.data.Dataset): forcing from times t, t+1, ..., t+j-1, t+j (and potentially times before t, given num_past_forcing_steps) are included as forcing inputs at time t. Default is 1. + num_past_boundary_steps: int, optional + Number of past time steps to include in boundary input. If set to i, + boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, + given num_future_forcing_steps) are included as boundary inputs at time t + Default is 1. + num_future_boundary_steps: int, optional + Number of future time steps to include in boundary input. If set to j, + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times before + t, given num_past_forcing_steps) are included as boundary inputs at time + t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -50,6 +60,8 @@ def __init__( ar_steps=3, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, standardize=True, ): super().__init__() @@ -60,10 +72,10 @@ def __init__( self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps - self.da_state = self.datastore.get_dataarray( - category="state", split=self.split - ) + self.da_state = self.datastore.get_dataarray(category="state", split=self.split) if self.da_state is None: raise ValueError( "A non-empty state dataarray must be provided. " @@ -74,9 +86,12 @@ def __init__( category="forcing", split=self.split ) # XXX For now boundary data is always considered mdp-forcing data - self.da_boundary = self.datastore_boundary.get_dataarray( - category="forcing", split=self.split - ) + if self.datastore_boundary is not None: + self.da_boundary = self.datastore_boundary.get_dataarray( + category="forcing", split=self.split + ) + else: + self.da_boundary = None # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -97,9 +112,7 @@ def __init__( parts["forcing"] = self.da_forcing for part, da in parts.items(): - expected_dim_order = self.datastore.expected_dim_order( - category=part - ) + expected_dim_order = self.datastore.expected_dim_order(category=part) if da.dims != expected_dim_order: raise ValueError( f"The dimension order of the `{part}` data ({da.dims}) " @@ -108,6 +121,23 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # handling ensemble data + if self.datastore.is_ensemble: + # for the now the strategy is to only include the first ensemble + # member + # XXX: this could be changed to include all ensemble members by + # splitting `idx` into two parts, one for the analysis time and one + # for the ensemble member and then increasing self.__len__ to + # include all ensemble members + warnings.warn( + "only use of ensemble member 0 (the first member) is " + "implemented for ensemble data" + ) + i_ensemble = 0 + self.da_state = self.da_state.isel(ensemble_member=i_ensemble) + else: + self.da_state = self.da_state + def get_time_step(times): """Calculate the time step from the data""" time_diffs = np.diff(times) @@ -119,11 +149,18 @@ def get_time_step(times): return time_diffs[0] # Check time step consistency in state data - _ = get_time_step(self.da_state.time.values) + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time + _ = get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: - state_times = self.da_state.time + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time state_time_min = state_times.min().values state_time_max = state_times.max().values @@ -131,26 +168,30 @@ def get_time_step(times): # Forcing data is part of the same datastore as state data # During creation the time dimension of the forcing data # is matched to the state data - forcing_times = self.da_forcing.time - _ = get_time_step(forcing_times.values) + if self.datastore.is_forecast: + forcing_times = self.da_forcing.analysis_time + else: + forcing_times = self.da_forcing.time + get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore # The boundary data is allowed to have a different time_step # Check that the boundary data covers the required time range - boundary_times = self.da_boundary.time + if self.datastore_boundary.is_forecast: + boundary_times = self.da_boundary.analysis_time + else: + boundary_times = self.da_boundary.time boundary_time_step = get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values # Calculate required bounds for boundary using its time step boundary_required_time_min = ( - state_time_min - - self.num_past_forcing_steps * boundary_time_step + state_time_min - self.num_past_forcing_steps * boundary_time_step ) boundary_required_time_max = ( - state_time_max - + self.num_future_forcing_steps * boundary_time_step + state_time_max + self.num_future_forcing_steps * boundary_time_step ) if boundary_time_min > boundary_required_time_min: @@ -179,10 +220,8 @@ def get_time_step(times): self.da_state_std = self.ds_state_stats.state_std if self.da_forcing is not None: - self.ds_forcing_stats = ( - self.datastore.get_standardization_dataarray( - category="forcing" - ) + self.ds_forcing_stats = self.datastore.get_standardization_dataarray( + category="forcing" ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std @@ -208,7 +247,7 @@ def __len__(self): warnings.warn( "only using first ensemble member, so dataset size is " " effectively reduced by the number of ensemble members " - f"({self.da_state.ensemble_member.size})", + f"({self.datastore._num_ensemble_members})", UserWarning, ) @@ -242,36 +281,50 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): + def _slice_time( + self, + da_state, + idx, + n_steps: int, + da_forcing_boundary=None, + num_past_steps=None, + num_future_steps=None, + ): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing` (forcing). For the state data, slicing is done based on - `idx`. For the forcing data, nearest neighbor matching is performed - based on the state times. Additionally, the time difference between the - matched forcing times and state times (in multiples of state time steps) - is added to the forcing dataarray. This will be used as an additional - feature in the model (temporal embedding). + `da_forcing_boundary`. For the state data, slicing is done + based on `idx`. For the forcing/boundary data, nearest neighbor matching + is performed based on the state times. Additionally, the time difference + between the matched forcing/boundary times and state times (in multiples + of state time steps) is added to the forcing dataarray. This will be + used as an additional feature in the model (temporal embedding). Parameters ---------- da_state : xr.DataArray The state dataarray to slice. - da_forcing : xr.DataArray - The forcing dataarray to slice. idx : int The index of the time step to start the sample from in the state data. n_steps : int The number of time steps to include in the sample. + da_forcing_boundary : xr.DataArray + The forcing/boundary dataarray to slice. + num_past_steps : int, optional + The number of past time steps to include in the forcing/boundary + data. Default is `None`. + num_future_steps : int, optional + The number of future time steps to include in the forcing/boundary + data. Default is `None`. Returns ------- da_state_sliced : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). - da_forcing_matched : xr.DataArray + da_forcing_boundary_matched : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', - 'forcing_feature_windowed'). + 'forcing/boundary_feature_windowed'). """ # Number of initial steps required (e.g., for initializing models) init_steps = 2 @@ -279,8 +332,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): # Slice the state data as before if self.datastore.is_forecast: # Calculate start and end indices for slicing - start_idx = max(0, self.num_past_forcing_steps - init_steps) - end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps + start_idx = max(0, num_past_steps - init_steps) + end_idx = max(init_steps, num_past_steps) + n_steps # Slice the state data over the elapsed forecast duration da_state_sliced = da_state.isel( @@ -299,13 +352,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): else: # For analysis data, slice the time dimension directly - start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) - end_idx = ( - idx + max(init_steps, self.num_past_forcing_steps) + n_steps - ) + start_idx = idx + max(0, num_past_steps - init_steps) + end_idx = idx + max(init_steps, num_past_steps) + n_steps da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - if da_forcing is None: + if da_forcing_boundary is None: return da_state_sliced, None # Get the state times and its temporal resolution for matching with @@ -313,78 +364,66 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): state_times = da_state_sliced["time"] state_time_step = state_times.values[1] - state_times.values[0] - # Match forcing data to state times based on nearest neighbor - if self.datastore.is_forecast: - # Calculate all possible forcing times - forcing_times = ( - da_forcing.analysis_time + da_forcing.elapsed_forecast_duration - ) - forcing_times_flat = forcing_times.stack( - forecast_time=("analysis_time", "elapsed_forecast_duration") - ) + if "analysis_time" in da_forcing_boundary.dims: + idx = np.abs( + da_forcing_boundary.analysis_time.values + - self.da_state.analysis_time.values[idx] + ).argmin() + # Add a 'time' dimension using the actual forecast times + offset = max(init_steps, num_past_steps) + da_list = [] + for step in range(n_steps): + start_idx = offset + step - num_past_steps + end_idx = offset + step + num_future_steps + + current_time = ( + da_forcing_boundary.analysis_time[idx] + + da_forcing_boundary.elapsed_forecast_duration[offset + step] + ) - # Compute time differences - time_deltas = ( - forcing_times_flat.values[:, np.newaxis] - - state_times.values[np.newaxis, :] - ) - time_diffs = np.abs(time_deltas) - idx_min = time_diffs.argmin(axis=0) - - # Retrieve corresponding indices for analysis_time and - # elapsed_forecast_duration - forecast_time_index = forcing_times_flat["forecast_time"][idx_min] - analysis_time_indices = forecast_time_index["analysis_time"] - elapsed_forecast_duration_indices = forecast_time_index[ - "elapsed_forecast_duration" - ] - - # Slice the forcing data using matched indices - da_forcing_matched = da_forcing.isel( - analysis_time=("time", analysis_time_indices), - elapsed_forecast_duration=( - "time", - elapsed_forecast_duration_indices, - ), - ) + da_sliced = da_forcing_boundary.isel( + analysis_time=idx, + elapsed_forecast_duration=slice(start_idx, end_idx + 1), + ) - # Assign matched state times to the forcing data - da_forcing_matched["time"] = state_times - da_forcing_matched = da_forcing_matched.swap_dims( - {"elapsed_forecast_duration": "time"} - ) + da_sliced = da_sliced.rename({"elapsed_forecast_duration": "window"}) + da_sliced = da_sliced.assign_coords( + window=np.arange(-num_past_steps, num_future_steps + 1) + ) - # Calculate time differences in multiples of state time steps - state_time_step = state_times.values[1] - state_times.values[0] - time_diff_steps = ( - time_deltas[idx_min, np.arange(len(state_times))] - / state_time_step - ) + da_sliced = da_sliced.expand_dims(dim={"time": [current_time.values]}) + + da_list.append(da_sliced) - # Add time difference as a new coordinate - da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=("time", time_diff_steps) + # Concatenate the list of DataArrays along the 'time' dimension + da_forcing_boundary_matched = xr.concat(da_list, dim="time") + forcing_time_step = ( + da_forcing_boundary_matched.time.values[1] + - da_forcing_boundary_matched.time.values[0] ) + da_forcing_boundary_matched["window"] = da_forcing_boundary_matched["window"] * ( + forcing_time_step / state_time_step + ) + time_diff_steps = da_forcing_boundary_matched.isel( + grid_index=0, forcing_feature=0 + ).data + else: # For analysis data, match directly using the 'time' coordinate - forcing_times = da_forcing["time"] + forcing_times = da_forcing_boundary["time"] # Compute time differences between forcing and state times # (in multiples of state time steps) # Retrieve the indices of the closest times in the forcing data time_deltas = ( - forcing_times.values[:, np.newaxis] - - state_times.values[np.newaxis, :] + forcing_times.values[:, np.newaxis] - state_times.values[np.newaxis, :] ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) time_diff_steps = np.stack( [ time_deltas[ - idx_i - - self.num_past_forcing_steps : idx_i - + self.num_future_forcing_steps - + 1, + idx_i - num_past_steps : idx_i + num_future_steps + 1, init_steps + step_i, ] for (step_i, idx_i) in enumerate(idx_min[init_steps:]) @@ -392,24 +431,22 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): ) # Create window dimension for forcing data to stack later - window_size = ( - self.num_past_forcing_steps + self.num_future_forcing_steps + 1 - ) - da_forcing_windowed = da_forcing.rolling( - time=window_size, center=True + window_size = num_past_steps + num_future_steps + 1 + da_forcing_boundary_windowed = da_forcing_boundary.rolling( + time=window_size, center=False ).construct(window_dim="window") - da_forcing_matched = da_forcing_windowed.isel( + da_forcing_boundary_matched = da_forcing_boundary_windowed.isel( time=idx_min[init_steps:] ) - # Add time difference as a new coordinate to concatenate to the - # forcing features later - da_forcing_matched["time_diff_steps"] = ( - ("time", "window"), - time_diff_steps, - ) + # Add time difference as a new coordinate to concatenate to the + # forcing features later + da_forcing_boundary_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, + ) - return da_state_sliced, da_forcing_matched + return da_state_sliced, da_forcing_boundary_matched def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data after standardization.""" @@ -462,23 +499,7 @@ def _build_item_dataarrays(self, idx): da_target_times : xr.DataArray The dataarray for the target times. """ - # handling ensemble data - if self.datastore.is_ensemble: - # for the now the strategy is to only include the first ensemble - # member - # XXX: this could be changed to include all ensemble members by - # splitting `idx` into two parts, one for the analysis time and one - # for the ensemble member and then increasing self.__len__ to - # include all ensemble members - warnings.warn( - "only use of ensemble member 0 (the first member) is " - "implemented for ensemble data" - ) - i_ensemble = 0 - da_state = self.da_state.isel(ensemble_member=i_ensemble) - else: - da_state = self.da_state - + da_state = self.da_state if self.da_forcing is not None: if "ensemble_member" in self.da_forcing.dims: raise NotImplementedError( @@ -500,13 +521,19 @@ def _build_item_dataarrays(self, idx): da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing=da_boundary, + da_forcing_boundary=da_boundary, + num_future_steps=self.num_future_boundary_steps, + num_past_steps=self.num_past_boundary_steps, ) + else: + da_boundary_windowed = None da_state, da_forcing_windowed = self._slice_time( da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing=da_forcing, + da_forcing_boundary=da_forcing, + num_future_steps=self.num_future_forcing_steps, + num_past_steps=self.num_past_forcing_steps, ) # load the data into memory @@ -521,9 +548,7 @@ def _build_item_dataarrays(self, idx): da_target_times = da_target_states.time if self.standardize: - da_init_states = ( - da_init_states - self.da_state_mean - ) / self.da_state_std + da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean ) / self.da_state_std @@ -595,9 +620,7 @@ def __getitem__(self, idx): tensor_dtype = torch.float32 init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) - target_states = torch.tensor( - da_target_states.values, dtype=tensor_dtype - ) + target_states = torch.tensor(da_target_states.values, dtype=tensor_dtype) target_times = torch.tensor( da_target_times.astype("datetime64[ns]").astype("int64").values, @@ -707,10 +730,7 @@ def _is_listlike(obj): ) for grid_coord in ["x", "y"]: - if ( - grid_coord in da_datastore_state.coords - and grid_coord not in da.coords - ): + if grid_coord in da_datastore_state.coords and grid_coord not in da.coords: da.coords[grid_coord] = da_datastore_state[grid_coord] if not add_time_as_dim: @@ -731,6 +751,8 @@ def __init__( standardize=True, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, batch_size=4, num_workers=16, ): @@ -739,6 +761,8 @@ def __init__( self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps self.ar_steps_train = ar_steps_train self.ar_steps_eval = ar_steps_eval self.standardize = standardize @@ -765,6 +789,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) self.val_dataset = WeatherDataset( datastore=self._datastore, @@ -774,6 +800,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) if stage == "test" or stage is None: @@ -785,6 +813,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) def train_dataloader(self): From dcc0b46861ff1263c688301eca265bd62803616f Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:45:35 +0100 Subject: [PATCH 028/103] datastore_boundars=None introduced --- .../datastore/npyfilesmeps/compute_standardization_stats.py | 1 + 1 file changed, 1 insertion(+) diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index f2c80e8a..4207812f 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -172,6 +172,7 @@ def main( ar_steps = 63 ds = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=ar_steps, standardize=False, From a3b3bde9ed1b044b32afde7e4b12bc8e4a1593e6 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:02 +0100 Subject: [PATCH 029/103] bug fix for file retrieval per member --- neural_lam/datastore/npyfilesmeps/store.py | 51 +++++++++------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 146b0627..7ee583be 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -244,9 +244,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # them separately features = ["toa_downwelling_shortwave_flux", "open_water_fraction"] das = [ - self._get_single_timeseries_dataarray( - features=[feature], split=split - ) + self._get_single_timeseries_dataarray(features=[feature], split=split) for feature in features ] da = xr.concat(das, dim="feature") @@ -259,9 +257,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # variable is turned into a dask array and so execution of the # calculation is delayed until the feature values are actually # used. - da_forecast_time = ( - da.analysis_time + da.elapsed_forecast_duration - ).chunk({"elapsed_forecast_duration": 1}) + da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk( + {"elapsed_forecast_duration": 1} + ) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -339,10 +337,7 @@ def _get_single_timeseries_dataarray( for all categories of data """ - if ( - set(features).difference(self.get_vars_names(category="static")) - == set() - ): + if set(features).difference(self.get_vars_names(category="static")) == set(): assert split in ( "train", "val", @@ -356,12 +351,8 @@ def _get_single_timeseries_dataarray( "test", ), f"Unknown dataset split {split} for features {features}" - if member is not None and features != self.get_vars_names( - category="state" - ): - raise ValueError( - "Member can only be specified for the 'state' category" - ) + if member is not None and features != self.get_vars_names(category="state"): + raise ValueError("Member can only be specified for the 'state' category") concat_axis = 0 @@ -377,9 +368,7 @@ def _get_single_timeseries_dataarray( fp_samples = self.root_path / "samples" / split if self._remove_state_features_with_index: n_to_drop = len(self._remove_state_features_with_index) - feature_dim_mask = np.ones( - len(features) + n_to_drop, dtype=bool - ) + feature_dim_mask = np.ones(len(features) + n_to_drop, dtype=bool) feature_dim_mask[self._remove_state_features_with_index] = False elif features == ["toa_downwelling_shortwave_flux"]: filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT @@ -445,7 +434,7 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split) + coord_values = self._get_analysis_times(split=split, member_id=member) elif d == "y": coord_values = y elif d == "x": @@ -464,9 +453,7 @@ def _get_single_timeseries_dataarray( if features_vary_with_analysis_time: filepaths = [ fp_samples - / filename_format.format( - analysis_time=analysis_time, **file_params - ) + / filename_format.format(analysis_time=analysis_time, **file_params) for analysis_time in coords["analysis_time"] ] else: @@ -505,7 +492,7 @@ def _get_single_timeseries_dataarray( return da - def _get_analysis_times(self, split) -> List[np.datetime64]: + def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: """Get the analysis times for the given split by parsing the filenames of all the files found for the given split. @@ -513,6 +500,8 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: ---------- split : str The dataset split to get the analysis times for. + member_id : int + The ensemble member to get the analysis times for. Returns ------- @@ -520,8 +509,12 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: The analysis times for the given split. """ + if member_id is None: + # Only interior state data files have member_id, to avoid duplicates + # we only look at the first member for all other categories + member_id = 0 pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT) - pattern = re.sub(r"{member_id:[^}]*}", "*", pattern) + pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern) sample_dir = self.root_path / "samples" / split sample_files = sample_dir.glob(pattern) @@ -531,9 +524,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: times.append(name_parts["analysis_time"]) if len(times) == 0: - raise ValueError( - f"No files found in {sample_dir} with pattern {pattern}" - ) + raise ValueError(f"No files found in {sample_dir} with pattern {pattern}") return times @@ -690,9 +681,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ def load_pickled_tensor(fn): - return torch.load( - self.root_path / "static" / fn, weights_only=True - ).numpy() + return torch.load(self.root_path / "static" / fn, weights_only=True).numpy() mean_diff_values = None std_diff_values = None From 3ffc413e2f669dafd4c745a50b9b723fff231316 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:17 +0100 Subject: [PATCH 030/103] rename datastore for tests --- tests/conftest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index be5cf3e7..90a86d0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,14 +94,14 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) -DATASTORES_BOUNDARY_EXAMPLES = dict( - mdp=( +DATASTORES_BOUNDARY_EXAMPLES = { + "mdp": ( DATASTORE_EXAMPLES_ROOT_PATH / "mdp" - / "era5_1000hPa_winds" + / "era5_1000hPa_danra_100m_winds" / "era5.datastore.yaml" - ) -) + ), +} DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore From 85aad66c8e9eec4e0b4e95cabb753d8492a0c49a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:31 +0100 Subject: [PATCH 031/103] aligned time with danra for easier boundary testing --- tests/dummy_datastore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index d62c7356..a958b8f5 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -28,7 +28,7 @@ class DummyDatastore(BaseRegularGridDatastore): """ SHORT_NAME = "dummydata" - T0 = isodate.parse_datetime("2021-01-01T00:00:00") + T0 = isodate.parse_datetime("1990-09-02T00:00:00") N_FEATURES = dict(state=5, forcing=2, static=1) CARTESIAN_COORDS = ["x", "y"] From 64f057f78b713e39496abfc3962affa794666369 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:50 +0100 Subject: [PATCH 032/103] Fixed test for temporal embedding --- tests/test_time_slicing.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 29161505..2f5ed96c 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -40,9 +40,7 @@ def get_dataarray(self, category, split): if self.is_forecast: raise NotImplementedError() else: - da = xr.DataArray( - values, dims=["time"], coords={"time": self._time_values} - ) + da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values}) # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") @@ -78,10 +76,8 @@ def get_vars_long_names(self, category): def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): - # state and forcing variables have only on dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange( - len(ANALYSIS_STATE_VALUES) - ) + # state and forcing variables have only one dimension, `time` + time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -93,6 +89,7 @@ def test_time_slicing_analysis( dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, ar_steps=ar_steps, num_future_forcing_steps=num_future_forcing_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -101,9 +98,7 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _ = [ - tensor.numpy() for tensor in sample - ] + init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] expected_init_states = [0, 1] if ar_steps == 3: @@ -130,7 +125,7 @@ def test_time_slicing_analysis( # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) - # forcing: (ar_steps, N_grid, d_windowed_forcing) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) # target_times: (ar_steps,) assert init_states.shape == (2, 1, 1) assert init_states[:, 0, 0].tolist() == expected_init_states @@ -141,6 +136,10 @@ def test_time_slicing_analysis( assert forcing.shape == ( 3, 1, - 1 + num_past_forcing_steps + num_future_forcing_steps, + # Factor 2 because each window step has a temporal embedding + (1 + num_past_forcing_steps + num_future_forcing_steps) * 2, + ) + np.testing.assert_equal( + forcing[:, 0, : num_past_forcing_steps + num_future_forcing_steps + 1], + np.array(expected_forcing_values), ) - np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values)) From 6205dbd88f1b208118d93da6d12c0a1be672caef Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Mon, 2 Dec 2024 10:26:54 +0100 Subject: [PATCH 033/103] pin dataclass-wizard <0.31.0 to avoid bug in dataclass-wizard --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f0bc0851..fdcb7f3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "torch>=2.3.0", "torch-geometric==2.3.1", "parse>=1.20.2", - "dataclass-wizard>=0.22.3", + "dataclass-wizard<0.31.0", "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" From 551cd267235a82378ab28f2b1a4db90523d87ea8 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 10:40:48 +0100 Subject: [PATCH 034/103] allow boundary as input to ar_model.common_step --- neural_lam/models/ar_model.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 4ab73cc7..4a08306d 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -107,7 +107,9 @@ def __init__( self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim - + num_forcing_vars + # Factor 2 because of temporal embedding or windowed features + + 2 + * num_forcing_vars * (num_past_forcing_steps + num_future_forcing_steps + 1) ) @@ -200,19 +202,20 @@ def unroll_prediction(self, init_states, forcing_features, true_states): def common_step(self, batch): """ - Predict on single batch batch consists of: init_states: (B, 2, - num_grid_nodes, d_features) target_states: (B, pred_steps, - num_grid_nodes, d_features) forcing_features: (B, pred_steps, - num_grid_nodes, d_forcing), - where index 0 corresponds to index 1 of init_states + Predict on single batch batch consists of: + init_states: (B, 2,num_grid_nodes, d_features) + target_states: (B, pred_steps,num_grid_nodes, d_features) + forcing_features: (B, pred_steps,num_grid_nodes, d_forcing) + boundary_features: (B, pred_steps,num_grid_nodes, d_boundaries) + batch_times: (B, pred_steps) """ - (init_states, target_states, forcing_features, batch_times) = batch + (init_states, target_states, forcing_features, _, batch_times) = batch prediction, pred_std = self.unroll_prediction( init_states, forcing_features, target_states - ) # (B, pred_steps, num_grid_nodes, d_f) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) + ) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) return prediction, target_states, pred_std, batch_times From fc95350a28cbdb81419962b203e0bb08e36520dd Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 10:40:56 +0100 Subject: [PATCH 035/103] linting --- neural_lam/datastore/npyfilesmeps/store.py | 43 ++++++++---- neural_lam/weather_dataset.py | 66 ++++++++++++------- .../era5.datastore.yaml | 2 +- tests/test_time_slicing.py | 12 +++- tests/test_training.py | 17 ++--- 5 files changed, 91 insertions(+), 49 deletions(-) diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 7ee583be..24349e7e 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -244,7 +244,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # them separately features = ["toa_downwelling_shortwave_flux", "open_water_fraction"] das = [ - self._get_single_timeseries_dataarray(features=[feature], split=split) + self._get_single_timeseries_dataarray( + features=[feature], split=split + ) for feature in features ] da = xr.concat(das, dim="feature") @@ -257,9 +259,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # variable is turned into a dask array and so execution of the # calculation is delayed until the feature values are actually # used. - da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk( - {"elapsed_forecast_duration": 1} - ) + da_forecast_time = ( + da.analysis_time + da.elapsed_forecast_duration + ).chunk({"elapsed_forecast_duration": 1}) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -337,7 +339,10 @@ def _get_single_timeseries_dataarray( for all categories of data """ - if set(features).difference(self.get_vars_names(category="static")) == set(): + if ( + set(features).difference(self.get_vars_names(category="static")) + == set() + ): assert split in ( "train", "val", @@ -351,8 +356,12 @@ def _get_single_timeseries_dataarray( "test", ), f"Unknown dataset split {split} for features {features}" - if member is not None and features != self.get_vars_names(category="state"): - raise ValueError("Member can only be specified for the 'state' category") + if member is not None and features != self.get_vars_names( + category="state" + ): + raise ValueError( + "Member can only be specified for the 'state' category" + ) concat_axis = 0 @@ -368,7 +377,9 @@ def _get_single_timeseries_dataarray( fp_samples = self.root_path / "samples" / split if self._remove_state_features_with_index: n_to_drop = len(self._remove_state_features_with_index) - feature_dim_mask = np.ones(len(features) + n_to_drop, dtype=bool) + feature_dim_mask = np.ones( + len(features) + n_to_drop, dtype=bool + ) feature_dim_mask[self._remove_state_features_with_index] = False elif features == ["toa_downwelling_shortwave_flux"]: filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT @@ -434,7 +445,9 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split, member_id=member) + coord_values = self._get_analysis_times( + split=split, member_id=member + ) elif d == "y": coord_values = y elif d == "x": @@ -453,7 +466,9 @@ def _get_single_timeseries_dataarray( if features_vary_with_analysis_time: filepaths = [ fp_samples - / filename_format.format(analysis_time=analysis_time, **file_params) + / filename_format.format( + analysis_time=analysis_time, **file_params + ) for analysis_time in coords["analysis_time"] ] else: @@ -524,7 +539,9 @@ def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: times.append(name_parts["analysis_time"]) if len(times) == 0: - raise ValueError(f"No files found in {sample_dir} with pattern {pattern}") + raise ValueError( + f"No files found in {sample_dir} with pattern {pattern}" + ) return times @@ -681,7 +698,9 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ def load_pickled_tensor(fn): - return torch.load(self.root_path / "static" / fn, weights_only=True).numpy() + return torch.load( + self.root_path / "static" / fn, weights_only=True + ).numpy() mean_diff_values = None std_diff_values = None diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index b717c40a..b3d86292 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -41,13 +41,13 @@ class WeatherDataset(torch.utils.data.Dataset): num_past_boundary_steps: int, optional Number of past time steps to include in boundary input. If set to i, boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, - given num_future_forcing_steps) are included as boundary inputs at time t - Default is 1. + given num_future_forcing_steps) are included as boundary inputs at time + t Default is 1. num_future_boundary_steps: int, optional Number of future time steps to include in boundary input. If set to j, - boundary from times t, t+1, ..., t+j-1, t+j (and potentially times before - t, given num_past_forcing_steps) are included as boundary inputs at time - t. Default is 1. + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times + before t, given num_past_forcing_steps) are included as boundary inputs + at time t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -75,7 +75,9 @@ def __init__( self.num_past_boundary_steps = num_past_boundary_steps self.num_future_boundary_steps = num_future_boundary_steps - self.da_state = self.datastore.get_dataarray(category="state", split=self.split) + self.da_state = self.datastore.get_dataarray( + category="state", split=self.split + ) if self.da_state is None: raise ValueError( "A non-empty state dataarray must be provided. " @@ -112,7 +114,9 @@ def __init__( parts["forcing"] = self.da_forcing for part, da in parts.items(): - expected_dim_order = self.datastore.expected_dim_order(category=part) + expected_dim_order = self.datastore.expected_dim_order( + category=part + ) if da.dims != expected_dim_order: raise ValueError( f"The dimension order of the `{part}` data ({da.dims}) " @@ -188,10 +192,12 @@ def get_time_step(times): # Calculate required bounds for boundary using its time step boundary_required_time_min = ( - state_time_min - self.num_past_forcing_steps * boundary_time_step + state_time_min + - self.num_past_forcing_steps * boundary_time_step ) boundary_required_time_max = ( - state_time_max + self.num_future_forcing_steps * boundary_time_step + state_time_max + + self.num_future_forcing_steps * boundary_time_step ) if boundary_time_min > boundary_required_time_min: @@ -220,8 +226,10 @@ def get_time_step(times): self.da_state_std = self.ds_state_stats.state_std if self.da_forcing is not None: - self.ds_forcing_stats = self.datastore.get_standardization_dataarray( - category="forcing" + self.ds_forcing_stats = ( + self.datastore.get_standardization_dataarray( + category="forcing" + ) ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std @@ -378,7 +386,9 @@ def _slice_time( current_time = ( da_forcing_boundary.analysis_time[idx] - + da_forcing_boundary.elapsed_forecast_duration[offset + step] + + da_forcing_boundary.elapsed_forecast_duration[ + offset + step + ] ) da_sliced = da_forcing_boundary.isel( @@ -386,12 +396,16 @@ def _slice_time( elapsed_forecast_duration=slice(start_idx, end_idx + 1), ) - da_sliced = da_sliced.rename({"elapsed_forecast_duration": "window"}) + da_sliced = da_sliced.rename( + {"elapsed_forecast_duration": "window"} + ) da_sliced = da_sliced.assign_coords( window=np.arange(-num_past_steps, num_future_steps + 1) ) - da_sliced = da_sliced.expand_dims(dim={"time": [current_time.values]}) + da_sliced = da_sliced.expand_dims( + dim={"time": [current_time.values]} + ) da_list.append(da_sliced) @@ -401,13 +415,13 @@ def _slice_time( da_forcing_boundary_matched.time.values[1] - da_forcing_boundary_matched.time.values[0] ) - da_forcing_boundary_matched["window"] = da_forcing_boundary_matched["window"] * ( - forcing_time_step / state_time_step - ) + da_forcing_boundary_matched["window"] = da_forcing_boundary_matched[ + "window" + ] * (forcing_time_step / state_time_step) time_diff_steps = da_forcing_boundary_matched.isel( grid_index=0, forcing_feature=0 ).data - + else: # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing_boundary["time"] @@ -416,7 +430,8 @@ def _slice_time( # (in multiples of state time steps) # Retrieve the indices of the closest times in the forcing data time_deltas = ( - forcing_times.values[:, np.newaxis] - state_times.values[np.newaxis, :] + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) @@ -548,7 +563,9 @@ def _build_item_dataarrays(self, idx): da_target_times = da_target_states.time if self.standardize: - da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std + da_init_states = ( + da_init_states - self.da_state_mean + ) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean ) / self.da_state_std @@ -620,7 +637,9 @@ def __getitem__(self, idx): tensor_dtype = torch.float32 init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) - target_states = torch.tensor(da_target_states.values, dtype=tensor_dtype) + target_states = torch.tensor( + da_target_states.values, dtype=tensor_dtype + ) target_times = torch.tensor( da_target_times.astype("datetime64[ns]").astype("int64").values, @@ -730,7 +749,10 @@ def _is_listlike(obj): ) for grid_coord in ["x", "y"]: - if grid_coord in da_datastore_state.coords and grid_coord not in da.coords: + if ( + grid_coord in da_datastore_state.coords + and grid_coord not in da.coords + ): da.coords[grid_coord] = da_datastore_state[grid_coord] if not add_time_as_dim: diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index c97da4bc..7c5ffb3b 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -25,7 +25,7 @@ output: end: 2022-09-30T00:00 test: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 2f5ed96c..4a59c81e 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -40,7 +40,9 @@ def get_dataarray(self, category, split): if self.is_forecast: raise NotImplementedError() else: - da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values}) + da = xr.DataArray( + values, dims=["time"], coords={"time": self._time_values} + ) # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") @@ -77,7 +79,9 @@ def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): # state and forcing variables have only one dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) + time_values = np.datetime64("2020-01-01") + np.arange( + len(ANALYSIS_STATE_VALUES) + ) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -98,7 +102,9 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] + init_states, target_states, forcing, _, _ = [ + tensor.numpy() for tensor in sample + ] expected_init_states = [0, 1] if ar_steps == 3: diff --git a/tests/test_training.py b/tests/test_training.py index 28566a4b..7a1b4717 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -5,6 +5,7 @@ import pytest import pytorch_lightning as pl import torch + import wandb # First-party @@ -22,14 +23,10 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize( - "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() -) +@pytest.mark.parametrize("datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys()) def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example( - datastore_boundary_name - ) + datastore_boundary = init_datastore_boundary_example(datastore_boundary_name) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( @@ -38,15 +35,13 @@ def test_training(datastore_name, datastore_boundary_name): ) if not isinstance(datastore_boundary, BaseRegularGridDatastore): pytest.skip( - f"Skipping test for {datastore_boundary_name} as it is not a regular " - "grid datastore." + f"Skipping test for {datastore_boundary_name} as it is not a " + "regular grid datastore." ) if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision( - "high" - ) # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s else: device_name = "cpu" From 01fa807bc5ce47270e3b4568db8df8ce3b436953 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:10:29 +0100 Subject: [PATCH 036/103] improved docstrings and added some assertions --- neural_lam/weather_dataset.py | 105 ++++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 23 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index b3d86292..991965d9 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -143,7 +143,13 @@ def __init__( self.da_state = self.da_state def get_time_step(times): - """Calculate the time step from the data""" + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ time_diffs = np.diff(times) if not np.all(time_diffs == time_diffs[0]): raise ValueError( @@ -234,6 +240,7 @@ def get_time_step(times): self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + # XXX: Again, the boundary data is considered forcing data for now if self.da_boundary is not None: self.ds_boundary_stats = ( self.datastore_boundary.get_standardization_dataarray( @@ -305,7 +312,7 @@ def _slice_time( is performed based on the state times. Additionally, the time difference between the matched forcing/boundary times and state times (in multiples of state time steps) is added to the forcing dataarray. This will be - used as an additional feature in the model (temporal embedding). + used as an additional input feature in the model (temporal embedding). Parameters ---------- @@ -333,23 +340,26 @@ def _slice_time( da_forcing_boundary_matched : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'forcing/boundary_feature_windowed'). + If no forcing/boundary data is provided, this will be `None`. """ - # Number of initial steps required (e.g., for initializing models) + # The current implementation requires at least 2 time steps for the + # initial state (see GraphCast). init_steps = 2 - - # Slice the state data as before + # slice the dataarray to include the required number of time steps if self.datastore.is_forecast: - # Calculate start and end indices for slicing - start_idx = max(0, num_past_steps - init_steps) - end_idx = max(init_steps, num_past_steps) + n_steps - - # Slice the state data over the elapsed forecast duration + start_idx = max(0, self.num_past_forcing_steps - init_steps) + end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps + # this implies that the data will have both `analysis_time` and + # `elapsed_forecast_duration` dimensions for forecasts. We for now + # simply select a analysis time and the first `n_steps` forecast + # times (given no offset). Note that this means that we get one + # sample per forecast, always starting at forecast time 2. da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) - - # Create a new 'time' dimension + # create a new time dimension so that the produced sample has a + # `time` dimension, similarly to the analysis only data da_state_sliced["time"] = ( da_state_sliced.analysis_time + da_state_sliced.elapsed_forecast_duration @@ -357,9 +367,13 @@ def _slice_time( da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + # Asserting that the forecast time step is consistent + self.get_time_step(da_state_sliced.time) else: - # For analysis data, slice the time dimension directly + # For analysis data we slice the time dimension directly. The offset + # is only relevant for the very first (and last) samples in the + # dataset. start_idx = idx + max(0, num_past_steps - init_steps) end_idx = idx + max(init_steps, num_past_steps) + n_steps da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) @@ -372,7 +386,13 @@ def _slice_time( state_times = da_state_sliced["time"] state_time_step = state_times.values[1] - state_times.values[0] + # Here we cannot check 'self.datastore.is_forecast' directly because we + # might be dealing with a datastore_boundary if "analysis_time" in da_forcing_boundary.dims: + # Select the closest analysis time in the forcing/boundary data + # This is mostly relevant for boundary data where the time steps + # are not necessarily the same as the state data. But still fast + # enough for forcing data where the time steps are the same. idx = np.abs( da_forcing_boundary.analysis_time.values - self.da_state.analysis_time.values[idx] @@ -399,6 +419,8 @@ def _slice_time( da_sliced = da_sliced.rename( {"elapsed_forecast_duration": "window"} ) + + # Assign the 'window' coordinate to be relative positions da_sliced = da_sliced.assign_coords( window=np.arange(-num_past_steps, num_future_steps + 1) ) @@ -409,7 +431,10 @@ def _slice_time( da_list.append(da_sliced) - # Concatenate the list of DataArrays along the 'time' dimension + # Generate temporal embedding `time_diff_steps` for the + # forcing/boundary data. This is the time difference in multiples + # of state time steps between the forcing/boundary time and the + # state time. da_forcing_boundary_matched = xr.concat(da_list, dim="time") forcing_time_step = ( da_forcing_boundary_matched.time.values[1] @@ -423,7 +448,9 @@ def _slice_time( ).data else: - # For analysis data, match directly using the 'time' coordinate + # For analysis data, we slice the time dimension directly. The + # offset is only relevant for the very first (and last) samples in + # the dataset. forcing_times = da_forcing_boundary["time"] # Compute time differences between forcing and state times @@ -455,7 +482,7 @@ def _slice_time( ) # Add time difference as a new coordinate to concatenate to the - # forcing features later + # forcing features later as temporal embedding da_forcing_boundary_matched["time_diff_steps"] = ( ("time", "window"), time_diff_steps, @@ -464,7 +491,26 @@ def _slice_time( return da_state_sliced, da_forcing_boundary_matched def _process_windowed_data(self, da_windowed, da_state, da_target_times): - """Helper function to process windowed data after standardization.""" + """Helper function to process windowed data. This function stacks the + 'forcing_feature' and 'window' dimensions and adds the time step + differences to the existing features as a temporal embedding. + + Parameters + ---------- + da_windowed : xr.DataArray + The windowed data to process. Can be `None` if no data is provided. + da_state : xr.DataArray + The state dataarray. + da_target_times : xr.DataArray + The target times. + + Returns + ------- + da_windowed : xr.DataArray + The processed windowed data. If `da_windowed` is `None`, an empty + DataArray with the correct dimensions and coordinates is returned. + + """ stacked_dim = "forcing_feature_windowed" if da_windowed is not None: # Stack the 'feature' and 'window' dimensions and add the @@ -492,8 +538,8 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): def _build_item_dataarrays(self, idx): """ - Create the dataarrays for the initial states, target states and forcing - data for the sample at index `idx`. + Create the dataarrays for the initial states, target states, forcing + and boundary data for the sample at index `idx`. Parameters ---------- @@ -529,7 +575,7 @@ def _build_item_dataarrays(self, idx): else: da_boundary = None - # if da_forcing is None, the function will return None for + # if da_forcing_boundary is None, the function will return None for # da_forcing_windowed if da_boundary is not None: _, da_boundary_windowed = self._slice_time( @@ -542,6 +588,9 @@ def _build_item_dataarrays(self, idx): ) else: da_boundary_windowed = None + # XXX: Currently, the order of the `slice_time` calls is important + # as `da_state` is modified in the second call. This should be + # refactored to be more robust. da_state, da_forcing_windowed = self._slice_time( da_state=da_state, idx=idx, @@ -584,6 +633,10 @@ def _build_item_dataarrays(self, idx): da_boundary_windowed - self.da_boundary_mean ) / self.da_boundary_std + # This function handles the stacking of the forcing and boundary data + # and adds the time step differences as a temporal embedding. + # It can handle `None` inputs for the forcing and boundary data + # (and simlpy return an empty DataArray in that case). da_forcing_windowed = self._process_windowed_data( da_forcing_windowed, da_state, da_target_times ) @@ -655,6 +708,11 @@ def __getitem__(self, idx): # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) + # Assert that the boundary data is an empty tensor if the corresponding + # datastore_boundary is `None` + if self.datastore_boundary is None: + assert boundary.numel() == 0 + return init_states, target_states, forcing, boundary, target_times def __iter__(self): @@ -794,9 +852,10 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: - # BUG: There also seem to be issues with "spawn", to be investigated - # default to spawn for now, as the default on linux "fork" hangs - # when using dask (which the npyfilesmeps datastore uses) + # BUG: There also seem to be issues with "spawn" and `gloo`, to be + # investigated. Defaults to spawn for now, as the default on linux + # "fork" hangs when using dask (which the npyfilesmeps datastore + # uses) self.multiprocessing_context = "spawn" else: self.multiprocessing_context = None From 5a749f3ab55d79ce27ebe5bf439815d0cbf78093 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:10:42 +0100 Subject: [PATCH 037/103] update mdp dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5bbe4d92..ef75c8d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "parse>=1.20.2", "dataclass-wizard>=0.22.3", "gcsfs>=2021.10.0", - "mllam-data-prep @ git+https://github.com/leifdenby/mllam-data-prep@temp/for-neural-lam-datastores", + "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" From 45ba60782066cfc94d621f07119f23266556a374 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:11:32 +0100 Subject: [PATCH 038/103] remove boundary datastore from tests that don't need it --- tests/test_datasets.py | 17 ++--------------- tests/test_training.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 5fbe4a5d..063ec147 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -108,37 +108,24 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): # try to get the last item of the dataset to ensure slicing and stacking # operations are working as expected and are consistent with the dataset # length - dataset[len(dataset) - 1] @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize( - "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() -) -def test_dataset_item_create_dataarray_from_tensor( - datastore_name, datastore_boundary_name -): +def test_dataset_item_create_dataarray_from_tensor(datastore_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example( - datastore_boundary_name - ) N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 - num_past_boundary_steps = 1 - num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, - datastore_boundary=datastore_boundary, + datastore_boundary=None, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, - num_past_boundary_steps=num_past_boundary_steps, - num_future_boundary_steps=num_future_boundary_steps, ) idx = 0 diff --git a/tests/test_training.py b/tests/test_training.py index 7a1b4717..ca0ebf41 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -5,7 +5,6 @@ import pytest import pytorch_lightning as pl import torch - import wandb # First-party @@ -23,10 +22,14 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize("datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example(datastore_boundary_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( @@ -41,7 +44,9 @@ def test_training(datastore_name, datastore_boundary_name): if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision( + "high" + ) # Allows using Tensor Cores on A100s else: device_name = "cpu" From f36f36040dcbfa40380880d4cc9fa03f6632da43 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:42:43 +0100 Subject: [PATCH 039/103] fix scope of _get_slice_time --- neural_lam/weather_dataset.py | 40 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 991965d9..4bc9d5c7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,28 +142,14 @@ def __init__( else: self.da_state = self.da_state - def get_time_step(times): - """Calculate the time step from the data - - Parameters - ---------- - times : xr.DataArray - The time dataarray to calculate the time step from. - """ - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] + # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time else: state_times = self.da_state.time - _ = get_time_step(state_times) + _ = self._get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: @@ -182,7 +168,7 @@ def get_time_step(times): forcing_times = self.da_forcing.analysis_time else: forcing_times = self.da_forcing.time - get_time_step(forcing_times.values) + self._get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore @@ -192,7 +178,7 @@ def get_time_step(times): boundary_times = self.da_boundary.analysis_time else: boundary_times = self.da_boundary.time - boundary_time_step = get_time_step(boundary_times.values) + boundary_time_step = self._get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values @@ -296,6 +282,22 @@ def __len__(self): - self.num_future_forcing_steps ) + def _get_time_step(self, times): + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + def _slice_time( self, da_state, @@ -368,7 +370,7 @@ def _slice_time( {"elapsed_forecast_duration": "time"} ) # Asserting that the forecast time step is consistent - self.get_time_step(da_state_sliced.time) + self._get_time_step(da_state_sliced.time) else: # For analysis data we slice the time dimension directly. The offset From 105108e9bd144c64075e0f5588f15176fc1fde52 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:43:01 +0100 Subject: [PATCH 040/103] fix scope of _get_time_step --- neural_lam/weather_dataset.py | 40 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 991965d9..4bc9d5c7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,28 +142,14 @@ def __init__( else: self.da_state = self.da_state - def get_time_step(times): - """Calculate the time step from the data - - Parameters - ---------- - times : xr.DataArray - The time dataarray to calculate the time step from. - """ - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] + # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time else: state_times = self.da_state.time - _ = get_time_step(state_times) + _ = self._get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: @@ -182,7 +168,7 @@ def get_time_step(times): forcing_times = self.da_forcing.analysis_time else: forcing_times = self.da_forcing.time - get_time_step(forcing_times.values) + self._get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore @@ -192,7 +178,7 @@ def get_time_step(times): boundary_times = self.da_boundary.analysis_time else: boundary_times = self.da_boundary.time - boundary_time_step = get_time_step(boundary_times.values) + boundary_time_step = self._get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values @@ -296,6 +282,22 @@ def __len__(self): - self.num_future_forcing_steps ) + def _get_time_step(self, times): + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + def _slice_time( self, da_state, @@ -368,7 +370,7 @@ def _slice_time( {"elapsed_forecast_duration": "time"} ) # Asserting that the forecast time step is consistent - self.get_time_step(da_state_sliced.time) + self._get_time_step(da_state_sliced.time) else: # For analysis data we slice the time dimension directly. The offset From ae0cf764bd23adfde2befa4bef8ef89122975688 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 16:58:46 +0100 Subject: [PATCH 041/103] added information about optional boundary datastore --- README.md | 22 +++++++++++++--------- neural_lam/weather_dataset.py | 2 -- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e21b7c24..7a5e5caf 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,9 @@ Once `neural-lam` is installed you will be able to train/evaluate models. For th interface that provides the data in a data-structure that can be used within neural-lam. A datastore is used to create a `pytorch.Dataset`-derived class that samples the data in time to create individual samples for - training, validation and testing. + training, validation and testing. A secondary datastore can be provided + for the boundary data. Currently, boundary datastore must be of type `mdp` + and only contain forcing features. This can easily be expanded in the future. 2. **The graph structure** is used to define message-passing GNN layers, that are trained to emulate fluid flow in the atmosphere over time. The @@ -121,7 +123,7 @@ different aspects about the training and evaluation of the model. The path you provide to the neural-lam config (`config.yaml`) also sets the root directory relative to which all other paths are resolved, as in the parent -directory of the config becomes the root directory. Both the datastore and +directory of the config becomes the root directory. Both the datastores and graphs you generate are then stored in subdirectories of this root directory. Exactly how and where a specific datastore expects its source data to be stored and where it stores its derived data is up to the implementation of the @@ -134,6 +136,7 @@ assume you placed `config.yaml` in a folder called `data`): data/ ├── config.yaml - Configuration file for neural-lam ├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml +├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml └── graphs/ - Directory containing graphs for training ``` @@ -142,18 +145,20 @@ And the content of `config.yaml` could in this case look like: datastore: kind: mdp config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml training: state_feature_weighting: __config_class__: ManualStateFeatureWeighting - values: + weights: u100m: 1.0 v100m: 1.0 ``` -For now the neural-lam config only defines two things: 1) the kind of data -store and the path to its config, and 2) the weighting of different features in -the loss function. If you don't define the state feature weighting it will default -to weighting all features equally. +For now the neural-lam config only defines two things: +1) the kind of datastores and the path to their config +2) the weighting of different features in the loss function. If you don't define the state feature weighting it will default to weighting all features equally. (This example is taken from the `tests/datastore_examples/mdp` directory.) @@ -525,5 +530,4 @@ Furthermore, all tests in the ```tests``` directory will be run upon pushing cha # Contact If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch. -There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join (after following the link you have to request to join, this is to avoid spam bots). -You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). +There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 4bc9d5c7..8d82229f 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,8 +142,6 @@ def __init__( else: self.da_state = self.da_state - - # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time From 9af27e0741894319860d11eb22cd9e9fd398e1ec Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 19:46:37 +0100 Subject: [PATCH 042/103] add datastore_boundary to neural_lam --- neural_lam/train_model.py | 22 ++++++++++++++++++++++ neural_lam/weather_dataset.py | 10 ++++++++++ 2 files changed, 32 insertions(+) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c89..37bf6db7 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -34,6 +34,11 @@ def main(input_args=None): type=str, help="Path to the configuration for neural-lam", ) + parser.add_argument( + "--config_path_boundary", + type=str, + help="Path to the configuration for boundary conditions", + ) parser.add_argument( "--model", type=str, @@ -212,6 +217,9 @@ def main(input_args=None): assert ( args.config_path is not None ), "Specify your config with --config_path" + assert ( + args.config_path_boundary is not None + ), "Specify your config with --config_path_boundary" assert args.model in MODELS, f"Unknown model: {args.model}" assert args.eval in ( None, @@ -227,10 +235,24 @@ def main(input_args=None): # Load neural-lam configuration and datastore to use config, datastore = load_config_and_datastore(config_path=args.config_path) + config_boundary, datastore_boundary = load_config_and_datastore( + config_path=args.config_path_boundary + ) + + # TODO this should not be required, make more flexible + assert ( + datastore.num_past_forcing_steps + == datastore_boundary.num_past_forcing_steps + ), "Mismatch in num_past_forcing_steps" + assert ( + datastore.num_future_forcing_steps + == datastore_boundary.num_future_forcing_steps + ), "Mismatch in num_future_forcing_steps" # Create datamodule data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=args.ar_steps_train, ar_steps_eval=args.ar_steps_eval, standardize=True, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index b5f85580..75f7e04e 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -22,6 +22,8 @@ class WeatherDataset(torch.utils.data.Dataset): ---------- datastore : BaseDatastore The datastore to load the data from (e.g. mdp). + datastore_boundary : BaseDatastore + The boundary datastore to load the data from (e.g. mdp). split : str, optional The data split to use ("train", "val" or "test"). Default is "train". ar_steps : int, optional @@ -43,6 +45,7 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, split="train", ar_steps=3, num_past_forcing_steps=1, @@ -54,6 +57,7 @@ def __init__( self.split = split self.ar_steps = ar_steps self.datastore = datastore + self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps @@ -606,6 +610,7 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, ar_steps_train=3, ar_steps_eval=25, standardize=True, @@ -616,6 +621,7 @@ def __init__( ): super().__init__() self._datastore = datastore + self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps self.ar_steps_train = ar_steps_train @@ -627,6 +633,7 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: + # BUG: There also seem to be issues with "spawn", to be investigated # default to spawn for now, as the default on linux "fork" hangs # when using dask (which the npyfilesmeps datastore uses) self.multiprocessing_context = "spawn" @@ -637,6 +644,7 @@ def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="train", ar_steps=self.ar_steps_train, standardize=self.standardize, @@ -645,6 +653,7 @@ def setup(self, stage=None): ) self.val_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="val", ar_steps=self.ar_steps_eval, standardize=self.standardize, @@ -655,6 +664,7 @@ def setup(self, stage=None): if stage == "test" or stage is None: self.test_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="test", ar_steps=self.ar_steps_eval, standardize=self.standardize, From c25fb30ab6b9fc8038227a590b5551f1660dbe19 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 20:15:41 +0100 Subject: [PATCH 043/103] complete integration of boundary in weatherDataset --- neural_lam/weather_dataset.py | 55 ++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 75f7e04e..7585207c 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -67,6 +67,9 @@ def __init__( self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + self.da_boundary = self.datastore_boundary.get_dataarray( + category="boundary", split=self.split + ) # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -118,6 +121,15 @@ def __init__( self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + if self.da_boundary is not None: + self.ds_boundary_stats = ( + self.datastore_boundary.get_standardization_dataarray( + category="boundary" + ) + ) + self.da_boundary_mean = self.ds_boundary_stats.boundary_mean + self.da_boundary_std = self.ds_boundary_stats.boundary_std + def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time @@ -352,6 +364,8 @@ def _build_item_dataarrays(self, idx): The dataarray for the target states. da_forcing_windowed : xr.DataArray The dataarray for the forcing data, windowed for the sample. + da_boundary_windowed : xr.DataArray + The dataarray for the boundary data, windowed for the sample. da_target_times : xr.DataArray The dataarray for the target times. """ @@ -381,6 +395,11 @@ def _build_item_dataarrays(self, idx): else: da_forcing = None + if self.da_boundary is not None: + da_boundary = self.da_boundary + else: + da_boundary = None + # handle time sampling in a way that is compatible with both analysis # and forecast data da_state = self._slice_state_time( @@ -390,11 +409,17 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed = self._slice_forcing_time( da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps ) + if da_boundary is not None: + da_boundary_windowed = self._slice_forcing_time( + da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps + ) # load the data into memory da_state.load() if da_forcing is not None: da_forcing_windowed.load() + if da_boundary is not None: + da_boundary_windowed.load() da_init_states = da_state.isel(time=slice(0, 2)) da_target_states = da_state.isel(time=slice(2, None)) @@ -417,6 +442,11 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed - self.da_forcing_mean ) / self.da_forcing_std + if da_boundary is not None: + da_boundary_windowed = ( + da_boundary_windowed - self.da_boundary_mean + ) / self.da_boundary_std + if da_forcing is not None: # stack the `forcing_feature` and `window_sample` dimensions into a # single `forcing_feature` dimension @@ -436,11 +466,31 @@ def _build_item_dataarrays(self, idx): "forcing_feature": [], }, ) + if da_boundary is not None: + # stack the `forcing_feature` and `window_sample` dimensions into a + # single `forcing_feature` dimension + da_boundary_windowed = da_boundary_windowed.stack( + boundary_feature_windowed=("boundary_feature", "window") + ) + else: + # create an empty forcing tensor with the right shape + da_boundary_windowed = xr.DataArray( + data=np.empty( + (self.ar_steps, da_state.grid_index.size, 0), + ), + dims=("time", "grid_index", "boundary_feature"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + "boundary_feature": [], + }, + ) return ( da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) @@ -475,6 +525,7 @@ def __getitem__(self, idx): da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) = self._build_item_dataarrays(idx=idx) @@ -491,13 +542,15 @@ def __getitem__(self, idx): ) forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype) + boundary = torch.tensor(da_boundary_windowed.values, dtype=tensor_dtype) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) # forcing: (ar_steps, N_grid, d_windowed_forcing) + # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) - return init_states, target_states, forcing, target_times + return init_states, target_states, forcing, boundary, target_times def __iter__(self): """ From 505ceeb589c3398d37100a6073fa5590e7d786c2 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 20:15:55 +0100 Subject: [PATCH 044/103] Add test to check timestep length and spacing --- neural_lam/weather_dataset.py | 76 +++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 7585207c..8e55d4a5 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -101,6 +101,82 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # Check time coverage for forcing and boundary data + if self.da_forcing is not None or self.da_boundary is not None: + state_times = self.da_state.time + state_time_min = state_times.min().values + state_time_max = state_times.max().values + + def get_time_step(times): + """Calculate the time step from the data""" + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + if self.da_forcing is not None: + forcing_times = self.da_forcing.time + forcing_time_step = get_time_step(forcing_times.values) + forcing_time_min = forcing_times.min().values + forcing_time_max = forcing_times.max().values + + # Calculate required bounds for forcing using its time step + forcing_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * forcing_time_step + ) + forcing_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * forcing_time_step + ) + + if forcing_time_min > forcing_required_time_min: + raise ValueError( + f"Forcing data starts too late." + f"Required start: {forcing_required_time_min}, " + f"but forcing starts at {forcing_time_min}." + ) + + if forcing_time_max < forcing_required_time_max: + raise ValueError( + f"Forcing data ends too early." + f"Required end: {forcing_required_time_max}," + f"but forcing ends at {forcing_time_max}." + ) + + if self.da_boundary is not None: + boundary_times = self.da_boundary.time + boundary_time_step = get_time_step(boundary_times.values) + boundary_time_min = boundary_times.min().values + boundary_time_max = boundary_times.max().values + + # Calculate required bounds for boundary using its time step + boundary_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * boundary_time_step + ) + boundary_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * boundary_time_step + ) + + if boundary_time_min > boundary_required_time_min: + raise ValueError( + f"Boundary data starts too late." + f"Required start: {boundary_required_time_min}, " + f"but boundary starts at {boundary_time_min}." + ) + + if boundary_time_max < boundary_required_time_max: + raise ValueError( + f"Boundary data ends too early." + f"Required end: {boundary_required_time_max}, " + f"but boundary ends at {boundary_time_max}." + ) + # Set up for standardization # TODO: This will become part of ar_model.py soon! self.standardize = standardize From e7330664661bd336caf40842dfb46a406b120721 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 21:43:57 +0100 Subject: [PATCH 045/103] setting default mdp boundary to 0 gridcells --- neural_lam/datastore/mdp.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 0d1aac7b..b6f1676c 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -27,7 +27,7 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): + def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at `config_path`. A boundary mask is created with `n_boundary_points` @@ -336,19 +336,22 @@ def boundary_mask(self) -> xr.DataArray: boundary point and 0 is not. """ - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) + if self._n_boundary_points > 0: + ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) + da_state_variable = ( + ds_unstacked["state"].isel(time=0).isel(state_feature=0) + ) + da_domain_allzero = xr.zeros_like(da_state_variable) + ds_unstacked["boundary_mask"] = da_domain_allzero.isel( + x=slice(self._n_boundary_points, -self._n_boundary_points), + y=slice(self._n_boundary_points, -self._n_boundary_points), + ) + ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( + 1 + ).astype(int) + return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) + else: + return None @property def coords_projection(self) -> ccrs.Projection: From d8349a4801654c152f14924aa86d08c4ab952468 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 21:44:54 +0100 Subject: [PATCH 046/103] implement time-based slicing combine two slicing fcts into one --- neural_lam/weather_dataset.py | 300 ++++++++++++++++++---------------- 1 file changed, 161 insertions(+), 139 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 8e55d4a5..5559e838 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -67,8 +67,9 @@ def __init__( self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + # XXX For now boundary data is always considered forcing data self.da_boundary = self.datastore_boundary.get_dataarray( - category="boundary", split=self.split + category="forcing", split=self.split ) # check that with the provided data-arrays and ar_steps that we have a @@ -200,7 +201,7 @@ def get_time_step(times): if self.da_boundary is not None: self.ds_boundary_stats = ( self.datastore_boundary.get_standardization_dataarray( - category="boundary" + category="forcing" ) ) self.da_boundary_mean = self.ds_boundary_stats.boundary_mean @@ -252,175 +253,156 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_state_time(self, da_state, idx, n_steps: int): + def _slice_time(self, da_state, da_forcing, idx, n_steps: int): """ - Produce a time slice of the given dataarray `da_state` (state) starting - at `idx` and with `n_steps` steps. An `offset`is calculated based on the - `num_past_forcing_steps` class attribute. `Offset` is used to offset the - start of the sample, to assert that enough previous time steps are - available for the 2 initial states and any corresponding forcings - (calculated in `_slice_forcing_time`). + Produce time slices of the given dataarrays `da_state` (state) and + `da_forcing` (forcing). For the state data, slicing is done as before + based on `idx`. For the forcing data, nearest neighbor matching is + performed based on the state times. Additionally, the time difference + between the matched forcing times and state times (in multiples of state + time steps) is added to the forcing dataarray. Parameters ---------- da_state : xr.DataArray - The dataarray to slice. This is expected to have a `time` dimension - if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. + The state dataarray to slice. + da_forcing : xr.DataArray + The forcing dataarray to slice. idx : int - The index of the time step to start the sample from. + The index of the time step to start the sample from in the state + data. n_steps : int The number of time steps to include in the sample. Returns ------- - da_sliced : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', + da_state_sliced : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). + da_forcing_matched : xr.DataArray + The forcing dataarray matched to state times with an added + coordinate 'time_diff', representing the time difference to state + times in multiples of state time steps. """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). + # Number of initial steps required (e.g., for initializing models) init_steps = 2 - # slice the dataarray to include the required number of time steps + + # Slice the state data as before if self.datastore.is_forecast: + # Calculate start and end indices for slicing start_idx = max(0, self.num_past_forcing_steps - init_steps) end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps - # this implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select a analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast, always starting at forecast time 2. - da_sliced = da_state.isel( + + # Slice the state data over the elapsed forecast duration + da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) - # create a new time dimension so that the produced sample has a - # `time` dimension, similarly to the analysis only data - da_sliced["time"] = ( - da_sliced.analysis_time + da_sliced.elapsed_forecast_duration + + # Create a new 'time' dimension + da_state_sliced["time"] = ( + da_state_sliced.analysis_time + + da_state_sliced.elapsed_forecast_duration ) - da_sliced = da_sliced.swap_dims( + da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + else: - # For analysis data we slice the time dimension directly. The offset - # is only relevant for the very first (and last) samples in the - # dataset. + # For analysis data, slice the time dimension directly start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) end_idx = ( idx + max(init_steps, self.num_past_forcing_steps) + n_steps ) - da_sliced = da_state.isel(time=slice(start_idx, end_idx)) - return da_sliced + da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - def _slice_forcing_time(self, da_forcing, idx, n_steps: int): - """ - Produce a time slice of the given dataarray `da_forcing` (forcing) - starting at `idx` and with `n_steps` steps. An `offset` is calculated - based on the `num_past_forcing_steps` class attribute. It is used to - offset the start of the sample, to ensure that enough previous time - steps are available for the forcing data. The forcing data is windowed - around the current autoregressive time step to include the past and - future forcings. - - Parameters - ---------- - da_forcing : xr.DataArray - The forcing dataarray to slice. This is expected to have a `time` - dimension if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. - idx : int - The index of the time step to start the sample from. - n_steps : int - The number of time steps to include in the sample. - - Returns - ------- - da_concat : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', - 'window', 'forcing_feature'). - """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). The forcing data is windowed around the - # current autregressive time step. The two `init_steps` can also be used - # as past forcings. - init_steps = 2 - da_list = [] + # Get the state times for matching + state_times = da_state_sliced["time"] + # Match forcing data to state times based on nearest neighbor if self.datastore.is_forecast: - # This implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select an analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast. - # Add a 'time' dimension using the actual forecast times - offset = max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps - - current_time = ( - da_forcing.analysis_time[idx] - + da_forcing.elapsed_forecast_duration[offset + step] - ) - - da_sliced = da_forcing.isel( - analysis_time=idx, - elapsed_forecast_duration=slice(start_idx, end_idx + 1), - ) - - da_sliced = da_sliced.rename( - {"elapsed_forecast_duration": "window"} - ) + # Calculate all possible forcing times + forcing_times = ( + da_forcing.analysis_time + da_forcing.elapsed_forecast_duration + ) + forcing_times_flat = forcing_times.stack( + forecast_time=("analysis_time", "elapsed_forecast_duration") + ) - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) - ) + # Compute time differences + time_deltas = ( + forcing_times_flat.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) + time_diffs = np.abs(time_deltas) + idx_min = time_diffs.argmin(axis=0) + + # Retrieve corresponding indices for analysis_time and + # elapsed_forecast_duration + forecast_time_index = forcing_times_flat["forecast_time"][idx_min] + analysis_time_indices = forecast_time_index["analysis_time"] + elapsed_forecast_duration_indices = forecast_time_index[ + "elapsed_forecast_duration" + ] + + # Slice the forcing data using matched indices + da_forcing_matched = da_forcing.isel( + analysis_time=("time", analysis_time_indices), + elapsed_forecast_duration=( + "time", + elapsed_forecast_duration_indices, + ), + ) - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + # Assign matched state times to the forcing data + da_forcing_matched["time"] = state_times + da_forcing_matched = da_forcing_matched.swap_dims( + {"elapsed_forecast_duration": "time"} + ) - da_list.append(da_sliced) + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] + time_diff_steps = ( + time_deltas[idx_min, np.arange(len(state_times))] + / state_time_step + ) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + # Add time difference as a new coordinate + da_forcing_matched = da_forcing_matched.assign_coords( + time_diff=("time", time_diff_steps) + ) else: - # For analysis data, we slice the time dimension directly. The - # offset is only relevant for the very first (and last) samples in - # the dataset. - offset = idx + max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps - - # Slice the data over the desired time window - da_sliced = da_forcing.isel(time=slice(start_idx, end_idx + 1)) - - da_sliced = da_sliced.rename({"time": "window"}) - - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) - ) + # For analysis data, match directly using the 'time' coordinate + forcing_times = da_forcing["time"] - # Add a 'time' dimension to keep track of steps using actual - # time coordinates - current_time = da_forcing.time[offset + step] - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + # Compute time differences + time_deltas = ( + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) + time_diffs = np.abs(time_deltas) + idx_min = time_diffs.argmin(axis=0) - da_list.append(da_sliced) + # Slice the forcing data using matched indices + da_forcing_matched = da_forcing.isel(time=idx_min) + da_forcing_matched = da_forcing_matched.assign_coords( + time=state_times + ) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] + time_diff_steps = ( + time_deltas[idx_min, np.arange(len(state_times))] + / state_time_step + ) - return da_concat + # Add time difference as a new coordinate + da_forcing_matched = da_forcing_matched.assign_coords( + time_diff=("time", time_diff_steps) + ) + + return da_state_sliced, da_forcing_matched def _build_item_dataarrays(self, idx): """ @@ -442,6 +424,7 @@ def _build_item_dataarrays(self, idx): The dataarray for the forcing data, windowed for the sample. da_boundary_windowed : xr.DataArray The dataarray for the boundary data, windowed for the sample. + Boundary data is always considered forcing data. da_target_times : xr.DataArray The dataarray for the target times. """ @@ -478,15 +461,15 @@ def _build_item_dataarrays(self, idx): # handle time sampling in a way that is compatible with both analysis # and forecast data - da_state = self._slice_state_time( + da_state = self._slice_time( da_state=da_state, idx=idx, n_steps=self.ar_steps ) if da_forcing is not None: - da_forcing_windowed = self._slice_forcing_time( + da_forcing_windowed = self._slice_time( da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps ) if da_boundary is not None: - da_boundary_windowed = self._slice_forcing_time( + da_boundary_windowed = self._slice_time( da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps ) @@ -524,13 +507,32 @@ def _build_item_dataarrays(self, idx): ) / self.da_boundary_std if da_forcing is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension + # Expand 'time_diff' to align with 'forcing_feature' and 'window' + # dimensions 'time_diff' has dimension ('time'), expand to ('time', + # 'forcing_feature', 'window') + time_diff_expanded = da_forcing_windowed["time_diff"].expand_dims( + forcing_feature=da_forcing_windowed["forcing_feature"], + window=da_forcing_windowed["window"], + ) + + # Stack 'forcing_feature' and 'window' into a single + # 'forcing_feature_windowed' dimension da_forcing_windowed = da_forcing_windowed.stack( forcing_feature_windowed=("forcing_feature", "window") ) + time_diff_expanded = time_diff_expanded.stack( + forcing_feature_windowed=("forcing_feature", "window") + ) + + # Assign 'time_diff' as a coordinate to 'forcing_feature_windowed' + da_forcing_windowed = da_forcing_windowed.assign_coords( + time_diff=( + "forcing_feature_windowed", + time_diff_expanded.values, + ) + ) else: - # create an empty forcing tensor with the right shape + # Create an empty forcing tensor with the right shape da_forcing_windowed = xr.DataArray( data=np.empty( (self.ar_steps, da_state.grid_index.size, 0), @@ -542,14 +544,34 @@ def _build_item_dataarrays(self, idx): "forcing_feature": [], }, ) + if da_boundary is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension + # If 'da_boundary_windowed' also has 'time_diff', process similarly + # Expand 'time_diff' to align with 'boundary_feature' and 'window' + # dimensions + time_diff_expanded = da_boundary_windowed["time_diff"].expand_dims( + boundary_feature=da_boundary_windowed["boundary_feature"], + window=da_boundary_windowed["window"], + ) + + # Stack 'boundary_feature' and 'window' into a single + # 'boundary_feature_windowed' dimension da_boundary_windowed = da_boundary_windowed.stack( boundary_feature_windowed=("boundary_feature", "window") ) + time_diff_expanded = time_diff_expanded.stack( + boundary_feature_windowed=("boundary_feature", "window") + ) + + # Assign 'time_diff' as a coordinate to 'boundary_feature_windowed' + da_boundary_windowed = da_boundary_windowed.assign_coords( + time_diff=( + "boundary_feature_windowed", + time_diff_expanded.values, + ) + ) else: - # create an empty forcing tensor with the right shape + # Create an empty boundary tensor with the right shape da_boundary_windowed = xr.DataArray( data=np.empty( (self.ar_steps, da_state.grid_index.size, 0), From fd791bfb51c3c751ff4af8d74eaa47c81b63a1eb Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 06:26:54 +0100 Subject: [PATCH 047/103] remove all interior_mask and boundary_mask --- neural_lam/datastore/base.py | 17 ---- neural_lam/datastore/mdp.py | 34 -------- neural_lam/datastore/npyfilesmeps/store.py | 28 ------ neural_lam/models/ar_model.py | 53 +++--------- neural_lam/vis.py | 12 --- .../config.yaml | 18 ++++ .../era5.datastore.yaml | 85 +++++++++++++++++++ .../meps_example_reduced.datastore.yaml | 44 ++++++++++ tests/dummy_datastore.py | 22 ----- tests/test_datastores.py | 21 ----- 10 files changed, 157 insertions(+), 177 deletions(-) create mode 100644 tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml create mode 100644 tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml create mode 100644 tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index b0055e39..e2d21404 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -228,23 +228,6 @@ def get_dataarray( """ pass - @cached_property - @abc.abstractmethod - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - pass - @abc.abstractmethod def get_xy(self, category: str) -> np.ndarray: """ diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index b6f1676c..e662cb63 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -319,40 +319,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: ds_stats = self._ds[stats_variables.keys()].rename(stats_variables) return ds_stats - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Produce a 0/1 mask for the boundary points of the dataset, these will - sit at the edges of the domain (in x/y extent) and will be used to mask - out the boundary points from the loss function and to overwrite the - boundary points from the prediction. For now this is created when the - mask is requested, but in the future this could be saved to the zarr - file. - - Returns - ------- - xr.DataArray - A 0/1 mask for the boundary points of the dataset, where 1 is a - boundary point and 0 is not. - - """ - if self._n_boundary_points > 0: - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) - else: - return None - @property def coords_projection(self) -> ccrs.Projection: """ diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 42e80706..146b0627 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -668,34 +668,6 @@ def grid_shape_state(self) -> CartesianGridShape: ny, nx = self.config.grid_shape_state return CartesianGridShape(x=nx, y=ny) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """The boundary mask for the dataset. This is a binary mask that is 1 - where the grid cell is on the boundary of the domain, and 0 otherwise. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions `[grid_index]`. - - """ - xy = self.get_xy(category="state", stacked=False) - xs = xy[:, :, 0] - ys = xy[:, :, 1] - # Check if x-coordinates are constant along columns - assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant" - # Check if y-coordinates are constant along rows - assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant" - # Extract unique x and y coordinates - x = xs[:, 0] # Unique x-coordinates (changes along the first axis) - y = ys[0, :] # Unique y-coordinates (changes along the second axis) - values = np.load(self.root_path / "static" / "border_mask.npy") - da_mask = xr.DataArray( - values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask" - ) - da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int) - return da_mask_stacked_xy - def get_standardization_dataarray(self, category: str) -> xr.Dataset: """Return the standardization dataarray for the given category. This should contain a `{category}_mean` and `{category}_std` variable for diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 44baf9c2..710efcec 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -45,7 +45,6 @@ def __init__( da_state_stats = datastore.get_standardization_dataarray( category="state" ) - da_boundary_mask = datastore.boundary_mask num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps @@ -118,18 +117,6 @@ def __init__( # Instantiate loss function self.loss = metrics.get_metric(args.loss) - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.float32 - ).unsqueeze( - 1 - ) # add feature dim - - self.register_buffer("boundary_mask", boundary_mask, persistent=False) - # Pre-compute interior mask for use in loss function - self.register_buffer( - "interior_mask", 1.0 - self.boundary_mask, persistent=False - ) # (num_grid_nodes, 1), 1 for non-border - self.val_metrics = { "mse": [], } @@ -194,13 +181,6 @@ def configure_optimizers(self): ) return opt - @property - def interior_mask_bool(self): - """ - Get the interior mask as a boolean (N,) mask. - """ - return self.interior_mask[:, 0].to(torch.bool) - @staticmethod def expand_to_batch(x, batch_size): """ @@ -232,7 +212,6 @@ def unroll_prediction(self, init_states, forcing_features, true_states): for i in range(pred_steps): forcing = forcing_features[:, i] - border_state = true_states[:, i] pred_state, pred_std = self.predict_step( prev_state, prev_prev_state, forcing @@ -240,19 +219,13 @@ def unroll_prediction(self, init_states, forcing_features, true_states): # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, # d_f) or None - # Overwrite border with true state - new_state = ( - self.boundary_mask * border_state - + self.interior_mask * pred_state - ) - - prediction_list.append(new_state) + prediction_list.append(pred_state) if self.output_std: pred_std_list.append(pred_std) # Update conditioning states prev_prev_state = prev_state - prev_state = new_state + prev_state = pred_state prediction = torch.stack( prediction_list, dim=1 @@ -290,12 +263,14 @@ def training_step(self, batch): """ prediction, target, pred_std, _ = self.common_step(batch) - # Compute loss + # Compute loss - mean over unrolled times and batch batch_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ) - ) # mean over unrolled times and batch + ) log_dict = {"train_loss": batch_loss} self.log_dict( @@ -328,9 +303,7 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, _ = self.common_step(batch) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) @@ -355,7 +328,6 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.val_metrics["mse"].append(entry_mses) @@ -382,9 +354,7 @@ def test_step(self, batch, batch_idx): # pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1,) mean_loss = torch.mean(time_step_loss) @@ -413,16 +383,13 @@ def test_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.test_metrics[metric_name].append(batch_metric_vals) if self.output_std: # Store output std. per variable, spatially averaged - mean_pred_std = torch.mean( - pred_std[..., self.interior_mask_bool, :], dim=-2 - ) # (B, pred_steps, d_f) + mean_pred_std = torch.mean(pred_std, dim=-2) # (B, pred_steps, d_f) self.test_metrics["output_std"].append(mean_pred_std) # Save per-sample spatial loss for specific times diff --git a/neural_lam/vis.py b/neural_lam/vis.py index d6b57f88..efab20bf 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,11 +87,6 @@ def plot_prediction( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_values = np.invert(da_mask.values.astype(bool)).astype(float) - pixel_alpha = mask_values.clip(0.7, 1) # Faded border region - fig, axes = plt.subplots( 1, 2, @@ -107,7 +102,6 @@ def plot_prediction( origin="lower", x="x", extent=extent, - alpha=pixel_alpha.T, vmin=vmin, vmax=vmax, cmap="plasma", @@ -141,11 +135,6 @@ def plot_spatial_error( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region - fig, ax = plt.subplots( figsize=(5, 4.8), subplot_kw={"projection": datastore.coords_projection}, @@ -164,7 +153,6 @@ def plot_spatial_error( error_grid, origin="lower", extent=extent, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="OrRd", diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml new file mode 100644 index 00000000..27cc9764 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml @@ -0,0 +1,18 @@ +datastore: + kind: npyfilesmeps + config_path: meps_example_reduced.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + nlwrs_0: 1.0 + nswrs_0: 1.0 + pres_0g: 1.0 + pres_0s: 1.0 + r_2: 1.0 + r_65: 1.0 + t_2: 1.0 + t_65: 1.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml new file mode 100644 index 00000000..600a1845 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml @@ -0,0 +1,85 @@ +schema_version: v0.5.0 +dataset_version: v1.0.0 + +output: + variables: + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + test: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_surface_net_short_wave_radiation_flux + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml new file mode 100644 index 00000000..3d88d4a4 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml @@ -0,0 +1,44 @@ +dataset: + name: meps_example_reduced + num_forcing_features: 16 + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + var_units: + - Pa + - Pa + - W/m**2 + - W/m**2 + - '' + - '' + - K + - K + num_timesteps: 65 + num_ensemble_members: 2 + step_length: 3 +grid_shape_state: +- 134 +- 119 +projection: + class_name: LambertConformal + kwargs: + central_latitude: 63.3 + central_longitude: 15.0 + standard_parallels: + - 63.3 + - 63.3 diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 9075d404..d62c7356 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -148,12 +148,6 @@ def __init__( times = [self.T0 + dt * i for i in range(n_timesteps)] self.ds.coords["time"] = times - # Add boundary mask - self.ds["boundary_mask"] = xr.DataArray( - np.random.choice([0, 1], size=(n_points_1d, n_points_1d)), - dims=["x", "y"], - ) - # Stack the spatial dimensions into grid_index self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS) @@ -342,22 +336,6 @@ def get_dataarray( dim_order = self.expected_dim_order(category=category) return self.ds[category].transpose(*dim_order) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - return self.ds["boundary_mask"] - def get_xy(self, category: str, stacked: bool) -> ndarray: """Return the x, y coordinates of the dataset. diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 4a4b1100..a91f6245 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -18,8 +18,6 @@ dataarray for the given category. - `get_dataarray` (method): Return the processed data (as a single `xr.DataArray`) for the given category and test/train/val-split. -- `boundary_mask` (property): Return the boundary mask for the dataset, - with spatial dimensions stacked. - `config` (property): Return the configuration of the datastore. In addition BaseRegularGridDatastore must have the following methods and @@ -213,25 +211,6 @@ def test_get_dataarray(datastore_name): assert n_features["train"] == n_features["val"] == n_features["test"] -@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_boundary_mask(datastore_name): - """Check that the `datastore.boundary_mask` property is implemented and - that the returned object is an xarray DataArray with the correct shape.""" - datastore = init_datastore_example(datastore_name) - da_mask = datastore.boundary_mask - - assert isinstance(da_mask, xr.DataArray) - assert set(da_mask.dims) == {"grid_index"} - assert da_mask.dtype == "int" - assert set(da_mask.values) == {0, 1} - assert da_mask.sum() > 0 - assert da_mask.sum() < da_mask.size - - if isinstance(datastore, BaseRegularGridDatastore): - grid_shape = datastore.grid_shape_state - assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y - - @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_xy_extent(datastore_name): """Check that the `datastore.get_xy_extent` method is implemented and that From ae82cdb8360d899b063bdf48a877a42184306cab Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:55:56 +0100 Subject: [PATCH 048/103] added gcsfs dependency for era5 weatherbench download --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index fdcb7f3e..38e7cb0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "torch-geometric==2.3.1", "parse>=1.20.2", "dataclass-wizard<0.31.0", + "gcsfs>=2021.10.0", "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" From 34a6cc7d24ffb218b2aef909cac7db06ffbef618 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:57:57 +0100 Subject: [PATCH 049/103] added new era5 datastore config for boundary --- tests/conftest.py | 19 +++- .../mdp/era5_1000hPa_winds/.gitignore | 2 + .../mdp/era5_1000hPa_winds/config.yaml | 3 + .../era5_1000hPa_winds/era5.datastore.yaml | 90 +++++++++++++++++++ 4 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml diff --git a/tests/conftest.py b/tests/conftest.py index 6f579621..be5cf3e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,15 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) +DATASTORES_BOUNDARY_EXAMPLES = dict( + mdp=( + DATASTORE_EXAMPLES_ROOT_PATH + / "mdp" + / "era5_1000hPa_winds" + / "era5.datastore.yaml" + ) +) + DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore @@ -102,5 +111,13 @@ def init_datastore_example(datastore_kind): datastore_kind=datastore_kind, config_path=DATASTORES_EXAMPLES[datastore_kind], ) - return datastore + + +def init_datastore_boundary_example(datastore_kind): + datastore_boundary = init_datastore( + datastore_kind=datastore_kind, + config_path=DATASTORES_BOUNDARY_EXAMPLES[datastore_kind], + ) + + return datastore_boundary diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore new file mode 100644 index 00000000..f2828f46 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore @@ -0,0 +1,2 @@ +*.zarr/ +graph/ diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml new file mode 100644 index 00000000..5d1e05f2 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml @@ -0,0 +1,3 @@ +datastore: + kind: mdp + config_path: era5.datastore.yaml diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml new file mode 100644 index 00000000..36b39501 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml @@ -0,0 +1,90 @@ +#TODO: What do these versions mean? Should they be updated? +schema_version: v0.2.0+dev +dataset_version: v1.0.0 + +output: + variables: + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-02T00:00 + end: 1990-09-10T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-02T00:00 + end: 1990-09-07T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-05T00:00 + end: 1990-09-08T00:00 + test: + start: 1990-09-06T00:00 + end: 1990-09-10T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + v_component_of_wind: + level: + values: [1000, ] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_surface_net_short_wave_radiation_flux + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 From 2dc67a02e2acad0665452bfe336384de1cc34b4e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:58:36 +0100 Subject: [PATCH 050/103] removed left-over boundary-mask references --- neural_lam/datastore/mdp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index e662cb63..b28d2650 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -27,11 +27,10 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): + def __init__(self, config_path, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at - `config_path`. A boundary mask is created with `n_boundary_points` - boundary points. If `reuse_existing` is True, the dataset is loaded + `config_path`. If `reuse_existing` is True, the dataset is loaded from a zarr file if it exists (unless the config has been modified since the zarr was created), otherwise it is created from the configuration file. @@ -42,8 +41,6 @@ def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): The path to the configuration file, this will be fed to the `mllam_data_prep.Config.from_yaml_file` method to then call `mllam_data_prep.create_dataset` to create the dataset. - n_boundary_points : int - The number of boundary points to use in the boundary mask. reuse_existing : bool Whether to reuse an existing dataset zarr file if it exists and its creation date is newer than the configuration file. @@ -70,7 +67,6 @@ def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): if self._ds is None: self._ds = mdp.create_dataset(config=self._config) self._ds.to_zarr(fp_ds) - self._n_boundary_points = n_boundary_points print("The loaded datastore contains the following features:") for category in ["state", "forcing", "static"]: From 9f8628e03487a80ab3313656857b5fde3e6fde45 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:59:12 +0100 Subject: [PATCH 051/103] make check for existing category in datastore more flexible (for boundary) --- neural_lam/datastore/mdp.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index b28d2650..7b947c20 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -154,8 +154,8 @@ def get_vars_units(self, category: str) -> List[str]: The units of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_units"].values.tolist() @@ -173,8 +173,8 @@ def get_vars_names(self, category: str) -> List[str]: The names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature"].values.tolist() @@ -193,8 +193,8 @@ def get_vars_long_names(self, category: str) -> List[str]: The long names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_long_name"].values.tolist() @@ -249,9 +249,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: The xarray DataArray object with processed dataset. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") - return None + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return [] da_category = self._ds[category] From 388c79df3fdbbaa24ef025621a09dd25ac567ac5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 20 Nov 2024 16:00:15 +0100 Subject: [PATCH 052/103] implement xarray based (mostly) time slicing and windowing --- neural_lam/weather_dataset.py | 255 +++++++++++++++------------------- 1 file changed, 111 insertions(+), 144 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5559e838..555f2c35 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -64,10 +64,16 @@ def __init__( self.da_state = self.datastore.get_dataarray( category="state", split=self.split ) + if self.da_state is None: + raise ValueError( + "A non-empty state dataarray must be provided. " + "The datastore.get_dataarray() returned None or empty array " + "for category='state'" + ) self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) - # XXX For now boundary data is always considered forcing data + # XXX For now boundary data is always considered mdp-forcing data self.da_boundary = self.datastore_boundary.get_dataarray( category="forcing", split=self.split ) @@ -102,53 +108,36 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + def get_time_step(times): + """Calculate the time step from the data""" + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + # Check time step consistency in state data + _ = get_time_step(self.da_state.time.values) + # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: state_times = self.da_state.time state_time_min = state_times.min().values state_time_max = state_times.max().values - def get_time_step(times): - """Calculate the time step from the data""" - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] - if self.da_forcing is not None: + # Forcing data is part of the same datastore as state data + # During creation the time dimension of the forcing data + # is matched to the state data forcing_times = self.da_forcing.time - forcing_time_step = get_time_step(forcing_times.values) - forcing_time_min = forcing_times.min().values - forcing_time_max = forcing_times.max().values - - # Calculate required bounds for forcing using its time step - forcing_required_time_min = ( - state_time_min - - self.num_past_forcing_steps * forcing_time_step - ) - forcing_required_time_max = ( - state_time_max - + self.num_future_forcing_steps * forcing_time_step - ) - - if forcing_time_min > forcing_required_time_min: - raise ValueError( - f"Forcing data starts too late." - f"Required start: {forcing_required_time_min}, " - f"but forcing starts at {forcing_time_min}." - ) - - if forcing_time_max < forcing_required_time_max: - raise ValueError( - f"Forcing data ends too early." - f"Required end: {forcing_required_time_max}," - f"but forcing ends at {forcing_time_max}." - ) + _ = get_time_step(forcing_times.values) if self.da_boundary is not None: + # Boundary data is part of a separate datastore + # The boundary data is allowed to have a different time_step + # Check that the boundary data covers the required time range boundary_times = self.da_boundary.time boundary_time_step = get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values @@ -204,8 +193,8 @@ def get_time_step(times): category="forcing" ) ) - self.da_boundary_mean = self.ds_boundary_stats.boundary_mean - self.da_boundary_std = self.ds_boundary_stats.boundary_std + self.da_boundary_mean = self.ds_boundary_stats.forcing_mean + self.da_boundary_std = self.ds_boundary_stats.forcing_std def __len__(self): if self.datastore.is_forecast: @@ -253,7 +242,7 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_time(self, da_state, da_forcing, idx, n_steps: int): + def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): """ Produce time slices of the given dataarrays `da_state` (state) and `da_forcing` (forcing). For the state data, slicing is done as before @@ -316,8 +305,13 @@ def _slice_time(self, da_state, da_forcing, idx, n_steps: int): ) da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) + if da_forcing is None: + return da_state_sliced, None + # Get the state times for matching state_times = da_state_sliced["time"] + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] # Match forcing data to state times based on nearest neighbor if self.datastore.is_forecast: @@ -371,39 +365,80 @@ def _slice_time(self, da_state, da_forcing, idx, n_steps: int): da_forcing_matched = da_forcing_matched.assign_coords( time_diff=("time", time_diff_steps) ) - else: # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing["time"] # Compute time differences time_deltas = ( - forcing_times.values[:, np.newaxis] - - state_times.values[np.newaxis, :] + state_times.values[np.newaxis, :] + - forcing_times.values[:, np.newaxis] + ) + idx_min = np.abs(time_deltas).argmin(axis=0) + + time_diff_steps = xr.DataArray( + np.stack( + [ + np.diagonal(time_deltas, offset=offset)[ + -len(state_times) + init_steps : + ] + / state_time_step + for offset in range( + -self.num_past_forcing_steps, + self.num_future_forcing_steps + 1, + ) + ], + axis=1, + ), + dims=["time", "window"], + coords={ + "time": state_times.isel(time=slice(init_steps, None)), + "window": np.arange( + -self.num_past_forcing_steps, + self.num_future_forcing_steps + 1, + ), + }, + name="time_diff_steps", ) - time_diffs = np.abs(time_deltas) - idx_min = time_diffs.argmin(axis=0) - # Slice the forcing data using matched indices - da_forcing_matched = da_forcing.isel(time=idx_min) - da_forcing_matched = da_forcing_matched.assign_coords( - time=state_times + # Create window dimension using rolling + window_size = ( + self.num_past_forcing_steps + self.num_future_forcing_steps + 1 ) - - # Calculate time differences in multiples of state time steps - state_time_step = state_times.values[1] - state_times.values[0] - time_diff_steps = ( - time_deltas[idx_min, np.arange(len(state_times))] - / state_time_step + da_forcing_windowed = da_forcing.rolling( + time=window_size, center=True + ).construct(window_dim="window") + da_forcing_matched = da_forcing_windowed.isel( + time=idx_min[init_steps:] ) # Add time difference as a new coordinate da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=("time", time_diff_steps) + time_diff=time_diff_steps ) return da_state_sliced, da_forcing_matched + def _process_windowed_data(self, da_windowed, da_state, da_target_times): + """Helper function to process windowed data after standardization.""" + stacked_dim = "forcing_feature_windowed" + if da_windowed is not None: + # Stack the 'feature' and 'window' dimensions + da_windowed = da_windowed.stack( + {stacked_dim: ("forcing_feature", "window")} + ) + else: + # Create empty DataArray with the correct dimensions and coordinates + return xr.DataArray( + data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), + dims=("time", "grid_index", f"{stacked_dim}"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + f"{stacked_dim}": [], + }, + ) + def _build_item_dataarrays(self, idx): """ Create the dataarrays for the initial states, target states and forcing @@ -459,18 +494,21 @@ def _build_item_dataarrays(self, idx): else: da_boundary = None - # handle time sampling in a way that is compatible with both analysis - # and forecast data - da_state = self._slice_time( - da_state=da_state, idx=idx, n_steps=self.ar_steps + # if da_forcing is None, the function will return None for + # da_forcing_windowed + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, ) - if da_forcing is not None: - da_forcing_windowed = self._slice_time( - da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps - ) + if da_boundary is not None: - da_boundary_windowed = self._slice_time( - da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps + _, da_boundary_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_boundary, ) # load the data into memory @@ -506,83 +544,12 @@ def _build_item_dataarrays(self, idx): da_boundary_windowed - self.da_boundary_mean ) / self.da_boundary_std - if da_forcing is not None: - # Expand 'time_diff' to align with 'forcing_feature' and 'window' - # dimensions 'time_diff' has dimension ('time'), expand to ('time', - # 'forcing_feature', 'window') - time_diff_expanded = da_forcing_windowed["time_diff"].expand_dims( - forcing_feature=da_forcing_windowed["forcing_feature"], - window=da_forcing_windowed["window"], - ) - - # Stack 'forcing_feature' and 'window' into a single - # 'forcing_feature_windowed' dimension - da_forcing_windowed = da_forcing_windowed.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - time_diff_expanded = time_diff_expanded.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - - # Assign 'time_diff' as a coordinate to 'forcing_feature_windowed' - da_forcing_windowed = da_forcing_windowed.assign_coords( - time_diff=( - "forcing_feature_windowed", - time_diff_expanded.values, - ) - ) - else: - # Create an empty forcing tensor with the right shape - da_forcing_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "forcing_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "forcing_feature": [], - }, - ) - - if da_boundary is not None: - # If 'da_boundary_windowed' also has 'time_diff', process similarly - # Expand 'time_diff' to align with 'boundary_feature' and 'window' - # dimensions - time_diff_expanded = da_boundary_windowed["time_diff"].expand_dims( - boundary_feature=da_boundary_windowed["boundary_feature"], - window=da_boundary_windowed["window"], - ) - - # Stack 'boundary_feature' and 'window' into a single - # 'boundary_feature_windowed' dimension - da_boundary_windowed = da_boundary_windowed.stack( - boundary_feature_windowed=("boundary_feature", "window") - ) - time_diff_expanded = time_diff_expanded.stack( - boundary_feature_windowed=("boundary_feature", "window") - ) - - # Assign 'time_diff' as a coordinate to 'boundary_feature_windowed' - da_boundary_windowed = da_boundary_windowed.assign_coords( - time_diff=( - "boundary_feature_windowed", - time_diff_expanded.values, - ) - ) - else: - # Create an empty boundary tensor with the right shape - da_boundary_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "boundary_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "boundary_feature": [], - }, - ) + da_forcing_windowed = self._process_windowed_data( + da_forcing_windowed, da_state, da_target_times + ) + da_boundary_windowed = self._process_windowed_data( + da_boundary_windowed, da_state, da_target_times + ) return ( da_init_states, From 2529969b12eb7babdcfd3311d6eae3045fe1fe15 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 07:09:52 +0100 Subject: [PATCH 053/103] cleanup analysis based time-slicing --- neural_lam/weather_dataset.py | 85 +++++++++++++++++------------------ 1 file changed, 42 insertions(+), 43 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 555f2c35..fd40a2c8 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -245,11 +245,12 @@ def __len__(self): def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing` (forcing). For the state data, slicing is done as before - based on `idx`. For the forcing data, nearest neighbor matching is - performed based on the state times. Additionally, the time difference - between the matched forcing times and state times (in multiples of state - time steps) is added to the forcing dataarray. + `da_forcing` (forcing). For the state data, slicing is done based on + `idx`. For the forcing data, nearest neighbor matching is performed + based on the state times. Additionally, the time difference between the + matched forcing times and state times (in multiples of state time steps) + is added to the forcing dataarray. This will be used as an additional + feature in the model (temporal embedding). Parameters ---------- @@ -269,9 +270,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). da_forcing_matched : xr.DataArray - The forcing dataarray matched to state times with an added - coordinate 'time_diff', representing the time difference to state - times in multiples of state time steps. + The sliced state dataarray with dims ('time', 'grid_index', + 'forcing_feature_windowed'). """ # Number of initial steps required (e.g., for initializing models) init_steps = 2 @@ -308,9 +308,9 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): if da_forcing is None: return da_state_sliced, None - # Get the state times for matching + # Get the state times and its temporal resolution for matching with + # forcing data state_times = da_state_sliced["time"] - # Calculate time differences in multiples of state time steps state_time_step = state_times.values[1] - state_times.values[0] # Match forcing data to state times based on nearest neighbor @@ -369,39 +369,29 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing["time"] - # Compute time differences + # Compute time differences between forcing and state times + # (in multiples of state time steps) + # Retrieve the indices of the closest times in the forcing data time_deltas = ( - state_times.values[np.newaxis, :] - - forcing_times.values[:, np.newaxis] - ) + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) - time_diff_steps = xr.DataArray( - np.stack( - [ - np.diagonal(time_deltas, offset=offset)[ - -len(state_times) + init_steps : - ] - / state_time_step - for offset in range( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ) - ], - axis=1, - ), - dims=["time", "window"], - coords={ - "time": state_times.isel(time=slice(init_steps, None)), - "window": np.arange( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ), - }, - name="time_diff_steps", + time_diff_steps = np.stack( + [ + time_deltas[ + idx_i + - self.num_past_forcing_steps : idx_i + + self.num_future_forcing_steps + + 1, + init_steps + step_i, + ] + for (step_i, idx_i) in enumerate(idx_min[init_steps:]) + ], ) - # Create window dimension using rolling + # Create window dimension for forcing data to stack later window_size = ( self.num_past_forcing_steps + self.num_future_forcing_steps + 1 ) @@ -412,9 +402,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): time=idx_min[init_steps:] ) - # Add time difference as a new coordinate - da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=time_diff_steps + # Add time difference as a new coordinate to concatenate to the + # forcing features later + da_forcing_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, ) return da_state_sliced, da_forcing_matched @@ -423,13 +415,19 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data after standardization.""" stacked_dim = "forcing_feature_windowed" if da_windowed is not None: - # Stack the 'feature' and 'window' dimensions + # Stack the 'feature' and 'window' dimensions and add the + # time step differences to the existing features as a temporal + # embedding da_windowed = da_windowed.stack( {stacked_dim: ("forcing_feature", "window")} ) + da_windowed = xr.concat( + [da_windowed, da_windowed.time_diff_steps], + dim="forcing_feature_windowed", + ) else: # Create empty DataArray with the correct dimensions and coordinates - return xr.DataArray( + da_windowed = xr.DataArray( data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), dims=("time", "grid_index", f"{stacked_dim}"), coords={ @@ -438,6 +436,7 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): f"{stacked_dim}": [], }, ) + return da_windowed def _build_item_dataarrays(self, idx): """ From 179a035ac8b976a74e54ce4f38102addf06ed318 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:59:42 +0100 Subject: [PATCH 054/103] implement datastore_boundary in existing tests --- tests/test_datasets.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 419aece0..67eac70e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -14,12 +14,19 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataset -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) from tests.dummy_datastore import DummyDatastore @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_shapes(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_shapes(datastore_name, datastore_boundary_name): """Check that the `datastore.get_dataarray` method is implemented. Validate the shapes of the tensors match between the different @@ -31,6 +38,9 @@ def test_dataset_item_shapes(datastore_name): """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_gridpoints = datastore.num_grid_points N_pred_steps = 4 @@ -38,6 +48,7 @@ def test_dataset_item_shapes(datastore_name): num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -48,7 +59,7 @@ def test_dataset_item_shapes(datastore_name): # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - init_states, target_states, forcing, target_times = item + init_states, target_states, forcing, boundary, target_times = item # initial states assert init_states.ndim == 3 @@ -81,14 +92,23 @@ def test_dataset_item_shapes(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_create_dataarray_from_tensor(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_create_dataarray_from_tensor( + datastore_name, datastore_boundary_name +): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -158,13 +178,19 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): @pytest.mark.parametrize("split", ["train", "val", "test"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_single_batch(datastore_name, split): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_single_batch(datastore_name, datastore_boundary_name, split): """Check that the `datastore.get_dataarray` method is implemented. And that it returns an xarray DataArray with the correct dimensions. """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) device_name = ( torch.device("cuda") if torch.cuda.is_available() else "cpu" @@ -210,7 +236,9 @@ def _create_graph(): ) ) - dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2) + dataset = WeatherDataset( + datastore=datastore, datastore_boundary=datastore_boundary, split=split + ) model = GraphLAM(args=args, datastore=datastore, config=config) # noqa From 2daeb1642d276730496cc7ab183203ed5abba6ce Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 16:39:27 +0100 Subject: [PATCH 055/103] allow for grid shape retrieval from forcing data --- neural_lam/datastore/mdp.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 7b947c20..809bbdb8 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -380,8 +380,17 @@ def grid_shape_state(self): The shape of the cartesian grid for the state variables. """ - ds_state = self.unstack_grid_coords(self._ds["state"]) - da_x, da_y = ds_state.x, ds_state.y + # Boundary data often has no state features + if "state" not in self._ds: + warnings.warn( + "no state data found in datastore" + "returning grid shape from forcing data" + ) + ds_forcing = self.unstack_grid_coords(self._ds["forcing"]) + da_x, da_y = ds_forcing.x, ds_forcing.y + else: + ds_state = self.unstack_grid_coords(self._ds["state"]) + da_x, da_y = ds_state.x, ds_state.y assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) From cbcdcaee71039977090a66ec2b8b1116063cf2a4 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 16:40:47 +0100 Subject: [PATCH 056/103] rearrange time slicing, boundary first --- neural_lam/weather_dataset.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index fd40a2c8..f172d47f 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -495,13 +495,6 @@ def _build_item_dataarrays(self, idx): # if da_forcing is None, the function will return None for # da_forcing_windowed - da_state, da_forcing_windowed = self._slice_time( - da_state=da_state, - idx=idx, - n_steps=self.ar_steps, - da_forcing=da_forcing, - ) - if da_boundary is not None: _, da_boundary_windowed = self._slice_time( da_state=da_state, @@ -509,6 +502,12 @@ def _build_item_dataarrays(self, idx): n_steps=self.ar_steps, da_forcing=da_boundary, ) + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, + ) # load the data into memory da_state.load() From e6ace2727038d5a472a18e7eab7e6a26b6362fbb Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:42:05 +0100 Subject: [PATCH 057/103] renaming test datastores --- tests/datastore_examples/.gitignore | 3 +- .../.gitignore | 0 .../era5_1000hPa_danra_100m_winds/config.yaml | 12 +++ .../danra.datastore.yaml | 99 +++++++++++++++++++ .../era5.datastore.yaml | 23 ++--- .../mdp/era5_1000hPa_winds/config.yaml | 3 - 6 files changed, 122 insertions(+), 18 deletions(-) rename tests/datastore_examples/mdp/{era5_1000hPa_winds => era5_1000hPa_danra_100m_winds}/.gitignore (100%) create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml rename tests/datastore_examples/mdp/{era5_1000hPa_winds => era5_1000hPa_danra_100m_winds}/era5.datastore.yaml (80%) delete mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore index e84e6493..4fbd2326 100644 --- a/tests/datastore_examples/.gitignore +++ b/tests/datastore_examples/.gitignore @@ -1,2 +1,3 @@ npyfilesmeps/*.zip -npyfilesmeps/meps_example_reduced/ +npyfilesmeps/meps_example_reduced +npyfilesmeps/era5_1000hPa_temp_meps_example_reduced diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore similarity index 100% rename from tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore rename to tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml new file mode 100644 index 00000000..a158bee3 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml @@ -0,0 +1,12 @@ +datastore: + kind: mdp + config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml new file mode 100644 index 00000000..3edf1267 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml @@ -0,0 +1,99 @@ +schema_version: v0.5.0 +dataset_version: v0.1.0 + +output: + variables: + static: [grid_index, static_feature] + state: [time, grid_index, state_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-03T00:00 + end: 1990-09-09T00:00 + step: PT3H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-03T00:00 + end: 1990-09-06T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-06T00:00 + end: 1990-09-07T00:00 + test: + start: 1990-09-07T00:00 + end: 1990-09-09T00:00 + +inputs: + danra_height_levels: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + target_output_variable: state + + danra_surface: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + # use surface incoming shortwave radiation as forcing + - swavr0m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: forcing + + danra_lsm: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr + dims: [x, y] + variables: + - lsm + dim_mapping: + grid_index: + method: stack + dims: [x, y] + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: static + +extra: + projection: + class_name: LambertConformal + kwargs: + central_longitude: 25.0 + central_latitude: 56.7 + standard_parallels: [56.7, 56.7] + globe: + semimajor_axis: 6367470.0 + semiminor_axis: 6367470.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml similarity index 80% rename from tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml rename to tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index 36b39501..c97da4bc 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -1,5 +1,4 @@ -#TODO: What do these versions mean? Should they be updated? -schema_version: v0.2.0+dev +schema_version: v0.5.0 dataset_version: v1.0.0 output: @@ -7,8 +6,8 @@ output: forcing: [time, grid_index, forcing_feature] coord_ranges: time: - start: 1990-09-02T00:00 - end: 1990-09-10T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 step: PT6H chunking: time: 1 @@ -16,17 +15,17 @@ output: dim: time splits: train: - start: 1990-09-02T00:00 - end: 1990-09-07T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 compute_statistics: ops: [mean, std, diff_mean, diff_std] dims: [grid_index, time] val: - start: 1990-09-05T00:00 - end: 1990-09-08T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 test: - start: 1990-09-06T00:00 - end: 1990-09-10T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: @@ -37,10 +36,6 @@ inputs: level: values: [1000,] units: hPa - v_component_of_wind: - level: - values: [1000, ] - units: hPa dim_mapping: time: method: rename diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml deleted file mode 100644 index 5d1e05f2..00000000 --- a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml +++ /dev/null @@ -1,3 +0,0 @@ -datastore: - kind: mdp - config_path: era5.datastore.yaml From 42818f0e91ccebb03c506b00f42e05e7d8d6fdfa Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:44:15 +0100 Subject: [PATCH 058/103] adding num_past/future_boundary_step args --- neural_lam/train_model.py | 37 +++++++++++++++------------------ tests/test_datasets.py | 43 +++++++++++++++++++++++++++++++++------ tests/test_training.py | 24 ++++++++++++++++++++-- 3 files changed, 75 insertions(+), 29 deletions(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 37bf6db7..2a61e86c 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -34,11 +34,6 @@ def main(input_args=None): type=str, help="Path to the configuration for neural-lam", ) - parser.add_argument( - "--config_path_boundary", - type=str, - help="Path to the configuration for boundary conditions", - ) parser.add_argument( "--model", type=str, @@ -208,6 +203,18 @@ def main(input_args=None): default=1, help="Number of future time steps to use as input for forcing data", ) + parser.add_argument( + "--num_past_boundary_steps", + type=int, + default=1, + help="Number of past time steps to use as input for boundary data", + ) + parser.add_argument( + "--num_future_boundary_steps", + type=int, + default=1, + help="Number of future time steps to use as input for boundary data", + ) args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() @@ -217,9 +224,6 @@ def main(input_args=None): assert ( args.config_path is not None ), "Specify your config with --config_path" - assert ( - args.config_path_boundary is not None - ), "Specify your config with --config_path_boundary" assert args.model in MODELS, f"Unknown model: {args.model}" assert args.eval in ( None, @@ -234,21 +238,10 @@ def main(input_args=None): seed.seed_everything(args.seed) # Load neural-lam configuration and datastore to use - config, datastore = load_config_and_datastore(config_path=args.config_path) - config_boundary, datastore_boundary = load_config_and_datastore( - config_path=args.config_path_boundary + config, datastore, datastore_boundary = load_config_and_datastore( + config_path=args.config_path ) - # TODO this should not be required, make more flexible - assert ( - datastore.num_past_forcing_steps - == datastore_boundary.num_past_forcing_steps - ), "Mismatch in num_past_forcing_steps" - assert ( - datastore.num_future_forcing_steps - == datastore_boundary.num_future_forcing_steps - ), "Mismatch in num_future_forcing_steps" - # Create datamodule data_module = WeatherDataModule( datastore=datastore, @@ -258,6 +251,8 @@ def main(input_args=None): standardize=True, num_past_forcing_steps=args.num_past_forcing_steps, num_future_forcing_steps=args.num_future_forcing_steps, + num_past_boundary_steps=args.num_past_boundary_steps, + num_future_boundary_steps=args.num_future_boundary_steps, batch_size=args.batch_size, num_workers=args.num_workers, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 67eac70e..5fbe4a5d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -42,10 +42,13 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): datastore_boundary_name ) N_gridpoints = datastore.num_grid_points + N_gridpoints_boundary = datastore_boundary.num_grid_points N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, datastore_boundary=datastore_boundary, @@ -53,6 +56,8 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) item = dataset[0] @@ -77,8 +82,23 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): assert forcing.ndim == 3 assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints - assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * ( - num_past_forcing_steps + num_future_forcing_steps + 1 + # each stacked forcing feature has one corresponding temporal embedding + assert ( + forcing.shape[2] + == datastore.get_num_data_vars("forcing") + * (num_past_forcing_steps + num_future_forcing_steps + 1) + * 2 + ) + + # boundary + assert boundary.ndim == 3 + assert boundary.shape[0] == N_pred_steps + assert boundary.shape[1] == N_gridpoints_boundary + assert ( + boundary.shape[2] + == datastore_boundary.get_num_data_vars("forcing") + * (num_past_boundary_steps + num_future_boundary_steps + 1) + * 2 ) # batch times @@ -88,6 +108,7 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): # try to get the last item of the dataset to ensure slicing and stacking # operations are working as expected and are consistent with the dataset # length + dataset[len(dataset) - 1] @@ -106,6 +127,9 @@ def test_dataset_item_create_dataarray_from_tensor( N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 + dataset = WeatherDataset( datastore=datastore, datastore_boundary=datastore_boundary, @@ -113,16 +137,22 @@ def test_dataset_item_create_dataarray_from_tensor( ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) idx = 0 # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - _, target_states, _, target_times_arr = dataset[idx] - _, da_target_true, _, da_target_times_true = dataset._build_item_dataarrays( - idx=idx - ) + _, target_states, _, _, target_times_arr = dataset[idx] + ( + _, + da_target_true, + _, + _, + da_target_times_true, + ) = dataset._build_item_dataarrays(idx=idx) target_times = np.array(target_times_arr, dtype="datetime64[ns]") np.testing.assert_equal(target_times, da_target_times_true.values) @@ -272,6 +302,7 @@ def test_dataset_length(dataset_config): dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=dataset_config["ar_steps"], num_past_forcing_steps=dataset_config["past"], diff --git a/tests/test_training.py b/tests/test_training.py index 1ed1847d..28566a4b 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -14,18 +14,33 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataModule -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_training(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( f"Skipping test for {datastore_name} as it is not a regular " "grid datastore." ) + if not isinstance(datastore_boundary, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_boundary_name} as it is not a regular " + "grid datastore." + ) if torch.cuda.is_available(): device_name = "cuda" @@ -59,6 +74,7 @@ def test_training(datastore_name): data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=3, ar_steps_eval=5, standardize=True, @@ -66,6 +82,8 @@ def test_training(datastore_name): num_workers=1, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, ) class ModelArgs: @@ -85,6 +103,8 @@ class ModelArgs: metrics_watch = [] num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 model_args = ModelArgs() From 0103b6e70927cb53e59b77c30245d3fa8139f8ed Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:44:51 +0100 Subject: [PATCH 059/103] using combined config file --- neural_lam/config.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/neural_lam/config.py b/neural_lam/config.py index d3e09697..914ebb38 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -168,4 +168,15 @@ def load_config_and_datastore( datastore_kind=config.datastore.kind, config_path=datastore_config_path ) - return config, datastore + if config.datastore_boundary is not None: + datastore_boundary_config_path = ( + Path(config_path).parent / config.datastore_boundary.config_path + ) + datastore_boundary = init_datastore( + datastore_kind=config.datastore_boundary.kind, + config_path=datastore_boundary_config_path, + ) + else: + datastore_boundary = None + + return config, datastore, datastore_boundary From 089634447df0c2704670df900fc4733a727fce38 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:45:12 +0100 Subject: [PATCH 060/103] proper handling of state/forcing/boundary in dataset --- neural_lam/weather_dataset.py | 304 +++++++++++++++++++--------------- 1 file changed, 167 insertions(+), 137 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index f172d47f..7dbe0567 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -38,6 +38,16 @@ class WeatherDataset(torch.utils.data.Dataset): forcing from times t, t+1, ..., t+j-1, t+j (and potentially times before t, given num_past_forcing_steps) are included as forcing inputs at time t. Default is 1. + num_past_boundary_steps: int, optional + Number of past time steps to include in boundary input. If set to i, + boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, + given num_future_forcing_steps) are included as boundary inputs at time t + Default is 1. + num_future_boundary_steps: int, optional + Number of future time steps to include in boundary input. If set to j, + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times before + t, given num_past_forcing_steps) are included as boundary inputs at time + t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -50,6 +60,8 @@ def __init__( ar_steps=3, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, standardize=True, ): super().__init__() @@ -60,10 +72,10 @@ def __init__( self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps - self.da_state = self.datastore.get_dataarray( - category="state", split=self.split - ) + self.da_state = self.datastore.get_dataarray(category="state", split=self.split) if self.da_state is None: raise ValueError( "A non-empty state dataarray must be provided. " @@ -74,9 +86,12 @@ def __init__( category="forcing", split=self.split ) # XXX For now boundary data is always considered mdp-forcing data - self.da_boundary = self.datastore_boundary.get_dataarray( - category="forcing", split=self.split - ) + if self.datastore_boundary is not None: + self.da_boundary = self.datastore_boundary.get_dataarray( + category="forcing", split=self.split + ) + else: + self.da_boundary = None # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -97,9 +112,7 @@ def __init__( parts["forcing"] = self.da_forcing for part, da in parts.items(): - expected_dim_order = self.datastore.expected_dim_order( - category=part - ) + expected_dim_order = self.datastore.expected_dim_order(category=part) if da.dims != expected_dim_order: raise ValueError( f"The dimension order of the `{part}` data ({da.dims}) " @@ -108,6 +121,23 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # handling ensemble data + if self.datastore.is_ensemble: + # for the now the strategy is to only include the first ensemble + # member + # XXX: this could be changed to include all ensemble members by + # splitting `idx` into two parts, one for the analysis time and one + # for the ensemble member and then increasing self.__len__ to + # include all ensemble members + warnings.warn( + "only use of ensemble member 0 (the first member) is " + "implemented for ensemble data" + ) + i_ensemble = 0 + self.da_state = self.da_state.isel(ensemble_member=i_ensemble) + else: + self.da_state = self.da_state + def get_time_step(times): """Calculate the time step from the data""" time_diffs = np.diff(times) @@ -119,11 +149,18 @@ def get_time_step(times): return time_diffs[0] # Check time step consistency in state data - _ = get_time_step(self.da_state.time.values) + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time + _ = get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: - state_times = self.da_state.time + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time state_time_min = state_times.min().values state_time_max = state_times.max().values @@ -131,26 +168,30 @@ def get_time_step(times): # Forcing data is part of the same datastore as state data # During creation the time dimension of the forcing data # is matched to the state data - forcing_times = self.da_forcing.time - _ = get_time_step(forcing_times.values) + if self.datastore.is_forecast: + forcing_times = self.da_forcing.analysis_time + else: + forcing_times = self.da_forcing.time + get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore # The boundary data is allowed to have a different time_step # Check that the boundary data covers the required time range - boundary_times = self.da_boundary.time + if self.datastore_boundary.is_forecast: + boundary_times = self.da_boundary.analysis_time + else: + boundary_times = self.da_boundary.time boundary_time_step = get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values # Calculate required bounds for boundary using its time step boundary_required_time_min = ( - state_time_min - - self.num_past_forcing_steps * boundary_time_step + state_time_min - self.num_past_forcing_steps * boundary_time_step ) boundary_required_time_max = ( - state_time_max - + self.num_future_forcing_steps * boundary_time_step + state_time_max + self.num_future_forcing_steps * boundary_time_step ) if boundary_time_min > boundary_required_time_min: @@ -179,10 +220,8 @@ def get_time_step(times): self.da_state_std = self.ds_state_stats.state_std if self.da_forcing is not None: - self.ds_forcing_stats = ( - self.datastore.get_standardization_dataarray( - category="forcing" - ) + self.ds_forcing_stats = self.datastore.get_standardization_dataarray( + category="forcing" ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std @@ -208,7 +247,7 @@ def __len__(self): warnings.warn( "only using first ensemble member, so dataset size is " " effectively reduced by the number of ensemble members " - f"({self.da_state.ensemble_member.size})", + f"({self.datastore._num_ensemble_members})", UserWarning, ) @@ -242,36 +281,50 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): + def _slice_time( + self, + da_state, + idx, + n_steps: int, + da_forcing_boundary=None, + num_past_steps=None, + num_future_steps=None, + ): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing` (forcing). For the state data, slicing is done based on - `idx`. For the forcing data, nearest neighbor matching is performed - based on the state times. Additionally, the time difference between the - matched forcing times and state times (in multiples of state time steps) - is added to the forcing dataarray. This will be used as an additional - feature in the model (temporal embedding). + `da_forcing_boundary`. For the state data, slicing is done + based on `idx`. For the forcing/boundary data, nearest neighbor matching + is performed based on the state times. Additionally, the time difference + between the matched forcing/boundary times and state times (in multiples + of state time steps) is added to the forcing dataarray. This will be + used as an additional feature in the model (temporal embedding). Parameters ---------- da_state : xr.DataArray The state dataarray to slice. - da_forcing : xr.DataArray - The forcing dataarray to slice. idx : int The index of the time step to start the sample from in the state data. n_steps : int The number of time steps to include in the sample. + da_forcing_boundary : xr.DataArray + The forcing/boundary dataarray to slice. + num_past_steps : int, optional + The number of past time steps to include in the forcing/boundary + data. Default is `None`. + num_future_steps : int, optional + The number of future time steps to include in the forcing/boundary + data. Default is `None`. Returns ------- da_state_sliced : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). - da_forcing_matched : xr.DataArray + da_forcing_boundary_matched : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', - 'forcing_feature_windowed'). + 'forcing/boundary_feature_windowed'). """ # Number of initial steps required (e.g., for initializing models) init_steps = 2 @@ -279,8 +332,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): # Slice the state data as before if self.datastore.is_forecast: # Calculate start and end indices for slicing - start_idx = max(0, self.num_past_forcing_steps - init_steps) - end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps + start_idx = max(0, num_past_steps - init_steps) + end_idx = max(init_steps, num_past_steps) + n_steps # Slice the state data over the elapsed forecast duration da_state_sliced = da_state.isel( @@ -299,13 +352,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): else: # For analysis data, slice the time dimension directly - start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) - end_idx = ( - idx + max(init_steps, self.num_past_forcing_steps) + n_steps - ) + start_idx = idx + max(0, num_past_steps - init_steps) + end_idx = idx + max(init_steps, num_past_steps) + n_steps da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - if da_forcing is None: + if da_forcing_boundary is None: return da_state_sliced, None # Get the state times and its temporal resolution for matching with @@ -313,78 +364,66 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): state_times = da_state_sliced["time"] state_time_step = state_times.values[1] - state_times.values[0] - # Match forcing data to state times based on nearest neighbor - if self.datastore.is_forecast: - # Calculate all possible forcing times - forcing_times = ( - da_forcing.analysis_time + da_forcing.elapsed_forecast_duration - ) - forcing_times_flat = forcing_times.stack( - forecast_time=("analysis_time", "elapsed_forecast_duration") - ) + if "analysis_time" in da_forcing_boundary.dims: + idx = np.abs( + da_forcing_boundary.analysis_time.values + - self.da_state.analysis_time.values[idx] + ).argmin() + # Add a 'time' dimension using the actual forecast times + offset = max(init_steps, num_past_steps) + da_list = [] + for step in range(n_steps): + start_idx = offset + step - num_past_steps + end_idx = offset + step + num_future_steps + + current_time = ( + da_forcing_boundary.analysis_time[idx] + + da_forcing_boundary.elapsed_forecast_duration[offset + step] + ) - # Compute time differences - time_deltas = ( - forcing_times_flat.values[:, np.newaxis] - - state_times.values[np.newaxis, :] - ) - time_diffs = np.abs(time_deltas) - idx_min = time_diffs.argmin(axis=0) - - # Retrieve corresponding indices for analysis_time and - # elapsed_forecast_duration - forecast_time_index = forcing_times_flat["forecast_time"][idx_min] - analysis_time_indices = forecast_time_index["analysis_time"] - elapsed_forecast_duration_indices = forecast_time_index[ - "elapsed_forecast_duration" - ] - - # Slice the forcing data using matched indices - da_forcing_matched = da_forcing.isel( - analysis_time=("time", analysis_time_indices), - elapsed_forecast_duration=( - "time", - elapsed_forecast_duration_indices, - ), - ) + da_sliced = da_forcing_boundary.isel( + analysis_time=idx, + elapsed_forecast_duration=slice(start_idx, end_idx + 1), + ) - # Assign matched state times to the forcing data - da_forcing_matched["time"] = state_times - da_forcing_matched = da_forcing_matched.swap_dims( - {"elapsed_forecast_duration": "time"} - ) + da_sliced = da_sliced.rename({"elapsed_forecast_duration": "window"}) + da_sliced = da_sliced.assign_coords( + window=np.arange(-num_past_steps, num_future_steps + 1) + ) - # Calculate time differences in multiples of state time steps - state_time_step = state_times.values[1] - state_times.values[0] - time_diff_steps = ( - time_deltas[idx_min, np.arange(len(state_times))] - / state_time_step - ) + da_sliced = da_sliced.expand_dims(dim={"time": [current_time.values]}) + + da_list.append(da_sliced) - # Add time difference as a new coordinate - da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=("time", time_diff_steps) + # Concatenate the list of DataArrays along the 'time' dimension + da_forcing_boundary_matched = xr.concat(da_list, dim="time") + forcing_time_step = ( + da_forcing_boundary_matched.time.values[1] + - da_forcing_boundary_matched.time.values[0] ) + da_forcing_boundary_matched["window"] = da_forcing_boundary_matched["window"] * ( + forcing_time_step / state_time_step + ) + time_diff_steps = da_forcing_boundary_matched.isel( + grid_index=0, forcing_feature=0 + ).data + else: # For analysis data, match directly using the 'time' coordinate - forcing_times = da_forcing["time"] + forcing_times = da_forcing_boundary["time"] # Compute time differences between forcing and state times # (in multiples of state time steps) # Retrieve the indices of the closest times in the forcing data time_deltas = ( - forcing_times.values[:, np.newaxis] - - state_times.values[np.newaxis, :] + forcing_times.values[:, np.newaxis] - state_times.values[np.newaxis, :] ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) time_diff_steps = np.stack( [ time_deltas[ - idx_i - - self.num_past_forcing_steps : idx_i - + self.num_future_forcing_steps - + 1, + idx_i - num_past_steps : idx_i + num_future_steps + 1, init_steps + step_i, ] for (step_i, idx_i) in enumerate(idx_min[init_steps:]) @@ -392,24 +431,22 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): ) # Create window dimension for forcing data to stack later - window_size = ( - self.num_past_forcing_steps + self.num_future_forcing_steps + 1 - ) - da_forcing_windowed = da_forcing.rolling( - time=window_size, center=True + window_size = num_past_steps + num_future_steps + 1 + da_forcing_boundary_windowed = da_forcing_boundary.rolling( + time=window_size, center=False ).construct(window_dim="window") - da_forcing_matched = da_forcing_windowed.isel( + da_forcing_boundary_matched = da_forcing_boundary_windowed.isel( time=idx_min[init_steps:] ) - # Add time difference as a new coordinate to concatenate to the - # forcing features later - da_forcing_matched["time_diff_steps"] = ( - ("time", "window"), - time_diff_steps, - ) + # Add time difference as a new coordinate to concatenate to the + # forcing features later + da_forcing_boundary_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, + ) - return da_state_sliced, da_forcing_matched + return da_state_sliced, da_forcing_boundary_matched def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data after standardization.""" @@ -462,23 +499,7 @@ def _build_item_dataarrays(self, idx): da_target_times : xr.DataArray The dataarray for the target times. """ - # handling ensemble data - if self.datastore.is_ensemble: - # for the now the strategy is to only include the first ensemble - # member - # XXX: this could be changed to include all ensemble members by - # splitting `idx` into two parts, one for the analysis time and one - # for the ensemble member and then increasing self.__len__ to - # include all ensemble members - warnings.warn( - "only use of ensemble member 0 (the first member) is " - "implemented for ensemble data" - ) - i_ensemble = 0 - da_state = self.da_state.isel(ensemble_member=i_ensemble) - else: - da_state = self.da_state - + da_state = self.da_state if self.da_forcing is not None: if "ensemble_member" in self.da_forcing.dims: raise NotImplementedError( @@ -500,13 +521,19 @@ def _build_item_dataarrays(self, idx): da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing=da_boundary, + da_forcing_boundary=da_boundary, + num_future_steps=self.num_future_boundary_steps, + num_past_steps=self.num_past_boundary_steps, ) + else: + da_boundary_windowed = None da_state, da_forcing_windowed = self._slice_time( da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing=da_forcing, + da_forcing_boundary=da_forcing, + num_future_steps=self.num_future_forcing_steps, + num_past_steps=self.num_past_forcing_steps, ) # load the data into memory @@ -521,9 +548,7 @@ def _build_item_dataarrays(self, idx): da_target_times = da_target_states.time if self.standardize: - da_init_states = ( - da_init_states - self.da_state_mean - ) / self.da_state_std + da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean ) / self.da_state_std @@ -595,9 +620,7 @@ def __getitem__(self, idx): tensor_dtype = torch.float32 init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) - target_states = torch.tensor( - da_target_states.values, dtype=tensor_dtype - ) + target_states = torch.tensor(da_target_states.values, dtype=tensor_dtype) target_times = torch.tensor( da_target_times.astype("datetime64[ns]").astype("int64").values, @@ -708,10 +731,7 @@ def _is_listlike(obj): ) for grid_coord in ["x", "y"]: - if ( - grid_coord in da_datastore_state.coords - and grid_coord not in da.coords - ): + if grid_coord in da_datastore_state.coords and grid_coord not in da.coords: da.coords[grid_coord] = da_datastore_state[grid_coord] if not add_time_as_dim: @@ -732,6 +752,8 @@ def __init__( standardize=True, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, batch_size=4, num_workers=16, ): @@ -740,6 +762,8 @@ def __init__( self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps self.ar_steps_train = ar_steps_train self.ar_steps_eval = ar_steps_eval self.standardize = standardize @@ -766,6 +790,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) self.val_dataset = WeatherDataset( datastore=self._datastore, @@ -775,6 +801,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) if stage == "test" or stage is None: @@ -786,6 +814,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) def train_dataloader(self): From 355423c8412677823db63d34ad4b2649abcf1478 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:45:35 +0100 Subject: [PATCH 061/103] datastore_boundars=None introduced --- .../datastore/npyfilesmeps/compute_standardization_stats.py | 1 + 1 file changed, 1 insertion(+) diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index f2c80e8a..4207812f 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -172,6 +172,7 @@ def main( ar_steps = 63 ds = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=ar_steps, standardize=False, From 121d460930fd24ae0ff90dd0d07279c75a15b1d5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:02 +0100 Subject: [PATCH 062/103] bug fix for file retrieval per member --- neural_lam/datastore/npyfilesmeps/store.py | 51 +++++++++------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 146b0627..7ee583be 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -244,9 +244,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # them separately features = ["toa_downwelling_shortwave_flux", "open_water_fraction"] das = [ - self._get_single_timeseries_dataarray( - features=[feature], split=split - ) + self._get_single_timeseries_dataarray(features=[feature], split=split) for feature in features ] da = xr.concat(das, dim="feature") @@ -259,9 +257,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # variable is turned into a dask array and so execution of the # calculation is delayed until the feature values are actually # used. - da_forecast_time = ( - da.analysis_time + da.elapsed_forecast_duration - ).chunk({"elapsed_forecast_duration": 1}) + da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk( + {"elapsed_forecast_duration": 1} + ) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -339,10 +337,7 @@ def _get_single_timeseries_dataarray( for all categories of data """ - if ( - set(features).difference(self.get_vars_names(category="static")) - == set() - ): + if set(features).difference(self.get_vars_names(category="static")) == set(): assert split in ( "train", "val", @@ -356,12 +351,8 @@ def _get_single_timeseries_dataarray( "test", ), f"Unknown dataset split {split} for features {features}" - if member is not None and features != self.get_vars_names( - category="state" - ): - raise ValueError( - "Member can only be specified for the 'state' category" - ) + if member is not None and features != self.get_vars_names(category="state"): + raise ValueError("Member can only be specified for the 'state' category") concat_axis = 0 @@ -377,9 +368,7 @@ def _get_single_timeseries_dataarray( fp_samples = self.root_path / "samples" / split if self._remove_state_features_with_index: n_to_drop = len(self._remove_state_features_with_index) - feature_dim_mask = np.ones( - len(features) + n_to_drop, dtype=bool - ) + feature_dim_mask = np.ones(len(features) + n_to_drop, dtype=bool) feature_dim_mask[self._remove_state_features_with_index] = False elif features == ["toa_downwelling_shortwave_flux"]: filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT @@ -445,7 +434,7 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split) + coord_values = self._get_analysis_times(split=split, member_id=member) elif d == "y": coord_values = y elif d == "x": @@ -464,9 +453,7 @@ def _get_single_timeseries_dataarray( if features_vary_with_analysis_time: filepaths = [ fp_samples - / filename_format.format( - analysis_time=analysis_time, **file_params - ) + / filename_format.format(analysis_time=analysis_time, **file_params) for analysis_time in coords["analysis_time"] ] else: @@ -505,7 +492,7 @@ def _get_single_timeseries_dataarray( return da - def _get_analysis_times(self, split) -> List[np.datetime64]: + def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: """Get the analysis times for the given split by parsing the filenames of all the files found for the given split. @@ -513,6 +500,8 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: ---------- split : str The dataset split to get the analysis times for. + member_id : int + The ensemble member to get the analysis times for. Returns ------- @@ -520,8 +509,12 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: The analysis times for the given split. """ + if member_id is None: + # Only interior state data files have member_id, to avoid duplicates + # we only look at the first member for all other categories + member_id = 0 pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT) - pattern = re.sub(r"{member_id:[^}]*}", "*", pattern) + pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern) sample_dir = self.root_path / "samples" / split sample_files = sample_dir.glob(pattern) @@ -531,9 +524,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: times.append(name_parts["analysis_time"]) if len(times) == 0: - raise ValueError( - f"No files found in {sample_dir} with pattern {pattern}" - ) + raise ValueError(f"No files found in {sample_dir} with pattern {pattern}") return times @@ -690,9 +681,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ def load_pickled_tensor(fn): - return torch.load( - self.root_path / "static" / fn, weights_only=True - ).numpy() + return torch.load(self.root_path / "static" / fn, weights_only=True).numpy() mean_diff_values = None std_diff_values = None From 7e82eef5d797c76a7667271603e5ea94a3485ac2 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:17 +0100 Subject: [PATCH 063/103] rename datastore for tests --- tests/conftest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index be5cf3e7..90a86d0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,14 +94,14 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) -DATASTORES_BOUNDARY_EXAMPLES = dict( - mdp=( +DATASTORES_BOUNDARY_EXAMPLES = { + "mdp": ( DATASTORE_EXAMPLES_ROOT_PATH / "mdp" - / "era5_1000hPa_winds" + / "era5_1000hPa_danra_100m_winds" / "era5.datastore.yaml" - ) -) + ), +} DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore From 320d7c4826e4055fef0edfa748c3e7b6704c589a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:31 +0100 Subject: [PATCH 064/103] aligned time with danra for easier boundary testing --- tests/dummy_datastore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index d62c7356..a958b8f5 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -28,7 +28,7 @@ class DummyDatastore(BaseRegularGridDatastore): """ SHORT_NAME = "dummydata" - T0 = isodate.parse_datetime("2021-01-01T00:00:00") + T0 = isodate.parse_datetime("1990-09-02T00:00:00") N_FEATURES = dict(state=5, forcing=2, static=1) CARTESIAN_COORDS = ["x", "y"] From f18dcc2340434ce96f709ba987af482d063de4e5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:50 +0100 Subject: [PATCH 065/103] Fixed test for temporal embedding --- tests/test_time_slicing.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 29161505..2f5ed96c 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -40,9 +40,7 @@ def get_dataarray(self, category, split): if self.is_forecast: raise NotImplementedError() else: - da = xr.DataArray( - values, dims=["time"], coords={"time": self._time_values} - ) + da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values}) # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") @@ -78,10 +76,8 @@ def get_vars_long_names(self, category): def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): - # state and forcing variables have only on dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange( - len(ANALYSIS_STATE_VALUES) - ) + # state and forcing variables have only one dimension, `time` + time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -93,6 +89,7 @@ def test_time_slicing_analysis( dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, ar_steps=ar_steps, num_future_forcing_steps=num_future_forcing_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -101,9 +98,7 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _ = [ - tensor.numpy() for tensor in sample - ] + init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] expected_init_states = [0, 1] if ar_steps == 3: @@ -130,7 +125,7 @@ def test_time_slicing_analysis( # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) - # forcing: (ar_steps, N_grid, d_windowed_forcing) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) # target_times: (ar_steps,) assert init_states.shape == (2, 1, 1) assert init_states[:, 0, 0].tolist() == expected_init_states @@ -141,6 +136,10 @@ def test_time_slicing_analysis( assert forcing.shape == ( 3, 1, - 1 + num_past_forcing_steps + num_future_forcing_steps, + # Factor 2 because each window step has a temporal embedding + (1 + num_past_forcing_steps + num_future_forcing_steps) * 2, + ) + np.testing.assert_equal( + forcing[:, 0, : num_past_forcing_steps + num_future_forcing_steps + 1], + np.array(expected_forcing_values), ) - np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values)) From e6327d88373bb2708733f6331aebe407facc1f67 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 10:40:48 +0100 Subject: [PATCH 066/103] allow boundary as input to ar_model.common_step --- neural_lam/models/ar_model.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 710efcec..331966e4 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -110,7 +110,9 @@ def __init__( self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim - + num_forcing_vars + # Factor 2 because of temporal embedding or windowed features + + 2 + * num_forcing_vars * (num_past_forcing_steps + num_future_forcing_steps + 1) ) @@ -241,19 +243,20 @@ def unroll_prediction(self, init_states, forcing_features, true_states): def common_step(self, batch): """ - Predict on single batch batch consists of: init_states: (B, 2, - num_grid_nodes, d_features) target_states: (B, pred_steps, - num_grid_nodes, d_features) forcing_features: (B, pred_steps, - num_grid_nodes, d_forcing), - where index 0 corresponds to index 1 of init_states + Predict on single batch batch consists of: + init_states: (B, 2,num_grid_nodes, d_features) + target_states: (B, pred_steps,num_grid_nodes, d_features) + forcing_features: (B, pred_steps,num_grid_nodes, d_forcing) + boundary_features: (B, pred_steps,num_grid_nodes, d_boundaries) + batch_times: (B, pred_steps) """ - (init_states, target_states, forcing_features, batch_times) = batch + (init_states, target_states, forcing_features, _, batch_times) = batch prediction, pred_std = self.unroll_prediction( init_states, forcing_features, target_states - ) # (B, pred_steps, num_grid_nodes, d_f) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) + ) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) return prediction, target_states, pred_std, batch_times From 1374a1976f002ffba86c7c203c6fbb2bea83fb0e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 10:40:56 +0100 Subject: [PATCH 067/103] linting --- neural_lam/datastore/npyfilesmeps/store.py | 43 ++++++++---- neural_lam/weather_dataset.py | 66 ++++++++++++------- .../era5.datastore.yaml | 2 +- tests/test_time_slicing.py | 12 +++- tests/test_training.py | 17 ++--- 5 files changed, 91 insertions(+), 49 deletions(-) diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 7ee583be..24349e7e 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -244,7 +244,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # them separately features = ["toa_downwelling_shortwave_flux", "open_water_fraction"] das = [ - self._get_single_timeseries_dataarray(features=[feature], split=split) + self._get_single_timeseries_dataarray( + features=[feature], split=split + ) for feature in features ] da = xr.concat(das, dim="feature") @@ -257,9 +259,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # variable is turned into a dask array and so execution of the # calculation is delayed until the feature values are actually # used. - da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk( - {"elapsed_forecast_duration": 1} - ) + da_forecast_time = ( + da.analysis_time + da.elapsed_forecast_duration + ).chunk({"elapsed_forecast_duration": 1}) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -337,7 +339,10 @@ def _get_single_timeseries_dataarray( for all categories of data """ - if set(features).difference(self.get_vars_names(category="static")) == set(): + if ( + set(features).difference(self.get_vars_names(category="static")) + == set() + ): assert split in ( "train", "val", @@ -351,8 +356,12 @@ def _get_single_timeseries_dataarray( "test", ), f"Unknown dataset split {split} for features {features}" - if member is not None and features != self.get_vars_names(category="state"): - raise ValueError("Member can only be specified for the 'state' category") + if member is not None and features != self.get_vars_names( + category="state" + ): + raise ValueError( + "Member can only be specified for the 'state' category" + ) concat_axis = 0 @@ -368,7 +377,9 @@ def _get_single_timeseries_dataarray( fp_samples = self.root_path / "samples" / split if self._remove_state_features_with_index: n_to_drop = len(self._remove_state_features_with_index) - feature_dim_mask = np.ones(len(features) + n_to_drop, dtype=bool) + feature_dim_mask = np.ones( + len(features) + n_to_drop, dtype=bool + ) feature_dim_mask[self._remove_state_features_with_index] = False elif features == ["toa_downwelling_shortwave_flux"]: filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT @@ -434,7 +445,9 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split, member_id=member) + coord_values = self._get_analysis_times( + split=split, member_id=member + ) elif d == "y": coord_values = y elif d == "x": @@ -453,7 +466,9 @@ def _get_single_timeseries_dataarray( if features_vary_with_analysis_time: filepaths = [ fp_samples - / filename_format.format(analysis_time=analysis_time, **file_params) + / filename_format.format( + analysis_time=analysis_time, **file_params + ) for analysis_time in coords["analysis_time"] ] else: @@ -524,7 +539,9 @@ def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: times.append(name_parts["analysis_time"]) if len(times) == 0: - raise ValueError(f"No files found in {sample_dir} with pattern {pattern}") + raise ValueError( + f"No files found in {sample_dir} with pattern {pattern}" + ) return times @@ -681,7 +698,9 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ def load_pickled_tensor(fn): - return torch.load(self.root_path / "static" / fn, weights_only=True).numpy() + return torch.load( + self.root_path / "static" / fn, weights_only=True + ).numpy() mean_diff_values = None std_diff_values = None diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 7dbe0567..60f8d316 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -41,13 +41,13 @@ class WeatherDataset(torch.utils.data.Dataset): num_past_boundary_steps: int, optional Number of past time steps to include in boundary input. If set to i, boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, - given num_future_forcing_steps) are included as boundary inputs at time t - Default is 1. + given num_future_forcing_steps) are included as boundary inputs at time + t Default is 1. num_future_boundary_steps: int, optional Number of future time steps to include in boundary input. If set to j, - boundary from times t, t+1, ..., t+j-1, t+j (and potentially times before - t, given num_past_forcing_steps) are included as boundary inputs at time - t. Default is 1. + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times + before t, given num_past_forcing_steps) are included as boundary inputs + at time t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -75,7 +75,9 @@ def __init__( self.num_past_boundary_steps = num_past_boundary_steps self.num_future_boundary_steps = num_future_boundary_steps - self.da_state = self.datastore.get_dataarray(category="state", split=self.split) + self.da_state = self.datastore.get_dataarray( + category="state", split=self.split + ) if self.da_state is None: raise ValueError( "A non-empty state dataarray must be provided. " @@ -112,7 +114,9 @@ def __init__( parts["forcing"] = self.da_forcing for part, da in parts.items(): - expected_dim_order = self.datastore.expected_dim_order(category=part) + expected_dim_order = self.datastore.expected_dim_order( + category=part + ) if da.dims != expected_dim_order: raise ValueError( f"The dimension order of the `{part}` data ({da.dims}) " @@ -188,10 +192,12 @@ def get_time_step(times): # Calculate required bounds for boundary using its time step boundary_required_time_min = ( - state_time_min - self.num_past_forcing_steps * boundary_time_step + state_time_min + - self.num_past_forcing_steps * boundary_time_step ) boundary_required_time_max = ( - state_time_max + self.num_future_forcing_steps * boundary_time_step + state_time_max + + self.num_future_forcing_steps * boundary_time_step ) if boundary_time_min > boundary_required_time_min: @@ -220,8 +226,10 @@ def get_time_step(times): self.da_state_std = self.ds_state_stats.state_std if self.da_forcing is not None: - self.ds_forcing_stats = self.datastore.get_standardization_dataarray( - category="forcing" + self.ds_forcing_stats = ( + self.datastore.get_standardization_dataarray( + category="forcing" + ) ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std @@ -378,7 +386,9 @@ def _slice_time( current_time = ( da_forcing_boundary.analysis_time[idx] - + da_forcing_boundary.elapsed_forecast_duration[offset + step] + + da_forcing_boundary.elapsed_forecast_duration[ + offset + step + ] ) da_sliced = da_forcing_boundary.isel( @@ -386,12 +396,16 @@ def _slice_time( elapsed_forecast_duration=slice(start_idx, end_idx + 1), ) - da_sliced = da_sliced.rename({"elapsed_forecast_duration": "window"}) + da_sliced = da_sliced.rename( + {"elapsed_forecast_duration": "window"} + ) da_sliced = da_sliced.assign_coords( window=np.arange(-num_past_steps, num_future_steps + 1) ) - da_sliced = da_sliced.expand_dims(dim={"time": [current_time.values]}) + da_sliced = da_sliced.expand_dims( + dim={"time": [current_time.values]} + ) da_list.append(da_sliced) @@ -401,13 +415,13 @@ def _slice_time( da_forcing_boundary_matched.time.values[1] - da_forcing_boundary_matched.time.values[0] ) - da_forcing_boundary_matched["window"] = da_forcing_boundary_matched["window"] * ( - forcing_time_step / state_time_step - ) + da_forcing_boundary_matched["window"] = da_forcing_boundary_matched[ + "window" + ] * (forcing_time_step / state_time_step) time_diff_steps = da_forcing_boundary_matched.isel( grid_index=0, forcing_feature=0 ).data - + else: # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing_boundary["time"] @@ -416,7 +430,8 @@ def _slice_time( # (in multiples of state time steps) # Retrieve the indices of the closest times in the forcing data time_deltas = ( - forcing_times.values[:, np.newaxis] - state_times.values[np.newaxis, :] + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) @@ -548,7 +563,9 @@ def _build_item_dataarrays(self, idx): da_target_times = da_target_states.time if self.standardize: - da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std + da_init_states = ( + da_init_states - self.da_state_mean + ) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean ) / self.da_state_std @@ -620,7 +637,9 @@ def __getitem__(self, idx): tensor_dtype = torch.float32 init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) - target_states = torch.tensor(da_target_states.values, dtype=tensor_dtype) + target_states = torch.tensor( + da_target_states.values, dtype=tensor_dtype + ) target_times = torch.tensor( da_target_times.astype("datetime64[ns]").astype("int64").values, @@ -731,7 +750,10 @@ def _is_listlike(obj): ) for grid_coord in ["x", "y"]: - if grid_coord in da_datastore_state.coords and grid_coord not in da.coords: + if ( + grid_coord in da_datastore_state.coords + and grid_coord not in da.coords + ): da.coords[grid_coord] = da_datastore_state[grid_coord] if not add_time_as_dim: diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index c97da4bc..7c5ffb3b 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -25,7 +25,7 @@ output: end: 2022-09-30T00:00 test: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 2f5ed96c..4a59c81e 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -40,7 +40,9 @@ def get_dataarray(self, category, split): if self.is_forecast: raise NotImplementedError() else: - da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values}) + da = xr.DataArray( + values, dims=["time"], coords={"time": self._time_values} + ) # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") @@ -77,7 +79,9 @@ def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): # state and forcing variables have only one dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) + time_values = np.datetime64("2020-01-01") + np.arange( + len(ANALYSIS_STATE_VALUES) + ) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -98,7 +102,9 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] + init_states, target_states, forcing, _, _ = [ + tensor.numpy() for tensor in sample + ] expected_init_states = [0, 1] if ar_steps == 3: diff --git a/tests/test_training.py b/tests/test_training.py index 28566a4b..7a1b4717 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -5,6 +5,7 @@ import pytest import pytorch_lightning as pl import torch + import wandb # First-party @@ -22,14 +23,10 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize( - "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() -) +@pytest.mark.parametrize("datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys()) def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example( - datastore_boundary_name - ) + datastore_boundary = init_datastore_boundary_example(datastore_boundary_name) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( @@ -38,15 +35,13 @@ def test_training(datastore_name, datastore_boundary_name): ) if not isinstance(datastore_boundary, BaseRegularGridDatastore): pytest.skip( - f"Skipping test for {datastore_boundary_name} as it is not a regular " - "grid datastore." + f"Skipping test for {datastore_boundary_name} as it is not a " + "regular grid datastore." ) if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision( - "high" - ) # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s else: device_name = "cpu" From 779f3e9ed31d9525851793fae409cc145a30e15a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:10:29 +0100 Subject: [PATCH 068/103] improved docstrings and added some assertions --- neural_lam/weather_dataset.py | 105 ++++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 23 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 60f8d316..c65ec468 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -143,7 +143,13 @@ def __init__( self.da_state = self.da_state def get_time_step(times): - """Calculate the time step from the data""" + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ time_diffs = np.diff(times) if not np.all(time_diffs == time_diffs[0]): raise ValueError( @@ -234,6 +240,7 @@ def get_time_step(times): self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + # XXX: Again, the boundary data is considered forcing data for now if self.da_boundary is not None: self.ds_boundary_stats = ( self.datastore_boundary.get_standardization_dataarray( @@ -305,7 +312,7 @@ def _slice_time( is performed based on the state times. Additionally, the time difference between the matched forcing/boundary times and state times (in multiples of state time steps) is added to the forcing dataarray. This will be - used as an additional feature in the model (temporal embedding). + used as an additional input feature in the model (temporal embedding). Parameters ---------- @@ -333,23 +340,26 @@ def _slice_time( da_forcing_boundary_matched : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'forcing/boundary_feature_windowed'). + If no forcing/boundary data is provided, this will be `None`. """ - # Number of initial steps required (e.g., for initializing models) + # The current implementation requires at least 2 time steps for the + # initial state (see GraphCast). init_steps = 2 - - # Slice the state data as before + # slice the dataarray to include the required number of time steps if self.datastore.is_forecast: - # Calculate start and end indices for slicing - start_idx = max(0, num_past_steps - init_steps) - end_idx = max(init_steps, num_past_steps) + n_steps - - # Slice the state data over the elapsed forecast duration + start_idx = max(0, self.num_past_forcing_steps - init_steps) + end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps + # this implies that the data will have both `analysis_time` and + # `elapsed_forecast_duration` dimensions for forecasts. We for now + # simply select a analysis time and the first `n_steps` forecast + # times (given no offset). Note that this means that we get one + # sample per forecast, always starting at forecast time 2. da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) - - # Create a new 'time' dimension + # create a new time dimension so that the produced sample has a + # `time` dimension, similarly to the analysis only data da_state_sliced["time"] = ( da_state_sliced.analysis_time + da_state_sliced.elapsed_forecast_duration @@ -357,9 +367,13 @@ def _slice_time( da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + # Asserting that the forecast time step is consistent + self.get_time_step(da_state_sliced.time) else: - # For analysis data, slice the time dimension directly + # For analysis data we slice the time dimension directly. The offset + # is only relevant for the very first (and last) samples in the + # dataset. start_idx = idx + max(0, num_past_steps - init_steps) end_idx = idx + max(init_steps, num_past_steps) + n_steps da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) @@ -372,7 +386,13 @@ def _slice_time( state_times = da_state_sliced["time"] state_time_step = state_times.values[1] - state_times.values[0] + # Here we cannot check 'self.datastore.is_forecast' directly because we + # might be dealing with a datastore_boundary if "analysis_time" in da_forcing_boundary.dims: + # Select the closest analysis time in the forcing/boundary data + # This is mostly relevant for boundary data where the time steps + # are not necessarily the same as the state data. But still fast + # enough for forcing data where the time steps are the same. idx = np.abs( da_forcing_boundary.analysis_time.values - self.da_state.analysis_time.values[idx] @@ -399,6 +419,8 @@ def _slice_time( da_sliced = da_sliced.rename( {"elapsed_forecast_duration": "window"} ) + + # Assign the 'window' coordinate to be relative positions da_sliced = da_sliced.assign_coords( window=np.arange(-num_past_steps, num_future_steps + 1) ) @@ -409,7 +431,10 @@ def _slice_time( da_list.append(da_sliced) - # Concatenate the list of DataArrays along the 'time' dimension + # Generate temporal embedding `time_diff_steps` for the + # forcing/boundary data. This is the time difference in multiples + # of state time steps between the forcing/boundary time and the + # state time. da_forcing_boundary_matched = xr.concat(da_list, dim="time") forcing_time_step = ( da_forcing_boundary_matched.time.values[1] @@ -423,7 +448,9 @@ def _slice_time( ).data else: - # For analysis data, match directly using the 'time' coordinate + # For analysis data, we slice the time dimension directly. The + # offset is only relevant for the very first (and last) samples in + # the dataset. forcing_times = da_forcing_boundary["time"] # Compute time differences between forcing and state times @@ -455,7 +482,7 @@ def _slice_time( ) # Add time difference as a new coordinate to concatenate to the - # forcing features later + # forcing features later as temporal embedding da_forcing_boundary_matched["time_diff_steps"] = ( ("time", "window"), time_diff_steps, @@ -464,7 +491,26 @@ def _slice_time( return da_state_sliced, da_forcing_boundary_matched def _process_windowed_data(self, da_windowed, da_state, da_target_times): - """Helper function to process windowed data after standardization.""" + """Helper function to process windowed data. This function stacks the + 'forcing_feature' and 'window' dimensions and adds the time step + differences to the existing features as a temporal embedding. + + Parameters + ---------- + da_windowed : xr.DataArray + The windowed data to process. Can be `None` if no data is provided. + da_state : xr.DataArray + The state dataarray. + da_target_times : xr.DataArray + The target times. + + Returns + ------- + da_windowed : xr.DataArray + The processed windowed data. If `da_windowed` is `None`, an empty + DataArray with the correct dimensions and coordinates is returned. + + """ stacked_dim = "forcing_feature_windowed" if da_windowed is not None: # Stack the 'feature' and 'window' dimensions and add the @@ -492,8 +538,8 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): def _build_item_dataarrays(self, idx): """ - Create the dataarrays for the initial states, target states and forcing - data for the sample at index `idx`. + Create the dataarrays for the initial states, target states, forcing + and boundary data for the sample at index `idx`. Parameters ---------- @@ -529,7 +575,7 @@ def _build_item_dataarrays(self, idx): else: da_boundary = None - # if da_forcing is None, the function will return None for + # if da_forcing_boundary is None, the function will return None for # da_forcing_windowed if da_boundary is not None: _, da_boundary_windowed = self._slice_time( @@ -542,6 +588,9 @@ def _build_item_dataarrays(self, idx): ) else: da_boundary_windowed = None + # XXX: Currently, the order of the `slice_time` calls is important + # as `da_state` is modified in the second call. This should be + # refactored to be more robust. da_state, da_forcing_windowed = self._slice_time( da_state=da_state, idx=idx, @@ -584,6 +633,10 @@ def _build_item_dataarrays(self, idx): da_boundary_windowed - self.da_boundary_mean ) / self.da_boundary_std + # This function handles the stacking of the forcing and boundary data + # and adds the time step differences as a temporal embedding. + # It can handle `None` inputs for the forcing and boundary data + # (and simlpy return an empty DataArray in that case). da_forcing_windowed = self._process_windowed_data( da_forcing_windowed, da_state, da_target_times ) @@ -655,6 +708,11 @@ def __getitem__(self, idx): # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) + # Assert that the boundary data is an empty tensor if the corresponding + # datastore_boundary is `None` + if self.datastore_boundary is None: + assert boundary.numel() == 0 + return init_states, target_states, forcing, boundary, target_times def __iter__(self): @@ -795,9 +853,10 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: - # BUG: There also seem to be issues with "spawn", to be investigated - # default to spawn for now, as the default on linux "fork" hangs - # when using dask (which the npyfilesmeps datastore uses) + # BUG: There also seem to be issues with "spawn" and `gloo`, to be + # investigated. Defaults to spawn for now, as the default on linux + # "fork" hangs when using dask (which the npyfilesmeps datastore + # uses) self.multiprocessing_context = "spawn" else: self.multiprocessing_context = None From f126ec27b6c7d8534893850f07427e3737418216 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:11:32 +0100 Subject: [PATCH 069/103] remove boundary datastore from tests that don't need it --- tests/test_datasets.py | 17 ++--------------- tests/test_training.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 5fbe4a5d..063ec147 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -108,37 +108,24 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): # try to get the last item of the dataset to ensure slicing and stacking # operations are working as expected and are consistent with the dataset # length - dataset[len(dataset) - 1] @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize( - "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() -) -def test_dataset_item_create_dataarray_from_tensor( - datastore_name, datastore_boundary_name -): +def test_dataset_item_create_dataarray_from_tensor(datastore_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example( - datastore_boundary_name - ) N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 - num_past_boundary_steps = 1 - num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, - datastore_boundary=datastore_boundary, + datastore_boundary=None, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, - num_past_boundary_steps=num_past_boundary_steps, - num_future_boundary_steps=num_future_boundary_steps, ) idx = 0 diff --git a/tests/test_training.py b/tests/test_training.py index 7a1b4717..ca0ebf41 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -5,7 +5,6 @@ import pytest import pytorch_lightning as pl import torch - import wandb # First-party @@ -23,10 +22,14 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize("datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example(datastore_boundary_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( @@ -41,7 +44,9 @@ def test_training(datastore_name, datastore_boundary_name): if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision( + "high" + ) # Allows using Tensor Cores on A100s else: device_name = "cpu" From 4b656da04526d3d38d71881deab18ee69519b29d Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:43:01 +0100 Subject: [PATCH 070/103] fix scope of _get_time_step --- neural_lam/weather_dataset.py | 40 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index c65ec468..3685e227 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,28 +142,14 @@ def __init__( else: self.da_state = self.da_state - def get_time_step(times): - """Calculate the time step from the data - - Parameters - ---------- - times : xr.DataArray - The time dataarray to calculate the time step from. - """ - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] + # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time else: state_times = self.da_state.time - _ = get_time_step(state_times) + _ = self._get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: @@ -182,7 +168,7 @@ def get_time_step(times): forcing_times = self.da_forcing.analysis_time else: forcing_times = self.da_forcing.time - get_time_step(forcing_times.values) + self._get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore @@ -192,7 +178,7 @@ def get_time_step(times): boundary_times = self.da_boundary.analysis_time else: boundary_times = self.da_boundary.time - boundary_time_step = get_time_step(boundary_times.values) + boundary_time_step = self._get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values @@ -296,6 +282,22 @@ def __len__(self): - self.num_future_forcing_steps ) + def _get_time_step(self, times): + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + def _slice_time( self, da_state, @@ -368,7 +370,7 @@ def _slice_time( {"elapsed_forecast_duration": "time"} ) # Asserting that the forecast time step is consistent - self.get_time_step(da_state_sliced.time) + self._get_time_step(da_state_sliced.time) else: # For analysis data we slice the time dimension directly. The offset From 75db4b8a5ac0769dab7be8837e707b734c62ff92 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 16:58:46 +0100 Subject: [PATCH 071/103] added information about optional boundary datastore --- README.md | 22 +++++++++++++--------- neural_lam/weather_dataset.py | 2 -- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e21b7c24..7a5e5caf 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,9 @@ Once `neural-lam` is installed you will be able to train/evaluate models. For th interface that provides the data in a data-structure that can be used within neural-lam. A datastore is used to create a `pytorch.Dataset`-derived class that samples the data in time to create individual samples for - training, validation and testing. + training, validation and testing. A secondary datastore can be provided + for the boundary data. Currently, boundary datastore must be of type `mdp` + and only contain forcing features. This can easily be expanded in the future. 2. **The graph structure** is used to define message-passing GNN layers, that are trained to emulate fluid flow in the atmosphere over time. The @@ -121,7 +123,7 @@ different aspects about the training and evaluation of the model. The path you provide to the neural-lam config (`config.yaml`) also sets the root directory relative to which all other paths are resolved, as in the parent -directory of the config becomes the root directory. Both the datastore and +directory of the config becomes the root directory. Both the datastores and graphs you generate are then stored in subdirectories of this root directory. Exactly how and where a specific datastore expects its source data to be stored and where it stores its derived data is up to the implementation of the @@ -134,6 +136,7 @@ assume you placed `config.yaml` in a folder called `data`): data/ ├── config.yaml - Configuration file for neural-lam ├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml +├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml └── graphs/ - Directory containing graphs for training ``` @@ -142,18 +145,20 @@ And the content of `config.yaml` could in this case look like: datastore: kind: mdp config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml training: state_feature_weighting: __config_class__: ManualStateFeatureWeighting - values: + weights: u100m: 1.0 v100m: 1.0 ``` -For now the neural-lam config only defines two things: 1) the kind of data -store and the path to its config, and 2) the weighting of different features in -the loss function. If you don't define the state feature weighting it will default -to weighting all features equally. +For now the neural-lam config only defines two things: +1) the kind of datastores and the path to their config +2) the weighting of different features in the loss function. If you don't define the state feature weighting it will default to weighting all features equally. (This example is taken from the `tests/datastore_examples/mdp` directory.) @@ -525,5 +530,4 @@ Furthermore, all tests in the ```tests``` directory will be run upon pushing cha # Contact If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch. -There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join (after following the link you have to request to join, this is to avoid spam bots). -You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). +There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 3685e227..f02cfbd4 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,8 +142,6 @@ def __init__( else: self.da_state = self.da_state - - # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time From 4c175452af54fa4833fd9ac67bb4b1b36cdaa777 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 05:14:38 +0100 Subject: [PATCH 072/103] moved gcsfs to dev group --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 38e7cb0e..f556ef6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,13 +26,12 @@ dependencies = [ "torch-geometric==2.3.1", "parse>=1.20.2", "dataclass-wizard<0.31.0", - "gcsfs>=2021.10.0", "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" [project.optional-dependencies] -dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"] +dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2", "gcsfs>=2021.10.0"] [tool.setuptools] py-modules = ["neural_lam"] From a700350f9c0b6161ffefa06b7fa7fc7151e51f23 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 05:14:44 +0100 Subject: [PATCH 073/103] linting --- .../era5.datastore.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml index 600a1845..7c5ffb3b 100644 --- a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml @@ -7,7 +7,7 @@ output: coord_ranges: time: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 step: PT6H chunking: time: 1 @@ -16,16 +16,16 @@ output: splits: train: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 compute_statistics: ops: [mean, std, diff_mean, diff_std] dims: [grid_index, time] val: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 test: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: From 16d5d04bbd9e49a1fe53e56ff95e22d565326c67 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 12:23:15 +0100 Subject: [PATCH 074/103] Fixed issue with temporal encoding dimensions + some more comments --- neural_lam/weather_dataset.py | 38 +++++++++++++++++++++++------------ tests/test_datasets.py | 19 +++++++----------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index f02cfbd4..93988ed7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -287,6 +287,11 @@ def _get_time_step(self, times): ---------- times : xr.DataArray The time dataarray to calculate the time step from. + + Returns + ------- + time_step : float + The time step in the the format of the times dataarray. """ time_diffs = np.diff(times) if not np.all(time_diffs == time_diffs[0]): @@ -368,6 +373,7 @@ def _slice_time( {"elapsed_forecast_duration": "time"} ) # Asserting that the forecast time step is consistent + # In init this was only done for the analysis_time self._get_time_step(da_state_sliced.time) else: @@ -382,7 +388,8 @@ def _slice_time( return da_state_sliced, None # Get the state times and its temporal resolution for matching with - # forcing data + # forcing data. No need to self._get_time_step as we have already + # checked the time step consistency in the state data. state_times = da_state_sliced["time"] state_time_step = state_times.values[1] - state_times.values[0] @@ -440,12 +447,14 @@ def _slice_time( da_forcing_boundary_matched.time.values[1] - da_forcing_boundary_matched.time.values[0] ) + # Since all time, grid_index and forcing_features share the same + # temporal_embedding we can just use the first one da_forcing_boundary_matched["window"] = da_forcing_boundary_matched[ "window" ] * (forcing_time_step / state_time_step) time_diff_steps = da_forcing_boundary_matched.isel( - grid_index=0, forcing_feature=0 - ).data + grid_index=0, forcing_feature=0, time=0 + ).window.values else: # For analysis data, we slice the time dimension directly. The @@ -462,15 +471,16 @@ def _slice_time( ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) - time_diff_steps = np.stack( - [ - time_deltas[ - idx_i - num_past_steps : idx_i + num_future_steps + 1, - init_steps + step_i, - ] - for (step_i, idx_i) in enumerate(idx_min[init_steps:]) - ], - ) + # Get the time differences for windowed time steps - they are + # used as temporal embeddings and concatenated to the forcing + # features later. All features share the same temporal embedding + time_diff_steps = time_deltas[ + idx_min[init_steps] + - num_past_steps : idx_min[init_steps] + + num_future_steps + + 1, + init_steps, + ] # Create window dimension for forcing data to stack later window_size = num_past_steps + num_future_steps + 1 @@ -484,7 +494,7 @@ def _slice_time( # Add time difference as a new coordinate to concatenate to the # forcing features later as temporal embedding da_forcing_boundary_matched["time_diff_steps"] = ( - ("time", "window"), + ("window"), time_diff_steps, ) @@ -519,6 +529,8 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): da_windowed = da_windowed.stack( {stacked_dim: ("forcing_feature", "window")} ) + # All data variables share the same temporal embedding, hence + # only the first one is used da_windowed = xr.concat( [da_windowed, da_windowed.time_diff_steps], dim="forcing_feature_windowed", diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 063ec147..aa7b645d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -82,24 +82,19 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): assert forcing.ndim == 3 assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints - # each stacked forcing feature has one corresponding temporal embedding - assert ( - forcing.shape[2] - == datastore.get_num_data_vars("forcing") - * (num_past_forcing_steps + num_future_forcing_steps + 1) - * 2 + # each time step in the window has one corresponding temporal embedding + # that is shared across all grid points, times and variables + assert forcing.shape[2] == (datastore.get_num_data_vars("forcing") + 1) * ( + num_past_forcing_steps + num_future_forcing_steps + 1 ) # boundary assert boundary.ndim == 3 assert boundary.shape[0] == N_pred_steps assert boundary.shape[1] == N_gridpoints_boundary - assert ( - boundary.shape[2] - == datastore_boundary.get_num_data_vars("forcing") - * (num_past_boundary_steps + num_future_boundary_steps + 1) - * 2 - ) + assert boundary.shape[2] == ( + datastore_boundary.get_num_data_vars("forcing") + 1 + ) * (num_past_boundary_steps + num_future_boundary_steps + 1) # batch times assert target_times.ndim == 1 From f1f3f73e8269ffe20bc7acb771037fa8b9410d4f Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 12:23:23 +0100 Subject: [PATCH 075/103] format docstrings --- neural_lam/models/ar_model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 331966e4..6074a024 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -193,18 +193,18 @@ def expand_to_batch(x, batch_size): def predict_step(self, prev_state, prev_prev_state, forcing): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, - num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes, - forcing_dim) + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + forcing: (B, num_grid_nodes, forcing_dim) """ raise NotImplementedError("No prediction step implemented") def unroll_prediction(self, init_states, forcing_features, true_states): """ Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B, - pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps, - num_grid_nodes, d_f) + init_states: (B, 2, num_grid_nodes, d_f) + forcing_features: (B, pred_steps, num_grid_nodes, d_static_f) + true_states: (B, pred_steps, num_grid_nodes, d_f) """ prev_prev_state = init_states[:, 0] prev_state = init_states[:, 1] From 8fd7a10fc1d01c998641baf228dfa66f8630280c Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 12:23:33 +0100 Subject: [PATCH 076/103] introduced time slicing test for forecast type data --- tests/test_time_slicing.py | 186 +++++++++++++++++++++++++++++++------ 1 file changed, 158 insertions(+), 28 deletions(-) diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 4a59c81e..57e468db 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -16,40 +16,76 @@ class SinglePointDummyDatastore(BaseDatastore): root_path = None def __init__(self, time_values, state_data, forcing_data, is_forecast): - self._time_values = np.array(time_values) - self._state_data = np.array(state_data) - self._forcing_data = np.array(forcing_data) self.is_forecast = is_forecast - if is_forecast: - assert self._state_data.ndim == 2 + self._analysis_times, self._forecast_times = time_values + self._state_data = np.array(state_data) + self._forcing_data = np.array(forcing_data) + # state_data and forcing_data should be 2D arrays with shape + # (n_analysis_times, n_forecast_times) else: - assert self._state_data.ndim == 1 + self._time_values = np.array(time_values) + self._state_data = np.array(state_data) + self._forcing_data = np.array(forcing_data) + + if is_forecast: + assert self._state_data.ndim == 2 + else: + assert self._state_data.ndim == 1 def get_num_data_vars(self, category): return 1 def get_dataarray(self, category, split): - if category == "state": - values = self._state_data - elif category == "forcing": - values = self._forcing_data - else: - raise NotImplementedError(category) - if self.is_forecast: - raise NotImplementedError() + if category == "state": + # Create DataArray with dims ('analysis_time', + # 'elapsed_forecast_duration') + da = xr.DataArray( + self._state_data, + dims=["analysis_time", "elapsed_forecast_duration"], + coords={ + "analysis_time": self._analysis_times, + "elapsed_forecast_duration": self._forecast_times, + }, + ) + elif category == "forcing": + da = xr.DataArray( + self._forcing_data, + dims=["analysis_time", "elapsed_forecast_duration"], + coords={ + "analysis_time": self._analysis_times, + "elapsed_forecast_duration": self._forecast_times, + }, + ) + else: + raise NotImplementedError(category) + # Add 'grid_index' and '{category}_feature' dimensions + da = da.expand_dims("grid_index") + da = da.expand_dims(f"{category}_feature") + dim_order = self.expected_dim_order(category=category) + return da.transpose(*dim_order) else: - da = xr.DataArray( - values, dims=["time"], coords={"time": self._time_values} - ) - # add `{category}_feature` and `grid_index` dimensions + if category == "state": + values = self._state_data + elif category == "forcing": + values = self._forcing_data + else: + raise NotImplementedError(category) + + if self.is_forecast: + raise NotImplementedError() + else: + da = xr.DataArray( + values, dims=["time"], coords={"time": self._time_values} + ) + # add `{category}_feature` and `grid_index` dimensions - da = da.expand_dims("grid_index") - da = da.expand_dims(f"{category}_feature") + da = da.expand_dims("grid_index") + da = da.expand_dims(f"{category}_feature") - dim_order = self.expected_dim_order(category=category) - return da.transpose(*dim_order) + dim_order = self.expected_dim_order(category=category) + return da.transpose(*dim_order) def get_standardization_dataarray(self, category): raise NotImplementedError() @@ -70,6 +106,32 @@ def get_vars_long_names(self, category): ANALYSIS_STATE_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] FORCING_VALUES = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] +# Constants for forecast data +FORECAST_ANALYSIS_TIMES = np.datetime64("2020-01-01") + np.arange(3) +FORECAST_FORECAST_TIMES = np.timedelta64(0, "D") + np.arange(7) + +FORECAST_STATE_VALUES = np.array( + [ + # Analysis time 0 + [0, 1, 2, 3, 4, 5, 6], + # Analysis time 1 + [10, 11, 12, 13, 14, 15, 16], + # Analysis time 2 + [20, 21, 22, 23, 24, 25, 26], + ] +) + +FORECAST_FORCING_VALUES = np.array( + [ + # Analysis time 0 + [100, 101, 102, 103, 104, 105, 106], + # Analysis time 1 + [110, 111, 112, 113, 114, 115, 116], + # Analysis time 2 + [120, 121, 122, 123, 124, 125, 126], + ] +) + @pytest.mark.parametrize( "ar_steps,num_past_forcing_steps,num_future_forcing_steps", @@ -79,9 +141,7 @@ def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): # state and forcing variables have only one dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange( - len(ANALYSIS_STATE_VALUES) - ) + time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -102,9 +162,7 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _, _ = [ - tensor.numpy() for tensor in sample - ] + init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] expected_init_states = [0, 1] if ar_steps == 3: @@ -149,3 +207,75 @@ def test_time_slicing_analysis( forcing[:, 0, : num_past_forcing_steps + num_future_forcing_steps + 1], np.array(expected_forcing_values), ) + + +@pytest.mark.parametrize( + "ar_steps,num_past_forcing_steps,num_future_forcing_steps", + [ + [3, 0, 0], + [3, 1, 0], + [3, 2, 0], + [3, 0, 1], + [3, 0, 2], + ], +) +def test_time_slicing_forecast( + ar_steps, num_past_forcing_steps, num_future_forcing_steps +): + # Create a dummy datastore with forecast data + time_values = (FORECAST_ANALYSIS_TIMES, FORECAST_FORECAST_TIMES) + datastore = SinglePointDummyDatastore( + state_data=FORECAST_STATE_VALUES, + forcing_data=FORECAST_FORCING_VALUES, + time_values=time_values, + is_forecast=True, + ) + + dataset = WeatherDataset( + datastore=datastore, + datastore_boundary=None, + split="train", + ar_steps=ar_steps, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + standardize=False, + ) + + # Test the dataset length + assert len(dataset) == len(FORECAST_ANALYSIS_TIMES) + + sample = dataset[0] + + init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] + + # Expected initial states and target states + expected_init_states = FORECAST_STATE_VALUES[0][:2] + expected_target_states = FORECAST_STATE_VALUES[0][2 : 2 + ar_steps] + + # Expected forcing values + total_forcing_window = num_past_forcing_steps + num_future_forcing_steps + 1 + expected_forcing_values = [] + for i in range(ar_steps): + start_idx = max(0, i + 2 - num_past_forcing_steps) + end_idx = i + 2 + num_future_forcing_steps + 1 + forcing_window = FORECAST_FORCING_VALUES[0][start_idx:end_idx] + expected_forcing_values.append(forcing_window) + + # Assertions + np.testing.assert_array_equal(init_states[:, 0, 0], expected_init_states) + np.testing.assert_array_equal(target_states[:, 0, 0], expected_target_states) + + # Verify the shape of the forcing data + expected_forcing_shape = ( + ar_steps, + 1, + total_forcing_window * 2, # Each windowed feature includes temporal embedding + ) + assert forcing.shape == expected_forcing_shape + + # Extract the forcing values from the tensor (excluding temporal embeddings) + forcing_values = forcing[:, 0, :total_forcing_window] + + # Compare with expected forcing values + for i in range(ar_steps): + np.testing.assert_array_equal(forcing_values[i], expected_forcing_values[i]) From 252a33cd5903aca793f480e871912cc8b8616df2 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 13:05:05 +0100 Subject: [PATCH 077/103] bugfix temporal embedding dimension --- neural_lam/models/ar_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 6074a024..81d5a623 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -110,9 +110,8 @@ def __init__( self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim - # Factor 2 because of temporal embedding or windowed features - + 2 - * num_forcing_vars + # Temporal Embedding counts as one additional forcing_feature + + (num_forcing_vars + 1) * (num_past_forcing_steps + num_future_forcing_steps + 1) ) From 8a9114a3629d0840b6c5d4fb9b18f080f61e751d Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 13:05:08 +0100 Subject: [PATCH 078/103] linting --- tests/test_time_slicing.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 57e468db..48860161 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -141,7 +141,9 @@ def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): # state and forcing variables have only one dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) + time_values = np.datetime64("2020-01-01") + np.arange( + len(ANALYSIS_STATE_VALUES) + ) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -162,7 +164,9 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] + init_states, target_states, forcing, _, _ = [ + tensor.numpy() for tensor in sample + ] expected_init_states = [0, 1] if ar_steps == 3: @@ -246,7 +250,9 @@ def test_time_slicing_forecast( sample = dataset[0] - init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] + init_states, target_states, forcing, _, _ = [ + tensor.numpy() for tensor in sample + ] # Expected initial states and target states expected_init_states = FORECAST_STATE_VALUES[0][:2] @@ -263,13 +269,16 @@ def test_time_slicing_forecast( # Assertions np.testing.assert_array_equal(init_states[:, 0, 0], expected_init_states) - np.testing.assert_array_equal(target_states[:, 0, 0], expected_target_states) + np.testing.assert_array_equal( + target_states[:, 0, 0], expected_target_states + ) # Verify the shape of the forcing data expected_forcing_shape = ( ar_steps, 1, - total_forcing_window * 2, # Each windowed feature includes temporal embedding + total_forcing_window + * 2, # Each windowed feature includes temporal embedding ) assert forcing.shape == expected_forcing_shape @@ -278,4 +287,6 @@ def test_time_slicing_forecast( # Compare with expected forcing values for i in range(ar_steps): - np.testing.assert_array_equal(forcing_values[i], expected_forcing_values[i]) + np.testing.assert_array_equal( + forcing_values[i], expected_forcing_values[i] + ) From 8c7709a3c4b7bbd14c5736713dee6af1cd6e2b80 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 17:36:05 +0100 Subject: [PATCH 079/103] switched to low-res data --- .../mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml | 4 ++-- .../era5.datastore.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index 7c5ffb3b..587d7879 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -29,7 +29,7 @@ output: inputs: era_height_levels: - path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' dims: [time, longitude, latitude, level] variables: u_component_of_wind: @@ -56,7 +56,7 @@ inputs: target_output_variable: forcing era5_surface: - path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' dims: [time, longitude, latitude, level] variables: - mean_surface_net_short_wave_radiation_flux diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml index 7c5ffb3b..587d7879 100644 --- a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml @@ -29,7 +29,7 @@ output: inputs: era_height_levels: - path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' dims: [time, longitude, latitude, level] variables: u_component_of_wind: @@ -56,7 +56,7 @@ inputs: target_output_variable: forcing era5_surface: - path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' dims: [time, longitude, latitude, level] variables: - mean_surface_net_short_wave_radiation_flux From 24cbf13b1e51e42adaf7bd4aeab95a77bb479a1e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 17:36:27 +0100 Subject: [PATCH 080/103] add datastore_boundary as explicit attribute --- neural_lam/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/neural_lam/config.py b/neural_lam/config.py index 914ebb38..f8879811 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -97,11 +97,15 @@ class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard): ---------- datastore : DatastoreSelection The configuration for the datastore to use. + datastore_boundary : Union[DatastoreSelection, None] + The configuration for the boundary datastore to use, if any. If None, + no boundary datastore is used. training : TrainingConfig The configuration for training the model. """ datastore: DatastoreSelection + datastore_boundary: Union[DatastoreSelection, None] = None training: TrainingConfig = dataclasses.field(default_factory=TrainingConfig) class _(dataclass_wizard.JSONWizard.Meta): From 1d53ce7b86ee8c936a1c8c2fd9bad58bb672b844 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 5 Dec 2024 13:26:06 +0100 Subject: [PATCH 081/103] fixing up forecast type data tests, more and better defined scenarios --- tests/test_time_slicing.py | 175 ++++++++++++++++++++++--------------- 1 file changed, 105 insertions(+), 70 deletions(-) diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 48860161..21038e7b 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -79,8 +79,8 @@ def get_dataarray(self, category, split): da = xr.DataArray( values, dims=["time"], coords={"time": self._time_values} ) - # add `{category}_feature` and `grid_index` dimensions + # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") da = da.expand_dims(f"{category}_feature") @@ -103,51 +103,55 @@ def get_vars_long_names(self, category): raise NotImplementedError() -ANALYSIS_STATE_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -FORCING_VALUES = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] +INIT_STEPS = 2 -# Constants for forecast data -FORECAST_ANALYSIS_TIMES = np.datetime64("2020-01-01") + np.arange(3) -FORECAST_FORECAST_TIMES = np.timedelta64(0, "D") + np.arange(7) - -FORECAST_STATE_VALUES = np.array( - [ - # Analysis time 0 - [0, 1, 2, 3, 4, 5, 6], - # Analysis time 1 - [10, 11, 12, 13, 14, 15, 16], - # Analysis time 2 - [20, 21, 22, 23, 24, 25, 26], - ] -) +STATE_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +FORCING_VALUES = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] -FORECAST_FORCING_VALUES = np.array( - [ - # Analysis time 0 - [100, 101, 102, 103, 104, 105, 106], - # Analysis time 1 - [110, 111, 112, 113, 114, 115, 116], - # Analysis time 2 - [120, 121, 122, 123, 124, 125, 126], - ] -) +STATE_VALUES_FORECAST = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], # Analysis time 0 + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], # Analysis time 1 + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], # Analysis time 2 +] +FORCING_VALUES_FORECAST = [ + [100, 101, 102, 103, 104, 105, 106, 107, 108, 109], # Analysis time 0 + [110, 111, 112, 113, 114, 115, 116, 117, 118, 119], # Analysis time 1 + [120, 121, 122, 123, 124, 125, 126, 127, 128, 129], # Analysis time 2 +] + +SCENARIOS = [ + [3, 0, 0], + [3, 1, 0], + [3, 2, 0], + [3, 3, 0], + [3, 0, 1], + [3, 0, 2], + [3, 0, 3], + [3, 1, 1], + [3, 2, 1], + [3, 3, 1], + [3, 1, 2], + [3, 1, 3], + [3, 2, 2], + [3, 2, 3], + [3, 3, 2], + [3, 3, 3], +] @pytest.mark.parametrize( "ar_steps,num_past_forcing_steps,num_future_forcing_steps", - [[3, 0, 0], [3, 1, 0], [3, 2, 0], [3, 3, 0]], + SCENARIOS, ) def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): # state and forcing variables have only one dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange( - len(ANALYSIS_STATE_VALUES) - ) - assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) + time_values = np.datetime64("2020-01-01") + np.arange(len(STATE_VALUES)) + assert len(STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( - state_data=ANALYSIS_STATE_VALUES, + state_data=STATE_VALUES, forcing_data=FORCING_VALUES, time_values=time_values, is_forecast=False, @@ -168,12 +172,10 @@ def test_time_slicing_analysis( tensor.numpy() for tensor in sample ] + # Some scenarios for the human reader expected_init_states = [0, 1] if ar_steps == 3: expected_target_states = [2, 3, 4] - else: - raise NotImplementedError() - if num_past_forcing_steps == num_future_forcing_steps == 0: expected_forcing_values = [[12], [13], [14]] elif num_past_forcing_steps == 1 and num_future_forcing_steps == 0: @@ -188,49 +190,72 @@ def test_time_slicing_analysis( [11, 12, 13, 14], [12, 13, 14, 15], ] - else: - raise NotImplementedError() + + # Compute expected initial states and target states based on ar_steps + offset = max(0, num_past_forcing_steps - INIT_STEPS) + init_idx = INIT_STEPS + offset + # Compute expected forcing values based on num_past_forcing_steps and + # num_future_forcing_steps for all scenarios + expected_init_states = STATE_VALUES[offset:init_idx] + expected_target_states = STATE_VALUES[init_idx : init_idx + ar_steps] + total_forcing_window = num_past_forcing_steps + num_future_forcing_steps + 1 + expected_forcing_values = [] + for i in range(ar_steps): + start_idx = i + init_idx - num_past_forcing_steps + end_idx = i + init_idx + num_future_forcing_steps + 1 + forcing_window = FORCING_VALUES[start_idx:end_idx] + expected_forcing_values.append(forcing_window) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) # target_times: (ar_steps,) - assert init_states.shape == (2, 1, 1) - assert init_states[:, 0, 0].tolist() == expected_init_states - assert target_states.shape == (3, 1, 1) - assert target_states[:, 0, 0].tolist() == expected_target_states + # Adjust assertions to use computed expected values + assert init_states.shape == (INIT_STEPS, 1, 1) + np.testing.assert_array_equal(init_states[:, 0, 0], expected_init_states) + + assert target_states.shape == (ar_steps, 1, 1) + np.testing.assert_array_equal( + target_states[:, 0, 0], expected_target_states + ) assert forcing.shape == ( - 3, + ar_steps, 1, - # Factor 2 because each window step has a temporal embedding - (1 + num_past_forcing_steps + num_future_forcing_steps) * 2, - ) - np.testing.assert_equal( - forcing[:, 0, : num_past_forcing_steps + num_future_forcing_steps + 1], - np.array(expected_forcing_values), + total_forcing_window + * 2, # Each windowed feature includes temporal embedding ) + # Extract the forcing values from the tensor (excluding temporal embeddings) + forcing_values = forcing[:, 0, :total_forcing_window] + + # Compare with expected forcing values + for i in range(ar_steps): + np.testing.assert_array_equal( + forcing_values[i], expected_forcing_values[i] + ) + @pytest.mark.parametrize( "ar_steps,num_past_forcing_steps,num_future_forcing_steps", - [ - [3, 0, 0], - [3, 1, 0], - [3, 2, 0], - [3, 0, 1], - [3, 0, 2], - ], + SCENARIOS, ) def test_time_slicing_forecast( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): + # Constants for forecast data + ANALYSIS_TIMES = np.datetime64("2020-01-01") + np.arange( + len(STATE_VALUES_FORECAST) + ) + ELAPSED_FORECAST_DURATION = np.timedelta64(0, "D") + np.arange( + len(FORCING_VALUES_FORECAST[0]) + ) # Create a dummy datastore with forecast data - time_values = (FORECAST_ANALYSIS_TIMES, FORECAST_FORECAST_TIMES) + time_values = (ANALYSIS_TIMES, ELAPSED_FORECAST_DURATION) datastore = SinglePointDummyDatastore( - state_data=FORECAST_STATE_VALUES, - forcing_data=FORECAST_FORCING_VALUES, + state_data=STATE_VALUES_FORECAST, + forcing_data=FORCING_VALUES_FORECAST, time_values=time_values, is_forecast=True, ) @@ -246,7 +271,7 @@ def test_time_slicing_forecast( ) # Test the dataset length - assert len(dataset) == len(FORECAST_ANALYSIS_TIMES) + assert len(dataset) == len(ANALYSIS_TIMES) sample = dataset[0] @@ -254,19 +279,29 @@ def test_time_slicing_forecast( tensor.numpy() for tensor in sample ] - # Expected initial states and target states - expected_init_states = FORECAST_STATE_VALUES[0][:2] - expected_target_states = FORECAST_STATE_VALUES[0][2 : 2 + ar_steps] + # Compute expected initial states and target states based on ar_steps + offset = max(0, num_past_forcing_steps - INIT_STEPS) + init_idx = INIT_STEPS + offset + expected_init_states = STATE_VALUES_FORECAST[0][offset:init_idx] + expected_target_states = STATE_VALUES_FORECAST[0][ + init_idx : init_idx + ar_steps + ] - # Expected forcing values + # Compute expected forcing values based on num_past_forcing_steps and + # num_future_forcing_steps total_forcing_window = num_past_forcing_steps + num_future_forcing_steps + 1 expected_forcing_values = [] for i in range(ar_steps): - start_idx = max(0, i + 2 - num_past_forcing_steps) - end_idx = i + 2 + num_future_forcing_steps + 1 - forcing_window = FORECAST_FORCING_VALUES[0][start_idx:end_idx] + start_idx = i + init_idx - num_past_forcing_steps + end_idx = i + init_idx + num_future_forcing_steps + 1 + forcing_window = FORCING_VALUES_FORECAST[INIT_STEPS][start_idx:end_idx] expected_forcing_values.append(forcing_window) + # init_states: (2, N_grid, d_features) + # target_states: (ar_steps, N_grid, d_features) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) + # target_times: (ar_steps,) + # Assertions np.testing.assert_array_equal(init_states[:, 0, 0], expected_init_states) np.testing.assert_array_equal( @@ -275,9 +310,9 @@ def test_time_slicing_forecast( # Verify the shape of the forcing data expected_forcing_shape = ( - ar_steps, - 1, - total_forcing_window + ar_steps, # Number of AR steps + 1, # Number of grid points + total_forcing_window # Total number of forcing steps in the window * 2, # Each windowed feature includes temporal embedding ) assert forcing.shape == expected_forcing_shape From cfe1e278ae16a4ec8e19dc3f2db79e976e8014d6 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 5 Dec 2024 13:28:34 +0100 Subject: [PATCH 082/103] time step can and should be retrieved in __init__ match of state with forcing/boundary is now done with .sel and "pad" renaming some variables to make the code easier to read fixing the temporal encoding to only include embeddings for window-size --- neural_lam/weather_dataset.py | 218 +++++++++++++++++++--------------- 1 file changed, 120 insertions(+), 98 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 93988ed7..0ddad878 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -145,9 +145,12 @@ def __init__( # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time + self.forecast_step_state = self._get_time_step( + self.da_state.elapsed_forecast_duration + ) else: state_times = self.da_state.time - _ = self._get_time_step(state_times) + self.time_step_state = self._get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: @@ -164,9 +167,14 @@ def __init__( # is matched to the state data if self.datastore.is_forecast: forcing_times = self.da_forcing.analysis_time + self.forecast_step_forcing = self._get_time_step( + self.da_forcing.elapsed_forecast_duration + ) else: forcing_times = self.da_forcing.time - self._get_time_step(forcing_times.values) + self.time_step_forcing = self._get_time_step( + forcing_times.values + ) if self.da_boundary is not None: # Boundary data is part of a separate datastore @@ -174,20 +182,25 @@ def __init__( # Check that the boundary data covers the required time range if self.datastore_boundary.is_forecast: boundary_times = self.da_boundary.analysis_time + self.forecast_step_boundary = self._get_time_step( + self.da_boundary.elapsed_forecast_duration + ) else: boundary_times = self.da_boundary.time - boundary_time_step = self._get_time_step(boundary_times.values) + self.time_step_boundary = self._get_time_step( + boundary_times.values + ) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values # Calculate required bounds for boundary using its time step boundary_required_time_min = ( state_time_min - - self.num_past_forcing_steps * boundary_time_step + - self.num_past_forcing_steps * self.time_step_boundary ) boundary_required_time_max = ( state_time_max - + self.num_future_forcing_steps * boundary_time_step + + self.num_future_forcing_steps * self.time_step_boundary ) if boundary_time_min > boundary_required_time_min: @@ -306,13 +319,14 @@ def _slice_time( da_state, idx, n_steps: int, - da_forcing_boundary=None, + da_forcing=None, num_past_steps=None, num_future_steps=None, + is_boundary=False, ): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing_boundary`. For the state data, slicing is done + `da_forcing`. For the state data, slicing is done based on `idx`. For the forcing/boundary data, nearest neighbor matching is performed based on the state times. Additionally, the time difference between the matched forcing/boundary times and state times (in multiples @@ -328,7 +342,7 @@ def _slice_time( data. n_steps : int The number of time steps to include in the sample. - da_forcing_boundary : xr.DataArray + da_forcing : xr.DataArray The forcing/boundary dataarray to slice. num_past_steps : int, optional The number of past time steps to include in the forcing/boundary @@ -336,13 +350,15 @@ def _slice_time( num_future_steps : int, optional The number of future time steps to include in the forcing/boundary data. Default is `None`. + is_boundary : bool, optional + Whether the data is boundary data. Default is `False`. Returns ------- da_state_sliced : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). - da_forcing_boundary_matched : xr.DataArray + da_forcing_matched : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'forcing/boundary_feature_windowed'). If no forcing/boundary data is provided, this will be `None`. @@ -372,9 +388,6 @@ def _slice_time( da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) - # Asserting that the forecast time step is consistent - # In init this was only done for the analysis_time - self._get_time_step(da_state_sliced.time) else: # For analysis data we slice the time dimension directly. The offset @@ -384,43 +397,36 @@ def _slice_time( end_idx = idx + max(init_steps, num_past_steps) + n_steps da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - if da_forcing_boundary is None: + if da_forcing is None: return da_state_sliced, None # Get the state times and its temporal resolution for matching with - # forcing data. No need to self._get_time_step as we have already - # checked the time step consistency in the state data. + # forcing data. state_times = da_state_sliced["time"] - state_time_step = state_times.values[1] - state_times.values[0] - + da_list = [] # Here we cannot check 'self.datastore.is_forecast' directly because we # might be dealing with a datastore_boundary - if "analysis_time" in da_forcing_boundary.dims: - # Select the closest analysis time in the forcing/boundary data - # This is mostly relevant for boundary data where the time steps - # are not necessarily the same as the state data. But still fast - # enough for forcing data where the time steps are the same. - idx = np.abs( - da_forcing_boundary.analysis_time.values - - self.da_state.analysis_time.values[idx] - ).argmin() - # Add a 'time' dimension using the actual forecast times - offset = max(init_steps, num_past_steps) - da_list = [] - for step in range(n_steps): - start_idx = offset + step - num_past_steps - end_idx = offset + step + num_future_steps + if "analysis_time" in da_forcing.dims: + # For forecast data with analysis_time and elapsed_forecast_duration + # Select the closest analysis_time in the past in the + # forcing/boundary data + offset = max(0, num_past_steps - init_steps) + state_time = state_times[init_steps].values + forcing_analysis_time_idx = da_forcing.analysis_time.get_index( + "analysis_time" + ).get_indexer([state_time], method="pad")[0] + for step_idx in range(init_steps, len(state_times)): + start_idx = offset + step_idx - num_past_steps + end_idx = offset + step_idx + num_future_steps + 1 current_time = ( - da_forcing_boundary.analysis_time[idx] - + da_forcing_boundary.elapsed_forecast_duration[ - offset + step - ] + forcing_analysis_time_idx + + da_forcing.elapsed_forecast_duration[step_idx] ) - da_sliced = da_forcing_boundary.isel( - analysis_time=idx, - elapsed_forecast_duration=slice(start_idx, end_idx + 1), + da_sliced = da_forcing.isel( + analysis_time=forcing_analysis_time_idx, + elapsed_forecast_duration=slice(start_idx, end_idx), ) da_sliced = da_sliced.rename( @@ -438,67 +444,75 @@ def _slice_time( da_list.append(da_sliced) - # Generate temporal embedding `time_diff_steps` for the - # forcing/boundary data. This is the time difference in multiples - # of state time steps between the forcing/boundary time and the - # state time. - da_forcing_boundary_matched = xr.concat(da_list, dim="time") - forcing_time_step = ( - da_forcing_boundary_matched.time.values[1] - - da_forcing_boundary_matched.time.values[0] - ) - # Since all time, grid_index and forcing_features share the same - # temporal_embedding we can just use the first one - da_forcing_boundary_matched["window"] = da_forcing_boundary_matched[ - "window" - ] * (forcing_time_step / state_time_step) - time_diff_steps = da_forcing_boundary_matched.isel( - grid_index=0, forcing_feature=0, time=0 - ).window.values + else: + for idx_time in range(init_steps, len(state_times)): + state_time = state_times[idx_time].values + + # Select the closest time in the past from forcing data using + # sel with method="pad" + forcing_time_idx = da_forcing.time.get_index( + "time" + ).get_indexer([state_time], method="pad")[0] + + # Use isel to select the window + da_window = da_forcing.isel( + time=slice( + forcing_time_idx - num_past_steps, + forcing_time_idx + num_future_steps + 1, + ), + ) + da_window = da_window.rename({"time": "window"}) + + # Assign 'window' coordinate + da_window = da_window.assign_coords( + window=np.arange(-num_past_steps, num_future_steps + 1) + ) + + da_window = da_window.expand_dims(dim={"time": [state_time]}) + + da_list.append(da_window) + + da_forcing_matched = xr.concat(da_list, dim="time") + + # Generate temporal embedding `time_diff_steps` for the + # forcing/boundary data. This is the time difference in multiples + # of state time steps between the forcing/boundary time and the + # state time + + if is_boundary: + if self.datastore_boundary.is_forecast: + boundary_time_step = self.forecast_step_boundary + state_time_step = self.forecast_step_state + else: + boundary_time_step = self.time_step_boundary + state_time_step = self.time_step_state + time_diff_steps = ( + da_forcing_matched["window"] + * (boundary_time_step / state_time_step), + ) else: - # For analysis data, we slice the time dimension directly. The - # offset is only relevant for the very first (and last) samples in - # the dataset. - forcing_times = da_forcing_boundary["time"] - - # Compute time differences between forcing and state times - # (in multiples of state time steps) - # Retrieve the indices of the closest times in the forcing data - time_deltas = ( - forcing_times.values[:, np.newaxis] - - state_times.values[np.newaxis, :] - ) / state_time_step - idx_min = np.abs(time_deltas).argmin(axis=0) - - # Get the time differences for windowed time steps - they are - # used as temporal embeddings and concatenated to the forcing - # features later. All features share the same temporal embedding - time_diff_steps = time_deltas[ - idx_min[init_steps] - - num_past_steps : idx_min[init_steps] - + num_future_steps - + 1, - init_steps, - ] - - # Create window dimension for forcing data to stack later - window_size = num_past_steps + num_future_steps + 1 - da_forcing_boundary_windowed = da_forcing_boundary.rolling( - time=window_size, center=False - ).construct(window_dim="window") - da_forcing_boundary_matched = da_forcing_boundary_windowed.isel( - time=idx_min[init_steps:] + if self.datastore.is_forecast: + forcing_time_step = self.forecast_step_forcing + state_time_step = self.forecast_step_state + else: + forcing_time_step = self.time_step_forcing + state_time_step = self.time_step_state + time_diff_steps = ( + da_forcing_matched["window"] + * (forcing_time_step / state_time_step), ) - + time_diff_steps = da_forcing_matched.isel( + grid_index=0, forcing_feature=0 + ).window.values # Add time difference as a new coordinate to concatenate to the # forcing features later as temporal embedding - da_forcing_boundary_matched["time_diff_steps"] = ( + da_forcing_matched["time_diff_steps"] = ( ("window"), time_diff_steps, ) - return da_state_sliced, da_forcing_boundary_matched + return da_state_sliced, da_forcing_matched def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data. This function stacks the @@ -523,16 +537,21 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): """ stacked_dim = "forcing_feature_windowed" if da_windowed is not None: + window_size = da_windowed.window.size # Stack the 'feature' and 'window' dimensions and add the # time step differences to the existing features as a temporal # embedding da_windowed = da_windowed.stack( {stacked_dim: ("forcing_feature", "window")} ) - # All data variables share the same temporal embedding, hence - # only the first one is used + # Add the time step differences as a new feature to the windowed + # data + time_diff_steps = da_windowed["time_diff_steps"].isel( + forcing_feature_windowed=slice(0, window_size) + ) + # All data variables share the same temporal embedding da_windowed = xr.concat( - [da_windowed, da_windowed.time_diff_steps], + [da_windowed, time_diff_steps], dim="forcing_feature_windowed", ) else: @@ -587,16 +606,19 @@ def _build_item_dataarrays(self, idx): else: da_boundary = None - # if da_forcing_boundary is None, the function will return None for - # da_forcing_windowed + # This function will return a slice of the state data and the forcing + # and boundary data (if provided) for one sample (idx). + # If da_forcing is None, the function will return None for + # da_forcing_windowed. if da_boundary is not None: _, da_boundary_windowed = self._slice_time( da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing_boundary=da_boundary, + da_forcing=da_boundary, num_future_steps=self.num_future_boundary_steps, num_past_steps=self.num_past_boundary_steps, + is_boundary=True, ) else: da_boundary_windowed = None @@ -607,7 +629,7 @@ def _build_item_dataarrays(self, idx): da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing_boundary=da_forcing, + da_forcing=da_forcing, num_future_steps=self.num_future_forcing_steps, num_past_steps=self.num_past_forcing_steps, ) From e4e4e3789764c1a08270b41fb4c15dcace146fa5 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 4 Dec 2024 17:46:52 +0100 Subject: [PATCH 083/103] Fix dataset issue in npy stat script --- .../datastore/npyfilesmeps/compute_standardization_stats.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index 4207812f..1f1c6943 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -202,7 +202,7 @@ def main( print("Computing mean and std.-dev. for parameters...") means, squares, flux_means, flux_squares = [], [], [], [] - for init_batch, target_batch, forcing_batch, _ in tqdm(loader): + for init_batch, target_batch, forcing_batch, _, _ in tqdm(loader): if distributed: init_batch, target_batch, forcing_batch = ( init_batch.to(device), @@ -276,6 +276,7 @@ def main( print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=ar_steps, standardize=True, @@ -304,7 +305,7 @@ def main( diff_means, diff_squares = [], [] - for init_batch, target_batch, _, _ in tqdm( + for init_batch, target_batch, _, _, _ in tqdm( loader_standard, disable=rank != 0 ): if distributed: From f8613da77e0e2a040a05307210c7439f1660e43d Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 5 Dec 2024 15:14:28 +0100 Subject: [PATCH 084/103] added static feature to era5 boundary test datastore --- .../era5.datastore.yaml | 23 ++++++++++++++++++- .../era5.datastore.yaml | 23 ++++++++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index 587d7879..c83489c6 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -3,6 +3,7 @@ dataset_version: v1.0.0 output: variables: + static: [grid_index, static_feature] forcing: [time, grid_index, forcing_feature] coord_ranges: time: @@ -59,7 +60,7 @@ inputs: path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' dims: [time, longitude, latitude, level] variables: - - mean_surface_net_short_wave_radiation_flux + - mean_sea_level_pressure dim_mapping: time: method: rename @@ -78,6 +79,26 @@ inputs: dims: [x, y] target_output_variable: forcing + era5_static: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - land_sea_mask + dim_mapping: + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: static + extra: projection: class_name: PlateCarree diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml index 587d7879..c83489c6 100644 --- a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml @@ -3,6 +3,7 @@ dataset_version: v1.0.0 output: variables: + static: [grid_index, static_feature] forcing: [time, grid_index, forcing_feature] coord_ranges: time: @@ -59,7 +60,7 @@ inputs: path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' dims: [time, longitude, latitude, level] variables: - - mean_surface_net_short_wave_radiation_flux + - mean_sea_level_pressure dim_mapping: time: method: rename @@ -78,6 +79,26 @@ inputs: dims: [x, y] target_output_variable: forcing + era5_static: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - land_sea_mask + dim_mapping: + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: static + extra: projection: class_name: PlateCarree From 8cc608dfc4f8b4e97a931c58e5170c2578281fef Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 11:03:37 +0100 Subject: [PATCH 085/103] rename function to represent multiple datastores --- neural_lam/config.py | 6 +++--- neural_lam/create_graph.py | 4 ++-- neural_lam/plot_graph.py | 4 ++-- neural_lam/train_model.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/neural_lam/config.py b/neural_lam/config.py index f8879811..4b57a141 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -140,11 +140,11 @@ class InvalidConfigError(Exception): pass -def load_config_and_datastore( +def load_config_and_datastores( config_path: str, ) -> tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]: """ - Load the neural-lam configuration and the datastore specified in the + Load the neural-lam configuration and the datastores specified in the configuration. Parameters @@ -155,7 +155,7 @@ def load_config_and_datastore( Returns ------- tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]] - The Neural-LAM configuration and the loaded datastore. + The Neural-LAM configuration and the loaded datastores. """ try: config = NeuralLAMConfig.from_yaml_file(config_path) diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index ef979be3..1ab4e1e9 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -13,7 +13,7 @@ from torch_geometric.utils.convert import from_networkx # Local -from .config import load_config_and_datastore +from .config import load_config_and_datastores from .datastore.base import BaseRegularGridDatastore @@ -595,7 +595,7 @@ def cli(input_args=None): ), "Specify your config with --config_path" # Load neural-lam configuration and datastore to use - _, datastore = load_config_and_datastore(config_path=args.config_path) + _, datastore = load_config_and_datastores(config_path=args.config_path) create_graph_from_datastore( datastore=datastore, diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py index 999c8e53..ad27b5b0 100644 --- a/neural_lam/plot_graph.py +++ b/neural_lam/plot_graph.py @@ -9,7 +9,7 @@ # Local from . import utils -from .config import load_config_and_datastore +from .config import load_config_and_datastores MESH_HEIGHT = 0.1 MESH_LEVEL_DIST = 0.2 @@ -43,7 +43,7 @@ def main(): ) args = parser.parse_args() - _, datastore = load_config_and_datastore( + _, datastore = load_config_and_datastores( config_path=args.datastore_config_path ) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 2a61e86c..54017dbb 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -12,7 +12,7 @@ # Local from . import utils -from .config import load_config_and_datastore +from .config import load_config_and_datastores from .models import GraphLAM, HiLAM, HiLAMParallel from .weather_dataset import WeatherDataModule @@ -238,7 +238,7 @@ def main(input_args=None): seed.seed_everything(args.seed) # Load neural-lam configuration and datastore to use - config, datastore, datastore_boundary = load_config_and_datastore( + config, datastore, datastore_boundary = load_config_and_datastores( config_path=args.config_path ) From 857f7482e9e04c90a7112e8063bc6e449c9971c6 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 11:07:26 +0100 Subject: [PATCH 086/103] streamline da_grid_reference variable naming --- neural_lam/datastore/mdp.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 809bbdb8..f68bb4d0 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -386,11 +386,10 @@ def grid_shape_state(self): "no state data found in datastore" "returning grid shape from forcing data" ) - ds_forcing = self.unstack_grid_coords(self._ds["forcing"]) - da_x, da_y = ds_forcing.x, ds_forcing.y + da_grid_reference = self.unstack_grid_coords(self._ds["forcing"]) else: - ds_state = self.unstack_grid_coords(self._ds["state"]) - da_x, da_y = ds_state.x, ds_state.y + da_grid_reference = self.unstack_grid_coords(self._ds["state"]) + da_x, da_y = da_grid_reference.x, da_grid_reference.y assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) From d0a6f2425f473b06db99740f295c0d452d003281 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 11:09:44 +0100 Subject: [PATCH 087/103] updated docstring of WeatherDataset --- neural_lam/weather_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 0ddad878..510a9504 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -21,9 +21,9 @@ class WeatherDataset(torch.utils.data.Dataset): Parameters ---------- datastore : BaseDatastore - The datastore to load the data from (e.g. mdp). + The datastore to load the data from. datastore_boundary : BaseDatastore - The boundary datastore to load the data from (e.g. mdp). + The boundary datastore to load the data from. split : str, optional The data split to use ("train", "val" or "test"). Default is "train". ar_steps : int, optional From ef40a399fe0595dc936ee9f0e6c8028338a0c2f5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 11:11:25 +0100 Subject: [PATCH 088/103] renamed da_boundary -> da_boundary_forcing --- neural_lam/weather_dataset.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 510a9504..f20e3506 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -89,11 +89,11 @@ def __init__( ) # XXX For now boundary data is always considered mdp-forcing data if self.datastore_boundary is not None: - self.da_boundary = self.datastore_boundary.get_dataarray( + self.da_boundary_forcing = self.datastore_boundary.get_dataarray( category="forcing", split=self.split ) else: - self.da_boundary = None + self.da_boundary_forcing = None # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -153,7 +153,7 @@ def __init__( self.time_step_state = self._get_time_step(state_times) # Check time coverage for forcing and boundary data - if self.da_forcing is not None or self.da_boundary is not None: + if self.da_forcing is not None or self.da_boundary_forcing is not None: if self.datastore.is_forecast: state_times = self.da_state.analysis_time else: @@ -176,17 +176,17 @@ def __init__( forcing_times.values ) - if self.da_boundary is not None: + if self.da_boundary_forcing is not None: # Boundary data is part of a separate datastore # The boundary data is allowed to have a different time_step # Check that the boundary data covers the required time range if self.datastore_boundary.is_forecast: - boundary_times = self.da_boundary.analysis_time + boundary_times = self.da_boundary_forcing.analysis_time self.forecast_step_boundary = self._get_time_step( - self.da_boundary.elapsed_forecast_duration + self.da_boundary_forcing.elapsed_forecast_duration ) else: - boundary_times = self.da_boundary.time + boundary_times = self.da_boundary_forcing.time self.time_step_boundary = self._get_time_step( boundary_times.values ) @@ -238,7 +238,7 @@ def __init__( self.da_forcing_std = self.ds_forcing_stats.forcing_std # XXX: Again, the boundary data is considered forcing data for now - if self.da_boundary is not None: + if self.da_boundary_forcing is not None: self.ds_boundary_stats = ( self.datastore_boundary.get_standardization_dataarray( category="forcing" @@ -601,8 +601,8 @@ def _build_item_dataarrays(self, idx): else: da_forcing = None - if self.da_boundary is not None: - da_boundary = self.da_boundary + if self.da_boundary_forcing is not None: + da_boundary = self.da_boundary_forcing else: da_boundary = None From 71b52b248caef8ffb15a0147988f60db9e41fcda Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 11:31:36 +0100 Subject: [PATCH 089/103] updated docstrings of get_dataarray() --- neural_lam/datastore/base.py | 8 +++----- neural_lam/datastore/mdp.py | 5 ++--- tests/dummy_datastore.py | 8 +++----- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index e2d21404..b9de2da5 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -193,11 +193,9 @@ def get_dataarray( """ Return the processed data (as a single `xr.DataArray`) for the given category of data and test/train/val-split that covers all the data (in - space and time) of a given category (state/forcing/static). A - datastore must be able to return for the "state" category, but - "forcing" and "static" are optional (in which case the method should - return `None`). For the "static" category the `split` is allowed to be - `None` because the static data is the same for all splits. + space and time) of a given category (state/forcing/static). For the + "static" category the `split` is allowed to be `None` because the static + data is the same for all splits. The returned dataarray is expected to at minimum have dimensions of `(grid_index, {category}_feature)` so that any spatial dimensions have diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index f68bb4d0..8f488910 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -218,9 +218,8 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: """ Return the processed data (as a single `xr.DataArray`) for the given category of data and test/train/val-split that covers all the data (in - space and time) of a given category (state/forcin g/static). "state" is - the only required category, for other categories, the method will - return `None` if the category is not found in the datastore. + space and time) of a given category (state/forcing/static). The method + will return `None` if the category is not found in the datastore. The returned dataarray will at minimum have dimensions of `(grid_index, {category}_feature)` so that any spatial dimensions have been stacked diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index a958b8f5..1bdbc8c8 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -300,11 +300,9 @@ def get_dataarray( """ Return the processed data (as a single `xr.DataArray`) for the given category of data and test/train/val-split that covers all the data (in - space and time) of a given category (state/forcing/static). A - datastore must be able to return for the "state" category, but - "forcing" and "static" are optional (in which case the method should - return `None`). For the "static" category the `split` is allowed to be - `None` because the static data is the same for all splits. + space and time) of a given category (state/forcing/static). For the + "static" category the `split` is allowed to be `None` because the static + data is the same for all splits. The returned dataarray is expected to at minimum have dimensions of `(grid_index, {category}_feature)` so that any spatial dimensions have From b69056341eb288c5a439fcdf038ae3172aee52c3 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 12:17:29 +0100 Subject: [PATCH 090/103] check times in stateless functions from utils.py --- neural_lam/utils.py | 90 +++++++++++++++++++++++++ neural_lam/weather_dataset.py | 120 ++++++++++------------------------ 2 files changed, 124 insertions(+), 86 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 4a0752e4..f55f17da 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -3,6 +3,7 @@ import shutil # Third-party +import numpy as np import torch from torch import nn from tueplots import bundles, figsizes @@ -241,3 +242,92 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric("val_mean_loss", summary="min") for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") + + +def get_time_step(self, times): + """Calculate the time step from a time dataarray. + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + + Returns + ------- + time_step : float + The time step in the the datetime-format of the times dataarray. + """ + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + +def check_time_overlap( + da1, + da2, + da1_is_forecast=False, + da2_is_forecast=False, + num_past_steps=1, + num_future_steps=1, +): + """Check that the time coverage of two dataarrays overlap. + + Parameters + ---------- + da1 : xr.DataArray + The first dataarray to check. + da2 : xr.DataArray + The second dataarray to check. + da1_is_forecast : bool, optional + Whether the first dataarray is forecast data. + da2_is_forecast : bool, optional + Whether the second dataarray is forecast data. + num_past_steps : int, optional + Number of past forcing steps. + num_future_steps : int, optional + Number of future forcing steps. + + Raises + ------ + ValueError + If the time coverage of the dataarrays does not overlap. + """ + + if da1_is_forecast: + times_da1 = da1.analysis_time + else: + times_da1 = da1.time + time_min_da1 = times_da1.min().values + time_max_da1 = times_da1.max().values + + if da2_is_forecast: + times_da2 = da2.analysis_time + _ = get_time_step(da2.elapsed_forecast_duration) + else: + times_da2 = da2.time + time_step_da2 = get_time_step(times_da2.values) + + time_min_da2 = da2.min().values + time_max_da2 = da2.max().values + + # Calculate required bounds for da2 using its time step + da2_required_time_min = time_min_da1 - num_past_steps * time_step_da2 + da2_required_time_max = time_max_da1 + num_future_steps * time_step_da2 + + if time_min_da2 > da2_required_time_min: + raise ValueError( + f"The second DataArray ('Boundary forcing'?) data starts too late." + f"Required start: {da2_required_time_min}, " + f"but DataArray starts at {time_min_da2}." + ) + + if time_max_da2 < da2_required_time_max: + raise ValueError( + f"The second DataArray ('Boundary forcing'?) ends too early." + f"Required end: {da2_required_time_max}, " + f"but DataArray ends at {time_max_da2}." + ) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index f20e3506..c6b142ec 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -11,6 +11,7 @@ # First-party from neural_lam.datastore.base import BaseDatastore +from neural_lam.utils import check_time_overlap, get_time_step class WeatherDataset(torch.utils.data.Dataset): @@ -142,80 +143,48 @@ def __init__( else: self.da_state = self.da_state - # Check time step consistency in state data + # Check time step consistency in state data and determine time steps + # for state, forcing and boundary data if self.datastore.is_forecast: state_times = self.da_state.analysis_time - self.forecast_step_state = self._get_time_step( + self.forecast_step_state = get_time_step( self.da_state.elapsed_forecast_duration ) else: state_times = self.da_state.time - self.time_step_state = self._get_time_step(state_times) - - # Check time coverage for forcing and boundary data - if self.da_forcing is not None or self.da_boundary_forcing is not None: + self.time_step_state = get_time_step(state_times) + if self.da_forcing is not None: + # Forcing data is part of the same datastore as state data + # During creation the time dimension of the forcing data + # is matched to the state data if self.datastore.is_forecast: - state_times = self.da_state.analysis_time - else: - state_times = self.da_state.time - state_time_min = state_times.min().values - state_time_max = state_times.max().values - - if self.da_forcing is not None: - # Forcing data is part of the same datastore as state data - # During creation the time dimension of the forcing data - # is matched to the state data - if self.datastore.is_forecast: - forcing_times = self.da_forcing.analysis_time - self.forecast_step_forcing = self._get_time_step( - self.da_forcing.elapsed_forecast_duration - ) - else: - forcing_times = self.da_forcing.time - self.time_step_forcing = self._get_time_step( - forcing_times.values + forcing_times = self.da_forcing.analysis_time + self.forecast_step_forcing = self._get_time_step( + self.da_forcing.elapsed_forecast_duration ) - - if self.da_boundary_forcing is not None: - # Boundary data is part of a separate datastore - # The boundary data is allowed to have a different time_step - # Check that the boundary data covers the required time range - if self.datastore_boundary.is_forecast: - boundary_times = self.da_boundary_forcing.analysis_time - self.forecast_step_boundary = self._get_time_step( - self.da_boundary_forcing.elapsed_forecast_duration - ) - else: - boundary_times = self.da_boundary_forcing.time - self.time_step_boundary = self._get_time_step( - boundary_times.values - ) - boundary_time_min = boundary_times.min().values - boundary_time_max = boundary_times.max().values - - # Calculate required bounds for boundary using its time step - boundary_required_time_min = ( - state_time_min - - self.num_past_forcing_steps * self.time_step_boundary - ) - boundary_required_time_max = ( - state_time_max - + self.num_future_forcing_steps * self.time_step_boundary + else: + forcing_times = self.da_forcing.time + self.time_step_forcing = self._get_time_step(forcing_times.values) + # Boundary data is part of a separate datastore + # The boundary data is allowed to have a different time_step + # Check that the boundary data covers the required time range + if self.datastore_boundary.is_forecast: + boundary_times = self.da_boundary_forcing.analysis_time + self.forecast_step_boundary = self._get_time_step( + self.da_boundary_forcing.elapsed_forecast_duration ) - - if boundary_time_min > boundary_required_time_min: - raise ValueError( - f"Boundary data starts too late." - f"Required start: {boundary_required_time_min}, " - f"but boundary starts at {boundary_time_min}." - ) - - if boundary_time_max < boundary_required_time_max: - raise ValueError( - f"Boundary data ends too early." - f"Required end: {boundary_required_time_max}, " - f"but boundary ends at {boundary_time_max}." - ) + else: + boundary_times = self.da_boundary_forcing.time + self.time_step_boundary = self._get_time_step(boundary_times.values) + + check_time_overlap( + self.da_state, + self.da_boundary_forcing, + da1_is_forecast=self.datastore.is_forecast, + da2_is_forecast=self.datastore_boundary.is_forecast, + num_past_steps=self.num_past_boundary_steps, + num_future_steps=self.num_future_boundary_steps, + ) # Set up for standardization # TODO: This will become part of ar_model.py soon! @@ -293,27 +262,6 @@ def __len__(self): - self.num_future_forcing_steps ) - def _get_time_step(self, times): - """Calculate the time step from the data - - Parameters - ---------- - times : xr.DataArray - The time dataarray to calculate the time step from. - - Returns - ------- - time_step : float - The time step in the the format of the times dataarray. - """ - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] - def _slice_time( self, da_state, From a37dc3ceddfdb9421767528a03e08147bbe4a185 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 14:03:22 +0100 Subject: [PATCH 091/103] add num_ensemble_members property to BaseDatastore --- neural_lam/datastore/base.py | 13 +++++++++++++ neural_lam/datastore/npyfilesmeps/store.py | 6 ++++-- neural_lam/weather_dataset.py | 18 +++++++++--------- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index b9de2da5..84600b50 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -298,6 +298,19 @@ def num_grid_points(self) -> int: """ pass + @property + @abc.abstractmethod + def num_ensemble_members(self) -> int: + """Return the number of ensemble members in the dataset. + + Returns + ------- + int + The number of ensemble members in the dataset. + + """ + pass + @cached_property @abc.abstractmethod def state_feature_weights_values(self) -> List[float]: diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 24349e7e..b91f7291 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -166,7 +166,6 @@ def __init__( self._root_path = self._config_path.parent self._config = NpyDatastoreConfig.from_yaml_file(self._config_path) - self._num_ensemble_members = self.config.dataset.num_ensemble_members self._num_timesteps = self.config.dataset.num_timesteps self._step_length = self.config.dataset.step_length self._remove_state_features_with_index = ( @@ -199,6 +198,9 @@ def config(self) -> NpyDatastoreConfig: """ return self._config + def num_ensemble_members(self) -> int: + return self.config.dataset.num_ensemble_members + def get_dataarray(self, category: str, split: str) -> DataArray: """ Get the data array for the given category and split of data. If the @@ -230,7 +232,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray: if category == "state": das = [] # for the state category, we need to load all ensemble members - for member in range(self._num_ensemble_members): + for member in range(self.num_ensemble_members): da_member = self._get_single_timeseries_dataarray( features=self.get_vars_names(category="state"), split=split, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index c6b142ec..d2fb5921 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -217,21 +217,21 @@ def __init__( self.da_boundary_std = self.ds_boundary_stats.forcing_std def __len__(self): + + if self.datastore.is_ensemble: + warnings.warn( + "only using first ensemble member, so dataset size is " + " effectively reduced by the number of ensemble members " + f"({self.datastore.num_ensemble_members})", + UserWarning, + ) + if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time # and then take the first (2 + ar_steps) forecast times. In # addition we only use the first ensemble member (if ensemble data # has been provided). # This means that for each analysis time we get a single sample - - if self.datastore.is_ensemble: - warnings.warn( - "only using first ensemble member, so dataset size is " - " effectively reduced by the number of ensemble members " - f"({self.datastore._num_ensemble_members})", - UserWarning, - ) - # check that there are enough forecast steps available to create # samples given the number of autoregressive steps requested n_forecast_steps = self.da_state.elapsed_forecast_duration.size From 8d1bec6a72ecc628f2dffec405173ee28fa1680f Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:08:07 +0100 Subject: [PATCH 092/103] Update neural_lam/weather_dataset.py Co-authored-by: Leif Denby --- neural_lam/weather_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index d2fb5921..f311b7e8 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -276,7 +276,8 @@ def _slice_time( Produce time slices of the given dataarrays `da_state` (state) and `da_forcing`. For the state data, slicing is done based on `idx`. For the forcing/boundary data, nearest neighbor matching - is performed based on the state times. Additionally, the time difference + is performed based on the state times (assuming constant timestep size). + Additionally, the time difference between the matched forcing/boundary times and state times (in multiples of state time steps) is added to the forcing dataarray. This will be used as an additional input feature in the model (temporal embedding). From 47370f9636c355e10f7290c900955de0085c80eb Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 14:17:18 +0100 Subject: [PATCH 093/103] renaming time_diff_steps to time_deltas --- neural_lam/weather_dataset.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index d2fb5921..02a940b0 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -217,7 +217,6 @@ def __init__( self.da_boundary_std = self.ds_boundary_stats.forcing_std def __len__(self): - if self.datastore.is_ensemble: warnings.warn( "only using first ensemble member, so dataset size is " @@ -423,7 +422,7 @@ def _slice_time( da_forcing_matched = xr.concat(da_list, dim="time") - # Generate temporal embedding `time_diff_steps` for the + # Generate temporal embedding `time_deltas` for the # forcing/boundary data. This is the time difference in multiples # of state time steps between the forcing/boundary time and the # state time @@ -435,7 +434,7 @@ def _slice_time( else: boundary_time_step = self.time_step_boundary state_time_step = self.time_step_state - time_diff_steps = ( + time_deltas = ( da_forcing_matched["window"] * (boundary_time_step / state_time_step), ) @@ -446,18 +445,18 @@ def _slice_time( else: forcing_time_step = self.time_step_forcing state_time_step = self.time_step_state - time_diff_steps = ( + time_deltas = ( da_forcing_matched["window"] * (forcing_time_step / state_time_step), ) - time_diff_steps = da_forcing_matched.isel( + time_deltas = da_forcing_matched.isel( grid_index=0, forcing_feature=0 ).window.values # Add time difference as a new coordinate to concatenate to the # forcing features later as temporal embedding - da_forcing_matched["time_diff_steps"] = ( + da_forcing_matched["time_deltas"] = ( ("window"), - time_diff_steps, + time_deltas, ) return da_state_sliced, da_forcing_matched @@ -494,12 +493,12 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): ) # Add the time step differences as a new feature to the windowed # data - time_diff_steps = da_windowed["time_diff_steps"].isel( + time_deltas = da_windowed["time_deltas"].isel( forcing_feature_windowed=slice(0, window_size) ) # All data variables share the same temporal embedding da_windowed = xr.concat( - [da_windowed, time_diff_steps], + [da_windowed, time_deltas], dim="forcing_feature_windowed", ) else: From d52437717a41b8de9324e5987c35572ae768401c Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 14:23:53 +0100 Subject: [PATCH 094/103] add num_ensemble_members to mdp store --- neural_lam/datastore/mdp.py | 12 ++++++++++++ neural_lam/datastore/npyfilesmeps/store.py | 1 + 2 files changed, 13 insertions(+) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 8f488910..3682a51e 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -140,6 +140,18 @@ def step_length(self) -> int: da_dt = self._ds["time"].diff("time") return (da_dt.dt.seconds[0] // 3600).item() + @property + def num_ensemble_members(self) -> int: + """The number of ensemble members in the dataset. + + Returns + ------- + int + The number of ensemble members in the dataset. + + """ + return None + def get_vars_units(self, category: str) -> List[str]: """Return the units of the variables in the given category. diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index b91f7291..1b0f6065 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -198,6 +198,7 @@ def config(self) -> NpyDatastoreConfig: """ return self._config + @property def num_ensemble_members(self) -> int: return self.config.dataset.num_ensemble_members From 98c54d9856f4dead1ed8fd8951a69795931cb5e6 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 14:31:25 +0100 Subject: [PATCH 095/103] Rename temporal embeddings and diffs to time deltas --- neural_lam/models/ar_model.py | 2 +- neural_lam/weather_dataset.py | 39 ++++++++++++++++------------------- tests/test_datasets.py | 2 +- tests/test_time_slicing.py | 9 ++++---- 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 81d5a623..e21faf43 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -110,7 +110,7 @@ def __init__( self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim - # Temporal Embedding counts as one additional forcing_feature + # Time deltas count as one additional forcing_feature + (num_forcing_vars + 1) * (num_past_forcing_steps + num_future_forcing_steps + 1) ) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index db2b2c70..2f37b4b3 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -273,13 +273,13 @@ def _slice_time( ): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing`. For the state data, slicing is done - based on `idx`. For the forcing/boundary data, nearest neighbor matching - is performed based on the state times (assuming constant timestep size). - Additionally, the time difference - between the matched forcing/boundary times and state times (in multiples - of state time steps) is added to the forcing dataarray. This will be - used as an additional input feature in the model (temporal embedding). + `da_forcing`. For the state data, slicing is done based on `idx`. For + the forcing/boundary data, nearest neighbor matching is performed based + on the state times (assuming constant timestep size). Additionally, the + time deltas between the matched forcing/boundary times and state times + (in multiples of state time steps) is added to the forcing dataarray. + This will be used as an additional input feature in the model (as + temporal embedding). Parameters ---------- @@ -423,10 +423,9 @@ def _slice_time( da_forcing_matched = xr.concat(da_list, dim="time") - # Generate temporal embedding `time_deltas` for the - # forcing/boundary data. This is the time difference in multiples - # of state time steps between the forcing/boundary time and the - # state time + # Generate time_deltas for the forcing/boundary data. This is the time + # difference in multiples of state time steps between the + # forcing/boundary time and the state time if is_boundary: if self.datastore_boundary.is_forecast: @@ -453,8 +452,8 @@ def _slice_time( time_deltas = da_forcing_matched.isel( grid_index=0, forcing_feature=0 ).window.values - # Add time difference as a new coordinate to concatenate to the - # forcing features later as temporal embedding + # Add time deltas as a new coordinate to concatenate to the + # forcing features later as temporal embedding in the model da_forcing_matched["time_deltas"] = ( ("window"), time_deltas, @@ -465,7 +464,7 @@ def _slice_time( def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data. This function stacks the 'forcing_feature' and 'window' dimensions and adds the time step - differences to the existing features as a temporal embedding. + deltas to the existing features. Parameters ---------- @@ -487,17 +486,16 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): if da_windowed is not None: window_size = da_windowed.window.size # Stack the 'feature' and 'window' dimensions and add the - # time step differences to the existing features as a temporal - # embedding + # time deltas to the existing features da_windowed = da_windowed.stack( {stacked_dim: ("forcing_feature", "window")} ) - # Add the time step differences as a new feature to the windowed + # Add the time deltas a new feature to the windowed # data time_deltas = da_windowed["time_deltas"].isel( forcing_feature_windowed=slice(0, window_size) ) - # All data variables share the same temporal embedding + # All data variables share the same time deltas da_windowed = xr.concat( [da_windowed, time_deltas], dim="forcing_feature_windowed", @@ -616,9 +614,8 @@ def _build_item_dataarrays(self, idx): ) / self.da_boundary_std # This function handles the stacking of the forcing and boundary data - # and adds the time step differences as a temporal embedding. - # It can handle `None` inputs for the forcing and boundary data - # (and simlpy return an empty DataArray in that case). + # and adds the time deltas. It can handle `None` inputs for the forcing + # and boundary data (and simlpy return an empty DataArray in that case). da_forcing_windowed = self._process_windowed_data( da_forcing_windowed, da_state, da_target_times ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index aa7b645d..6031fc81 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -82,7 +82,7 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): assert forcing.ndim == 3 assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints - # each time step in the window has one corresponding temporal embedding + # each time step in the window has one corresponding time deltas # that is shared across all grid points, times and variables assert forcing.shape[2] == (datastore.get_num_data_vars("forcing") + 1) * ( num_past_forcing_steps + num_future_forcing_steps + 1 diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 21038e7b..a8afdacd 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -223,11 +223,10 @@ def test_time_slicing_analysis( assert forcing.shape == ( ar_steps, 1, - total_forcing_window - * 2, # Each windowed feature includes temporal embedding + total_forcing_window * 2, # Each windowed feature includes time deltas ) - # Extract the forcing values from the tensor (excluding temporal embeddings) + # Extract the forcing values from the tensor (excluding time deltas) forcing_values = forcing[:, 0, :total_forcing_window] # Compare with expected forcing values @@ -313,11 +312,11 @@ def test_time_slicing_forecast( ar_steps, # Number of AR steps 1, # Number of grid points total_forcing_window # Total number of forcing steps in the window - * 2, # Each windowed feature includes temporal embedding + * 2, # Each windowed feature includes time deltas ) assert forcing.shape == expected_forcing_shape - # Extract the forcing values from the tensor (excluding temporal embeddings) + # Extract the forcing values from the tensor (excluding time deltas) forcing_values = forcing[:, 0, :total_forcing_window] # Compare with expected forcing values From 4a278fd9dea3fd8ce7bce91e88cb6259fa42cc6a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 14:40:52 +0100 Subject: [PATCH 096/103] Adding some comments about analysis_time indexing --- tests/test_time_slicing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index a8afdacd..daec72f2 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -248,6 +248,7 @@ def test_time_slicing_forecast( len(STATE_VALUES_FORECAST) ) ELAPSED_FORECAST_DURATION = np.timedelta64(0, "D") + np.arange( + # Retrieving the first analysis_time len(FORCING_VALUES_FORECAST[0]) ) # Create a dummy datastore with forecast data @@ -281,6 +282,7 @@ def test_time_slicing_forecast( # Compute expected initial states and target states based on ar_steps offset = max(0, num_past_forcing_steps - INIT_STEPS) init_idx = INIT_STEPS + offset + # Retrieving the first analysis_time expected_init_states = STATE_VALUES_FORECAST[0][offset:init_idx] expected_target_states = STATE_VALUES_FORECAST[0][ init_idx : init_idx + ar_steps @@ -293,6 +295,8 @@ def test_time_slicing_forecast( for i in range(ar_steps): start_idx = i + init_idx - num_past_forcing_steps end_idx = i + init_idx + num_future_forcing_steps + 1 + # Retrieving the analysis_time relevant for forcing-windows (i.e. + # the first analysis_time after the 2 init_steps) forcing_window = FORCING_VALUES_FORECAST[INIT_STEPS][start_idx:end_idx] expected_forcing_values.append(forcing_window) From c82d22ba9749c8f1fcd04c7bd7e1301881951622 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 14:46:12 +0100 Subject: [PATCH 097/103] moved comments around --- neural_lam/weather_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 2f37b4b3..0577f89f 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -154,9 +154,6 @@ def __init__( state_times = self.da_state.time self.time_step_state = get_time_step(state_times) if self.da_forcing is not None: - # Forcing data is part of the same datastore as state data - # During creation the time dimension of the forcing data - # is matched to the state data if self.datastore.is_forecast: forcing_times = self.da_forcing.analysis_time self.forecast_step_forcing = self._get_time_step( @@ -165,9 +162,6 @@ def __init__( else: forcing_times = self.da_forcing.time self.time_step_forcing = self._get_time_step(forcing_times.values) - # Boundary data is part of a separate datastore - # The boundary data is allowed to have a different time_step - # Check that the boundary data covers the required time range if self.datastore_boundary.is_forecast: boundary_times = self.da_boundary_forcing.analysis_time self.forecast_step_boundary = self._get_time_step( @@ -177,6 +171,12 @@ def __init__( boundary_times = self.da_boundary_forcing.time self.time_step_boundary = self._get_time_step(boundary_times.values) + # Forcing data is part of the same datastore as state data + # During creation the time dimension of the forcing data + # is matched to the state data + # Boundary data is part of a separate datastore + # The boundary data is allowed to have a different time_step + # Check that the boundary data covers the required time range check_time_overlap( self.da_state, self.da_boundary_forcing, From 6e3f3bd42ad804ba80765369c4688908c701f99c Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Thu, 19 Dec 2024 12:12:01 +0100 Subject: [PATCH 098/103] Make hotfix to make boundary dataset created with mdp work --- neural_lam/datastore/mdp.py | 76 +++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 19 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 3682a51e..4007c192 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -1,5 +1,6 @@ # Standard library import copy +import functools import warnings from functools import cached_property from pathlib import Path @@ -8,9 +9,9 @@ # Third-party import cartopy.crs as ccrs import mllam_data_prep as mdp +import numpy as np import xarray as xr from loguru import logger -from numpy import ndarray # Local from .base import BaseRegularGridDatastore, CartesianGridShape @@ -86,6 +87,8 @@ def __init__(self, config_path, reuse_existing=True): print("With the following splits (over time):") for split in required_splits: da_split = self._ds.splits.sel(split_name=split) + if "grid_index" in da_split.coords: + da_split = da_split.isel(grid_index=0) da_split_start = da_split.sel(split_part="start").load().item() da_split_end = da_split.sel(split_part="end").load().item() print(f" {split:<8s}: {da_split_start} to {da_split_end}") @@ -266,27 +269,15 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: da_category = self._ds[category] - # set units on x y coordinates if missing - for coord in ["x", "y"]: - if "units" not in da_category[coord].attrs: - da_category[coord].attrs["units"] = "m" - # set multi-index for grid-index da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS) if "time" in da_category.dims: - t_start = ( - self._ds.splits.sel(split_name=split) - .sel(split_part="start") - .load() - .item() - ) - t_end = ( - self._ds.splits.sel(split_name=split) - .sel(split_part="end") - .load() - .item() - ) + da_split = self._ds.splits.sel(split_name=split) + if "grid_index" in da_split.coords: + da_split = da_split.isel(grid_index=0) + t_start = da_split.sel(split_part="start").load().item() + t_end = da_split.sel(split_part="end").load().item() da_category = da_category.sel(time=slice(t_start, t_end)) dim_order = self.expected_dim_order(category=category) @@ -324,6 +315,8 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: ) ds_stats = self._ds[stats_variables.keys()].rename(stats_variables) + if "grid_index" in ds_stats.coords: + ds_stats = ds_stats.isel(grid_index=0) return ds_stats @property @@ -404,7 +397,7 @@ def grid_shape_state(self): assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) - def get_xy(self, category: str, stacked: bool) -> ndarray: + def get_xy(self, category: str, stacked: bool = True) -> np.ndarray: """Return the x, y coordinates of the dataset. Parameters @@ -449,3 +442,48 @@ def get_xy(self, category: str, stacked: bool) -> ndarray: da_xy = da_xy.transpose(*dims) return da_xy.values + + @functools.lru_cache + def get_lat_lon(self, category: str) -> np.ndarray: + """ + Return the longitude, latitude coordinates of the dataset as numpy + array for a given category of data. + Override in MDP to use lat/lons directly from xr.Dataset, if available. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + np.ndarray + The longitude, latitude coordinates of the dataset + with shape `[n_grid_points, 2]`. + """ + # Check first if lat/lon saved in ds + lookup_ds = self._ds + if "latitude" in lookup_ds.coords and "longitude" in lookup_ds.coords: + lon = lookup_ds.longitude + lat = lookup_ds.latitude + elif "lat" in lookup_ds.coords and "lon" in lookup_ds.coords: + lon = lookup_ds.lon + lat = lookup_ds.lat + else: + # Not saved, use method from BaseDatastore to derive from x/y + return super().get_lat_lon(category) + + coords = np.stack((lon.values, lat.values), axis=1) + return coords + + @property + def num_grid_points(self) -> int: + """Return the number of grid points in the dataset. + + Returns + ------- + int + The number of grid points in the dataset. + + """ + return len(self._ds.grid_index) From 20ca2636c363fd82323aaff31594259cb5f49fbd Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 15:34:20 +0100 Subject: [PATCH 099/103] Bugfixes --- neural_lam/datastore/mdp.py | 2 +- neural_lam/utils.py | 2 +- neural_lam/weather_dataset.py | 8 ++++---- tests/dummy_datastore.py | 12 ++++++++++++ 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 4007c192..3f1e0441 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -153,7 +153,7 @@ def num_ensemble_members(self) -> int: The number of ensemble members in the dataset. """ - return None + return 1 def get_vars_units(self, category: str) -> List[str]: """Return the units of the variables in the given category. diff --git a/neural_lam/utils.py b/neural_lam/utils.py index f55f17da..c2bc3c57 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -244,7 +244,7 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric(f"val_loss_unroll{step}", summary="min") -def get_time_step(self, times): +def get_time_step(times): """Calculate the time step from a time dataarray. Parameters diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 0577f89f..91d68462 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -156,20 +156,20 @@ def __init__( if self.da_forcing is not None: if self.datastore.is_forecast: forcing_times = self.da_forcing.analysis_time - self.forecast_step_forcing = self._get_time_step( + self.forecast_step_forcing = get_time_step( self.da_forcing.elapsed_forecast_duration ) else: forcing_times = self.da_forcing.time - self.time_step_forcing = self._get_time_step(forcing_times.values) + self.time_step_forcing = get_time_step(forcing_times.values) if self.datastore_boundary.is_forecast: boundary_times = self.da_boundary_forcing.analysis_time - self.forecast_step_boundary = self._get_time_step( + self.forecast_step_boundary = get_time_step( self.da_boundary_forcing.elapsed_forecast_duration ) else: boundary_times = self.da_boundary_forcing.time - self.time_step_boundary = self._get_time_step(boundary_times.values) + self.time_step_boundary = get_time_step(boundary_times.values) # Forcing data is part of the same datastore as state data # During creation the time dimension of the forcing data diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 1bdbc8c8..dcc5510f 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -410,6 +410,18 @@ def num_grid_points(self) -> int: """ return self._num_grid_points + @property + def num_ensemble_members(self) -> int: + """Return the number of ensemble members in the dataset. + + Returns + ------- + int + The number of ensemble members in the dataset. + + """ + return 1 + @cached_property def grid_shape_state(self) -> CartesianGridShape: """The shape of the grid for the state variables. From c0c50d5e44580c6a8e486d44cee19be6ed1fd847 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 15:47:16 +0100 Subject: [PATCH 100/103] --- neural_lam/datastore/base.py | 19 +++++++++++++++---- neural_lam/datastore/mdp.py | 12 ------------ neural_lam/datastore/npyfilesmeps/store.py | 10 +++++++++- tests/dummy_datastore.py | 12 ------------ 4 files changed, 24 insertions(+), 29 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 84600b50..8b51b07e 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -47,7 +47,6 @@ class BaseDatastore(abc.ABC): each of the `x` and `y` coordinates. """ - is_ensemble: bool = False is_forecast: bool = False @property @@ -299,17 +298,29 @@ def num_grid_points(self) -> int: pass @property - @abc.abstractmethod def num_ensemble_members(self) -> int: """Return the number of ensemble members in the dataset. Returns ------- int - The number of ensemble members in the dataset. + The number of ensemble members in the dataset (default is 1 - + not an ensemble). """ - pass + return 1 + + @property + def is_ensemble(self) -> bool: + """Return whether the dataset represents ensemble data. + + Returns + ------- + bool + True if the dataset represents ensemble data, False otherwise. + + """ + return self.num_ensemble_members > 1 @cached_property @abc.abstractmethod diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 3f1e0441..b82c9277 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -143,18 +143,6 @@ def step_length(self) -> int: da_dt = self._ds["time"].diff("time") return (da_dt.dt.seconds[0] // 3600).item() - @property - def num_ensemble_members(self) -> int: - """The number of ensemble members in the dataset. - - Returns - ------- - int - The number of ensemble members in the dataset. - - """ - return 1 - def get_vars_units(self, category: str) -> List[str]: """Return the units of the variables in the given category. diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 1b0f6065..b4d93ca3 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -141,7 +141,6 @@ class NpyFilesDatastoreMEPS(BaseRegularGridDatastore): """ SHORT_NAME = "npyfilesmeps" - is_ensemble = True is_forecast = True def __init__( @@ -200,6 +199,15 @@ def config(self) -> NpyDatastoreConfig: @property def num_ensemble_members(self) -> int: + """Return the number of ensemble members in the dataset as defined in + the config file. + + Returns + ------- + int + The number of ensemble members in the dataset. + + """ return self.config.dataset.num_ensemble_members def get_dataarray(self, category: str, split: str) -> DataArray: diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index dcc5510f..1bdbc8c8 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -410,18 +410,6 @@ def num_grid_points(self) -> int: """ return self._num_grid_points - @property - def num_ensemble_members(self) -> int: - """Return the number of ensemble members in the dataset. - - Returns - ------- - int - The number of ensemble members in the dataset. - - """ - return 1 - @cached_property def grid_shape_state(self) -> CartesianGridShape: """The shape of the grid for the state variables. From 94de24018ea75a229245f54f3b8aa17cfc2d79a4 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 15:54:17 +0100 Subject: [PATCH 101/103] Add missing check if boundary_forcing is None --- neural_lam/weather_dataset.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 91d68462..bb9934b4 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -144,7 +144,8 @@ def __init__( self.da_state = self.da_state # Check time step consistency in state data and determine time steps - # for state, forcing and boundary data + # for state, forcing and boundary forcing data + # STATE if self.datastore.is_forecast: state_times = self.da_state.analysis_time self.forecast_step_state = get_time_step( @@ -153,6 +154,7 @@ def __init__( else: state_times = self.da_state.time self.time_step_state = get_time_step(state_times) + # FORCING if self.da_forcing is not None: if self.datastore.is_forecast: forcing_times = self.da_forcing.analysis_time @@ -162,6 +164,8 @@ def __init__( else: forcing_times = self.da_forcing.time self.time_step_forcing = get_time_step(forcing_times.values) + # BOUNDARY FORCING + if self.da_boundary_forcing is not None: if self.datastore_boundary.is_forecast: boundary_times = self.da_boundary_forcing.analysis_time self.forecast_step_boundary = get_time_step( @@ -177,14 +181,15 @@ def __init__( # Boundary data is part of a separate datastore # The boundary data is allowed to have a different time_step # Check that the boundary data covers the required time range - check_time_overlap( - self.da_state, - self.da_boundary_forcing, - da1_is_forecast=self.datastore.is_forecast, - da2_is_forecast=self.datastore_boundary.is_forecast, - num_past_steps=self.num_past_boundary_steps, - num_future_steps=self.num_future_boundary_steps, - ) + if self.da_boundary_forcing is not None: + check_time_overlap( + self.da_state, + self.da_boundary_forcing, + da1_is_forecast=self.datastore.is_forecast, + da2_is_forecast=self.datastore_boundary.is_forecast, + num_past_steps=self.num_past_boundary_steps, + num_future_steps=self.num_future_boundary_steps, + ) # Set up for standardization # TODO: This will become part of ar_model.py soon! From 1d14a157e81cef3b6584758390077b5327a82f61 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 20:23:10 +0100 Subject: [PATCH 102/103] bugfix typo in time check --- neural_lam/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index c2bc3c57..32f92cf2 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -311,8 +311,8 @@ def check_time_overlap( times_da2 = da2.time time_step_da2 = get_time_step(times_da2.values) - time_min_da2 = da2.min().values - time_max_da2 = da2.max().values + time_min_da2 = times_da2.min().values + time_max_da2 = times_da2.max().values # Calculate required bounds for da2 using its time step da2_required_time_min = time_min_da1 - num_past_steps * time_step_da2 From 7e5797e18f5a251441e98ef9ad63f412c06ef409 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 20 Dec 2024 21:39:30 +0100 Subject: [PATCH 103/103] introduce crop_time_if_needed to align interior with boundary data --- neural_lam/utils.py | 77 ++++++++++++++++++++++++++++++++++- neural_lam/weather_dataset.py | 30 ++++++++++---- 2 files changed, 98 insertions(+), 9 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 32f92cf2..2a8ba6ed 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -320,14 +320,87 @@ def check_time_overlap( if time_min_da2 > da2_required_time_min: raise ValueError( - f"The second DataArray ('Boundary forcing'?) data starts too late." + f"The second DataArray (e.g. 'boundary forcing') starts too late." f"Required start: {da2_required_time_min}, " f"but DataArray starts at {time_min_da2}." ) if time_max_da2 < da2_required_time_max: raise ValueError( - f"The second DataArray ('Boundary forcing'?) ends too early." + f"The second DataArray (e.g. 'boundary forcing') ends too early." f"Required end: {da2_required_time_max}, " f"but DataArray ends at {time_max_da2}." ) + + +def crop_time_if_needed( + da1, da2, da1_is_forecast=False, da2_is_forecast=False, num_past_steps=1 +): + """ + Slice away the first few timesteps from the first DataArray (e.g. 'state') + if the second DataArray (e.g. boundary forcing) does not cover that range + (including num_past_steps). + + Parameters + ---------- + da1 : xr.DataArray + The first DataArray to crop. + da2 : xr.DataArray + The second DataArray to compare against. + da1_is_forecast : bool, optional + Whether the first dataarray is forecast data. + da2_is_forecast : bool, optional + Whether the second dataarray is forecast data. + num_past_steps : int + Number of past time steps to consider. + + Return + ------ + da1 : xr.DataArray + The cropped first DataArray and print a warning if any steps are + removed. + """ + if da1 is None or da2 is None: + return da1 + + try: + check_time_overlap( + da1, + da2, + da1_is_forecast, + da2_is_forecast, + num_past_steps, + num_future_steps=0, + ) + return da1 + except ValueError: + # If da2 coverage is insufficient, remove earliest da1 times + # until coverage is possible. Figure out how many steps to remove. + if da1_is_forecast: + da1_tvals = da1.analysis_time.values + else: + da1_tvals = da1.time.values + if da2_is_forecast: + da2_tvals = da2.analysis_time.values + else: + da2_tvals = da2.time.values + + if da1_tvals[0] < da2_tvals[0]: + # Calculate how many steps to remove skip just enough steps so that: + if da2_is_forecast: + # The windowing for forecast type data happens in the + # elapsed_forecast_duration dimension, so we can omit it here. + required_min = da2_tvals[0] + else: + dt = get_time_step(da2_tvals) + required_min = da2_tvals[0] + num_past_steps * dt + first_valid_idx = (da1_tvals >= required_min).argmax() + n_removed = first_valid_idx + if n_removed > 0: + print( + f"Warning: removing {n_removed} da1 (e.g. 'state') " + f"timesteps to align with da2 (e.g. 'boundary forcing') " + f"coverage." + ) + da1 = da1.isel(time=slice(first_valid_idx, None)) + return da1 diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index bb9934b4..66165b6f 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -11,7 +11,11 @@ # First-party from neural_lam.datastore.base import BaseDatastore -from neural_lam.utils import check_time_overlap, get_time_step +from neural_lam.utils import ( + check_time_overlap, + crop_time_if_needed, + get_time_step, +) class WeatherDataset(torch.utils.data.Dataset): @@ -175,12 +179,24 @@ def __init__( boundary_times = self.da_boundary_forcing.time self.time_step_boundary = get_time_step(boundary_times.values) - # Forcing data is part of the same datastore as state data - # During creation the time dimension of the forcing data - # is matched to the state data - # Boundary data is part of a separate datastore - # The boundary data is allowed to have a different time_step - # Check that the boundary data covers the required time range + # Forcing data is part of the same datastore as state data. During + # creation, the time dimension of the forcing data is matched to the + # state data. + # Boundary data is part of a separate datastore The boundary data is + # allowed to have a different time_step Checks that the boundary data + # covers the required time range is required. + + # Crop interior data if boundary coverage is insufficient + if self.da_boundary_forcing is not None: + self.da_state = crop_time_if_needed( + self.da_state, + self.da_boundary_forcing, + da1_is_forecast=self.datastore.is_forecast, + da2_is_forecast=self.datastore_boundary.is_forecast, + num_past_steps=self.num_past_boundary_steps, + ) + + # Now do final overlap check and possibly raise errors if still invalid if self.da_boundary_forcing is not None: check_time_overlap( self.da_state,