Skip to content

Commit

Permalink
doc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed Feb 26, 2024
1 parent bfefdfe commit c4d698a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion lagrangebench/evaluate/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def infer(
state: Haiku state.
load_ckp: Path to checkpoint directory.
rollout_dir: Path to rollout directory.
cfd_eval_infer: Evaluation configuration for inference mode.
cfg_eval_infer: Evaluation configuration for inference mode.
n_rollout_steps: Number of rollout steps.
seed: Seed.
Expand Down
28 changes: 14 additions & 14 deletions lagrangebench/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,6 @@ class Trainer:
1. Initializes (or restarts a checkpoint) model, optimizer and loss function.
2. Trains the model on data_train, using the given pushforward and noise tricks.
3. Evaluates the model on data_valid on the specified metrics.
Args:
model: (Transformed) Haiku model.
case: Case setup class.
data_train: Training dataset.
data_valid: Validation dataset.
cfg_train: Training configuration.
cfg_eval: Evaluation configuration.
cfg_logging: Logging configuration.
input_seq_length: Input sequence length, i.e. number of past positions.
seed: Random seed for model init, training tricks and dataloading.
Returns:
Configured training function.
"""

def __init__(
Expand All @@ -127,6 +113,20 @@ def __init__(
input_seq_length: int = defaults.model.input_seq_length,
seed: int = defaults.seed,
):
"""Initializes the trainer.
Args:
model: (Transformed) Haiku model.
case: Case setup class.
data_train: Training dataset.
data_valid: Validation dataset.
cfg_train: Training configuration.
cfg_eval: Evaluation configuration.
cfg_logging: Logging configuration.
input_seq_length: Input sequence length, i.e. number of past positions.
seed: Random seed for model init, training tricks and dataloading.
"""

if isinstance(cfg_train, Dict):
cfg_train = OmegaConf.create(cfg_train)

Check warning on line 131 in lagrangebench/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

lagrangebench/train/trainer.py#L131

Added line #L131 was not covered by tests
if isinstance(cfg_eval, Dict):
Expand Down

0 comments on commit c4d698a

Please sign in to comment.