From 64f057f78b713e39496abfc3962affa794666369 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:50 +0100 Subject: [PATCH] Fixed test for temporal embedding --- tests/test_time_slicing.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 29161505..2f5ed96c 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -40,9 +40,7 @@ def get_dataarray(self, category, split): if self.is_forecast: raise NotImplementedError() else: - da = xr.DataArray( - values, dims=["time"], coords={"time": self._time_values} - ) + da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values}) # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") @@ -78,10 +76,8 @@ def get_vars_long_names(self, category): def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): - # state and forcing variables have only on dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange( - len(ANALYSIS_STATE_VALUES) - ) + # state and forcing variables have only one dimension, `time` + time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -93,6 +89,7 @@ def test_time_slicing_analysis( dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, ar_steps=ar_steps, num_future_forcing_steps=num_future_forcing_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -101,9 +98,7 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _ = [ - tensor.numpy() for tensor in sample - ] + init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] expected_init_states = [0, 1] if ar_steps == 3: @@ -130,7 +125,7 @@ def test_time_slicing_analysis( # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) - # forcing: (ar_steps, N_grid, d_windowed_forcing) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) # target_times: (ar_steps,) assert init_states.shape == (2, 1, 1) assert init_states[:, 0, 0].tolist() == expected_init_states @@ -141,6 +136,10 @@ def test_time_slicing_analysis( assert forcing.shape == ( 3, 1, - 1 + num_past_forcing_steps + num_future_forcing_steps, + # Factor 2 because each window step has a temporal embedding + (1 + num_past_forcing_steps + num_future_forcing_steps) * 2, + ) + np.testing.assert_equal( + forcing[:, 0, : num_past_forcing_steps + num_future_forcing_steps + 1], + np.array(expected_forcing_values), ) - np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values))