Skip to content

Commit

Permalink
add cfg sanity checks
Browse files Browse the repository at this point in the history
  • Loading branch information
arturtoshev committed Feb 22, 2024
1 parent 0a16248 commit 5e5233f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
15 changes: 15 additions & 0 deletions lagrangebench/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,18 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:


defaults = set_defaults()


def check_cfg(cfg: DictConfig):
"""Check if the configs are valid."""

assert cfg.main.mode in ["train", "infer", "all"]
assert cfg.main.dtype in ["float32", "float64"]
assert cfg.main.data_dir is not None, "main.data_dir must be specified."

assert cfg.eval.train.n_trajs >= -1
assert cfg.eval.infer.n_trajs >= -1
assert set(cfg.eval.train.metrics).issubset(["mse", "e_kin", "sinkhorn"])
assert set(cfg.eval.infer.metrics).issubset(["mse", "e_kin", "sinkhorn"])
assert cfg.eval.train.out_type in ["none", "vtk", "pkl"]
assert cfg.eval.infer.out_type in ["none", "vtk", "pkl"]
4 changes: 3 additions & 1 deletion lagrangebench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
from lagrangebench import Trainer, infer, models
from lagrangebench.case_setup import case_builder
from lagrangebench.data import H5Dataset
from lagrangebench.defaults import check_cfg
from lagrangebench.evaluate import averaged_metrics
from lagrangebench.models.utils import node_irreps
from lagrangebench.utils import NodeType


def train_or_infer(cfg: DictConfig):
# TODO: sanity checks on the passed configs go in here
# sanity check on the passed configs
check_cfg(cfg)

mode = cfg.main.mode
old_model_dir = cfg.main.model_dir
Expand Down

0 comments on commit 5e5233f

Please sign in to comment.