diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index dc519e5b..dadac50d 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,4 +1,4 @@ -name: lint +name: Linting on: # trigger on pushes to any branch, but not main diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml new file mode 100644 index 00000000..71bff3d3 --- /dev/null +++ b/.github/workflows/run_tests.yml @@ -0,0 +1,45 @@ +name: Unit Tests + +on: + # trigger on pushes to any branch, but not main + push: + branches-ignore: + - main + # and also on PRs to main + pull_request: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install torch-geometric>=2.5.2 + - name: Load cache data + uses: actions/cache/restore@v4 + with: + path: data + key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0 + restore-keys: | + ${{ runner.os }}-meps-reduced-example-data-v0.1.0 + - name: Test with pytest + run: | + pytest -v -s + - name: Save cache data + uses: actions/cache/save@v4 + with: + path: data + key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0 diff --git a/.gitignore b/.gitignore index 2a12cf57..022206f5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,12 +2,14 @@ wandb slurm_log* saved_models +lightning_logs data graphs *.sif sweeps test_*.sh .vscode +*slurm* ### Python ### # Byte-compiled / optimized / DLL files diff --git a/CHANGELOG.md b/CHANGELOG.md index 63feff96..3544b299 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) ### Added +- Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub. Added caching of test data tp speed up running tests. + [/#38](https://github.com/mllam/neural-lam/pull/38) + @SimonKamuk - Replaced `constants.py` with `data_config.yaml` for data configuration management [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) @@ -27,6 +30,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Robust restoration of optimizer and scheduler using `ckpt_path` + [\#17](https://github.com/mllam/neural-lam/pull/17) + @sadamov + - Updated scripts and modules to use `data_config.yaml` instead of `constants.py` [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) @sadamov @@ -65,6 +72,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#29](https://github.com/mllam/neural-lam/pull/29) @leifdenby +- change copyright formulation in license to encompass all contributors + [\#47](https://github.com/mllam/neural-lam/pull/47) + @joeloskarsson ## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) diff --git a/LICENSE.txt b/LICENSE.txt index 1bb69de2..ed176ba1 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Joel Oskarsson, Tomas Landelius +Copyright (c) 2023 Neural-LAM Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index ba0bb3fe..1bdc6602 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +![Linting](https://github.com/mllam/neural-lam/actions/workflows/pre-commit.yml/badge.svg) +![Automatic tests](https://github.com/mllam/neural-lam/actions/workflows/run_tests.yml/badge.svg) +
@@ -279,6 +282,8 @@ pre-commit run --all-files ``` from the root directory of the repository. +Furthermore, all tests in the ```tests``` directory will be run upon pushing changes by a github action. Failure in any of the tests will also reject the push/PR. + # Contact If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch. You can open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/create_mesh.py b/create_mesh.py index f04b4d4b..41557a97 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -153,7 +153,7 @@ def prepend_node_index(graph, new_index): return networkx.relabel_nodes(graph, to_mapping, copy=True) -def main(): +def main(input_args=None): parser = ArgumentParser(description="Graph generation arguments") parser.add_argument( "--data_config", @@ -186,7 +186,7 @@ def main(): default=0, help="Generate hierarchical mesh graph (default: 0, no)", ) - args = parser.parse_args() + args = parser.parse_args(input_args) # Load grid positions config_loader = config.Config.from_file(args.data_config) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index cae1ae3e..c85cd5a3 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -1,10 +1,13 @@ # Standard library import os +import subprocess from argparse import ArgumentParser # Third-party import numpy as np import torch +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm # First-party @@ -12,6 +15,117 @@ from neural_lam.weather_dataset import WeatherDataset +class PaddedWeatherDataset(torch.utils.data.Dataset): + def __init__(self, base_dataset, world_size, batch_size): + super().__init__() + self.base_dataset = base_dataset + self.world_size = world_size + self.batch_size = batch_size + self.total_samples = len(base_dataset) + self.padded_samples = ( + (self.world_size * self.batch_size) - self.total_samples + ) % self.world_size + self.original_indices = list(range(len(base_dataset))) + self.padded_indices = list( + range(self.total_samples, self.total_samples + self.padded_samples) + ) + + def __getitem__(self, idx): + return self.base_dataset[ + self.original_indices[-1] + if idx >= self.total_samples + else idx % len(self.base_dataset) + ] + + def __len__(self): + return self.total_samples + self.padded_samples + + def get_original_indices(self): + return self.original_indices + + def get_original_window_indices(self, step_length): + return [ + i // step_length + for i in range(len(self.original_indices) * step_length) + ] + + +def get_rank(): + return int(os.environ.get("SLURM_PROCID", 0)) + + +def get_world_size(): + return int(os.environ.get("SLURM_NTASKS", 1)) + + +def setup(rank, world_size): # pylint: disable=redefined-outer-name + """Initialize the distributed group.""" + if "SLURM_JOB_NODELIST" in os.environ: + master_node = ( + subprocess.check_output( + "scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1", + shell=True, + ) + .strip() + .decode("utf-8") + ) + else: + print( + "\033[91mCareful, you are running this script with --distributed " + "without any scheduler. In most cases this will result in slower " + "execution and the --distributed flag should be removed.\033[0m" + ) + master_node = "localhost" + os.environ["MASTER_ADDR"] = master_node + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group( + "nccl" if torch.cuda.is_available() else "gloo", + rank=rank, + world_size=world_size, + ) + if rank == 0: + print( + f"Initialized {dist.get_backend()} " + f"process group with world size {world_size}." + ) + + +def save_stats( + static_dir_path, means, squares, flux_means, flux_squares, filename_prefix +): + means = ( + torch.stack(means) if len(means) > 1 else means[0] + ) # (N_batch, d_features,) + squares = ( + torch.stack(squares) if len(squares) > 1 else squares[0] + ) # (N_batch, d_features,) + mean = torch.mean(means, dim=0) # (d_features,) + second_moment = torch.mean(squares, dim=0) # (d_features,) + std = torch.sqrt(second_moment - mean**2) # (d_features,) + torch.save( + mean.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_mean.pt") + ) + torch.save( + std.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_std.pt") + ) + + if len(flux_means) == 0: + return + flux_means = ( + torch.stack(flux_means) if len(flux_means) > 1 else flux_means[0] + ) # (N_batch,) + flux_squares = ( + torch.stack(flux_squares) if len(flux_squares) > 1 else flux_squares[0] + ) # (N_batch,) + flux_mean = torch.mean(flux_means) # (,) + flux_second_moment = torch.mean(flux_squares) # (,) + flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) + torch.save( + torch.stack((flux_mean, flux_std)).cpu(), + os.path.join(static_dir_path, "flux_stats.pt"), + ) + + def main(): """ Pre-compute parameter weights to be used in loss function @@ -41,32 +155,52 @@ def main(): default=4, help="Number of workers in data loader (default: 4)", ) + parser.add_argument( + "--distributed", + type=int, + default=0, + help="Run the script in distributed mode (1) or not (0) (default: 0)", + ) args = parser.parse_args() + distributed = bool(args.distributed) + rank = get_rank() + world_size = get_world_size() config_loader = config.Config.from_file(args.data_config) - static_dir_path = os.path.join("data", config_loader.dataset.name, "static") - - # Create parameter weights based on height - # based on fig A.1 in graph cast paper - w_dict = { - "2": 1.0, - "0": 0.1, - "65": 0.065, - "1000": 0.1, - "850": 0.05, - "500": 0.03, - } - w_list = np.array( - [ - w_dict[par.split("_")[-2]] - for par in config_loader.dataset.var_longnames - ] - ) - print("Saving parameter weights...") - np.save( - os.path.join(static_dir_path, "parameter_weights.npy"), - w_list.astype("float32"), - ) + + if distributed: + + setup(rank, world_size) + device = torch.device( + f"cuda:{rank}" if torch.cuda.is_available() else "cpu" + ) + torch.cuda.set_device(device) if torch.cuda.is_available() else None + + if rank == 0: + static_dir_path = os.path.join( + "data", config_loader.dataset.name, "static" + ) + # Create parameter weights based on height + # based on fig A.1 in graph cast paper + w_dict = { + "2": 1.0, + "0": 0.1, + "65": 0.065, + "1000": 0.1, + "850": 0.05, + "500": 0.03, + } + w_list = np.array( + [ + w_dict[par.split("_")[-2]] + for par in config_loader.dataset.var_longnames + ] + ) + print("Saving parameter weights...") + np.save( + os.path.join(static_dir_path, "parameter_weights.npy"), + w_list.astype("float32"), + ) # Load dataset without any subsampling ds = WeatherDataset( @@ -75,47 +209,97 @@ def main(): subsample_step=1, pred_length=63, standardize=False, - ) # Without standardization + ) + if distributed: + ds = PaddedWeatherDataset( + ds, + world_size, + args.batch_size, + ) + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=False + ) + else: + sampler = None loader = torch.utils.data.DataLoader( - ds, args.batch_size, shuffle=False, num_workers=args.n_workers + ds, + args.batch_size, + shuffle=False, + num_workers=args.n_workers, + sampler=sampler, ) - # Compute mean and std.-dev. of each parameter (+ flux forcing) - # across full dataset - print("Computing mean and std.-dev. for parameters...") - means = [] - squares = [] - flux_means = [] - flux_squares = [] + + if rank == 0: + print("Computing mean and std.-dev. for parameters...") + means, squares, flux_means, flux_squares = [], [], [], [] + for init_batch, target_batch, forcing_batch in tqdm(loader): - batch = torch.cat( - (init_batch, target_batch), dim=1 - ) # (N_batch, N_t, N_grid, d_features) - means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,) + if distributed: + init_batch, target_batch, forcing_batch = ( + init_batch.to(device), + target_batch.to(device), + forcing_batch.to(device), + ) + # (N_batch, N_t, N_grid, d_features) + batch = torch.cat((init_batch, target_batch), dim=1) + # Flux at 1st windowed position is index 1 in forcing + flux_batch = forcing_batch[:, :, :, 1] + # (N_batch, d_features,) + means.append(torch.mean(batch, dim=(1, 2)).cpu()) squares.append( - torch.mean(batch**2, dim=(1, 2)) + torch.mean(batch**2, dim=(1, 2)).cpu() ) # (N_batch, d_features,) + flux_means.append(torch.mean(flux_batch).cpu()) # (,) + flux_squares.append(torch.mean(flux_batch**2).cpu()) # (,) - # Flux at 1st windowed position is index 1 in forcing - flux_batch = forcing_batch[:, :, :, 1] - flux_means.append(torch.mean(flux_batch)) # (,) - flux_squares.append(torch.mean(flux_batch**2)) # (,) + if distributed and world_size > 1: + means_gathered, squares_gathered = [None] * world_size, [ + None + ] * world_size + flux_means_gathered, flux_squares_gathered = [None] * world_size, [ + None + ] * world_size + dist.all_gather_object(means_gathered, torch.cat(means, dim=0)) + dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0)) + dist.all_gather_object(flux_means_gathered, flux_means) + dist.all_gather_object(flux_squares_gathered, flux_squares) - mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features) - second_moment = torch.mean(torch.cat(squares, dim=0), dim=0) - std = torch.sqrt(second_moment - mean**2) # (d_features) + if rank == 0: + means_gathered, squares_gathered = torch.cat( + means_gathered, dim=0 + ), torch.cat(squares_gathered, dim=0) + flux_means_gathered, flux_squares_gathered = torch.tensor( + flux_means_gathered + ), torch.tensor(flux_squares_gathered) - flux_mean = torch.mean(torch.stack(flux_means)) # (,) - flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,) - flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) - flux_stats = torch.stack((flux_mean, flux_std)) + original_indices = ds.get_original_indices() + means, squares = [means_gathered[i] for i in original_indices], [ + squares_gathered[i] for i in original_indices + ] + flux_means, flux_squares = [ + flux_means_gathered[i] for i in original_indices + ], [flux_squares_gathered[i] for i in original_indices] + else: + means = [torch.cat(means, dim=0)] # (N_batch, d_features,) + squares = [torch.cat(squares, dim=0)] # (N_batch, d_features,) + flux_means = [torch.tensor(flux_means)] # (N_batch,) + flux_squares = [torch.tensor(flux_squares)] # (N_batch,) + + if rank == 0: + save_stats( + static_dir_path, + means, + squares, + flux_means, + flux_squares, + "parameter", + ) - print("Saving mean, std.-dev, flux_stats...") - torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt")) - torch.save(std, os.path.join(static_dir_path, "parameter_std.pt")) - torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt")) + if distributed: + dist.barrier() - # Compute mean and std.-dev. of one-step differences across the dataset - print("Computing mean and std.-dev. for one-step differences...") + if rank == 0: + print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( config_loader.dataset.name, split="train", @@ -123,17 +307,35 @@ def main(): pred_length=63, standardize=True, ) # Re-load with standardization + if distributed: + ds_standard = PaddedWeatherDataset( + ds_standard, + world_size, + args.batch_size, + ) + sampler_standard = DistributedSampler( + ds_standard, num_replicas=world_size, rank=rank, shuffle=False + ) + else: + sampler_standard = None loader_standard = torch.utils.data.DataLoader( - ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers + ds_standard, + args.batch_size, + shuffle=False, + num_workers=args.n_workers, + sampler=sampler_standard, ) used_subsample_len = (65 // args.step_length) * args.step_length - diff_means = [] - diff_squares = [] - for init_batch, target_batch, _ in tqdm(loader_standard): - batch = torch.cat( - (init_batch, target_batch), dim=1 - ) # (N_batch, N_t', N_grid, d_features) + diff_means, diff_squares = [], [] + + for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0): + if distributed: + init_batch, target_batch = init_batch.to(device), target_batch.to( + device + ) + # (N_batch, N_t', N_grid, d_features) + batch = torch.cat((init_batch, target_batch), dim=1) # Note: batch contains only 1h-steps stepped_batch = torch.cat( [ @@ -144,24 +346,48 @@ def main(): ) # (N_batch', N_t, N_grid, d_features), # N_batch' = args.step_length*N_batch - batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] # (N_batch', N_t-1, N_grid, d_features) + diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu()) + # (N_batch', d_features,) + diff_squares.append(torch.mean(batch_diffs**2, dim=(1, 2)).cpu()) + # (N_batch', d_features,) + + if distributed and world_size > 1: + dist.barrier() + diff_means_gathered, diff_squares_gathered = [None] * world_size, [ + None + ] * world_size + dist.all_gather_object( + diff_means_gathered, torch.cat(diff_means, dim=0) + ) + dist.all_gather_object( + diff_squares_gathered, torch.cat(diff_squares, dim=0) + ) + + if rank == 0: + diff_means_gathered, diff_squares_gathered = torch.cat( + diff_means_gathered, dim=0 + ).view(-1, *diff_means[0].shape), torch.cat( + diff_squares_gathered, dim=0 + ).view( + -1, *diff_squares[0].shape + ) + original_indices = ds_standard.get_original_window_indices( + args.step_length + ) + diff_means, diff_squares = [ + diff_means_gathered[i] for i in original_indices + ], [diff_squares_gathered[i] for i in original_indices] - diff_means.append( - torch.mean(batch_diffs, dim=(1, 2)) - ) # (N_batch', d_features,) - diff_squares.append( - torch.mean(batch_diffs**2, dim=(1, 2)) - ) # (N_batch', d_features,) + diff_means = [torch.cat(diff_means, dim=0)] # (N_batch', d_features,) + diff_squares = [torch.cat(diff_squares, dim=0)] # (N_batch', d_features,) - diff_mean = torch.mean(torch.cat(diff_means, dim=0), dim=0) # (d_features) - diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0) - diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features) + if rank == 0: + save_stats(static_dir_path, diff_means, diff_squares, [], [], "diff") - print("Saving one-step difference mean and std.-dev...") - torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt")) - torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt")) + if distributed: + dist.destroy_process_group() if __name__ == "__main__": diff --git a/docs/notebooks/create_reduced_meps_dataset.ipynb b/docs/notebooks/create_reduced_meps_dataset.ipynb new file mode 100644 index 00000000..daba23c4 --- /dev/null +++ b/docs/notebooks/create_reduced_meps_dataset.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating meps_example_reduced\n", + "This notebook outlines how the small-size test dataset ```meps_example_reduced``` was created based on the slightly larger dataset ```meps_example```. The zipped up datasets are 263 MB and 2.6 GB, respectively. See [README.md](../../README.md) for info on how to download ```meps_example```.\n", + "\n", + "The dataset was reduced in size by reducing the number of grid points and variables.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Standard library\n", + "import os\n", + "\n", + "# Third-party\n", + "import numpy as np\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The number of grid points was reduced to 1/4 by halving the number of coordinates in both the x and y direction. This was done by removing a quarter of the grid points along each outer edge, so the center grid points would stay centered in the new set.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load existing grid\n", + "grid_xy = np.load('data/meps_example/static/nwp_xy.npy')\n", + "# Get slices in each dimension by cutting off a quarter along each edge\n", + "num_x, num_y = grid_xy.shape[1:]\n", + "x_slice = slice(num_x//4, 3*num_x//4)\n", + "y_slice = slice(num_y//4, 3*num_y//4)\n", + "# Index and save reduced grid\n", + "grid_xy_reduced = grid_xy[:, x_slice, y_slice]\n", + "np.save('data/meps_example_reduced/static/nwp_xy.npy', grid_xy_reduced)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "This cut out the border, so a new perimeter of 10 grid points was established as border (10 was also the border size in the original \"meps_example\").\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Outer 10 grid points are border\n", + "old_border_mask = np.load('data/meps_example/static/border_mask.npy')\n", + "assert np.all(old_border_mask[10:-10, 10:-10] == False)\n", + "assert np.all(old_border_mask[:10, :] == True)\n", + "assert np.all(old_border_mask[:, :10] == True)\n", + "assert np.all(old_border_mask[-10:,:] == True)\n", + "assert np.all(old_border_mask[:,-10:] == True)\n", + "\n", + "# Create new array with False everywhere but the outer 10 grid points\n", + "border_mask = np.zeros_like(grid_xy_reduced[0,:,:], dtype=bool)\n", + "border_mask[:10] = True\n", + "border_mask[:,:10] = True\n", + "border_mask[-10:] = True\n", + "border_mask[:,-10:] = True\n", + "np.save('data/meps_example_reduced/static/border_mask.npy', border_mask)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A few other files also needed to be copied using only the new reduced grid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load surface_geopotential.npy, index only values from the reduced grid, and save to new file\n", + "surface_geopotential = np.load('data/meps_example/static/surface_geopotential.npy')\n", + "surface_geopotential_reduced = surface_geopotential[x_slice, y_slice]\n", + "np.save('data/meps_example_reduced/static/surface_geopotential.npy', surface_geopotential_reduced)\n", + "\n", + "# Load pytorch file grid_features.pt\n", + "grid_features = torch.load('data/meps_example/static/grid_features.pt')\n", + "# Index only values from the reduced grid. \n", + "# First reshape from (num_grid_points_total, 4) to (num_grid_points_x, num_grid_points_y, 4), \n", + "# then index, then reshape back to new total number of grid points\n", + "print(grid_features.shape)\n", + "grid_features_new = grid_features.reshape(num_x, num_y, 4)[x_slice,y_slice,:].reshape((-1, 4))\n", + "# Save to new file\n", + "torch.save(grid_features_new, 'data/meps_example_reduced/static/grid_features.pt')\n", + "\n", + "# flux_stats.pt is just a vector of length 2, so the grid shape and variable changes does not change this file\n", + "torch.save(torch.load('data/meps_example/static/flux_stats.pt'), 'data/meps_example_reduced/static/flux_stats.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The number of variables was reduced by truncating the variable list to the first 8." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_vars = 8\n", + "\n", + "# Load parameter_weights.npy, truncate to first 8 variables, and save to new file\n", + "parameter_weights = np.load('data/meps_example/static/parameter_weights.npy')\n", + "parameter_weights_reduced = parameter_weights[:num_vars]\n", + "np.save('data/meps_example_reduced/static/parameter_weights.npy', parameter_weights_reduced)\n", + "\n", + "# Do the same for following 4 pytorch files\n", + "for file in ['diff_mean', 'diff_std', 'parameter_mean', 'parameter_std']:\n", + " old_file = torch.load(f'data/meps_example/static/{file}.pt')\n", + " new_file = old_file[:num_vars]\n", + " torch.save(new_file, f'data/meps_example_reduced/static/{file}.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly the files in each of the directories train, test, and val have to be reduced. The folders all have the same structure with files of the following types:\n", + "```\n", + "nwp_YYYYMMDDHH_mbrXXX.npy\n", + "wtr_YYYYMMDDHH.npy\n", + "nwp_toa_downwelling_shortwave_flux_YYYYMMDDHH.npy\n", + "```\n", + "with ```YYYYMMDDHH``` being some date with hours, and ```XXX``` being some 3-digit integer.\n", + "\n", + "The first type of file has x and y in dimensions 1 and 2, and variable index in dimension 3. Dimension 0 is unchanged.\n", + "The second type has has x and y in dimensions 1 and 2. Dimension 0 is unchanged.\n", + "The last type has just x and y as the only 2 dimensions.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(65, 268, 238, 18)\n", + "(65, 268, 238)\n" + ] + } + ], + "source": [ + "print(np.load('data/meps_example/samples/train/nwp_2022040100_mbr000.npy').shape)\n", + "print(np.load('data/meps_example/samples/train/nwp_toa_downwelling_shortwave_flux_2022040112.npy').shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following loop goes through each file in each sample folder and indexes them according to the dimensions given by the file name." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for sample in ['train', 'test', 'val']:\n", + " files = os.listdir(f'data/meps_example/samples/{sample}')\n", + "\n", + " for f in files:\n", + " data = np.load(f'data/meps_example/samples/{sample}/{f}')\n", + " if 'mbr' in f:\n", + " data = data[:,x_slice,y_slice,:num_vars]\n", + " elif 'wtr' in f:\n", + " data = data[x_slice, y_slice]\n", + " else:\n", + " data = data[:,x_slice,y_slice]\n", + " np.save(f'data/meps_example_reduced/samples/{sample}/{f}', data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly, the file ```data_config.yaml``` is modified manually by truncating the variable units, long and short names, and setting the new grid shape. Also the unit descriptions containing ```^``` was automatically parsed using latex, and to avoid having to install latex in the GitHub CI/CD pipeline, this was changed to ```**```. \n", + "\n", + "This new config file was placed in ```data/meps_example_reduced```, and that directory was then zipped and placed in a European Weather Cloud S3 bucket." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index f16a4a30..f1527849 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -21,8 +21,8 @@ dataset: var_units: - Pa - Pa - - r"$\mathrm{W}/\mathrm{m}^2$" - - r"$\mathrm{W}/\mathrm{m}^2$" + - $\mathrm{W}/\mathrm{m}^2$ + - $\mathrm{W}/\mathrm{m}^2$ - "" - "" - K @@ -33,9 +33,9 @@ dataset: - m/s - m/s - m/s - - r"$\mathrm{kg}/\mathrm{m}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" + - $\mathrm{kg}/\mathrm{m}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ var_longnames: - pres_heightAboveGround_0_instant - pres_heightAboveSea_0_instant diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 9cda9fc2..6ced211f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -83,8 +83,8 @@ def __init__(self, args): if self.output_std: self.test_metrics["output_std"] = [] # Treat as metric - # For making restoring of optimizer state optional (slight hack) - self.opt_state = None + # For making restoring of optimizer state optional + self.restore_opt = args.restore_opt # For example plotting self.n_example_pred = args.n_example_pred @@ -97,9 +97,6 @@ def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) - if self.opt_state: - opt.load_state_dict(self.opt_state) - return opt @property @@ -476,7 +473,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): # Check if metrics are watched, log exact values for specific vars if full_log_name in self.args.metrics_watch: for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var = self.config_loader.dataset.var_nums[var_i] + var = self.config_loader.dataset.var_names[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ @@ -597,3 +594,6 @@ def on_load_checkpoint(self, checkpoint): ) loaded_state_dict[new_key] = loaded_state_dict[old_key] del loaded_state_dict[old_key] + if not self.restore_opt: + opt = self.configure_optimizers() + checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 836b04ed..59a529eb 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,5 +1,6 @@ # Standard library import os +import shutil # Third-party import numpy as np @@ -250,7 +251,11 @@ def fractional_plot_bundle(fraction): Get the tueplots bundle, but with figure width as a fraction of the page width. """ - bundle = bundles.neurips2023(usetex=True, family="serif") + # If latex is not available, some visualizations might not render correctly, + # but will at least not raise an error. + # Alternatively, use unicode raised numbers. + usetex = True if shutil.which("latex") else False + bundle = bundles.neurips2023(usetex=usetex, family="serif") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] bundle["figure.figsize"] = ( diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 2b6abf15..8c9ca77c 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) # Plot pred and target @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): fig, ax = plt.subplots( figsize=(5, 4.8), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) ax.coastlines() # Add coastline outlines diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py new file mode 100644 index 00000000..f91170c9 --- /dev/null +++ b/tests/test_mllam_dataset.py @@ -0,0 +1,138 @@ +# Standard library +import os + +# Third-party +import pooch + +# First-party +from create_mesh import main as create_mesh +from neural_lam.config import Config +from neural_lam.utils import load_static_data +from neural_lam.weather_dataset import WeatherDataset +from train_model import main as train_model + +# Disable weights and biases to avoid unnecessary logging +# and to avoid having to deal with authentication +os.environ["WANDB_DISABLED"] = "true" + +# Initializing variables for the s3 client +S3_BUCKET_NAME = "mllam-testdata" +S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int" +S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip" +S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH]) +TEST_DATA_KNOWN_HASH = ( + "98c7a2f442922de40c6891fe3e5d190346889d6e0e97550170a82a7ce58a72b7" +) + + +def test_retrieve_data_ewc(): + # Download and unzip test data into data/meps_example_reduced + pooch.retrieve( + url=S3_FULL_PATH, + known_hash=TEST_DATA_KNOWN_HASH, + processor=pooch.Unzip(extract_dir=""), + path="data", + fname="meps_example_reduced.zip", + ) + + +def test_load_reduced_meps_dataset(): + # The data_config.yaml file is downloaded and extracted in + # test_retrieve_data_ewc together with the dataset itself + data_config_file = "data/meps_example_reduced/data_config.yaml" + dataset_name = "meps_example_reduced" + + dataset = WeatherDataset(dataset_name="meps_example_reduced") + config = Config.from_file(data_config_file) + + var_names = config.values["dataset"]["var_names"] + var_units = config.values["dataset"]["var_units"] + var_longnames = config.values["dataset"]["var_longnames"] + + assert len(var_names) == len(var_longnames) + assert len(var_names) == len(var_units) + + # in future the number of grid static features + # will be provided by the Dataset class itself + n_grid_static_features = 4 + # Hardcoded in model + n_input_steps = 2 + + n_forcing_features = config.values["dataset"]["num_forcing_features"] + n_state_features = len(var_names) + n_prediction_timesteps = dataset.sample_length - n_input_steps + + nx, ny = config.values["grid_shape_state"] + n_grid = nx * ny + + # check that the dataset is not empty + assert len(dataset) > 0 + + # get the first item + init_states, target_states, forcing = dataset[0] + + # check that the shapes of the tensors are correct + assert init_states.shape == (n_input_steps, n_grid, n_state_features) + assert target_states.shape == ( + n_prediction_timesteps, + n_grid, + n_state_features, + ) + assert forcing.shape == ( + n_prediction_timesteps, + n_grid, + n_forcing_features, + ) + + static_data = load_static_data(dataset_name=dataset_name) + + required_props = { + "border_mask", + "grid_static_features", + "step_diff_mean", + "step_diff_std", + "data_mean", + "data_std", + "param_weights", + } + + # check the sizes of the props + assert static_data["border_mask"].shape == (n_grid, 1) + assert static_data["grid_static_features"].shape == ( + n_grid, + n_grid_static_features, + ) + assert static_data["step_diff_mean"].shape == (n_state_features,) + assert static_data["step_diff_std"].shape == (n_state_features,) + assert static_data["data_mean"].shape == (n_state_features,) + assert static_data["data_std"].shape == (n_state_features,) + assert static_data["param_weights"].shape == (n_state_features,) + + assert set(static_data.keys()) == required_props + + +def test_create_graph_reduced_meps_dataset(): + args = [ + "--graph=hierarchical", + "--hierarchical=1", + "--data_config=data/meps_example_reduced/data_config.yaml", + "--levels=2", + ] + create_mesh(args) + + +def test_train_model_reduced_meps_dataset(): + args = [ + "--model=hi_lam", + "--data_config=data/meps_example_reduced/data_config.yaml", + "--n_workers=4", + "--epochs=1", + "--graph=hierarchical", + "--hidden_dim=16", + "--hidden_layers=1", + "--processor_layers=1", + "--ar_steps=1", + "--eval=val", + "--n_example_pred=0", + ] + train_model(args) diff --git a/train_model.py b/train_model.py index 390da6d4..03863275 100644 --- a/train_model.py +++ b/train_model.py @@ -1,4 +1,5 @@ # Standard library +import json import random import time from argparse import ArgumentParser @@ -22,7 +23,7 @@ } -def main(): +def main(input_args=None): """ Main function for training and evaluating models """ @@ -196,18 +197,21 @@ def main(): ) parser.add_argument( "--metrics_watch", - type=list, + nargs="+", default=[], help="List of metrics to watch, including any prefix (e.g. val_rmse)", ) parser.add_argument( "--var_leads_metrics_watch", - type=dict, - default={}, - help="Dict with variables and lead times to log watched metrics for", - ) - args = parser.parse_args() - + type=str, + default="{}", + help="""JSON string with variable-IDs and lead times to log watched + metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""", + ) + args = parser.parse_args(input_args) + args.var_leads_metrics_watch = { + int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() + } config_loader = config.Config.from_file(args.data_config) # Asserts for arguments @@ -265,14 +269,7 @@ def main(): # Load model parameters Use new args for model model_class = MODELS[args.model] - if args.load: - model = model_class.load_from_checkpoint(args.load, args=args) - if args.restore_opt: - # Save for later - # Unclear if this works for multi-GPU - model.opt_state = torch.load(args.load)["optimizer_states"][0] - else: - model = model_class(args) + model = model_class(args) prefix = "subset-" if args.subset_ds else "" if args.eval: @@ -327,13 +324,14 @@ def main(): ) print(f"Running evaluation on {args.eval}") - trainer.test(model=model, dataloaders=eval_loader) + trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load) else: # Train model trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, + ckpt_path=args.load, )