diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 81c2f720..ccf20be4 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -9,11 +9,13 @@ import pytorch_lightning as pl import torch import xarray as xr +from loguru import logger # Local from .. import metrics, vis from ..config import NeuralLAMConfig from ..datastore import BaseDatastore +from ..datastore.base import BaseRegularGridDatastore from ..loss_weighting import get_state_feature_weighting from ..weather_dataset import WeatherDataset @@ -368,6 +370,70 @@ def on_validation_epoch_end(self): for metric_list in self.val_metrics.values(): metric_list.clear() + def _save_predictions_to_zarr( + self, + batch_times: torch.Tensor, + batch_predictions: torch.Tensor, + batch_idx: int, + zarr_output_path: str, + ): + """ + Save state predictions for single batch to zarr dataset. Will append to + existing dataset for batch_idx > 0. Resulting dataset will contain a + variable named `state` with coordinates (start_time, + elapsed_forecast_duration, grid_index, state_feature). + + Parameters + ---------- + batch_times : torch.Tensor[int] + The times for the batch, given as epoch time in nanoseconds. Shape + is (B, args.pred_steps) where B is the batch size and + args.pred_steps is the number of prediction steps. + batch_predictions : torch.Tensor[float] + The predictions for the batch, given as (B, args.pred_steps, + num_grid_nodes, d_f) where B is the batch size, args.pred_steps is + the number of prediction steps, num_grid_nodes is the number of + grid nodes, and d_f is the number of state features. + batch_idx : int + The index of the batch in the current epoch. + """ + batch_size = batch_predictions.shape[0] + # Convert predictions to DataArray using _create_dataarray_from_tensor + das_pred = [] + for i in range(len(batch_times)): + da_pred = self._create_dataarray_from_tensor( + tensor=batch_predictions[i], + time=batch_times[i], + split="test", + category="state", + ) + # Unstack grid coords if necessary, this also avoids the need to + # try to store a MultiIndex zarr dataset which is not supported by + # xarray + if isinstance(self._datastore, BaseRegularGridDatastore): + da_pred = self._datastore.unstack_grid_coords(da_pred) + + t0 = da_pred.coords["time"].values[0] + da_pred.coords["start_time"] = t0 + da_pred.coords["elapsed_forecast_duration"] = da_pred.time - t0 + da_pred = da_pred.swap_dims({"time": "elapsed_forecast_duration"}) + da_pred.name = "state" + das_pred.append(da_pred) + + da_pred_batch = xr.concat(das_pred, dim="start_time") + + # Apply chunking along analysis_time so that each batch is saved as a + # separate chunk + da_pred_batch = da_pred_batch.chunk({"start_time": batch_size}) + + if batch_idx == 0: + logger.info(f"Saving predictions to {zarr_output_path}") + da_pred_batch.to_zarr(zarr_output_path, mode="w", consolidated=True) + else: + da_pred_batch.to_zarr( + zarr_output_path, mode="a", append_dim="start_time" + ) + # pylint: disable-next=unused-argument def test_step(self, batch, batch_idx): """ @@ -375,8 +441,8 @@ def test_step(self, batch, batch_idx): """ # TODO Here batch_times can be used for plotting routines prediction, target, pred_std, batch_times = self.common_step(batch) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( self.loss( @@ -432,6 +498,14 @@ def test_step(self, batch, batch_idx): self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) + if self.args.save_eval_to_zarr_path: + self._save_predictions_to_zarr( + batch_times=batch_times, + batch_predictions=prediction, + batch_idx=batch_idx, + zarr_output_path=self.args.save_eval_to_zarr_path, + ) + # Plot example predictions (on rank 0 only) if ( self.trainer.is_global_zero diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index e8b402d5..47361ae3 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -24,7 +24,7 @@ } -@logger.catch +@logger.catch(reraise=True) def main(input_args=None): """Main function for training and evaluating models.""" parser = ArgumentParser( @@ -166,6 +166,11 @@ def main(input_args=None): help="Eval model on given data split (val/test) " "(default: None (train model))", ) + parser.add_argument( + "--save_eval_to_zarr_path", + type=str, + help="Save evaluation results to zarr dataset at given path ", + ) parser.add_argument( "--ar_steps_eval", type=int, diff --git a/pyproject.toml b/pyproject.toml index 09ff3a67..533de9e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,15 +29,20 @@ dependencies = [ "torch-geometric==2.3.1", "parse>=1.20.2", "dataclass-wizard<0.31.0", - "mllam-data-prep>=0.5.0", "mlflow>=2.16.2", "boto3>=1.35.32", "pynvml>=12.0.0", + "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" [project.optional-dependencies] -dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"] +dev = [ + "pre-commit>=3.8.0", + "pytest>=8.3.2", + "pooch>=1.8.2", + "pytest-dependency>=0.6.0", +] [tool.setuptools] py-modules = ["neural_lam"] diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 00000000..c1c81bc8 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,64 @@ +# Third-party +import pytest + +# First-party +from neural_lam.train_model import main as train_model_main +from tests.conftest import init_datastore_example + + +@pytest.mark.dependency(depends=["test_training"]) +def test_inference(request): + """ + Run inference on a trained model and save the results to a zarr dataset + through the command line interface. + + NB: This test will need refactoring once we clean up the command line + interface + """ + datastore = init_datastore_example("mdp") + + # NB: this is brittle and should be refactored when the command line + # interface is cleaned up so that tests point to neural-lam config files + # rather than datastore config files + nl_config_path = datastore.root_path / "config.yaml" + + # fetch the path to the trained model that was saved by the training test + model_path = request.config.cache.get("model_checkpoint_path", None) + if model_path is None: + raise Exception("training test must be run first") + + args = [ + "--config_path", + nl_config_path, + "--model", + "graph_lam", + "--eval", + "test", + "--load", + model_path, + "--hidden_dim", + "4", + "--hidden_layers", + "1", + "--processor_layers", + "2", + "--mesh_aggr", + "sum", + "--lr", + "1.0e-3", + "--val_steps_to_log", + "1", + "3", + "--num_past_forcing_steps", + "1", + "--num_future_forcing_steps", + "1", + "--n_example_pred", + "1", + "--graph", + "1level", + "--save_eval_to_zarr_path", + "state_test.zarr", + ] + + train_model_main(args) diff --git a/tests/test_training.py b/tests/test_training.py index 1ed1847d..b0aec806 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -17,8 +17,9 @@ from tests.conftest import init_datastore_example +@pytest.mark.dependency() @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_training(datastore_name): +def test_training(datastore_name, request): datastore = init_datastore_example(datastore_name) if not isinstance(datastore, BaseRegularGridDatastore): @@ -101,3 +102,10 @@ class ModelArgs: ) wandb.init() trainer.fit(model=model, datamodule=data_module) + + # save the path to the model checkpoint in to the request object so we can + # use in the inference test + request.config.cache.set( + "model_checkpoint_path", + model.trainer.checkpoint_callback.best_model_path, + )