From 5e5233fd18572ee6f74b9f707f7ad3fefe75b6f4 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Thu, 22 Feb 2024 23:49:28 +0000 Subject: [PATCH] add cfg sanity checks --- lagrangebench/defaults.py | 15 +++++++++++++++ lagrangebench/runner.py | 4 +++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/lagrangebench/defaults.py b/lagrangebench/defaults.py index 119adc0..00f7339 100644 --- a/lagrangebench/defaults.py +++ b/lagrangebench/defaults.py @@ -172,3 +172,18 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig: defaults = set_defaults() + + +def check_cfg(cfg: DictConfig): + """Check if the configs are valid.""" + + assert cfg.main.mode in ["train", "infer", "all"] + assert cfg.main.dtype in ["float32", "float64"] + assert cfg.main.data_dir is not None, "main.data_dir must be specified." + + assert cfg.eval.train.n_trajs >= -1 + assert cfg.eval.infer.n_trajs >= -1 + assert set(cfg.eval.train.metrics).issubset(["mse", "e_kin", "sinkhorn"]) + assert set(cfg.eval.infer.metrics).issubset(["mse", "e_kin", "sinkhorn"]) + assert cfg.eval.train.out_type in ["none", "vtk", "pkl"] + assert cfg.eval.infer.out_type in ["none", "vtk", "pkl"] diff --git a/lagrangebench/runner.py b/lagrangebench/runner.py index e83a382..eab6544 100644 --- a/lagrangebench/runner.py +++ b/lagrangebench/runner.py @@ -17,13 +17,15 @@ from lagrangebench import Trainer, infer, models from lagrangebench.case_setup import case_builder from lagrangebench.data import H5Dataset +from lagrangebench.defaults import check_cfg from lagrangebench.evaluate import averaged_metrics from lagrangebench.models.utils import node_irreps from lagrangebench.utils import NodeType def train_or_infer(cfg: DictConfig): - # TODO: sanity checks on the passed configs go in here + # sanity check on the passed configs + check_cfg(cfg) mode = cfg.main.mode old_model_dir = cfg.main.model_dir