Skip to content

Commit

Permalink
Fix data loading test
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 12, 2024
1 parent 42be03f commit ce204b6
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/test_mllam_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import pytest

# First-party
from neural_lam.config import Config
from neural_lam.build_graph import main as build_graph
from neural_lam.config import Config
from neural_lam.train_model import main as train_model
from neural_lam.utils import load_static_data
from neural_lam.weather_dataset import WeatherDataset
Expand Down Expand Up @@ -91,14 +91,14 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
assert boundary_forcing.shape == (
n_prediction_timesteps,
n_boundary,
2 * n_grid + n_forcing_features, # TODO Adjust dimensionality
2 * n_state_features + n_forcing_features, # TODO Adjust dimensionality
)

static_data = load_static_data(dataset_name=dataset_name)

required_props = {
"border_mask",
"boundary_mask",
"interior_mask",
"grid_static_features",
"boundary_static_features",
"step_diff_mean",
"step_diff_std",
"data_mean",
Expand All @@ -107,7 +107,9 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
}

# check the sizes of the props
assert static_data["border_mask"].shape == (n_grid, 1)
# TODO Should this config not be for only interior?
nx, ny = config.values["grid_shape_state"]
assert n_grid + n_boundary == nx * ny
assert static_data["grid_static_features"].shape == (
n_grid,
n_grid_static_features,
Expand Down

0 comments on commit ce204b6

Please sign in to comment.