From 6f1efd657e76fa1290b33d671c2910cf42602e46 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 10 Sep 2024 17:07:06 +0200 Subject: [PATCH] forcing_window_size from args --- neural_lam/models/ar_model.py | 5 ++++- neural_lam/models/base_graph_model.py | 6 ++---- neural_lam/models/graph_lam.py | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index a0a7880c..203b20c5 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -23,7 +23,9 @@ class ARModel(pl.LightningModule): # Disable to override args/kwargs from superclass def __init__( - self, args, datastore: BaseDatastore, forcing_window_size: int + self, + args, + datastore: BaseDatastore, ): super().__init__() self.save_hyperparameters(ignore=["datastore"]) @@ -38,6 +40,7 @@ def __init__( ) da_state_stats = datastore.get_normalization_dataarray(category="state") da_boundary_mask = datastore.boundary_mask + forcing_window_size = args.forcing_window_size # Load static features for grid/data, NB: self.predict_step assumes # dimension order to be (grid_index, static_feature) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 16897e4f..b9dce90f 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -13,10 +13,8 @@ class BaseGraphModel(ARModel): the encode-process-decode idea. """ - def __init__(self, args, datastore, forcing_window_size): - super().__init__( - args, datastore=datastore, forcing_window_size=forcing_window_size - ) + def __init__(self, args, datastore): + super().__init__(args, datastore=datastore) # Load graph with static features # NOTE: (IMPORTANT!) mesh nodes MUST have the first diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index a4c726b1..9288e539 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -15,8 +15,8 @@ class GraphLAM(BaseGraphModel): Oskarsson et al. (2023). """ - def __init__(self, args, datastore, forcing_window_size): - super().__init__(args, datastore, forcing_window_size) + def __init__(self, args, datastore): + super().__init__(args, datastore) assert ( not self.hierarchical