From b35072d551b9698099441296615354b43a7b3cac Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 2 Dec 2024 16:25:01 +0100 Subject: [PATCH] Fix graph loading and boundary mask --- neural_lam/build_rectangular_graph.py | 10 ++++++---- neural_lam/models/ar_model.py | 4 ++-- neural_lam/models/base_graph_model.py | 2 +- neural_lam/plot_graph.py | 6 ++++-- neural_lam/train_model.py | 2 +- neural_lam/weather_dataset.py | 4 ++-- 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/neural_lam/build_rectangular_graph.py b/neural_lam/build_rectangular_graph.py index e4570397..df7f8ba8 100644 --- a/neural_lam/build_rectangular_graph.py +++ b/neural_lam/build_rectangular_graph.py @@ -30,7 +30,7 @@ def main(input_args=None): help="Path to the configuration for neural-lam", ) parser.add_argument( - "--name", + "--graph_name", type=str, help="Name to save graph as (default: multiscale)", ) @@ -74,8 +74,8 @@ def main(input_args=None): args.config_path is not None ), "Specify your config with --config_path" assert ( - args.name is not None - ), "Specify the name to save graph as with --name" + args.graph_name is not None + ), "Specify the name to save graph as with --graph_name" _, datastore = load_config_and_datastore(config_path=args.config_path) @@ -124,7 +124,9 @@ def main(input_args=None): print(f"{name}: {subgraph}") # Save graph - graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name) + graph_dir_path = os.path.join( + datastore.root_path, "graphs", args.graph_name + ) os.makedirs(graph_dir_path, exist_ok=True) for component, graph in graph_comp.items(): # This seems like a bit of a hack, maybe better if saving in wmg diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f48da665..7fca0b67 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -70,12 +70,12 @@ def __init__( static_features_torch = torch.tensor(arr_static, dtype=torch.float32) self.register_buffer( "grid_static_features", - static_features_torch[self.boundary_mask.to(torch.bool)], + static_features_torch[self.boundary_mask[:, 0].to(torch.bool)], persistent=False, ) self.register_buffer( "boundary_static_features", - static_features_torch[self.interior_mask.to(torch.bool)], + static_features_torch[self.interior_mask[:, 0].to(torch.bool)], persistent=False, ) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 1feec63d..52f2d7a3 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -19,7 +19,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): super().__init__(args, config=config, datastore=datastore) # Load graph with static features - graph_dir_path = datastore.root_path / "graphs" / args.graph + graph_dir_path = datastore.root_path / "graphs" / args.graph_name self.hierarchical, graph_ldict = utils.load_graph( graph_dir_path=graph_dir_path ) diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py index 9d04f3e3..11bd795a 100644 --- a/neural_lam/plot_graph.py +++ b/neural_lam/plot_graph.py @@ -21,7 +21,7 @@ def main(): help="Path to the configuration for neural-lam", ) parser.add_argument( - "--name", + "--graph_name", type=str, default="multiscale", help="Name of saved graph to plot (default: multiscale)", @@ -46,7 +46,9 @@ def main(): _, datastore = load_config_and_datastore(config_path=args.config_path) # Load graph data - graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name) + graph_dir_path = os.path.join( + datastore.root_path, "graphs", args.graph_name + ) hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path) (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = ( graph_ldict["g2m_edge_index"], diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 615b10d6..1214d3ba 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -78,7 +78,7 @@ def main(input_args=None): # Model architecture parser.add_argument( - "--graph", + "--graph_name", type=str, default="multiscale", help="Graph to load and use in graph-based model " diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 27be3e7d..850aeda9 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -116,9 +116,9 @@ def __init__( # Load border/interior mask for splitting border_mask_float = torch.tensor( - self.datastore.boundary_mask, dtype=torch.float32 + self.datastore.boundary_mask.to_numpy(), dtype=torch.float32 ) - self.border_mask = border_mask_float.to(torch.bool)[:, 0] + self.border_mask = border_mask_float.to(torch.bool) self.interior_mask = torch.logical_not(self.border_mask) def __len__(self):