Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support dp + mp hybrid parallel training #167

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions passl/core/sync_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def grad_sync(param_groups, comm_group=None, grad_avg=True):

for group in param_groups:
for p in group['params']:
if p.is_distributed:
continue

grad = p.grad
if grad is None:
Expand All @@ -57,9 +55,6 @@ def param_sync(model, src_rank=0, comm_group=None):

for _, param in model._obtain_parameters_buffers().items():

if hasattr(param, 'is_distributed') and param.is_distributed:
continue

if getattr(param, "no_sync", False):
continue

Expand Down
11 changes: 10 additions & 1 deletion passl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def build_dataloader(config, mode, device, use_dali=False,
worker_init_fn=None):
worker_init_fn=None, hybrid_parallel=False):
assert mode in ['Train', 'Eval', 'Test'
], "Dataset mode should be Train, Eval, Test"

Expand Down Expand Up @@ -57,6 +57,15 @@ def build_dataloader(config, mode, device, use_dali=False,
config_sampler = config[mode]['sampler']
config_sampler = copy.deepcopy(config_sampler)
sampler_name = config_sampler.pop("name")

if hybrid_parallel:
from paddle.distributed import fleet
hcg = fleet.get_hybrid_communicate_group()
data_ranks = hcg.get_data_sharding_parallel_world_size()
data_rank = hcg.get_data_sharding_parallel_world_rank()
print(data_ranks, data_rank)
config_sampler.update({'num_replicas': data_ranks, 'rank': data_rank})

batch_sampler = eval("sampler.{}".format(sampler_name))(dataset,
**config_sampler)
logger.debug("build batch_sampler({}) success...".format(batch_sampler))
Expand Down
22 changes: 15 additions & 7 deletions passl/distributed/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def get_data_sharding_parallel_group(self):
return self._dp_sharding_comm_group

def get_data_sharding_parallel_world_rank(self):
if self._dp_sharding_comm_group.nranks == 1:
return 0
return self._dp_sharding_comm_group.rank

def get_data_sharding_parallel_world_size(self):
Expand Down Expand Up @@ -303,20 +305,26 @@ def get_model_parallel_ring_group(self):
hcg.get_model_parallel_ring_group = types.MethodType(get_model_parallel_ring_group, hcg)


def init_dist_env(seed, mp_degree=1, pp_degree=1, sharding_degree=1):
def init_dist_env(seed, hybrid_configs={}):
"""
init distributed env
"""

mp_degree = hybrid_configs.get('mp_degree', 1)
pp_degree = hybrid_configs.get('pp_degree', 1)
sharding_degree = hybrid_configs.get('sharding_degree', 1)

strategy = fleet.DistributedStrategy()
other_degree = mp_degree * pp_degree * sharding_degree
assert dist.get_world_size() % other_degree == 0
dp_degree = dist.get_world_size() // other_degree
strategy.hybrid_configs = {
"dp_degree": dp_degree,
"mp_degree": mp_degree,
"pp_degree": pp_degree,
"sharding_degree": sharding_degree
}

if 'dp_degree' in hybrid_configs:
assert hybrid_configs['dp_degree'] == dp_degree
else:
hybrid_configs['dp_degree'] = dp_degree

strategy.hybrid_configs = hybrid_configs
strategy.tensor_parallel_configs = {"tensor_init_seed": seed}

# init Fleet env
Expand Down
69 changes: 37 additions & 32 deletions passl/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from passl.core import recompute_warp, GradScaler, param_sync
from passl.models.utils import EMA
from passl.utils.infohub import runtime_info_hub
from passl.distributed import env as dist_env
from . import loops


Expand Down Expand Up @@ -71,22 +72,40 @@ def __init__(self, config, mode="train"):
self.config["Global"]["distributed"] = dist.get_world_size() != 1
self.config["Global"]["rank"] = dist.get_rank()
self.config["Global"]["world_size"] = dist.get_world_size()
if self.config["Global"]["distributed"]:
dist.init_parallel_env()

# set seed
seed = self.config["Global"].get("seed", False)
if seed:
assert isinstance(seed, int), "The 'seed' must be a integer!"
seed += self.config["Global"]["rank"]
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)

if self.config["Global"]["distributed"]:
assert self.config.get("DistributedStrategy", None) is not None
hybrid_configs = self.config["DistributedStrategy"].get("hybrid_configs", {})
if len(hybrid_configs) > 0:
self.hybrid_parallel = True
seed = self.config["Global"].get("seed", 42)
dist_env.init_dist_env(seed=seed, hybrid_configs=hybrid_configs)
else:
self.hybrid_parallel = False
dist.fleet.init(is_collective=True)

if self.hybrid_parallel:
seed = dist_env.get_dp_seed()
def worker_init_fn(worker_id):
""" set seed in subproces for dataloader when num_workers > 0"""
np.random.seed(seed + worker_id)
random.seed(seed + worker_id)
else:
# backward compatibility
# set seed
seed = self.config["Global"].get("seed", False)
if seed:
assert isinstance(seed, int), "The 'seed' must be a integer!"
seed += self.config["Global"]["rank"]
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)

def worker_init_fn(worker_id):
""" set seed in subproces for dataloader when num_workers > 0"""
np.random.seed(seed + worker_id)
random.seed(seed + worker_id)

RELATED_FLAGS_SETTING = {}
RELATED_FLAGS_SETTING['FLAGS_cudnn_exhaustive_search'] = 1
Expand Down Expand Up @@ -131,12 +150,12 @@ def worker_init_fn(worker_id):
if self.mode == 'train':
self.train_dataloader = build_dataloader(
self.config["DataLoader"], "Train", self.device, self.use_dali,
worker_init_fn)
worker_init_fn, self.hybrid_parallel)
if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device, self.use_dali,
worker_init_fn)
worker_init_fn, self.hybrid_parallel)

