From c4d698a86531a86051fb1f8a7c2f8093349f8eed Mon Sep 17 00:00:00 2001 From: gerkone Date: Mon, 26 Feb 2024 20:26:55 +0100 Subject: [PATCH] doc fixes --- lagrangebench/evaluate/rollout.py | 2 +- lagrangebench/train/trainer.py | 28 ++++++++++++++-------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/lagrangebench/evaluate/rollout.py b/lagrangebench/evaluate/rollout.py index 81016eb..341b1ab 100644 --- a/lagrangebench/evaluate/rollout.py +++ b/lagrangebench/evaluate/rollout.py @@ -331,7 +331,7 @@ def infer( state: Haiku state. load_ckp: Path to checkpoint directory. rollout_dir: Path to rollout directory. - cfd_eval_infer: Evaluation configuration for inference mode. + cfg_eval_infer: Evaluation configuration for inference mode. n_rollout_steps: Number of rollout steps. seed: Seed. diff --git a/lagrangebench/train/trainer.py b/lagrangebench/train/trainer.py index 9b62a02..575af5e 100644 --- a/lagrangebench/train/trainer.py +++ b/lagrangebench/train/trainer.py @@ -99,20 +99,6 @@ class Trainer: 1. Initializes (or restarts a checkpoint) model, optimizer and loss function. 2. Trains the model on data_train, using the given pushforward and noise tricks. 3. Evaluates the model on data_valid on the specified metrics. - - Args: - model: (Transformed) Haiku model. - case: Case setup class. - data_train: Training dataset. - data_valid: Validation dataset. - cfg_train: Training configuration. - cfg_eval: Evaluation configuration. - cfg_logging: Logging configuration. - input_seq_length: Input sequence length, i.e. number of past positions. - seed: Random seed for model init, training tricks and dataloading. - - Returns: - Configured training function. """ def __init__( @@ -127,6 +113,20 @@ def __init__( input_seq_length: int = defaults.model.input_seq_length, seed: int = defaults.seed, ): + """Initializes the trainer. + + Args: + model: (Transformed) Haiku model. + case: Case setup class. + data_train: Training dataset. + data_valid: Validation dataset. + cfg_train: Training configuration. + cfg_eval: Evaluation configuration. + cfg_logging: Logging configuration. + input_seq_length: Input sequence length, i.e. number of past positions. + seed: Random seed for model init, training tricks and dataloading. + """ + if isinstance(cfg_train, Dict): cfg_train = OmegaConf.create(cfg_train) if isinstance(cfg_eval, Dict):