From 64d43a61288acbf16765d50d7d99ed2a6f299983 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 25 Jul 2024 10:01:20 +0000 Subject: [PATCH] training working with mllam datastore! --- .gitignore | 1 + neural_lam/models/ar_model.py | 19 ++++---- neural_lam/vis.py | 30 ++++++++----- tests/conftest.py | 6 +++ tests/test_training.py | 82 +++++++++++++++++++++++++++++++++++ 5 files changed, 118 insertions(+), 20 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_training.py diff --git a/.gitignore b/.gitignore index f5faeb52..8cd4e45d 100644 --- a/.gitignore +++ b/.gitignore @@ -82,6 +82,7 @@ tags # pdm (https://pdm-project.org/en/stable/) .pdm-python +.venv # exclude pdm.lock file so that both cpu and gpu versions of torch will be accepted by pdm pdm.lock diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 59ca1fdc..d18c89ab 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -28,6 +28,7 @@ def __init__( super().__init__() self.save_hyperparameters() self.args = args + self._datastore = datastore # XXX: should be this be somewhere else? split = "train" num_state_vars = datastore.get_num_data_vars(category="state") @@ -429,18 +430,18 @@ def plot_examples(self, batch, n_examples, prediction=None): # Create one figure per variable at this time step var_figs = [ vis.plot_prediction( - pred_t[:, var_i], - target_t[:, var_i], - self.interior_mask[:, 0], - self.datastore, + pred=pred_t[:, var_i], + target=target_t[:, var_i], + obs_mask=self.interior_mask[:, 0], + datastore=self.datastore, title=f"{var_name} ({var_unit}), " f"t={t_i} ({self.step_length * t_i} h)", vrange=var_vrange, ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( - self.data_config.vars_names("state"), - self.data_config.vars_units("state"), + self._datastore.get_vars_names("state"), + self._datastore.get_vars_units("state"), var_vranges, ) ) @@ -451,7 +452,7 @@ def plot_examples(self, batch, n_examples, prediction=None): { f"{var_name}_example_{example_i}": wandb.Image(fig) for var_name, fig in zip( - self.data_config.vars_names("state"), var_figs + self._datastore.get_vars_names("state"), var_figs ) } ) @@ -485,7 +486,9 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): """ log_dict = {} metric_fig = vis.plot_error_map( - metric_tensor, self.data_config, step_length=self.step_length + errors=metric_tensor, + datastore=self._datastore, + step_length=self.step_length, ) full_log_name = f"{prefix}_{metric_name}" log_dict[full_log_name] = wandb.Image(metric_fig) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 1edf71e9..98e066c4 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -5,10 +5,13 @@ # Local from . import utils +from .datastore.base import BaseCartesianDatastore @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_error_map(errors, data_config, title=None, step_length=1): +def plot_error_map( + errors, datastore: BaseCartesianDatastore, title=None, step_length=1 +): """ Plot a heatmap of errors of different variables at different predictions horizons @@ -48,11 +51,10 @@ def plot_error_map(errors, data_config, title=None, step_length=1): ax.set_xlabel("Lead time (h)", size=label_size) ax.set_yticks(np.arange(d_f)) + var_names = datastore.get_vars_names(category="state") + var_units = datastore.get_vars_units(category="state") y_ticklabels = [ - f"{name} ({unit})" - for name, unit in zip( - data_config.vars_names("state"), data_config.vars_units("state") - ) + f"{name} ({unit})" for name, unit in zip(var_names, var_units) ] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -64,7 +66,12 @@ def plot_error_map(errors, data_config, title=None, step_length=1): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, target, obs_mask, data_config, title=None, vrange=None + pred, + target, + obs_mask, + datastore: BaseCartesianDatastore, + title=None, + vrange=None, ): """Plot example prediction and grond truth. @@ -77,12 +84,11 @@ def plot_prediction( else: vmin, vmax = vrange - extent = data_config.get_xy_extent("state") + extent = datastore.get_xy_extent("state") # Set up masking of border region - mask_reshaped = obs_mask.reshape( - list(data_config.grid_shape_state.values.values()) - ) + da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) + mask_reshaped = da_mask.values pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region @@ -91,14 +97,14 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.coords_projection}, + subplot_kw={"projection": datastore.coords_projection}, ) # Plot pred and target for ax, data in zip(axes, (target, pred)): ax.coastlines() # Add coastline outlines data_grid = ( - data.reshape(list(data_config.grid_shape_state.values.values())) + data.reshape(list(datastore.grid_shape_state.values.values())) .cpu() .numpy() ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..0ec7f4b0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,6 @@ +# Standard library +import os + +# Disable weights and biases to avoid unnecessary logging +# and to avoid having to deal with authentication +os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/test_training.py b/tests/test_training.py new file mode 100644 index 00000000..3767fbc0 --- /dev/null +++ b/tests/test_training.py @@ -0,0 +1,82 @@ +# Standard library +from pathlib import Path + +# Third-party +import pytest +import pytorch_lightning as pl +import torch +import wandb +from test_datastores import DATASTORES, init_datastore + +# First-party +from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.models.graph_lam import GraphLAM +from neural_lam.weather_dataset import WeatherDataModule + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_training(datastore_name): + datastore = init_datastore(datastore_name) + + if torch.cuda.is_available(): + device_name = "cuda" + torch.set_float32_matmul_precision( + "high" + ) # Allows using Tensor Cores on A100s + else: + device_name = "cpu" + + trainer = pl.Trainer( + max_epochs=3, + deterministic=True, + strategy="ddp", + accelerator=device_name, + log_every_n_steps=1, + ) + + graph_name = "1level" + + graph_dir_path = Path(datastore.root_path) / "graph" / graph_name + + if not graph_dir_path.exists(): + create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + n_max_levels=1, + ) + + data_module = WeatherDataModule( + datastore=datastore, + ar_steps_train=3, + ar_steps_eval=5, + standardize=True, + batch_size=2, + num_workers=1, + forcing_window_size=3, + ) + + class ModelArgs: + output_std = False + loss = "mse" + restore_opt = False + n_example_pred = 1 + # XXX: this should be superfluous when we have already defined the + # model object no? + graph = graph_name + hidden_dim = 8 + hidden_layers = 1 + processor_layers = 4 + mesh_aggr = "sum" + lr = 1.0e-3 + val_steps_to_log = [1] + metrics_watch = [] + + model_args = ModelArgs() + + model = GraphLAM( # noqa + args=model_args, + forcing_window_size=data_module.forcing_window_size, + datastore=datastore, + ) + wandb.init() + trainer.fit(model=model, datamodule=data_module)