diff --git a/neural_lam/utils.py b/neural_lam/utils.py index d1602cfd..0c7aba45 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -264,13 +264,6 @@ def fractional_plot_bundle(fraction): return bundle -def init_wandb_metrics(wandb_logger, val_steps): - """ - Set up wandb metrics to track - """ - - - @rank_zero_only def rank_zero_print(*args, **kwargs): """Print only from rank 0 process""" @@ -290,30 +283,30 @@ def init_wandb(args): ) wandb.init( name=run_name, - project=constants.WANDB_PROJECT, + project=args.wandb_project, config=args, ) logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, + project=args.wandb_project, name=run_name, config=args, ) - wandb.save("neural_lam/constants.py") + wandb.save("neural_lam/data_config.yaml") else: wandb.init( - project=constants.WANDB_PROJECT, + project=args.wandb_project, config=args, id=args.resume_run, resume="must", ) logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, + project=args.wandb_project, id=args.resume_run, config=args, ) experiment = logger.experiment experiment.define_metric("val_mean_loss", summary="min") - for step in val_steps: + for step in args.val_steps_to_log: experiment.define_metric(f"val_loss_unroll{step}", summary="min") return logger diff --git a/train_model.py b/train_model.py index 5a106f76..388cbd90 100644 --- a/train_model.py +++ b/train_model.py @@ -9,7 +9,6 @@ from lightning_fabric.utilities import seed # First-party -from neural_lam import utils from neural_lam import config, utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM