Skip to content

Commit

Permalink
Make sure wandb is initialized before defining metrics, also for pyto…
Browse files Browse the repository at this point in the history
…rch-lightning >= 2.1
  • Loading branch information
joeloskarsson committed Nov 9, 2023
1 parent 6377d44 commit 9912ece
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch.nn as nn
import numpy as np
from tueplots import bundles, figsizes
import wandb

from neural_lam import constants

Expand Down Expand Up @@ -196,10 +195,11 @@ def fractional_plot_bundle(fraction):
bundle["figure.figsize"] = (original_figsize[0]/fraction, original_figsize[1])
return bundle

def init_wandb_metrics():
def init_wandb_metrics(wandb_logger):
"""
Set up wandb metrics to track
"""
wandb.define_metric("val_mean_loss", summary="min")
experiment = wandb_logger.experiment
experiment.define_metric("val_mean_loss", summary="min")
for step in constants.val_step_log_errors:
wandb.define_metric(f"val_loss_unroll{step}", summary="min")
experiment.define_metric(f"val_loss_unroll{step}", summary="min")
2 changes: 1 addition & 1 deletion train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def main():

# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_wandb_metrics() # Do after wandb.init
utils.init_wandb_metrics(logger) # Do after wandb.init

if args.eval:
if args.eval == "val":
Expand Down

0 comments on commit 9912ece

Please sign in to comment.