Skip to content

Commit

Permalink
Fixed test for temporal embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Nov 30, 2024
1 parent 85aad66 commit 64f057f
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions tests/test_time_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def get_dataarray(self, category, split):
if self.is_forecast:
raise NotImplementedError()
else:
da = xr.DataArray(
values, dims=["time"], coords={"time": self._time_values}
)
da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values})
# add `{category}_feature` and `grid_index` dimensions

da = da.expand_dims("grid_index")
Expand Down Expand Up @@ -78,10 +76,8 @@ def get_vars_long_names(self, category):
def test_time_slicing_analysis(
ar_steps, num_past_forcing_steps, num_future_forcing_steps
):
# state and forcing variables have only on dimension, `time`
time_values = np.datetime64("2020-01-01") + np.arange(
len(ANALYSIS_STATE_VALUES)
)
# state and forcing variables have only one dimension, `time`
time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES))
assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values)

datastore = SinglePointDummyDatastore(
Expand All @@ -93,6 +89,7 @@ def test_time_slicing_analysis(

dataset = WeatherDataset(
datastore=datastore,
datastore_boundary=None,
ar_steps=ar_steps,
num_future_forcing_steps=num_future_forcing_steps,
num_past_forcing_steps=num_past_forcing_steps,
Expand All @@ -101,9 +98,7 @@ def test_time_slicing_analysis(

sample = dataset[0]

init_states, target_states, forcing, _ = [
tensor.numpy() for tensor in sample
]
init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample]

expected_init_states = [0, 1]
if ar_steps == 3:
Expand All @@ -130,7 +125,7 @@ def test_time_slicing_analysis(

# init_states: (2, N_grid, d_features)
# target_states: (ar_steps, N_grid, d_features)
# forcing: (ar_steps, N_grid, d_windowed_forcing)
# forcing: (ar_steps, N_grid, d_windowed_forcing * 2)
# target_times: (ar_steps,)
assert init_states.shape == (2, 1, 1)
assert init_states[:, 0, 0].tolist() == expected_init_states
Expand All @@ -141,6 +136,10 @@ def test_time_slicing_analysis(
assert forcing.shape == (
3,
1,
1 + num_past_forcing_steps + num_future_forcing_steps,
# Factor 2 because each window step has a temporal embedding
(1 + num_past_forcing_steps + num_future_forcing_steps) * 2,
)
np.testing.assert_equal(
forcing[:, 0, : num_past_forcing_steps + num_future_forcing_steps + 1],
np.array(expected_forcing_values),
)
np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values))

0 comments on commit 64f057f

Please sign in to comment.