From 54c44cdfdbcbb8ed887a1ca0f5bdc5d45e34c22d Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 25 May 2024 19:31:53 +0200 Subject: [PATCH 1/2] three minor bugfixes --- neural_lam/data_config.yaml | 10 +++++----- neural_lam/vis.py | 4 ++-- train_model.py | 9 ++++++--- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index f16a4a30..f1527849 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -21,8 +21,8 @@ dataset: var_units: - Pa - Pa - - r"$\mathrm{W}/\mathrm{m}^2$" - - r"$\mathrm{W}/\mathrm{m}^2$" + - $\mathrm{W}/\mathrm{m}^2$ + - $\mathrm{W}/\mathrm{m}^2$ - "" - "" - K @@ -33,9 +33,9 @@ dataset: - m/s - m/s - m/s - - r"$\mathrm{kg}/\mathrm{m}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" + - $\mathrm{kg}/\mathrm{m}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ var_longnames: - pres_heightAboveGround_0_instant - pres_heightAboveSea_0_instant diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 2b6abf15..8c9ca77c 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) # Plot pred and target @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): fig, ax = plt.subplots( figsize=(5, 4.8), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) ax.coastlines() # Add coastline outlines diff --git a/train_model.py b/train_model.py index 390da6d4..df63bcfe 100644 --- a/train_model.py +++ b/train_model.py @@ -1,4 +1,5 @@ # Standard library +import json import random import time from argparse import ArgumentParser @@ -202,11 +203,13 @@ def main(): ) parser.add_argument( "--var_leads_metrics_watch", - type=dict, - default={}, - help="Dict with variables and lead times to log watched metrics for", + type=str, + default="{}", + help="JSON string with variables and lead times to log watched metrics" + # e.g. '{"var1": [1, 2], "var2": [3, 4]}' ) args = parser.parse_args() + args.var_leads_metrics_watch = json.loads(args.var_leads_metrics_watch) config_loader = config.Config.from_file(args.data_config) From 2c07814e55165baf5bbbeef590b2735bc2483e4e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 30 May 2024 16:08:52 +0200 Subject: [PATCH 2/2] Misunderstood lists and dicts in argsparge -> bugfix --- neural_lam/models/ar_model.py | 2 +- train_model.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 29b169d4..6ced211f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -473,7 +473,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): # Check if metrics are watched, log exact values for specific vars if full_log_name in self.args.metrics_watch: for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var = self.config_loader.dataset.var_nums[var_i] + var = self.config_loader.dataset.var_names[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ diff --git a/train_model.py b/train_model.py index 1f348e34..cbd787f0 100644 --- a/train_model.py +++ b/train_model.py @@ -197,7 +197,7 @@ def main(): ) parser.add_argument( "--metrics_watch", - type=list, + nargs="+", default=[], help="List of metrics to watch, including any prefix (e.g. val_rmse)", ) @@ -205,11 +205,13 @@ def main(): "--var_leads_metrics_watch", type=str, default="{}", - help="JSON string with variables and lead times to log watched metrics" - # e.g. '{"var1": [1, 2], "var2": [3, 4]}' + help="""JSON string with variable-IDs and lead times to log watched + metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""", ) args = parser.parse_args() - args.var_leads_metrics_watch = json.loads(args.var_leads_metrics_watch) + args.var_leads_metrics_watch = { + int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() + } config_loader = config.Config.from_file(args.data_config)