Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rank-Zero Printing and Improve Wandb Initialization #16

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD)

### Added

- Added `rank_zero_print` function to `utils.py` for printing in multi-node distributed training
[\#16](https://github.com/mllam/neural-lam/pull/16)
@sadamov

- Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub. Added caching of test data tp speed up running tests.
[/#38](https://github.com/mllam/neural-lam/pull/38)
@SimonKamuk
Expand All @@ -30,6 +35,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Initialization of wandb is now robust for multi-node distributed training and config files are saved to wandb
[\#16](https://github.com/mllam/neural-lam/pull/16)
@sadamov

- Robust restoration of optimizer and scheduler using `ckpt_path`
[\#17](https://github.com/mllam/neural-lam/pull/17)
@sadamov
Expand Down
17 changes: 17 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def training_step(self, batch):
)
return batch_loss

def on_train_start(self):
"""Save data config file to wandb at start of training"""
self.save_data_config()

def all_gather_cat(self, tensor_to_gather):
"""
Gather tensors across all ranks, and concatenate across dim. 0
Expand Down Expand Up @@ -521,6 +525,10 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
wandb.log(log_dict) # Log all
plt.close("all") # Close all figs

def on_test_start(self):
"""Save data config file to wandb at start of test"""
self.save_data_config()

def on_test_epoch_end(self):
"""
Compute test metrics and make plots at the end of test epoch.
Expand Down Expand Up @@ -597,3 +605,12 @@ def on_load_checkpoint(self, checkpoint):
if not self.restore_opt:
opt = self.configure_optimizers()
checkpoint["optimizer_states"] = [opt.state_dict()]

def save_data_config(self):
"""Save data config file to wandb"""
if self.trainer.is_global_zero:
wandb.save(
self.args.data_config,
base_path=os.path.dirname(self.args.data_config),
policy="now",
)
2 changes: 1 addition & 1 deletion neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, args):

# Specify dimensions of data
self.num_mesh_nodes, _ = self.get_num_mesh()
print(
utils.rank_zero_print(
f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} "
f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)"
)
Expand Down
11 changes: 7 additions & 4 deletions neural_lam/models/base_hi_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,22 @@ def __init__(self, args):
] # Needs as python list for later

# Print some useful info
print("Loaded hierarchical graph with structure:")
utils.rank_zero_print("Loaded hierarchical graph with structure:")
for level_index, level_mesh_size in enumerate(self.level_mesh_sizes):
same_level_edges = self.m2m_features[level_index].shape[0]
print(
utils.rank_zero_print(
f"level {level_index} - {level_mesh_size} nodes, "
f"{same_level_edges} same-level edges"
)

if level_index < (self.num_levels - 1):
up_edges = self.mesh_up_features[level_index].shape[0]
down_edges = self.mesh_down_features[level_index].shape[0]
print(f" {level_index}<->{level_index + 1}")
print(f" - {up_edges} up edges, {down_edges} down edges")
utils.rank_zero_print(f" {level_index}<->{level_index + 1}")
utils.rank_zero_print(
f" - {up_edges} up edges, {down_edges} down edges"
)

# Embedders
# Assume all levels have same static feature dimensionality
mesh_dim = self.mesh_static_features[0].shape[1]
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/graph_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, args):
# grid_dim from data + static + batch_static
mesh_dim = self.mesh_static_features.shape[1]
m2m_edges, m2m_dim = self.m2m_features.shape
print(
utils.rank_zero_print(
f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, "
f"m2g={self.m2g_edges}"
)
Expand Down
46 changes: 39 additions & 7 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Standard library
import os
import random
import shutil
import time

# Third-party
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities import rank_zero_only
from torch import nn
from tueplots import bundles, figsizes

Expand Down Expand Up @@ -129,7 +133,8 @@ def loads_file(fn):
hierarchical = n_levels > 1 # Nor just single level mesh graph

# Load static edge features
m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f)
# List of (M_m2m[l], d_edge_f)
m2m_features = loads_file("m2m_features.pt")
g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f)
m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f)

Expand Down Expand Up @@ -265,11 +270,38 @@ def fractional_plot_bundle(fraction):
return bundle


def init_wandb_metrics(wandb_logger, val_steps):
"""
Set up wandb metrics to track
"""
experiment = wandb_logger.experiment
@rank_zero_only
def rank_zero_print(*args, **kwargs):
"""Print only from rank 0 process"""
print(*args, **kwargs)


@rank_zero_only
joeloskarsson marked this conversation as resolved.
Show resolved Hide resolved
def init_wandb(args):
sadamov marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize wandb"""
if args.resume_run is None:
prefix = f"subset-{args.subset_ds}-" if args.subset_ds else ""
if args.eval:
prefix = prefix + f"eval-{args.eval}-"
random_int = random.randint(0, 10000)
run_name = (
f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"
f"{time.strftime('%m_%d_%H_%M_%S')}-{random_int}"
)
logger = pl.loggers.WandbLogger(
sadamov marked this conversation as resolved.
Show resolved Hide resolved
project=args.wandb_project,
name=run_name,
config=args,
)
else:
logger = pl.loggers.WandbLogger(
sadamov marked this conversation as resolved.
Show resolved Hide resolved
project=args.wandb_project,
id=args.resume_run,
config=args,
)
experiment = logger.experiment
experiment.define_metric("val_mean_loss", summary="min")
for step in val_steps:
for step in args.val_steps_to_log:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")

return logger
15 changes: 6 additions & 9 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def main(input_args=None):
type=str,
help="Path to load model parameters from (default: None)",
)
parser.add_argument(
"--resume_run", type=str, help="Run ID to resume (default: None)"
)
parser.add_argument(
"--restore_opt",
type=int,
Expand Down Expand Up @@ -285,9 +288,9 @@ def main(input_args=None):
mode="min",
save_last=True,
)
logger = pl.loggers.WandbLogger(
project=args.wandb_project, name=run_name, config=args
)

logger = utils.init_wandb(args)

trainer = pl.Trainer(
max_epochs=args.epochs,
deterministic=True,
Expand All @@ -300,12 +303,6 @@ def main(input_args=None):
precision=args.precision,
)

# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_wandb_metrics(
logger, args.val_steps_to_log
) # Do after wandb.init

if args.eval:
if args.eval == "val":
eval_loader = val_loader
Expand Down
Loading