diff --git a/passl/core/sync_utils.py b/passl/core/sync_utils.py index f7d6ba64..b52a01e1 100644 --- a/passl/core/sync_utils.py +++ b/passl/core/sync_utils.py @@ -29,8 +29,6 @@ def grad_sync(param_groups, comm_group=None, grad_avg=True): for group in param_groups: for p in group['params']: - if p.is_distributed: - continue grad = p.grad if grad is None: @@ -57,9 +55,6 @@ def param_sync(model, src_rank=0, comm_group=None): for _, param in model._obtain_parameters_buffers().items(): - if hasattr(param, 'is_distributed') and param.is_distributed: - continue - if getattr(param, "no_sync", False): continue diff --git a/passl/data/__init__.py b/passl/data/__init__.py index 049606f6..bd383b7e 100644 --- a/passl/data/__init__.py +++ b/passl/data/__init__.py @@ -23,7 +23,7 @@ def build_dataloader(config, mode, device, use_dali=False, - worker_init_fn=None): + worker_init_fn=None, hybrid_parallel=False): assert mode in ['Train', 'Eval', 'Test' ], "Dataset mode should be Train, Eval, Test" @@ -57,6 +57,15 @@ def build_dataloader(config, mode, device, use_dali=False, config_sampler = config[mode]['sampler'] config_sampler = copy.deepcopy(config_sampler) sampler_name = config_sampler.pop("name") + + if hybrid_parallel: + from paddle.distributed import fleet + hcg = fleet.get_hybrid_communicate_group() + data_ranks = hcg.get_data_sharding_parallel_world_size() + data_rank = hcg.get_data_sharding_parallel_world_rank() + print(data_ranks, data_rank) + config_sampler.update({'num_replicas': data_ranks, 'rank': data_rank}) + batch_sampler = eval("sampler.{}".format(sampler_name))(dataset, **config_sampler) logger.debug("build batch_sampler({}) success...".format(batch_sampler)) diff --git a/passl/distributed/env.py b/passl/distributed/env.py index 9452123b..d71b41e9 100644 --- a/passl/distributed/env.py +++ b/passl/distributed/env.py @@ -259,6 +259,8 @@ def get_data_sharding_parallel_group(self): return self._dp_sharding_comm_group def get_data_sharding_parallel_world_rank(self): + if self._dp_sharding_comm_group.nranks == 1: + return 0 return self._dp_sharding_comm_group.rank def get_data_sharding_parallel_world_size(self): @@ -303,20 +305,26 @@ def get_model_parallel_ring_group(self): hcg.get_model_parallel_ring_group = types.MethodType(get_model_parallel_ring_group, hcg) -def init_dist_env(seed, mp_degree=1, pp_degree=1, sharding_degree=1): +def init_dist_env(seed, hybrid_configs={}): """ init distributed env """ + + mp_degree = hybrid_configs.get('mp_degree', 1) + pp_degree = hybrid_configs.get('pp_degree', 1) + sharding_degree = hybrid_configs.get('sharding_degree', 1) + strategy = fleet.DistributedStrategy() other_degree = mp_degree * pp_degree * sharding_degree assert dist.get_world_size() % other_degree == 0 dp_degree = dist.get_world_size() // other_degree - strategy.hybrid_configs = { - "dp_degree": dp_degree, - "mp_degree": mp_degree, - "pp_degree": pp_degree, - "sharding_degree": sharding_degree - } + + if 'dp_degree' in hybrid_configs: + assert hybrid_configs['dp_degree'] == dp_degree + else: + hybrid_configs['dp_degree'] = dp_degree + + strategy.hybrid_configs = hybrid_configs strategy.tensor_parallel_configs = {"tensor_init_seed": seed} # init Fleet env diff --git a/passl/engine/engine.py b/passl/engine/engine.py index 5302efff..2dcfd28e 100644 --- a/passl/engine/engine.py +++ b/passl/engine/engine.py @@ -40,6 +40,7 @@ from passl.core import recompute_warp, GradScaler, param_sync from passl.models.utils import EMA from passl.utils.infohub import runtime_info_hub +from passl.distributed import env as dist_env from . import loops @@ -71,22 +72,40 @@ def __init__(self, config, mode="train"): self.config["Global"]["distributed"] = dist.get_world_size() != 1 self.config["Global"]["rank"] = dist.get_rank() self.config["Global"]["world_size"] = dist.get_world_size() - if self.config["Global"]["distributed"]: - dist.init_parallel_env() - # set seed - seed = self.config["Global"].get("seed", False) - if seed: - assert isinstance(seed, int), "The 'seed' must be a integer!" - seed += self.config["Global"]["rank"] - paddle.seed(seed) - np.random.seed(seed) - random.seed(seed) + if self.config["Global"]["distributed"]: + assert self.config.get("DistributedStrategy", None) is not None + hybrid_configs = self.config["DistributedStrategy"].get("hybrid_configs", {}) + if len(hybrid_configs) > 0: + self.hybrid_parallel = True + seed = self.config["Global"].get("seed", 42) + dist_env.init_dist_env(seed=seed, hybrid_configs=hybrid_configs) + else: + self.hybrid_parallel = False + dist.fleet.init(is_collective=True) + + if self.hybrid_parallel: + seed = dist_env.get_dp_seed() def worker_init_fn(worker_id): """ set seed in subproces for dataloader when num_workers > 0""" np.random.seed(seed + worker_id) random.seed(seed + worker_id) + else: + # backward compatibility + # set seed + seed = self.config["Global"].get("seed", False) + if seed: + assert isinstance(seed, int), "The 'seed' must be a integer!" + seed += self.config["Global"]["rank"] + paddle.seed(seed) + np.random.seed(seed) + random.seed(seed) + + def worker_init_fn(worker_id): + """ set seed in subproces for dataloader when num_workers > 0""" + np.random.seed(seed + worker_id) + random.seed(seed + worker_id) RELATED_FLAGS_SETTING = {} RELATED_FLAGS_SETTING['FLAGS_cudnn_exhaustive_search'] = 1 @@ -131,12 +150,12 @@ def worker_init_fn(worker_id): if self.mode == 'train': self.train_dataloader = build_dataloader( self.config["DataLoader"], "Train", self.device, self.use_dali, - worker_init_fn) + worker_init_fn, self.hybrid_parallel) if self.mode == "eval" or (self.mode == "train" and self.config["Global"]["eval_during_train"]): self.eval_dataloader = build_dataloader( self.config["DataLoader"], "Eval", self.device, self.use_dali, - worker_init_fn) + worker_init_fn, self.hybrid_parallel) # build loss self.train_loss_func = None @@ -253,26 +272,12 @@ def worker_init_fn(worker_id): recompute_warp( self.model, **self.config["DistributedStrategy"]['recompute']) - if self.config["DistributedStrategy"].get("data_sharding", False): - assert 'data_parallel' not in self.config[ - "DistributedStrategy"], "data_parallel cannot be set when using data_sharding" - # from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 - # from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 - # from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler - - # # Note(GuoxiaWang): Only support global data parallel now! - # # First, we need to split optimizer - # self.optimizer = ShardingOptimizerStage2( - # params=self.model.parameters(), optim=self.optimizer) - - # # Second, warpper the origin model to have gradient sharding function - # self.model = ShardingStage2( - # self.model, - # self.optimizer, - # accumulate_grads=self.accum_steps > 1, - # device=self.config["Global"]["device"], ) - # self.scaler = ShardingScaler(self.scaler) - assert False, "Do not support data_sharding now!" + + if self.hybrid_parallel: + hcg = dist_env.get_hcg() + if hcg.get_model_parallel_world_size() > 1: + from paddle.distributed.fleet.meta_parallel import TensorParallel + self.model = TensorParallel(self.model, hcg, strategy=None) else: # we always use pure data parallel default assert 'data_parallel' in self.config["DistributedStrategy"] and \ diff --git a/passl/engine/loops/classification_loop.py b/passl/engine/loops/classification_loop.py index 92ce83f0..7ce90c91 100644 --- a/passl/engine/loops/classification_loop.py +++ b/passl/engine/loops/classification_loop.py @@ -23,6 +23,7 @@ import collections import platform import paddle +import paddle.distributed as dist from passl.core import grad_sync, param_sync from passl.utils import io @@ -82,7 +83,8 @@ def train_one_step(self, batch): # do forward and backward out, loss_dict = self.forward_backward(batch) - grad_sync(self.trainer.optimizer.param_groups) + comm_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + grad_sync(self.trainer.optimizer.param_groups, comm_group=comm_group) # do unscale and step if using fp16 and not found nan/inf # otherwise do nothing @@ -186,14 +188,16 @@ def eval_one_dataset(self, eval_dataloader): output_info[key].update(float(loss_dict[key]), batch_size) # just for DistributedBatchSampler issue: repeat sampling - current_samples = batch_size * paddle.distributed.get_world_size() + data_ranks = dist.fleet.get_hybrid_communicate_group().get_data_sharding_parallel_world_size() + current_samples = batch_size * data_ranks accum_samples += current_samples # calc metric if self.trainer.eval_metric_func is not None: - if paddle.distributed.get_world_size() > 1: + if data_ranks > 1: + group = dist.fleet.get_hybrid_communicate_group().get_data_sharding_parallel_group() label_list = [] - paddle.distributed.all_gather(label_list, batch[1]) + dist.all_gather(label_list, batch[1], group=group) labels = paddle.concat(label_list, 0) if isinstance(out, dict): @@ -202,12 +206,12 @@ def eval_one_dataset(self, eval_dataloader): pred = [] for x in out: pred_list = [] - paddle.distributed.all_gather(pred_list, x) + dist.all_gather(pred_list, x, group=group) pred_x = paddle.concat(pred_list, 0) pred.append(pred_x) else: pred_list = [] - paddle.distributed.all_gather(pred_list, out) + dist.all_gather(pred_list, out, group=group) pred = paddle.concat(pred_list, 0) if accum_samples > total_samples and not self.trainer.use_dali: diff --git a/passl/engine/loops/contrastive_learning_loop.py b/passl/engine/loops/contrastive_learning_loop.py index 428d4853..26c18a6c 100644 --- a/passl/engine/loops/contrastive_learning_loop.py +++ b/passl/engine/loops/contrastive_learning_loop.py @@ -16,9 +16,11 @@ from __future__ import division from __future__ import print_function -import paddle import collections +import paddle +import paddle.distributed as dist + from passl.core import grad_sync from passl.utils import logger from .loop import TrainingEpochLoop @@ -71,7 +73,8 @@ def train_one_step(self, batch): # do forward and backward loss_dict = self.forward_backward(batch) - grad_sync(self.trainer.optimizer.param_groups) + comm_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + grad_sync(self.trainer.optimizer.param_groups, comm_group=comm_group) # do unscale and step if using fp16 and not found nan/inf # otherwise do nothing diff --git a/tasks/classification/vit/configs/ViT_hybird_base_patch16_224_in1k_1n8c_dp4mp2_fp16o2.yaml b/tasks/classification/vit/configs/ViT_hybird_base_patch16_224_in1k_1n8c_dp4mp2_fp16o2.yaml new file mode 100644 index 00000000..5c25e0a7 --- /dev/null +++ b/tasks/classification/vit/configs/ViT_hybird_base_patch16_224_in1k_1n8c_dp4mp2_fp16o2.yaml @@ -0,0 +1,129 @@ +# global configs +Global: + task_type: Classification + train_loop: ClassificationTrainingEpochLoop + validate_loop: ClassificationEvaluationLoop + checkpoint: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + max_num_latest_checkpoint: 0 + eval_during_train: True + eval_interval: 1 + eval_unit: "epoch" + accum_steps: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + seed: 2021 + +# FP16 setting +FP16: + level: O2 + GradScaler: + init_loss_scaling: 65536.0 + +DistributedStrategy: + hybrid_configs: + mp_degree: 2 + +# model architecture +Model: + name: ViT_hybrid_base_patch16_224 + class_num: 1000 + drop_rate: 0.1 + +# loss function config for traing/eval process +Loss: + Train: + - ViTCELoss: + weight: 1.0 + epsilon: 0.0001 + Eval: + - CELoss: + weight: 1.0 + +LRScheduler: + name: ViTLRScheduler + learning_rate: 3e-3 + decay_type: cosine + warmup_steps: 10000 + +Optimizer: + name: AdamW + betas: (0.9, 0.999) + epsilon: 1e-8 + weight_decay: 0.3 + use_master_param: False + grad_clip: + name: ClipGradByGlobalNorm + clip_norm: 1.0 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageFolder + root: ./dataset/ILSVRC2012/train + transform: + - RandomResizedCrop: + size: 224 + scale: [0.05, 1.0] + interpolation: bicubic + - RandomHorizontalFlip: + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 1024 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageFolder + root: ./dataset/ILSVRC2012/val + transform: + - Resize: + size: 256 + interpolation: bicubic + backend: pil + - CenterCrop: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: False + loader: + num_workers: 8 + use_shared_memory: True + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] + +Export: + export_type: paddle + input_shape: [None, 3, 224, 224] diff --git a/tasks/classification/vit/evaluate.py b/tasks/classification/vit/evaluate.py index 1346c852..25e98409 100644 --- a/tasks/classification/vit/evaluate.py +++ b/tasks/classification/vit/evaluate.py @@ -44,7 +44,8 @@ def accuracy(output, target, topk=(1, )): pp_degree = 1 sharding_degree = 1 if mp_degree > 1: - dist_env.init_dist_env(seed=42, mp_degree=mp_degree, pp_degree=pp_degree, sharding_degree=sharding_degree) + hybrid_configs = {'mp_degree': mp_degree, 'pp_degree': pp_degree, 'sharding_degree': sharding_degree} + dist_env.init_dist_env(seed=42, hybrid_configs=hybrid_configs) from paddle.distributed.fleet.meta_parallel import TensorParallel model = vision_transformer_hybrid.ViT_hybrid_base_patch16_224() diff --git a/tasks/classification/vit/train_vit_hybrid.sh b/tasks/classification/vit/train_vit_hybrid.sh new file mode 100644 index 00000000..751085a4 --- /dev/null +++ b/tasks/classification/vit/train_vit_hybrid.sh @@ -0,0 +1,24 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + passl-train \ + -c ./configs/ViT_hybird_base_patch16_224_in1k_1n8c_dp4mp2_fp16o2.yaml