diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py index 74f5e44d..d268e87f 100644 --- a/tests/test_mllam_dataset.py +++ b/tests/test_mllam_dataset.py @@ -7,8 +7,8 @@ import pytest # First-party -from neural_lam.config import Config from neural_lam.build_graph import main as build_graph +from neural_lam.config import Config from neural_lam.train_model import main as train_model from neural_lam.utils import load_static_data from neural_lam.weather_dataset import WeatherDataset @@ -91,14 +91,14 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath): assert boundary_forcing.shape == ( n_prediction_timesteps, n_boundary, - 2 * n_grid + n_forcing_features, # TODO Adjust dimensionality + 2 * n_state_features + n_forcing_features, # TODO Adjust dimensionality ) - static_data = load_static_data(dataset_name=dataset_name) - required_props = { - "border_mask", + "boundary_mask", + "interior_mask", "grid_static_features", + "boundary_static_features", "step_diff_mean", "step_diff_std", "data_mean", @@ -107,7 +107,9 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath): } # check the sizes of the props - assert static_data["border_mask"].shape == (n_grid, 1) + # TODO Should this config not be for only interior? + nx, ny = config.values["grid_shape_state"] + assert n_grid + n_boundary == nx * ny assert static_data["grid_static_features"].shape == ( n_grid, n_grid_static_features,