# build loss
self.train_loss_func = None
Expand Down Expand Up @@ -253,26 +272,12 @@ def worker_init_fn(worker_id):
recompute_warp(
self.model,
**self.config["DistributedStrategy"]['recompute'])
if self.config["DistributedStrategy"].get("data_sharding", False):
assert 'data_parallel' not in self.config[
"DistributedStrategy"], "data_parallel cannot be set when using data_sharding"
# from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
# from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
# from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler

# # Note(GuoxiaWang): Only support global data parallel now!
# # First, we need to split optimizer
# self.optimizer = ShardingOptimizerStage2(
# params=self.model.parameters(), optim=self.optimizer)

# # Second, warpper the origin model to have gradient sharding function
# self.model = ShardingStage2(
# self.model,
# self.optimizer,
# accumulate_grads=self.accum_steps > 1,
# device=self.config["Global"]["device"], )
# self.scaler = ShardingScaler(self.scaler)
assert False, "Do not support data_sharding now!"

if self.hybrid_parallel:
hcg = dist_env.get_hcg()
if hcg.get_model_parallel_world_size() > 1:
from paddle.distributed.fleet.meta_parallel import TensorParallel
self.model = TensorParallel(self.model, hcg, strategy=None)
else:
# we always use pure data parallel default
assert 'data_parallel' in self.config["DistributedStrategy"] and \
Expand Down
16 changes: 10 additions & 6 deletions passl/engine/loops/classification_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import collections
import platform
import paddle
import paddle.distributed as dist
from passl.core import grad_sync, param_sync
from passl.utils import io

Expand Down Expand Up @@ -82,7 +83,8 @@ def train_one_step(self, batch):
# do forward and backward
out, loss_dict = self.forward_backward(batch)

grad_sync(self.trainer.optimizer.param_groups)
comm_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group()
grad_sync(self.trainer.optimizer.param_groups, comm_group=comm_group)

# do unscale and step if using fp16 and not found nan/inf
# otherwise do nothing
Expand Down Expand Up @@ -186,14 +188,16 @@ def eval_one_dataset(self, eval_dataloader):
output_info[key].update(float(loss_dict[key]), batch_size)

# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size()
data_ranks = dist.fleet.get_hybrid_communicate_group().get_data_sharding_parallel_world_size()
current_samples = batch_size * data_ranks
accum_samples += current_samples

# calc metric
if self.trainer.eval_metric_func is not None:
if paddle.distributed.get_world_size() > 1:
if data_ranks > 1:
group = dist.fleet.get_hybrid_communicate_group().get_data_sharding_parallel_group()
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
dist.all_gather(label_list, batch[1], group=group)
labels = paddle.concat(label_list, 0)

if isinstance(out, dict):
Expand All @@ -202,12 +206,12 @@ def eval_one_dataset(self, eval_dataloader):
pred = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, x)
dist.all_gather(pred_list, x, group=group)
pred_x = paddle.concat(pred_list, 0)
pred.append(pred_x)
else:
pred_list = []
paddle.distributed.all_gather(pred_list, out)
dist.all_gather(pred_list, out, group=group)
pred = paddle.concat(pred_list, 0)

if accum_samples > total_samples and not self.trainer.use_dali:
Expand Down
7 changes: 5 additions & 2 deletions passl/engine/loops/contrastive_learning_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from __future__ import division
from __future__ import print_function

import paddle
import collections

import paddle
import paddle.distributed as dist

from passl.core import grad_sync
from passl.utils import logger
from .loop import TrainingEpochLoop
Expand Down Expand Up @@ -71,7 +73,8 @@ def train_one_step(self, batch):
# do forward and backward
loss_dict = self.forward_backward(batch)

grad_sync(self.trainer.optimizer.param_groups)
comm_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group()
grad_sync(self.trainer.optimizer.param_groups, comm_group=comm_group)

# do unscale and step if using fp16 and not found nan/inf
# otherwise do nothing
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# global configs
Global:
task_type: Classification
train_loop: ClassificationTrainingEpochLoop
validate_loop: ClassificationEvaluationLoop
checkpoint: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
max_num_latest_checkpoint: 0
eval_during_train: True
eval_interval: 1
eval_unit: "epoch"
accum_steps: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
seed: 2021

# FP16 setting
FP16:
level: O2
GradScaler:
init_loss_scaling: 65536.0

DistributedStrategy:
hybrid_configs:
mp_degree: 2

# model architecture
Model:
name: ViT_hybrid_base_patch16_224
class_num: 1000
drop_rate: 0.1

# loss function config for traing/eval process
Loss:
Train:
- ViTCELoss:
weight: 1.0
epsilon: 0.0001
Eval:
- CELoss:
weight: 1.0

LRScheduler:
name: ViTLRScheduler
learning_rate: 3e-3
decay_type: cosine
warmup_steps: 10000

Optimizer:
name: AdamW
betas: (0.9, 0.999)
epsilon: 1e-8
weight_decay: 0.3
use_master_param: False
grad_clip:
name: ClipGradByGlobalNorm
clip_norm: 1.0


# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageFolder
root: ./dataset/ILSVRC2012/train
transform:
- RandomResizedCrop:
size: 224
scale: [0.05, 1.0]
interpolation: bicubic
- RandomHorizontalFlip:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
- ToCHWImage:

sampler:
name: DistributedBatchSampler
batch_size: 1024
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True

Eval:
dataset:
name: ImageFolder
root: ./dataset/ILSVRC2012/val
transform:
- Resize:
size: 256
interpolation: bicubic
backend: pil
- CenterCrop:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:

sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True

Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]

Export:
export_type: paddle
input_shape: [None, 3, 224, 224]
Loading