diff --git a/neural_lam/build_graph.py b/neural_lam/build_graph.py index a0c675ac..dcbff49d 100644 --- a/neural_lam/build_graph.py +++ b/neural_lam/build_graph.py @@ -16,7 +16,7 @@ } -def main(): +def main(input_args=None): parser = argparse.ArgumentParser( description="Graph generation using WMG", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -61,16 +61,12 @@ def main(): help="Limit multi-scale mesh to given number of levels, " "from bottom up", ) - parser.add_argument( - "--hierarchical", - action="store_true", - help="Generate hierarchical mesh graph (default: False)", - ) - args = parser.parse_args() + args = parser.parse_args(input_args) # Load grid positions config_loader = config.Config.from_file(args.data_config) + # TODO Do not get normalised positions coords = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy() # (num_nodes_full, 2) @@ -126,13 +122,24 @@ def main(): graph, attr="direction" ) for direction, graph in m2m_direction_comp.items(): - wmg.save.to_pyg( - graph=graph, - name=f"mesh_{direction}", - list_from_attribute="level", - edge_features=["len", "vdiff"], - output_directory=args.output_dir, - ) + if direction == "same": + # Name just m2m to be consistent with non-hierarchical + wmg.save.to_pyg( + graph=graph, + name="m2m", + list_from_attribute="level", + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) + else: + # up and down directions + wmg.save.to_pyg( + graph=graph, + name=f"mesh_{direction}", + list_from_attribute="levels", + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) else: wmg.save.to_pyg( graph=graph, diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 2f45b03f..8b8c5c85 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -30,7 +30,8 @@ def __init__( """ Create a new InteractionNet - edge_index: (2,M), Edges in pyg format + edge_index: (2,M), Edges in pyg format, with boeth sender and receiver + node indices starting at 0 input_dim: Dimensionality of input representations, for both nodes and edges update_edges: If new edge representations should be computed @@ -52,8 +53,7 @@ def __init__( # Default to input dim if not explicitly given hidden_dim = input_dim - # Make both sender and receiver indices of edge_index start at 0 - edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0] + # any edge_index used here must start sender and rec. nodes at index 0 # Store number of receiver nodes according to edge_index self.num_rec = edge_index[1].max() + 1 edge_index[0] = ( diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 5cfdcba7..6f90c0a5 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -116,6 +116,13 @@ def __iter__(self): return (self[i] for i in range(len(self))) +def zero_index_edge_index(edge_index): + """ + Make both sender and receiver indices of edge_index start at 0 + """ + return edge_index - edge_index.min(dim=1, keepdim=True)[0] + + def load_graph(graph_name, device="cpu"): """ Load all tensors representing the graph @@ -128,11 +135,16 @@ def loads_file(fn): # Load edges (edge_index) m2m_edge_index = BufferList( - loads_file("m2m_edge_index.pt"), persistent=False + [zero_index_edge_index(ei) for ei in loads_file("m2m_edge_index.pt")], + persistent=False, ) # List of (2, M_m2m[l]) g2m_edge_index = loads_file("g2m_edge_index.pt") # (2, M_g2m) m2g_edge_index = loads_file("m2g_edge_index.pt") # (2, M_m2g) + # Change first indices to 0 + g2m_edge_index = zero_index_edge_index(g2m_edge_index) + m2g_edge_index = zero_index_edge_index(m2g_edge_index) + n_levels = len(m2m_edge_index) hierarchical = n_levels > 1 # Nor just single level mesh graph @@ -168,10 +180,18 @@ def loads_file(fn): if hierarchical: # Load up and down edges and features mesh_up_edge_index = BufferList( - loads_file("mesh_up_edge_index.pt"), persistent=False + [ + zero_index_edge_index(ei) + for ei in loads_file("mesh_up_edge_index.pt") + ], + persistent=False, ) # List of (2, M_up[l]) mesh_down_edge_index = BufferList( - loads_file("mesh_down_edge_index.pt"), persistent=False + [ + zero_index_edge_index(ei) + for ei in loads_file("mesh_down_edge_index.pt") + ], + persistent=False, ) # List of (2, M_down[l]) mesh_up_features = loads_file( diff --git a/plot_graph.py b/plot_graph.py index 46a63a7f..b938bea6 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -47,7 +47,11 @@ def main(): # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) - (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = ( + ( + g2m_edge_index, + m2g_edge_index, + m2m_edge_index, + ) = ( graph_ldict["g2m_edge_index"], graph_ldict["m2g_edge_index"], graph_ldict["m2m_edge_index"], @@ -66,11 +70,9 @@ def main(): (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 ) - # List of edges to plot, (edge_index, color, line_width, label) - edge_plot_list = [ - (m2g_edge_index.numpy(), "black", 0.4, "M2G"), - (g2m_edge_index.numpy(), "black", 0.4, "G2M"), - ] + # List of edges to plot, (edge_index, from_pos, to_pos, color, + # line_width, label) + edge_plot_list = [] # Mesh positioning and edges to plot differ if we have a hierarchical graph if hierarchical: @@ -89,24 +91,80 @@ def main(): mesh_static_features, start=1 ) ] - mesh_pos = np.concatenate(mesh_level_pos, axis=0) + all_mesh_pos = np.concatenate(mesh_level_pos, axis=0) + grid_con_mesh_pos = mesh_level_pos[0] # Add inter-level mesh edges edge_plot_list += [ - (level_ei.numpy(), "blue", 1, f"M2M Level {level}") - for level, level_ei in enumerate(m2m_edge_index) + ( + level_ei.numpy(), + level_pos, + level_pos, + "blue", + 1, + f"M2M Level {level}", + ) + for level, (level_ei, level_pos) in enumerate( + zip(m2m_edge_index, mesh_level_pos) + ) ] # Add intra-level mesh edges - up_edges_ei = np.concatenate( - [level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1 + up_edges_ei = [ + level_up_ei.numpy() for level_up_ei in mesh_up_edge_index + ] + down_edges_ei = [ + level_down_ei.numpy() for level_down_ei in mesh_down_edge_index + ] + # Add up edges + for level_i, (up_ei, from_pos, to_pos) in enumerate( + zip(up_edges_ei, mesh_level_pos[:-1], mesh_level_pos[1:]) + ): + edge_plot_list.append( + ( + up_ei, + from_pos, + to_pos, + "green", + 1, + f"Mesh up {level_i}-{level_i+1}", + ) + ) + # Add down edges + for level_i, (down_ei, from_pos, to_pos) in enumerate( + zip(down_edges_ei, mesh_level_pos[1:], mesh_level_pos[:-1]) + ): + edge_plot_list.append( + ( + down_ei, + from_pos, + to_pos, + "green", + 1, + f"Mesh down {level_i+1}-{level_i}", + ) + ) + + edge_plot_list.append( + ( + m2g_edge_index.numpy(), + grid_con_mesh_pos, + grid_pos, + "black", + 0.4, + "M2G", + ) ) - down_edges_ei = np.concatenate( - [level_down_ei.numpy() for level_down_ei in mesh_down_edge_index], - axis=1, + edge_plot_list.append( + ( + g2m_edge_index.numpy(), + grid_pos, + grid_con_mesh_pos, + "black", + 0.4, + "G2M", + ) ) - edge_plot_list.append((up_edges_ei, "green", 1, "Mesh up")) - edge_plot_list.append((down_edges_ei, "green", 1, "Mesh down")) mesh_node_size = 2.5 else: @@ -120,21 +178,30 @@ def main(): (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1 ) - edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M")) + edge_plot_list.append( + (m2m_edge_index.numpy(), mesh_pos, mesh_pos, "blue", 1, "M2M") + ) + edge_plot_list.append( + (m2g_edge_index.numpy(), mesh_pos, grid_pos, "black", 0.4, "M2G") + ) + edge_plot_list.append( + (g2m_edge_index.numpy(), grid_pos, mesh_pos, "black", 0.4, "G2M") + ) - # All node positions in one array - node_pos = np.concatenate((mesh_pos, grid_pos), axis=0) + all_mesh_pos = mesh_pos # Add edges data_objs = [] for ( ei, + from_pos, + to_pos, col, width, label, ) in edge_plot_list: - edge_start = node_pos[ei[0]] # (M, 2) - edge_end = node_pos[ei[1]] # (M, 2) + edge_start = from_pos[ei[0]] # (M, 2) + edge_end = to_pos[ei[1]] # (M, 2) n_edges = edge_start.shape[0] x_edges = np.stack( @@ -171,9 +238,9 @@ def main(): ) data_objs.append( go.Scatter3d( - x=mesh_pos[:, 0], - y=mesh_pos[:, 1], - z=mesh_pos[:, 2], + x=all_mesh_pos[:, 0], + y=all_mesh_pos[:, 1], + z=all_mesh_pos[:, 2], mode="markers", marker={"color": "blue", "size": mesh_node_size}, name="Mesh nodes", diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py index 5c8b7aa1..74f5e44d 100644 --- a/tests/test_mllam_dataset.py +++ b/tests/test_mllam_dataset.py @@ -8,7 +8,7 @@ # First-party from neural_lam.config import Config -from neural_lam.create_mesh import main as create_mesh +from neural_lam.build_graph import main as build_graph from neural_lam.train_model import main as train_model from neural_lam.utils import load_static_data from neural_lam.weather_dataset import WeatherDataset @@ -66,14 +66,15 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath): n_state_features = len(var_names) n_prediction_timesteps = dataset.sample_length - n_input_steps - nx, ny = config.values["grid_shape_state"] - n_grid = nx * ny + static_data = load_static_data(dataset_name) + n_grid = static_data["interior_mask"].sum().item() + n_boundary = static_data["boundary_mask"].sum().item() # check that the dataset is not empty assert len(dataset) > 0 # get the first item - init_states, target_states, forcing = dataset[0] + init_states, target_states, forcing, boundary_forcing = dataset[0] # check that the shapes of the tensors are correct assert init_states.shape == (n_input_steps, n_grid, n_state_features) @@ -87,6 +88,11 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath): n_grid, n_forcing_features, ) + assert boundary_forcing.shape == ( + n_prediction_timesteps, + n_boundary, + 2 * n_grid + n_forcing_features, # TODO Adjust dimensionality + ) static_data = load_static_data(dataset_name=dataset_name) @@ -117,12 +123,14 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath): def test_create_graph_reduced_meps_dataset(): args = [ - "--graph=hierarchical", - "--hierarchical", + "--output_dir=graphs/reduced_meps_hierarchical", + "--archetype=hierarchical", "--data_config=data/meps_example_reduced/data_config.yaml", - "--levels=2", + "--max_num_levels=2", + "--mesh_node_distance=0.05", + # Distance for normalized data, might need adjustment ] - create_mesh(args) + build_graph(args) def test_train_model_reduced_meps_dataset(): @@ -131,7 +139,7 @@ def test_train_model_reduced_meps_dataset(): "--data_config=data/meps_example_reduced/data_config.yaml", "--n_workers=4", "--epochs=1", - "--graph=hierarchical", + "--graph=reduced_meps_hierarchical", "--hidden_dim=16", "--hidden_layers=1", "--processor_layers=1",