diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 00000000..1a4fc983 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,28 @@ +name: Run pre-commit in blueprint + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + blueprint-pre-commit: + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.11.7 + - name: Install pre-commit hooks + run: | + pip install -r requirements.txt + - name: Run pre-commit hooks + run: | + pre-commit run --all-files diff --git a/.gitignore b/.gitignore index 605cf968..7bb826a2 100644 --- a/.gitignore +++ b/.gitignore @@ -72,4 +72,3 @@ tags # Coc configuration directory .vim - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..f48eca67 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,51 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-ast + - id: check-case-conflict + - id: check-docstring-first + - id: check-symlinks + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: local + hooks: + - id: codespell + name: codespell + description: Check for spelling errors + language: system + entry: codespell +- repo: local + hooks: + - id: black + name: black + description: Format Python code + language: system + entry: black + types_or: [python, pyi] +- repo: local + hooks: + - id: isort + name: isort + description: Group and sort Python imports + language: system + entry: isort + types_or: [python, pyi, cython] +- repo: local + hooks: + - id: flake8 + name: flake8 + description: Check Python code for correctness, consistency and adherence to best practices + language: system + entry: flake8 --max-line-length=80 --ignore=E203,F811,I002,W503 + types: [python] +- repo: local + hooks: + - id: pylint + name: pylint + entry: pylint -rn -sn + language: system + types: [python] diff --git a/README.md b/README.md index 5413c8cd..67d9d9b1 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,9 @@ See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://git Below follows instructions on how to use Neural-LAM to train and evaluate models. ## Installation -Follow the steps below to create the neccesary python environment. +Follow the steps below to create the necessary python environment. -1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is neccesary for the Cartopy requirement. +1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement. 2. Use python 3.9. 3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system. 4. Install required packages specified in `requirements.txt`. @@ -160,7 +160,7 @@ python train_model.py --model hi_lam --graph hierarchical ... ``` ### Hi-LAM-Parallel -A version of Hi-LAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in paralell. +A version of Hi-LAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in parallel. Not included in the paper as initial experiments showed worse results than Hi-LAM, but could be interesting to try in more settings. To train Hi-LAM-Parallel use @@ -270,6 +270,16 @@ In addition, hierarchical mesh graphs (`L > 1`) feature a few additional files w These files have the same list format as the ones above, but each list has length `L-1` (as these edges describe connections between levels). Entries 0 in these lists describe edges between the lowest levels 1 and 2. +# Development and Contributing +Any push or Pull-Request to the main branch will trigger a selection of pre-commit hooks. +These hooks will run a series of checks on the code, like formatting and linting. +If any of these checks fail the push or PR will be rejected. +To test whether your code passes these checks before pushing, run +``` bash +pre-commit run --all-files +``` +from the root directory of the repository. + # Contact If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch. You can open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/create_grid_features.py b/create_grid_features.py index 869d0fa2..c9038103 100644 --- a/create_grid_features.py +++ b/create_grid_features.py @@ -1,42 +1,59 @@ +# Standard library import os -from tqdm import tqdm from argparse import ArgumentParser + +# Third-party import numpy as np import torch + def main(): - parser = ArgumentParser(description='Training arguments') - parser.add_argument('--dataset', type=str, default="meps_example", - help='Dataset to compute weights for (default: meps_example)') + """ + Pre-compute all static features related to the grid nodes + """ + parser = ArgumentParser(description="Training arguments") + parser.add_argument( + "--dataset", + type=str, + default="meps_example", + help="Dataset to compute weights for (default: meps_example)", + ) args = parser.parse_args() static_dir_path = os.path.join("data", args.dataset, "static") # -- Static grid node features -- - grid_xy = torch.tensor(np.load(os.path.join(static_dir_path, "nwp_xy.npy") - )) # (2, N_x, N_y) - grid_xy = grid_xy.flatten(1,2).T # (N_grid, 2) + grid_xy = torch.tensor( + np.load(os.path.join(static_dir_path, "nwp_xy.npy")) + ) # (2, N_x, N_y) + grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2) pos_max = torch.max(torch.abs(grid_xy)) grid_xy = grid_xy / pos_max # Divide by maximum coordinate - geopotential = torch.tensor(np.load(os.path.join(static_dir_path, - "surface_geopotential.npy"))) # (N_x, N_y) - geopotential = geopotential.flatten(0,1).unsqueeze(1) # (N_grid,1) + geopotential = torch.tensor( + np.load(os.path.join(static_dir_path, "surface_geopotential.npy")) + ) # (N_x, N_y) + geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1) gp_min = torch.min(geopotential) gp_max = torch.max(geopotential) # Rescale geopotential to [0,1] - geopotential = (geopotential - gp_min)/(gp_max - gp_min) # (N_grid, 1) + geopotential = (geopotential - gp_min) / (gp_max - gp_min) # (N_grid, 1) - grid_border_mask = torch.tensor(np.load(os.path.join(static_dir_path, - "border_mask.npy")), dtype=torch.int64) # (N_x, N_y) - grid_border_mask = grid_border_mask.flatten(0, 1).to( - torch.float).unsqueeze(1) # (N_grid, 1) + grid_border_mask = torch.tensor( + np.load(os.path.join(static_dir_path, "border_mask.npy")), + dtype=torch.int64, + ) # (N_x, N_y) + grid_border_mask = ( + grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1) + ) # (N_grid, 1) # Concatenate grid features - grid_features = torch.cat((grid_xy, geopotential, grid_border_mask), - dim=1) # (N_grid, 4) + grid_features = torch.cat( + (grid_xy, geopotential, grid_border_mask), dim=1 + ) # (N_grid, 4) torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt")) + if __name__ == "__main__": main() diff --git a/create_mesh.py b/create_mesh.py index 27c54500..cb524cd6 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -1,16 +1,20 @@ +# Standard library +import os +from argparse import ArgumentParser + +# Third-party +import matplotlib +import matplotlib.pyplot as plt import networkx import numpy as np -import matplotlib.pyplot as plt -import matplotlib -from argparse import ArgumentParser import scipy.spatial import torch import torch_geometric as pyg -import os from torch_geometric.utils.convert import from_networkx + def plot_graph(graph, title=None): - fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H + fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H edge_index = graph.edge_index pos = graph.pos @@ -20,24 +24,37 @@ def plot_graph(graph, title=None): if pyg.utils.is_undirected(edge_index): # Keep only 1 direction of edge_index - edge_index = edge_index[:,edge_index[0] < edge_index[1]] # (2, M/2) + edge_index = edge_index[:, edge_index[0] < edge_index[1]] # (2, M/2) # TODO: indicate direction of directed edges # Move all to cpu and numpy, compute (in)-degrees - degrees = pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy() + degrees = ( + pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy() + ) edge_index = edge_index.cpu().numpy() pos = pos.cpu().numpy() # Plot edges - from_pos = pos[edge_index[0]] # (M/2, 2) - to_pos = pos[edge_index[1]] # (M/2, 2) + from_pos = pos[edge_index[0]] # (M/2, 2) + to_pos = pos[edge_index[1]] # (M/2, 2) edge_lines = np.stack((from_pos, to_pos), axis=1) - axis.add_collection(matplotlib.collections.LineCollection(edge_lines, lw=0.4, - colors="black", zorder=1)) + axis.add_collection( + matplotlib.collections.LineCollection( + edge_lines, lw=0.4, colors="black", zorder=1 + ) + ) # Plot nodes - node_scatter = axis.scatter(pos[:,0], pos[:,1], c=degrees, s=3, marker="o", zorder=2, - cmap="viridis", clim=None) + node_scatter = axis.scatter( + pos[:, 0], + pos[:, 1], + c=degrees, + s=3, + marker="o", + zorder=2, + cmap="viridis", + clim=None, + ) plt.colorbar(node_scatter, aspect=50) @@ -46,90 +63,127 @@ def plot_graph(graph, title=None): return fig, axis + def sort_nodes_internally(nx_graph): # For some reason the networkx .nodes() return list can not be sorted, - # but this is the ordering used by pyg when converting. This function fixes this + # but this is the ordering used by pyg when converting. + # This function fixes this. H = networkx.DiGraph() H.add_nodes_from(sorted(nx_graph.nodes(data=True))) H.add_edges_from(nx_graph.edges(data=True)) return H + def save_edges(graph, name, base_path): - torch.save(graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt")) - edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), - dim=1).to(torch.float32) # Save as float32 + torch.save( + graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt") + ) + edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( + torch.float32 + ) # Save as float32 torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) + def save_edges_list(graphs, name, base_path): - torch.save([graph.edge_index for graph in graphs], - os.path.join(base_path, f"{name}_edge_index.pt")) - edge_features = [torch.cat((graph.len.unsqueeze(1), graph.vdiff), - dim=1).to(torch.float32) for graph in graphs] # Save as float32 + torch.save( + [graph.edge_index for graph in graphs], + os.path.join(base_path, f"{name}_edge_index.pt"), + ) + edge_features = [ + torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( + torch.float32 + ) + for graph in graphs + ] # Save as float32 torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) + def from_networkx_with_start_index(nx_graph, start_index): pyg_graph = from_networkx(nx_graph) pyg_graph.edge_index += start_index return pyg_graph + def mk_2d_graph(xy, nx, ny): - xm,xM = np.amin(xy[0][0,:]), np.amax(xy[0][0,:]) - ym,yM = np.amin(xy[1][:,0]), np.amax(xy[1][:,0]) + xm, xM = np.amin(xy[0][0, :]), np.amax(xy[0][0, :]) + ym, yM = np.amin(xy[1][:, 0]), np.amax(xy[1][:, 0]) # avoid nodes on border - dx = (xM-xm)/nx - dy = (yM-ym)/ny - lx = np.linspace(xm+dx/2, xM-dx/2, nx) - ly = np.linspace(ym+dy/2, yM-dy/2, ny) + dx = (xM - xm) / nx + dy = (yM - ym) / ny + lx = np.linspace(xm + dx / 2, xM - dx / 2, nx) + ly = np.linspace(ym + dy / 2, yM - dy / 2, ny) - mg = np.meshgrid(lx,ly) - g = networkx.grid_2d_graph(len(ly),len(lx)) + mg = np.meshgrid(lx, ly) + g = networkx.grid_2d_graph(len(ly), len(lx)) for node in g.nodes: - g.nodes[node]['pos'] = np.array([mg[0][node],mg[1][node]]) + g.nodes[node]["pos"] = np.array([mg[0][node], mg[1][node]]) # add diagonal edges - g.add_edges_from([ - ((x, y), (x+1, y+1)) - for x in range(nx-1) - for y in range(ny-1) - ] + [ - ((x+1, y), (x, y+1)) - for x in range(nx-1) - for y in range(ny-1) - ]) + g.add_edges_from( + [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] + + [ + ((x + 1, y), (x, y + 1)) + for x in range(nx - 1) + for y in range(ny - 1) + ] + ) # turn into directed graph dg = networkx.DiGraph(g) - for (u, v) in g.edges(): - d = np.sqrt(np.sum((g.nodes[u]['pos']-g.nodes[v]['pos'])**2)) - dg.edges[u,v]['len'] = d - dg.edges[u,v]['vdiff'] = g.nodes[u]['pos']-g.nodes[v]['pos'] + for u, v in g.edges(): + d = np.sqrt(np.sum((g.nodes[u]["pos"] - g.nodes[v]["pos"]) ** 2)) + dg.edges[u, v]["len"] = d + dg.edges[u, v]["vdiff"] = g.nodes[u]["pos"] - g.nodes[v]["pos"] dg.add_edge(v, u) - dg.edges[v,u]['len'] = d - dg.edges[v,u]['vdiff'] = g.nodes[v]['pos']-g.nodes[u]['pos'] + dg.edges[v, u]["len"] = d + dg.edges[v, u]["vdiff"] = g.nodes[v]["pos"] - g.nodes[u]["pos"] return dg + def prepend_node_index(graph, new_index): # Relabel node indices in graph, insert (graph_level, i, j) - ijk= [tuple((new_index,)+x) for x in graph.nodes] + ijk = [tuple((new_index,) + x) for x in graph.nodes] to_mapping = dict(zip(graph.nodes, ijk)) return networkx.relabel_nodes(graph, to_mapping, copy=True) + def main(): - parser = ArgumentParser(description='Graph genreation arguments') - parser.add_argument('--dataset', type=str, default="meps_example", - help='Dataset to load grid point coordinates from (default: meps_example)') - parser.add_argument('--graph', type=str, default="multiscale", - help='Name to save graph as (default: multiscale)') - parser.add_argument('--plot', type=int, default=0, - help='If graphs should be plotted during generation (default: 0 (false))') - parser.add_argument('--levels', type=int, - help='Limit multi-scale mesh to given number of levels, ' - 'from bottom up (default: None (no limit))') - parser.add_argument('--hierarchical', type=int, default=0, - help='Generate hierarchical mesh graph (default: 0, no)') + parser = ArgumentParser(description="Graph generation arguments") + parser.add_argument( + "--dataset", + type=str, + default="meps_example", + help="Dataset to load grid point coordinates from " + "(default: meps_example)", + ) + parser.add_argument( + "--graph", + type=str, + default="multiscale", + help="Name to save graph as (default: multiscale)", + ) + parser.add_argument( + "--plot", + type=int, + default=0, + help="If graphs should be plotted during generation " + "(default: 0 (false))", + ) + parser.add_argument( + "--levels", + type=int, + help="Limit multi-scale mesh to given number of levels, " + "from bottom up (default: None (no limit))", + ) + parser.add_argument( + "--hierarchical", + type=int, + default=0, + help="Generate hierarchical mesh graph (default: 0, no)", + ) args = parser.parse_args() # Load grid positions @@ -147,11 +201,11 @@ def main(): # # graph geometry - nx = 3 # number of children = nx**2 - nlev = int(np.log(max(xy.shape))/np.log(nx)) - nleaf = nx**nlev # leaves at the bottom = nleaf**2 + nx = 3 # number of children = nx**2 + nlev = int(np.log(max(xy.shape)) / np.log(nx)) + nleaf = nx**nlev # leaves at the bottom = nleaf**2 - mesh_levels = nlev-1 + mesh_levels = nlev - 1 if args.levels: # Limit the levels in mesh graph mesh_levels = min(mesh_levels, args.levels) @@ -160,8 +214,8 @@ def main(): # multi resolution tree levels G = [] - for lev in range(1, mesh_levels+1): - n = int(nleaf/(nx**lev)) + for lev in range(1, mesh_levels + 1): + n = int(nleaf / (nx**lev)) g = mk_2d_graph(xy, n, n) if args.plot: plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}") @@ -171,24 +225,27 @@ def main(): if args.hierarchical: # Relabel nodes of each level with level index first - G = [prepend_node_index(graph, level_i) for level_i, graph in enumerate(G)] + G = [ + prepend_node_index(graph, level_i) + for level_i, graph in enumerate(G) + ] num_nodes_level = np.array([len(g_level.nodes) for g_level in G]) - # First node index in each level in the hierarcical graph - first_index_level = np.concatenate(( - np.zeros(1, dtype=int), - np.cumsum(num_nodes_level[:-1]))) + # First node index in each level in the hierarchical graph + first_index_level = np.concatenate( + (np.zeros(1, dtype=int), np.cumsum(num_nodes_level[:-1])) + ) # Create inter-level mesh edges up_graphs = [] down_graphs = [] for from_level, to_level, G_from, G_to, start_index in zip( - range(1, mesh_levels), - range(0, mesh_levels-1), - G[1:], - G[:-1], - first_index_level[:mesh_levels-1]): - + range(1, mesh_levels), + range(0, mesh_levels - 1), + G[1:], + G[:-1], + first_index_level[: mesh_levels - 1], + ): # start out from graph at from level G_down = G_from.copy() G_down.clear_edges() @@ -201,30 +258,38 @@ def main(): # order in vm should be same as in vm_xy v_to_list = list(G_to.nodes) v_from_list = list(G_from.nodes) - v_from_xy = np.array([xy for _, xy in G_from.nodes.data('pos')]) + v_from_xy = np.array([xy for _, xy in G_from.nodes.data("pos")]) kdt_m = scipy.spatial.KDTree(v_from_xy) # add edges from mesh to grid for v in v_to_list: # find 1(?) nearest neighbours (index to vm_xy) - neigh_idx = kdt_m.query(G_down.nodes[v]['pos'], 1)[1] + neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], 1)[1] u = v_from_list[neigh_idx] # add edge from mesh to grid G_down.add_edge(u, v) - d = np.sqrt(np.sum((G_down.nodes[u]['pos']-G_down.nodes[v]['pos'])**2)) - G_down.edges[u,v]['len'] = d - G_down.edges[u,v]['vdiff'] = G_down.nodes[u]['pos']-G_down.nodes[v]['pos'] + d = np.sqrt( + np.sum( + (G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2 + ) + ) + G_down.edges[u, v]["len"] = d + G_down.edges[u, v]["vdiff"] = ( + G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] + ) # relabel nodes to integers (sorted) - G_down_int = networkx.convert_node_labels_to_integers(G_down, - first_label=start_index, ordering='sorted') # Issue with sorting here + G_down_int = networkx.convert_node_labels_to_integers( + G_down, first_label=start_index, ordering="sorted" + ) # Issue with sorting here G_down_int = sort_nodes_internally(G_down_int) - pyg_down = from_networkx_with_start_index(G_down_int , start_index) + pyg_down = from_networkx_with_start_index(G_down_int, start_index) # Create up graph, invert downwards edges - up_edges = torch.stack((pyg_down.edge_index[1], pyg_down.edge_index[0]), - dim=0) + up_edges = torch.stack( + (pyg_down.edge_index[1], pyg_down.edge_index[0]), dim=0 + ) pyg_up = pyg_down.clone() pyg_up.edge_index = up_edges @@ -232,10 +297,14 @@ def main(): down_graphs.append(pyg_down) if args.plot: - plot_graph(pyg_down, title=f"Down graph, {from_level} -> {to_level}") + plot_graph( + pyg_down, title=f"Down graph, {from_level} -> {to_level}" + ) plt.show() - plot_graph(pyg_down, title=f"Up graph, {to_level} -> {from_level}") + plot_graph( + pyg_down, title=f"Up graph, {to_level} -> {from_level}" + ) plt.show() # Save up and down edges @@ -243,10 +312,15 @@ def main(): save_edges_list(down_graphs, "mesh_down", graph_dir_path) # Extract intra-level edges for m2m - m2m_graphs = [from_networkx_with_start_index( - networkx.convert_node_labels_to_integers(level_graph, - first_label=start_index, ordering='sorted'), start_index) - for level_graph, start_index in zip(G, first_index_level)] + m2m_graphs = [ + from_networkx_with_start_index( + networkx.convert_node_labels_to_integers( + level_graph, first_label=start_index, ordering="sorted" + ), + start_index, + ) + for level_graph, start_index in zip(G, first_index_level) + ] mesh_pos = [graph.pos.to(torch.float32) for graph in m2m_graphs] @@ -259,10 +333,14 @@ def main(): else: # combine all levels to one graph G_tot = G[0] - for lev in range(1,len(G)): - nodes = list(G[lev-1].nodes) + for lev in range(1, len(G)): + nodes = list(G[lev - 1].nodes) n = int(np.sqrt(len(nodes))) - ij = np.array(nodes).reshape((n,n,2))[1::nx,1::nx,:].reshape(int(n/nx)**2,2) + ij = ( + np.array(nodes) + .reshape((n, n, 2))[1::nx, 1::nx, :] + .reshape(int(n / nx) ** 2, 2) + ) ij = [tuple(x) for x in ij] G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij))) G_tot = networkx.compose(G_tot, G[lev]) @@ -271,8 +349,9 @@ def main(): G_tot = prepend_node_index(G_tot, 0) # relabel nodes to integers (sorted) - G_int = networkx.convert_node_labels_to_integers(G_tot, first_label=0, - ordering='sorted') + G_int = networkx.convert_node_labels_to_integers( + G_tot, first_label=0, ordering="sorted" + ) # Graph to use in g2m and m2g G_bottom_mesh = G_tot @@ -291,11 +370,12 @@ def main(): save_edges_list(m2m_graphs, "m2m", graph_dir_path) # Divide mesh node pos by max coordinate of grid cell - mesh_pos = [pos/pos_max for pos in mesh_pos] + mesh_pos = [pos / pos_max for pos in mesh_pos] # Save mesh positions - torch.save(mesh_pos, os.path.join(graph_dir_path, - "mesh_features.pt")) # mesh pos, in float32 + torch.save( + mesh_pos, os.path.join(graph_dir_path, "mesh_features.pt") + ) # mesh pos, in float32 # # Grid2Mesh @@ -307,9 +387,11 @@ def main(): # mesh nodes on lowest level vm = G_bottom_mesh.nodes - vm_xy = np.array([xy for _, xy in vm.data('pos')]) + vm_xy = np.array([xy for _, xy in vm.data("pos")]) # distance between mesh nodes - dm = np.sqrt(np.sum((vm.data('pos')[(0,1,0)] - vm.data('pos')[(0,0,0)])**2)) + dm = np.sqrt( + np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) + ) # grid nodes Ny, Nx = xy.shape[1:] @@ -320,16 +402,16 @@ def main(): # vg features (only pos introduced here) for node in G_grid.nodes: # pos is in feature but here explicit for convenience - G_grid.nodes[node]['pos'] = np.array([xy[0][node],xy[1][node]]) + G_grid.nodes[node]["pos"] = np.array([xy[0][node], xy[1][node]]) - # add 1000 to node key to separate grid nodes (1000,i,j) from mesh nodes (i,j) - # and impose sorting order such that vm are the first nodes + # add 1000 to node key to separate grid nodes (1000,i,j) from mesh nodes + # (i,j) and impose sorting order such that vm are the first nodes G_grid = prepend_node_index(G_grid, 1000) # build kd tree for grid point pos # order in vg_list should be same as in vg_xy vg_list = list(G_grid.nodes) - vg_xy = np.array([[xy[0][node[1:]],xy[1][node[1:]]] for node in vg_list]) + vg_xy = np.array([[xy[0][node[1:]], xy[1][node[1:]]] for node in vg_list]) kdt_g = scipy.spatial.KDTree(vg_xy) # now add (all) mesh nodes, include features (pos) @@ -346,14 +428,18 @@ def main(): # add edges for v in vm: # find neighbours (index to vg_xy) - neigh_idxs = kdt_g.query_ball_point(vm[v]['pos'], dm*DM_SCALE) + neigh_idxs = kdt_g.query_ball_point(vm[v]["pos"], dm * DM_SCALE) for i in neigh_idxs: u = vg_list[i] # add edge from grid to mesh G_g2m.add_edge(u, v) - d = np.sqrt(np.sum((G_g2m.nodes[u]['pos']-G_g2m.nodes[v]['pos'])**2)) - G_g2m.edges[u,v]['len'] = d - G_g2m.edges[u,v]['vdiff'] = G_g2m.nodes[u]['pos']-G_g2m.nodes[v]['pos'] + d = np.sqrt( + np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2) + ) + G_g2m.edges[u, v]["len"] = d + G_g2m.edges[u, v]["vdiff"] = ( + G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"] + ) pyg_g2m = from_networkx(G_g2m) @@ -377,18 +463,23 @@ def main(): # add edges from mesh to grid for v in vg_list: # find 4 nearest neighbours (index to vm_xy) - neigh_idxs = kdt_m.query(G_m2g.nodes[v]['pos'], 4)[1] + neigh_idxs = kdt_m.query(G_m2g.nodes[v]["pos"], 4)[1] for i in neigh_idxs: u = vm_list[i] # add edge from mesh to grid G_m2g.add_edge(u, v) - d = np.sqrt(np.sum((G_m2g.nodes[u]['pos']-G_m2g.nodes[v]['pos'])**2)) - G_m2g.edges[u,v]['len'] = d - G_m2g.edges[u,v]['vdiff'] = G_m2g.nodes[u]['pos']-G_m2g.nodes[v]['pos'] + d = np.sqrt( + np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2) + ) + G_m2g.edges[u, v]["len"] = d + G_m2g.edges[u, v]["vdiff"] = ( + G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"] + ) # relabel nodes to integers (sorted) - G_m2g_int = networkx.convert_node_labels_to_integers(G_m2g, first_label=0, - ordering='sorted') + G_m2g_int = networkx.convert_node_labels_to_integers( + G_m2g, first_label=0, ordering="sorted" + ) pyg_m2g = from_networkx(G_m2g_int) if args.plot: @@ -401,5 +492,6 @@ def main(): # m2g save_edges(pyg_m2g, "m2g", graph_dir_path) + if __name__ == "__main__": main() diff --git a/create_parameter_weights.py b/create_parameter_weights.py index 1df33360..0fddfd15 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -1,63 +1,107 @@ +# Standard library import os -from tqdm import tqdm from argparse import ArgumentParser + +# Third-party import numpy as np import torch +from tqdm import tqdm -from neural_lam.weather_dataset import WeatherDataset +# First-party from neural_lam import constants +from neural_lam.weather_dataset import WeatherDataset + def main(): - parser = ArgumentParser(description='Training arguments') - parser.add_argument('--dataset', type=str, default="meps_example", - help='Dataset to compute weights for (default: meps_example)') - parser.add_argument('--batch_size', type=int, default=32, - help='Batch size when iterating over the dataset') - parser.add_argument('--step_length', type=int, default=3, - help='Step length in hours to consider single time step (default: 3)') - parser.add_argument('--n_workers', type=int, default=4, - help='Number of workers in data loader (default: 4)') + """ + Pre-compute parameter weights to be used in loss function + """ + parser = ArgumentParser(description="Training arguments") + parser.add_argument( + "--dataset", + type=str, + default="meps_example", + help="Dataset to compute weights for (default: meps_example)", + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size when iterating over the dataset", + ) + parser.add_argument( + "--step_length", + type=int, + default=3, + help="Step length in hours to consider single time step (default: 3)", + ) + parser.add_argument( + "--n_workers", + type=int, + default=4, + help="Number of workers in data loader (default: 4)", + ) args = parser.parse_args() static_dir_path = os.path.join("data", args.dataset, "static") # Create parameter weights based on height # based on fig A.1 in graph cast paper - w_par = np.zeros((len(constants.param_names),)) - w_dict = {'2': 1.0, '0': 0.1, '65': 0.065, '1000': 0.1, '850': 0.05, '500': 0.03} - w_list = np.array([w_dict[par.split('_')[-2]] for par in constants.param_names]) + w_dict = { + "2": 1.0, + "0": 0.1, + "65": 0.065, + "1000": 0.1, + "850": 0.05, + "500": 0.03, + } + w_list = np.array( + [w_dict[par.split("_")[-2]] for par in constants.PARAM_NAMES] + ) print("Saving parameter weights...") - np.save(os.path.join(static_dir_path, 'parameter_weights.npy'), - w_list.astype('float32')) + np.save( + os.path.join(static_dir_path, "parameter_weights.npy"), + w_list.astype("float32"), + ) # Load dataset without any subsampling - ds = WeatherDataset(args.dataset, split="train", subsample_step=1, pred_length=63, - standardize=False) # Without standardization - loader = torch.utils.data.DataLoader(ds, args.batch_size, shuffle=False, - num_workers=args.n_workers) - # Compute mean and std.-dev. of each parameter (+ flux forcing) across full dataset + ds = WeatherDataset( + args.dataset, + split="train", + subsample_step=1, + pred_length=63, + standardize=False, + ) # Without standardization + loader = torch.utils.data.DataLoader( + ds, args.batch_size, shuffle=False, num_workers=args.n_workers + ) + # Compute mean and std.-dev. of each parameter (+ flux forcing) + # across full dataset print("Computing mean and std.-dev. for parameters...") means = [] squares = [] flux_means = [] flux_squares = [] for init_batch, target_batch, _, forcing_batch in tqdm(loader): - batch = torch.cat((init_batch, target_batch), - dim=1) # (N_batch, N_t, N_grid, d_features) - means.append(torch.mean(batch, dim=(1,2))) # (N_batch, d_features,) - squares.append(torch.mean(batch**2, dim=(1,2))) # (N_batch, d_features,) - - flux_batch = forcing_batch[:,:,:,0] # Flux is first index - flux_means.append(torch.mean(flux_batch)) # (,) - flux_squares.append(torch.mean(flux_batch**2)) # (,) - - mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features) + batch = torch.cat( + (init_batch, target_batch), dim=1 + ) # (N_batch, N_t, N_grid, d_features) + means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,) + squares.append( + torch.mean(batch**2, dim=(1, 2)) + ) # (N_batch, d_features,) + + flux_batch = forcing_batch[:, :, :, 0] # Flux is first index + flux_means.append(torch.mean(flux_batch)) # (,) + flux_squares.append(torch.mean(flux_batch**2)) # (,) + + mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features) second_moment = torch.mean(torch.cat(squares, dim=0), dim=0) - std = torch.sqrt(second_moment - mean**2) # (d_features) + std = torch.sqrt(second_moment - mean**2) # (d_features) - flux_mean = torch.mean(torch.stack(flux_means)) # (,) - flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,) - flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) + flux_mean = torch.mean(torch.stack(flux_means)) # (,) + flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,) + flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) flux_stats = torch.stack((flux_mean, flux_std)) print("Saving mean, std.-dev, flux_stats...") @@ -67,36 +111,53 @@ def main(): # Compute mean and std.-dev. of one-step differences across the dataset print("Computing mean and std.-dev. for one-step differences...") - ds_standard = WeatherDataset(args.dataset, split="train", subsample_step=1, - pred_length=63, standardize=True) # Re-load with standardization - loader_standard = torch.utils.data.DataLoader(ds_standard, args.batch_size, - shuffle=False, num_workers=args.n_workers) - used_subsample_len = (65//args.step_length)*args.step_length + ds_standard = WeatherDataset( + args.dataset, + split="train", + subsample_step=1, + pred_length=63, + standardize=True, + ) # Re-load with standardization + loader_standard = torch.utils.data.DataLoader( + ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers + ) + used_subsample_len = (65 // args.step_length) * args.step_length diff_means = [] diff_squares = [] for init_batch, target_batch, _, _ in tqdm(loader_standard): - batch = torch.cat((init_batch, target_batch), - dim=1) # (N_batch, N_t', N_grid, d_features) + batch = torch.cat( + (init_batch, target_batch), dim=1 + ) # (N_batch, N_t', N_grid, d_features) # Note: batch contains only 1h-steps - stepped_batch = torch.cat([batch[:,ss_i:used_subsample_len:args.step_length] - for ss_i in range(args.step_length)], dim=0) - # (N_batch', N_t, N_grid, d_features), N_batch' = args.step_length*N_batch - - batch_diffs = stepped_batch[:,1:] - stepped_batch[:,:-1] + stepped_batch = torch.cat( + [ + batch[:, ss_i : used_subsample_len : args.step_length] + for ss_i in range(args.step_length) + ], + dim=0, + ) + # (N_batch', N_t, N_grid, d_features), + # N_batch' = args.step_length*N_batch + + batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] # (N_batch', N_t-1, N_grid, d_features) - diff_means.append(torch.mean(batch_diffs, dim=(1,2))) # (N_batch', d_features,) - diff_squares.append(torch.mean(batch_diffs**2, - dim=(1,2))) # (N_batch', d_features,) + diff_means.append( + torch.mean(batch_diffs, dim=(1, 2)) + ) # (N_batch', d_features,) + diff_squares.append( + torch.mean(batch_diffs**2, dim=(1, 2)) + ) # (N_batch', d_features,) - diff_mean = torch.mean(torch.cat(diff_means, dim=0), dim=0) # (d_features) + diff_mean = torch.mean(torch.cat(diff_means, dim=0), dim=0) # (d_features) diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0) - diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features) + diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features) print("Saving one-step difference mean and std.-dev...") torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt")) torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt")) + if __name__ == "__main__": main() diff --git a/neural_lam/constants.py b/neural_lam/constants.py index 9448a2db..9fbe3910 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -1,114 +1,121 @@ +# Third-party import cartopy import numpy as np -wandb_project = "neural-lam" +WANDB_PROJECT = "neural-lam" -seconds_in_year = 365*24*60*60 # Assuming no leap years in dataset (2024 is next) +SECONDS_IN_YEAR = ( + 365 * 24 * 60 * 60 +) # Assuming no leap years in dataset (2024 is next) # Log prediction error for these lead times -val_step_log_errors = np.array([1, 2, 3, 5, 10, 15, 19]) +VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 5, 10, 15, 19]) -# Log these metrics to wandb as scalar values for specific variables and lead times +# Log these metrics to wandb as scalar values for +# specific variables and lead times # List of metrics to watch, including any prefix (e.g. val_rmse) -metrics_watch = [ -] +METRICS_WATCH = [] # Dict with variables and lead times to log watched metrics for -# Format is a dictionary that maps from a variable index to a list of lead time steps -var_leads_metrics_watch = { - 6: [2, 19], # t_2 - 14: [2, 19], # wvint_0 - 15: [2, 19], # z_1000 +# Format is a dictionary that maps from a variable index to +# a list of lead time steps +VAR_LEADS_METRICS_WATCH = { + 6: [2, 19], # t_2 + 14: [2, 19], # wvint_0 + 15: [2, 19], # z_1000 } # Variable names -param_names = [ - 'pres_heightAboveGround_0_instant', - 'pres_heightAboveSea_0_instant', - 'nlwrs_heightAboveGround_0_accum', - 'nswrs_heightAboveGround_0_accum', - 'r_heightAboveGround_2_instant', - 'r_hybrid_65_instant', - 't_heightAboveGround_2_instant', - 't_hybrid_65_instant', - 't_isobaricInhPa_500_instant', - 't_isobaricInhPa_850_instant', - 'u_hybrid_65_instant', - 'u_isobaricInhPa_850_instant', - 'v_hybrid_65_instant', - 'v_isobaricInhPa_850_instant', - 'wvint_entireAtmosphere_0_instant', - 'z_isobaricInhPa_1000_instant', - 'z_isobaricInhPa_500_instant' +PARAM_NAMES = [ + "pres_heightAboveGround_0_instant", + "pres_heightAboveSea_0_instant", + "nlwrs_heightAboveGround_0_accum", + "nswrs_heightAboveGround_0_accum", + "r_heightAboveGround_2_instant", + "r_hybrid_65_instant", + "t_heightAboveGround_2_instant", + "t_hybrid_65_instant", + "t_isobaricInhPa_500_instant", + "t_isobaricInhPa_850_instant", + "u_hybrid_65_instant", + "u_isobaricInhPa_850_instant", + "v_hybrid_65_instant", + "v_isobaricInhPa_850_instant", + "wvint_entireAtmosphere_0_instant", + "z_isobaricInhPa_1000_instant", + "z_isobaricInhPa_500_instant", ] -param_names_short = [ - 'pres_0g', - 'pres_0s', - 'nlwrs_0', - 'nswrs_0', - 'r_2', - 'r_65', - 't_2', - 't_65', - 't_500', - 't_850', - 'u_65', - 'u_850', - 'v_65', - 'v_850', - 'wvint_0', - 'z_1000', - 'z_500' +PARAM_NAMES_SHORT = [ + "pres_0g", + "pres_0s", + "nlwrs_0", + "nswrs_0", + "r_2", + "r_65", + "t_2", + "t_65", + "t_500", + "t_850", + "u_65", + "u_850", + "v_65", + "v_850", + "wvint_0", + "z_1000", + "z_500", ] -param_units = [ - 'Pa', - 'Pa', - 'W/m\\textsuperscript{2}', - 'W/m\\textsuperscript{2}', - '-', # unitless - '-', - 'K', - 'K', - 'K', - 'K', - 'm/s', - 'm/s', - 'm/s', - 'm/s', - 'kg/m\\textsuperscript{2}', - 'm\\textsuperscript{2}/s\\textsuperscript{2}', - 'm\\textsuperscript{2}/s\\textsuperscript{2}' +PARAM_UNITS = [ + "Pa", + "Pa", + "W/m\\textsuperscript{2}", + "W/m\\textsuperscript{2}", + "-", # unitless + "-", + "K", + "K", + "K", + "K", + "m/s", + "m/s", + "m/s", + "m/s", + "kg/m\\textsuperscript{2}", + "m\\textsuperscript{2}/s\\textsuperscript{2}", + "m\\textsuperscript{2}/s\\textsuperscript{2}", ] # Projection and grid -# TODO Do not hard code this, make part of static dataset files -grid_shape = (268, 238) # (y, x) +# Hard coded for now, but should eventually be part of dataset desc. files +GRID_SHAPE = (268, 238) # (y, x) -lambert_proj_params = { - 'a': 6367470, - 'b': 6367470, - 'lat_0': 63.3, - 'lat_1': 63.3, - 'lat_2': 63.3, - 'lon_0': 15.0, - 'proj': 'lcc' - } +LAMBERT_PROJ_PARAMS = { + "a": 6367470, + "b": 6367470, + "lat_0": 63.3, + "lat_1": 63.3, + "lat_2": 63.3, + "lon_0": 15.0, + "proj": "lcc", +} -grid_limits = [ # In projection - -1059506.5523409774, # min x - 1310493.4476590226, # max x - -1331732.4471934352, # min y - 1338267.5528065648, # max y +GRID_LIMITS = [ # In projection + -1059506.5523409774, # min x + 1310493.4476590226, # max x + -1331732.4471934352, # min y + 1338267.5528065648, # max y ] # Create projection -lambert_proj = cartopy.crs.LambertConformal( - central_longitude=lambert_proj_params['lon_0'], - central_latitude=lambert_proj_params['lat_0'], - standard_parallels=(lambert_proj_params['lat_1'], - lambert_proj_params['lat_2'])) +LAMBERT_PROJ = cartopy.crs.LambertConformal( + central_longitude=LAMBERT_PROJ_PARAMS["lon_0"], + central_latitude=LAMBERT_PROJ_PARAMS["lat_0"], + standard_parallels=( + LAMBERT_PROJ_PARAMS["lat_1"], + LAMBERT_PROJ_PARAMS["lat_2"], + ), +) # Data dimensions -batch_static_feature_dim = 1 # Only open water -grid_forcing_dim = 5*3 # 5 features for 3 time-step window -grid_state_dim = 17 +BATCH_STATIC_FEATURE_DIM = 1 # Only open water +GRID_FORCING_DIM = 5 * 3 # 5 features for 3 time-step window +GRID_STATE_DIM = 17 diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index a31e3d95..663f27e4 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -1,27 +1,48 @@ +# Third-party import torch -from torch import nn import torch_geometric as pyg +from torch import nn +# First-party from neural_lam import utils + class InteractionNet(pyg.nn.MessagePassing): """ - Implementation of a generic Interaction Network, from Battaglia et al. (2016) + Implementation of a generic Interaction Network, + from Battaglia et al. (2016) """ - def __init__(self, edge_index, input_dim, update_edges=True, hidden_layers=1, - hidden_dim=None, edge_chunk_sizes=None, aggr_chunk_sizes=None, aggr="sum"): + + # pylint: disable=arguments-differ + # Disable to override args/kwargs from superclass + + def __init__( + self, + edge_index, + input_dim, + update_edges=True, + hidden_layers=1, + hidden_dim=None, + edge_chunk_sizes=None, + aggr_chunk_sizes=None, + aggr="sum", + ): """ Create a new InteractionNet edge_index: (2,M), Edges in pyg format - input_dim: Dimensionality of input representations, for both nodes and edges - update_edges: If new edge representations should be computed and returned + input_dim: Dimensionality of input representations, + for both nodes and edges + update_edges: If new edge representations should be computed + and returned hidden_layers: Number of hidden layers in MLPs - hidden_dim: Dimensionality of hidden layers, if None then same as input_dim - edge_chunk_sizes: List of chunks sizes to split edge representation into and - use separate MLPs for (None = no chunking, same MLP) - aggr_chunk_sizes: List of chunks sizes to split aggregated node representation + hidden_dim: Dimensionality of hidden layers, if None then same + as input_dim + edge_chunk_sizes: List of chunks sizes to split edge representation into and use separate MLPs for (None = no chunking, same MLP) + aggr_chunk_sizes: List of chunks sizes to split aggregated node + representation into and use separate MLPs for + (None = no chunking, same MLP) aggr: Message aggregation method (sum/mean) """ assert aggr in ("sum", "mean"), f"Unknown aggregation method: {aggr}" @@ -35,31 +56,37 @@ def __init__(self, edge_index, input_dim, update_edges=True, hidden_layers=1, edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0] # Store number of receiver nodes according to edge_index self.num_rec = edge_index[1].max() + 1 - edge_index[0] = edge_index[0] + self.num_rec # Make sender indices after rec + edge_index[0] = ( + edge_index[0] + self.num_rec + ) # Make sender indices after rec self.register_buffer("edge_index", edge_index, persistent=False) # Create MLPs - edge_mlp_recipe = [3*input_dim] + [hidden_dim]*(hidden_layers + 1) - aggr_mlp_recipe = [2*input_dim] + [hidden_dim]*(hidden_layers + 1) + edge_mlp_recipe = [3 * input_dim] + [hidden_dim] * (hidden_layers + 1) + aggr_mlp_recipe = [2 * input_dim] + [hidden_dim] * (hidden_layers + 1) if edge_chunk_sizes is None: self.edge_mlp = utils.make_mlp(edge_mlp_recipe) else: - self.edge_mlp = SplitMLPs([utils.make_mlp(edge_mlp_recipe) for _ in - edge_chunk_sizes], edge_chunk_sizes) + self.edge_mlp = SplitMLPs( + [utils.make_mlp(edge_mlp_recipe) for _ in edge_chunk_sizes], + edge_chunk_sizes, + ) if aggr_chunk_sizes is None: self.aggr_mlp = utils.make_mlp(aggr_mlp_recipe) else: - self.aggr_mlp = SplitMLPs([utils.make_mlp(aggr_mlp_recipe) for _ in - aggr_chunk_sizes], aggr_chunk_sizes) + self.aggr_mlp = SplitMLPs( + [utils.make_mlp(aggr_mlp_recipe) for _ in aggr_chunk_sizes], + aggr_chunk_sizes, + ) self.update_edges = update_edges def forward(self, send_rep, rec_rep, edge_rep): """ - Apply interaction network to update the representations of receiver nodes, - and optionally the edge representations. + Apply interaction network to update the representations of receiver + nodes, and optionally the edge representations. send_rep: (N_send, d_h), vector representations of sender nodes rec_rep: (N_rec, d_h), vector representations of receiver nodes @@ -67,13 +94,15 @@ def forward(self, send_rep, rec_rep, edge_rep): Returns: rec_rep: (N_rec, d_h), updated vector representations of receiver nodes - (optionally) edge_rep: (M, d_h), updated vector representations of edges + (optionally) edge_rep: (M, d_h), updated vector representations + of edges """ - # Always concatenate to [rec_nodes, send_nodes] for propagation, but only - # aggregate to rec_nodes + # Always concatenate to [rec_nodes, send_nodes] for propagation, + # but only aggregate to rec_nodes node_reps = torch.cat((rec_rep, send_rep), dim=-2) - edge_rep_aggr, edge_diff = self.propagate(self.edge_index, x=node_reps, - edge_attr=edge_rep) + edge_rep_aggr, edge_diff = self.propagate( + self.edge_index, x=node_reps, edge_attr=edge_rep + ) rec_diff = self.aggr_mlp(torch.cat((rec_rep, edge_rep_aggr), dim=-1)) # Residual connections @@ -91,14 +120,16 @@ def message(self, x_j, x_i, edge_attr): """ return self.edge_mlp(torch.cat((edge_attr, x_j, x_i), dim=-1)) - def aggregate(self, messages, index, ptr, dim_size): + # pylint: disable-next=signature-differs + def aggregate(self, inputs, index, ptr, dim_size): """ Overridden aggregation function to: * return both aggregated and original messages, * only aggregate to number of receiver nodes. """ - aggr = super().aggregate(messages, index, ptr, self.num_rec) - return aggr, messages + aggr = super().aggregate(inputs, index, ptr, self.num_rec) + return aggr, inputs + class SplitMLPs(nn.Module): """ @@ -106,10 +137,12 @@ class SplitMLPs(nn.Module): Split up input along dim -2 using given chunk sizes and feeds each chunk through separate MLPs. """ + def __init__(self, mlps, chunk_sizes): super().__init__() - assert len(mlps) == len(chunk_sizes), ( - "Number of MLPs must match the number of chunks") + assert len(mlps) == len( + chunk_sizes + ), "Number of MLPs must match the number of chunks" self.mlps = nn.ModuleList(mlps) self.chunk_sizes = chunk_sizes @@ -121,8 +154,10 @@ def forward(self, x): x: (..., N, d), where N = sum(chunk_sizes) Returns: - joined_output: (..., N, d), concatenated results from the different MLPs + joined_output: (..., N, d), concatenated results from the MLPs """ chunks = torch.split(x, self.chunk_sizes, dim=-2) - chunk_outputs = [mlp(chunk_input) for mlp, chunk_input in zip(self.mlps, chunks)] + chunk_outputs = [ + mlp(chunk_input) for mlp, chunk_input in zip(self.mlps, chunks) + ] return torch.cat(chunk_outputs, dim=-2) diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py index 5b3270ab..93014fd3 100644 --- a/neural_lam/metrics.py +++ b/neural_lam/metrics.py @@ -1,5 +1,7 @@ +# Third-party import torch + def get_metric(metric_name): """ Get a defined metric with given name @@ -10,202 +12,256 @@ def get_metric(metric_name): metric: function implementing the metric """ metric_name_lower = metric_name.lower() - assert metric_name_lower in DEFINED_METRICS, f"Unknown metric: {metric_name}" + assert ( + metric_name_lower in DEFINED_METRICS + ), f"Unknown metric: {metric_name}" return DEFINED_METRICS[metric_name_lower] + def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars): """ Masks and (optionally) reduces entry-wise metric values - (...,) is any number of batch dimensions, potentially different but broadcastable + (...,) is any number of batch dimensions, potentially different + but broadcastable metric_entry_vals: (..., N, d_state), prediction mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) - sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) Returns: - metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending - on reduction arguments. + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. """ # Only keep grid nodes in mask if mask is not None: - metric_entry_vals = metric_entry_vals[...,mask,:] # (..., N', d_state) + metric_entry_vals = metric_entry_vals[ + ..., mask, : + ] # (..., N', d_state) # Optionally reduce last two dimensions - if average_grid: # Reduce grid first - metric_entry_vals = torch.mean(metric_entry_vals, dim=-2) # (..., d_state) - if sum_vars: # Reduce vars second - metric_entry_vals = torch.sum(metric_entry_vals, dim=-1) # (..., N) or (...,) + if average_grid: # Reduce grid first + metric_entry_vals = torch.mean( + metric_entry_vals, dim=-2 + ) # (..., d_state) + if sum_vars: # Reduce vars second + metric_entry_vals = torch.sum( + metric_entry_vals, dim=-1 + ) # (..., N) or (...,) return metric_entry_vals + def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ Weighted Mean Squared Error - (...,) is any number of batch dimensions, potentially different but broadcastable + (...,) is any number of batch dimensions, potentially different + but broadcastable pred: (..., N, d_state), prediction target: (..., N, d_state), target pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) - sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) Returns: - metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending - on reduction arguments. + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. """ - entry_mse = torch.nn.functional.mse_loss(pred, target, - reduction='none') # (..., N, d_state) - entry_mse_weighted = entry_mse / (pred_std**2) # (..., N, d_state) + entry_mse = torch.nn.functional.mse_loss( + pred, target, reduction="none" + ) # (..., N, d_state) + entry_mse_weighted = entry_mse / (pred_std**2) # (..., N, d_state) + + return mask_and_reduce_metric( + entry_mse_weighted, + mask=mask, + average_grid=average_grid, + sum_vars=sum_vars, + ) - return mask_and_reduce_metric(entry_mse_weighted, mask=mask, - average_grid=average_grid, sum_vars=sum_vars) def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ (Unweighted) Mean Squared Error - (...,) is any number of batch dimensions, potentially different but broadcastable + (...,) is any number of batch dimensions, potentially different + but broadcastable pred: (..., N, d_state), prediction target: (..., N, d_state), target pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) - sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) Returns: - metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending - on reduction arguments. + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. """ # Replace pred_std with constant ones - return wmse(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars) + return wmse( + pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars + ) + def rmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ Root Mean Squared Error - Note: here take sqrt only after spatial averaging, averaging the RMSE of forecasts. + Note: here take sqrt only after spatial averaging, averaging the RMSE + of forecasts. This is consistent with Weatherbench and others. Because of this, averaging over grid must be set to true. - (...,) is any number of batch dimensions, potentially different but broadcastable + (...,) is any number of batch dimensions, potentially different + but broadcastable pred: (..., N, d_state), prediction target: (..., N, d_state), target pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) - sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) Returns: - metric_val: One of (...,), (..., d_state), depending on reduction arguments. + metric_val: One of (...,), (..., d_state), depending on reduction arguments """ assert average_grid, "Can not compute RMSE without averaging grid" # Spatially averaged mse, masking is also performed here - averaged_mse = mse(pred, target, pred_std, mask, average_grid=True, - sum_vars=False) # (..., d_state) - entry_rmse = torch.sqrt(averaged_mse) # (..., d_state) + averaged_mse = mse( + pred, target, pred_std, mask, average_grid=True, sum_vars=False + ) # (..., d_state) + entry_rmse = torch.sqrt(averaged_mse) # (..., d_state) # Optionally sum over variables here manually if sum_vars: - return torch.sum(entry_rmse, dim=-1) # (...,) + return torch.sum(entry_rmse, dim=-1) # (...,) + + return entry_rmse # (..., d_state) - return entry_rmse # (..., d_state) def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ Weighted Mean Absolute Error - (...,) is any number of batch dimensions, potentially different but broadcastable + (...,) is any number of batch dimensions, potentially different + but broadcastable pred: (..., N, d_state), prediction target: (..., N, d_state), target pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) - sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) Returns: - metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending - on reduction arguments. + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. """ - entry_mae = torch.nn.functional.l1_loss(pred, target, - reduction='none') # (..., N, d_state) - entry_mae_weighted = entry_mae / pred_std # (..., N, d_state) + entry_mae = torch.nn.functional.l1_loss( + pred, target, reduction="none" + ) # (..., N, d_state) + entry_mae_weighted = entry_mae / pred_std # (..., N, d_state) + + return mask_and_reduce_metric( + entry_mae_weighted, + mask=mask, + average_grid=average_grid, + sum_vars=sum_vars, + ) - return mask_and_reduce_metric(entry_mae_weighted, mask=mask, - average_grid=average_grid, sum_vars=sum_vars) def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ (Unweighted) Mean Absolute Error - (...,) is any number of batch dimensions, potentially different but broadcastable + (...,) is any number of batch dimensions, potentially different + but broadcastable pred: (..., N, d_state), prediction target: (..., N, d_state), target pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) - sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) Returns: - metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending - on reduction arguments. + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. """ # Replace pred_std with constant ones - return wmae(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars) + return wmae( + pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars + ) + def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ Negative Log Likelihood loss, for isotropic Gaussian likelihood - (...,) is any number of batch dimensions, potentially different but broadcastable + (...,) is any number of batch dimensions, potentially different + but broadcastable pred: (..., N, d_state), prediction target: (..., N, d_state), target pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) - sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) Returns: - metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending - on reduction arguments. + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. """ # Broadcast pred_std if shaped (d_state,), done internally in Normal class - dist = torch.distributions.Normal(pred, pred_std) # (..., N, d_state) - entry_nll = -dist.log_prob(target) # (..., N, d_state) + dist = torch.distributions.Normal(pred, pred_std) # (..., N, d_state) + entry_nll = -dist.log_prob(target) # (..., N, d_state) - return mask_and_reduce_metric(entry_nll, mask=mask, average_grid=average_grid, - sum_vars=sum_vars) + return mask_and_reduce_metric( + entry_nll, mask=mask, average_grid=average_grid, sum_vars=sum_vars + ) -def crps_gauss(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): + +def crps_gauss( + pred, target, pred_std, mask=None, average_grid=True, sum_vars=True +): """ (Negative) Continuous Ranked Probability Score (CRPS) Closed-form expression based on Gaussian predictive distribution - (...,) is any number of batch dimensions, potentially different but broadcastable + (...,) is any number of batch dimensions, potentially different + but broadcastable pred: (..., N, d_state), prediction target: (..., N, d_state), target pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) - sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) Returns: - metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending - on reduction arguments. + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. """ - std_normal = torch.distributions.Normal(torch.zeros((), device=pred.device), - torch.ones((), device=pred.device)) - target_standard = (target - pred)/pred_std # (..., N, d_state) + std_normal = torch.distributions.Normal( + torch.zeros((), device=pred.device), torch.ones((), device=pred.device) + ) + target_standard = (target - pred) / pred_std # (..., N, d_state) + + entry_crps = -pred_std * ( + torch.pi ** (-0.5) + - 2 * torch.exp(std_normal.log_prob(target_standard)) + - target_standard * (2 * std_normal.cdf(target_standard) - 1) + ) # (..., N, d_state) - entry_crps = -pred_std*( - torch.pi**(-0.5) - -2*torch.exp(std_normal.log_prob(target_standard)) - -target_standard*(2*std_normal.cdf(target_standard) - 1) - ) # (..., N, d_state) + return mask_and_reduce_metric( + entry_crps, mask=mask, average_grid=average_grid, sum_vars=sum_vars + ) - return mask_and_reduce_metric(entry_crps, mask=mask, average_grid=average_grid, - sum_vars=sum_vars) DEFINED_METRICS = { "mse": mse, diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 1700373a..d28aa36c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,18 +1,26 @@ +# Standard library import os -import torch + +# Third-party +import matplotlib.pyplot as plt import numpy as np -from torch import nn import pytorch_lightning as pl +import torch import wandb -import matplotlib.pyplot as plt -from neural_lam import utils, vis, constants, metrics +# First-party +from neural_lam import constants, metrics, utils, vis + class ARModel(pl.LightningModule): """ Generic auto-regressive weather model. Abstract class that can be extended. """ + + # pylint: disable=arguments-differ + # Disable to override args/kwargs from superclass + def __init__(self, args): super().__init__() self.save_hyperparameters() @@ -21,33 +29,51 @@ def __init__(self, args): # Load static features for grid/data static_data_dict = utils.load_static_data(args.dataset) for static_data_name, static_data_tensor in static_data_dict.items(): - self.register_buffer(static_data_name, static_data_tensor, persistent=False) + self.register_buffer( + static_data_name, static_data_tensor, persistent=False + ) # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) if self.output_std: - self.grid_output_dim = 2*constants.grid_state_dim # Pred. dim. in grid cell + self.grid_output_dim = ( + 2 * constants.GRID_STATE_DIM + ) # Pred. dim. in grid cell else: - self.grid_output_dim = constants.grid_state_dim # Pred. dim. in grid cell + self.grid_output_dim = ( + constants.GRID_STATE_DIM + ) # Pred. dim. in grid cell # Store constant per-variable std.-dev. weighting - # Note that this is the inverse of the multiplicative weighting in wMSE/wMAE - self.register_buffer("per_var_std", - self.step_diff_std/torch.sqrt(self.param_weights), persistent=False) + # Note that this is the inverse of the multiplicative weighting + # in wMSE/wMAE + self.register_buffer( + "per_var_std", + self.step_diff_std / torch.sqrt(self.param_weights), + persistent=False, + ) # grid_dim from data + static + batch_static - self.N_grid, grid_static_dim = self.grid_static_features.shape # 63784 = 268x238 - self.grid_dim = 2*constants.grid_state_dim + grid_static_dim +\ - constants.grid_forcing_dim + constants.batch_static_feature_dim + ( + self.num_grid_nodes, + grid_static_dim, + ) = self.grid_static_features.shape # 63784 = 268x238 + self.grid_dim = ( + 2 * constants.GRID_STATE_DIM + + grid_static_dim + + constants.GRID_FORCING_DIM + + constants.BATCH_STATIC_FEATURE_DIM + ) # Instantiate loss function self.loss = metrics.get_metric(args.loss) # Pre-compute interior mask for use in loss function - self.register_buffer("interior_mask", 1. - self.border_mask, - persistent=False) # (N_grid, 1), 1 for non-border + self.register_buffer( + "interior_mask", 1.0 - self.border_mask, persistent=False + ) # (num_grid_nodes, 1), 1 for non-border - self.step_length = args.step_length # Number of hours per pred. step + self.step_length = args.step_length # Number of hours per pred. step self.val_metrics = { "rmse": [], } @@ -56,8 +82,7 @@ def __init__(self, args): "mae": [], } if self.output_std: - self.test_metrics["output_std"] = [] # Treat as metric - + self.test_metrics["output_std"] = [] # Treat as metric # For making restoring of optimizer state optional (slight hack) self.opt_state = None @@ -70,7 +95,9 @@ def __init__(self, args): self.spatial_loss_maps = [] def configure_optimizers(self): - opt = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.95)) + opt = torch.optim.AdamW( + self.parameters(), lr=self.lr, betas=(0.9, 0.95) + ) if self.opt_state: opt.load_state_dict(self.opt_state) @@ -81,7 +108,7 @@ def interior_mask_bool(self): """ Get the interior mask as a boolean (N,) mask. """ - return self.interior_mask[:,0].to(torch.bool) + return self.interior_mask[:, 0].to(torch.bool) @staticmethod def expand_to_batch(x, batch_size): @@ -90,27 +117,30 @@ def expand_to_batch(x, batch_size): """ return x.unsqueeze(0).expand(batch_size, -1, -1) - def predict_step(self, prev_state, prev_prev_state, batch_static_features, forcing): + def predict_step( + self, prev_state, prev_prev_state, batch_static_features, forcing + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, N_grid, feature_dim), X_t - prev_prev_state: (B, N_grid, feature_dim), X_{t-1} - batch_static_features: (B, N_grid, batch_static_feature_dim) - forcing: (B, N_grid, forcing_dim) + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + batch_static_features: (B, num_grid_nodes, batch_static_feature_dim) + forcing: (B, num_grid_nodes, forcing_dim) """ raise NotImplementedError("No prediction step implemented") - def unroll_prediction(self, init_states, batch_static_features, forcing_features, - true_states): + def unroll_prediction( + self, init_states, batch_static_features, forcing_features, true_states + ): """ Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, N_grid, d_f) - batch_static_features: (B, N_grid, d_static_f) - forcing_features: (B, pred_steps, N_grid, d_static_f) - true_states: (B, pred_steps, N_grid, d_f) + init_states: (B, 2, num_grid_nodes, d_f) + batch_static_features: (B, num_grid_nodes, d_static_f) + forcing_features: (B, pred_steps, num_grid_nodes, d_static_f) + true_states: (B, pred_steps, num_grid_nodes, d_f) """ - prev_prev_state = init_states[:,0] - prev_state = init_states[:,1] + prev_prev_state = init_states[:, 0] + prev_state = init_states[:, 1] prediction_list = [] pred_std_list = [] pred_steps = forcing_features.shape[1] @@ -119,28 +149,35 @@ def unroll_prediction(self, init_states, batch_static_features, forcing_features forcing = forcing_features[:, i] border_state = true_states[:, i] - pred_state, pred_std = self.predict_step(prev_state, prev_prev_state, - batch_static_features, forcing) - # state: (B, N_grid, d_f) - # pred_std: (B, N_grid, d_f) or None + pred_state, pred_std = self.predict_step( + prev_state, prev_prev_state, batch_static_features, forcing + ) + # state: (B, num_grid_nodes, d_f) + # pred_std: (B, num_grid_nodes, d_f) or None # Overwrite border with true state - new_state = self.border_mask*border_state +\ - self.interior_mask*pred_state + new_state = ( + self.border_mask * border_state + + self.interior_mask * pred_state + ) prediction_list.append(new_state) if self.output_std: pred_std_list.append(pred_std) - # Upate conditioning states + # Update conditioning states prev_prev_state = prev_state prev_state = new_state - prediction = torch.stack(prediction_list, dim=1) # (B, pred_steps, N_grid, d_f) + prediction = torch.stack( + prediction_list, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) if self.output_std: - pred_std = torch.stack(pred_std_list, dim=1) # (B, pred_steps, N_grid, d_f) + pred_std = torch.stack( + pred_std_list, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) else: - pred_std = self.per_var_std # (d_f,) + pred_std = self.per_var_std # (d_f,) return prediction, pred_std @@ -149,18 +186,25 @@ def common_step(self, batch): Predict on single batch batch = time_series, batch_static_features, forcing_features - init_states: (B, 2, N_grid, d_features) - target_states: (B, pred_steps, N_grid, d_features) - batch_static_features: (B, N_grid, d_static_f), for example open water - forcing_features: (B, pred_steps, N_grid, d_forcing), where index 0 - corresponds to index 1 of init_states + init_states: (B, 2, num_grid_nodes, d_features) + target_states: (B, pred_steps, num_grid_nodes, d_features) + batch_static_features: (B, num_grid_nodes, d_static_f), + for example open water + forcing_features: (B, pred_steps, num_grid_nodes, d_forcing), + where index 0 corresponds to index 1 of init_states """ - init_states, target_states, batch_static_features, forcing_features = batch - - prediction, pred_std = self.unroll_prediction(init_states, batch_static_features, - forcing_features, target_states) # (B, pred_steps, N_grid, d_f) - # prediction: (B, pred_steps, N_grid, d_f) - # pred_std: (B, pred_steps, N_grid, d_f) or (d_f,) + ( + init_states, + target_states, + batch_static_features, + forcing_features, + ) = batch + + prediction, pred_std = self.unroll_prediction( + init_states, batch_static_features, forcing_features, target_states + ) # (B, pred_steps, num_grid_nodes, d_f) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) return prediction, target_states, pred_std @@ -171,18 +215,22 @@ def training_step(self, batch): prediction, target, pred_std = self.common_step(batch) # Compute loss - batch_loss = torch.mean(self.loss(prediction, target, pred_std, - mask=self.interior_mask_bool)) # mean over unrolled times and batch + batch_loss = torch.mean( + self.loss( + prediction, target, pred_std, mask=self.interior_mask_bool + ) + ) # mean over unrolled times and batch log_dict = {"train_loss": batch_loss} - self.log_dict(log_dict, prog_bar=True, on_step=True, on_epoch=True, - sync_dist=True) + self.log_dict( + log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True + ) return batch_loss def all_gather_cat(self, tensor_to_gather): """ - Gather tensors across all ranks, and concatenate across dim. 0 (instead of - stacking in new dim. 0) + Gather tensors across all ranks, and concatenate across dim. 0 + (instead of stacking in new dim. 0) tensor_to_gather: (d1, d2, ...), distributed over K ranks @@ -190,26 +238,38 @@ def all_gather_cat(self, tensor_to_gather): """ return self.all_gather(tensor_to_gather).flatten(0, 1) - def validation_step(self, batch, batch_idx): + def validation_step(self, batch): """ Run validation on single batch """ prediction, target, pred_std = self.common_step(batch) - time_step_loss = torch.mean(self.loss(prediction, - target, pred_std, mask=self.interior_mask_bool), dim=0) # (time_steps-1) + time_step_loss = torch.mean( + self.loss( + prediction, target, pred_std, mask=self.interior_mask_bool + ), + dim=0, + ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) # Log loss per time step forward and mean - val_log_dict = {f"val_loss_unroll{step}": time_step_loss[step-1] - for step in constants.val_step_log_errors} + val_log_dict = { + f"val_loss_unroll{step}": time_step_loss[step - 1] + for step in constants.VAL_STEP_LOG_ERRORS + } val_log_dict["val_mean_loss"] = mean_loss - self.log_dict(val_log_dict, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict( + val_log_dict, on_step=False, on_epoch=True, sync_dist=True + ) # Store RMSEs - entry_rmses = metrics.rmse(prediction, target, pred_std, - mask=self.interior_mask_bool, - sum_vars=False) # (B, pred_steps, d_f) + entry_rmses = metrics.rmse( + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) self.val_metrics["rmse"].append(entry_rmses) def on_validation_epoch_end(self): @@ -223,53 +283,76 @@ def on_validation_epoch_end(self): for metric_list in self.val_metrics.values(): metric_list.clear() - def test_step(self, batch, batch_idx): + def test_step(self, batch): """ Run test on single batch """ prediction, target, pred_std = self.common_step(batch) - # prediction: (B, pred_steps, N_grid, d_f) - # pred_std: (B, pred_steps, N_grid, d_f) or (d_f,) - - time_step_loss = torch.mean(self.loss(prediction, target, pred_std, - mask=self.interior_mask_bool), dim=0) # (time_steps-1,) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + + time_step_loss = torch.mean( + self.loss( + prediction, target, pred_std, mask=self.interior_mask_bool + ), + dim=0, + ) # (time_steps-1,) mean_loss = torch.mean(time_step_loss) # Log loss per time step forward and mean - test_log_dict = {f"test_loss_unroll{step}": time_step_loss[step-1] - for step in constants.val_step_log_errors} + test_log_dict = { + f"test_loss_unroll{step}": time_step_loss[step - 1] + for step in constants.VAL_STEP_LOG_ERRORS + } test_log_dict["test_mean_loss"] = mean_loss - self.log_dict(test_log_dict, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict( + test_log_dict, on_step=False, on_epoch=True, sync_dist=True + ) # Compute all evaluation metrics for error maps - # Note: explicitly list metrics here, as test_metrics can contain additional - # ones, computed differently, but that should be aggregated on_test_epoch_end + # Note: explicitly list metrics here, as test_metrics can contain + # additional ones, computed differently, but that should be aggregated + # on_test_epoch_end for metric_name in ("rmse", "mae"): metric_func = metrics.get_metric(metric_name) - batch_metric_vals = metric_func(prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False) # (B, pred_steps, d_f) + batch_metric_vals = metric_func( + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) self.test_metrics[metric_name].append(batch_metric_vals) if self.output_std: - # Store ouput std. per variable, spatially averaged - mean_pred_std = torch.mean(pred_std[..., self.interior_mask_bool,:], - dim=-2) # (B, pred_steps, d_f) + # Store output std. per variable, spatially averaged + mean_pred_std = torch.mean( + pred_std[..., self.interior_mask_bool, :], dim=-2 + ) # (B, pred_steps, d_f) self.test_metrics["output_std"].append(mean_pred_std) # Save per-sample spatial loss for specific times - spatial_loss = self.loss(prediction, target, pred_std, - average_grid=False) # (B, pred_steps, N_grid) - log_spatial_losses = spatial_loss[:,constants.val_step_log_errors-1] - self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, N_grid) + spatial_loss = self.loss( + prediction, target, pred_std, average_grid=False + ) # (B, pred_steps, num_grid_nodes) + log_spatial_losses = spatial_loss[:, constants.VAL_STEP_LOG_ERRORS - 1] + self.spatial_loss_maps.append(log_spatial_losses) + # (B, N_log, num_grid_nodes) # Plot example predictions (on rank 0 only) - if self.trainer.is_global_zero and self.plotted_examples < self.n_example_pred: + if ( + self.trainer.is_global_zero + and self.plotted_examples < self.n_example_pred + ): # Need to plot more example predictions - n_additional_examples = min(prediction.shape[0], self.n_example_pred - - self.plotted_examples) + n_additional_examples = min( + prediction.shape[0], self.n_example_pred - self.plotted_examples + ) - self.plot_examples(batch, n_additional_examples, prediction=prediction) + self.plot_examples( + batch, n_additional_examples, prediction=prediction + ) def plot_examples(self, batch, n_examples, prediction=None): """ @@ -277,7 +360,8 @@ def plot_examples(self, batch, n_examples, prediction=None): batch: batch with data to plot corresponding forecasts for n_examples: number of forecasts to plot - prediction: (B, pred_steps, N_grid, d_f), existing prediction. Generate if None. + prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. + Generate if None. """ if prediction is None: prediction, target = self.common_step(batch) @@ -285,54 +369,83 @@ def plot_examples(self, batch, n_examples, prediction=None): target = batch[1] # Rescale to original data scale - prediction_rescaled = prediction*self.data_std + self.data_mean - target_rescaled = target*self.data_std + self.data_mean + prediction_rescaled = prediction * self.data_std + self.data_mean + target_rescaled = target * self.data_std + self.data_mean # Iterate over the examples for pred_slice, target_slice in zip( - prediction_rescaled[:n_examples], - target_rescaled[:n_examples]): - # Each slice is (pred_steps, N_grid, d_f) - self.plotted_examples += 1 # Increment already here - - var_vmin = torch.minimum( - pred_slice.flatten(0,1).min(dim=0)[0], - target_slice.flatten(0,1).min(dim=0)[0]).cpu().numpy() # (d_f,) - var_vmax = torch.maximum( - pred_slice.flatten(0,1).max(dim=0)[0], - target_slice.flatten(0,1).max(dim=0)[0]).cpu().numpy() # (d_f,) + prediction_rescaled[:n_examples], target_rescaled[:n_examples] + ): + # Each slice is (pred_steps, num_grid_nodes, d_f) + self.plotted_examples += 1 # Increment already here + + var_vmin = ( + torch.minimum( + pred_slice.flatten(0, 1).min(dim=0)[0], + target_slice.flatten(0, 1).min(dim=0)[0], + ) + .cpu() + .numpy() + ) # (d_f,) + var_vmax = ( + torch.maximum( + pred_slice.flatten(0, 1).max(dim=0)[0], + target_slice.flatten(0, 1).max(dim=0)[0], + ) + .cpu() + .numpy() + ) # (d_f,) var_vranges = list(zip(var_vmin, var_vmax)) # Iterate over prediction horizon time steps - for t_i, (pred_t, target_t) in enumerate(zip(pred_slice, target_slice), - start=1): + for t_i, (pred_t, target_t) in enumerate( + zip(pred_slice, target_slice), start=1 + ): # Create one figure per variable at this time step - var_figs = [vis.plot_prediction( - pred_t[:,var_i], target_t[:,var_i], - self.interior_mask[:,0], - title=f"{var_name} ({var_unit}), " + var_figs = [ + vis.plot_prediction( + pred_t[:, var_i], + target_t[:, var_i], + self.interior_mask[:, 0], + title=f"{var_name} ({var_unit}), " f"t={t_i} ({self.step_length*t_i} h)", - vrange=var_vrange) - for var_i, (var_name, var_unit, var_vrange) in - enumerate(zip( - constants.param_names_short, - constants.param_units, - var_vranges - )) - ] - - wandb.log({ - f"{var_name}_example_{self.plotted_examples}": - wandb.Image(fig) - for var_name, fig in zip(constants.param_names_short, var_figs) - }) - plt.close("all") # Close all figs for this time step, saves memory + vrange=var_vrange, + ) + for var_i, (var_name, var_unit, var_vrange) in enumerate( + zip( + constants.PARAM_NAMES_SHORT, + constants.PARAM_UNITS, + var_vranges, + ) + ) + ] + + example_i = self.plotted_examples + wandb.log( + { + f"{var_name}_example_{example_i}": wandb.Image(fig) + for var_name, fig in zip( + constants.PARAM_NAMES_SHORT, var_figs + ) + } + ) + plt.close( + "all" + ) # Close all figs for this time step, saves memory # Save pred and target as .pt files - torch.save(pred_slice.cpu(),os.path.join( - wandb.run.dir, f'example_pred_{self.plotted_examples}.pt')) - torch.save(target_slice.cpu(),os.path.join( - wandb.run.dir, f'example_target_{self.plotted_examples}.pt')) + torch.save( + pred_slice.cpu(), + os.path.join( + wandb.run.dir, f"example_pred_{self.plotted_examples}.pt" + ), + ) + torch.save( + target_slice.cpu(), + os.path.join( + wandb.run.dir, f"example_target_{self.plotted_examples}.pt" + ), + ) def create_metric_log_dict(self, metric_tensor, prefix, metric_name): """ @@ -347,28 +460,36 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): log_dict: dict with everything to log for given metric """ log_dict = {} - metric_fig = vis.plot_error_map(metric_tensor, step_length=self.step_length) + metric_fig = vis.plot_error_map( + metric_tensor, step_length=self.step_length + ) full_log_name = f"{prefix}_{metric_name}" log_dict[full_log_name] = wandb.Image(metric_fig) if prefix == "test": # Save pdf - metric_fig.savefig(os.path.join(wandb.run.dir, - f"{full_log_name}.pdf")) + metric_fig.savefig( + os.path.join(wandb.run.dir, f"{full_log_name}.pdf") + ) # Save errors also as csv - np.savetxt(os.path.join(wandb.run.dir, - f"{full_log_name}.csv"), metric_tensor.cpu().numpy(), - delimiter=",") + np.savetxt( + os.path.join(wandb.run.dir, f"{full_log_name}.csv"), + metric_tensor.cpu().numpy(), + delimiter=",", + ) # Check if metrics are watched, log exact values for specific vars - if full_log_name in constants.metrics_watch: - for var_i, timesteps in constants.var_leads_metrics_watch.items(): - var_name = constants.param_names_short[var_i] - log_dict.update({ - f"{full_log_name}_{var_name}_step_{step}": - metric_tensor[step-1, var_i] # 1-indexed in constants - for step in timesteps - }) + if full_log_name in constants.METRICS_WATCH: + for var_i, timesteps in constants.VAR_LEADS_METRICS_WATCH.items(): + var = constants.PARAM_NAMES_SHORT[var_i] + log_dict.update( + { + f"{full_log_name}_{var}_step_{step}": metric_tensor[ + step - 1, var_i + ] # 1-indexed in constants + for step in timesteps + } + ) return log_dict @@ -376,24 +497,31 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): """ Aggregate and create error map plots for all metrics in metrics_dict - metrics_dict: dictionary with metric_names and list of tensors with step-evals. + metrics_dict: dictionary with metric_names and list of tensors + with step-evals. prefix: string, prefix to use for logging """ log_dict = {} for metric_name, metric_val_list in metrics_dict.items(): - metric_tensor = self.all_gather_cat(torch.cat(metric_val_list, - dim=0)) # (N_eval, pred_steps, d_f) + metric_tensor = self.all_gather_cat( + torch.cat(metric_val_list, dim=0) + ) # (N_eval, pred_steps, d_f) if self.trainer.is_global_zero: # Note: we here assume rescaling for all metrics is linear - metric_rescaled = torch.mean(metric_tensor, dim=0) * self.data_std + metric_rescaled = ( + torch.mean(metric_tensor, dim=0) * self.data_std + ) # (pred_steps, d_f) - log_dict.update(self.create_metric_log_dict(metric_rescaled, prefix, - metric_name)) + log_dict.update( + self.create_metric_log_dict( + metric_rescaled, prefix, metric_name + ) + ) if self.trainer.is_global_zero and not self.trainer.sanity_checking: - wandb.log(log_dict) # Log all - plt.close("all") # Close all figs + wandb.log(log_dict) # Log all + plt.close("all") # Close all figs def on_test_epoch_end(self): """ @@ -404,45 +532,66 @@ def on_test_epoch_end(self): self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") # Plot spatial loss maps - spatial_loss_tensor = self.all_gather_cat(torch.cat(self.spatial_loss_maps, - dim=0)) # (N_test, N_log, N_grid) + spatial_loss_tensor = self.all_gather_cat( + torch.cat(self.spatial_loss_maps, dim=0) + ) # (N_test, N_log, num_grid_nodes) if self.trainer.is_global_zero: - mean_spatial_loss = torch.mean(spatial_loss_tensor, dim=0) # (N_log, N_grid) - - loss_map_figs = [vis.plot_spatial_error(loss_map, self.interior_mask[:,0], - title=f"Test loss, t={t_i} ({self.step_length*t_i} h)") - for t_i, loss_map in zip(constants.val_step_log_errors, - mean_spatial_loss)] + mean_spatial_loss = torch.mean( + spatial_loss_tensor, dim=0 + ) # (N_log, num_grid_nodes) + + loss_map_figs = [ + vis.plot_spatial_error( + loss_map, + self.interior_mask[:, 0], + title=f"Test loss, t={t_i} ({self.step_length*t_i} h)", + ) + for t_i, loss_map in zip( + constants.VAL_STEP_LOG_ERRORS, mean_spatial_loss + ) + ] # log all to same wandb key, sequentially for fig in loss_map_figs: wandb.log({"test_loss": wandb.Image(fig)}) # also make without title and save as pdf - pdf_loss_map_figs = [vis.plot_spatial_error(loss_map, - self.interior_mask[:,0]) for loss_map in mean_spatial_loss] + pdf_loss_map_figs = [ + vis.plot_spatial_error(loss_map, self.interior_mask[:, 0]) + for loss_map in mean_spatial_loss + ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip(constants.val_step_log_errors,pdf_loss_map_figs): + for t_i, fig in zip( + constants.VAL_STEP_LOG_ERRORS, pdf_loss_map_figs + ): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also - torch.save(mean_spatial_loss.cpu(),os.path.join( - wandb.run.dir, 'mean_spatial_loss.pt')) + torch.save( + mean_spatial_loss.cpu(), + os.path.join(wandb.run.dir, "mean_spatial_loss.pt"), + ) self.spatial_loss_maps.clear() - def on_load_checkpoint(self, ckpt): + def on_load_checkpoint(self, checkpoint): """ Perform any changes to state dict before loading checkpoint """ - loaded_state_dict = ckpt["state_dict"] + loaded_state_dict = checkpoint["state_dict"] - # Fix for loading older models after IneractionNet refactoring, where the - # grid MLP was moved outside the encoder InteractionNet class + # Fix for loading older models after IneractionNet refactoring, where + # the grid MLP was moved outside the encoder InteractionNet class if "g2m_gnn.grid_mlp.0.weight" in loaded_state_dict: - replace_keys = list(filter(lambda key: key.startswith("g2m_gnn.grid_mlp"), - loaded_state_dict.keys())) + replace_keys = list( + filter( + lambda key: key.startswith("g2m_gnn.grid_mlp"), + loaded_state_dict.keys(), + ) + ) for old_key in replace_keys: - new_key = old_key.replace("g2m_gnn.grid_mlp", "encoding_grid_mlp") + new_key = old_key.replace( + "g2m_gnn.grid_mlp", "encoding_grid_mlp" + ) loaded_state_dict[new_key] = loaded_state_dict[old_key] del loaded_state_dict[old_key] diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index cbf9f8ab..31b3c79d 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -1,20 +1,24 @@ +# Third-party import torch -import torch_geometric as pyg -from neural_lam import utils, constants -from neural_lam.models.ar_model import ARModel +# First-party +from neural_lam import utils from neural_lam.interaction_net import InteractionNet +from neural_lam.models.ar_model import ARModel + class BaseGraphModel(ARModel): """ Base (abstract) class for graph-based models building on the encode-process-decode idea. """ + def __init__(self, args): super().__init__(args) # Load graph with static features - # NOTE: (IMPORTANT!) mesh nodes MUST have the first N_mesh indices, + # NOTE: (IMPORTANT!) mesh nodes MUST have the first + # num_mesh_nodes indices, self.hierarchical, graph_ldict = utils.load_graph(args.graph) for name, attr_value in graph_ldict.items(): # Make BufferLists module members and register tensors as buffers @@ -24,9 +28,11 @@ def __init__(self, args): setattr(self, name, attr_value) # Specify dimensions of data - self.N_mesh, N_mesh_ignore = self.get_num_mesh() - print(f"Loaded graph with {self.N_grid + self.N_mesh} nodes "+ - f"({self.N_grid} grid, {self.N_mesh} mesh)") + self.num_mesh_nodes, _ = self.get_num_mesh() + print( + f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} " + f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)" + ) # grid_dim from data + static + batch_static self.g2m_edges, g2m_dim = self.g2m_features.shape @@ -34,28 +40,39 @@ def __init__(self, args): # Define sub-models # Feature embedders for grid - self.mlp_blueprint_end = [args.hidden_dim]*(args.hidden_layers + 1) - self.grid_embedder = utils.make_mlp([self.grid_dim] + - self.mlp_blueprint_end) - self.g2m_embedder = utils.make_mlp([g2m_dim] + - self.mlp_blueprint_end) - self.m2g_embedder = utils.make_mlp([m2g_dim] + - self.mlp_blueprint_end) + self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1) + self.grid_embedder = utils.make_mlp( + [self.grid_dim] + self.mlp_blueprint_end + ) + self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) + self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) # GNNs # encoder - self.g2m_gnn = InteractionNet(self.g2m_edge_index, - args.hidden_dim, hidden_layers=args.hidden_layers, update_edges=False) - self.encoding_grid_mlp = utils.make_mlp([args.hidden_dim] - + self.mlp_blueprint_end) + self.g2m_gnn = InteractionNet( + self.g2m_edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + update_edges=False, + ) + self.encoding_grid_mlp = utils.make_mlp( + [args.hidden_dim] + self.mlp_blueprint_end + ) # decoder - self.m2g_gnn = InteractionNet(self.m2g_edge_index, - args.hidden_dim, hidden_layers=args.hidden_layers, update_edges=False) + self.m2g_gnn = InteractionNet( + self.m2g_edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + update_edges=False, + ) # Output mapping (hidden_dim -> output_dim) - self.output_map = utils.make_mlp([args.hidden_dim]*(args.hidden_layers + 1) +\ - [self.grid_output_dim], layer_norm=False) # No layer norm on this one + self.output_map = utils.make_mlp( + [args.hidden_dim] * (args.hidden_layers + 1) + + [self.grid_output_dim], + layer_norm=False, + ) # No layer norm on this one def get_num_mesh(self): """ @@ -66,8 +83,8 @@ def get_num_mesh(self): def embedd_mesh_nodes(self): """ - Embedd static mesh features - Returns tensor of shape (N_mesh, d_h) + Embed static mesh features + Returns tensor of shape (num_mesh_nodes, d_h) """ raise NotImplementedError("embedd_mesh_nodes not implemented") @@ -76,63 +93,86 @@ def process_step(self, mesh_rep): Process step of embedd-process-decode framework Processes the representation on the mesh, possible in multiple steps - mesh_rep: has shape (B, N_mesh, d_h) - Returns mesh_rep: (B, N_mesh, d_h) + mesh_rep: has shape (B, num_mesh_nodes, d_h) + Returns mesh_rep: (B, num_mesh_nodes, d_h) """ raise NotImplementedError("process_step not implemented") - def predict_step(self, prev_state, prev_prev_state, batch_static_features, forcing): + def predict_step( + self, prev_state, prev_prev_state, batch_static_features, forcing + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, N_grid, feature_dim), X_t - prev_prev_state: (B, N_grid, feature_dim), X_{t-1} - batch_static_features: (B, N_grid, batch_static_feature_dim) - forcing: (B, N_grid, forcing_dim) + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + batch_static_features: (B, num_grid_nodes, batch_static_feature_dim) + forcing: (B, num_grid_nodes, forcing_dim) """ batch_size = prev_state.shape[0] - # Create full grid node features of shape (B, N_grid, grid_dim) - grid_features = torch.cat((prev_state, prev_prev_state, batch_static_features, - forcing, self.expand_to_batch(self.grid_static_features, batch_size)), - dim=-1) - - # Embedd all features - grid_emb = self.grid_embedder(grid_features) # (B, N_grid, d_h) - g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) - m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) + # Create full grid node features of shape (B, num_grid_nodes, grid_dim) + grid_features = torch.cat( + ( + prev_state, + prev_prev_state, + batch_static_features, + forcing, + self.expand_to_batch(self.grid_static_features, batch_size), + ), + dim=-1, + ) + + # Embed all features + grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h) + g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) + m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) mesh_emb = self.embedd_mesh_nodes() # Map from grid to mesh - mesh_emb_expanded = self.expand_to_batch(mesh_emb, batch_size) # (B, N_mesh, d_h) + mesh_emb_expanded = self.expand_to_batch( + mesh_emb, batch_size + ) # (B, num_mesh_nodes, d_h) g2m_emb_expanded = self.expand_to_batch(g2m_emb, batch_size) # This also splits representation into grid and mesh - mesh_rep = self.g2m_gnn(grid_emb, mesh_emb_expanded, - g2m_emb_expanded) # (B, N_mesh, d_h) + mesh_rep = self.g2m_gnn( + grid_emb, mesh_emb_expanded, g2m_emb_expanded + ) # (B, num_mesh_nodes, d_h) # Also MLP with residual for grid representation - grid_rep = grid_emb + self.encoding_grid_mlp(grid_emb) # (B, N_grid, d_h) + grid_rep = grid_emb + self.encoding_grid_mlp( + grid_emb + ) # (B, num_grid_nodes, d_h) # Run processor step mesh_rep = self.process_step(mesh_rep) # Map back from mesh to grid m2g_emb_expanded = self.expand_to_batch(m2g_emb, batch_size) - grid_rep = self.m2g_gnn(mesh_rep, grid_rep, m2g_emb_expanded) # (B, N_grid, d_h) + grid_rep = self.m2g_gnn( + mesh_rep, grid_rep, m2g_emb_expanded + ) # (B, num_grid_nodes, d_h) # Map to output dimension, only for grid - net_output = self.output_map(grid_rep) # (B, N_grid, d_grid_out) + net_output = self.output_map( + grid_rep + ) # (B, num_grid_nodes, d_grid_out) if self.output_std: - pred_delta_mean, pred_std_raw = net_output.chunk(2, - dim=-1) # both (B, N_grid, d_f) + pred_delta_mean, pred_std_raw = net_output.chunk( + 2, dim=-1 + ) # both (B, num_grid_nodes, d_f) # Note: The predicted std. is not scaled in any way here + # linter for some reason does not think softplus is callable + # pylint: disable-next=not-callable pred_std = torch.nn.functional.softplus(pred_std_raw) else: pred_delta_mean = net_output pred_std = None # Rescale with one-step difference statistics - rescaled_delta_mean = pred_delta_mean*self.step_diff_std + self.step_diff_mean + rescaled_delta_mean = ( + pred_delta_mean * self.step_diff_std + self.step_diff_mean + ) # Residual connection for full state return prev_state + rescaled_delta_mean, pred_std diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index cb52d213..8ce87030 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -1,36 +1,43 @@ -import torch +# Third-party from torch import nn +# First-party from neural_lam import utils -from neural_lam.models.base_graph_model import BaseGraphModel from neural_lam.interaction_net import InteractionNet +from neural_lam.models.base_graph_model import BaseGraphModel + + class BaseHiGraphModel(BaseGraphModel): """ Base class for hierarchical graph models. """ + def __init__(self, args): super().__init__(args) # Track number of nodes, edges on each level # Flatten lists for efficient embedding - self.N_levels = len(self.mesh_static_features) + self.num_levels = len(self.mesh_static_features) # Number of mesh nodes at each level - self.N_mesh_levels = [mesh_feat.shape[0] for mesh_feat in - self.mesh_static_features] # Needs as python list for later - N_mesh_levels_torch = torch.tensor(self.N_mesh_levels) + self.level_mesh_sizes = [ + mesh_feat.shape[0] for mesh_feat in self.mesh_static_features + ] # Needs as python list for later # Print some useful info - print("Loaded hierachical graph with structure:") - for l, N_level in enumerate(self.N_mesh_levels): - same_level_edges = self.m2m_features[l].shape[0] - print(f"level {l} - {N_level} nodes, {same_level_edges} same-level edges") - - if l < (self.N_levels-1): - up_edges = self.mesh_up_features[l].shape[0] - down_edges = self.mesh_down_features[l].shape[0] - print(f" {l}<->{l+1} - {up_edges} up edges, {down_edges} down edges") - + print("Loaded hierarchical graph with structure:") + for level_index, level_mesh_size in enumerate(self.level_mesh_sizes): + same_level_edges = self.m2m_features[level_index].shape[0] + print( + f"level {level_index} - {level_mesh_size} nodes, " + f"{same_level_edges} same-level edges" + ) + + if level_index < (self.num_levels - 1): + up_edges = self.mesh_up_features[level_index].shape[0] + down_edges = self.mesh_down_features[level_index].shape[0] + print(f" {level_index}<->{level_index+1}") + print(f" - {up_edges} up edges, {down_edges} down edges") # Embedders # Assume all levels have same static feature dimensionality mesh_dim = self.mesh_static_features[0].shape[1] @@ -39,41 +46,76 @@ def __init__(self, args): mesh_down_dim = self.mesh_down_features[0].shape[1] # Separate mesh node embedders for each level - self.mesh_embedders = nn.ModuleList([utils.make_mlp([mesh_dim] + - self.mlp_blueprint_end) for _ in range(self.N_levels)]) - self.mesh_same_embedders = nn.ModuleList([utils.make_mlp([mesh_same_dim] + - self.mlp_blueprint_end) for _ in range(self.N_levels)]) - self.mesh_up_embedders = nn.ModuleList([utils.make_mlp([mesh_up_dim] + - self.mlp_blueprint_end) for _ in range(self.N_levels-1)]) - self.mesh_down_embedders = nn.ModuleList([utils.make_mlp([mesh_down_dim] + - self.mlp_blueprint_end) for _ in range(self.N_levels-1)]) + self.mesh_embedders = nn.ModuleList( + [ + utils.make_mlp([mesh_dim] + self.mlp_blueprint_end) + for _ in range(self.num_levels) + ] + ) + self.mesh_same_embedders = nn.ModuleList( + [ + utils.make_mlp([mesh_same_dim] + self.mlp_blueprint_end) + for _ in range(self.num_levels) + ] + ) + self.mesh_up_embedders = nn.ModuleList( + [ + utils.make_mlp([mesh_up_dim] + self.mlp_blueprint_end) + for _ in range(self.num_levels - 1) + ] + ) + self.mesh_down_embedders = nn.ModuleList( + [ + utils.make_mlp([mesh_down_dim] + self.mlp_blueprint_end) + for _ in range(self.num_levels - 1) + ] + ) # Instantiate GNNs # Init GNNs - self.mesh_init_gnns = nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers) - for edge_index in self.mesh_up_edge_index]) + self.mesh_init_gnns = nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + ) + for edge_index in self.mesh_up_edge_index + ] + ) # Read out GNNs - self.mesh_read_gnns = nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers, - update_edges=False) - for edge_index in self.mesh_down_edge_index]) + self.mesh_read_gnns = nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + update_edges=False, + ) + for edge_index in self.mesh_down_edge_index + ] + ) def get_num_mesh(self): """ Compute number of mesh nodes from loaded features, and number of mesh nodes that should be ignored in encoding/decoding """ - N_mesh = sum(node_feat.shape[0] for node_feat in self.mesh_static_features) - N_mesh_ignore = N_mesh - self.mesh_static_features[0].shape[0] - return N_mesh, N_mesh_ignore + num_mesh_nodes = sum( + node_feat.shape[0] for node_feat in self.mesh_static_features + ) + num_mesh_nodes_ignore = ( + num_mesh_nodes - self.mesh_static_features[0].shape[0] + ) + return num_mesh_nodes, num_mesh_nodes_ignore def embedd_mesh_nodes(self): """ - Embedd static mesh features - This embedds only bottom level, rest is done at beginning of processing step - Returns tensor of shape (N_mesh[0], d_h) + Embed static mesh features + This embeds only bottom level, rest is done at beginning of + processing step + Returns tensor of shape (num_mesh_nodes[0], d_h) """ return self.mesh_embedders[0](self.mesh_static_features[0]) @@ -82,75 +124,106 @@ def process_step(self, mesh_rep): Process step of embedd-process-decode framework Processes the representation on the mesh, possible in multiple steps - mesh_rep: has shape (B, N_mesh, d_h) - Returns mesh_rep: (B, N_mesh, d_h) + mesh_rep: has shape (B, num_mesh_nodes, d_h) + Returns mesh_rep: (B, num_mesh_nodes, d_h) """ batch_size = mesh_rep.shape[0] - # EMBEDD REMAINING MESH NODES (levels >= 1) - + # EMBED REMAINING MESH NODES (levels >= 1) - # Create list of mesh node representations for each level, - # each of size (B, N_mesh[l], d_h) - mesh_rep_levels = [mesh_rep] + [self.expand_to_batch( - emb(node_static_features), batch_size) for - emb, node_static_features in - zip(list(self.mesh_embedders)[1:], list(self.mesh_static_features)[1:])] - - # - EMBEDD EDGES - - # Embedd edges, expand with batch dimension - mesh_same_rep = [self.expand_to_batch(emb(edge_feat), batch_size) for - emb, edge_feat in zip(self.mesh_same_embedders, self.m2m_features)] - mesh_up_rep = [self.expand_to_batch(emb(edge_feat), batch_size) for - emb, edge_feat in zip(self.mesh_up_embedders, self.mesh_up_features)] - mesh_down_rep = [self.expand_to_batch(emb(edge_feat), batch_size) for - emb, edge_feat in zip(self.mesh_down_embedders, self.mesh_down_features)] + # each of size (B, num_mesh_nodes[l], d_h) + mesh_rep_levels = [mesh_rep] + [ + self.expand_to_batch(emb(node_static_features), batch_size) + for emb, node_static_features in zip( + list(self.mesh_embedders)[1:], + list(self.mesh_static_features)[1:], + ) + ] + + # - EMBED EDGES - + # Embed edges, expand with batch dimension + mesh_same_rep = [ + self.expand_to_batch(emb(edge_feat), batch_size) + for emb, edge_feat in zip( + self.mesh_same_embedders, self.m2m_features + ) + ] + mesh_up_rep = [ + self.expand_to_batch(emb(edge_feat), batch_size) + for emb, edge_feat in zip( + self.mesh_up_embedders, self.mesh_up_features + ) + ] + mesh_down_rep = [ + self.expand_to_batch(emb(edge_feat), batch_size) + for emb, edge_feat in zip( + self.mesh_down_embedders, self.mesh_down_features + ) + ] # - MESH INIT. - # Let level_l go from 1 to L for level_l, gnn in enumerate(self.mesh_init_gnns, start=1): # Extract representations - send_node_rep = mesh_rep_levels[level_l-1] # (B, N_mesh[l-1], d_h) - rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) - edge_rep = mesh_up_rep[level_l-1] + send_node_rep = mesh_rep_levels[ + level_l - 1 + ] # (B, num_mesh_nodes[l-1], d_h) + rec_node_rep = mesh_rep_levels[ + level_l + ] # (B, num_mesh_nodes[l], d_h) + edge_rep = mesh_up_rep[level_l - 1] # Apply GNN - new_node_rep, new_edge_rep = gnn(send_node_rep, rec_node_rep, edge_rep) + new_node_rep, new_edge_rep = gnn( + send_node_rep, rec_node_rep, edge_rep + ) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = new_node_rep # (B, N_mesh[l], d_h) - mesh_up_rep[level_l-1] = new_edge_rep # (B, M_up[l-1], d_h) + mesh_rep_levels[level_l] = ( + new_node_rep # (B, num_mesh_nodes[l], d_h) + ) + mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h) # - PROCESSOR - - mesh_rep_levels, _, _, mesh_down_rep = self.hi_processor_step(mesh_rep_levels, - mesh_same_rep, mesh_up_rep, mesh_down_rep) + mesh_rep_levels, _, _, mesh_down_rep = self.hi_processor_step( + mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep + ) # - MESH READ OUT. - # Let level_l go from L-1 to 0 for level_l, gnn in zip( - range(self.N_levels-2, -1, -1), - reversed(self.mesh_read_gnns)): + range(self.num_levels - 2, -1, -1), reversed(self.mesh_read_gnns) + ): # Extract representations - send_node_rep = mesh_rep_levels[level_l+1] # (B, N_mesh[l+1], d_h) - rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) + send_node_rep = mesh_rep_levels[ + level_l + 1 + ] # (B, num_mesh_nodes[l+1], d_h) + rec_node_rep = mesh_rep_levels[ + level_l + ] # (B, num_mesh_nodes[l], d_h) edge_rep = mesh_down_rep[level_l] # Apply GNN new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = new_node_rep # (B, N_mesh[l], d_h) + mesh_rep_levels[level_l] = ( + new_node_rep # (B, num_mesh_nodes[l], d_h) + ) # Return only bottom level representation - return mesh_rep_levels[0] # (B, N_mesh[0], d_h) + return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h) - def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, - mesh_down_rep): + def hi_processor_step( + self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep + ): """ Internal processor step of hierarchical graph models. Between mesh init and read out. Each input is list with representations, each with shape - mesh_rep_levels: (B, N_mesh[l], d_h) + mesh_rep_levels: (B, num_mesh_nodes[l], d_h) mesh_same_rep: (B, M_same[l], d_h) mesh_up_rep: (B, M_up[l -> l+1], d_h) mesh_down_rep: (B, M_down[l <- l+1], d_h) diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index 316f765c..f767fba0 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -1,42 +1,58 @@ -import torch +# Third-party import torch_geometric as pyg +# First-party from neural_lam import utils -from neural_lam.models.base_graph_model import BaseGraphModel from neural_lam.interaction_net import InteractionNet +from neural_lam.models.base_graph_model import BaseGraphModel + class GraphLAM(BaseGraphModel): """ - Full graph-based LAM model that can be used with different (non-hierarchical )graphs. - Mainly based on GraphCast, but the model from Keisler (2022) almost identical. - Used for GC-LAM and L1-LAM in Oskarsson et al. (2023). + Full graph-based LAM model that can be used with different + (non-hierarchical )graphs. Mainly based on GraphCast, but the model from + Keisler (2022) is almost identical. Used for GC-LAM and L1-LAM in + Oskarsson et al. (2023). """ + def __init__(self, args): super().__init__(args) - assert not self.hierarchical, "GraphLAM does not use a hierarchical mesh graph" + assert ( + not self.hierarchical + ), "GraphLAM does not use a hierarchical mesh graph" # grid_dim from data + static + batch_static mesh_dim = self.mesh_static_features.shape[1] m2m_edges, m2m_dim = self.m2m_features.shape - print(f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, " - f"m2g={self.m2g_edges}") + print( + f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, " + f"m2g={self.m2g_edges}" + ) # Define sub-models # Feature embedders for mesh - self.mesh_embedder = utils.make_mlp([mesh_dim] + - self.mlp_blueprint_end) - self.m2m_embedder = utils.make_mlp([m2m_dim] + - self.mlp_blueprint_end) + self.mesh_embedder = utils.make_mlp([mesh_dim] + self.mlp_blueprint_end) + self.m2m_embedder = utils.make_mlp([m2m_dim] + self.mlp_blueprint_end) # GNNs # processor - processor_nets = [InteractionNet(self.m2m_edge_index, - args.hidden_dim, hidden_layers=args.hidden_layers, aggr=args.mesh_aggr) - for _ in range(args.processor_layers)] - self.processor = pyg.nn.Sequential("mesh_rep, edge_rep", [ + processor_nets = [ + InteractionNet( + self.m2m_edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + aggr=args.mesh_aggr, + ) + for _ in range(args.processor_layers) + ] + self.processor = pyg.nn.Sequential( + "mesh_rep, edge_rep", + [ (net, "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep") - for net in processor_nets]) + for net in processor_nets + ], + ) def get_num_mesh(self): """ @@ -47,10 +63,10 @@ def get_num_mesh(self): def embedd_mesh_nodes(self): """ - Embedd static mesh features + Embed static mesh features Returns tensor of shape (N_mesh, d_h) """ - return self.mesh_embedder(self.mesh_static_features) # (N_mesh, d_h) + return self.mesh_embedder(self.mesh_static_features) # (N_mesh, d_h) def process_step(self, mesh_rep): """ @@ -60,10 +76,14 @@ def process_step(self, mesh_rep): mesh_rep: has shape (B, N_mesh, d_h) Returns mesh_rep: (B, N_mesh, d_h) """ - # Embedd m2m here first + # Embed m2m here first batch_size = mesh_rep.shape[0] - m2m_emb = self.m2m_embedder(self.m2m_features) # (M_mesh, d_h) - m2m_emb_expanded = self.expand_to_batch(m2m_emb, batch_size) # (B, M_mesh, d_h) + m2m_emb = self.m2m_embedder(self.m2m_features) # (M_mesh, d_h) + m2m_emb_expanded = self.expand_to_batch( + m2m_emb, batch_size + ) # (B, M_mesh, d_h) - mesh_rep, _ = self.processor(mesh_rep, m2m_emb_expanded) # (B, N_mesh, d_h) + mesh_rep, _ = self.processor( + mesh_rep, m2m_emb_expanded + ) # (B, N_mesh, d_h) return mesh_rep diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py index 53367d58..4d7eb94c 100644 --- a/neural_lam/models/hi_lam.py +++ b/neural_lam/models/hi_lam.py @@ -1,120 +1,168 @@ -import torch +# Third-party from torch import nn -from neural_lam import utils -from neural_lam.models.base_hi_graph_model import BaseHiGraphModel +# First-party from neural_lam.interaction_net import InteractionNet +from neural_lam.models.base_hi_graph_model import BaseHiGraphModel + class HiLAM(BaseHiGraphModel): """ - Hierarchical graph model with message passing that goes sequentially down and up - the hierarchy during processing. + Hierarchical graph model with message passing that goes sequentially down + and up the hierarchy during processing. The Hi-LAM model from Oskarsson et al. (2023) """ + def __init__(self, args): super().__init__(args) # Make down GNNs, both for down edges and same level - self.mesh_down_gnns = nn.ModuleList([self.make_down_gnns(args) for _ in - range(args.processor_layers)]) # Nested lists (proc_steps, N_levels-1) - self.mesh_down_same_gnns = nn.ModuleList([self.make_same_gnns(args) for _ in - range(args.processor_layers)]) # Nested lists (proc_steps, N_levels) + self.mesh_down_gnns = nn.ModuleList( + [self.make_down_gnns(args) for _ in range(args.processor_layers)] + ) # Nested lists (proc_steps, num_levels-1) + self.mesh_down_same_gnns = nn.ModuleList( + [self.make_same_gnns(args) for _ in range(args.processor_layers)] + ) # Nested lists (proc_steps, num_levels) # Make up GNNs, both for up edges and same level - self.mesh_up_gnns = nn.ModuleList([self.make_up_gnns(args) for _ in - range(args.processor_layers)]) # Nested lists (proc_steps, N_levels-1) - self.mesh_up_same_gnns = nn.ModuleList([self.make_same_gnns(args) for _ in - range(args.processor_layers)]) # Nested lists (proc_steps, N_levels) + self.mesh_up_gnns = nn.ModuleList( + [self.make_up_gnns(args) for _ in range(args.processor_layers)] + ) # Nested lists (proc_steps, num_levels-1) + self.mesh_up_same_gnns = nn.ModuleList( + [self.make_same_gnns(args) for _ in range(args.processor_layers)] + ) # Nested lists (proc_steps, num_levels) def make_same_gnns(self, args): """ Make intra-level GNNs. """ - return nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers) - for edge_index in self.m2m_edge_index]) + return nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + ) + for edge_index in self.m2m_edge_index + ] + ) def make_up_gnns(self, args): """ Make GNNs for processing steps up through the hierarchy. """ - return nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers) - for edge_index in self.mesh_up_edge_index]) + return nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + ) + for edge_index in self.mesh_up_edge_index + ] + ) def make_down_gnns(self, args): """ Make GNNs for processing steps down through the hierarchy. """ - return nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers) - for edge_index in self.mesh_down_edge_index]) - - def mesh_down_step(self, mesh_rep_levels, mesh_same_rep, mesh_down_rep, down_gnns, - same_gnns): + return nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + ) + for edge_index in self.mesh_down_edge_index + ] + ) + + def mesh_down_step( + self, + mesh_rep_levels, + mesh_same_rep, + mesh_down_rep, + down_gnns, + same_gnns, + ): """ - Run down-part of vertical processing, sequentially alternating between processing - using down edges and same-level edges. + Run down-part of vertical processing, sequentially alternating between + processing using down edges and same-level edges. """ # Run same level processing on level L - mesh_rep_levels[-1], mesh_same_rep[-1] = same_gnns[-1](mesh_rep_levels[-1], - mesh_rep_levels[-1], mesh_same_rep[-1]) + mesh_rep_levels[-1], mesh_same_rep[-1] = same_gnns[-1]( + mesh_rep_levels[-1], mesh_rep_levels[-1], mesh_same_rep[-1] + ) # Let level_l go from L-1 to 0 for level_l, down_gnn, same_gnn in zip( - range(self.N_levels-2, -1, -1), - reversed(down_gnns), reversed(same_gnns[:-1])): + range(self.num_levels - 2, -1, -1), + reversed(down_gnns), + reversed(same_gnns[:-1]), + ): # Extract representations - send_node_rep = mesh_rep_levels[level_l+1] # (B, N_mesh[l+1], d_h) - rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) + send_node_rep = mesh_rep_levels[ + level_l + 1 + ] # (B, N_mesh[l+1], d_h) + rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) down_edge_rep = mesh_down_rep[level_l] same_edge_rep = mesh_same_rep[level_l] # Apply down GNN - new_node_rep, mesh_down_rep[level_l] = down_gnn(send_node_rep, rec_node_rep, - down_edge_rep) + new_node_rep, mesh_down_rep[level_l] = down_gnn( + send_node_rep, rec_node_rep, down_edge_rep + ) # Run same level processing on level l - mesh_rep_levels[level_l], mesh_same_rep[level_l] = same_gnn(new_node_rep, - new_node_rep, same_edge_rep) + mesh_rep_levels[level_l], mesh_same_rep[level_l] = same_gnn( + new_node_rep, new_node_rep, same_edge_rep + ) # (B, N_mesh[l], d_h) and (B, M_same[l], d_h) return mesh_rep_levels, mesh_same_rep, mesh_down_rep - def mesh_up_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, up_gnns, - same_gnns): + def mesh_up_step( + self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, up_gnns, same_gnns + ): """ - Run up-part of vertical processing, sequentially alternating between processing - using up edges and same-level edges. + Run up-part of vertical processing, sequentially alternating between + processing using up edges and same-level edges. """ # Run same level processing on level 0 - mesh_rep_levels[0], mesh_same_rep[0] = same_gnns[0](mesh_rep_levels[0], - mesh_rep_levels[0], mesh_same_rep[0]) + mesh_rep_levels[0], mesh_same_rep[0] = same_gnns[0]( + mesh_rep_levels[0], mesh_rep_levels[0], mesh_same_rep[0] + ) # Let level_l go from 1 to L - for level_l, (up_gnn, same_gnn) in enumerate(zip(up_gnns, same_gnns[1:]), - start=1): + for level_l, (up_gnn, same_gnn) in enumerate( + zip(up_gnns, same_gnns[1:]), start=1 + ): # Extract representations - send_node_rep = mesh_rep_levels[level_l-1] # (B, N_mesh[l-1], d_h) - rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) - up_edge_rep = mesh_up_rep[level_l-1] + send_node_rep = mesh_rep_levels[ + level_l - 1 + ] # (B, N_mesh[l-1], d_h) + rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) + up_edge_rep = mesh_up_rep[level_l - 1] same_edge_rep = mesh_same_rep[level_l] # Apply up GNN - new_node_rep, mesh_up_rep[level_l-1] = up_gnn(send_node_rep, rec_node_rep, - up_edge_rep) + new_node_rep, mesh_up_rep[level_l - 1] = up_gnn( + send_node_rep, rec_node_rep, up_edge_rep + ) # (B, N_mesh[l], d_h) and (B, M_up[l-1], d_h) # Run same level processing on level l - mesh_rep_levels[level_l], mesh_same_rep[level_l] = same_gnn(new_node_rep, - new_node_rep, same_edge_rep) + mesh_rep_levels[level_l], mesh_same_rep[level_l] = same_gnn( + new_node_rep, new_node_rep, same_edge_rep + ) # (B, N_mesh[l], d_h) and (B, M_same[l], d_h) return mesh_rep_levels, mesh_same_rep, mesh_up_rep - def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, - mesh_down_rep): + def hi_processor_step( + self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep + ): """ Internal processor step of hierarchical graph models. Between mesh init and read out. @@ -128,17 +176,29 @@ def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, Returns same lists """ - for down_gnns, down_same_gnns, up_gnns, up_same_gnns in zip(self.mesh_down_gnns, - self.mesh_down_same_gnns, self.mesh_up_gnns, self.mesh_up_same_gnns): + for down_gnns, down_same_gnns, up_gnns, up_same_gnns in zip( + self.mesh_down_gnns, + self.mesh_down_same_gnns, + self.mesh_up_gnns, + self.mesh_up_same_gnns, + ): # Down mesh_rep_levels, mesh_same_rep, mesh_down_rep = self.mesh_down_step( - mesh_rep_levels, mesh_same_rep, mesh_down_rep, down_gnns, - down_same_gnns) + mesh_rep_levels, + mesh_same_rep, + mesh_down_rep, + down_gnns, + down_same_gnns, + ) # Up mesh_rep_levels, mesh_same_rep, mesh_up_rep = self.mesh_up_step( - mesh_rep_levels, mesh_same_rep, mesh_up_rep, up_gnns, - up_same_gnns) + mesh_rep_levels, + mesh_same_rep, + mesh_up_rep, + up_gnns, + up_same_gnns, + ) # Note: We return all, even though only down edges really are used later return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py index 225b3095..740824e1 100644 --- a/neural_lam/models/hi_lam_parallel.py +++ b/neural_lam/models/hi_lam_parallel.py @@ -1,41 +1,58 @@ +# Third-party import torch import torch_geometric as pyg -from neural_lam import utils -from neural_lam.models.base_hi_graph_model import BaseHiGraphModel +# First-party from neural_lam.interaction_net import InteractionNet +from neural_lam.models.base_hi_graph_model import BaseHiGraphModel + class HiLAMParallel(BaseHiGraphModel): """ - Version of HiLAM where all message passing in the hierarchical mesh (up, down, - inter-level) is ran in paralell. + Version of HiLAM where all message passing in the hierarchical mesh (up, + down, inter-level) is ran in parallel. - This is a somewhat simpler alternative to the sequential message passing of Hi-LAM. + This is a somewhat simpler alternative to the sequential message passing + of Hi-LAM. """ + def __init__(self, args): super().__init__(args) # Processor GNNs - # Create the complete total edge_index combining all edges for processing - total_edge_index_list = list(self.m2m_edge_index) +\ - list(self.mesh_up_edge_index) + list(self.mesh_down_edge_index) + # Create the complete edge_index combining all edges for processing + total_edge_index_list = ( + list(self.m2m_edge_index) + + list(self.mesh_up_edge_index) + + list(self.mesh_down_edge_index) + ) total_edge_index = torch.cat(total_edge_index_list, dim=1) self.edge_split_sections = [ei.shape[1] for ei in total_edge_index_list] if args.processor_layers == 0: - self.processor = (lambda x, edge_attr: (x, edge_attr)) + self.processor = lambda x, edge_attr: (x, edge_attr) else: - processor_nets = [InteractionNet(total_edge_index, args.hidden_dim, - hidden_layers=args.hidden_layers, - edge_chunk_sizes=self.edge_split_sections, - aggr_chunk_sizes=self.N_mesh_levels) - for _ in range(args.processor_layers)] - self.processor = pyg.nn.Sequential("mesh_rep, edge_rep", [ - (net, "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep") - for net in processor_nets]) - - def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, - mesh_down_rep): + processor_nets = [ + InteractionNet( + total_edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + edge_chunk_sizes=self.edge_split_sections, + aggr_chunk_sizes=self.level_mesh_sizes, + ) + for _ in range(args.processor_layers) + ] + self.processor = pyg.nn.Sequential( + "mesh_rep, edge_rep", + [ + (net, "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep") + for net in processor_nets + ], + ) + + def hi_processor_step( + self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep + ): """ Internal processor step of hierarchical graph models. Between mesh init and read out. @@ -51,23 +68,29 @@ def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, """ # First join all node and edge representations to single tensors - mesh_rep = torch.cat(mesh_rep_levels, dim=1) # (B, N_mesh, d_h) - mesh_edge_rep = torch.cat(mesh_same_rep + mesh_up_rep + mesh_down_rep, - axis=1) # (B, M_mesh, d_h) + mesh_rep = torch.cat(mesh_rep_levels, dim=1) # (B, N_mesh, d_h) + mesh_edge_rep = torch.cat( + mesh_same_rep + mesh_up_rep + mesh_down_rep, axis=1 + ) # (B, M_mesh, d_h) # Here, update mesh_*_rep and mesh_rep mesh_rep, mesh_edge_rep = self.processor(mesh_rep, mesh_edge_rep) # Split up again for read-out step - mesh_rep_levels = list(torch.split(mesh_rep, self.N_mesh_levels, dim=1)) - mesh_edge_rep_sections = torch.split(mesh_edge_rep, self.edge_split_sections, - dim=1) + mesh_rep_levels = list( + torch.split(mesh_rep, self.level_mesh_sizes, dim=1) + ) + mesh_edge_rep_sections = torch.split( + mesh_edge_rep, self.edge_split_sections, dim=1 + ) - mesh_same_rep = mesh_edge_rep_sections[:self.N_levels] + mesh_same_rep = mesh_edge_rep_sections[: self.num_levels] mesh_up_rep = mesh_edge_rep_sections[ - self.N_levels:self.N_levels+(self.N_levels-1)] + self.num_levels : self.num_levels + (self.num_levels - 1) + ] mesh_down_rep = mesh_edge_rep_sections[ - self.N_levels+(self.N_levels-1):] # Last are down edges + self.num_levels + (self.num_levels - 1) : + ] # Last are down edges # Note: We return all, even though only down edges really are used later return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 7c7de96b..31715502 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,20 +1,31 @@ +# Standard library import os -import torch -import torch.nn as nn + +# Third-party import numpy as np +import torch +from torch import nn from tueplots import bundles, figsizes +# First-party from neural_lam import constants + def load_dataset_stats(dataset_name, device="cpu"): + """ + Load arrays with stored dataset statistics from pre-processing + """ static_dir_path = os.path.join("data", dataset_name, "static") - loads_file = lambda fn: torch.load(os.path.join(static_dir_path, fn), - map_location=device) - data_mean = loads_file("parameter_mean.pt") # (d_features,) - data_std = loads_file("parameter_std.pt") # (d_features,) + def loads_file(fn): + return torch.load( + os.path.join(static_dir_path, fn), map_location=device + ) - flux_stats = loads_file("flux_stats.pt") # (2,) + data_mean = loads_file("parameter_mean.pt") # (d_features,) + data_std = loads_file("parameter_std.pt") # (d_features,) + + flux_stats = loads_file("flux_stats.pt") # (2,) flux_mean, flux_std = flux_stats return { @@ -24,30 +35,44 @@ def load_dataset_stats(dataset_name, device="cpu"): "flux_std": flux_std, } + def load_static_data(dataset_name, device="cpu"): + """ + Load static files related to dataset + """ static_dir_path = os.path.join("data", dataset_name, "static") - loads_file = lambda fn: torch.load(os.path.join(static_dir_path, fn), - map_location=device) + + def loads_file(fn): + return torch.load( + os.path.join(static_dir_path, fn), map_location=device + ) # Load border mask, 1. if node is part of border, else 0. border_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy")) - border_mask = torch.tensor(border_mask_np, dtype=torch.float32, - device=device).flatten(0,1).unsqueeze(1) # (N_grid, 1) + border_mask = ( + torch.tensor(border_mask_np, dtype=torch.float32, device=device) + .flatten(0, 1) + .unsqueeze(1) + ) # (N_grid, 1) - grid_static_features = loads_file("grid_features.pt") # (N_grid, d_grid_static) + grid_static_features = loads_file( + "grid_features.pt" + ) # (N_grid, d_grid_static) # Load step diff stats - step_diff_mean = loads_file("diff_mean.pt") # (d_f,) - step_diff_std = loads_file("diff_std.pt") # (d_f,) + step_diff_mean = loads_file("diff_mean.pt") # (d_f,) + step_diff_std = loads_file("diff_std.pt") # (d_f,) # Load parameter std for computing validation errors in original data scale - data_mean = loads_file("parameter_mean.pt") # (d_features,) - data_std = loads_file("parameter_std.pt") # (d_features,) + data_mean = loads_file("parameter_mean.pt") # (d_features,) + data_std = loads_file("parameter_std.pt") # (d_features,) # Load loss weighting vectors - param_weights = torch.tensor(np.load(os.path.join(static_dir_path, - "parameter_weights.npy")), dtype=torch.float32, - device=device) # (d_f,) + param_weights = torch.tensor( + np.load(os.path.join(static_dir_path, "parameter_weights.npy")), + dtype=torch.float32, + device=device, + ) # (d_f,) return { "border_mask": border_mask, @@ -59,14 +84,16 @@ def load_static_data(dataset_name, device="cpu"): "param_weights": param_weights, } + class BufferList(nn.Module): """ - A list of torch buffer tensors that sit together as a Module with no parameters and only - buffers. + A list of torch buffer tensors that sit together as a Module with no + parameters and only buffers. This should be replaced by a native torch BufferList once implemented. See: https://github.com/pytorch/pytorch/issues/37386 """ + def __init__(self, buffer_tensors, persistent=True): super().__init__() self.n_buffers = len(buffer_tensors) @@ -82,20 +109,26 @@ def __len__(self): def __iter__(self): return (self[i] for i in range(len(self))) + def load_graph(graph_name, device="cpu"): + """ + Load all tensors representing the graph + """ # Define helper lambda function graph_dir_path = os.path.join("graphs", graph_name) - loads_file = lambda fn: torch.load(os.path.join(graph_dir_path, fn), - map_location=device) + + def loads_file(fn): + return torch.load(os.path.join(graph_dir_path, fn), map_location=device) # Load edges (edge_index) - m2m_edge_index = BufferList(loads_file("m2m_edge_index.pt"), - persistent=False) # List of (2, M_m2m[l]) + m2m_edge_index = BufferList( + loads_file("m2m_edge_index.pt"), persistent=False + ) # List of (2, M_m2m[l]) g2m_edge_index = loads_file("g2m_edge_index.pt") # (2, M_g2m) m2g_edge_index = loads_file("m2g_edge_index.pt") # (2, M_m2g) n_levels = len(m2m_edge_index) - hierarchical = n_levels > 1 # Nor just single level mesh graph + hierarchical = n_levels > 1 # Nor just single level mesh graph # Load static edge features m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f) @@ -103,63 +136,92 @@ def load_graph(graph_name, device="cpu"): m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f) # Normalize by dividing with longest edge (found in m2m) - longest_edge = max([torch.max(level_features[:,0]) - for level_features in m2m_features]) # Col. 0 is length - m2m_features = BufferList([level_features / longest_edge - for level_features in m2m_features], persistent=False) + longest_edge = max( + torch.max(level_features[:, 0]) for level_features in m2m_features + ) # Col. 0 is length + m2m_features = BufferList( + [level_features / longest_edge for level_features in m2m_features], + persistent=False, + ) g2m_features = g2m_features / longest_edge m2g_features = m2g_features / longest_edge # Load static node features - mesh_static_features = loads_file("mesh_features.pt" - ) # List of (N_mesh[l], d_mesh_static) + mesh_static_features = loads_file( + "mesh_features.pt" + ) # List of (N_mesh[l], d_mesh_static) # Some checks for consistency - assert len(m2m_features) == n_levels, "Inconsistent number of levels in mesh" - assert len(mesh_static_features) == n_levels, "Inconsistent number of levels in mesh" + assert ( + len(m2m_features) == n_levels + ), "Inconsistent number of levels in mesh" + assert ( + len(mesh_static_features) == n_levels + ), "Inconsistent number of levels in mesh" if hierarchical: # Load up and down edges and features - mesh_up_edge_index = BufferList(loads_file("mesh_up_edge_index.pt"), - persistent=False) # List of (2, M_up[l]) - mesh_down_edge_index = BufferList(loads_file("mesh_down_edge_index.pt"), - persistent=False) # List of (2, M_down[l]) - - mesh_up_features = loads_file("mesh_up_features.pt" - ) # List of (M_up[l], d_edge_f) - mesh_down_features = loads_file("mesh_down_features.pt" - ) # List of (M_down[l], d_edge_f) + mesh_up_edge_index = BufferList( + loads_file("mesh_up_edge_index.pt"), persistent=False + ) # List of (2, M_up[l]) + mesh_down_edge_index = BufferList( + loads_file("mesh_down_edge_index.pt"), persistent=False + ) # List of (2, M_down[l]) + + mesh_up_features = loads_file( + "mesh_up_features.pt" + ) # List of (M_up[l], d_edge_f) + mesh_down_features = loads_file( + "mesh_down_features.pt" + ) # List of (M_down[l], d_edge_f) # Rescale - mesh_up_features = BufferList([edge_features / longest_edge - for edge_features in mesh_up_features], persistent=False) - mesh_down_features = BufferList([edge_features / longest_edge - for edge_features in mesh_down_features], persistent=False) - - mesh_static_features = BufferList(mesh_static_features, persistent=False) + mesh_up_features = BufferList( + [ + edge_features / longest_edge + for edge_features in mesh_up_features + ], + persistent=False, + ) + mesh_down_features = BufferList( + [ + edge_features / longest_edge + for edge_features in mesh_down_features + ], + persistent=False, + ) + + mesh_static_features = BufferList( + mesh_static_features, persistent=False + ) else: # Extract single mesh level m2m_edge_index = m2m_edge_index[0] m2m_features = m2m_features[0] mesh_static_features = mesh_static_features[0] - mesh_up_edge_index, mesh_down_edge_index, mesh_up_features, mesh_down_features =\ - [], [], [], [] + ( + mesh_up_edge_index, + mesh_down_edge_index, + mesh_up_features, + mesh_down_features, + ) = ([], [], [], []) return hierarchical, { - "g2m_edge_index": g2m_edge_index, - "m2g_edge_index": m2g_edge_index, - "m2m_edge_index": m2m_edge_index, - "mesh_up_edge_index": mesh_up_edge_index, - "mesh_down_edge_index": mesh_down_edge_index, - "g2m_features": g2m_features, - "m2g_features": m2g_features, - "m2m_features": m2m_features, - "mesh_up_features": mesh_up_features, - "mesh_down_features": mesh_down_features, - "mesh_static_features": mesh_static_features, + "g2m_edge_index": g2m_edge_index, + "m2g_edge_index": m2g_edge_index, + "m2m_edge_index": m2m_edge_index, + "mesh_up_edge_index": mesh_up_edge_index, + "mesh_down_edge_index": mesh_down_edge_index, + "g2m_features": g2m_features, + "m2g_features": m2g_features, + "m2m_features": m2m_features, + "mesh_up_features": mesh_up_features, + "mesh_down_features": mesh_down_features, + "mesh_static_features": mesh_static_features, } + def make_mlp(blueprint, layer_norm=True): """ Create MLP from list blueprint, with @@ -177,7 +239,7 @@ def make_mlp(blueprint, layer_norm=True): for layer_i, (dim1, dim2) in enumerate(zip(blueprint[:-1], blueprint[1:])): layers.append(nn.Linear(dim1, dim2)) if layer_i != hidden_layers: - layers.append(nn.SiLU()) # Swish activation + layers.append(nn.SiLU()) # Swish activation # Optionally add layer norm to output if layer_norm: @@ -185,21 +247,27 @@ def make_mlp(blueprint, layer_norm=True): return nn.Sequential(*layers) + def fractional_plot_bundle(fraction): """ - Get the tueplots bundle, but with figure width as a fraction of the page width. + Get the tueplots bundle, but with figure width as a fraction of + the page width. """ bundle = bundles.neurips2023(usetex=True, family="serif") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] - bundle["figure.figsize"] = (original_figsize[0]/fraction, original_figsize[1]) + bundle["figure.figsize"] = ( + original_figsize[0] / fraction, + original_figsize[1], + ) return bundle + def init_wandb_metrics(wandb_logger): """ Set up wandb metrics to track """ experiment = wandb_logger.experiment experiment.define_metric("val_mean_loss", summary="min") - for step in constants.val_step_log_errors: + for step in constants.VAL_STEP_LOG_ERRORS: experiment.define_metric(f"val_loss_unroll{step}", summary="min") diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 71f1e418..cef34a84 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -1,53 +1,67 @@ -import torch -import numpy as np +# Third-party import matplotlib import matplotlib.pyplot as plt -from tueplots import axes, bundles +import numpy as np + +# First-party +from neural_lam import constants, utils -from neural_lam import utils, constants @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_error_map(errors, title=None, step_length=3): """ - Plot a heatmap of errors of different variables at different predictions horizons + Plot a heatmap of errors of different variables at different + predictions horizons errors: (pred_steps, d_f) """ - errors_np = errors.T.cpu().numpy() # (d_f, pred_steps) + errors_np = errors.T.cpu().numpy() # (d_f, pred_steps) d_f, pred_steps = errors_np.shape # Normalize all errors to [0,1] for color map - max_errors = errors_np.max(axis=1) # d_f + max_errors = errors_np.max(axis=1) # d_f errors_norm = errors_np / np.expand_dims(max_errors, axis=1) - fig, ax = plt.subplots(figsize=(15,10)) + fig, ax = plt.subplots(figsize=(15, 10)) - ax.imshow(errors_norm, cmap="OrRd", vmin=0, vmax=1., interpolation="none", - aspect="auto", alpha=0.8) + ax.imshow( + errors_norm, + cmap="OrRd", + vmin=0, + vmax=1.0, + interpolation="none", + aspect="auto", + alpha=0.8, + ) # ax and labels - for (j,i),error in np.ndenumerate(errors_np): + for (j, i), error in np.ndenumerate(errors_np): # Numbers > 9999 will be too large to fit formatted_error = f"{error:.3f}" if error < 9999 else f"{error:.2E}" - ax.text(i,j, formatted_error,ha='center',va='center', usetex=False) + ax.text(i, j, formatted_error, ha="center", va="center", usetex=False) # Ticks and labels - label_size=15 + label_size = 15 ax.set_xticks(np.arange(pred_steps)) - pred_hor_i = np.arange(pred_steps)+1 # Prediction horiz. in index - pred_hor_h = step_length*pred_hor_i # Prediction horiz. in hours + pred_hor_i = np.arange(pred_steps) + 1 # Prediction horiz. in index + pred_hor_h = step_length * pred_hor_i # Prediction horiz. in hours ax.set_xticklabels(pred_hor_h, size=label_size) ax.set_xlabel("Lead time (h)", size=label_size) ax.set_yticks(np.arange(d_f)) - y_ticklabels = [f"{name} ({unit})" for name, unit in - zip(constants.param_names_short, constants.param_units)] - ax.set_yticklabels(y_ticklabels , rotation=30, size=label_size) + y_ticklabels = [ + f"{name} ({unit})" + for name, unit in zip( + constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS + ) + ] + ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) if title: ax.set_title(title, size=15) return fig + @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction(pred, target, obs_mask, title=None, vrange=None): """ @@ -62,18 +76,28 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.grid_shape) - pixel_alpha = mask_reshaped.clamp(0.7, 1).cpu().numpy() # Faded border region + mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + pixel_alpha = ( + mask_reshaped.clamp(0.7, 1).cpu().numpy() + ) # Faded border region - fig, axes = plt.subplots(1, 2, figsize=(13,7), - subplot_kw={"projection": constants.lambert_proj}) + fig, axes = plt.subplots( + 1, 2, figsize=(13, 7), subplot_kw={"projection": constants.LAMBERT_PROJ} + ) # Plot pred and target for ax, data in zip(axes, (target, pred)): - ax.coastlines() # Add coastline outlines - data_grid = data.reshape(*constants.grid_shape).cpu().numpy() - im = ax.imshow(data_grid, origin="lower", extent=constants.grid_limits, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="plasma") + ax.coastlines() # Add coastline outlines + data_grid = data.reshape(*constants.GRID_SHAPE).cpu().numpy() + im = ax.imshow( + data_grid, + origin="lower", + extent=constants.GRID_LIMITS, + alpha=pixel_alpha, + vmin=vmin, + vmax=vmax, + cmap="plasma", + ) # Ticks and labels axes[0].set_title("Ground Truth", size=15) @@ -86,6 +110,7 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): return fig + @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_spatial_error(error, obs_mask, title=None, vrange=None): """ @@ -100,17 +125,27 @@ def plot_spatial_error(error, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.grid_shape) - pixel_alpha = mask_reshaped.clamp(0.7, 1).cpu().numpy() # Faded border region - - fig, ax = plt.subplots(figsize=(5,4.8), - subplot_kw={"projection": constants.lambert_proj}) - - ax.coastlines() # Add coastline outlines - error_grid = error.reshape(*constants.grid_shape).cpu().numpy() - - im = ax.imshow(error_grid, origin="lower", extent=constants.grid_limits, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="OrRd") + mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + pixel_alpha = ( + mask_reshaped.clamp(0.7, 1).cpu().numpy() + ) # Faded border region + + fig, ax = plt.subplots( + figsize=(5, 4.8), subplot_kw={"projection": constants.LAMBERT_PROJ} + ) + + ax.coastlines() # Add coastline outlines + error_grid = error.reshape(*constants.GRID_SHAPE).cpu().numpy() + + im = ax.imshow( + error_grid, + origin="lower", + extent=constants.GRID_LIMITS, + alpha=pixel_alpha, + vmin=vmin, + vmax=vmax, + cmap="OrRd", + ) # Ticks and labels cbar = fig.colorbar(im, aspect=30) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 19438c72..015a140d 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,51 +1,76 @@ -import os +# Standard library +import datetime as dt import glob -import torch +import os + +# Third-party import numpy as np -import datetime as dt +import torch + +# First-party +from neural_lam import constants, utils -from neural_lam import utils, constants class WeatherDataset(torch.utils.data.Dataset): """ For our dataset: N_t' = 65 N_t = 65//subsample_step (= 21 for 3h steps) - N_x = 268 - N_y = 238 + dim_x = 268 + dim_y = 238 N_grid = 268x238 = 63784 d_features = 17 (d_features' = 18) d_forcing = 5 """ - def __init__(self, dataset_name, pred_length=19, split="train", subsample_step=3, - standardize=True, subset=False, control_only=False): + + def __init__( + self, + dataset_name, + pred_length=19, + split="train", + subsample_step=3, + standardize=True, + subset=False, + control_only=False, + ): super().__init__() assert split in ("train", "val", "test"), "Unknown dataset split" - self.sample_dir_path = os.path.join("data", dataset_name, "samples", split) - - member_file_regexp = "nwp*mbr000.npy" if control_only else "nwp*mbr*.npy" - sample_paths = glob.glob(os.path.join(self.sample_dir_path, member_file_regexp)) + self.sample_dir_path = os.path.join( + "data", dataset_name, "samples", split + ) + + member_file_regexp = ( + "nwp*mbr000.npy" if control_only else "nwp*mbr*.npy" + ) + sample_paths = glob.glob( + os.path.join(self.sample_dir_path, member_file_regexp) + ) self.sample_names = [path.split("/")[-1][4:-4] for path in sample_paths] # Now on form "yyymmddhh_mbrXXX" if subset: - self.sample_names = self.sample_names[:50] # Limit to 50 samples + self.sample_names = self.sample_names[:50] # Limit to 50 samples - self.sample_length = pred_length + 2 # 2 init states + self.sample_length = pred_length + 2 # 2 init states self.subsample_step = subsample_step - self.original_sample_length = 65//self.subsample_step # 21 for 3h steps - assert self.sample_length <= self.original_sample_length, ( - "Requesting too long time series samples") + self.original_sample_length = ( + 65 // self.subsample_step + ) # 21 for 3h steps + assert ( + self.sample_length <= self.original_sample_length + ), "Requesting too long time series samples" # Set up for standardization self.standardize = standardize if standardize: ds_stats = utils.load_dataset_stats(dataset_name, "cpu") - self.data_mean, self.data_std, self.flux_mean, self.flux_std =\ - ds_stats["data_mean"], ds_stats["data_std"], ds_stats["flux_mean"], \ - ds_stats["flux_std"] - + self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( + ds_stats["data_mean"], + ds_stats["data_std"], + ds_stats["flux_mean"], + ds_stats["flux_std"], + ) # If subsample index should be sampled (only duing training) self.random_subsample = split == "train" @@ -56,124 +81,170 @@ def __len__(self): def __getitem__(self, idx): # === Sample === sample_name = self.sample_names[idx] - sample_path = os.path.join(self.sample_dir_path, f"nwp_{sample_name}.npy") + sample_path = os.path.join( + self.sample_dir_path, f"nwp_{sample_name}.npy" + ) try: - full_sample = torch.tensor(np.load(sample_path), - dtype=torch.float32) # (N_t', N_x, N_y, d_features') + full_sample = torch.tensor( + np.load(sample_path), dtype=torch.float32 + ) # (N_t', dim_x, dim_y, d_features') except ValueError: print(f"Failed to load {sample_path}") # Only use every ss_step:th time step, sample which of ss_step # possible such time series if self.random_subsample: - subsample_index = torch.randint(0,self.subsample_step,()).item() + subsample_index = torch.randint(0, self.subsample_step, ()).item() else: subsample_index = 0 - subsample_end_index = self.original_sample_length*self.subsample_step - sample = full_sample[subsample_index:subsample_end_index:self.subsample_step] - # (N_t, N_x, N_y, d_features') + subsample_end_index = self.original_sample_length * self.subsample_step + sample = full_sample[ + subsample_index : subsample_end_index : self.subsample_step + ] + # (N_t, dim_x, dim_y, d_features') # Remove feature 15, "z_height_above_ground" - sample = torch.cat((sample[:,:,:,:15], sample[:,:,:,16:]), - dim=3) # (N_t, N_x, N_y, d_features) + sample = torch.cat( + (sample[:, :, :, :15], sample[:, :, :, 16:]), dim=3 + ) # (N_t, dim_x, dim_y, d_features) # Accumulate solar radiation instead of just subsampling - rad_features = full_sample[:,:,:,2:4] # (N_t', N_x, N_y, 2) + rad_features = full_sample[:, :, :, 2:4] # (N_t', dim_x, dim_y, 2) # Accumulate for first time step - init_accum_rad = torch.sum(rad_features[:(subsample_index+1)], - dim=0, keepdim=True) # (1, N_x, N_y, 2) + init_accum_rad = torch.sum( + rad_features[: (subsample_index + 1)], dim=0, keepdim=True + ) # (1, dim_x, dim_y, 2) # Accumulate for rest of subsampled sequence - in_subsample_len = subsample_end_index - self.subsample_step + subsample_index +1 - rad_features_in_subsample = rad_features[(subsample_index+1): - in_subsample_len] # (N_t*, N_x, N_y, 2), N_t* = (N_t-1)*ss_step - _, N_x, N_y, _ = sample.shape - rest_accum_rad = torch.sum(rad_features_in_subsample.view( - self.original_sample_length-1, self.subsample_step, N_x, N_y, 2 - ), dim=1) # (N_t-1, N_x, N_y, 2) - accum_rad = torch.cat((init_accum_rad, rest_accum_rad), - dim=0) # (N_t, N_x, N_y, 2) + in_subsample_len = ( + subsample_end_index - self.subsample_step + subsample_index + 1 + ) + rad_features_in_subsample = rad_features[ + (subsample_index + 1) : in_subsample_len + ] # (N_t*, dim_x, dim_y, 2), N_t* = (N_t-1)*ss_step + _, dim_x, dim_y, _ = sample.shape + rest_accum_rad = torch.sum( + rad_features_in_subsample.view( + self.original_sample_length - 1, + self.subsample_step, + dim_x, + dim_y, + 2, + ), + dim=1, + ) # (N_t-1, dim_x, dim_y, 2) + accum_rad = torch.cat( + (init_accum_rad, rest_accum_rad), dim=0 + ) # (N_t, dim_x, dim_y, 2) # Replace in sample - sample[:,:,:,2:4] = accum_rad + sample[:, :, :, 2:4] = accum_rad # Flatten spatial dim - sample = sample.flatten(1,2) # (N_t, N_grid, d_features) + sample = sample.flatten(1, 2) # (N_t, N_grid, d_features) # Uniformly sample time id to start sample from - init_id = torch.randint(0, 1+self.original_sample_length - self.sample_length, - ()) - sample = sample[init_id:(init_id+self.sample_length)] + init_id = torch.randint( + 0, 1 + self.original_sample_length - self.sample_length, () + ) + sample = sample[init_id : (init_id + self.sample_length)] # (sample_length, N_grid, d_features) if self.standardize: # Standardize sample - sample = (sample - self.data_mean)/self.data_std + sample = (sample - self.data_mean) / self.data_std # Split up sample in init. states and target states - init_states = sample[:2] # (2, N_grid, d_features) - target_states = sample[2:] # (sample_length-2, N_grid, d_features) + init_states = sample[:2] # (2, N_grid, d_features) + target_states = sample[2:] # (sample_length-2, N_grid, d_features) # === Static batch features === # Load water coverage sample_datetime = sample_name[:10] - water_path = os.path.join(self.sample_dir_path, f"wtr_{sample_datetime}.npy") - static_features = torch.tensor(np.load(water_path), - dtype=torch.float32).unsqueeze(-1) # (N_x, N_y, 1) + water_path = os.path.join( + self.sample_dir_path, f"wtr_{sample_datetime}.npy" + ) + static_features = torch.tensor( + np.load(water_path), dtype=torch.float32 + ).unsqueeze( + -1 + ) # (dim_x, dim_y, 1) # Flatten - static_features = static_features.flatten(0,1) # (N_grid, 1) + static_features = static_features.flatten(0, 1) # (N_grid, 1) # === Forcing features === # Forcing features - flux_path = os.path.join(self.sample_dir_path, - f"nwp_toa_downwelling_shortwave_flux_{sample_datetime}.npy") - flux = torch.tensor(np.load(flux_path), - dtype=torch.float32).unsqueeze(-1) # (N_t', N_x, N_y, 1) + flux_path = os.path.join( + self.sample_dir_path, + f"nwp_toa_downwelling_shortwave_flux_{sample_datetime}.npy", + ) + flux = torch.tensor(np.load(flux_path), dtype=torch.float32).unsqueeze( + -1 + ) # (N_t', dim_x, dim_y, 1) if self.standardize: - flux = (flux - self.flux_mean)/self.flux_std + flux = (flux - self.flux_mean) / self.flux_std # Flatten and subsample flux forcing - flux = flux.flatten(1,2) # (N_t, N_grid, 1) - flux = flux[subsample_index::self.subsample_step] # (N_t, N_grid, 1) - flux = flux[init_id:(init_id+self.sample_length)] # (sample_len, N_grid, 1) + flux = flux.flatten(1, 2) # (N_t, N_grid, 1) + flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1) + flux = flux[ + init_id : (init_id + self.sample_length) + ] # (sample_len, N_grid, 1) # Time of day and year - dt_obj = dt.datetime.strptime(sample_datetime, '%Y%m%d%H') - dt_obj = dt_obj + dt.timedelta(hours=2+subsample_index) # Offset for first index + dt_obj = dt.datetime.strptime(sample_datetime, "%Y%m%d%H") + dt_obj = dt_obj + dt.timedelta( + hours=2 + subsample_index + ) # Offset for first index # Extract for initial step init_hour_in_day = dt_obj.hour - start_of_year = dt.datetime(dt_obj.year,1,1) - init_seconds_into_year = (dt_obj-start_of_year).total_seconds() + start_of_year = dt.datetime(dt_obj.year, 1, 1) + init_seconds_into_year = (dt_obj - start_of_year).total_seconds() # Add increments for all steps - hour_inc = torch.arange(self.sample_length)*self.subsample_step # (sample_len,) - hour_of_day = init_hour_in_day + hour_inc # (sample_len,), Can be > 24 but ok - second_into_year = init_seconds_into_year + hour_inc*3600 # (sample_len,) - #can roll over to next year, ok because periodicity + hour_inc = ( + torch.arange(self.sample_length) * self.subsample_step + ) # (sample_len,) + hour_of_day = ( + init_hour_in_day + hour_inc + ) # (sample_len,), Can be > 24 but ok + second_into_year = ( + init_seconds_into_year + hour_inc * 3600 + ) # (sample_len,) + # can roll over to next year, ok because periodicity # Encode as sin/cos - hour_angle = (hour_of_day/12)*torch.pi # (sample_len,) - year_angle = (second_into_year/constants.seconds_in_year - )*2*torch.pi # (sample_len,) - datetime_forcing = torch.stack(( - torch.sin(hour_angle), - torch.cos(hour_angle), - torch.sin(year_angle), - torch.cos(year_angle), - ), dim=1) # (N_t, 4) - datetime_forcing = (datetime_forcing + 1)/2 # Rescale to [0,1] - datetime_forcing = datetime_forcing.unsqueeze(1).expand(-1, - flux.shape[1], -1) # (sample_len, N_grid, 4) + hour_angle = (hour_of_day / 12) * torch.pi # (sample_len,) + year_angle = ( + (second_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi + ) # (sample_len,) + datetime_forcing = torch.stack( + ( + torch.sin(hour_angle), + torch.cos(hour_angle), + torch.sin(year_angle), + torch.cos(year_angle), + ), + dim=1, + ) # (N_t, 4) + datetime_forcing = (datetime_forcing + 1) / 2 # Rescale to [0,1] + datetime_forcing = datetime_forcing.unsqueeze(1).expand( + -1, flux.shape[1], -1 + ) # (sample_len, N_grid, 4) # Put forcing features together - forcing_features = torch.cat((flux, datetime_forcing), - dim=-1) # (sample_len, N_grid, d_forcing) + forcing_features = torch.cat( + (flux, datetime_forcing), dim=-1 + ) # (sample_len, N_grid, d_forcing) # Combine forcing over each window of 3 time steps - forcing_windowed = torch.cat(( - forcing_features[:-2], - forcing_features[1:-1], - forcing_features[2:], - ), dim=2) # (sample_len-2, N_grid, 3*d_forcing) + forcing_windowed = torch.cat( + ( + forcing_features[:-2], + forcing_features[1:-1], + forcing_features[2:], + ), + dim=2, + ) # (sample_len-2, N_grid, 3*d_forcing) # Now index 0 of ^ corresponds to forcing at index 0-2 of sample return init_states, target_states, static_features, forcing_windowed diff --git a/plot_graph.py b/plot_graph.py index 93531657..48427d5c 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -1,44 +1,78 @@ -import os +# Standard library from argparse import ArgumentParser + +# Third-party import numpy as np import plotly.graph_objects as go import torch_geometric as pyg +# First-party from neural_lam import utils MESH_HEIGHT = 0.1 MESH_LEVEL_DIST = 0.2 GRID_HEIGHT = 0 + def main(): - parser = ArgumentParser(description='Plot graph') - parser.add_argument('--dataset', type=str, default="meps_example", - help='Datast to load grid coordinates from (default: meps_example)') - parser.add_argument('--graph', type=str, default="multiscale", - help='Graph to plot (default: multiscale)') - parser.add_argument('--save', type=str, - help='Name of .html file to save interactive plot to (default: None)') - parser.add_argument('--show_axis', type=int, default=0, - help='If the axis should be displayed (default: 0 (No))') + """ + Plot graph structure in 3D using plotly + """ + parser = ArgumentParser(description="Plot graph") + parser.add_argument( + "--dataset", + type=str, + default="meps_example", + help="Datast to load grid coordinates from (default: meps_example)", + ) + parser.add_argument( + "--graph", + type=str, + default="multiscale", + help="Graph to plot (default: multiscale)", + ) + parser.add_argument( + "--save", + type=str, + help="Name of .html file to save interactive plot to (default: None)", + ) + parser.add_argument( + "--show_axis", + type=int, + default=0, + help="If the axis should be displayed (default: 0 (No))", + ) args = parser.parse_args() # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) - g2m_edge_index, m2g_edge_index, m2m_edge_index, =\ - graph_ldict["g2m_edge_index"], graph_ldict["m2g_edge_index"],\ - graph_ldict["m2m_edge_index"] - mesh_up_edge_index, mesh_down_edge_index = graph_ldict["mesh_up_edge_index"],\ - graph_ldict["mesh_down_edge_index"] + ( + g2m_edge_index, + m2g_edge_index, + m2m_edge_index, + ) = ( + graph_ldict["g2m_edge_index"], + graph_ldict["m2g_edge_index"], + graph_ldict["m2m_edge_index"], + ) + mesh_up_edge_index, mesh_down_edge_index = ( + graph_ldict["mesh_up_edge_index"], + graph_ldict["mesh_down_edge_index"], + ) mesh_static_features = graph_ldict["mesh_static_features"] - grid_static_features = utils.load_static_data(args.dataset)["grid_static_features"] + grid_static_features = utils.load_static_data(args.dataset)[ + "grid_static_features" + ] # Extract values needed, turn to numpy grid_pos = grid_static_features[:, :2].numpy() # Add in z-dimension - z_grid = GRID_HEIGHT*np.ones((grid_pos.shape[0],)) - grid_pos = np.concatenate((grid_pos, np.expand_dims(z_grid, axis=1)), axis=1) + z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) + grid_pos = np.concatenate( + (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 + ) # List of edges to plot, (edge_index, color, line_width, label) edge_plot_list = [ @@ -46,26 +80,39 @@ def main(): (g2m_edge_index.numpy(), "black", 0.4, "G2M"), ] - # Mesh positioning and edges to plot differ if we have a hierachical graph + # Mesh positioning and edges to plot differ if we have a hierarchical graph if hierarchical: - mesh_level_pos = [np.concatenate(( - level_static_features.numpy(), - MESH_HEIGHT + MESH_LEVEL_DIST*height_level*np.ones( - (level_static_features.shape[0], 1)), - ), axis=1) - for height_level, level_static_features - in enumerate(mesh_static_features, start=1)] + mesh_level_pos = [ + np.concatenate( + ( + level_static_features.numpy(), + MESH_HEIGHT + + MESH_LEVEL_DIST + * height_level + * np.ones((level_static_features.shape[0], 1)), + ), + axis=1, + ) + for height_level, level_static_features in enumerate( + mesh_static_features, start=1 + ) + ] mesh_pos = np.concatenate(mesh_level_pos, axis=0) # Add inter-level mesh edges - edge_plot_list += [(level_ei.numpy(), "blue", 1, f"M2M Level {level}") - for level, level_ei in enumerate(m2m_edge_index)] + edge_plot_list += [ + (level_ei.numpy(), "blue", 1, f"M2M Level {level}") + for level, level_ei in enumerate(m2m_edge_index) + ] # Add intra-level mesh edges - up_edges_ei = np.concatenate([level_up_ei.numpy() - for level_up_ei in mesh_up_edge_index], axis=1) - down_edges_ei = np.concatenate([level_down_ei.numpy() - for level_down_ei in mesh_down_edge_index], axis=1) + up_edges_ei = np.concatenate( + [level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1 + ) + down_edges_ei = np.concatenate( + [level_down_ei.numpy() for level_down_ei in mesh_down_edge_index], + axis=1, + ) edge_plot_list.append((up_edges_ei, "green", 1, "Mesh up")) edge_plot_list.append((down_edges_ei, "green", 1, "Mesh down")) @@ -74,10 +121,12 @@ def main(): mesh_pos = mesh_static_features.numpy() mesh_degrees = pyg.utils.degree(m2m_edge_index[1]).numpy() - z_mesh = MESH_HEIGHT + 0.01*mesh_degrees - mesh_node_size = mesh_degrees/2 + z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees + mesh_node_size = mesh_degrees / 2 - mesh_pos = np.concatenate((mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1) + mesh_pos = np.concatenate( + (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1 + ) edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M")) @@ -86,44 +135,73 @@ def main(): # Add edges data_objs = [] - for ei, col, width, label, in edge_plot_list: - edge_start = node_pos[ei[0]] # (M, 2) - edge_end = node_pos[ei[1]] # (M, 2) + for ( + ei, + col, + width, + label, + ) in edge_plot_list: + edge_start = node_pos[ei[0]] # (M, 2) + edge_end = node_pos[ei[1]] # (M, 2) n_edges = edge_start.shape[0] - x_edges = np.stack((edge_start[:, 0], edge_end[:, 0], np.full(n_edges, None)), - axis=1).flatten() - y_edges = np.stack((edge_start[:, 1], edge_end[:, 1], np.full(n_edges, None)), - axis=1).flatten() - z_edges = np.stack((edge_start[:, 2], edge_end[:, 2], np.full(n_edges, None)), - axis=1).flatten() - - scatter_obj = go.Scatter3d(x=x_edges, y=y_edges, z=z_edges, - mode='lines', line={"color":col, "width":width}, name=label) + x_edges = np.stack( + (edge_start[:, 0], edge_end[:, 0], np.full(n_edges, None)), axis=1 + ).flatten() + y_edges = np.stack( + (edge_start[:, 1], edge_end[:, 1], np.full(n_edges, None)), axis=1 + ).flatten() + z_edges = np.stack( + (edge_start[:, 2], edge_end[:, 2], np.full(n_edges, None)), axis=1 + ).flatten() + + scatter_obj = go.Scatter3d( + x=x_edges, + y=y_edges, + z=z_edges, + mode="lines", + line={"color": col, "width": width}, + name=label, + ) data_objs.append(scatter_obj) # Add node objects - data_objs.append(go.Scatter3d(x=grid_pos[:, 0], y=grid_pos[:, 1], z=grid_pos[:, 2], - mode='markers', marker={"color":"black", "size":1}, name="Grid nodes")) - data_objs.append(go.Scatter3d(x=mesh_pos[:, 0], y=mesh_pos[:, 1], z=mesh_pos[:, 2], - mode='markers', marker={"color":"blue", "size": mesh_node_size}, - name="Mesh nodes")) + data_objs.append( + go.Scatter3d( + x=grid_pos[:, 0], + y=grid_pos[:, 1], + z=grid_pos[:, 2], + mode="markers", + marker={"color": "black", "size": 1}, + name="Grid nodes", + ) + ) + data_objs.append( + go.Scatter3d( + x=mesh_pos[:, 0], + y=mesh_pos[:, 1], + z=mesh_pos[:, 2], + mode="markers", + marker={"color": "blue", "size": mesh_node_size}, + name="Mesh nodes", + ) + ) fig = go.Figure(data=data_objs) - fig.update_layout(scene_aspectmode='data') + fig.update_layout(scene_aspectmode="data") fig.update_traces(connectgaps=False) if not args.show_axis: # Hide axis fig.update_layout( - scene = dict( - xaxis = dict(visible=False), - yaxis = dict(visible=False), - zaxis =dict(visible=False) - ) - ) + scene={ + "xaxis": {"visible": False}, + "yaxis": {"visible": False}, + "zaxis": {"visible": False}, + } + ) if args.save: fig.write_html(args.save, include_plotlyjs="cdn") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..2478ddf1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,63 @@ +[tool.black] +line-length = 80 + +[tool.isort] +default_section = "THIRDPARTY" +profile = "black" +# Headings +import_heading_stdlib = "Standard library" +import_heading_thirdparty = "Third-party" +import_heading_firstparty = "First-party" +import_heading_localfolder = "Local" +# Known modules to avoid misclassification +known_standard_library = [ + # Add standard library modules that may be misclassified by isort +] +known_third_party = [ + # Add third-party modules that may be misclassified by isort + "wandb", +] +known_first_party = [ + # Add first-party modules that may be misclassified by isort + "neural_lam", +] + +[tool.flake8] +max-line-length = 80 +ignore = [ + "E203", # Allow whitespace before ':' (https://github.com/PyCQA/pycodestyle/issues/373) + "I002", # Don't check for isort configuration + "W503", # Allow line break before binary operator (PEP 8-compatible) +] +per-file-ignores = [ + "__init__.py: F401", # Allow unused imports +] + +[tool.codespell] +skip = "requirements/*" + +# Pylint config +[tool.pylint] +ignore = [ + "create_mesh.py", # Disable linting for now, as major rework is planned/expected +] +# Temporary fix for import neural_lam statements until set up as proper package +init-hook='import sys; sys.path.append(".")' +[tool.pylint.TYPECHECK] +generated-members = [ + "numpy.*", + "torch.*", +] +[tool.pylint.'MESSAGES CONTROL'] +disable = [ + "C0114", # 'missing-module-docstring', Do not require module docstrings + "R0901", # 'too-many-ancestors', Allow many layers of sub-classing + "R0902", # 'too-many-instance-attribtes', Allow many attributes + "R0913", # 'too-many-arguments', Allow many function arguments + "R0914", # 'too-many-locals', Allow many local variables + "W0223", # 'abstract-method', Subclasses do not have to override all abstract methods +] +[tool.pylint.DESIGN] +max-statements=100 # Allow for some more involved functions +[tool.pylint.IMPORTS] +allow-any-import-level="neural_lam" diff --git a/requirements.txt b/requirements.txt index fc9b6843..5a2111b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +# for all numpy>=1.24.2 wandb>=0.13.10 matplotlib>=3.7.0 @@ -9,3 +10,10 @@ Cartopy>=0.22.0 pyproj>=3.4.1 tueplots>=0.0.8 plotly>=5.15.0 +# for dev +codespell>=2.0.0 +black>=21.9b0 +isort>=5.9.3 +flake8>=4.0.1 +pylint>=3.0.3 +pre-commit>=2.15.0 diff --git a/train_model.py b/train_model.py index ef9fc30f..96d21a3f 100644 --- a/train_model.py +++ b/train_model.py @@ -1,18 +1,19 @@ +# Standard library import random -import torch +import time +from argparse import ArgumentParser + +# Third-party import pytorch_lightning as pl +import torch from lightning_fabric.utilities import seed -from argparse import ArgumentParser -import time -import matplotlib.pyplot as plt -import wandb +# First-party +from neural_lam import constants, utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel - from neural_lam.weather_dataset import WeatherDataset -from neural_lam import constants, utils MODELS = { "graph_lam": GraphLAM, @@ -20,71 +21,178 @@ "hi_lam_parallel": HiLAMParallel, } + def main(): - parser = ArgumentParser(description='Train or evaluate NeurWP models for LAM') + """ + Main function for training and evaluating models + """ + parser = ArgumentParser( + description="Train or evaluate NeurWP models for LAM" + ) # General options - parser.add_argument('--dataset', type=str, default="meps_example", - help='Dataset, corresponding to name in data directory (default: meps_example)') - parser.add_argument('--model', type=str, default="graph_lam", - help='Model architecture to train/evaluate (default: graph_lam)') - parser.add_argument('--subset_ds', type=int, default=0, - help='Use only a small subset of the dataset, for debugging (default: 0=false)') - parser.add_argument('--seed', type=int, default=42, - help='random seed (default: 42)') - parser.add_argument('--n_workers', type=int, default=4, - help='Number of workers in data loader (default: 4)') - parser.add_argument('--epochs', type=int, default=200, - help='upper epoch limit (default: 200)') - parser.add_argument('--batch_size', type=int, default=4, - help='batch size (default: 4)') - parser.add_argument('--load', type=str, - help='Path to load model parameters from (default: None)') - parser.add_argument('--restore_opt', type=int, default=0, - help='If optimizer state shoudl be restored with model (default: 0 (false))') - parser.add_argument('--precision', type=str, default=32, - help='Numerical precision to use for model (32/16/bf16) (default: 32)') + parser.add_argument( + "--dataset", + type=str, + default="meps_example", + help="Dataset, corresponding to name in data directory " + "(default: meps_example)", + ) + parser.add_argument( + "--model", + type=str, + default="graph_lam", + help="Model architecture to train/evaluate (default: graph_lam)", + ) + parser.add_argument( + "--subset_ds", + type=int, + default=0, + help="Use only a small subset of the dataset, for debugging" + "(default: 0=false)", + ) + parser.add_argument( + "--seed", type=int, default=42, help="random seed (default: 42)" + ) + parser.add_argument( + "--n_workers", + type=int, + default=4, + help="Number of workers in data loader (default: 4)", + ) + parser.add_argument( + "--epochs", + type=int, + default=200, + help="upper epoch limit (default: 200)", + ) + parser.add_argument( + "--batch_size", type=int, default=4, help="batch size (default: 4)" + ) + parser.add_argument( + "--load", + type=str, + help="Path to load model parameters from (default: None)", + ) + parser.add_argument( + "--restore_opt", + type=int, + default=0, + help="If optimizer state should be restored with model " + "(default: 0 (false))", + ) + parser.add_argument( + "--precision", + type=str, + default=32, + help="Numerical precision to use for model (32/16/bf16) (default: 32)", + ) # Model architecture - parser.add_argument('--graph', type=str, default="multiscale", - help='Graph to load and use in graph-based model (default: multiscale)') - parser.add_argument('--hidden_dim', type=int, default=64, - help='Dimensionality of all hidden representations (default: 64)') - parser.add_argument('--hidden_layers', type=int, default=1, - help='Number of hidden layers in all MLPs (default: 1)') - parser.add_argument('--processor_layers', type=int, default=4, - help='Number of GNN layers in processor GNN (default: 4)') - parser.add_argument('--mesh_aggr', type=str, default="sum", - help='Aggregation to use for m2m processor GNN layers (sum/mean) (default: sum)') - parser.add_argument('--output_std', type=int, default=0, - help='If models should additionally output std.-dev. per output dimensions ' - '(default: 0 (no))') + parser.add_argument( + "--graph", + type=str, + default="multiscale", + help="Graph to load and use in graph-based model " + "(default: multiscale)", + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=64, + help="Dimensionality of all hidden representations (default: 64)", + ) + parser.add_argument( + "--hidden_layers", + type=int, + default=1, + help="Number of hidden layers in all MLPs (default: 1)", + ) + parser.add_argument( + "--processor_layers", + type=int, + default=4, + help="Number of GNN layers in processor GNN (default: 4)", + ) + parser.add_argument( + "--mesh_aggr", + type=str, + default="sum", + help="Aggregation to use for m2m processor GNN layers (sum/mean) " + "(default: sum)", + ) + parser.add_argument( + "--output_std", + type=int, + default=0, + help="If models should additionally output std.-dev. per " + "output dimensions " + "(default: 0 (no))", + ) # Training options - parser.add_argument('--ar_steps', type=int, default=1, - help='Number of steps to unroll prediction for in loss (1-19) (default: 1)') - parser.add_argument('--control_only', type=int, default=0, - help='Train only on control member of ensemble data (default: 0 (False))') - parser.add_argument('--loss', type=str, default="wmse", - help='Loss function to use, see metric.py (default: wmse)') - parser.add_argument('--step_length', type=int, default=3, - help='Step length in hours to consider single time step 1-3 (default: 3)') - parser.add_argument('--lr', type=float, default=1e-3, - help='learning rate (default: 0.001)') - parser.add_argument('--val_interval', type=int, default=1, - help='Number of epochs training between each validation run (default: 1)') + parser.add_argument( + "--ar_steps", + type=int, + default=1, + help="Number of steps to unroll prediction for in loss (1-19) " + "(default: 1)", + ) + parser.add_argument( + "--control_only", + type=int, + default=0, + help="Train only on control member of ensemble data " + "(default: 0 (False))", + ) + parser.add_argument( + "--loss", + type=str, + default="wmse", + help="Loss function to use, see metric.py (default: wmse)", + ) + parser.add_argument( + "--step_length", + type=int, + default=3, + help="Step length in hours to consider single time step 1-3 " + "(default: 3)", + ) + parser.add_argument( + "--lr", type=float, default=1e-3, help="learning rate (default: 0.001)" + ) + parser.add_argument( + "--val_interval", + type=int, + default=1, + help="Number of epochs training between each validation run " + "(default: 1)", + ) # Evaluation options - parser.add_argument('--eval', type=str, - help='Eval model on given data split (val/test) (default: None (train model))') - parser.add_argument('--n_example_pred', type=int, default=1, - help='Number of example predictions to plot during evaluation (default: 1)') + parser.add_argument( + "--eval", + type=str, + help="Eval model on given data split (val/test) " + "(default: None (train model))", + ) + parser.add_argument( + "--n_example_pred", + type=int, + default=1, + help="Number of example predictions to plot during evaluation " + "(default: 1)", + ) args = parser.parse_args() # Asserts for arguments assert args.model in MODELS, f"Unknown model: {args.model}" assert args.step_length <= 3, "Too high step length" - assert args.eval in (None, "val", "test"), f"Unknown eval setting: {args.eval}" + assert args.eval in ( + None, + "val", + "test", + ), f"Unknown eval setting: {args.eval}" # Get an (actual) random run id as a unique identifier random_run_id = random.randint(0, 9999) @@ -94,21 +202,39 @@ def main(): # Load data train_loader = torch.utils.data.DataLoader( - WeatherDataset(args.dataset, pred_length=args.ar_steps, split="train", - subsample_step=args.step_length, subset=bool(args.subset_ds), - control_only=args.control_only), - args.batch_size, shuffle=True, num_workers=args.n_workers) - max_pred_length = (65 // args.step_length) - 2 # 19 + WeatherDataset( + args.dataset, + pred_length=args.ar_steps, + split="train", + subsample_step=args.step_length, + subset=bool(args.subset_ds), + control_only=args.control_only, + ), + args.batch_size, + shuffle=True, + num_workers=args.n_workers, + ) + max_pred_length = (65 // args.step_length) - 2 # 19 val_loader = torch.utils.data.DataLoader( - WeatherDataset(args.dataset, pred_length=max_pred_length, split="val", - subsample_step=args.step_length, subset=bool(args.subset_ds), - control_only=args.control_only), - args.batch_size, shuffle=False, num_workers=args.n_workers) + WeatherDataset( + args.dataset, + pred_length=max_pred_length, + split="val", + subsample_step=args.step_length, + subset=bool(args.subset_ds), + control_only=args.control_only, + ), + args.batch_size, + shuffle=False, + num_workers=args.n_workers, + ) - # Instatiate model + trainer + # Instantiate model + trainer if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision( + "high" + ) # Allows using Tensor Cores on A100s else: device_name = "cpu" @@ -126,37 +252,63 @@ def main(): prefix = "subset-" if args.subset_ds else "" if args.eval: prefix = prefix + f"eval-{args.eval}-" - run_name = f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"\ - f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}" + run_name = ( + f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" + f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}" + ) checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=f"saved_models/{run_name}", filename="min_val_loss", - monitor="val_mean_loss", mode="min", save_last=True) - logger = pl.loggers.WandbLogger(project=constants.wandb_project, name=run_name, - config=args) - trainer = pl.Trainer(max_epochs=args.epochs, deterministic=True, strategy="ddp", - accelerator=device_name, logger=logger, log_every_n_steps=1, - callbacks=[checkpoint_callback], check_val_every_n_epoch=args.val_interval, - precision=args.precision) + dirpath=f"saved_models/{run_name}", + filename="min_val_loss", + monitor="val_mean_loss", + mode="min", + save_last=True, + ) + logger = pl.loggers.WandbLogger( + project=constants.WANDB_PROJECT, name=run_name, config=args + ) + trainer = pl.Trainer( + max_epochs=args.epochs, + deterministic=True, + strategy="ddp", + accelerator=device_name, + logger=logger, + log_every_n_steps=1, + callbacks=[checkpoint_callback], + check_val_every_n_epoch=args.val_interval, + precision=args.precision, + ) # Only init once, on rank 0 only if trainer.global_rank == 0: - utils.init_wandb_metrics(logger) # Do after wandb.init + utils.init_wandb_metrics(logger) # Do after wandb.init if args.eval: if args.eval == "val": eval_loader = val_loader - else: # Test - eval_loader = torch.utils.data.DataLoader(WeatherDataset(args.dataset, - pred_length=max_pred_length, split="test", - subsample_step=args.step_length, subset=bool(args.subset_ds)), - args.batch_size, shuffle=False, num_workers=args.n_workers) + else: # Test + eval_loader = torch.utils.data.DataLoader( + WeatherDataset( + args.dataset, + pred_length=max_pred_length, + split="test", + subsample_step=args.step_length, + subset=bool(args.subset_ds), + ), + args.batch_size, + shuffle=False, + num_workers=args.n_workers, + ) print(f"Running evaluation on {args.eval}") trainer.test(model=model, dataloaders=eval_loader) else: # Train model - trainer.fit(model=model, train_dataloaders=train_loader, - val_dataloaders=val_loader) + trainer.fit( + model=model, + train_dataloaders=train_loader, + val_dataloaders=val_loader, + ) + if __name__ == "__main__": main()