Skip to content

Commit

Permalink
wandb update configs
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed Feb 23, 2024
1 parent ecbe8fc commit 434f27f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
12 changes: 10 additions & 2 deletions lagrangebench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp
import jmp
import numpy as np
import wandb
from e3nn_jax import Irreps
from jax import config
from jax_md import space
Expand Down Expand Up @@ -90,7 +91,13 @@ def train_or_infer(cfg: Union[Dict, DictConfig]):
OmegaConf.save(config=cfg, f=f.name)

# dictionary of configs which will be stored on W&B
wandb_config = OmegaConf.to_container(cfg)
wandb_run = wandb.init(
project=cfg.logging.wandb_project,
entity=cfg.logging.wandb_entity,
name=cfg.logging.run_name,
config=OmegaConf.to_container(cfg),
save_code=True,
)

trainer = Trainer(
model,
Expand All @@ -102,8 +109,9 @@ def train_or_infer(cfg: Union[Dict, DictConfig]):
cfg.logging,
input_seq_length=cfg.model.input_seq_length,
seed=cfg.main.seed,
wandb_config=wandb_config,
wandb_run=wandb_run,
)

_, _, _ = trainer.train(
step_max=cfg.train.step_max,
load_checkpoint=old_model_dir,
Expand Down
50 changes: 26 additions & 24 deletions lagrangebench/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def __init__(
cfg_logging: Union[Dict, DictConfig] = defaults.logging,
input_seq_length: int = defaults.model.input_seq_length,
seed: int = defaults.main.seed,
**kwargs,
wandb_run: Optional[wandb.wandb_sdk.wandb_run.Run] = None,
) -> Callable:
"""
Builds a function that automates model training and evaluation.
Trainer class.
Given a model, training and validation datasets and a case this function returns
another function that:
Given a model, case setup, training and validation datasets this class
automates training and evaluation.
1. Initializes (or restarts a checkpoint) model, optimizer and loss function.
2. Trains the model on data_train, using the given pushforward and noise tricks.
Expand Down Expand Up @@ -159,9 +159,6 @@ def __init__(
# set the number of validation trajectories during training
if self.cfg_eval.train.n_trajs == -1:
self.cfg_eval.train.n_trajs = data_valid.num_samples
if self.cfg_logging.wandb:
self.wandb_config = kwargs["wandb_config"]
self.wandb_config["eval"]["train"]["n_trajs"] = self.cfg_eval.train.n_trajs

# make immutable for jitting
loss_weight = self.cfg_train.loss_weight
Expand Down Expand Up @@ -209,6 +206,15 @@ def __init__(
stride=self.cfg_eval.train.metrics_stride,
)

if wandb_run is None and self.cfg_logging.wandb:
wandb_run = wandb.init(
project=self.cfg_logging.wandb_project,
entity=self.cfg_logging.wandb_entity,
name=self.cfg_logging.run_name,
save_code=True,
)
self.wandb_run = wandb_run

def train(
self,
step_max: int = defaults.train.step_max,
Expand Down Expand Up @@ -272,24 +278,20 @@ def train(
params, state = model.init(subkey, (features, particle_type[0]))

# start logging
if cfg_logging.wandb:
self.wandb_config["info"] = {
if cfg_logging.wandb and self.wandb_run:
extended_config = self.wandb_run.config
extended_config["eval"]["train"]["n_trajs"] = self.cfg_eval.train.n_trajs

extended_config["info"] = {
"dataset_name": loader_train.dataset.name,
"len_train": len(loader_train.dataset),
"len_eval": len(loader_valid.dataset),
"num_params": get_num_params(params).item(),
"step_start": step,
}

wandb_run = wandb.init(
project=cfg_logging.wandb_project,
entity=cfg_logging.wandb_entity,
name=cfg_logging.run_name,
config=self.wandb_config,
save_code=True,
)
else:
wandb_run = None
self.wandb_run.config.update(extended_config)
self.wandb_run.config.persist()

# initialize optimizer state
if opt_state is None:
Expand Down Expand Up @@ -365,8 +367,8 @@ def train(

if step % cfg_logging.log_steps == 0:
loss.block_until_ready()
if wandb_run:
wandb_run.log({"train/loss": loss.item()}, step)
if self.wandb_run:
self.wandb_run.log({"train/loss": loss.item()}, step)
else:
step_str = str(step).zfill(len(str(int(step_max))))
print(f"{step_str}, train/loss: {loss.item():.5f}.")
Expand Down Expand Up @@ -397,16 +399,16 @@ def train(
store_checkpoint, params, state, opt_state, metadata_ckp
)

if wandb_run:
wandb_run.log(metrics, step)
if self.wandb_run:
self.wandb_run.log(metrics, step)
else:
print(metrics)

step += 1
if step == step_max + 1:
break

if cfg_logging.wandb:
wandb.finish()
if self.wandb_run:
self.wandb_run.finish()

return params, state, opt_state

0 comments on commit 434f27f

Please sign in to comment.