Skip to content

Commit

Permalink
introduce crop_time_if_needed to align interior with boundary data
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Dec 20, 2024
1 parent 1d14a15 commit 7e5797e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 9 deletions.
77 changes: 75 additions & 2 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,87 @@ def check_time_overlap(

if time_min_da2 > da2_required_time_min:
raise ValueError(
f"The second DataArray ('Boundary forcing'?) data starts too late."
f"The second DataArray (e.g. 'boundary forcing') starts too late."
f"Required start: {da2_required_time_min}, "
f"but DataArray starts at {time_min_da2}."
)

if time_max_da2 < da2_required_time_max:
raise ValueError(
f"The second DataArray ('Boundary forcing'?) ends too early."
f"The second DataArray (e.g. 'boundary forcing') ends too early."
f"Required end: {da2_required_time_max}, "
f"but DataArray ends at {time_max_da2}."
)


def crop_time_if_needed(
da1, da2, da1_is_forecast=False, da2_is_forecast=False, num_past_steps=1
):
"""
Slice away the first few timesteps from the first DataArray (e.g. 'state')
if the second DataArray (e.g. boundary forcing) does not cover that range
(including num_past_steps).
Parameters
----------
da1 : xr.DataArray
The first DataArray to crop.
da2 : xr.DataArray
The second DataArray to compare against.
da1_is_forecast : bool, optional
Whether the first dataarray is forecast data.
da2_is_forecast : bool, optional
Whether the second dataarray is forecast data.
num_past_steps : int
Number of past time steps to consider.
Return
------
da1 : xr.DataArray
The cropped first DataArray and print a warning if any steps are
removed.
"""
if da1 is None or da2 is None:
return da1

try:
check_time_overlap(
da1,
da2,
da1_is_forecast,
da2_is_forecast,
num_past_steps,
num_future_steps=0,
)
return da1
except ValueError:
# If da2 coverage is insufficient, remove earliest da1 times
# until coverage is possible. Figure out how many steps to remove.
if da1_is_forecast:
da1_tvals = da1.analysis_time.values
else:
da1_tvals = da1.time.values
if da2_is_forecast:
da2_tvals = da2.analysis_time.values
else:
da2_tvals = da2.time.values

if da1_tvals[0] < da2_tvals[0]:
# Calculate how many steps to remove skip just enough steps so that:
if da2_is_forecast:
# The windowing for forecast type data happens in the
# elapsed_forecast_duration dimension, so we can omit it here.
required_min = da2_tvals[0]
else:
dt = get_time_step(da2_tvals)
required_min = da2_tvals[0] + num_past_steps * dt
first_valid_idx = (da1_tvals >= required_min).argmax()
n_removed = first_valid_idx
if n_removed > 0:
print(
f"Warning: removing {n_removed} da1 (e.g. 'state') "
f"timesteps to align with da2 (e.g. 'boundary forcing') "
f"coverage."
)
da1 = da1.isel(time=slice(first_valid_idx, None))
return da1
30 changes: 23 additions & 7 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

# First-party
from neural_lam.datastore.base import BaseDatastore
from neural_lam.utils import check_time_overlap, get_time_step
from neural_lam.utils import (
check_time_overlap,
crop_time_if_needed,
get_time_step,
)


class WeatherDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -175,12 +179,24 @@ def __init__(
boundary_times = self.da_boundary_forcing.time
self.time_step_boundary = get_time_step(boundary_times.values)

# Forcing data is part of the same datastore as state data
# During creation the time dimension of the forcing data
# is matched to the state data
# Boundary data is part of a separate datastore
# The boundary data is allowed to have a different time_step
# Check that the boundary data covers the required time range
# Forcing data is part of the same datastore as state data. During
# creation, the time dimension of the forcing data is matched to the
# state data.
# Boundary data is part of a separate datastore The boundary data is
# allowed to have a different time_step Checks that the boundary data
# covers the required time range is required.

# Crop interior data if boundary coverage is insufficient
if self.da_boundary_forcing is not None:
self.da_state = crop_time_if_needed(
self.da_state,
self.da_boundary_forcing,
da1_is_forecast=self.datastore.is_forecast,
da2_is_forecast=self.datastore_boundary.is_forecast,
num_past_steps=self.num_past_boundary_steps,
)

# Now do final overlap check and possibly raise errors if still invalid
if self.da_boundary_forcing is not None:
check_time_overlap(
self.da_state,
Expand Down

0 comments on commit 7e5797e

Please sign in to comment.