From 6fe19ac71e1023b036a7c41d400e07a0d40fdd46 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 26 Jul 2024 10:36:14 +0200 Subject: [PATCH] npyfiles datastore complete --- neural_lam/create_graph.py | 8 +- neural_lam/datastore/base.py | 3 +- neural_lam/datastore/mllam.py | 8 +- neural_lam/datastore/multizarr/store.py | 10 +- neural_lam/datastore/npyfiles/config.py | 101 ++++++++------- neural_lam/datastore/npyfiles/store.py | 64 ++++++---- neural_lam/train_model.py | 23 ++-- neural_lam/weather_dataset.py | 12 +- pyproject.toml | 1 + tests/conftest.py | 25 +++- tests/test_datastores.py | 4 + tests/test_npy_forecast_dataset.py | 161 ------------------------ tests/test_training.py | 17 +++ 13 files changed, 172 insertions(+), 265 deletions(-) delete mode 100644 tests/test_npy_forecast_dataset.py diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index c281887d..6b062e3d 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -228,6 +228,8 @@ def create_graph( """ os.makedirs(graph_dir_path, exist_ok=True) + print(f"Writing graph components to {graph_dir_path}") + grid_xy = torch.tensor(xy) pos_max = torch.max(torch.abs(grid_xy)) @@ -562,7 +564,7 @@ def cli(input_args=None): help="kind of data store to use (default: multizarr)", ) parser.add_argument( - "datastore-path", + "datastore_path", type=str, help="path to the data store", ) @@ -594,11 +596,11 @@ def cli(input_args=None): args = parser.parse_args(input_args) DatastoreClass = DATASTORES[args.datastore] - datastore = DatastoreClass(args.datastore_path) + datastore = DatastoreClass(root_path=args.datastore_path) create_graph_from_datastore( datastore=datastore, - output_root_path=os.path.join(datastore.root_path, "graphs", args.name), + output_root_path=os.path.join(datastore.root_path, "graph", args.name), n_max_levels=args.levels, hierarchical=args.hierarchical, create_plot=args.plot, diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 2a472cbf..c2c2d798 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -266,7 +266,8 @@ def get_xy_extent(self, category: str) -> List[float]: The extent of the x, y coordinates. """ xy = self.get_xy(category, stacked=False) - return [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] + extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] + return [float(v) for v in extent] def unstack_grid_coords( self, da_or_ds: Union[xr.DataArray, xr.Dataset] diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py index f91faad9..a83cf31c 100644 --- a/neural_lam/datastore/mllam.py +++ b/neural_lam/datastore/mllam.py @@ -15,7 +15,7 @@ class MLLAMDatastore(BaseCartesianDatastore): """Datastore class for the MLLAM dataset.""" - def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): + def __init__(self, root_path, n_boundary_points=30, reuse_existing=True): """Construct a new MLLAMDatastore from the configuration file at `config_path`. A boundary mask is created with `n_boundary_points` boundary points. If `reuse_existing` is True, the dataset is loaded @@ -33,7 +33,9 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): reuse_existing : bool Whether to reuse an existing dataset zarr file if it exists. """ - self._config_path = Path(config_path) + config_filename = "data_config.yaml" + self._root_path = Path(root_path) + config_path = self._root_path / config_filename self._config = mdp.Config.from_yaml_file(config_path) fp_ds = self._config_path.parent / self._config_path.name.replace( ".yaml", ".zarr" @@ -48,7 +50,7 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): @property def root_path(self) -> Path: - return Path(self._config_path.parent) + return self._root_path @property def step_length(self) -> int: diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py index 37993be5..1f874d6e 100644 --- a/neural_lam/datastore/multizarr/store.py +++ b/neural_lam/datastore/multizarr/store.py @@ -1,6 +1,7 @@ # Standard library import functools import os +from pathlib import Path # Third-party import cartopy.crs as ccrs @@ -16,11 +17,16 @@ class MultiZarrDatastore(BaseCartesianDatastore): DIMS_TO_KEEP = {"time", "grid_index", "variable"} - def __init__(self, config_path): - self.config_path = config_path + def __init__(self, root_path): + self._root_path = Path(root_path) + config_path = self._root_path / "data_config.yaml" with open(config_path, encoding="utf-8", mode="r") as file: self._config = yaml.safe_load(file) + @property + def root_path(self): + return self._root_path + def _normalize_path(self, path): # try to parse path to see if it defines a protocol, e.g. s3:// if "://" in path or path.startswith("/"): diff --git a/neural_lam/datastore/npyfiles/config.py b/neural_lam/datastore/npyfiles/config.py index f3fe25ca..545b4b8b 100644 --- a/neural_lam/datastore/npyfiles/config.py +++ b/neural_lam/datastore/npyfiles/config.py @@ -1,10 +1,61 @@ # Standard library -import functools -from pathlib import Path +from dataclasses import dataclass +from typing import Any, Dict, List # Third-party -import cartopy.crs as ccrs -import yaml +import dataclass_wizard + + +@dataclass +class Projection: + """Represents the projection information for a dataset, including the type + of projection and its parameters. Capable of creating a cartopy.crs + projection object. + + Attributes: + class_name: The class name of the projection, this should be a valid + cartopy.crs class. + kwargs: A dictionary of keyword arguments specific to the projection type. + """ + + class_name: str # = field(metadata={'data_key': 'class'}) + kwargs: Dict[str, Any] + + +@dataclass +class Dataset: + """Contains information about the dataset, including variable names, units, + and descriptions. + + Attributes: + name: The name of the dataset. + var_names: A list of variable names in the dataset. + var_units: A list of units for each variable. + var_longnames: A list of long, descriptive names for each variable. + num_forcing_features: The number of forcing features in the dataset. + """ + + name: str + var_names: List[str] + var_units: List[str] + var_longnames: List[str] + num_forcing_features: int + + +@dataclass +class NpyDatastoreConfig(dataclass_wizard.YAMLWizard): + """Configuration for loading and processing a dataset, including dataset + details, grid shape, and projection information. + + Attributes: + dataset: An instance of Dataset containing details about the dataset. + grid_shape_state: A list representing the shape of the grid state. + projection: An instance of Projection containing projection details. + """ + + dataset: Dataset + grid_shape_state: List[int] + projection: Projection class NpyConfig: @@ -14,48 +65,6 @@ class NpyConfig: its values as attributes. """ - def __init__(self, values): - self.values = values - - @classmethod - def from_file(cls, filepath): - """Load a configuration file.""" - if str(filepath).endswith(".yaml"): - with open(filepath, encoding="utf-8", mode="r") as file: - return cls(values=yaml.safe_load(file)) - else: - raise NotImplementedError(Path(filepath).suffix) - - def __getattr__(self, name): - child, *children = name.split(".") - - value = self.values[child] - if len(children) > 0: - return self.__class__(values=value).get(".".join(children)) - else: - if isinstance(value, dict): - return self.__class__(values=value) - else: - return value - - def __getitem__(self, key): - value = self.values[key] - if isinstance(value, dict): - return self.__class__(values=value) - return value - - def __contains__(self, key): - return key in self.values - def num_data_vars(self): """Return the number of data variables for a given key.""" return len(self.dataset.var_names) - - @functools.cached_property - def coords_projection(self): - """Return the projection.""" - proj_config = self.values["projection"] - proj_class_name = proj_config["class"] - proj_class = getattr(ccrs, proj_class_name) - proj_params = proj_config.get("kwargs", {}) - return proj_class(**proj_params) diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py index 295ef882..02365a46 100644 --- a/neural_lam/datastore/npyfiles/store.py +++ b/neural_lam/datastore/npyfiles/store.py @@ -1,9 +1,11 @@ # Standard library +import functools import re from pathlib import Path from typing import List # Third-party +import cartopy.crs as ccrs import dask import dask.array import dask.delayed @@ -15,7 +17,7 @@ # Local from ..base import BaseCartesianDatastore, CartesianGridShape -from .config import NpyConfig +from .config import NpyDatastoreConfig STATE_FILENAME_FORMAT = "nwp_{analysis_time:%Y%m%d%H}_mbr{member_id:03d}.npy" TOA_SW_DOWN_FLUX_FILENAME_FORMAT = ( @@ -24,6 +26,13 @@ COLUMN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy" +def _load_np(fp, add_feature_dim): + arr = np.load(fp) + if add_feature_dim: + arr = arr[..., np.newaxis] + return arr + + class NpyFilesDatastore(BaseCartesianDatastore): __doc__ = f""" Represents a dataset stored as numpy files on disk. The dataset is assumed @@ -133,7 +142,9 @@ def __init__( self._num_ensemble_members = 2 self._root_path = Path(root_path) - self.config = NpyConfig.from_file(self.root_path / "data_config.yaml") + self.config = NpyDatastoreConfig.from_yaml_file( + self.root_path / "data_config.yaml" + ) @property def root_path(self): @@ -157,9 +168,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: xr.DataArray The data array for the given category and split, with dimensions per category: - state: `[elapsed_forecast_time, analysis_time, grid_index, feature, + state: `[elapsed_forecast_duration, analysis_time, grid_index, feature, ensemble_member]` - forcing: `[elapsed_forecast_time, analysis_time, grid_index, feature]` + forcing: `[elapsed_forecast_duration, analysis_time, grid_index, feature]` static: `[grid_index, feature]` """ if category == "state": @@ -188,14 +199,14 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # add datetime forcing as a feature # to do this we create a forecast time variable which has the - # dimensions of (analysis_time, elapsed_forecast_time) with values + # dimensions of (analysis_time, elapsed_forecast_duration) with values # that are the actual forecast time of each time step. By calling - # .chunk({"elapsed_forecast_time": 1}) this time variable is turned + # .chunk({"elapsed_forecast_duration": 1}) this time variable is turned # into a dask array and so execution of the calculation is delayed # until the feature values are actually used. da_forecast_time = ( - da.analysis_time + da.elapsed_forecast_time - ).chunk({"elapsed_forecast_time": 1}) + da.analysis_time + da.elapsed_forecast_duration + ).chunk({"elapsed_forecast_duration": 1}) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -262,7 +273,7 @@ def _get_single_timeseries_dataarray( ------- xr.DataArray The data array for the given category and split, with dimensions - `[elapsed_forecast_time, analysis_time, grid_index, feature]` for + `[elapsed_forecast_duration, analysis_time, grid_index, feature]` for all categories of data """ assert split in ("train", "val", "test"), "Unknown dataset split" @@ -284,12 +295,12 @@ def _get_single_timeseries_dataarray( features_vary_with_analysis_time = True if features == self.get_vars_names(category="state"): filename_format = STATE_FILENAME_FORMAT - file_dims = ["elapsed_forecast_time", "y", "x", "feature"] + file_dims = ["elapsed_forecast_duration", "y", "x", "feature"] # only select one member for now file_params["member_id"] = member elif features == ["toa_downwelling_shortwave_flux"]: filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT - file_dims = ["elapsed_forecast_time", "y", "x", "feature"] + file_dims = ["elapsed_forecast_duration", "y", "x", "feature"] add_feature_dim = True elif features == ["column_water"]: filename_format = COLUMN_WATER_FILENAME_FORMAT @@ -329,7 +340,7 @@ def _get_single_timeseries_dataarray( coords = {} arr_shape = [] for d in dims: - if d == "elapsed_forecast_time": + if d == "elapsed_forecast_duration": coord_values = ( self.step_length * np.arange(self._num_timesteps) @@ -346,16 +357,12 @@ def _get_single_timeseries_dataarray( else: raise NotImplementedError(f"Dimension {d} not supported") - print(f"{d}: {len(coord_values)}") - coords[d] = coord_values if d != "analysis_time": # analysis_time varies across the different files, but not # within a single file arr_shape.append(len(coord_values)) - print(f"{features}: {dims=} {file_dims=} {arr_shape=}") - if features_vary_with_analysis_time: filepaths = [ fp_samples @@ -369,16 +376,11 @@ def _get_single_timeseries_dataarray( # use dask.delayed to load the numpy files, so that loading isn't # done until the data is actually needed - @dask.delayed - def _load_np(fp): - arr = np.load(fp) - if add_feature_dim: - arr = arr[..., np.newaxis] - return arr - arrays = [ dask.array.from_delayed( - _load_np(fp), shape=arr_shape, dtype=np.float32 + dask.delayed(_load_np)(fp=fp, add_feature_dim=add_feature_dim), + shape=arr_shape, + dtype=np.float32, ) for fp in filepaths ] @@ -457,7 +459,7 @@ def _calc_datetime_forcing_features(self, da_time: xr.DataArray): def get_vars_units(self, category: str) -> torch.List[str]: if category == "state": - return self.config["dataset"]["var_units"] + return self.config.dataset.var_units elif category == "forcing": return [ "W/m^2", @@ -474,7 +476,7 @@ def get_vars_units(self, category: str) -> torch.List[str]: def get_vars_names(self, category: str) -> torch.List[str]: if category == "state": - return self.config["dataset"]["var_names"] + return self.config.dataset.var_names elif category == "forcing": # XXX: this really shouldn't be hard-coded here, this should be in # the config @@ -557,7 +559,7 @@ def boundary_mask(self): da_mask = xr.DataArray( values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask" ) - da_mask_stacked_xy = self.stack_grid_coords(da_mask) + da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int) return da_mask_stacked_xy def get_normalization_dataarray(self, category: str) -> xr.Dataset: @@ -623,3 +625,11 @@ def load_pickled_tensor(fn): ) return ds_norm + + @functools.cached_property + def coords_projection(self): + """Return the projection.""" + proj_class_name = self.config.projection.class_name + ProjectionClass = getattr(ccrs, proj_class_name) + proj_params = self.config.projection.kwargs + return ProjectionClass(**proj_params) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 3ea86716..39f0cbdf 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -7,7 +7,6 @@ # Third-party import pytorch_lightning as pl import torch -import wandb from lightning_fabric.utilities import seed # Local @@ -27,13 +26,13 @@ } -def _init_datastore(datastore_kind, data_config): +def _init_datastore(datastore_kind, path): if datastore_kind == "multizarr": - datastore = MultiZarrDatastore(data_config) + datastore = MultiZarrDatastore(root_path=path) elif datastore_kind == "npyfiles": - datastore = NpyFilesDatastore(data_config) + datastore = NpyFilesDatastore(root_path=path) elif datastore_kind == "mllam": - datastore = MLLAMDatastore(data_config) + datastore = MLLAMDatastore(root_path=path) else: raise ValueError(f"Unknown datastore kind: {datastore_kind}") return datastore @@ -52,10 +51,10 @@ def main(input_args=None): help="Kind of datastore to use (default: multizarr)", ) parser.add_argument( - "--datastore-config", + "--datastore-path", type=str, - default="tests/datastore_configs/multizarr/data_config.yaml", - help="Path to data config file", + default="tests/datastore_configs/multizarr", + help="The root path for the datastore", ) parser.add_argument( "--model", @@ -248,7 +247,9 @@ def main(input_args=None): # Set seed seed.seed_everything(args.seed) # Create datastore - datastore = _init_datastore(args.datastore_kind, args.datastore_config) + datastore = _init_datastore( + datastore_kind=args.datastore_kind, path=args.datastore_path + ) # Create datamodule data_module = WeatherDataModule( datastore=datastore, @@ -303,6 +304,7 @@ def main(input_args=None): callbacks=[checkpoint_callback], check_val_every_n_epoch=args.val_interval, precision=args.precision, + devices=1, ) # Only init once, on rank 0 only @@ -310,7 +312,8 @@ def main(input_args=None): utils.init_wandb_metrics( logger, val_steps=args.val_steps_to_log ) # Do after wandb.init - wandb.save(args.datastore_config) + # TODO: should we save the datastore config here? + # wandb.save() if args.eval: trainer.test(model=model, datamodule=data_module, ckpt_path=args.load) else: diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 05607f8f..ceae3663 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -71,7 +71,7 @@ def __len__(self): UserWarning, ) # XXX: we should maybe check that the 2+ar_steps actually fits - # in the elapsed_forecast_time dimension, should that be checked here? + # in the elapsed_forecast_duration dimension, should that be checked here? return self.da_state.analysis_time.size else: # sample_len = 2 + ar_steps <-- 2 initial states + ar_steps target states @@ -93,7 +93,7 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0): da : xr.DataArray The dataarray to sample from. This is expected to have a `time` dimension if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_time` dimensions if the + `analysis_time` and `elapsed_forecast_duration` dimensions if the datastore is providing forecast data. idx : int The index of the time step to start the sample from. @@ -103,19 +103,19 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0): # selecting the time slice if self.datastore.is_forecast: # this implies that the data will have both `analysis_time` and - # `elapsed_forecast_time` dimensions for forecasts we for now + # `elapsed_forecast_duration` dimensions for forecasts we for now # simply select a analysis time and then the next ar_steps forecast # times da = da.isel( analysis_time=idx, - elapsed_forecast_time=slice( + elapsed_forecast_duration=slice( n_timesteps_offset, n_steps + n_timesteps_offset ), ) # create a new time dimension so that the produced sample has a # `time` dimension, similarly to the analysis only data - da["time"] = da.analysis_time + da.elapsed_forecast_time - da = da.swap_dims({"elapsed_forecast_time": "time"}) + da["time"] = da.analysis_time + da.elapsed_forecast_duration + da = da.swap_dims({"elapsed_forecast_duration": "time"}) else: # only `time` dimension for analysis only data da = da.isel( diff --git a/pyproject.toml b/pyproject.toml index 2681831d..075b0146 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "torch-geometric==2.3.1", "mllam-data-prep @ git+https://github.com/mllam/mllam-data-prep", "parse>=1.20.2", + "dataclass-wizard>=0.22.3", ] requires-python = ">=3.9" diff --git a/tests/conftest.py b/tests/conftest.py index 1c1cdd3e..9ff25a91 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ # Third-party import pooch +import yaml # First-party from neural_lam.datastore.mllam import MLLAMDatastore @@ -32,7 +33,10 @@ def download_meps_example_reduced_dataset(): # Download and unzip test data into data/meps_example_reduced - root_path = Path("tests/datastores_examples/npy") + root_path = Path("tests/datastore_configs/npy") + dataset_path = root_path / "meps_example_reduced" + will_download = not dataset_path.exists() + pooch.retrieve( url=S3_FULL_PATH, known_hash=TEST_DATA_KNOWN_HASH, @@ -40,14 +44,23 @@ def download_meps_example_reduced_dataset(): path=root_path, fname="meps_example_reduced.zip", ) - return root_path / "meps_example_reduced" + + if will_download: + # XXX: should update the dataset stored on S3 the change below + config_path = dataset_path / "data_config.yaml" + # rename the `projection.class` key to `projection.class_name` in the config + with open(config_path, "r") as f: + config = yaml.safe_load(f) + config["projection.class_name"] = config.pop("projection.class") + with open(config_path, "w") as f: + yaml.dump(config, f) + + return dataset_path DATASTORES_EXAMPLES = dict( - multizarr=dict( - config_path="tests/datastore_configs/multizarr/data_config.yaml" - ), - mllam=dict(config_path="tests/datastore_configs/mllam/example.danra.yaml"), + multizarr=dict(root_path="tests/datastore_configs/multizarr"), + mllam=dict(root_path="tests/datastore_configs/mllam"), npyfiles=dict(root_path=download_meps_example_reduced_dataset()), ) diff --git a/tests/test_datastores.py b/tests/test_datastores.py index abb41e92..bd378e98 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -150,6 +150,10 @@ def test_get_dataarray(datastore_name): "elapsed_forecast_duration", ] + if datastore.is_ensemble and category == "state": + # assume that only state variables change with ensemble members + expected_dims.append("ensemble_member") + # XXX: for now we only have a single attribute to get the shape of # the grid which uses the shape from the "state" category, maybe # this should change? diff --git a/tests/test_npy_forecast_dataset.py b/tests/test_npy_forecast_dataset.py deleted file mode 100644 index 571565dd..00000000 --- a/tests/test_npy_forecast_dataset.py +++ /dev/null @@ -1,161 +0,0 @@ -# Standard library -import os - -# Third-party -import pooch -import pytest - -# First-party -from neural_lam.create_graph import create_graph as create_graph -from neural_lam.datastore.npyfiles import NpyFilesDatastore -from neural_lam.train_model import main as train_model -from neural_lam.weather_dataset import WeatherDataset - -# Disable weights and biases to avoid unnecessary logging -# and to avoid having to deal with authentication -os.environ["WANDB_DISABLED"] = "true" - -# Initializing variables for the s3 client -S3_BUCKET_NAME = "mllam-testdata" -S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int" -S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip" -S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH]) -TEST_DATA_KNOWN_HASH = ( - "98c7a2f442922de40c6891fe3e5d190346889d6e0e97550170a82a7ce58a72b7" -) - - -@pytest.fixture(scope="session") -def ewc_testdata_path(): - # Download and unzip test data into data/meps_example_reduced - pooch.retrieve( - url=S3_FULL_PATH, - known_hash=TEST_DATA_KNOWN_HASH, - processor=pooch.Unzip(extract_dir=""), - path="data", - fname="meps_example_reduced.zip", - ) - - return "data/meps_example_reduced" - - -def test_load_reduced_meps_dataset(ewc_testdata_path): - datastore = NpyFilesDatastore(root_path=ewc_testdata_path) - datastore.get_xy(category="state", stacked=True) - - datastore.get_dataarray(category="forcing", split="train").unstack( - "grid_index" - ) - datastore.get_dataarray(category="state", split="train").unstack( - "grid_index" - ) - - dataset = WeatherDataset(datastore=datastore) - - var_names = datastore.config.values["dataset"]["var_names"] - var_units = datastore.config.values["dataset"]["var_units"] - var_longnames = datastore.config.values["dataset"]["var_longnames"] - - assert len(var_names) == len(var_longnames) - assert len(var_names) == len(var_units) - - # in future the number of grid static features - # will be provided by the Dataset class itself - n_grid_static_features = 4 - # Hardcoded in model - n_input_steps = 2 - - n_forcing_features = datastore.config.values["dataset"][ - "num_forcing_features" - ] - n_state_features = len(var_names) - n_prediction_timesteps = dataset.ar_steps - - nx, ny = datastore.config.values["grid_shape_state"] - n_grid = nx * ny - - # check that the dataset is not empty - assert len(dataset) > 0 - - # get the first item - item = dataset[0] - init_states = item.init_states - target_states = item.target_states - forcing = item.forcing - - # check that the shapes of the tensors are correct - assert init_states.shape == (n_input_steps, n_grid, n_state_features) - assert target_states.shape == ( - n_prediction_timesteps, - n_grid, - n_state_features, - ) - assert forcing.shape == ( - n_prediction_timesteps, - n_grid, - n_forcing_features, - ) - - ds_state_norm = datastore.get_normalization_dataarray(category="state") - - static_data = { - "border_mask": datastore.boundary_mask.values, - "grid_static_features": datastore.get_dataarray( - category="static", split="train" - ).values, - "data_mean": ds_state_norm.state_mean.values, - "data_std": ds_state_norm.state_std.values, - "step_diff_mean": ds_state_norm.state_diff_mean.values, - "step_diff_std": ds_state_norm.state_diff_std.values, - } - - required_props = { - "border_mask", - "grid_static_features", - "step_diff_mean", - "step_diff_std", - "data_mean", - "data_std", - "param_weights", - } - - # check the sizes of the props - assert static_data["border_mask"].shape == (n_grid,) - assert static_data["grid_static_features"].shape == ( - n_grid, - n_grid_static_features, - ) - assert static_data["step_diff_mean"].shape == (n_state_features,) - assert static_data["step_diff_std"].shape == (n_state_features,) - assert static_data["data_mean"].shape == (n_state_features,) - assert static_data["data_std"].shape == (n_state_features,) - assert static_data["param_weights"].shape == (n_state_features,) - - assert set(static_data.keys()) == required_props - - -def test_create_graph_reduced_meps_dataset(): - args = [ - "--graph=hierarchical", - "--hierarchical=1", - "--data_config=data/meps_example_reduced/data_config.yaml", - "--levels=2", - ] - create_graph(args) - - -def test_train_model_reduced_meps_dataset(): - args = [ - "--model=hi_lam", - "--data_config=data/meps_example_reduced/data_config.yaml", - "--n_workers=4", - "--epochs=1", - "--graph=hierarchical", - "--hidden_dim=16", - "--hidden_layers=1", - "--processor_layers=1", - "--ar_steps=1", - "--eval=val", - "--n_example_pred=0", - ] - train_model(args) diff --git a/tests/test_training.py b/tests/test_training.py index 3767fbc0..5e7f4095 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -80,3 +80,20 @@ class ModelArgs: ) wandb.init() trainer.fit(model=model, datamodule=data_module) + + +# def test_train_model_reduced_meps_dataset(): +# args = [ +# "--model=hi_lam", +# "--data_config=data/meps_example_reduced/data_config.yaml", +# "--n_workers=4", +# "--epochs=1", +# "--graph=hierarchical", +# "--hidden_dim=16", +# "--hidden_layers=1", +# "--processor_layers=1", +# "--ar_steps=1", +# "--eval=val", +# "--n_example_pred=0", +# ] +# train_model(args)