From 879cfec1b49d0255ed963f44b3a9f55d42c9920a Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Wed, 29 May 2024 16:07:36 +0200 Subject: [PATCH 1/5] Make restoration of optimizer and scheduler more robust (#17) ## Summary This pull request introduces specific enhancements to the model loading and optimizer/scheduler state restoration functionalities, improving flexibility and compatibility with multi-GPU setups. ## Detailed Changes - **Enhanced Model Loading for Multi-GPU**: Modified the model loading logic to better support multi-GPU environments by ensuring that optimizer states are only loaded when necessary and appropriate. - **Checkpoint Adjustments**: Adjusted how learning rate schedulers are restored from checkpoints to ensure they align correctly with the current training state ## Impact These changes provide users with greater control over how training states are restored and improve the script's functionality in distributed training environments. ## Testing [x] Changes have been tested in both single and multi-GPU setups ## Notes Further integration testing with different types of training configurations is recommended to fully validate the new functionalities. --------- Co-authored-by: Simon Adamov --- CHANGELOG.md | 4 ++++ neural_lam/models/ar_model.py | 10 +++++----- neural_lam/vis.py | 4 ++-- train_model.py | 12 +++--------- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63feff96..061aa6bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Robust restoration of optimizer and scheduler using `ckpt_path` + [\#17](https://github.com/mllam/neural-lam/pull/17) + @sadamov + - Updated scripts and modules to use `data_config.yaml` instead of `constants.py` [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) @sadamov diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 9cda9fc2..29b169d4 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -83,8 +83,8 @@ def __init__(self, args): if self.output_std: self.test_metrics["output_std"] = [] # Treat as metric - # For making restoring of optimizer state optional (slight hack) - self.opt_state = None + # For making restoring of optimizer state optional + self.restore_opt = args.restore_opt # For example plotting self.n_example_pred = args.n_example_pred @@ -97,9 +97,6 @@ def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) - if self.opt_state: - opt.load_state_dict(self.opt_state) - return opt @property @@ -597,3 +594,6 @@ def on_load_checkpoint(self, checkpoint): ) loaded_state_dict[new_key] = loaded_state_dict[old_key] del loaded_state_dict[old_key] + if not self.restore_opt: + opt = self.configure_optimizers() + checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 2b6abf15..8c9ca77c 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) # Plot pred and target @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): fig, ax = plt.subplots( figsize=(5, 4.8), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) ax.coastlines() # Add coastline outlines diff --git a/train_model.py b/train_model.py index 390da6d4..fe064384 100644 --- a/train_model.py +++ b/train_model.py @@ -265,14 +265,7 @@ def main(): # Load model parameters Use new args for model model_class = MODELS[args.model] - if args.load: - model = model_class.load_from_checkpoint(args.load, args=args) - if args.restore_opt: - # Save for later - # Unclear if this works for multi-GPU - model.opt_state = torch.load(args.load)["optimizer_states"][0] - else: - model = model_class(args) + model = model_class(args) prefix = "subset-" if args.subset_ds else "" if args.eval: @@ -327,13 +320,14 @@ def main(): ) print(f"Running evaluation on {args.eval}") - trainer.test(model=model, dataloaders=eval_loader) + trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load) else: # Train model trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, + ckpt_path=args.load, ) From 9d558d1f0d343cfe6e0babaa8d9e6c45b852fe21 Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Fri, 31 May 2024 12:12:58 +0200 Subject: [PATCH 2/5] Fix minor bugs in data_config.yaml workflow (#40) ### Summary https://github.com/mllam/neural-lam/pull/31 introduced three minor bugs that are fixed with this PR: - r"" strings are not required in units of `data_config.yaml` - dictionaries cannot be passed as argsparse, rather JSON strings. This bug is related to the flag `var_leads_metrics_watch` --------- Co-authored-by: joeloskarsson --- neural_lam/data_config.yaml | 10 +++++----- neural_lam/models/ar_model.py | 2 +- train_model.py | 13 +++++++++---- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index f16a4a30..f1527849 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -21,8 +21,8 @@ dataset: var_units: - Pa - Pa - - r"$\mathrm{W}/\mathrm{m}^2$" - - r"$\mathrm{W}/\mathrm{m}^2$" + - $\mathrm{W}/\mathrm{m}^2$ + - $\mathrm{W}/\mathrm{m}^2$ - "" - "" - K @@ -33,9 +33,9 @@ dataset: - m/s - m/s - m/s - - r"$\mathrm{kg}/\mathrm{m}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" + - $\mathrm{kg}/\mathrm{m}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ + - $\mathrm{m}^2/\mathrm{s}^2$ var_longnames: - pres_heightAboveGround_0_instant - pres_heightAboveSea_0_instant diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 29b169d4..6ced211f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -473,7 +473,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): # Check if metrics are watched, log exact values for specific vars if full_log_name in self.args.metrics_watch: for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var = self.config_loader.dataset.var_nums[var_i] + var = self.config_loader.dataset.var_names[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ diff --git a/train_model.py b/train_model.py index fe064384..cbd787f0 100644 --- a/train_model.py +++ b/train_model.py @@ -1,4 +1,5 @@ # Standard library +import json import random import time from argparse import ArgumentParser @@ -196,17 +197,21 @@ def main(): ) parser.add_argument( "--metrics_watch", - type=list, + nargs="+", default=[], help="List of metrics to watch, including any prefix (e.g. val_rmse)", ) parser.add_argument( "--var_leads_metrics_watch", - type=dict, - default={}, - help="Dict with variables and lead times to log watched metrics for", + type=str, + default="{}", + help="""JSON string with variable-IDs and lead times to log watched + metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""", ) args = parser.parse_args() + args.var_leads_metrics_watch = { + int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() + } config_loader = config.Config.from_file(args.data_config) From e5400bbfa92d959d0f4856b90786abb18d282754 Mon Sep 17 00:00:00 2001 From: Joel Oskarsson Date: Mon, 3 Jun 2024 14:39:41 +0200 Subject: [PATCH 3/5] Change copyright notice to specify all contributors (#47) ## Motivation As more people are now contributing to the code the copyright does not just belong to me and Tomas. To avoid having to update this with the name of every person contributing, I suggest to take some inspiration from e.g. https://github.com/pyg-team/pytorch_geometric/blob/master/LICENSE and https://github.com/numpy/numpy/blob/main/LICENSE.txt and use a general formulation "Neural-LAM contributors". ## Description of change Change the copyright notice in the MIT license to "Neural-LAM Contributors". The year can stay 2023, as that is the first year the work (the code) was published. As me and Tomas are included under "Neural-LAM Contributors" this is strictly expanding the number of copyright holders. --- CHANGELOG.md | 3 +++ LICENSE.txt | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 061aa6bb..fd836c7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -69,6 +69,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#29](https://github.com/mllam/neural-lam/pull/29) @leifdenby +- change copyright formulation in license to encompass all contributors + [\#47](https://github.com/mllam/neural-lam/pull/47) + @joeloskarsson ## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) diff --git a/LICENSE.txt b/LICENSE.txt index 1bb69de2..ed176ba1 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Joel Oskarsson, Tomas Landelius +Copyright (c) 2023 Neural-LAM Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal From 743c07ac9f20ff16f05fcac5528196b8c4639c17 Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:45:03 +0200 Subject: [PATCH 4/5] Parallelize parameter weight computation using PyTorch Distributed (#22) ## Description This PR introduces parallelization to the `create_parameter_weights.py` script using PyTorch Distributed. The main changes include: 1. Added functions `get_rank()`, `get_world_size()`, `setup()`, and `cleanup()` to initialize and manage the distributed process group. - `get_rank()` retrieves the rank of the current process in the distributed group. - `get_world_size()` retrieves the total number of processes in the distributed group. - `setup()` initializes the distributed process group using NCCL (for GPU) or gloo (for CPU) backend. - `cleanup()` destroys the distributed process group. 2. Modified the `main()` function to take `rank` and `world_size` as arguments and set up the distributed environment. - The device is set based on the rank and available GPUs. - The dataset is adjusted to ensure its size is divisible by `(world_size * batch_size)` using the `adjust_dataset_size()` function. - A `DistributedSampler` is used to partition the dataset among the processes. 3. Parallelized the computation of means and squared values across the dataset. - Each process computes the means and squared values for its assigned portion of the dataset. - The results are gathered from all processes using `dist.all_gather_object()`. - The root process (rank 0) computes the final mean, standard deviation, and flux statistics using the gathered results. 4. Parallelized the computation of one-step difference means and squared values. - Similar to step 3, each process computes the difference means and squared values for its assigned portion of the dataset. - The results are gathered from all processes using `dist.all_gather_object()`. - The final difference mean and standard deviation are computed using the gathered results. These changes enable the script to leverage multiple processes/GPUs to speed up the computation of parameter weights, means, and standard deviations. The dataset is partitioned among the processes, and the results are gathered and aggregated by the root process. To run the script in a distributed manner, it can be launched using Slurm. Please review the changes and provide any feedback or suggestions. --------- Co-authored-by: Simon Adamov --- .gitignore | 2 + create_parameter_weights.py | 374 +++++++++++++++++++++++++++++------- 2 files changed, 302 insertions(+), 74 deletions(-) diff --git a/.gitignore b/.gitignore index c9d914c2..65e9f6f8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,12 +2,14 @@ wandb slurm_log* saved_models +lightning_logs data graphs *.sif sweeps test_*.sh .vscode +*slurm* ### Python ### # Byte-compiled / optimized / DLL files diff --git a/create_parameter_weights.py b/create_parameter_weights.py index cae1ae3e..c85cd5a3 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -1,10 +1,13 @@ # Standard library import os +import subprocess from argparse import ArgumentParser # Third-party import numpy as np import torch +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm # First-party @@ -12,6 +15,117 @@ from neural_lam.weather_dataset import WeatherDataset +class PaddedWeatherDataset(torch.utils.data.Dataset): + def __init__(self, base_dataset, world_size, batch_size): + super().__init__() + self.base_dataset = base_dataset + self.world_size = world_size + self.batch_size = batch_size + self.total_samples = len(base_dataset) + self.padded_samples = ( + (self.world_size * self.batch_size) - self.total_samples + ) % self.world_size + self.original_indices = list(range(len(base_dataset))) + self.padded_indices = list( + range(self.total_samples, self.total_samples + self.padded_samples) + ) + + def __getitem__(self, idx): + return self.base_dataset[ + self.original_indices[-1] + if idx >= self.total_samples + else idx % len(self.base_dataset) + ] + + def __len__(self): + return self.total_samples + self.padded_samples + + def get_original_indices(self): + return self.original_indices + + def get_original_window_indices(self, step_length): + return [ + i // step_length + for i in range(len(self.original_indices) * step_length) + ] + + +def get_rank(): + return int(os.environ.get("SLURM_PROCID", 0)) + + +def get_world_size(): + return int(os.environ.get("SLURM_NTASKS", 1)) + + +def setup(rank, world_size): # pylint: disable=redefined-outer-name + """Initialize the distributed group.""" + if "SLURM_JOB_NODELIST" in os.environ: + master_node = ( + subprocess.check_output( + "scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1", + shell=True, + ) + .strip() + .decode("utf-8") + ) + else: + print( + "\033[91mCareful, you are running this script with --distributed " + "without any scheduler. In most cases this will result in slower " + "execution and the --distributed flag should be removed.\033[0m" + ) + master_node = "localhost" + os.environ["MASTER_ADDR"] = master_node + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group( + "nccl" if torch.cuda.is_available() else "gloo", + rank=rank, + world_size=world_size, + ) + if rank == 0: + print( + f"Initialized {dist.get_backend()} " + f"process group with world size {world_size}." + ) + + +def save_stats( + static_dir_path, means, squares, flux_means, flux_squares, filename_prefix +): + means = ( + torch.stack(means) if len(means) > 1 else means[0] + ) # (N_batch, d_features,) + squares = ( + torch.stack(squares) if len(squares) > 1 else squares[0] + ) # (N_batch, d_features,) + mean = torch.mean(means, dim=0) # (d_features,) + second_moment = torch.mean(squares, dim=0) # (d_features,) + std = torch.sqrt(second_moment - mean**2) # (d_features,) + torch.save( + mean.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_mean.pt") + ) + torch.save( + std.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_std.pt") + ) + + if len(flux_means) == 0: + return + flux_means = ( + torch.stack(flux_means) if len(flux_means) > 1 else flux_means[0] + ) # (N_batch,) + flux_squares = ( + torch.stack(flux_squares) if len(flux_squares) > 1 else flux_squares[0] + ) # (N_batch,) + flux_mean = torch.mean(flux_means) # (,) + flux_second_moment = torch.mean(flux_squares) # (,) + flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) + torch.save( + torch.stack((flux_mean, flux_std)).cpu(), + os.path.join(static_dir_path, "flux_stats.pt"), + ) + + def main(): """ Pre-compute parameter weights to be used in loss function @@ -41,32 +155,52 @@ def main(): default=4, help="Number of workers in data loader (default: 4)", ) + parser.add_argument( + "--distributed", + type=int, + default=0, + help="Run the script in distributed mode (1) or not (0) (default: 0)", + ) args = parser.parse_args() + distributed = bool(args.distributed) + rank = get_rank() + world_size = get_world_size() config_loader = config.Config.from_file(args.data_config) - static_dir_path = os.path.join("data", config_loader.dataset.name, "static") - - # Create parameter weights based on height - # based on fig A.1 in graph cast paper - w_dict = { - "2": 1.0, - "0": 0.1, - "65": 0.065, - "1000": 0.1, - "850": 0.05, - "500": 0.03, - } - w_list = np.array( - [ - w_dict[par.split("_")[-2]] - for par in config_loader.dataset.var_longnames - ] - ) - print("Saving parameter weights...") - np.save( - os.path.join(static_dir_path, "parameter_weights.npy"), - w_list.astype("float32"), - ) + + if distributed: + + setup(rank, world_size) + device = torch.device( + f"cuda:{rank}" if torch.cuda.is_available() else "cpu" + ) + torch.cuda.set_device(device) if torch.cuda.is_available() else None + + if rank == 0: + static_dir_path = os.path.join( + "data", config_loader.dataset.name, "static" + ) + # Create parameter weights based on height + # based on fig A.1 in graph cast paper + w_dict = { + "2": 1.0, + "0": 0.1, + "65": 0.065, + "1000": 0.1, + "850": 0.05, + "500": 0.03, + } + w_list = np.array( + [ + w_dict[par.split("_")[-2]] + for par in config_loader.dataset.var_longnames + ] + ) + print("Saving parameter weights...") + np.save( + os.path.join(static_dir_path, "parameter_weights.npy"), + w_list.astype("float32"), + ) # Load dataset without any subsampling ds = WeatherDataset( @@ -75,47 +209,97 @@ def main(): subsample_step=1, pred_length=63, standardize=False, - ) # Without standardization + ) + if distributed: + ds = PaddedWeatherDataset( + ds, + world_size, + args.batch_size, + ) + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=False + ) + else: + sampler = None loader = torch.utils.data.DataLoader( - ds, args.batch_size, shuffle=False, num_workers=args.n_workers + ds, + args.batch_size, + shuffle=False, + num_workers=args.n_workers, + sampler=sampler, ) - # Compute mean and std.-dev. of each parameter (+ flux forcing) - # across full dataset - print("Computing mean and std.-dev. for parameters...") - means = [] - squares = [] - flux_means = [] - flux_squares = [] + + if rank == 0: + print("Computing mean and std.-dev. for parameters...") + means, squares, flux_means, flux_squares = [], [], [], [] + for init_batch, target_batch, forcing_batch in tqdm(loader): - batch = torch.cat( - (init_batch, target_batch), dim=1 - ) # (N_batch, N_t, N_grid, d_features) - means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,) + if distributed: + init_batch, target_batch, forcing_batch = ( + init_batch.to(device), + target_batch.to(device), + forcing_batch.to(device), + ) + # (N_batch, N_t, N_grid, d_features) + batch = torch.cat((init_batch, target_batch), dim=1) + # Flux at 1st windowed position is index 1 in forcing + flux_batch = forcing_batch[:, :, :, 1] + # (N_batch, d_features,) + means.append(torch.mean(batch, dim=(1, 2)).cpu()) squares.append( - torch.mean(batch**2, dim=(1, 2)) + torch.mean(batch**2, dim=(1, 2)).cpu() ) # (N_batch, d_features,) + flux_means.append(torch.mean(flux_batch).cpu()) # (,) + flux_squares.append(torch.mean(flux_batch**2).cpu()) # (,) - # Flux at 1st windowed position is index 1 in forcing - flux_batch = forcing_batch[:, :, :, 1] - flux_means.append(torch.mean(flux_batch)) # (,) - flux_squares.append(torch.mean(flux_batch**2)) # (,) + if distributed and world_size > 1: + means_gathered, squares_gathered = [None] * world_size, [ + None + ] * world_size + flux_means_gathered, flux_squares_gathered = [None] * world_size, [ + None + ] * world_size + dist.all_gather_object(means_gathered, torch.cat(means, dim=0)) + dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0)) + dist.all_gather_object(flux_means_gathered, flux_means) + dist.all_gather_object(flux_squares_gathered, flux_squares) - mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features) - second_moment = torch.mean(torch.cat(squares, dim=0), dim=0) - std = torch.sqrt(second_moment - mean**2) # (d_features) + if rank == 0: + means_gathered, squares_gathered = torch.cat( + means_gathered, dim=0 + ), torch.cat(squares_gathered, dim=0) + flux_means_gathered, flux_squares_gathered = torch.tensor( + flux_means_gathered + ), torch.tensor(flux_squares_gathered) - flux_mean = torch.mean(torch.stack(flux_means)) # (,) - flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,) - flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) - flux_stats = torch.stack((flux_mean, flux_std)) + original_indices = ds.get_original_indices() + means, squares = [means_gathered[i] for i in original_indices], [ + squares_gathered[i] for i in original_indices + ] + flux_means, flux_squares = [ + flux_means_gathered[i] for i in original_indices + ], [flux_squares_gathered[i] for i in original_indices] + else: + means = [torch.cat(means, dim=0)] # (N_batch, d_features,) + squares = [torch.cat(squares, dim=0)] # (N_batch, d_features,) + flux_means = [torch.tensor(flux_means)] # (N_batch,) + flux_squares = [torch.tensor(flux_squares)] # (N_batch,) + + if rank == 0: + save_stats( + static_dir_path, + means, + squares, + flux_means, + flux_squares, + "parameter", + ) - print("Saving mean, std.-dev, flux_stats...") - torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt")) - torch.save(std, os.path.join(static_dir_path, "parameter_std.pt")) - torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt")) + if distributed: + dist.barrier() - # Compute mean and std.-dev. of one-step differences across the dataset - print("Computing mean and std.-dev. for one-step differences...") + if rank == 0: + print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( config_loader.dataset.name, split="train", @@ -123,17 +307,35 @@ def main(): pred_length=63, standardize=True, ) # Re-load with standardization + if distributed: + ds_standard = PaddedWeatherDataset( + ds_standard, + world_size, + args.batch_size, + ) + sampler_standard = DistributedSampler( + ds_standard, num_replicas=world_size, rank=rank, shuffle=False + ) + else: + sampler_standard = None loader_standard = torch.utils.data.DataLoader( - ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers + ds_standard, + args.batch_size, + shuffle=False, + num_workers=args.n_workers, + sampler=sampler_standard, ) used_subsample_len = (65 // args.step_length) * args.step_length - diff_means = [] - diff_squares = [] - for init_batch, target_batch, _ in tqdm(loader_standard): - batch = torch.cat( - (init_batch, target_batch), dim=1 - ) # (N_batch, N_t', N_grid, d_features) + diff_means, diff_squares = [], [] + + for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0): + if distributed: + init_batch, target_batch = init_batch.to(device), target_batch.to( + device + ) + # (N_batch, N_t', N_grid, d_features) + batch = torch.cat((init_batch, target_batch), dim=1) # Note: batch contains only 1h-steps stepped_batch = torch.cat( [ @@ -144,24 +346,48 @@ def main(): ) # (N_batch', N_t, N_grid, d_features), # N_batch' = args.step_length*N_batch - batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] # (N_batch', N_t-1, N_grid, d_features) + diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu()) + # (N_batch', d_features,) + diff_squares.append(torch.mean(batch_diffs**2, dim=(1, 2)).cpu()) + # (N_batch', d_features,) + + if distributed and world_size > 1: + dist.barrier() + diff_means_gathered, diff_squares_gathered = [None] * world_size, [ + None + ] * world_size + dist.all_gather_object( + diff_means_gathered, torch.cat(diff_means, dim=0) + ) + dist.all_gather_object( + diff_squares_gathered, torch.cat(diff_squares, dim=0) + ) + + if rank == 0: + diff_means_gathered, diff_squares_gathered = torch.cat( + diff_means_gathered, dim=0 + ).view(-1, *diff_means[0].shape), torch.cat( + diff_squares_gathered, dim=0 + ).view( + -1, *diff_squares[0].shape + ) + original_indices = ds_standard.get_original_window_indices( + args.step_length + ) + diff_means, diff_squares = [ + diff_means_gathered[i] for i in original_indices + ], [diff_squares_gathered[i] for i in original_indices] - diff_means.append( - torch.mean(batch_diffs, dim=(1, 2)) - ) # (N_batch', d_features,) - diff_squares.append( - torch.mean(batch_diffs**2, dim=(1, 2)) - ) # (N_batch', d_features,) + diff_means = [torch.cat(diff_means, dim=0)] # (N_batch', d_features,) + diff_squares = [torch.cat(diff_squares, dim=0)] # (N_batch', d_features,) - diff_mean = torch.mean(torch.cat(diff_means, dim=0), dim=0) # (d_features) - diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0) - diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features) + if rank == 0: + save_stats(static_dir_path, diff_means, diff_squares, [], [], "diff") - print("Saving one-step difference mean and std.-dev...") - torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt")) - torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt")) + if distributed: + dist.destroy_process_group() if __name__ == "__main__": From 81d08400d5f11f40007c1c3686744fa01ee057b1 Mon Sep 17 00:00:00 2001 From: SimonKamuk <43374850+SimonKamuk@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:16:11 +0200 Subject: [PATCH 5/5] Feature: add tests for meps dataset (#38) Implemeted tests for loading a reduced size meps example dataset, creating graphs, and training model. - reduce number of variables, size of domain etc in Joel's MEPS data example so that the zip file is less than 500MB. Calling it `meps_example_reduced` - create test-data zip file and upload to EWC (credentials from @leifdenby) - implement test using pytorch to download and unpack testdata using [pooch](https://pypi.org/project/pooch/) - Implement testing of: - initiation of `neural_lam.weather_dataset.WeatherDataset` from downloaded data - check shapes of returned parts of training item - create new graph in tests for reduced dataset - feed single batch through model and check shape of output - add github action to run tests during ci/cd closes #30 --- .github/workflows/pre-commit.yml | 2 +- .github/workflows/run_tests.yml | 45 ++++ CHANGELOG.md | 3 + README.md | 5 + create_mesh.py | 4 +- .../create_reduced_meps_dataset.ipynb | 239 ++++++++++++++++++ neural_lam/utils.py | 7 +- requirements.txt | 2 + tests/__init__.py | 0 tests/test_mllam_dataset.py | 138 ++++++++++ train_model.py | 5 +- 11 files changed, 443 insertions(+), 7 deletions(-) create mode 100644 .github/workflows/run_tests.yml create mode 100644 docs/notebooks/create_reduced_meps_dataset.ipynb create mode 100644 tests/__init__.py create mode 100644 tests/test_mllam_dataset.py diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index dc519e5b..dadac50d 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,4 +1,4 @@ -name: lint +name: Linting on: # trigger on pushes to any branch, but not main diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml new file mode 100644 index 00000000..71bff3d3 --- /dev/null +++ b/.github/workflows/run_tests.yml @@ -0,0 +1,45 @@ +name: Unit Tests + +on: + # trigger on pushes to any branch, but not main + push: + branches-ignore: + - main + # and also on PRs to main + pull_request: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install torch-geometric>=2.5.2 + - name: Load cache data + uses: actions/cache/restore@v4 + with: + path: data + key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0 + restore-keys: | + ${{ runner.os }}-meps-reduced-example-data-v0.1.0 + - name: Test with pytest + run: | + pytest -v -s + - name: Save cache data + uses: actions/cache/save@v4 + with: + path: data + key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index fd836c7a..3544b299 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) ### Added +- Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub. Added caching of test data tp speed up running tests. + [/#38](https://github.com/mllam/neural-lam/pull/38) + @SimonKamuk - Replaced `constants.py` with `data_config.yaml` for data configuration management [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) diff --git a/README.md b/README.md index ba0bb3fe..1bdc6602 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +![Linting](https://github.com/mllam/neural-lam/actions/workflows/pre-commit.yml/badge.svg) +![Automatic tests](https://github.com/mllam/neural-lam/actions/workflows/run_tests.yml/badge.svg) +

@@ -279,6 +282,8 @@ pre-commit run --all-files ``` from the root directory of the repository. +Furthermore, all tests in the ```tests``` directory will be run upon pushing changes by a github action. Failure in any of the tests will also reject the push/PR. + # Contact If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch. You can open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/create_mesh.py b/create_mesh.py index f04b4d4b..41557a97 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -153,7 +153,7 @@ def prepend_node_index(graph, new_index): return networkx.relabel_nodes(graph, to_mapping, copy=True) -def main(): +def main(input_args=None): parser = ArgumentParser(description="Graph generation arguments") parser.add_argument( "--data_config", @@ -186,7 +186,7 @@ def main(): default=0, help="Generate hierarchical mesh graph (default: 0, no)", ) - args = parser.parse_args() + args = parser.parse_args(input_args) # Load grid positions config_loader = config.Config.from_file(args.data_config) diff --git a/docs/notebooks/create_reduced_meps_dataset.ipynb b/docs/notebooks/create_reduced_meps_dataset.ipynb new file mode 100644 index 00000000..daba23c4 --- /dev/null +++ b/docs/notebooks/create_reduced_meps_dataset.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating meps_example_reduced\n", + "This notebook outlines how the small-size test dataset ```meps_example_reduced``` was created based on the slightly larger dataset ```meps_example```. The zipped up datasets are 263 MB and 2.6 GB, respectively. See [README.md](../../README.md) for info on how to download ```meps_example```.\n", + "\n", + "The dataset was reduced in size by reducing the number of grid points and variables.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Standard library\n", + "import os\n", + "\n", + "# Third-party\n", + "import numpy as np\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The number of grid points was reduced to 1/4 by halving the number of coordinates in both the x and y direction. This was done by removing a quarter of the grid points along each outer edge, so the center grid points would stay centered in the new set.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load existing grid\n", + "grid_xy = np.load('data/meps_example/static/nwp_xy.npy')\n", + "# Get slices in each dimension by cutting off a quarter along each edge\n", + "num_x, num_y = grid_xy.shape[1:]\n", + "x_slice = slice(num_x//4, 3*num_x//4)\n", + "y_slice = slice(num_y//4, 3*num_y//4)\n", + "# Index and save reduced grid\n", + "grid_xy_reduced = grid_xy[:, x_slice, y_slice]\n", + "np.save('data/meps_example_reduced/static/nwp_xy.npy', grid_xy_reduced)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "This cut out the border, so a new perimeter of 10 grid points was established as border (10 was also the border size in the original \"meps_example\").\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Outer 10 grid points are border\n", + "old_border_mask = np.load('data/meps_example/static/border_mask.npy')\n", + "assert np.all(old_border_mask[10:-10, 10:-10] == False)\n", + "assert np.all(old_border_mask[:10, :] == True)\n", + "assert np.all(old_border_mask[:, :10] == True)\n", + "assert np.all(old_border_mask[-10:,:] == True)\n", + "assert np.all(old_border_mask[:,-10:] == True)\n", + "\n", + "# Create new array with False everywhere but the outer 10 grid points\n", + "border_mask = np.zeros_like(grid_xy_reduced[0,:,:], dtype=bool)\n", + "border_mask[:10] = True\n", + "border_mask[:,:10] = True\n", + "border_mask[-10:] = True\n", + "border_mask[:,-10:] = True\n", + "np.save('data/meps_example_reduced/static/border_mask.npy', border_mask)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A few other files also needed to be copied using only the new reduced grid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load surface_geopotential.npy, index only values from the reduced grid, and save to new file\n", + "surface_geopotential = np.load('data/meps_example/static/surface_geopotential.npy')\n", + "surface_geopotential_reduced = surface_geopotential[x_slice, y_slice]\n", + "np.save('data/meps_example_reduced/static/surface_geopotential.npy', surface_geopotential_reduced)\n", + "\n", + "# Load pytorch file grid_features.pt\n", + "grid_features = torch.load('data/meps_example/static/grid_features.pt')\n", + "# Index only values from the reduced grid. \n", + "# First reshape from (num_grid_points_total, 4) to (num_grid_points_x, num_grid_points_y, 4), \n", + "# then index, then reshape back to new total number of grid points\n", + "print(grid_features.shape)\n", + "grid_features_new = grid_features.reshape(num_x, num_y, 4)[x_slice,y_slice,:].reshape((-1, 4))\n", + "# Save to new file\n", + "torch.save(grid_features_new, 'data/meps_example_reduced/static/grid_features.pt')\n", + "\n", + "# flux_stats.pt is just a vector of length 2, so the grid shape and variable changes does not change this file\n", + "torch.save(torch.load('data/meps_example/static/flux_stats.pt'), 'data/meps_example_reduced/static/flux_stats.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The number of variables was reduced by truncating the variable list to the first 8." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_vars = 8\n", + "\n", + "# Load parameter_weights.npy, truncate to first 8 variables, and save to new file\n", + "parameter_weights = np.load('data/meps_example/static/parameter_weights.npy')\n", + "parameter_weights_reduced = parameter_weights[:num_vars]\n", + "np.save('data/meps_example_reduced/static/parameter_weights.npy', parameter_weights_reduced)\n", + "\n", + "# Do the same for following 4 pytorch files\n", + "for file in ['diff_mean', 'diff_std', 'parameter_mean', 'parameter_std']:\n", + " old_file = torch.load(f'data/meps_example/static/{file}.pt')\n", + " new_file = old_file[:num_vars]\n", + " torch.save(new_file, f'data/meps_example_reduced/static/{file}.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly the files in each of the directories train, test, and val have to be reduced. The folders all have the same structure with files of the following types:\n", + "```\n", + "nwp_YYYYMMDDHH_mbrXXX.npy\n", + "wtr_YYYYMMDDHH.npy\n", + "nwp_toa_downwelling_shortwave_flux_YYYYMMDDHH.npy\n", + "```\n", + "with ```YYYYMMDDHH``` being some date with hours, and ```XXX``` being some 3-digit integer.\n", + "\n", + "The first type of file has x and y in dimensions 1 and 2, and variable index in dimension 3. Dimension 0 is unchanged.\n", + "The second type has has x and y in dimensions 1 and 2. Dimension 0 is unchanged.\n", + "The last type has just x and y as the only 2 dimensions.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(65, 268, 238, 18)\n", + "(65, 268, 238)\n" + ] + } + ], + "source": [ + "print(np.load('data/meps_example/samples/train/nwp_2022040100_mbr000.npy').shape)\n", + "print(np.load('data/meps_example/samples/train/nwp_toa_downwelling_shortwave_flux_2022040112.npy').shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following loop goes through each file in each sample folder and indexes them according to the dimensions given by the file name." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for sample in ['train', 'test', 'val']:\n", + " files = os.listdir(f'data/meps_example/samples/{sample}')\n", + "\n", + " for f in files:\n", + " data = np.load(f'data/meps_example/samples/{sample}/{f}')\n", + " if 'mbr' in f:\n", + " data = data[:,x_slice,y_slice,:num_vars]\n", + " elif 'wtr' in f:\n", + " data = data[x_slice, y_slice]\n", + " else:\n", + " data = data[:,x_slice,y_slice]\n", + " np.save(f'data/meps_example_reduced/samples/{sample}/{f}', data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly, the file ```data_config.yaml``` is modified manually by truncating the variable units, long and short names, and setting the new grid shape. Also the unit descriptions containing ```^``` was automatically parsed using latex, and to avoid having to install latex in the GitHub CI/CD pipeline, this was changed to ```**```. \n", + "\n", + "This new config file was placed in ```data/meps_example_reduced```, and that directory was then zipped and placed in a European Weather Cloud S3 bucket." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 836b04ed..59a529eb 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,5 +1,6 @@ # Standard library import os +import shutil # Third-party import numpy as np @@ -250,7 +251,11 @@ def fractional_plot_bundle(fraction): Get the tueplots bundle, but with figure width as a fraction of the page width. """ - bundle = bundles.neurips2023(usetex=True, family="serif") + # If latex is not available, some visualizations might not render correctly, + # but will at least not raise an error. + # Alternatively, use unicode raised numbers. + usetex = True if shutil.which("latex") else False + bundle = bundles.neurips2023(usetex=usetex, family="serif") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] bundle["figure.figsize"] = ( diff --git a/requirements.txt b/requirements.txt index f381d54f..9309eea4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ plotly>=5.15.0 # for dev pre-commit>=2.15.0 +pytest>=8.1.1 +pooch>=1.8.1 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py new file mode 100644 index 00000000..f91170c9 --- /dev/null +++ b/tests/test_mllam_dataset.py @@ -0,0 +1,138 @@ +# Standard library +import os + +# Third-party +import pooch + +# First-party +from create_mesh import main as create_mesh +from neural_lam.config import Config +from neural_lam.utils import load_static_data +from neural_lam.weather_dataset import WeatherDataset +from train_model import main as train_model + +# Disable weights and biases to avoid unnecessary logging +# and to avoid having to deal with authentication +os.environ["WANDB_DISABLED"] = "true" + +# Initializing variables for the s3 client +S3_BUCKET_NAME = "mllam-testdata" +S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int" +S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip" +S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH]) +TEST_DATA_KNOWN_HASH = ( + "98c7a2f442922de40c6891fe3e5d190346889d6e0e97550170a82a7ce58a72b7" +) + + +def test_retrieve_data_ewc(): + # Download and unzip test data into data/meps_example_reduced + pooch.retrieve( + url=S3_FULL_PATH, + known_hash=TEST_DATA_KNOWN_HASH, + processor=pooch.Unzip(extract_dir=""), + path="data", + fname="meps_example_reduced.zip", + ) + + +def test_load_reduced_meps_dataset(): + # The data_config.yaml file is downloaded and extracted in + # test_retrieve_data_ewc together with the dataset itself + data_config_file = "data/meps_example_reduced/data_config.yaml" + dataset_name = "meps_example_reduced" + + dataset = WeatherDataset(dataset_name="meps_example_reduced") + config = Config.from_file(data_config_file) + + var_names = config.values["dataset"]["var_names"] + var_units = config.values["dataset"]["var_units"] + var_longnames = config.values["dataset"]["var_longnames"] + + assert len(var_names) == len(var_longnames) + assert len(var_names) == len(var_units) + + # in future the number of grid static features + # will be provided by the Dataset class itself + n_grid_static_features = 4 + # Hardcoded in model + n_input_steps = 2 + + n_forcing_features = config.values["dataset"]["num_forcing_features"] + n_state_features = len(var_names) + n_prediction_timesteps = dataset.sample_length - n_input_steps + + nx, ny = config.values["grid_shape_state"] + n_grid = nx * ny + + # check that the dataset is not empty + assert len(dataset) > 0 + + # get the first item + init_states, target_states, forcing = dataset[0] + + # check that the shapes of the tensors are correct + assert init_states.shape == (n_input_steps, n_grid, n_state_features) + assert target_states.shape == ( + n_prediction_timesteps, + n_grid, + n_state_features, + ) + assert forcing.shape == ( + n_prediction_timesteps, + n_grid, + n_forcing_features, + ) + + static_data = load_static_data(dataset_name=dataset_name) + + required_props = { + "border_mask", + "grid_static_features", + "step_diff_mean", + "step_diff_std", + "data_mean", + "data_std", + "param_weights", + } + + # check the sizes of the props + assert static_data["border_mask"].shape == (n_grid, 1) + assert static_data["grid_static_features"].shape == ( + n_grid, + n_grid_static_features, + ) + assert static_data["step_diff_mean"].shape == (n_state_features,) + assert static_data["step_diff_std"].shape == (n_state_features,) + assert static_data["data_mean"].shape == (n_state_features,) + assert static_data["data_std"].shape == (n_state_features,) + assert static_data["param_weights"].shape == (n_state_features,) + + assert set(static_data.keys()) == required_props + + +def test_create_graph_reduced_meps_dataset(): + args = [ + "--graph=hierarchical", + "--hierarchical=1", + "--data_config=data/meps_example_reduced/data_config.yaml", + "--levels=2", + ] + create_mesh(args) + + +def test_train_model_reduced_meps_dataset(): + args = [ + "--model=hi_lam", + "--data_config=data/meps_example_reduced/data_config.yaml", + "--n_workers=4", + "--epochs=1", + "--graph=hierarchical", + "--hidden_dim=16", + "--hidden_layers=1", + "--processor_layers=1", + "--ar_steps=1", + "--eval=val", + "--n_example_pred=0", + ] + train_model(args) diff --git a/train_model.py b/train_model.py index cbd787f0..03863275 100644 --- a/train_model.py +++ b/train_model.py @@ -23,7 +23,7 @@ } -def main(): +def main(input_args=None): """ Main function for training and evaluating models """ @@ -208,11 +208,10 @@ def main(): help="""JSON string with variable-IDs and lead times to log watched metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""", ) - args = parser.parse_args() + args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() } - config_loader = config.Config.from_file(args.data_config) # Asserts for arguments