From 799d55e3abd8a7ba34507cf1e2d524be070de89f Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 14 Aug 2024 12:30:07 +0000 Subject: [PATCH] linting fixes --- .pre-commit-config.yaml | 3 +- neural_lam/create_graph.py | 62 ++----- neural_lam/datastore/base.py | 155 +++++++++++++----- neural_lam/datastore/mllam.py | 114 +++++++++---- .../multizarr/create_boundary_mask.py | 1 + .../multizarr/create_datetime_forcings.py | 13 +- .../multizarr/create_normalization_stats.py | 17 +- neural_lam/datastore/multizarr/store.py | 116 ++++++++----- neural_lam/datastore/npyfiles/config.py | 16 +- neural_lam/datastore/npyfiles/store.py | 132 +++++++++------ neural_lam/interaction_net.py | 19 ++- neural_lam/metrics.py | 40 ++--- neural_lam/models/ar_model.py | 84 ++++------ neural_lam/models/base_graph_model.py | 31 ++-- neural_lam/models/base_hi_graph_model.py | 63 +++---- neural_lam/models/graph_lam.py | 31 ++-- neural_lam/models/hi_lam.py | 22 ++- neural_lam/models/hi_lam_parallel.py | 17 +- neural_lam/train_model.py | 23 +-- neural_lam/utils.py | 24 +-- neural_lam/vis.py | 25 +-- neural_lam/weather_dataset.py | 64 ++++---- plot_graph.py | 8 +- pyproject.toml | 11 +- tests/conftest.py | 18 +- tests/test_cli.py | 4 +- tests/test_datasets.py | 13 +- tests/test_datastores.py | 25 +-- tests/test_graph_creation.py | 11 +- tests/test_training.py | 4 +- 30 files changed, 625 insertions(+), 541 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fd40f4d7..91983d9b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,4 +40,5 @@ repos: rev: v1.7.5 hooks: - id: docformatter - args: [--in-place, --recursive] + args: [--in-place, --recursive, --config, ./pyproject.toml] + additional_dependencies: [tomli] diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index e5eb44a4..6450f134 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -35,9 +35,7 @@ def plot_graph(graph, title=None): # TODO: indicate direction of directed edges # Move all to cpu and numpy, compute (in)-degrees - degrees = ( - pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy() - ) + degrees = pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy() edge_index = edge_index.cpu().numpy() pos = pos.cpu().numpy() @@ -82,9 +80,7 @@ def sort_nodes_internally(nx_graph): def save_edges(graph, name, base_path): - torch.save( - graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt") - ) + torch.save(graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt")) edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( torch.float32 ) # Save as float32 @@ -97,9 +93,7 @@ def save_edges_list(graphs, name, base_path): os.path.join(base_path, f"{name}_edge_index.pt"), ) edge_features = [ - torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( - torch.float32 - ) + torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(torch.float32) for graph in graphs ] # Save as float32 torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) @@ -130,11 +124,7 @@ def mk_2d_graph(xy, nx, ny): # add diagonal edges g.add_edges_from( [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] - + [ - ((x + 1, y), (x, y + 1)) - for x in range(nx - 1) - for y in range(ny - 1) - ] + + [((x + 1, y), (x, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] ) # turn into directed graph @@ -164,8 +154,7 @@ def create_graph( hierarchical: bool, create_plot: bool, ): - """Create graph components from `xy` grid coordinates and store in - `graph_dir_path`. + """Create graph components from `xy` grid coordinates and store in `graph_dir_path`. Creates the following files for all graphs: - g2m_edge_index.pt [2, N_g2m_edges] @@ -225,6 +214,7 @@ def create_graph( Returns ------- None + """ os.makedirs(graph_dir_path, exist_ok=True) @@ -262,10 +252,7 @@ def create_graph( if hierarchical: # Relabel nodes of each level with level index first - G = [ - prepend_node_index(graph, level_i) - for level_i, graph in enumerate(G) - ] + G = [prepend_node_index(graph, level_i) for level_i, graph in enumerate(G)] num_nodes_level = np.array([len(g_level.nodes) for g_level in G]) # First node index in each level in the hierarchical graph @@ -307,9 +294,7 @@ def create_graph( # add edge from mesh to grid G_down.add_edge(u, v) d = np.sqrt( - np.sum( - (G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2 - ) + np.sum((G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2) ) G_down.edges[u, v]["len"] = d G_down.edges[u, v]["vdiff"] = ( @@ -334,14 +319,10 @@ def create_graph( down_graphs.append(pyg_down) if create_plot: - plot_graph( - pyg_down, title=f"Down graph, {from_level} -> {to_level}" - ) + plot_graph(pyg_down, title=f"Down graph, {from_level} -> {to_level}") plt.show() - plot_graph( - pyg_down, title=f"Up graph, {to_level} -> {from_level}" - ) + plot_graph(pyg_down, title=f"Up graph, {to_level} -> {from_level}") plt.show() # Save up and down edges @@ -426,9 +407,7 @@ def create_graph( vm = G_bottom_mesh.nodes vm_xy = np.array([xy for _, xy in vm.data("pos")]) # distance between mesh nodes - dm = np.sqrt( - np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) - ) + dm = np.sqrt(np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2)) # grid nodes Ny, Nx = xy.shape[1:] @@ -470,13 +449,9 @@ def create_graph( u = vg_list[i] # add edge from grid to mesh G_g2m.add_edge(u, v) - d = np.sqrt( - np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2) - ) + d = np.sqrt(np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2)) G_g2m.edges[u, v]["len"] = d - G_g2m.edges[u, v]["vdiff"] = ( - G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"] - ) + G_g2m.edges[u, v]["vdiff"] = G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"] pyg_g2m = from_networkx(G_g2m) @@ -505,13 +480,9 @@ def create_graph( u = vm_list[i] # add edge from mesh to grid G_m2g.add_edge(u, v) - d = np.sqrt( - np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2) - ) + d = np.sqrt(np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2)) G_m2g.edges[u, v]["len"] = d - G_m2g.edges[u, v]["vdiff"] = ( - G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"] - ) + G_m2g.edges[u, v]["vdiff"] = G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"] # relabel nodes to integers (sorted) G_m2g_int = networkx.convert_node_labels_to_integers( @@ -578,8 +549,7 @@ def cli(input_args=None): "--plot", type=int, default=0, - help="If graphs should be plotted during generation " - "(default: 0 (false))", + help="If graphs should be plotted during generation " "(default: 0 (false))", ) parser.add_argument( "--levels", diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 101a13bc..1b662fa4 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -11,10 +11,17 @@ class BaseDatastore(abc.ABC): - """Base class for weather data used in the neural-lam package. A datastore - defines the interface for accessing weather data by providing methods to - access the data in a processed format that can be used for training and - evaluation of neural networks. + """Base class for weather + data used in the neural- + lam package. A datastore + defines the interface for + accessing weather data by + providing methods to + access the data in a + processed format that can + be used for training and + evaluation of neural + networks. NOTE: All methods return either primitive types, `numpy.ndarray`, `xarray.DataArray` or `xarray.Dataset` objects, not `pytorch.Tensor` @@ -32,6 +39,7 @@ class BaseDatastore(abc.ABC): If the datastore is used to represent ensemble data, then the `is_ensemble` attribute should be set to True, and returned data from `get_dataarray` is assumed to have an `ensemble_member` dimension. + """ is_ensemble: bool = False @@ -40,13 +48,14 @@ class BaseDatastore(abc.ABC): @property @abc.abstractmethod def root_path(self) -> Path: - """The root path to the datastore. It is relative to this that any - derived files (for example the graph components) are stored. + """The root path to the datastore. It is relative to this that any derived files + (for example the graph components) are stored. Returns ------- pathlib.Path The root path to the datastore. + """ pass @@ -57,6 +66,7 @@ def step_length(self) -> int: Returns: int: The step length in hours. + """ pass @@ -73,6 +83,7 @@ def get_vars_units(self, category: str) -> List[str]: ------- List[str] The units of the variables. + """ pass @@ -89,6 +100,7 @@ def get_vars_names(self, category: str) -> List[str]: ------- List[str] The names of the variables. + """ pass @@ -105,19 +117,39 @@ def get_num_data_vars(self, category: str) -> int: ------- int The number of data variables. + """ pass @abc.abstractmethod def get_normalization_dataarray(self, category: str) -> xr.Dataset: - """Return the normalization dataarray for the given category. This - should contain a `{category}_mean` and `{category}_std` variable for - each variable in the category. For `category=="state"`, the dataarray - should also contain a `state_diff_mean` and `state_diff_std` variable - for the one-step differences of the state variables. The returned - dataarray should at least have dimensions of `({category}_feature)`, - but can also include for example `grid_index` (if the normalisation is - done per grid point for example). + """Return the + normalization + dataarray for the + given category. This + should contain a + `{category}_mean` and + `{category}_std` + variable for each + variable in the + category. For + `category=="state"`, + the dataarray should + also contain a + `state_diff_mean` and + `state_diff_std` + variable for the one- + step differences of + the state variables. + The returned dataarray + should at least have + dimensions of `({categ + ory}_feature)`, but + can also include for + example `grid_index` + (if the normalisation + is done per grid point + for example). Parameters ---------- @@ -130,18 +162,30 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset: The normalization dataarray for the given category, with variables for the mean and standard deviation of the variables (and differences for state variables). + """ pass @abc.abstractmethod - def get_dataarray( - self, category: str, split: str - ) -> Union[xr.DataArray, None]: - """Return the processed data (as a single `xr.DataArray`) for the given - category of data and test/train/val-split that covers all the data (in - space and time) of a given category (state/forcing/static). A datastore - must be able to return for the "state" category, but "forcing" and - "static" are optional (in which case the method should return `None`). + def get_dataarray(self, category: str, split: str) -> Union[xr.DataArray, None]: + """Return the + processed data (as a + single `xr.DataArray`) + for the given category + of data and + test/train/val-split + that covers all the + data (in space and + time) of a given + category (state/forcin + g/static). A datastore + must be able to return + for the "state" + category, but + "forcing" and "static" + are optional (in which + case the method should + return `None`). The returned dataarray is expected to at minimum have dimensions of `(grid_index, {category}_feature)` so that any spatial dimensions have @@ -168,20 +212,29 @@ def get_dataarray( ------- xr.DataArray or None The xarray DataArray object with processed dataset. + """ pass @property @abc.abstractmethod def boundary_mask(self) -> xr.DataArray: - """Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. + """Return the boundary + mask for the dataset, + with spatial + dimensions stacked. + Where the value is 1, + the grid point is a + boundary point, and + where the value is 0, + the grid point is not + a boundary point. Returns ------- xr.DataArray The boundary mask for the dataset, with dimensions `('grid_index',)`. + """ pass @@ -195,12 +248,21 @@ class CartesianGridShape: class BaseCartesianDatastore(BaseDatastore): - """Base class for weather data stored on a Cartesian grid. In addition to - the methods and attributes required for weather data in general (see - `BaseDatastore`) for Cartesian gridded source data each `grid_index` - coordinate value is assume to have an associated `x` and `y`-value so that - the processed data-arrays can be reshaped back into into 2D xy-gridded - arrays. + """Base class for weather + data stored on a Cartesian + grid. In addition to the + methods and attributes + required for weather data + in general (see + `BaseDatastore`) for + Cartesian gridded source + data each `grid_index` + coordinate value is assume + to have an associated `x` + and `y`-value so that the + processed data-arrays can + be reshaped back into into + 2D xy-gridded arrays. In addition the following attributes and methods are required: - `coords_projection` (property): Projection object for the coordinates. @@ -208,6 +270,7 @@ class BaseCartesianDatastore(BaseDatastore): - `get_xy_extent` (method): Return the extent of the x, y coordinates for a given category of data. - `get_xy` (method): Return the x, y coordinates of the dataset. + """ CARTESIAN_COORDS = ["y", "x"] @@ -223,6 +286,7 @@ def coords_projection(self) -> ccrs.Projection: ------- cartopy.crs.Projection: The projection object. + """ pass @@ -236,6 +300,7 @@ def grid_shape_state(self) -> CartesianGridShape: CartesianGridShape: The shape of the grid for the state variables, which has `x` and `y` attributes. + """ pass @@ -257,13 +322,22 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray: value of `stacked`: - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y. - `stacked==False`: shape `(2, N_y, N_x)` + """ pass def get_xy_extent(self, category: str) -> List[float]: - """Return the extent of the x, y coordinates for a given category of - data. The extent should be returned as a list of 4 floats with `[xmin, - xmax, ymin, ymax]` which can then be used to set the extent of a plot. + """Return the extent + of the x, y + coordinates for a + given category of + data. The extent + should be returned as + a list of 4 floats + with `[xmin, xmax, + ymin, ymax]` which can + then be used to set + the extent of a plot. Parameters ---------- @@ -274,6 +348,7 @@ def get_xy_extent(self, category: str) -> List[float]: ------- List[float] The extent of the x, y coordinates. + """ xy = self.get_xy(category, stacked=False) extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] @@ -282,9 +357,8 @@ def get_xy_extent(self, category: str) -> List[float]: def unstack_grid_coords( self, da_or_ds: Union[xr.DataArray, xr.Dataset] ) -> Union[xr.DataArray, xr.Dataset]: - """Stack the spatial grid coordinates into separate `x` and `y` - dimensions (the names can be set by the `CARTESIAN_COORDS` attribute) - to create a 2D grid. + """Stack the spatial grid coordinates into separate `x` and `y` dimensions (the + names can be set by the `CARTESIAN_COORDS` attribute) to create a 2D grid. Parameters ---------- @@ -295,6 +369,7 @@ def unstack_grid_coords( ------- xr.DataArray or xr.Dataset The dataarray or dataset with the grid coordinates unstacked. + """ return da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS).unstack( "grid_index" @@ -303,9 +378,8 @@ def unstack_grid_coords( def stack_grid_coords( self, da_or_ds: Union[xr.DataArray, xr.Dataset] ) -> Union[xr.DataArray, xr.Dataset]: - """Stack the spatial grid coordinated (by default `x` and `y`, but this - can be set by the `CARTESIAN_COORDS` attribute) into a single - `grid_index` dimension. + """Stack the spatial grid coordinated (by default `x` and `y`, but this can be + set by the `CARTESIAN_COORDS` attribute) into a single `grid_index` dimension. Parameters ---------- @@ -316,5 +390,6 @@ def stack_grid_coords( ------- xr.DataArray or xr.Dataset The dataarray or dataset with the grid coordinates stacked. + """ return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py index ae2c5d53..0d011e5e 100644 --- a/neural_lam/datastore/mllam.py +++ b/neural_lam/datastore/mllam.py @@ -19,11 +19,23 @@ class MLLAMDatastore(BaseCartesianDatastore): """Datastore class for the MLLAM dataset.""" def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): - """Construct a new MLLAMDatastore from the configuration file at - `config_path`. A boundary mask is created with `n_boundary_points` - boundary points. If `reuse_existing` is True, the dataset is loaded - from a zarr file if it exists (unless the config has been modified - since the zarr was created), otherwise it is created from the + """Construct a new + MLLAMDatastore from + the configuration file + at `config_path`. A + boundary mask is + created with + `n_boundary_points` + boundary points. If + `reuse_existing` is + True, the dataset is + loaded from a zarr + file if it exists + (unless the config has + been modified since + the zarr was created), + otherwise it is + created from the configuration file. Parameters @@ -37,13 +49,12 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): reuse_existing : bool Whether to reuse an existing dataset zarr file if it exists and its creation date is newer than the configuration file. + """ self._config_path = Path(config_path) self._root_path = self._config_path.parent self._config = mdp.Config.from_yaml_file(self._config_path) - fp_ds = self._root_path / self._config_path.name.replace( - ".yaml", ".zarr" - ) + fp_ds = self._root_path / self._config_path.name.replace(".yaml", ".zarr") self._ds = None if reuse_existing and fp_ds.exists(): @@ -71,6 +82,7 @@ def root_path(self) -> Path: ------- Path The root path of the dataset. + """ return self._root_path @@ -82,6 +94,7 @@ def step_length(self) -> int: ------- int The length of the time steps in hours. + """ da_dt = self._ds["time"].diff("time") return (da_dt.dt.seconds[0] // 3600).item() @@ -98,6 +111,7 @@ def get_vars_units(self, category: str) -> List[str]: ------- List[str] The units of the variables in the given category. + """ if category not in self._ds and category == "forcing": warnings.warn("no forcing data found in datastore") @@ -116,6 +130,7 @@ def get_vars_names(self, category: str) -> List[str]: ------- List[str] The names of the variables in the given category. + """ if category not in self._ds and category == "forcing": warnings.warn("no forcing data found in datastore") @@ -134,15 +149,29 @@ def get_num_data_vars(self, category: str) -> int: ------- int The number of variables in the given category. + """ return len(self.get_vars_names(category)) def get_dataarray(self, category: str, split: str) -> xr.DataArray: - """Return the processed data (as a single `xr.DataArray`) for the given - category of data and test/train/val-split that covers all the data (in - space and time) of a given category (state/forcing/static). "state" is - the only required category, for other categories, the method will - return `None` if the category is not found in the datastore. + """Return the + processed data (as a + single `xr.DataArray`) + for the given category + of data and + test/train/val-split + that covers all the + data (in space and + time) of a given + category (state/forcin + g/static). "state" is + the only required + category, for other + categories, the method + will return `None` if + the category is not + found in the + datastore. The returned dataarray will at minimum have dimensions of `(grid_index, {category}_feature)` so that any spatial dimensions have been stacked @@ -169,6 +198,7 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: ------- xr.DataArray or None The xarray DataArray object with processed dataset. + """ if category not in self._ds and category == "forcing": warnings.warn("no forcing data found in datastore") @@ -194,11 +224,24 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: return da_category.sel(time=slice(t_start, t_end)) def get_normalization_dataarray(self, category: str) -> xr.Dataset: - """Return the normalization dataarray for the given category. This - should contain a `{category}_mean` and `{category}_std` variable for - each variable in the category. For `category=="state"`, the dataarray - should also contain a `state_diff_mean` and `state_diff_std` variable - for the one-step differences of the state variables. + """Return the + normalization + dataarray for the + given category. This + should contain a + `{category}_mean` and + `{category}_std` + variable for each + variable in the + category. For + `category=="state"`, + the dataarray should + also contain a + `state_diff_mean` and + `state_diff_std` + variable for the one- + step differences of + the state variables. Parameters ---------- @@ -211,6 +254,7 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset: The normalization dataarray for the given category, with variables for the mean and standard deviation of the variables (and differences for state variables). + """ ops = ["mean", "std"] split = "train" @@ -227,11 +271,23 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset: @property def boundary_mask(self) -> xr.DataArray: - """Produce a 0/1 mask for the boundary points of the dataset, these - will sit at the edges of the domain (in x/y extent) and will be used to - mask out the boundary points from the loss function and to overwrite - the boundary points from the prediction. For now this is created when - the mask is requested, but in the future this could be saved to the + """Produce a 0/1 mask + for the boundary + points of the dataset, + these will sit at the + edges of the domain + (in x/y extent) and + will be used to mask + out the boundary + points from the loss + function and to + overwrite the boundary + points from the + prediction. For now + this is created when + the mask is requested, + but in the future this + could be saved to the zarr file. Returns @@ -239,19 +295,16 @@ def boundary_mask(self) -> xr.DataArray: xr.DataArray A 0/1 mask for the boundary points of the dataset, where 1 is a boundary point and 0 is not. + """ ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) + da_state_variable = ds_unstacked["state"].isel(time=0).isel(state_feature=0) da_domain_allzero = xr.zeros_like(da_state_variable) ds_unstacked["boundary_mask"] = da_domain_allzero.isel( x=slice(self._n_boundary_points, -self._n_boundary_points), y=slice(self._n_boundary_points, -self._n_boundary_points), ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) + ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(1).astype(int) return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) @property @@ -262,6 +315,7 @@ def coords_projection(self) -> ccrs.Projection: ------- ccrs.Projection The projection of the coordinates. + """ # TODO: danra doesn't contain projection information yet, but the next # version will for now we hardcode the projection @@ -276,6 +330,7 @@ def grid_shape_state(self): ------- CartesianGridShape The shape of the cartesian grid for the state variables. + """ ds_state = self.unstack_grid_coords(self._ds["state"]) da_x, da_y = ds_state.x, ds_state.y @@ -299,6 +354,7 @@ def get_xy(self, category: str, stacked: bool) -> ndarray: value of `stacked`: - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y. - `stacked==False`: shape `(2, N_y, N_x)` + """ # assume variables are stored in dimensions [grid_index, ...] ds_category = self.unstack_grid_coords(da_or_ds=self._ds[category]) diff --git a/neural_lam/datastore/multizarr/create_boundary_mask.py b/neural_lam/datastore/multizarr/create_boundary_mask.py index ae154941..31966394 100644 --- a/neural_lam/datastore/multizarr/create_boundary_mask.py +++ b/neural_lam/datastore/multizarr/create_boundary_mask.py @@ -21,6 +21,7 @@ def create_boundary_mask(data_config_path, zarr_path, n_boundary_cells): Data configuration. zarr_path : str Path to save the Zarr archive. + """ data_config_path = config.Config.from_file(str(data_config_path)) mask = np.zeros(list(data_config_path.grid_shape_state.values.values())) diff --git a/neural_lam/datastore/multizarr/create_datetime_forcings.py b/neural_lam/datastore/multizarr/create_datetime_forcings.py index 82a90147..7b645cae 100644 --- a/neural_lam/datastore/multizarr/create_datetime_forcings.py +++ b/neural_lam/datastore/multizarr/create_datetime_forcings.py @@ -36,6 +36,7 @@ def calculate_datetime_forcing(da_time: xr.DataArray): - hour_cos: The cosine of the hour of the day, normalized to [0, 1]. - year_sin: The sine of the time of year, normalized to [0, 1]. - year_cos: The cosine of the time of year, normalized to [0, 1]. + """ hours_of_day = xr.DataArray(da_time.dt.hour, dims=["time"]) seconds_into_year = xr.DataArray( @@ -49,10 +50,7 @@ def calculate_datetime_forcing(da_time: xr.DataArray): dims=["time"], ) year_seconds = xr.DataArray( - [ - get_seconds_in_year(pd.Timestamp(dt_obj).year) - for dt_obj in da_time.values - ], + [get_seconds_in_year(pd.Timestamp(dt_obj).year) for dt_obj in da_time.values], dims=["time"], ) hour_angle = (hours_of_day / 12) * np.pi @@ -85,6 +83,7 @@ def create_datetime_forcing_zarr( The time DataArray for which to create the datetime forcing. chunking : dict, optional The chunking to use when saving the Zarr archive. + """ if zarr_path is None: zarr_path = Path(data_config_path).parent / DEFAULT_FILENAME @@ -92,9 +91,9 @@ def create_datetime_forcing_zarr( datastore = MultiZarrDatastore(config_path=data_config_path) da_state = datastore.get_dataarray(category="state", split="train") - da_datetime_forcing = calculate_datetime_forcing( - da_time=da_state.time - ).expand_dims({"grid_index": da_state.grid_index}) + da_datetime_forcing = calculate_datetime_forcing(da_time=da_state.time).expand_dims( + {"grid_index": da_state.grid_index} + ) if "x" in da_state.coords and "y" in da_state.coords: # copy the x and y coordinates to the datetime forcing diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py index b4cf1be6..7a6df4d2 100644 --- a/neural_lam/datastore/multizarr/create_normalization_stats.py +++ b/neural_lam/datastore/multizarr/create_normalization_stats.py @@ -21,8 +21,8 @@ def create_normalization_stats_zarr( data_config_path: str, zarr_path: str = None, ): - """Compute mean and std.-dev. for state and forcing variables and save them - to a Zarr file. + """Compute mean and std.-dev. for state and forcing variables and save them to a + Zarr file. Parameters ---------- @@ -32,6 +32,7 @@ def create_normalization_stats_zarr( Path to save the normalization statistics to. If not provided, the statistics are saved to the same directory as the data config file with the name `normalization.zarr`. + """ if zarr_path is None: zarr_path = Path(data_config_path).parent / DEFAULT_FILENAME @@ -54,9 +55,7 @@ def create_normalization_stats_zarr( for group in combined_stats: vars_to_combine = group["vars"] - da_forcing_means = da_forcing_mean.sel( - forcing_feature=vars_to_combine - ) + da_forcing_means = da_forcing_mean.sel(forcing_feature=vars_to_combine) stds = da_forcing_std.sel(forcing_feature=vars_to_combine) combined_mean = da_forcing_means.mean(dim="forcing_feature") @@ -65,12 +64,8 @@ def create_normalization_stats_zarr( da_forcing_mean.loc[ dict(forcing_feature=vars_to_combine) ] = combined_mean - da_forcing_std.loc[ - dict(forcing_feature=vars_to_combine) - ] = combined_std - print( - "Computing mean and std.-dev. for one-step differences...", flush=True - ) + da_forcing_std.loc[dict(forcing_feature=vars_to_combine)] = combined_std + print("Computing mean and std.-dev. for one-step differences...", flush=True) state_data_normalized = (da_state - da_state_mean) / da_state_std state_data_diff_normalized = state_data_normalized.diff(dim="time") diff_mean, diff_std = compute_stats(state_data_diff_normalized) diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py index 1a3a2a89..18af8457 100644 --- a/neural_lam/datastore/multizarr/store.py +++ b/neural_lam/datastore/multizarr/store.py @@ -18,15 +18,25 @@ class MultiZarrDatastore(BaseCartesianDatastore): DIMS_TO_KEEP = {"time", "grid_index", "variable_name"} def __init__(self, config_path): - """Create a multi-zarr datastore from the given configuration file. The - configuration file should be a YAML file, the format of which is should - be inferred from the example configuration file in - `tests/datastore_examples/multizarr/data_config.yml`. + """Create a multi-zarr + datastore from the + given configuration + file. The + configuration file + should be a YAML file, + the format of which is + should be inferred + from the example + configuration file in + `tests/datastore_examp + les/multizarr/data_con + fig.yml`. Parameters ---------- config_path : str The path to the configuration file. + """ self._config_path = Path(config_path) self._root_path = self._config_path.parent @@ -41,6 +51,7 @@ def root_path(self): ------- str The root path of the datastore. + """ return self._root_path @@ -80,6 +91,7 @@ def open_zarrs(self, category): ------- xr.Dataset The xarray Dataset object. + """ zarr_configs = self._config[category]["zarrs"] @@ -104,6 +116,7 @@ def coords_projection(self): Returns: cartopy.crs.Projection: The projection object. + """ proj_config = self._config["projection"] proj_class_name = proj_config["class"] @@ -117,6 +130,7 @@ def step_length(self): Returns: int: The step length in hours. + """ dataset = self.open_zarrs("state") time = dataset.time.isel(time=slice(0, 2)).values @@ -133,6 +147,7 @@ def get_vars_names(self, category): Returns: list: The names of the variables in the dataset. + """ surface_vars_names = self._config[category].get("surface_vars") or [] atmosphere_vars_names = [ @@ -151,6 +166,7 @@ def get_vars_units(self, category): Returns: list: The units of the variables in the dataset. + """ surface_vars_units = self._config[category].get("surface_units") or [] atmosphere_vars_units = [ @@ -169,14 +185,13 @@ def get_num_data_vars(self, category): Returns: int: The number of data variables in the dataset. + """ surface_vars = self._config[category].get("surface_vars", []) atmosphere_vars = self._config[category].get("atmosphere_vars", []) levels = self._config[category].get("levels", []) - surface_vars_count = ( - len(surface_vars) if surface_vars is not None else 0 - ) + surface_vars_count = len(surface_vars) if surface_vars is not None else 0 atmosphere_vars_count = ( len(atmosphere_vars) if atmosphere_vars is not None else 0 ) @@ -192,6 +207,7 @@ def _stack_grid(self, ds): Returns: xr.Dataset: The xarray Dataset object with stacked grid dimensions. + """ if "grid_index" in ds.dims: raise ValueError("Grid dimensions already stacked.") @@ -212,6 +228,7 @@ def _convert_dataset_to_dataarray(self, dataset): Returns: xr.DataArray: The xarray DataArray object. + """ if isinstance(dataset, xr.Dataset): dataset = dataset.to_array(dim="variable_name") @@ -227,6 +244,7 @@ def _filter_dimensions(self, dataset, transpose_array=True): Returns: xr.Dataset: The xarray Dataset object with filtered dimensions. OR xr.DataArray: The xarray DataArray object with filtered dimensions. + """ dims_to_keep = self.DIMS_TO_KEEP dataset_dims = set(list(dataset.dims) + ["variable_name"]) @@ -277,9 +295,7 @@ def _filter_dimensions(self, dataset, transpose_array=True): dataset = self._convert_dataset_to_dataarray(dataset) if "time" in dataset.dims: - dataset = dataset.transpose( - "time", "grid_index", "variable_name" - ) + dataset = dataset.transpose("time", "grid_index", "variable_name") else: dataset = dataset.transpose("grid_index", "variable_name") dataset_vars = ( @@ -304,6 +320,7 @@ def _reshape_grid_to_2d(self, dataset, grid_shape=None): Returns: xr.Dataset: The xarray Dataset object with reshaped grid dimensions. + """ if grid_shape is None: grid_shape = dict(self.grid_shape_state.values.items()) @@ -311,13 +328,9 @@ def _reshape_grid_to_2d(self, dataset, grid_shape=None): x_coords = np.arange(x_dim) y_coords = np.arange(y_dim) - multi_index = pd.MultiIndex.from_product( - [y_coords, x_coords], names=["y", "x"] - ) + multi_index = pd.MultiIndex.from_product([y_coords, x_coords], names=["y", "x"]) - mindex_coords = xr.Coordinates.from_pandas_multiindex( - multi_index, "grid" - ) + mindex_coords = xr.Coordinates.from_pandas_multiindex(multi_index, "grid") dataset = dataset.drop_vars(["grid", "x", "y"], errors="ignore") dataset = dataset.assign_coords(mindex_coords) reshaped_data = dataset.unstack("grid") @@ -342,13 +355,12 @@ def get_xy(self, category, stacked=True): value of `stacked`: - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y. - `stacked==False`: shape `(2, N_y, N_x)` + """ dataset = self.open_zarrs(category) xs, ys = dataset.x.values, dataset.y.values - assert ( - xs.ndim == ys.ndim - ), "x and y coordinates must have the same dimensions." + assert xs.ndim == ys.ndim, "x and y coordinates must have the same dimensions." if xs.ndim == 1: x, y = np.meshgrid(xs, ys) @@ -366,14 +378,33 @@ def get_xy(self, category, stacked=True): @functools.lru_cache() def get_normalization_dataarray(self, category: str) -> xr.Dataset: - """Return the normalization dataarray for the given category. This - should contain a `{category}_mean` and `{category}_std` variable for - each variable in the category. For `category=="state"`, the dataarray - should also contain a `state_diff_mean` and `state_diff_std` variable - for the one-step differences of the state variables. The return - dataarray should at least have dimensions of `({category}_feature)`, - but can also include for example `grid_index` (if the normalisation is - done per grid point for example). + """Return the + normalization + dataarray for the + given category. This + should contain a + `{category}_mean` and + `{category}_std` + variable for each + variable in the + category. For + `category=="state"`, + the dataarray should + also contain a + `state_diff_mean` and + `state_diff_std` + variable for the one- + step differences of + the state variables. + The return dataarray + should at least have + dimensions of `({categ + ory}_feature)`, but + can also include for + example `grid_index` + (if the normalisation + is done per grid point + for example). Parameters ---------- @@ -386,6 +417,7 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset: The normalization dataarray for the given category, with variables for the mean and standard deviation of the variables (and differences for state variables). + """ # XXX: the multizarr code didn't include routines for computing the # normalization of "static" features previously, we'll just hack @@ -423,6 +455,7 @@ def _load_and_merge_stats(self): Returns: xr.Dataset: The merged normalization statistics for the dataset. + """ combined_stats = None for i, zarr_config in enumerate( @@ -449,6 +482,7 @@ def _rename_data_vars(self, combined_stats): Returns: xr.Dataset: The combined normalization statistics with renamed data variables. + """ vars_mapping = {} for zarr_config in self._config["utilities"]["normalization"]["zarrs"]: @@ -471,6 +505,7 @@ def _select_stats_by_category(self, combined_stats, category): Returns: xr.Dataset: The normalization statistics for the dataset. + """ if category == "state": stats = combined_stats.loc[ @@ -479,9 +514,7 @@ def _select_stats_by_category(self, combined_stats, category): stats = stats.drop_vars(["forcing_mean", "forcing_std"]) return stats elif category == "forcing": - non_normalized_vars = ( - self.utilities.normalization.non_normalized_vars - ) + non_normalized_vars = self.utilities.normalization.non_normalized_vars if non_normalized_vars is None: non_normalized_vars = [] forcing_vars = self.vars_names(category) @@ -517,6 +550,7 @@ def _extract_vars(self, category, ds=None): Returns: xr.Dataset: The xarray Dataset object with extracted variables. + """ if ds is None: ds = self.open_zarrs(category) @@ -529,9 +563,7 @@ def _extract_vars(self, category, ds=None): ds_atmosphere = None if atmoshere_vars is not None: - ds_atmosphere = self._extract_atmosphere_vars( - category=category, ds=ds - ) + ds_atmosphere = self._extract_atmosphere_vars(category=category, ds=ds) if ds_surface and ds_atmosphere: return xr.merge([ds_surface, ds_atmosphere]) @@ -551,15 +583,11 @@ def _extract_atmosphere_vars(self, category, ds): Returns: xr.Dataset: The xarray Dataset object with atmosphere variables. + """ - if ( - "level" not in list(ds.dims) - and self._config[category]["atmosphere_vars"] - ): - ds = self._rename_dataset_dims_and_vars( - ds.attrs["category"], dataset=ds - ) + if "level" not in list(ds.dims) and self._config[category]["atmosphere_vars"]: + ds = self._rename_dataset_dims_and_vars(ds.attrs["category"], dataset=ds) data_arrays = [ ds[var].sel(level=level, drop=True).rename(f"{var}_{level}") @@ -585,6 +613,7 @@ def _rename_dataset_dims_and_vars(self, category, dataset=None): variables. OR xr.DataArray: The xarray DataArray object with renamed dimensions and variables. + """ convert = False if dataset is None: @@ -620,6 +649,7 @@ def _apply_time_split(self, dataset, split="train"): Returns:["window"] xr.Dataset: The xarray Dataset object filtered by the time split. + """ start, end = ( self._config["splits"][split]["start"], @@ -635,6 +665,7 @@ def grid_shape_state(self): Returns: CartesianGridShape: The shape of the state grid. + """ return CartesianGridShape( x=self._config["grid_shape_state"]["x"], @@ -643,13 +674,13 @@ def grid_shape_state(self): @property def boundary_mask(self) -> xr.DataArray: - """Load the boundary mask for the dataset, with spatial dimensions - stacked. + """Load the boundary mask for the dataset, with spatial dimensions stacked. Returns ------- xr.DataArray The boundary mask for the dataset, with dimensions `('grid_index',)`. + """ boundary_mask_path = self._normalize_path( self._config["boundary"]["mask"]["path"] @@ -670,6 +701,7 @@ def get_dataarray(self, category, split="train"): Returns: xr.DataArray: The xarray DataArray object with processed dataset. + """ dataset = self.open_zarrs(category) dataset = self._extract_vars(category, dataset) diff --git a/neural_lam/datastore/npyfiles/config.py b/neural_lam/datastore/npyfiles/config.py index afb08c77..5cdb22ea 100644 --- a/neural_lam/datastore/npyfiles/config.py +++ b/neural_lam/datastore/npyfiles/config.py @@ -8,14 +8,14 @@ @dataclass class Projection: - """Represents the projection information for a dataset, including the type - of projection and its parameters. Capable of creating a cartopy.crs - projection object. + """Represents the projection information for a dataset, including the type of + projection and its parameters. Capable of creating a cartopy.crs projection object. Attributes: class_name: The class name of the projection, this should be a valid cartopy.crs class. kwargs: A dictionary of keyword arguments specific to the projection type. + """ class_name: str @@ -24,8 +24,8 @@ class Projection: @dataclass class Dataset: - """Contains information about the dataset, including variable names, units, - and descriptions. + """Contains information about the dataset, including variable names, units, and + descriptions. Attributes: name: The name of the dataset. @@ -33,6 +33,7 @@ class Dataset: var_units: A list of units for each variable. var_longnames: A list of long, descriptive names for each variable. num_forcing_features: The number of forcing features in the dataset. + """ name: str @@ -44,13 +45,14 @@ class Dataset: @dataclass class NpyDatastoreConfig(dataclass_wizard.YAMLWizard): - """Configuration for loading and processing a dataset, including dataset - details, grid shape, and projection information. + """Configuration for loading and processing a dataset, including dataset details, + grid shape, and projection information. Attributes: dataset: An instance of Dataset containing details about the dataset. grid_shape_state: A list representing the shape of the grid state. projection: An instance of Projection containing projection details. + """ dataset: Dataset diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py index 674c368d..630a8dd0 100644 --- a/neural_lam/datastore/npyfiles/store.py +++ b/neural_lam/datastore/npyfiles/store.py @@ -1,5 +1,5 @@ -"""Numpy-files based datastore to support the MEPS example dataset introduced -in neural-lam v0.1.0.""" +"""Numpy-files based datastore to support the MEPS example dataset introduced in neural- +lam v0.1.0.""" # Standard library import functools import re @@ -138,9 +138,17 @@ def __init__( self, config_path, ): - """Create a new NpyFilesDatastore using the configuration file at the - given path. The config file should be a YAML file and will be loaded - into an instance of the `NpyDatastoreConfig` dataclass. + """Create a new + NpyFilesDatastore + using the + configuration file at + the given path. The + config file should be + a YAML file and will + be loaded into an + instance of the + `NpyDatastoreConfig` + dataclass. Internally, the datastore uses dask.delayed to load the data from the numpy files, so that the data isn't actually loaded until it's needed. @@ -149,6 +157,7 @@ def __init__( ---------- config_path : str The path to the configuration file for the datastore. + """ # XXX: This should really be in the config file, not hard-coded in this class self._num_timesteps = 65 @@ -161,21 +170,32 @@ def __init__( @property def root_path(self) -> Path: - """The root path of the datastore on disk. This is the directory - relative to which graphs and other files can be stored. + """The root path of the datastore on disk. This is the directory relative to + which graphs and other files can be stored. Returns ------- Path The root path of the datastore + """ return self._root_path def get_dataarray(self, category: str, split: str) -> DataArray: - """Get the data array for the given category and split of data. If the - category is 'state', the data array will be a concatenation of the data - arrays for all ensemble members. The data will be loaded as a dask - array, so that the data isn't actually loaded until it's needed. + """Get the data array + for the given category + and split of data. If + the category is + 'state', the data + array will be a + concatenation of the + data arrays for all + ensemble members. The + data will be loaded as + a dask array, so that + the data isn't + actually loaded until + it's needed. Parameters ---------- @@ -193,6 +213,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray: ensemble_member]` forcing: `[elapsed_forecast_duration, analysis_time, grid_index, feature]` static: `[grid_index, feature]` + """ if category == "state": das = [] @@ -211,9 +232,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # them separately features = ["toa_downwelling_shortwave_flux", "column_water"] das = [ - self._get_single_timeseries_dataarray( - features=[feature], split=split - ) + self._get_single_timeseries_dataarray(features=[feature], split=split) for feature in features ] da = xr.concat(das, dim="feature") @@ -225,9 +244,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # .chunk({"elapsed_forecast_duration": 1}) this time variable is turned # into a dask array and so execution of the calculation is delayed # until the feature values are actually used. - da_forecast_time = ( - da.analysis_time + da.elapsed_forecast_duration - ).chunk({"elapsed_forecast_duration": 1}) + da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk( + {"elapsed_forecast_duration": 1} + ) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -248,9 +267,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray: features=features, split=split ) das.append(da) - da = xr.concat(das, dim="feature").transpose( - "grid_index", "feature" - ) + da = xr.concat(das, dim="feature").transpose("grid_index", "feature") else: raise NotImplementedError(category) @@ -270,11 +287,21 @@ def get_dataarray(self, category: str, split: str) -> DataArray: def _get_single_timeseries_dataarray( self, features: List[str], split: str, member: int = None ) -> DataArray: - """Get the data array spanning the complete time series for a given set - of features and split of data. For state features the `member` argument - should be specified to select the ensemble member to load. The data - will be loaded using dask.delayed, so that the data isn't actually - loaded until it's needed. + """Get the data array + spanning the complete + time series for a + given set of features + and split of data. For + state features the + `member` argument + should be specified to + select the ensemble + member to load. The + data will be loaded + using dask.delayed, so + that the data isn't + actually loaded until + it's needed. Parameters ---------- @@ -296,15 +323,12 @@ def _get_single_timeseries_dataarray( The data array for the given category and split, with dimensions `[elapsed_forecast_duration, analysis_time, grid_index, feature]` for all categories of data + """ assert split in ("train", "val", "test"), "Unknown dataset split" - if member is not None and features != self.get_vars_names( - category="state" - ): - raise ValueError( - "Member can only be specified for the 'state' category" - ) + if member is not None and features != self.get_vars_names(category="state"): + raise ValueError("Member can only be specified for the 'state' category") # XXX: we here assume that the grid shape is the same for all categories grid_shape = self.grid_shape_state @@ -387,9 +411,7 @@ def _get_single_timeseries_dataarray( if features_vary_with_analysis_time: filepaths = [ fp_samples - / filename_format.format( - analysis_time=analysis_time, **file_params - ) + / filename_format.format(analysis_time=analysis_time, **file_params) for analysis_time in coords["analysis_time"] ] else: @@ -425,8 +447,8 @@ def _get_single_timeseries_dataarray( return da def _get_analysis_times(self, split) -> List[np.datetime64]: - """Get the analysis times for the given split by parsing the filenames - of all the files found for the given split. + """Get the analysis times for the given split by parsing the filenames of all + the files found for the given split. Parameters ---------- @@ -437,6 +459,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: ------- List[dt.datetime] The analysis times for the given split. + """ pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT) pattern = re.sub(r"{member_id:[^}]*}", "*", pattern) @@ -449,9 +472,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: times.append(name_parts["analysis_time"]) if len(times) == 0: - raise ValueError( - f"No files found in {sample_dir} with pattern {pattern}" - ) + raise ValueError(f"No files found in {sample_dir} with pattern {pattern}") return times @@ -534,6 +555,7 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray: value of `stacked`: - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y. - `stacked==False`: shape `(2, N_y, N_x)` + """ # the array on disk has shape [2, N_x, N_y], but we want to return it @@ -557,6 +579,7 @@ def step_length(self) -> int: ------- int The length of each time step in hours. + """ return self._step_length @@ -568,19 +591,21 @@ def grid_shape_state(self) -> CartesianGridShape: ------- CartesianGridShape The shape of the cartesian grid for the state variables. + """ nx, ny = self.config.grid_shape_state return CartesianGridShape(x=nx, y=ny) @property def boundary_mask(self) -> xr.DataArray: - """The boundary mask for the dataset. This is a binary mask that is 1 - where the grid cell is on the boundary of the domain, and 0 otherwise. + """The boundary mask for the dataset. This is a binary mask that is 1 where the + grid cell is on the boundary of the domain, and 0 otherwise. Returns ------- xr.DataArray The boundary mask for the dataset, with dimensions `[grid_index]`. + """ xs, ys = self.get_xy(category="state", stacked=False) assert np.all(xs[:, 0] == xs[:, -1]) @@ -595,11 +620,24 @@ def boundary_mask(self) -> xr.DataArray: return da_mask_stacked_xy def get_normalization_dataarray(self, category: str) -> xr.Dataset: - """Return the normalization dataarray for the given category. This - should contain a `{category}_mean` and `{category}_std` variable for - each variable in the category. For `category=="state"`, the dataarray - should also contain a `state_diff_mean` and `state_diff_std` variable - for the one-step differences of the state variables. + """Return the + normalization + dataarray for the + given category. This + should contain a + `{category}_mean` and + `{category}_std` + variable for each + variable in the + category. For + `category=="state"`, + the dataarray should + also contain a + `state_diff_mean` and + `state_diff_std` + variable for the one- + step differences of + the state variables. Parameters ---------- @@ -612,6 +650,7 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset: The normalization dataarray for the given category, with variables for the mean and standard deviation of the variables (and differences for state variables). + """ def load_pickled_tensor(fn): @@ -666,6 +705,7 @@ def coords_projection(self) -> ccrs.Projection: ------- ccrs.Projection The projection of the spatial coordinates. + """ proj_class_name = self.config.projection.class_name ProjectionClass = getattr(ccrs, proj_class_name) diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 4ed3e3eb..5ad0fdca 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -11,6 +11,7 @@ class InteractionNet(pyg.nn.MessagePassing): """Implementation of a generic Interaction Network, from Battaglia et al. (2016) + """ # pylint: disable=arguments-differ @@ -43,6 +44,7 @@ def __init__( representation into and use separate MLPs for (None = no chunking, same MLP) aggr: Message aggregation method (sum/mean) + """ assert aggr in ("sum", "mean"), f"Unknown aggregation method: {aggr}" super().__init__(aggr=aggr) @@ -55,9 +57,7 @@ def __init__( edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0] # Store number of receiver nodes according to edge_index self.num_rec = edge_index[1].max() + 1 - edge_index[0] = ( - edge_index[0] + self.num_rec - ) # Make sender indices after rec + edge_index[0] = edge_index[0] + self.num_rec # Make sender indices after rec self.register_buffer("edge_index", edge_index, persistent=False) # Create MLPs @@ -83,8 +83,8 @@ def __init__( self.update_edges = update_edges def forward(self, send_rep, rec_rep, edge_rep): - """Apply interaction network to update the representations of receiver - nodes, and optionally the edge representations. + """Apply interaction network to update the representations of receiver nodes, + and optionally the edge representations. send_rep: (N_send, d_h), vector representations of sender nodes rec_rep: (N_rec, d_h), vector representations of receiver nodes @@ -94,6 +94,7 @@ def forward(self, send_rep, rec_rep, edge_rep): rec_rep: (N_rec, d_h), updated vector representations of receiver nodes (optionally) edge_rep: (M, d_h), updated vector representations of edges + """ # Always concatenate to [rec_nodes, send_nodes] for propagation, # but only aggregate to rec_nodes @@ -130,8 +131,11 @@ def aggregate(self, inputs, index, ptr, dim_size): class SplitMLPs(nn.Module): """Module that feeds chunks of input through different MLPs. - Split up input along dim -2 using given chunk sizes and feeds each - chunk through separate MLPs. + Split up input along dim + -2 using given chunk sizes + and feeds each chunk + through separate MLPs. + """ def __init__(self, mlps, chunk_sizes): @@ -150,6 +154,7 @@ def forward(self, x): Returns: joined_output: (..., N, d), concatenated results from the MLPs + """ chunks = torch.split(x, self.chunk_sizes, dim=-2) chunk_outputs = [ diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py index 1ed4fb08..324440a8 100644 --- a/neural_lam/metrics.py +++ b/neural_lam/metrics.py @@ -9,11 +9,10 @@ def get_metric(metric_name): Returns: metric: function implementing the metric + """ metric_name_lower = metric_name.lower() - assert ( - metric_name_lower in DEFINED_METRICS - ), f"Unknown metric: {metric_name}" + assert metric_name_lower in DEFINED_METRICS, f"Unknown metric: {metric_name}" return DEFINED_METRICS[metric_name_lower] @@ -31,22 +30,17 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars): Returns: metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending on reduction arguments. + """ # Only keep grid nodes in mask if mask is not None: - metric_entry_vals = metric_entry_vals[ - ..., mask, : - ] # (..., N', d_state) + metric_entry_vals = metric_entry_vals[..., mask, :] # (..., N', d_state) # Optionally reduce last two dimensions if average_grid: # Reduce grid first - metric_entry_vals = torch.mean( - metric_entry_vals, dim=-2 - ) # (..., d_state) + metric_entry_vals = torch.mean(metric_entry_vals, dim=-2) # (..., d_state) if sum_vars: # Reduce vars second - metric_entry_vals = torch.sum( - metric_entry_vals, dim=-1 - ) # (..., N) or (...,) + metric_entry_vals = torch.sum(metric_entry_vals, dim=-1) # (..., N) or (...,) return metric_entry_vals @@ -67,6 +61,7 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): Returns: metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending on reduction arguments. + """ entry_mse = torch.nn.functional.mse_loss( pred, target, reduction="none" @@ -97,11 +92,10 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): Returns: metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending on reduction arguments. + """ # Replace pred_std with constant ones - return wmse( - pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars - ) + return wmse(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars) def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): @@ -120,6 +114,7 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): Returns: metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending on reduction arguments. + """ entry_mae = torch.nn.functional.l1_loss( pred, target, reduction="none" @@ -150,11 +145,10 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): Returns: metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending on reduction arguments. + """ # Replace pred_std with constant ones - return wmae( - pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars - ) + return wmae(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars) def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): @@ -173,6 +167,7 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): Returns: metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending on reduction arguments. + """ # Broadcast pred_std if shaped (d_state,), done internally in Normal class dist = torch.distributions.Normal(pred, pred_std) # (..., N, d_state) @@ -183,11 +178,9 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): ) -def crps_gauss( - pred, target, pred_std, mask=None, average_grid=True, sum_vars=True -): - """(Negative) Continuous Ranked Probability Score (CRPS) Closed-form - expression based on Gaussian predictive distribution. +def crps_gauss(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): + """(Negative) Continuous Ranked Probability Score (CRPS) Closed-form expression + based on Gaussian predictive distribution. (...,) is any number of batch dimensions, potentially different but broadcastable @@ -202,6 +195,7 @@ def crps_gauss( Returns: metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending on reduction arguments. + """ std_normal = torch.distributions.Normal( torch.zeros((), device=pred.device), torch.ones((), device=pred.device) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index cea723b0..eadd9445 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -17,14 +17,13 @@ class ARModel(pl.LightningModule): """Generic auto-regressive weather model. Abstract class that can be extended. + """ # pylint: disable=arguments-differ # Disable to override args/kwargs from superclass - def __init__( - self, args, datastore: BaseDatastore, forcing_window_size: int - ): + def __init__(self, args, datastore: BaseDatastore, forcing_window_size: int): super().__init__() self.save_hyperparameters(ignore=["datastore"]) self.args = args @@ -33,17 +32,13 @@ def __init__( split = "train" num_state_vars = datastore.get_num_data_vars(category="state") num_forcing_vars = datastore.get_num_data_vars(category="forcing") - da_static_features = datastore.get_dataarray( - category="static", split=split - ) + da_static_features = datastore.get_dataarray(category="static", split=split) da_state_stats = datastore.get_normalization_dataarray(category="state") da_boundary_mask = datastore.boundary_mask # Load static features for grid/data, NB: self.predict_step assumes dimension # order to be (grid_index, static_feature) - arr_static = da_static_features.transpose( - "grid_index", "static_feature" - ).values + arr_static = da_static_features.transpose("grid_index", "static_feature").values self.register_buffer( "grid_static_features", torch.tensor(arr_static, dtype=torch.float32), @@ -136,9 +131,7 @@ def __init__( self.spatial_loss_maps = [] def configure_optimizers(self): - opt = torch.optim.AdamW( - self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) - ) + opt = torch.optim.AdamW(self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)) return opt @property @@ -185,8 +178,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states): # Overwrite border with true state new_state = ( - self.boundary_mask * border_state - + self.interior_mask * pred_state + self.boundary_mask * border_state + self.interior_mask * pred_state ) prediction_list.append(new_state) @@ -231,9 +223,7 @@ def training_step(self, batch): # Compute loss batch_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ) + self.loss(prediction, target, pred_std, mask=self.interior_mask_bool) ) # mean over unrolled times and batch log_dict = {"train_loss": batch_loss} @@ -248,12 +238,13 @@ def training_step(self, batch): return batch_loss def all_gather_cat(self, tensor_to_gather): - """Gather tensors across all ranks, and concatenate across dim. 0 - (instead of stacking in new dim. 0) + """Gather tensors across all ranks, and concatenate across dim. 0 (instead of + stacking in new dim. 0) tensor_to_gather: (d1, d2, ...), distributed over K ranks returns: (K*d1, d2, ...) + """ return self.all_gather(tensor_to_gather).flatten(0, 1) @@ -264,9 +255,7 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, _ = self.common_step(batch) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std, mask=self.interior_mask_bool), dim=0, ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) @@ -314,9 +303,7 @@ def test_step(self, batch, batch_idx): # pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std, mask=self.interior_mask_bool), dim=0, ) # (time_steps-1,) mean_loss = torch.mean(time_step_loss) @@ -368,19 +355,14 @@ def test_step(self, batch, batch_idx): # (B, N_log, num_grid_nodes) # Plot example predictions (on rank 0 only) - if ( - self.trainer.is_global_zero - and self.plotted_examples < self.n_example_pred - ): + if self.trainer.is_global_zero and self.plotted_examples < self.n_example_pred: # Need to plot more example predictions n_additional_examples = min( prediction.shape[0], self.n_example_pred - self.plotted_examples, ) - self.plot_examples( - batch, n_additional_examples, prediction=prediction - ) + self.plot_examples(batch, n_additional_examples, prediction=prediction) def plot_examples(self, batch, n_examples, prediction=None): """Plot the first n_examples forecasts from batch. @@ -389,6 +371,7 @@ def plot_examples(self, batch, n_examples, prediction=None): number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. Generate if None. + """ if prediction is None: prediction, target, _, _ = self.common_step(batch) @@ -457,16 +440,12 @@ def plot_examples(self, batch, n_examples, prediction=None): ) } ) - plt.close( - "all" - ) # Close all figs for this time step, saves memory + plt.close("all") # Close all figs for this time step, saves memory # Save pred and target as .pt files torch.save( pred_slice.cpu(), - os.path.join( - wandb.run.dir, f"example_pred_{self.plotted_examples}.pt" - ), + os.path.join(wandb.run.dir, f"example_pred_{self.plotted_examples}.pt"), ) torch.save( target_slice.cpu(), @@ -476,14 +455,15 @@ def plot_examples(self, batch, n_examples, prediction=None): ) def create_metric_log_dict(self, metric_tensor, prefix, metric_name): - """Put together a dict with everything to log for one metric. Also - saves plots as pdf and csv if using test prefix. + """Put together a dict with everything to log for one metric. Also saves plots + as pdf and csv if using test prefix. metric_tensor: (pred_steps, d_f), metric values per time and variable prefix: string, prefix to use for logging metric_name: string, name of the metric Return: log_dict: dict with everything to log for given metric + """ log_dict = {} metric_fig = vis.plot_error_map( @@ -496,9 +476,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): if prefix == "test": # Save pdf - metric_fig.savefig( - os.path.join(wandb.run.dir, f"{full_log_name}.pdf") - ) + metric_fig.savefig(os.path.join(wandb.run.dir, f"{full_log_name}.pdf")) # Save errors also as csv np.savetxt( os.path.join(wandb.run.dir, f"{full_log_name}.csv"), @@ -522,12 +500,12 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): return log_dict def aggregate_and_plot_metrics(self, metrics_dict, prefix): - """Aggregate and create error map plots for all metrics in - metrics_dict. + """Aggregate and create error map plots for all metrics in metrics_dict. metrics_dict: dictionary with metric_names and list of tensors with step-evals. prefix: string, prefix to use for logging + """ log_dict = {} for metric_name, metric_val_list in metrics_dict.items(): @@ -548,9 +526,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): metric_rescaled = metric_tensor_averaged * self.state_std # (pred_steps, d_f) log_dict.update( - self.create_metric_log_dict( - metric_rescaled, prefix, metric_name - ) + self.create_metric_log_dict(metric_rescaled, prefix, metric_name) ) if self.trainer.is_global_zero and not self.trainer.sanity_checking: @@ -560,8 +536,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): def on_test_epoch_end(self): """Compute test metrics and make plots at the end of test epoch. - Will gather stored tensors and perform plotting and logging on - rank 0. + Will gather stored tensors and perform plotting and logging on rank 0. + """ # Create error maps for all test metrics self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") @@ -582,9 +558,7 @@ def on_test_epoch_end(self): self.data_config, title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", ) - for t_i, loss_map in zip( - self.args.val_steps_to_log, mean_spatial_loss - ) + for t_i, loss_map in zip(self.args.val_steps_to_log, mean_spatial_loss) ] # log all to same wandb key, sequentially @@ -624,9 +598,7 @@ def on_load_checkpoint(self, checkpoint): ) ) for old_key in replace_keys: - new_key = old_key.replace( - "g2m_gnn.grid_mlp", "encoding_grid_mlp" - ) + new_key = old_key.replace("g2m_gnn.grid_mlp", "encoding_grid_mlp") loaded_state_dict[new_key] = loaded_state_dict[old_key] del loaded_state_dict[old_key] if not self.restore_opt: diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index a76fc518..158275dd 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -8,8 +8,8 @@ class BaseGraphModel(ARModel): - """Base (abstract) class for graph-based models building on the encode- - process- decode idea.""" + """Base (abstract) class for graph-based models building on the encode- process- + decode idea.""" def __init__(self, args, datastore, forcing_window_size): super().__init__( @@ -20,9 +20,7 @@ def __init__(self, args, datastore, forcing_window_size): # NOTE: (IMPORTANT!) mesh nodes MUST have the first # num_mesh_nodes indices, graph_dir_path = datastore.root_path / "graph" / args.graph - self.hierarchical, graph_ldict = utils.load_graph( - graph_dir_path=graph_dir_path - ) + self.hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path) for name, attr_value in graph_ldict.items(): # Make BufferLists module members and register tensors as buffers if isinstance(attr_value, torch.Tensor): @@ -44,9 +42,7 @@ def __init__(self, args, datastore, forcing_window_size): # Define sub-models # Feature embedders for grid self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1) - self.grid_embedder = utils.make_mlp( - [self.grid_dim] + self.mlp_blueprint_end - ) + self.grid_embedder = utils.make_mlp([self.grid_dim] + self.mlp_blueprint_end) self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) @@ -72,27 +68,26 @@ def __init__(self, args, datastore, forcing_window_size): # Output mapping (hidden_dim -> output_dim) self.output_map = utils.make_mlp( - [args.hidden_dim] * (args.hidden_layers + 1) - + [self.grid_output_dim], + [args.hidden_dim] * (args.hidden_layers + 1) + [self.grid_output_dim], layer_norm=False, ) # No layer norm on this one def get_num_mesh(self): - """Compute number of mesh nodes from loaded features, and number of - mesh nodes that should be ignored in encoding/decoding.""" + """Compute number of mesh nodes from loaded features, and number of mesh nodes + that should be ignored in encoding/decoding.""" raise NotImplementedError("get_num_mesh not implemented") def embedd_mesh_nodes(self): - """Embed static mesh features Returns tensor of shape (num_mesh_nodes, - d_h)""" + """Embed static mesh features Returns tensor of shape (num_mesh_nodes, d_h)""" raise NotImplementedError("embedd_mesh_nodes not implemented") def process_step(self, mesh_rep): - """Process step of embedd-process-decode framework Processes the - representation on the mesh, possible in multiple steps. + """Process step of embedd-process-decode framework Processes the representation + on the mesh, possible in multiple steps. mesh_rep: has shape (B, num_mesh_nodes, d_h) Returns mesh_rep: (B, num_mesh_nodes, d_h) + """ raise NotImplementedError("process_step not implemented") @@ -147,9 +142,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing): ) # (B, num_grid_nodes, d_h) # Map to output dimension, only for grid - net_output = self.output_map( - grid_rep - ) # (B, num_grid_nodes, d_grid_out) + net_output = self.output_map(grid_rep) # (B, num_grid_nodes, d_grid_out) if self.output_std: pred_delta_mean, pred_std_raw = net_output.chunk( diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index 14827f25..8bfc2c3e 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -96,28 +96,34 @@ def __init__(self, args): ) def get_num_mesh(self): - """Compute number of mesh nodes from loaded features, and number of - mesh nodes that should be ignored in encoding/decoding.""" + """Compute number of mesh nodes from loaded features, and number of mesh nodes + that should be ignored in encoding/decoding.""" num_mesh_nodes = sum( node_feat.shape[0] for node_feat in self.mesh_static_features ) - num_mesh_nodes_ignore = ( - num_mesh_nodes - self.mesh_static_features[0].shape[0] - ) + num_mesh_nodes_ignore = num_mesh_nodes - self.mesh_static_features[0].shape[0] return num_mesh_nodes, num_mesh_nodes_ignore def embedd_mesh_nodes(self): - """Embed static mesh features This embeds only bottom level, rest is - done at beginning of processing step Returns tensor of shape - (num_mesh_nodes[0], d_h)""" + """Embed static mesh + features This embeds + only bottom level, + rest is done at + beginning of + processing step + Returns tensor of + shape + (num_mesh_nodes[0], + d_h)""" return self.mesh_embedders[0](self.mesh_static_features[0]) def process_step(self, mesh_rep): - """Process step of embedd-process-decode framework Processes the - representation on the mesh, possible in multiple steps. + """Process step of embedd-process-decode framework Processes the representation + on the mesh, possible in multiple steps. mesh_rep: has shape (B, num_mesh_nodes, d_h) Returns mesh_rep: (B, num_mesh_nodes, d_h) + """ batch_size = mesh_rep.shape[0] @@ -136,21 +142,15 @@ def process_step(self, mesh_rep): # Embed edges, expand with batch dimension mesh_same_rep = [ self.expand_to_batch(emb(edge_feat), batch_size) - for emb, edge_feat in zip( - self.mesh_same_embedders, self.m2m_features - ) + for emb, edge_feat in zip(self.mesh_same_embedders, self.m2m_features) ] mesh_up_rep = [ self.expand_to_batch(emb(edge_feat), batch_size) - for emb, edge_feat in zip( - self.mesh_up_embedders, self.mesh_up_features - ) + for emb, edge_feat in zip(self.mesh_up_embedders, self.mesh_up_features) ] mesh_down_rep = [ self.expand_to_batch(emb(edge_feat), batch_size) - for emb, edge_feat in zip( - self.mesh_down_embedders, self.mesh_down_features - ) + for emb, edge_feat in zip(self.mesh_down_embedders, self.mesh_down_features) ] # - MESH INIT. - @@ -160,20 +160,14 @@ def process_step(self, mesh_rep): send_node_rep = mesh_rep_levels[ level_l - 1 ] # (B, num_mesh_nodes[l-1], d_h) - rec_node_rep = mesh_rep_levels[ - level_l - ] # (B, num_mesh_nodes[l], d_h) + rec_node_rep = mesh_rep_levels[level_l] # (B, num_mesh_nodes[l], d_h) edge_rep = mesh_up_rep[level_l - 1] # Apply GNN - new_node_rep, new_edge_rep = gnn( - send_node_rep, rec_node_rep, edge_rep - ) + new_node_rep, new_edge_rep = gnn(send_node_rep, rec_node_rep, edge_rep) # Update node and edge vectors in lists - mesh_rep_levels[ - level_l - ] = new_node_rep # (B, num_mesh_nodes[l], d_h) + mesh_rep_levels[level_l] = new_node_rep # (B, num_mesh_nodes[l], d_h) mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h) # - PROCESSOR - @@ -190,18 +184,14 @@ def process_step(self, mesh_rep): send_node_rep = mesh_rep_levels[ level_l + 1 ] # (B, num_mesh_nodes[l+1], d_h) - rec_node_rep = mesh_rep_levels[ - level_l - ] # (B, num_mesh_nodes[l], d_h) + rec_node_rep = mesh_rep_levels[level_l] # (B, num_mesh_nodes[l], d_h) edge_rep = mesh_down_rep[level_l] # Apply GNN new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep) # Update node and edge vectors in lists - mesh_rep_levels[ - level_l - ] = new_node_rep # (B, num_mesh_nodes[l], d_h) + mesh_rep_levels[level_l] = new_node_rep # (B, num_mesh_nodes[l], d_h) # Return only bottom level representation return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h) @@ -209,8 +199,8 @@ def process_step(self, mesh_rep): def hi_processor_step( self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep ): - """Internal processor step of hierarchical graph models. Between mesh - init and read out. + """Internal processor step of hierarchical graph models. Between mesh init and + read out. Each input is list with representations, each with shape @@ -220,5 +210,6 @@ def hi_processor_step( mesh_down_rep: (B, M_down[l <- l+1], d_h) Returns same lists + """ raise NotImplementedError("hi_process_step not implemented") diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index 723c4678..55befd02 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -8,20 +8,18 @@ class GraphLAM(BaseGraphModel): - """Full graph-based LAM model that can be used with different (non- - hierarchical )graphs. + """Full graph-based LAM model that can be used with different (non- hierarchical + )graphs. + + Mainly based on GraphCast, but the model from Keisler (2022) is almost identical. + Used for GC-LAM and L1-LAM in Oskarsson et al. (2023). - Mainly based on GraphCast, but the model from Keisler (2022) is - almost identical. Used for GC-LAM and L1-LAM in Oskarsson et al. - (2023). """ def __init__(self, args, datastore, forcing_window_size): super().__init__(args, datastore, forcing_window_size) - assert ( - not self.hierarchical - ), "GraphLAM does not use a hierarchical mesh graph" + assert not self.hierarchical, "GraphLAM does not use a hierarchical mesh graph" # grid_dim from data + static + batch_static mesh_dim = self.mesh_static_features.shape[1] @@ -56,8 +54,8 @@ def __init__(self, args, datastore, forcing_window_size): ) def get_num_mesh(self): - """Compute number of mesh nodes from loaded features, and number of - mesh nodes that should be ignored in encoding/decoding.""" + """Compute number of mesh nodes from loaded features, and number of mesh nodes + that should be ignored in encoding/decoding.""" return self.mesh_static_features.shape[0], 0 def embedd_mesh_nodes(self): @@ -65,20 +63,17 @@ def embedd_mesh_nodes(self): return self.mesh_embedder(self.mesh_static_features) # (N_mesh, d_h) def process_step(self, mesh_rep): - """Process step of embedd-process-decode framework Processes the - representation on the mesh, possible in multiple steps. + """Process step of embedd-process-decode framework Processes the representation + on the mesh, possible in multiple steps. mesh_rep: has shape (B, N_mesh, d_h) Returns mesh_rep: (B, N_mesh, d_h) + """ # Embed m2m here first batch_size = mesh_rep.shape[0] m2m_emb = self.m2m_embedder(self.m2m_features) # (M_mesh, d_h) - m2m_emb_expanded = self.expand_to_batch( - m2m_emb, batch_size - ) # (B, M_mesh, d_h) + m2m_emb_expanded = self.expand_to_batch(m2m_emb, batch_size) # (B, M_mesh, d_h) - mesh_rep, _ = self.processor( - mesh_rep, m2m_emb_expanded - ) # (B, N_mesh, d_h) + mesh_rep, _ = self.processor(mesh_rep, m2m_emb_expanded) # (B, N_mesh, d_h) return mesh_rep diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py index a7d55ba0..95300185 100644 --- a/neural_lam/models/hi_lam.py +++ b/neural_lam/models/hi_lam.py @@ -7,10 +7,11 @@ class HiLAM(BaseHiGraphModel): - """Hierarchical graph model with message passing that goes sequentially - down and up the hierarchy during processing. + """Hierarchical graph model with message passing that goes sequentially down and up + the hierarchy during processing. The Hi-LAM model from Oskarsson et al. (2023) + """ def __init__(self, args): @@ -79,8 +80,8 @@ def mesh_down_step( down_gnns, same_gnns, ): - """Run down-part of vertical processing, sequentially alternating - between processing using down edges and same-level edges.""" + """Run down-part of vertical processing, sequentially alternating between + processing using down edges and same-level edges.""" # Run same level processing on level L mesh_rep_levels[-1], mesh_same_rep[-1] = same_gnns[-1]( mesh_rep_levels[-1], mesh_rep_levels[-1], mesh_same_rep[-1] @@ -93,9 +94,7 @@ def mesh_down_step( reversed(same_gnns[:-1]), ): # Extract representations - send_node_rep = mesh_rep_levels[ - level_l + 1 - ] # (B, N_mesh[l+1], d_h) + send_node_rep = mesh_rep_levels[level_l + 1] # (B, N_mesh[l+1], d_h) rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) down_edge_rep = mesh_down_rep[level_l] same_edge_rep = mesh_same_rep[level_l] @@ -129,9 +128,7 @@ def mesh_up_step( zip(up_gnns, same_gnns[1:]), start=1 ): # Extract representations - send_node_rep = mesh_rep_levels[ - level_l - 1 - ] # (B, N_mesh[l-1], d_h) + send_node_rep = mesh_rep_levels[level_l - 1] # (B, N_mesh[l-1], d_h) rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) up_edge_rep = mesh_up_rep[level_l - 1] same_edge_rep = mesh_same_rep[level_l] @@ -153,8 +150,8 @@ def mesh_up_step( def hi_processor_step( self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep ): - """Internal processor step of hierarchical graph models. Between mesh - init and read out. + """Internal processor step of hierarchical graph models. Between mesh init and + read out. Each input is list with representations, each with shape @@ -164,6 +161,7 @@ def hi_processor_step( mesh_down_rep: (B, M_down[l <- l+1], d_h) Returns same lists + """ for down_gnns, down_same_gnns, up_gnns, up_same_gnns in zip( self.mesh_down_gnns, diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py index fe8152b3..26357281 100644 --- a/neural_lam/models/hi_lam_parallel.py +++ b/neural_lam/models/hi_lam_parallel.py @@ -8,11 +8,11 @@ class HiLAMParallel(BaseHiGraphModel): - """Version of HiLAM where all message passing in the hierarchical mesh (up, - down, inter-level) is ran in parallel. + """Version of HiLAM where all message passing in the hierarchical mesh (up, down, + inter-level) is ran in parallel. + + This is a somewhat simpler alternative to the sequential message passing of Hi-LAM. - This is a somewhat simpler alternative to the sequential message - passing of Hi-LAM. """ def __init__(self, args): @@ -52,8 +52,8 @@ def __init__(self, args): def hi_processor_step( self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep ): - """Internal processor step of hierarchical graph models. Between mesh - init and read out. + """Internal processor step of hierarchical graph models. Between mesh init and + read out. Each input is list with representations, each with shape @@ -63,6 +63,7 @@ def hi_processor_step( mesh_down_rep: (B, M_down[l <- l+1], d_h) Returns same lists + """ # First join all node and edge representations to single tensors @@ -75,9 +76,7 @@ def hi_processor_step( mesh_rep, mesh_edge_rep = self.processor(mesh_rep, mesh_edge_rep) # Split up again for read-out step - mesh_rep_levels = list( - torch.split(mesh_rep, self.level_mesh_sizes, dim=1) - ) + mesh_rep_levels = list(torch.split(mesh_rep, self.level_mesh_sizes, dim=1)) mesh_edge_rep_sections = torch.split( mesh_edge_rep, self.edge_split_sections, dim=1 ) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 4a69f1aa..4f011b76 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -38,9 +38,7 @@ def _init_datastore(datastore_kind, config_path): def main(input_args=None): """Main function for training and evaluating models.""" - parser = ArgumentParser( - description="Train or evaluate NeurWP models for LAM" - ) + parser = ArgumentParser(description="Train or evaluate NeurWP models for LAM") parser.add_argument( "datastore_kind", type=str, @@ -85,8 +83,7 @@ def main(input_args=None): "--restore_opt", type=int, default=0, - help="If optimizer state should be restored with model " - "(default: 0 (false))", + help="If optimizer state should be restored with model " "(default: 0 (false))", ) parser.add_argument( "--precision", @@ -100,8 +97,7 @@ def main(input_args=None): "--graph", type=str, default="multiscale", - help="Graph to load and use in graph-based model " - "(default: multiscale)", + help="Graph to load and use in graph-based model " "(default: multiscale)", ) parser.add_argument( "--hidden_dim", @@ -149,8 +145,7 @@ def main(input_args=None): "--control_only", type=int, default=0, - help="Train only on control member of ensemble data " - "(default: 0 (False))", + help="Train only on control member of ensemble data " "(default: 0 (False))", ) parser.add_argument( "--loss", @@ -165,8 +160,7 @@ def main(input_args=None): "--val_interval", type=int, default=1, - help="Number of epochs training between each validation run " - "(default: 1)", + help="Number of epochs training between each validation run " "(default: 1)", ) # Evaluation options @@ -187,8 +181,7 @@ def main(input_args=None): "--n_example_pred", type=int, default=1, - help="Number of example predictions to plot during evaluation " - "(default: 1)", + help="Number of example predictions to plot during evaluation " "(default: 1)", ) # Logger Settings @@ -261,9 +254,7 @@ def main(input_args=None): # Instantiate model + trainer if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision( - "high" - ) # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s else: device_name = "cpu" diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 79de3193..2ebe7b4d 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -9,11 +9,12 @@ class BufferList(nn.Module): - """A list of torch buffer tensors that sit together as a Module with no - parameters and only buffers. + """A list of torch buffer tensors that sit together as a Module with no parameters + and only buffers. This should be replaced by a native torch BufferList once implemented. See: https://github.com/pytorch/pytorch/issues/37386 + """ def __init__(self, buffer_tensors, persistent=True): @@ -74,6 +75,7 @@ def load_graph(graph_dir_path, device="cpu"): - mesh_up_features - mesh_down_features - mesh_static_features + """ def loads_file(fn): @@ -112,9 +114,7 @@ def loads_file(fn): ) # List of (N_mesh[l], d_mesh_static) # Some checks for consistency - assert ( - len(m2m_features) == n_levels - ), "Inconsistent number of levels in mesh" + assert len(m2m_features) == n_levels, "Inconsistent number of levels in mesh" assert ( len(mesh_static_features) == n_levels ), "Inconsistent number of levels in mesh" @@ -137,23 +137,15 @@ def loads_file(fn): # Rescale mesh_up_features = BufferList( - [ - edge_features / longest_edge - for edge_features in mesh_up_features - ], + [edge_features / longest_edge for edge_features in mesh_up_features], persistent=False, ) mesh_down_features = BufferList( - [ - edge_features / longest_edge - for edge_features in mesh_down_features - ], + [edge_features / longest_edge for edge_features in mesh_down_features], persistent=False, ) - mesh_static_features = BufferList( - mesh_static_features, persistent=False - ) + mesh_static_features = BufferList(mesh_static_features, persistent=False) else: # Extract single mesh level m2m_edge_index = m2m_edge_index[0] diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 98e066c4..e5c970c4 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -53,9 +53,7 @@ def plot_error_map( ax.set_yticks(np.arange(d_f)) var_names = datastore.get_vars_names(category="state") var_units = datastore.get_vars_units(category="state") - y_ticklabels = [ - f"{name} ({unit})" for name, unit in zip(var_names, var_units) - ] + y_ticklabels = [f"{name} ({unit})" for name, unit in zip(var_names, var_units)] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) if title: @@ -76,6 +74,7 @@ def plot_prediction( """Plot example prediction and grond truth. Each has shape (N_grid,) + """ # Get common scale for values if vrange is None: @@ -89,9 +88,7 @@ def plot_prediction( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + pixel_alpha = mask_reshaped.clamp(0.7, 1).cpu().numpy() # Faded border region fig, axes = plt.subplots( 1, @@ -104,9 +101,7 @@ def plot_prediction( for ax, data in zip(axes, (target, pred)): ax.coastlines() # Add coastline outlines data_grid = ( - data.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() - .numpy() + data.reshape(list(datastore.grid_shape_state.values.values())).cpu().numpy() ) im = ax.imshow( data_grid, @@ -143,12 +138,8 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): extent = data_config.get_xy_extent("state") # Set up masking of border region - mask_reshaped = obs_mask.reshape( - list(data_config.grid_shape_state.values.values()) - ) - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + mask_reshaped = obs_mask.reshape(list(data_config.grid_shape_state.values.values())) + pixel_alpha = mask_reshaped.clamp(0.7, 1).cpu().numpy() # Faded border region fig, ax = plt.subplots( figsize=(5, 4.8), @@ -157,9 +148,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): ax.coastlines() # Add coastline outlines error_grid = ( - error.reshape(list(data_config.grid_shape_state.values.values())) - .cpu() - .numpy() + error.reshape(list(data_config.grid_shape_state.values.values())).cpu().numpy() ) im = ax.imshow( diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5ba1d326..a8213922 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -15,6 +15,7 @@ class WeatherDataset(torch.utils.data.Dataset): """Dataset class for weather data. This class loads and processes weather data from a given datastore. + """ def __init__( @@ -31,9 +32,7 @@ def __init__( self.ar_steps = ar_steps self.datastore = datastore - self.da_state = self.datastore.get_dataarray( - category="state", split=self.split - ) + self.da_state = self.datastore.get_dataarray(category="state", split=self.split) self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) @@ -61,10 +60,8 @@ def __init__( self.da_state_std = self.ds_state_stats.state_std if self.da_forcing is not None: - self.ds_forcing_stats = ( - self.datastore.get_normalization_dataarray( - category="forcing" - ) + self.ds_forcing_stats = self.datastore.get_normalization_dataarray( + category="forcing" ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std @@ -91,11 +88,23 @@ def __len__(self): return len(self.da_state.time) - self.ar_steps - 1 def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0): - """Produce a time slice of the given dataarray `da` (state or forcing) - starting at `idx` and with `n_steps` steps. The `n_timesteps_offset` - parameter is used to offset the start of the sample, for example to - exclude the first two steps when sampling the forcing data (and to - produce the windowing samples of forcing data by increasing the offset + """Produce a time + slice of the given + dataarray `da` (state + or forcing) starting + at `idx` and with + `n_steps` steps. The + `n_timesteps_offset` + parameter is used to + offset the start of + the sample, for + example to exclude the + first two steps when + sampling the forcing + data (and to produce + the windowing samples + of forcing data by + increasing the offset for each window). Parameters @@ -109,6 +118,7 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0): The index of the time step to start the sample from. n_steps : int The number of time steps to include in the sample. + """ # selecting the time slice if self.datastore.is_forecast: @@ -129,15 +139,13 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0): else: # only `time` dimension for analysis only data da = da.isel( - time=slice( - idx + n_timesteps_offset, idx + n_steps + n_timesteps_offset - ) + time=slice(idx + n_timesteps_offset, idx + n_steps + n_timesteps_offset) ) return da def __getitem__(self, idx): - """Return a single training sample, which consists of the initial - states, target states, forcing and batch times. + """Return a single training sample, which consists of the initial states, target + states, forcing and batch times. The implementation currently uses xarray.DataArray objects for the normalisation so that we can make us of xarray's broadcasting @@ -158,6 +166,7 @@ def __getitem__(self, idx): A training sample object containing the initial states, target states, forcing and batch times. The batch times are the times of the target steps. + """ # handling ensemble data if self.datastore.is_ensemble: @@ -182,9 +191,7 @@ def __getitem__(self, idx): # handle time sampling in a way that is compatible with both analysis # and forecast data - da_state = self._sample_time( - da=da_state, idx=idx, n_steps=2 + self.ar_steps - ) + da_state = self._sample_time(da=da_state, idx=idx, n_steps=2 + self.ar_steps) if da_forcing is not None: das_forcing = [] @@ -219,9 +226,7 @@ def __getitem__(self, idx): batch_times = da_target_states.time.values.astype(float) if self.standardize: - da_init_states = ( - da_init_states - self.da_state_mean - ) / self.da_state_std + da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean ) / self.da_state_std @@ -239,9 +244,7 @@ def __getitem__(self, idx): ) init_states = torch.tensor(da_init_states.values, dtype=torch.float32) - target_states = torch.tensor( - da_target_states.values, dtype=torch.float32 - ) + target_states = torch.tensor(da_target_states.values, dtype=torch.float32) if self.da_forcing is None: # create an empty forcing tensor @@ -250,9 +253,7 @@ def __getitem__(self, idx): dtype=torch.float32, ) else: - forcing = torch.tensor( - da_forcing_windowed.values, dtype=torch.float32 - ) + forcing = torch.tensor(da_forcing_windowed.values, dtype=torch.float32) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) @@ -264,8 +265,9 @@ def __getitem__(self, idx): def __iter__(self): """Convenience method to iterate over the dataset. - This isn't used by pytorch DataLoader which itself implements an - iterator that uses Dataset.__getitem__ and Dataset.__len__. + This isn't used by pytorch DataLoader which itself implements an iterator that + uses Dataset.__getitem__ and Dataset.__len__. + """ for i in range(len(self)): yield self[i] diff --git a/plot_graph.py b/plot_graph.py index b7b710bf..e84bb627 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -64,9 +64,7 @@ def main(): # Add in z-dimension z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) - grid_pos = np.concatenate( - (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 - ) + grid_pos = np.concatenate((grid_pos, np.expand_dims(z_grid, axis=1)), axis=1) # List of edges to plot, (edge_index, color, line_width, label) edge_plot_list = [ @@ -118,9 +116,7 @@ def main(): z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees mesh_node_size = mesh_degrees / 2 - mesh_pos = np.concatenate( - (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1 - ) + mesh_pos = np.concatenate((mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1) edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M")) diff --git a/pyproject.toml b/pyproject.toml index e661ff46..1c86119c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,9 +45,6 @@ dev = [ [tool.setuptools] py-modules = ["neural_lam"] -[tool.black] -line-length = 80 - [tool.isort] default_section = "THIRDPARTY" profile = "black" @@ -70,7 +67,7 @@ known_first_party = [ ] [tool.flake8] -max-line-length = 80 +max-line-length = 88 ignore = [ "E203", # Allow whitespace before ':' (https://github.com/PyCQA/pycodestyle/issues/373) "I002", # Don't check for isort configuration @@ -114,3 +111,9 @@ min-similarity-lines=10 [build-system] requires = ["pdm-backend"] build-backend = "pdm.backend" + + +[tool.docformatter] +recursive = true +blank = true +black = true diff --git a/tests/conftest.py b/tests/conftest.py index 1f4edd1a..c8afc109 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,14 +66,15 @@ def download_meps_example_reduced_dataset(): def bootstrap_multizarr_example(): - """Run the steps that are needed to prepare the input data for the - multizarr datastore example. This includes: + """Run the steps that are needed to prepare the input data for the multizarr + datastore example. This includes: - Downloading the two zarr datasets (since training directly from S3 is error-prone as the connection often breaks) - Creating the datetime forcings zarr - Creating the normalization stats zarr - Creating the boundary mask zarr + """ multizarr_path = DATASTORE_EXAMPLES_ROOT_PATH / "multizarr" n_boundary_cells = 10 @@ -104,8 +105,7 @@ def bootstrap_multizarr_example(): # here assume that the data-config is referring the the default path # for the "datetime forcings" dataset datetime_forcing_zarr_path = ( - data_config_path.parent - / multizarr.create_datetime_forcings.DEFAULT_FILENAME + data_config_path.parent / multizarr.create_datetime_forcings.DEFAULT_FILENAME ) if not datetime_forcing_zarr_path.exists(): multizarr.create_datetime_forcings.create_datetime_forcing_zarr( @@ -113,8 +113,7 @@ def bootstrap_multizarr_example(): ) normalized_forcing_zarr_path = ( - data_config_path.parent - / multizarr.create_normalization_stats.DEFAULT_FILENAME + data_config_path.parent / multizarr.create_normalization_stats.DEFAULT_FILENAME ) if not normalized_forcing_zarr_path.exists(): multizarr.create_normalization_stats.create_normalization_stats_zarr( @@ -122,8 +121,7 @@ def bootstrap_multizarr_example(): ) boundary_mask_path = ( - data_config_path.parent - / multizarr.create_boundary_mask.DEFAULT_FILENAME + data_config_path.parent / multizarr.create_boundary_mask.DEFAULT_FILENAME ) if not boundary_mask_path.exists(): @@ -139,9 +137,7 @@ def bootstrap_multizarr_example(): DATASTORES_EXAMPLES = dict( multizarr=dict(config_path=bootstrap_multizarr_example()), mllam=dict( - config_path=DATASTORE_EXAMPLES_ROOT_PATH - / "mllam" - / "danra.example.yaml" + config_path=DATASTORE_EXAMPLES_ROOT_PATH / "mllam" / "danra.example.yaml" ), npyfiles=dict(config_path=download_meps_example_reduced_dataset()), ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 19ca1ed8..0dbd04a1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,8 +5,8 @@ def test_import(): - """This test just ensures that each cli entry-point can be imported for - now, eventually we should test their execution too.""" + """This test just ensures that each cli entry-point can be imported for now, + eventually we should test their execution too.""" assert neural_lam is not None assert neural_lam.create_graph is not None assert neural_lam.train_model is not None diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 7e73f787..8ae9d917 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_dataset_item(datastore_name): - """Check that the `datastore.get_dataarray` method is implemented. + """Check that the `datasto re.get_dataarray` method is implemented. Validate the shapes of the tensors match between the different components of the training sample. @@ -23,6 +23,7 @@ def test_dataset_item(datastore_name): init_states: (2, N_grid, d_features) target_states: (ar_steps, N_grid, d_features) forcing: (ar_steps, N_grid, d_windowed_forcing) # batch_times: (ar_steps,) + """ datastore = init_datastore(datastore_name) N_gridpoints = datastore.grid_shape_state.x * datastore.grid_shape_state.y @@ -59,8 +60,7 @@ def test_dataset_item(datastore_name): assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints assert ( - forcing.shape[2] - == datastore.get_num_data_vars("forcing") * forcing_window_size + forcing.shape[2] == datastore.get_num_data_vars("forcing") * forcing_window_size ) # batch times @@ -76,15 +76,14 @@ def test_dataset_item(datastore_name): @pytest.mark.parametrize("split", ["train", "val", "test"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_single_batch(datastore_name, split): - """Check that the `datastore.get_dataarray` method is implemented. + """Check that the `datasto re.get_dataarray` method is implemented. And that it returns an xarray DataArray with the correct dimensions. + """ datastore = init_datastore(datastore_name) - device_name = ( # noqa - torch.device("cuda") if torch.cuda.is_available() else "cpu" - ) + device_name = torch.device("cuda") if torch.cuda.is_available() else "cpu" # noqa graph_name = "1level" diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 198d4460..319c5a7c 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -23,6 +23,7 @@ - [x] `get_xy` (method): Return the x, y coordinates of the dataset. - [x] `coords_projection` (property): Projection object for the coordinates. - [x] `grid_shape_state` (property): Shape of the grid for the state variables. + """ # Standard library @@ -57,9 +58,9 @@ def test_step_length(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_datastore_grid_xy(datastore_name): - """Use the `datastore.get_xy` method to get the x, y coordinates of the - dataset and check that the shape is correct against the - `datastore.grid_shape_state` property.""" + """Use the `datastore.get_xy` method to get the x, y coordinates of the dataset and + check that the shape is correct against the `da tastore.grid_shape_state` + property.""" datastore = init_datastore(datastore_name) # check the shapes of the xy grid @@ -87,6 +88,7 @@ def test_get_vars(datastore_name): are consistent (as in the number of variables are the same) and that the return types of each are correct. + """ datastore = init_datastore(datastore_name) @@ -103,7 +105,7 @@ def test_get_vars(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_normalization_dataarray(datastore_name): - """Check that the `datastore.get_normalization_dataarray` method is + """Check that the `datasto re.get_normalization_dataa rray` method is implemented.""" datastore = init_datastore(datastore_name) @@ -132,9 +134,10 @@ def test_get_normalization_dataarray(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_dataarray(datastore_name): - """Check that the `datastore.get_dataarray` method is implemented. + """Check that the `datasto re.get_dataarray` method is implemented. And that it returns an xarray DataArray with the correct dimensions. + """ datastore = init_datastore(datastore_name) @@ -176,8 +179,8 @@ def test_get_dataarray(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_boundary_mask(datastore_name): - """Check that the `datastore.boundary_mask` property is implemented and - that the returned object is an xarray DataArray with the correct shape.""" + """Check that the `datastore.boundary_mask` property is implemented and that the + returned object is an xarray DataArray with the correct shape.""" datastore = init_datastore(datastore_name) da_mask = datastore.boundary_mask @@ -195,8 +198,8 @@ def test_boundary_mask(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_xy_extent(datastore_name): - """Check that the `datastore.get_xy_extent` method is implemented and that - the returned object is a tuple of the correct length.""" + """Check that the `datastore.get_xy_extent` method is implemented and that the + returned object is a tuple of the correct length.""" datastore = init_datastore(datastore_name) if not isinstance(datastore, BaseCartesianDatastore): @@ -247,7 +250,7 @@ def test_get_xy(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_projection(datastore_name): - """Check that the `datastore.coords_projection` property is implemented.""" + """Check that the `datasto re.coords_projection` property is implemented.""" datastore = init_datastore(datastore_name) if not isinstance(datastore, BaseCartesianDatastore): @@ -258,7 +261,7 @@ def test_get_projection(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def get_grid_shape_state(datastore_name): - """Check that the `datastore.grid_shape_state` property is implemented.""" + """Check that the `datasto re.grid_shape_state` property is implemented.""" datastore = init_datastore(datastore_name) if not isinstance(datastore, BaseCartesianDatastore): diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index 6384c46f..652e3dce 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -14,9 +14,10 @@ @pytest.mark.parametrize("graph_name", ["1level", "multiscale", "hierarchical"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_graph_creation(datastore_name, graph_name): - """Check that the `create_graph_from_datastore` function is implemented. + """Check that the `create_ graph_from_datastore` function is implemented. And that the graph is created in the correct location. + """ datastore = init_datastore(datastore_name) if graph_name == "hierarchical": @@ -80,9 +81,7 @@ def test_graph_creation(datastore_name, graph_name): assert isinstance(result, torch.Tensor) if file_id.endswith("_index"): - assert ( - result.shape[0] == 2 - ) # adjacency matrix uses two rows + assert result.shape[0] == 2 # adjacency matrix uses two rows elif file_id.endswith("_features"): assert result.shape[1] == d_features @@ -91,9 +90,7 @@ def test_graph_creation(datastore_name, graph_name): if not hierarchical: assert len(result) == 1 else: - if file_id.startswith("mesh_up") or file_id.startswith( - "mesh_down" - ): + if file_id.startswith("mesh_up") or file_id.startswith("mesh_down"): assert len(result) == n_max_levels - 1 else: assert len(result) == n_max_levels diff --git a/tests/test_training.py b/tests/test_training.py index ee532656..94b36980 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -20,9 +20,7 @@ def test_training(datastore_name): if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision( - "high" - ) # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s else: device_name = "cpu"