From 79199956225277cb88b255a514be1a72634926c5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 07:09:52 +0100 Subject: [PATCH] cleanup analysis based time-slicing --- neural_lam/weather_dataset.py | 85 +++++++++++++++++------------------ 1 file changed, 42 insertions(+), 43 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index c8806d1c..bbfb5705 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -245,11 +245,12 @@ def __len__(self): def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing` (forcing). For the state data, slicing is done as before - based on `idx`. For the forcing data, nearest neighbor matching is - performed based on the state times. Additionally, the time difference - between the matched forcing times and state times (in multiples of state - time steps) is added to the forcing dataarray. + `da_forcing` (forcing). For the state data, slicing is done based on + `idx`. For the forcing data, nearest neighbor matching is performed + based on the state times. Additionally, the time difference between the + matched forcing times and state times (in multiples of state time steps) + is added to the forcing dataarray. This will be used as an additional + feature in the model (temporal embedding). Parameters ---------- @@ -269,9 +270,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). da_forcing_matched : xr.DataArray - The forcing dataarray matched to state times with an added - coordinate 'time_diff', representing the time difference to state - times in multiples of state time steps. + The sliced state dataarray with dims ('time', 'grid_index', + 'forcing_feature_windowed'). """ # Number of initial steps required (e.g., for initializing models) init_steps = 2 @@ -308,9 +308,9 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): if da_forcing is None: return da_state_sliced, None - # Get the state times for matching + # Get the state times and its temporal resolution for matching with + # forcing data state_times = da_state_sliced["time"] - # Calculate time differences in multiples of state time steps state_time_step = state_times.values[1] - state_times.values[0] # Match forcing data to state times based on nearest neighbor @@ -369,39 +369,29 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing["time"] - # Compute time differences + # Compute time differences between forcing and state times + # (in multiples of state time steps) + # Retrieve the indices of the closest times in the forcing data time_deltas = ( - state_times.values[np.newaxis, :] - - forcing_times.values[:, np.newaxis] - ) + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) - time_diff_steps = xr.DataArray( - np.stack( - [ - np.diagonal(time_deltas, offset=offset)[ - -len(state_times) + init_steps : - ] - / state_time_step - for offset in range( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ) - ], - axis=1, - ), - dims=["time", "window"], - coords={ - "time": state_times.isel(time=slice(init_steps, None)), - "window": np.arange( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ), - }, - name="time_diff_steps", + time_diff_steps = np.stack( + [ + time_deltas[ + idx_i + - self.num_past_forcing_steps : idx_i + + self.num_future_forcing_steps + + 1, + init_steps + step_i, + ] + for (step_i, idx_i) in enumerate(idx_min[init_steps:]) + ], ) - # Create window dimension using rolling + # Create window dimension for forcing data to stack later window_size = ( self.num_past_forcing_steps + self.num_future_forcing_steps + 1 ) @@ -412,9 +402,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): time=idx_min[init_steps:] ) - # Add time difference as a new coordinate - da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=time_diff_steps + # Add time difference as a new coordinate to concatenate to the + # forcing features later + da_forcing_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, ) return da_state_sliced, da_forcing_matched @@ -423,13 +415,19 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data after standardization.""" stacked_dim = "forcing_feature_windowed" if da_windowed is not None: - # Stack the 'feature' and 'window' dimensions + # Stack the 'feature' and 'window' dimensions and add the + # time step differences to the existing features as a temporal + # embedding da_windowed = da_windowed.stack( {stacked_dim: ("forcing_feature", "window")} ) + da_windowed = xr.concat( + [da_windowed, da_windowed.time_diff_steps], + dim="forcing_feature_windowed", + ) else: # Create empty DataArray with the correct dimensions and coordinates - return xr.DataArray( + da_windowed = xr.DataArray( data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), dims=("time", "grid_index", f"{stacked_dim}"), coords={ @@ -438,6 +436,7 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): f"{stacked_dim}": [], }, ) + return da_windowed def _build_item_dataarrays(self, idx): """