Skip to content

Commit

Permalink
cleanup analysis based time-slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Nov 21, 2024
1 parent 17c920d commit 7919995
Showing 1 changed file with 42 additions and 43 deletions.
85 changes: 42 additions & 43 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,12 @@ def __len__(self):
def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None):
"""
Produce time slices of the given dataarrays `da_state` (state) and
`da_forcing` (forcing). For the state data, slicing is done as before
based on `idx`. For the forcing data, nearest neighbor matching is
performed based on the state times. Additionally, the time difference
between the matched forcing times and state times (in multiples of state
time steps) is added to the forcing dataarray.
`da_forcing` (forcing). For the state data, slicing is done based on
`idx`. For the forcing data, nearest neighbor matching is performed
based on the state times. Additionally, the time difference between the
matched forcing times and state times (in multiples of state time steps)
is added to the forcing dataarray. This will be used as an additional
feature in the model (temporal embedding).
Parameters
----------
Expand All @@ -269,9 +270,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None):
The sliced state dataarray with dims ('time', 'grid_index',
'state_feature').
da_forcing_matched : xr.DataArray
The forcing dataarray matched to state times with an added
coordinate 'time_diff', representing the time difference to state
times in multiples of state time steps.
The sliced state dataarray with dims ('time', 'grid_index',
'forcing_feature_windowed').
"""
# Number of initial steps required (e.g., for initializing models)
init_steps = 2
Expand Down Expand Up @@ -308,9 +308,9 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None):
if da_forcing is None:
return da_state_sliced, None

# Get the state times for matching
# Get the state times and its temporal resolution for matching with
# forcing data
state_times = da_state_sliced["time"]
# Calculate time differences in multiples of state time steps
state_time_step = state_times.values[1] - state_times.values[0]

# Match forcing data to state times based on nearest neighbor
Expand Down Expand Up @@ -369,39 +369,29 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None):
# For analysis data, match directly using the 'time' coordinate
forcing_times = da_forcing["time"]

# Compute time differences
# Compute time differences between forcing and state times
# (in multiples of state time steps)
# Retrieve the indices of the closest times in the forcing data
time_deltas = (
state_times.values[np.newaxis, :]
- forcing_times.values[:, np.newaxis]
)
forcing_times.values[:, np.newaxis]
- state_times.values[np.newaxis, :]
) / state_time_step
idx_min = np.abs(time_deltas).argmin(axis=0)

time_diff_steps = xr.DataArray(
np.stack(
[
np.diagonal(time_deltas, offset=offset)[
-len(state_times) + init_steps :
]
/ state_time_step
for offset in range(
-self.num_past_forcing_steps,
self.num_future_forcing_steps + 1,
)
],
axis=1,
),
dims=["time", "window"],
coords={
"time": state_times.isel(time=slice(init_steps, None)),
"window": np.arange(
-self.num_past_forcing_steps,
self.num_future_forcing_steps + 1,
),
},
name="time_diff_steps",
time_diff_steps = np.stack(
[
time_deltas[
idx_i
- self.num_past_forcing_steps : idx_i
+ self.num_future_forcing_steps
+ 1,
init_steps + step_i,
]
for (step_i, idx_i) in enumerate(idx_min[init_steps:])
],
)

# Create window dimension using rolling
# Create window dimension for forcing data to stack later
window_size = (
self.num_past_forcing_steps + self.num_future_forcing_steps + 1
)
Expand All @@ -412,9 +402,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None):
time=idx_min[init_steps:]
)

# Add time difference as a new coordinate
da_forcing_matched = da_forcing_matched.assign_coords(
time_diff=time_diff_steps
# Add time difference as a new coordinate to concatenate to the
# forcing features later
da_forcing_matched["time_diff_steps"] = (
("time", "window"),
time_diff_steps,
)

return da_state_sliced, da_forcing_matched
Expand All @@ -423,13 +415,19 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times):
"""Helper function to process windowed data after standardization."""
stacked_dim = "forcing_feature_windowed"
if da_windowed is not None:
# Stack the 'feature' and 'window' dimensions
# Stack the 'feature' and 'window' dimensions and add the
# time step differences to the existing features as a temporal
# embedding
da_windowed = da_windowed.stack(
{stacked_dim: ("forcing_feature", "window")}
)
da_windowed = xr.concat(
[da_windowed, da_windowed.time_diff_steps],
dim="forcing_feature_windowed",
)
else:
# Create empty DataArray with the correct dimensions and coordinates
return xr.DataArray(
da_windowed = xr.DataArray(
data=np.empty((self.ar_steps, da_state.grid_index.size, 0)),
dims=("time", "grid_index", f"{stacked_dim}"),
coords={
Expand All @@ -438,6 +436,7 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times):
f"{stacked_dim}": [],
},
)
return da_windowed

def _build_item_dataarrays(self, idx):
"""
Expand Down

0 comments on commit 7919995

Please sign in to comment.