From 39f59766ec949ac13cf351de8c264775f32dc0ef Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 1 May 2024 20:15:33 +0200 Subject: [PATCH 01/11] Utilize rank_one_only to init logger and plot --- neural_lam/models/ar_model.py | 4 +- neural_lam/models/base_graph_model.py | 2 +- neural_lam/models/base_hi_graph_model.py | 11 ++++-- neural_lam/models/graph_lam.py | 2 +- neural_lam/utils.py | 50 ++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 8 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 7d0a8320..6d160526 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -400,7 +400,7 @@ def plot_examples(self, batch, n_examples, prediction=None): target_t[:, var_i], self.interior_mask[:, 0], title=f"{var_name} ({var_unit}), " - f"t={t_i} ({self.step_length*t_i} h)", + f"t={t_i} ({self.step_length * t_i} h)", vrange=var_vrange, ) for var_i, (var_name, var_unit, var_vrange) in enumerate( @@ -542,7 +542,7 @@ def on_test_epoch_end(self): vis.plot_spatial_error( loss_map, self.interior_mask[:, 0], - title=f"Test loss, t={t_i} ({self.step_length*t_i} h)", + title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", ) for t_i, loss_map in zip( constants.VAL_STEP_LOG_ERRORS, mean_spatial_loss diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 256d4adc..8fd619dc 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -29,7 +29,7 @@ def __init__(self, args): # Specify dimensions of data self.num_mesh_nodes, _ = self.get_num_mesh() - print( + utils.rank_zero_print( f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} " f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)" ) diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index 8ce87030..9686d867 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -25,10 +25,10 @@ def __init__(self, args): ] # Needs as python list for later # Print some useful info - print("Loaded hierarchical graph with structure:") + utils.rank_zero_print("Loaded hierarchical graph with structure:") for level_index, level_mesh_size in enumerate(self.level_mesh_sizes): same_level_edges = self.m2m_features[level_index].shape[0] - print( + utils.rank_zero_print( f"level {level_index} - {level_mesh_size} nodes, " f"{same_level_edges} same-level edges" ) @@ -36,8 +36,11 @@ def __init__(self, args): if level_index < (self.num_levels - 1): up_edges = self.mesh_up_features[level_index].shape[0] down_edges = self.mesh_down_features[level_index].shape[0] - print(f" {level_index}<->{level_index+1}") - print(f" - {up_edges} up edges, {down_edges} down edges") + utils.rank_zero_print(f" {level_index}<->{level_index + 1}") + utils.rank_zero_print( + f" - {up_edges} up edges, {down_edges} down edges" + ) + # Embedders # Assume all levels have same static feature dimensionality mesh_dim = self.mesh_static_features[0].shape[1] diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index f767fba0..2c2bb149 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -25,7 +25,7 @@ def __init__(self, args): # grid_dim from data + static + batch_static mesh_dim = self.mesh_static_features.shape[1] m2m_edges, m2m_dim = self.m2m_features.shape - print( + utils.rank_zero_print( f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, " f"m2g={self.m2g_edges}" ) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 31715502..6e0ec15b 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,9 +1,13 @@ # Standard library import os +import time # Third-party import numpy as np +import pytorch_lightning as pl import torch +import wandb # pylint: disable=wrong-import-order +from pytorch_lightning.utilities import rank_zero_only from torch import nn from tueplots import bundles, figsizes @@ -271,3 +275,49 @@ def init_wandb_metrics(wandb_logger): experiment.define_metric("val_mean_loss", summary="min") for step in constants.VAL_STEP_LOG_ERRORS: experiment.define_metric(f"val_loss_unroll{step}", summary="min") + + +@rank_zero_only +def rank_zero_print(*args, **kwargs): + """Print only from rank 0 process""" + print(*args, **kwargs) + + +@rank_zero_only +def init_wandb(args): + """Initialize wandb""" + if args.resume_run is None: + prefix = f"subset-{args.subset_ds}-" if args.subset_ds else "" + if args.eval: + prefix = prefix + f"eval-{args.eval}-" + run_name = ( + f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" + f"{time.strftime('%m_%d_%H_%M_%S')}" + ) + wandb.init( + name=run_name, + project=constants.WANDB_PROJECT, + config=args, + ) + logger = pl.loggers.WandbLogger( + project=constants.WANDB_PROJECT, + name=run_name, + config=args, + log_model=True, + ) + wandb.save("neural_lam/constants.py") + else: + wandb.init( + project=constants.WANDB_PROJECT, + config=args, + id=args.resume_run, + resume="must", + ) + logger = pl.loggers.WandbLogger( + project=constants.WANDB_PROJECT, + id=args.resume_run, + config=args, + log_model=True, + ) + + return logger From 96eea7801efefb6c7b7d0bcc7767f1f1e13ca374 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 1 May 2024 20:41:48 +0200 Subject: [PATCH 02/11] new function returns logger directly --- train_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_model.py b/train_model.py index 96d21a3f..bfadbb34 100644 --- a/train_model.py +++ b/train_model.py @@ -9,7 +9,7 @@ from lightning_fabric.utilities import seed # First-party -from neural_lam import constants, utils +from neural_lam import utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel @@ -263,9 +263,9 @@ def main(): mode="min", save_last=True, ) - logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, name=run_name, config=args - ) + + logger = utils.init_wandb(args) + trainer = pl.Trainer( max_epochs=args.epochs, deterministic=True, From 72272bc5362bdfbcac6facf9fd08979e495ad5bd Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 1 May 2024 22:02:17 +0200 Subject: [PATCH 03/11] new flag to resume wandb run useful after crash or when splitting training into multiple jobs --- train_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train_model.py b/train_model.py index bfadbb34..2c579050 100644 --- a/train_model.py +++ b/train_model.py @@ -74,6 +74,9 @@ def main(): type=str, help="Path to load model parameters from (default: None)", ) + parser.add_argument( + "--resume_run", type=str, help="Run ID to resume (default: None)" + ) parser.add_argument( "--restore_opt", type=int, From 57a396c3d0b0e3eb70dbc57093eb2dd935a5d61a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 25 May 2024 16:07:16 +0200 Subject: [PATCH 04/11] combine two wandb init functions --- neural_lam/utils.py | 11 +++++------ train_model.py | 4 ---- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 6e0ec15b..943fc84e 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -271,10 +271,7 @@ def init_wandb_metrics(wandb_logger): """ Set up wandb metrics to track """ - experiment = wandb_logger.experiment - experiment.define_metric("val_mean_loss", summary="min") - for step in constants.VAL_STEP_LOG_ERRORS: - experiment.define_metric(f"val_loss_unroll{step}", summary="min") + @rank_zero_only @@ -303,7 +300,6 @@ def init_wandb(args): project=constants.WANDB_PROJECT, name=run_name, config=args, - log_model=True, ) wandb.save("neural_lam/constants.py") else: @@ -317,7 +313,10 @@ def init_wandb(args): project=constants.WANDB_PROJECT, id=args.resume_run, config=args, - log_model=True, ) + experiment = logger.experiment + experiment.define_metric("val_mean_loss", summary="min") + for step in constants.VAL_STEP_LOG_ERRORS: + experiment.define_metric(f"val_loss_unroll{step}", summary="min") return logger diff --git a/train_model.py b/train_model.py index 2c579050..bca3f638 100644 --- a/train_model.py +++ b/train_model.py @@ -281,10 +281,6 @@ def main(): precision=args.precision, ) - # Only init once, on rank 0 only - if trainer.global_rank == 0: - utils.init_wandb_metrics(logger) # Do after wandb.init - if args.eval: if args.eval == "val": eval_loader = val_loader From ea43435373c81097eef8ac6b470ba598924fbf58 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 25 May 2024 16:20:45 +0200 Subject: [PATCH 05/11] linter --- neural_lam/utils.py | 19 ++++++------------- train_model.py | 1 - 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index d1602cfd..0c7aba45 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -264,13 +264,6 @@ def fractional_plot_bundle(fraction): return bundle -def init_wandb_metrics(wandb_logger, val_steps): - """ - Set up wandb metrics to track - """ - - - @rank_zero_only def rank_zero_print(*args, **kwargs): """Print only from rank 0 process""" @@ -290,30 +283,30 @@ def init_wandb(args): ) wandb.init( name=run_name, - project=constants.WANDB_PROJECT, + project=args.wandb_project, config=args, ) logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, + project=args.wandb_project, name=run_name, config=args, ) - wandb.save("neural_lam/constants.py") + wandb.save("neural_lam/data_config.yaml") else: wandb.init( - project=constants.WANDB_PROJECT, + project=args.wandb_project, config=args, id=args.resume_run, resume="must", ) logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, + project=args.wandb_project, id=args.resume_run, config=args, ) experiment = logger.experiment experiment.define_metric("val_mean_loss", summary="min") - for step in val_steps: + for step in args.val_steps_to_log: experiment.define_metric(f"val_loss_unroll{step}", summary="min") return logger diff --git a/train_model.py b/train_model.py index 5a106f76..388cbd90 100644 --- a/train_model.py +++ b/train_model.py @@ -9,7 +9,6 @@ from lightning_fabric.utilities import seed # First-party -from neural_lam import utils from neural_lam import config, utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM From 0ed609ab2028b28232ebf8019bc8a6d12ae3b766 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 25 May 2024 16:25:12 +0200 Subject: [PATCH 06/11] adding randint to prevent slurm issues --- neural_lam/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 0c7aba45..c6875aa3 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,16 +1,18 @@ # Standard library import os +import random import time # Third-party import numpy as np import pytorch_lightning as pl import torch -import wandb # pylint: disable=wrong-import-order from pytorch_lightning.utilities import rank_zero_only from torch import nn from tueplots import bundles, figsizes +import wandb # pylint: disable=wrong-import-order + def load_dataset_stats(dataset_name, device="cpu"): """ @@ -277,9 +279,10 @@ def init_wandb(args): prefix = f"subset-{args.subset_ds}-" if args.subset_ds else "" if args.eval: prefix = prefix + f"eval-{args.eval}-" + random_int = random.randint(0, 10000) run_name = ( f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" - f"{time.strftime('%m_%d_%H_%M_%S')}" + f"{time.strftime('%m_%d_%H_%M_%S')}-{random_int}" ) wandb.init( name=run_name, From 7f99788f5be888627e53aff9ed2a732315e15c62 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 25 May 2024 16:53:04 +0200 Subject: [PATCH 07/11] debug and linting --- neural_lam/utils.py | 3 +-- neural_lam/vis.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index c6875aa3..59a718f4 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -7,12 +7,11 @@ import numpy as np import pytorch_lightning as pl import torch +import wandb # pylint: disable=wrong-import-order from pytorch_lightning.utilities import rank_zero_only from torch import nn from tueplots import bundles, figsizes -import wandb # pylint: disable=wrong-import-order - def load_dataset_stats(dataset_name, device="cpu"): """ diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 2b6abf15..8c9ca77c 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) # Plot pred and target @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): fig, ax = plt.subplots( figsize=(5, 4.8), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) ax.coastlines() # Add coastline outlines From f2a818093d1b5a05db7363d91a839b2baf94d8b0 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 6 Jun 2024 17:27:13 +0200 Subject: [PATCH 08/11] remove double wandb initialization --- neural_lam/models/ar_model.py | 4 ++++ neural_lam/utils.py | 18 +++--------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 6ced211f..9448edae 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -597,3 +597,7 @@ def on_load_checkpoint(self, checkpoint): if not self.restore_opt: opt = self.configure_optimizers() checkpoint["optimizer_states"] = [opt.state_dict()] + + def on_run_end(self): + if self.trainer.is_global_zero: + wandb.save("neural_lam/data_config.yaml") diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 3f7a27c6..19021204 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,14 +1,13 @@ # Standard library import os -import shutil import random +import shutil import time # Third-party import numpy as np import pytorch_lightning as pl import torch -import wandb # pylint: disable=wrong-import-order from pytorch_lightning.utilities import rank_zero_only from torch import nn from tueplots import bundles, figsizes @@ -134,7 +133,8 @@ def loads_file(fn): hierarchical = n_levels > 1 # Nor just single level mesh graph # Load static edge features - m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f) + # List of (M_m2m[l], d_edge_f) + m2m_features = loads_file("m2m_features.pt") g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f) m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f) @@ -288,24 +288,12 @@ def init_wandb(args): f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" f"{time.strftime('%m_%d_%H_%M_%S')}-{random_int}" ) - wandb.init( - name=run_name, - project=args.wandb_project, - config=args, - ) logger = pl.loggers.WandbLogger( project=args.wandb_project, name=run_name, config=args, ) - wandb.save("neural_lam/data_config.yaml") else: - wandb.init( - project=args.wandb_project, - config=args, - id=args.resume_run, - resume="must", - ) logger = pl.loggers.WandbLogger( project=args.wandb_project, id=args.resume_run, From d392e53ce71761732cddd32283e544d6a3d31ca4 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 7 Jun 2024 12:39:34 +0200 Subject: [PATCH 09/11] switched hooks for saving config xaml file --- CHANGELOG.md | 9 +++++++++ neural_lam/models/ar_model.py | 14 ++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4680c37..2a4f539b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) ### Added + +- Added `rank_zero_print` function to `utils.py` for printing in multi-node distributed training + [\#16](https://github.com/mllam/neural-lam/pull/16) + @sadamov + - Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub. Added caching of test data tp speed up running tests. [/#38](https://github.com/mllam/neural-lam/pull/38) @SimonKamuk @@ -30,6 +35,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Initialization of wandb is now robust for multi-node distributed training and config files are saved to wandb + [\#16](https://github.com/mllam/neural-lam/pull/16) + @sadamov + - Robust restoration of optimizer and scheduler using `ckpt_path` [\#17](https://github.com/mllam/neural-lam/pull/17) @sadamov diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 9448edae..9995e10a 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -213,6 +213,11 @@ def training_step(self, batch): ) return batch_loss + def on_train_start(self): + """Save data config file to wandb at start of training""" + if self.trainer.is_global_zero: + wandb.save("neural_lam/data_config.yaml") + def all_gather_cat(self, tensor_to_gather): """ Gather tensors across all ranks, and concatenate across dim. 0 @@ -521,6 +526,11 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): wandb.log(log_dict) # Log all plt.close("all") # Close all figs + def on_test_start(self): + """Save data config file to wandb at start of test""" + if self.trainer.is_global_zero: + wandb.save("neural_lam/data_config.yaml") + def on_test_epoch_end(self): """ Compute test metrics and make plots at the end of test epoch. @@ -597,7 +607,3 @@ def on_load_checkpoint(self, checkpoint): if not self.restore_opt: opt = self.configure_optimizers() checkpoint["optimizer_states"] = [opt.state_dict()] - - def on_run_end(self): - if self.trainer.is_global_zero: - wandb.save("neural_lam/data_config.yaml") From 2ac2a8b1ea1eb644771b2a697f194812c8921d66 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 7 Jun 2024 12:46:45 +0200 Subject: [PATCH 10/11] save user config instead of of default value --- neural_lam/models/ar_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 9995e10a..1dd6b72c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -216,7 +216,7 @@ def training_step(self, batch): def on_train_start(self): """Save data config file to wandb at start of training""" if self.trainer.is_global_zero: - wandb.save("neural_lam/data_config.yaml") + wandb.save(self.args.data_config) def all_gather_cat(self, tensor_to_gather): """ @@ -529,7 +529,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): def on_test_start(self): """Save data config file to wandb at start of test""" if self.trainer.is_global_zero: - wandb.save("neural_lam/data_config.yaml") + wandb.save(self.args.data_config) def on_test_epoch_end(self): """ From 10835e9aaa577569be1a427888750c2966a3f75c Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Fri, 7 Jun 2024 16:02:42 +0200 Subject: [PATCH 11/11] Do not include directory structure of data config in wandb storage --- neural_lam/models/ar_model.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 1dd6b72c..b8f0bfd7 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -215,8 +215,7 @@ def training_step(self, batch): def on_train_start(self): """Save data config file to wandb at start of training""" - if self.trainer.is_global_zero: - wandb.save(self.args.data_config) + self.save_data_config() def all_gather_cat(self, tensor_to_gather): """ @@ -528,8 +527,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): def on_test_start(self): """Save data config file to wandb at start of test""" - if self.trainer.is_global_zero: - wandb.save(self.args.data_config) + self.save_data_config() def on_test_epoch_end(self): """ @@ -607,3 +605,12 @@ def on_load_checkpoint(self, checkpoint): if not self.restore_opt: opt = self.configure_optimizers() checkpoint["optimizer_states"] = [opt.state_dict()] + + def save_data_config(self): + """Save data config file to wandb""" + if self.trainer.is_global_zero: + wandb.save( + self.args.data_config, + base_path=os.path.dirname(self.args.data_config), + policy="now", + )