From 9d558d1f0d343cfe6e0babaa8d9e6c45b852fe21 Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Fri, 31 May 2024 12:12:58 +0200 Subject: [PATCH] Fix minor bugs in data_config.yaml workflow (#40) ### Summary https://github.com/mllam/neural-lam/pull/31 introduced three minor bugs that are fixed with this PR: - r"" strings are not required in units of `data_config.yaml` - dictionaries cannot be passed as argsparse, rather JSON strings. This bug is related to the flag `var_leads_metrics_watch` --------- Co-authored-by: joeloskarsson --- neural_lam/data_config.yaml | 10 +++++----- neural_lam/models/ar_model.py | 2 +- train_model.py | 13 +++++++++---- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index f16a4a30..f1527849 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -21,8 +21,8 @@ dataset: var_units: - Pa - Pa - - r"$\mathrm{W}/\mathrm{m}^2$" - - r"$\mathrm{W}/\mathrm{m}^2$" + - $\mathrm{W}/\mathrm{m}^2$ + - $\mathrm{W}/\mathrm{m}^2$ - "" - "" - K @@ -33,9 +33,9 @@ dataset: - m/s - m/s - m/s - - r"$\mathrm{kg}/\mathrm{m}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" + - $\mathrm{kg}/\mathrm{m}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ var_longnames: - pres_heightAboveGround_0_instant - pres_heightAboveSea_0_instant diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 29b169d4..6ced211f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -473,7 +473,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): # Check if metrics are watched, log exact values for specific vars if full_log_name in self.args.metrics_watch: for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var = self.config_loader.dataset.var_nums[var_i] + var = self.config_loader.dataset.var_names[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ diff --git a/train_model.py b/train_model.py index fe064384..cbd787f0 100644 --- a/train_model.py +++ b/train_model.py @@ -1,4 +1,5 @@ # Standard library +import json import random import time from argparse import ArgumentParser @@ -196,17 +197,21 @@ def main(): ) parser.add_argument( "--metrics_watch", - type=list, + nargs="+", default=[], help="List of metrics to watch, including any prefix (e.g. val_rmse)", ) parser.add_argument( "--var_leads_metrics_watch", - type=dict, - default={}, - help="Dict with variables and lead times to log watched metrics for", + type=str, + default="{}", + help="""JSON string with variable-IDs and lead times to log watched + metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""", ) args = parser.parse_args() + args.var_leads_metrics_watch = { + int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() + } config_loader = config.Config.from_file(args.data_config)