diff --git a/neural_lam/config.py b/neural_lam/config.py index 819ce2aa..048241df 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -1,6 +1,7 @@ - +# Standard library import os +# Third-party import cartopy.crs as ccrs import numpy as np import xarray as xr @@ -86,7 +87,10 @@ def open_zarr(self, dataset_name): """Open a dataset specified by the dataset name.""" dataset_path = self.zarrs[dataset_name].path if dataset_path is None or not os.path.exists(dataset_path): - print(f"Dataset '{dataset_name}' not found at path: {dataset_path}") + print( + f"Dataset '{dataset_name}' " + f"not found at path: {dataset_path}" + ) return None dataset = xr.open_zarr(dataset_path, consolidated=True) return dataset diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index fff28632..fc78e638 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -6,21 +6,19 @@ import numpy as np import pytorch_lightning as pl import torch - import wandb # First-party -from neural_lam import metrics, vis +from neural_lam import config, metrics, vis class ARModel(pl.LightningModule): """ - Generic auto-regressive weather model. - Abstract class that can be extended. + Generic auto-regressive weather model. Abstract class that can be extended. """ - # pylint: disable=arguments-differ - # Disable to override args/kwargs from superclass + # pylint: disable=arguments-differ Disable to override args/kwargs from + # superclass def __init__(self, args): super().__init__() @@ -127,18 +125,18 @@ def expand_to_batch(x, batch_size): def predict_step(self, prev_state, prev_prev_state, forcing): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, num_grid_nodes, feature_dim), X_t - prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} - forcing: (B, num_grid_nodes, forcing_dim) + prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, + num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes, + forcing_dim) """ raise NotImplementedError("No prediction step implemented") def unroll_prediction(self, init_states, forcing_features, true_states): """ Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, num_grid_nodes, d_f) - forcing_features: (B, pred_steps, num_grid_nodes, d_static_f) - true_states: (B, pred_steps, num_grid_nodes, d_f) + init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B, + pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps, + num_grid_nodes, d_f) """ prev_prev_state = init_states[:, 0] prev_state = init_states[:, 1] @@ -153,8 +151,8 @@ def unroll_prediction(self, init_states, forcing_features, true_states): pred_state, pred_std = self.predict_step( prev_state, prev_prev_state, forcing ) - # state: (B, num_grid_nodes, d_f) - # pred_std: (B, num_grid_nodes, d_f) or None + # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, + # d_f) or None # Overwrite border with true state new_state = ( @@ -184,11 +182,10 @@ def unroll_prediction(self, init_states, forcing_features, true_states): def common_step(self, batch): """ - Predict on single batch - batch consists of: - init_states: (B, 2, num_grid_nodes, d_features) - target_states: (B, pred_steps, num_grid_nodes, d_features) - forcing_features: (B, pred_steps, num_grid_nodes, d_forcing), + Predict on single batch batch consists of: init_states: (B, 2, + num_grid_nodes, d_features) target_states: (B, pred_steps, + num_grid_nodes, d_features) forcing_features: (B, pred_steps, + num_grid_nodes, d_forcing), where index 0 corresponds to index 1 of init_states """ ( @@ -200,8 +197,8 @@ def common_step(self, batch): prediction, pred_std = self.unroll_prediction( init_states, forcing_features, target_states ) # (B, pred_steps, num_grid_nodes, d_f) - # prediction: (B, pred_steps, num_grid_nodes, d_f) - # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, + # pred_steps, num_grid_nodes, d_f) or (d_f,) return prediction, target_states, pred_std @@ -214,9 +211,8 @@ def on_after_batch_transfer(self, batch, dataloader_idx): forcing_features = ( forcing_features - self.forcing_mean ) / self.forcing_std - # boundary_features = ( - # boundary_features - self.boundary_mean - # ) / self.boundary_std + # boundary_features = ( boundary_features - self.boundary_mean ) / + # self.boundary_std batch = ( init_states, target_states, @@ -246,8 +242,8 @@ def training_step(self, batch): def all_gather_cat(self, tensor_to_gather): """ - Gather tensors across all ranks, and concatenate across dim. 0 - (instead of stacking in new dim. 0) + Gather tensors across all ranks, and concatenate across dim. 0 (instead + of stacking in new dim. 0) tensor_to_gather: (d1, d2, ...), distributed over K ranks @@ -308,8 +304,8 @@ def test_step(self, batch, batch_idx): Run test on single batch """ prediction, target, pred_std = self.common_step(batch) - # prediction: (B, pred_steps, num_grid_nodes, d_f) - # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, + # pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( self.loss( @@ -330,10 +326,9 @@ def test_step(self, batch, batch_idx): test_log_dict, on_step=False, on_epoch=True, sync_dist=True ) - # Compute all evaluation metrics for error maps - # Note: explicitly list metrics here, as test_metrics can contain - # additional ones, computed differently, but that should be aggregated - # on_test_epoch_end + # Compute all evaluation metrics for error maps Note: explicitly list + # metrics here, as test_metrics can contain additional ones, computed + # differently, but that should be aggregated on_test_epoch_end for metric_name in ("mse", "mae"): metric_func = metrics.get_metric(metric_name) batch_metric_vals = metric_func( @@ -378,9 +373,9 @@ def plot_examples(self, batch, n_examples, prediction=None): """ Plot the first n_examples forecasts from batch - batch: batch with data to plot corresponding forecasts for - n_examples: number of forecasts to plot - prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. + batch: batch with data to plot corresponding forecasts for n_examples: + number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes, + d_f), existing prediction. Generate if None. """ if prediction is None: @@ -470,15 +465,14 @@ def plot_examples(self, batch, n_examples, prediction=None): def create_metric_log_dict(self, metric_tensor, prefix, metric_name): """ - Put together a dict with everything to log for one metric. - Also saves plots as pdf and csv if using test prefix. + Put together a dict with everything to log for one metric. Also saves + plots as pdf and csv if using test prefix. metric_tensor: (pred_steps, d_f), metric values per time and variable - prefix: string, prefix to use for logging - metric_name: string, name of the metric + prefix: string, prefix to use for logging metric_name: string, name of + the metric - Return: - log_dict: dict with everything to log for given metric + Return: log_dict: dict with everything to log for given metric """ log_dict = {} metric_fig = vis.plot_error_map( @@ -552,8 +546,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): def on_test_epoch_end(self): """ - Compute test metrics and make plots at the end of test epoch. - Will gather stored tensors and perform plotting and logging on rank 0. + Compute test metrics and make plots at the end of test epoch. Will + gather stored tensors and perform plotting and logging on rank 0. """ # Create error maps for all test metrics self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index 8ce87030..d9a4c676 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -179,9 +179,9 @@ def process_step(self, mesh_rep): ) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = ( - new_node_rep # (B, num_mesh_nodes[l], d_h) - ) + mesh_rep_levels[ + level_l + ] = new_node_rep # (B, num_mesh_nodes[l], d_h) mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h) # - PROCESSOR - @@ -207,9 +207,9 @@ def process_step(self, mesh_rep): new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = ( - new_node_rep # (B, num_mesh_nodes[l], d_h) - ) + mesh_rep_levels[ + level_l + ] = new_node_rep # (B, num_mesh_nodes[l], d_h) # Return only bottom level representation return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h) diff --git a/plot_graph.py b/plot_graph.py index 9b465fd4..d6f297d0 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -10,7 +10,7 @@ from trimesh.primitives import Box # First-party -from neural_lam import utils +from neural_lam import config, utils MESH_HEIGHT = 0.1 MESH_LEVEL_DIST = 0.05