From 82a36a1fb2f5a3bab0fe5f84b9edede7e3a639cb Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 31 May 2024 08:27:28 +0200 Subject: [PATCH] cleanup, modularize, linting, list-comprehensions --- create_parameter_weights.py | 208 ++++++++++++++++++------------------ 1 file changed, 105 insertions(+), 103 deletions(-) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index e16ea707..c299208a 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -36,9 +36,11 @@ def __init__( ) def __getitem__(self, idx): - if idx >= self.total_samples: - return self.base_dataset[self.original_indices[-1]] - return self.base_dataset[idx % len(self.base_dataset)] + return self.base_dataset[ + self.original_indices[-1] + if idx >= self.total_samples + else idx % len(self.base_dataset) + ] def __len__(self): return self.total_samples + self.padded_samples @@ -48,17 +50,11 @@ def get_original_indices(self): def get_rank(): - """Get the rank of the current process in the distributed group.""" - if "SLURM_PROCID" in os.environ: - return int(os.environ["SLURM_PROCID"]) - return 0 + return int(os.environ.get("SLURM_PROCID", 0)) def get_world_size(): - """Get the number of processes in the distributed group.""" - if "SLURM_NTASKS" in os.environ: - return int(os.environ["SLURM_NTASKS"]) - return 1 + return int(os.environ.get("SLURM_NTASKS", 1)) def setup(rank, world_size): # pylint: disable=redefined-outer-name @@ -73,22 +69,57 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name .decode("utf-8") ) else: + print( + "\033[91mCareful, you are running this script with --parallelize " + "without any scheduler. In most cases this will result in slower " + "execution and the --parallelize flag should be removed.\033[0m" + ) master_node = "localhost" - master_port = "12355" os.environ["MASTER_ADDR"] = master_node - os.environ["MASTER_PORT"] = master_port - if torch.cuda.is_available(): - dist.init_process_group("nccl", rank=rank, world_size=world_size) - else: - dist.init_process_group("gloo", rank=rank, world_size=world_size) + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group( + "nccl" if torch.cuda.is_available() else "gloo", + rank=rank, + world_size=world_size, + ) print( - f"Initialized {dist.get_backend()} process group with " - f"world size " - f"{world_size}." + f"Initialized {dist.get_backend()} process group with world size {world_size}." ) -def main(): # pylint: disable=redefined-outer-name +def save_stats( + static_dir_path, means, squares, flux_means, flux_squares, filename_prefix +): + means = torch.stack(means) if len(means) > 1 else means[0] + squares = torch.stack(squares) if len(squares) > 1 else squares[0] + mean = torch.mean(means, dim=0) + second_moment = torch.mean(squares, dim=0) + std = torch.sqrt(second_moment - mean**2) + torch.save( + mean.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_mean.pt") + ) + torch.save( + std.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_std.pt") + ) + + if len(flux_means) == 0: + return + flux_means = ( + torch.stack(flux_means) if len(flux_means) > 1 else flux_means[0] + ) + flux_squares = ( + torch.stack(flux_squares) if len(flux_squares) > 1 else flux_squares[0] + ) + flux_mean = torch.mean(flux_means) + flux_second_moment = torch.mean(flux_squares) + flux_std = torch.sqrt(flux_second_moment - flux_mean**2) + torch.save( + torch.stack((flux_mean, flux_std)).cpu(), + os.path.join(static_dir_path, f"{filename_prefix}_flux_stats.pt"), + ) + + +def main(): parser = ArgumentParser(description="Training arguments") parser.add_argument( "--data_config", @@ -129,16 +160,15 @@ def main(): # pylint: disable=redefined-outer-name rank = get_rank() world_size = get_world_size() - config_loader = config.Config.from_file(args.data_config) if args.parallelize: + setup(rank, world_size) - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - else: - device = torch.device("cpu") + device = torch.device( + f"cuda:{rank}" if torch.cuda.is_available() else "cpu" + ) + torch.cuda.set_device(device) if torch.cuda.is_available() else None if rank == 0: static_dir_path = os.path.join( @@ -171,14 +201,13 @@ def main(): # pylint: disable=redefined-outer-name pred_length=63, standardize=False, ) + ds = PaddedWeatherDataset( + ds, + world_size, + args.batch_size, + duplication_factor=args.duplication_factor, + ) if args.parallelize: - ds = PaddedWeatherDataset( - ds, - world_size, - args.batch_size, - duplication_factor=args.duplication_factor, - ) - sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank) else: sampler = None @@ -190,14 +219,9 @@ def main(): # pylint: disable=redefined-outer-name sampler=sampler, ) - # Compute mean and std.-dev. of each parameter (+ flux forcing) across - # full dataset if rank == 0: print("Computing mean and std.-dev. for parameters...") - means = [] - squares = [] - flux_means = [] - flux_squares = [] + means, squares, flux_means, flux_squares = [], [], [], [] for init_batch, target_batch, forcing_batch in tqdm(loader): if args.parallelize: @@ -214,45 +238,32 @@ def main(): # pylint: disable=redefined-outer-name flux_squares.append(torch.mean(flux_batch**2).cpu()) if args.parallelize: - means_gathered = [None] * world_size - squares_gathered = [None] * world_size + means_gathered, squares_gathered = [None] * world_size, [ + None + ] * world_size dist.all_gather_object(means_gathered, torch.cat(means, dim=0)) dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0)) if rank == 0: - means_all = torch.cat(means_gathered, dim=0) - squares_all = torch.cat(squares_gathered, dim=0) - original_indices = ds.get_original_indices() - means = [means_all[i] for i in original_indices] - squares = [squares_all[i] for i in original_indices] + means_gathered, squares_gathered = torch.cat( + means_gathered, dim=0 + ), torch.cat(squares_gathered, dim=0) + means, squares = [ + means_gathered[i] for i in ds.get_original_indices() + ], [squares_gathered[i] for i in ds.get_original_indices()] + if rank == 0: - if len(means) > 1: - means = torch.stack(means) - squares = torch.stack(squares) - else: - means = means[0] - squares = squares[0] - mean = torch.mean(means, dim=0) - second_moment = torch.mean(squares, dim=0) - std = torch.sqrt(second_moment - mean**2) - torch.save( - mean.cpu(), os.path.join(static_dir_path, "parameter_mean.pt") - ) - torch.save(std.cpu(), os.path.join(static_dir_path, "parameter_std.pt")) - if len(flux_means) > 1: - flux_means_all = torch.stack(flux_means) - flux_squares_all = torch.stack(flux_squares) - else: - flux_means_all = flux_means[0] - flux_squares_all = flux_squares[0] - flux_mean = torch.mean(flux_means_all) - flux_second_moment = torch.mean(flux_squares_all) - flux_std = torch.sqrt(flux_second_moment - flux_mean**2) - torch.save( - torch.stack((flux_mean, flux_std)).cpu(), - os.path.join(static_dir_path, "flux_stats.pt"), + save_stats( + static_dir_path, + means, + squares, + flux_means, + flux_squares, + "parameter", ) + if args.parallelize: dist.barrier() + if rank == 0: print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( @@ -262,14 +273,13 @@ def main(): # pylint: disable=redefined-outer-name pred_length=63, standardize=True, ) + ds_standard = PaddedWeatherDataset( + ds_standard, + world_size, + args.batch_size, + duplication_factor=args.duplication_factor, + ) if args.parallelize: - ds_standard = PaddedWeatherDataset( - ds_standard, - world_size, - args.batch_size, - duplication_factor=args.duplication_factor, - ) - sampler_standard = DistributedSampler( ds_standard, num_replicas=world_size, rank=rank ) @@ -284,8 +294,7 @@ def main(): # pylint: disable=redefined-outer-name ) used_subsample_len = (65 // args.step_length) * args.step_length - diff_means = [] - diff_squares = [] + diff_means, diff_squares = [], [] for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0): if args.parallelize: @@ -301,41 +310,34 @@ def main(): # pylint: disable=redefined-outer-name dim=0, ) batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] - diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu()) diff_squares.append(torch.mean(batch_diffs**2, dim=(1, 2)).cpu()) if args.parallelize: dist.barrier() - - diff_means_gathered = [None] * world_size - diff_squares_gathered = [None] * world_size + diff_means_gathered, diff_squares_gathered = [None] * world_size, [ + None + ] * world_size dist.all_gather_object( diff_means_gathered, torch.cat(diff_means, dim=0) ) dist.all_gather_object( diff_squares_gathered, torch.cat(diff_squares, dim=0) ) - if rank == 0: - diff_means_all = torch.cat(diff_means_gathered, dim=0) - diff_squares_all = torch.cat(diff_squares_gathered, dim=0) - original_indices = ds_standard.get_original_indices() - diff_means = [diff_means_all[i] for i in original_indices] - diff_squares = [diff_squares_all[i] for i in original_indices] + diff_means_gathered, diff_squares_gathered = torch.cat( + diff_means_gathered, dim=0 + ), torch.cat(diff_squares_gathered, dim=0) + diff_means, diff_squares = [ + diff_means_gathered[i] + for i in ds_standard.get_original_indices() + ], [ + diff_squares_gathered[i] + for i in ds_standard.get_original_indices() + ] + if rank == 0: - if len(diff_means) > 1: - diff_means = torch.stack(diff_means) - diff_squares = torch.stack(diff_squares) - else: - diff_means = diff_means[0] - diff_squares = diff_squares[0] - diff_mean = torch.mean(diff_means, dim=0) - diff_second_moment = torch.mean(diff_squares, dim=0) - diff_std = torch.sqrt(diff_second_moment - diff_mean**2) - - torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt")) - torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt")) + save_stats(static_dir_path, diff_means, diff_squares, [], [], "diff") if args.parallelize: dist.destroy_process_group()