diff --git a/CHANGELOG.md b/CHANGELOG.md index f4680c37..2a4f539b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) ### Added + +- Added `rank_zero_print` function to `utils.py` for printing in multi-node distributed training + [\#16](https://github.com/mllam/neural-lam/pull/16) + @sadamov + - Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub. Added caching of test data tp speed up running tests. [/#38](https://github.com/mllam/neural-lam/pull/38) @SimonKamuk @@ -30,6 +35,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Initialization of wandb is now robust for multi-node distributed training and config files are saved to wandb + [\#16](https://github.com/mllam/neural-lam/pull/16) + @sadamov + - Robust restoration of optimizer and scheduler using `ckpt_path` [\#17](https://github.com/mllam/neural-lam/pull/17) @sadamov diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 9448edae..9995e10a 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -213,6 +213,11 @@ def training_step(self, batch): ) return batch_loss + def on_train_start(self): + """Save data config file to wandb at start of training""" + if self.trainer.is_global_zero: + wandb.save("neural_lam/data_config.yaml") + def all_gather_cat(self, tensor_to_gather): """ Gather tensors across all ranks, and concatenate across dim. 0 @@ -521,6 +526,11 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): wandb.log(log_dict) # Log all plt.close("all") # Close all figs + def on_test_start(self): + """Save data config file to wandb at start of test""" + if self.trainer.is_global_zero: + wandb.save("neural_lam/data_config.yaml") + def on_test_epoch_end(self): """ Compute test metrics and make plots at the end of test epoch. @@ -597,7 +607,3 @@ def on_load_checkpoint(self, checkpoint): if not self.restore_opt: opt = self.configure_optimizers() checkpoint["optimizer_states"] = [opt.state_dict()] - - def on_run_end(self): - if self.trainer.is_global_zero: - wandb.save("neural_lam/data_config.yaml")