From f118a46ca35dac3dd4090c56d9affbbdb25d7273 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 1 Jul 2024 08:24:51 -0700 Subject: [PATCH] create an engine for all things pipelining --- axonn/axonn.py | 764 +------------------------------------------ axonn/inter_layer.py | 456 ++++++++++++++++++++++++++ axonn/utils.py | 25 ++ 3 files changed, 491 insertions(+), 754 deletions(-) create mode 100644 axonn/inter_layer.py create mode 100644 axonn/utils.py diff --git a/axonn/axonn.py b/axonn/axonn.py index 4ea3f77..f678712 100644 --- a/axonn/axonn.py +++ b/axonn/axonn.py @@ -5,22 +5,15 @@ from . import config -from typing import Optional, List, Tuple +from typing import Optional from .communication import communication_handle -from .optim import CPUAdam import torch -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from enum import Enum -import numpy as np -import types try: - # from mpi4py import MPI import mpi4py MPI4PY = True mpi4py.rc.initialize = False # do not initialize MPI automatically - from mpi4py import MPI except ImportError: MPI4PY = False @@ -28,77 +21,6 @@ is_initialized = False # Communication handle for point-to-point (MPI) and collective (NCCL) communication comm_handle = None -# store references to input activation -input_tensors_cache = {} -# store references to output activation -output_tensors_cache = {} -# store (future object, tensor reference) for pending isends -transit_tensors = [] -# store (future object, tensor reference) for pending irecvs -requests = { - "fw": None, - "bw": None, -} -# store reference to model shard -model = None -# loss function -criterion = None -# reference to flattened model params -model_params_fp32, model_params_fp16 = None, None -# reference to flattened model gradients -model_grads_fp32, model_grads_fp16 = None, None -fp32_optimizer = None -# the computation dtype (one of fp16/fp32) -computation_dtype = None -# fp16 all reduce, only applicable with mixed precision -_fp16_all_reduce = None -# loss_scale -loss_scale = 2.0**16 -max_scale = 2.0**24 -min_scale = 2.0**10 -scaling_window = 200 -no_overflow_iters = 0 - -_cpu_offload = False - - -class Operation(Enum): - """ - AxoNNs enum class for the 2 microbatch operations - forward and backward pass - """ - - FW = 0 - BW = 1 - - -class empty_dataset(torch.utils.data.Dataset): - """ - Proxy dataset object for GPUs with inter_layer_parallel_rank > 0 - """ - - def __init__(self, length: int, num_tensors: int): - """Constructor for the proxy dataset class - - Arguments: - length (int): number of datapoints in the dataset - num_tensors (int): number of tensors per datapoint - - Returns: - A PyTorch dataset object - - """ - self.length = length - self.num_tensors = num_tensors - - def __len__(self): - return self.length - - def __getitem__(self, idx): - data = [0 for _ in range(self.num_tensors)] - if self.num_tensors == 1: - return torch.Tensor(data) - else: - return data def init( @@ -108,9 +30,6 @@ def init( G_intra_c: int = 1, G_intra_d: int = 1, gpus_per_node: Optional[int] = None, - mixed_precision=False, - fp16_allreduce=True, - cpu_offload=False, ) -> None: """ Initialize AxoNN's 2D parallelism with G_inter-way inter-layer @@ -124,16 +43,9 @@ def init( AxoNN just creates the required process groups. gpus_per_node (int, optional): number of GPUs per node, if not provided this is inferred using pytorch - mixed_precision (bool): whether to use mixed precision - fp16_allreduce (bool): invoke all reduce on fp16 parameters, - only applicable when mixed precision is True - cpu_offload (bool): offload optimizer states and fp32 parameters to - the cpu to save gpu memory. Currently only works with - mixed_precision, fp16_allreduce and axonn.optim.CPUAdam optimizer. """ - global comm_handle, is_initialized, computation_dtype, _fp16_all_reduce - global _cpu_offload + global comm_handle, is_initialized comm_handle = communication_handle( G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node ) @@ -152,21 +64,6 @@ def init( comm_handle.intra_layer_column_parallel_rank ) is_initialized = True - if mixed_precision: - computation_dtype = torch.float16 - else: - computation_dtype = torch.float32 - _fp16_all_reduce = fp16_allreduce - _cpu_offload = cpu_offload - - -def get_comm_handle(): - global comm_handle - return comm_handle - - -def is_zeroth_rank(): - return comm_handle.world_rank == 0 def create_dataloader( @@ -195,9 +92,7 @@ def create_dataloader( assert is_initialized config.micro_batch_size = micro_batch_size config.global_batch_size = global_batch_size - config.batch_size_per_network_instance = global_batch_size // ( - config.G_data * config.G_intra_d - ) + config.batch_size_per_gpu = global_batch_size // (config.G_data * config.G_intra_d) assert ( global_batch_size % (config.G_data * micro_batch_size) == 0 ), "Batch Size should be divisible by the G_data*micro_batch_size" @@ -208,9 +103,14 @@ def create_dataloader( rank=config.G_intra_d * config.data_parallel_rank + config.intra_layer_depth_parallel_rank, ) - data_loader = torch.utils.data.DataLoader( + if config.G_inter > 1: + batch_size_for_dataloader = config.batch_size_per_gpu + else: + batch_size_for_dataloader = config.micro_batch_size + + return torch.utils.data.DataLoader( dataset=dataset, - batch_size=config.batch_size_per_network_instance, + batch_size=batch_size_for_dataloader, shuffle=False, num_workers=num_workers, sampler=sampler, @@ -218,647 +118,3 @@ def create_dataloader( *args, **kwargs, ) # not working with drop_last=False - - return data_loader - - -def _coalesce_and_reassign(tensors: List[torch.Tensor]) -> torch.Tensor: - """Coalesce tensors into a flattened 1D tensor and reassign them to - subtensors in this 1D tensor. - - TODO:- By creating a flat tensor first this doubles the gpu memory. - Make this less memory consuming - - Arguments: - tensors (List[torch.Tensor]): list of tensors to be coalesced - - Returns: - flatenned_tensors (torch.tensor): the flattened tensor. - - """ - flattened_tensor = _flatten_dense_tensors(tensors) - for old_tensor, new_tensor in zip( - tensors, _unflatten_dense_tensors(flattened_tensor, tensors) - ): - old_tensor.data = new_tensor - return flattened_tensor - - -def _initialize_mixed_precision( - model: torch.nn.Module, optimizer: torch.optim.Optimizer -) -> Tuple[torch.nn.Module, torch.optim.Optimizer]: - """ - Initialize mixed precision. Makes model parameters and gradients fp-16 and - optimizer parameters as an fp-32 copy. Similar to Apex's O2 mode. - Also flattens fp-32/fp-16 parameters and gradients for a bulk - descaling and all-reduce. - - Arguments: - model: model object on the GPU - optimizer: the optimizer for the model - - Returns - model: modified model object with fp-16 parameters and gradients - optimizer : modified optimizer object with fp-32 parameters and gradients - """ - global model_params_fp32, model_params_fp16, model_grads_fp32, model_grads_fp16 - assert ( - computation_dtype == torch.float16 - ), "call this method only for mixed precision" - model = model.half() - # now model and optimizer both point to fp16 weights - # change optimizer to point to fp32 weights - fp32_params = [] - fp16_params = [] - fp32_grads = [] - fp16_grads = [] - for group in optimizer.param_groups: - for param_no, param in enumerate(group["params"]): - assert ( - param.dtype == torch.float16 - ), "currently does not handle a mix of fp-16/fp-32" - if param.requires_grad: - fp16_params.append(param) - param.grad = torch.zeros_like(param) - fp16_grads.append(param.grad) - fp32_param = param.detach().float() - fp32_params.append(fp32_param) - fp32_param.grad = torch.empty_like(fp32_param) - fp32_grads.append(fp32_param.grad) - group["params"][param_no] = fp32_param - - optimizer.load_state_dict( - optimizer.state_dict() - ) # trick to recast optimizer states - - model_params_fp32 = _coalesce_and_reassign(fp32_params) - model_params_fp16 = _coalesce_and_reassign(fp16_params) - model_grads_fp32 = _coalesce_and_reassign(fp32_grads) - model_grads_fp16 = _coalesce_and_reassign(fp16_grads) - - return model, optimizer - - -def _initialize_full_precision( - model: torch.nn.Module, optimizer: torch.optim.Optimizer -) -> Tuple[torch.nn.Module, torch.optim.Optimizer]: - """ - Initialize full precision training - leaves model and optimizer untouched. - Flattens fp-32 parameters and gradients. - """ - global model_params_fp32, model_params_fp16, model_grads_fp32, model_grads_fp16 - assert ( - computation_dtype == torch.float32 - ), "call this method only for mixed precision" - - fp32_params = [] - fp32_grads = [] - for group in optimizer.param_groups: - for param in group["params"]: - assert ( - param.dtype == torch.float32 - ), "currently does not handle a mix of fp-16/fp-32" - if param.requires_grad: - fp32_params.append(param) - param.grad = torch.empty_like(param) - fp32_grads.append(param.grad) - - model_params_fp32 = _coalesce_and_reassign(fp32_params) - model_grads_fp32 = _coalesce_and_reassign(fp32_grads) - model_grads_fp16 = None - model_params_fp16 = None - - return model, optimizer - - -def _initialize_mixed_precision_with_cpu_offload( - model: torch.nn.Module, optimizer: torch.optim.Optimizer -) -> Tuple[torch.nn.Module, torch.optim.Optimizer]: - """ - Initialize mixed precision. Makes model parameters and gradients fp-16 and - optimizer parameters as an fp-32 copy. Similar to Apex's O2 mode. - Also flattens fp-32/fp-16 parameters and gradients for a bulk - descaling and all-reduce. - - Arguments: - model: model object on the GPU - optimizer: the optimizer for the model - - Returns - model: modified model object with fp-16 parameters and gradients - optimizer : modified optimizer object with fp-32 parameters and gradients - """ - global model_params_fp32, model_params_fp16, model_grads_fp32, model_grads_fp16 - assert ( - computation_dtype == torch.float16 - ), "CPU offload only supports mixed precision" - assert _fp16_all_reduce, "CPU offload only supports fp-16 allreduce" - assert isinstance( - optimizer, CPUAdam - ), "only AxoNN's implementation of Adam is supported" - - model = model.half() - # now model and optimizer both point to fp16 weights - # change optimizer to point to fp32 weights - fp32_params = [] - fp16_params = [] - fp16_grads = [] - for group in optimizer.param_groups: - for param_no, param in enumerate(group["params"]): - assert ( - param.dtype == torch.float16 - ), "currently does not handle a mix of fp-16/fp-32" - if param.requires_grad: - fp16_params.append(param) - param.grad = torch.zeros_like(param) - fp16_grads.append(param.grad) - # create fp32 parameters and move them to cpu - fp32_param = param.detach().float().cpu() - fp32_params.append(fp32_param) - group["params"][param_no] = fp32_param - - optimizer.load_state_dict( - optimizer.state_dict() - ) # trick to recast optimizer states - - model_params_fp32 = _coalesce_and_reassign(fp32_params) - model_params_fp16 = _coalesce_and_reassign(fp16_params) - model_grads_fp16 = _coalesce_and_reassign(fp16_grads) - - return model, optimizer - - -@torch.no_grad() -def register_model_and_optimizer(model_shard, optimizer): - """AxoNN's user facing function to register a model shard and - the corresponding optimizer. - - Arguments: - model_shard (torch.nn.Module): the model shard created by the - user to be registered - optimizer (torch.nn.Optim): optimizer object for the model - """ - global model, model_params_fp32, model_grads_fp32, model_params_fp16 - global model_grads_fp16, fp32_optimizer - - assert is_initialized - - model = model_shard - if _cpu_offload: - model, optimizer = _initialize_mixed_precision_with_cpu_offload( - model, optimizer - ) - model_params = model_params_fp16 - elif computation_dtype == torch.float16: - model, optimizer = _initialize_mixed_precision(model, optimizer) - model_params = model_params_fp16 - else: - model, optimizer = _initialize_full_precision(model, optimizer) - model_params = model_params_fp32 - - comm_handle.allreduce( - model_params.div_(config.G_data), async_op=False - ) # sync all parameters across data parallel ranks - - if computation_dtype == torch.float16: - model_params_fp32.copy_(model_params_fp16) - - fp32_optimizer = optimizer - fp32_optimizer.skip_next_step = False - - unmodified_step = fp32_optimizer.step - - def modified_step(self): - if not self.skip_next_step: - unmodified_step() - model_params_fp16.copy_(model_params_fp32) - - if computation_dtype == torch.float16 and not _cpu_offload: - fp32_optimizer.step = types.MethodType(modified_step, fp32_optimizer) - - return model, optimizer - - -def register_loss_fn(loss_fn): - """AxoNN's user facing function to register a loss function. - - Arguments: - loss_fn: a PyTorch loss function (eg: torch.nn.CrossEntropy) - """ - global criterion - assert is_initialized - criterion = loss_fn - - -def _get_subtensor(tensor, microbatch_no): - """divide the tensor into equal tensors of micro_batch_size and - retrieve the microbatch_no tensor. Useful when fetching data - corresponding to a microbatch from a batch/labels. - - Arguments: - tensor (torch.Tensor): tensor to be divided - """ - start = microbatch_no * config.micro_batch_size - end = (microbatch_no + 1) * config.micro_batch_size - return tensor[start:end] - - -def print_status(*msg): - """print msg - - Arguments: - msg (str): message to be printed - """ - - print( - f"DP Rank : {config.data_parallel_rank} |", - f"ILP Rank : {config.inter_layer_parallel_rank} -", - *msg, - ) - - -def _forward_pass(input_activation: torch.Tensor, microbatch_no: int, eval_mode: bool): - """do the forward pass on an input activation and send the data to a forward GPU - - Arguments: - input_activation (torch.Tensor): input activation from the previous GPU - microbatch_no (int): the microbatch number of the input activation - eval_mode (bool): true if evaluating the model for validation/testing - - """ - if eval_mode: - with torch.no_grad(): - output_activation = model(input_activation) - if config.inter_layer_parallel_rank == config.G_inter - 1: - output_tensors_cache[microbatch_no] = output_activation - else: - output_activation = model(input_activation) - input_tensors_cache[microbatch_no] = input_activation - output_tensors_cache[microbatch_no] = output_activation - if config.inter_layer_parallel_rank + 1 < config.G_inter: - _send(output_activation, config.inter_layer_parallel_rank + 1, microbatch_no) - - -def _clear_transit_tensors(clear_all=False): - """test pending isends for completion and delete tensors that have been sent - - Arguments: - clear_all (bool): if true, return only after all isends have finished - """ - global transit_tensors - remaining_tensors = [] - for f, tensor in transit_tensors: - if clear_all: - f.Wait() - elif not f.Test(): - remaining_tensors.append([f, tensor]) - transit_tensors = remaining_tensors - - -def _send(tensor: torch.Tensor, destination: int, tag: int): - """send a tensor to a particular rank with a particular tag using MPI - - Arguments: - tensor (torch.Tensor): tensor to be sent - destination (int): inter-layer-parallel rank of the destination - tag (int): tag of the message - """ - if (destination < 0) or (destination >= config.G_inter): - return - _clear_transit_tensors() - tensor = tensor.contiguous() - torch.cuda.synchronize() - transit_tensors.append([comm_handle.send(tensor, destination, tag), tensor]) - - -def _fill_shape(shape): - return [config.micro_batch_size if x == -1 else x for x in shape] - - -def _post_fw_recv_requests(): - """ - Post a receive request for a forward pass - """ - if (requests["fw"] is None) and config.inter_layer_parallel_rank > 0: - tensor = torch.empty( - size=_fill_shape(model.get_input_shape()), - device="cuda", - dtype=computation_dtype, - ) - tensor.requires_grad = True - requests["fw"] = [ - tensor, - comm_handle.recv(tensor, config.inter_layer_parallel_rank - 1), - ] - - -def _post_bw_recv_requests(): - """ - Post a receive request for a backward pass - """ - if (requests["bw"] is None) and ( - config.inter_layer_parallel_rank < config.G_inter - 1 - ): - tensor = torch.empty( - size=_fill_shape(model.get_output_shape()), - device="cuda", - dtype=computation_dtype, - ) - requests["bw"] = [ - tensor, - comm_handle.recv(tensor, config.inter_layer_parallel_rank + 1), - ] - - -def _post_recv_requests(post_fw_recv=True, post_bw_recv=True): - """ - post mpi irecv requests if they haven't been posted. - """ - if post_fw_recv: - _post_fw_recv_requests() - if post_bw_recv: - _post_bw_recv_requests() - - -def _recv(post_fw_recv=True, post_bw_recv=True, eval_mode=False) -> int: - """ - Message driven scheduling of forward and backward passes for pipelining. - - Arguments: - post_fw_recv(bool): Post a new receive request for a forward pass if needed - post_bw_recv(bool): post a new receive request for a backward pass if needed - eval_mode(bool): True if evaluating - Returns: - tag(int): the tag of the received message which is the microbatch number - """ - assert MPI4PY, "attempting to use inter-layer parallelism without mpi4py installed" - status = MPI.Status() - if (requests["bw"] is None) and (requests["fw"] is not None): - requests["fw"][1].Wait(status) - tag = status.Get_tag() - input_activation = requests["fw"][0] - requests["fw"] = None - if post_fw_recv: - _post_fw_recv_requests() - _forward_pass(input_activation, tag, eval_mode) - op = Operation.FW - elif (requests["fw"] is None) and (requests["bw"] is not None): - requests["bw"][1].Wait(status) - tag = status.Get_tag() - output_gradients = requests["bw"][0] - requests["bw"] = None - if post_bw_recv: - _post_bw_recv_requests() - _backward_pass(output_gradients, tag) - op = Operation.BW - else: - index = MPI.Request.Waitany([requests["fw"][1], requests["bw"][1]], status) - tag = status.Get_tag() - if index == 0: # forward pass - input_activation = requests["fw"][0] - requests["fw"] = None - if post_fw_recv: - _post_fw_recv_requests() - _forward_pass(input_activation, tag, eval_mode) - op = Operation.FW - else: - output_gradients = requests["bw"][0] - requests["bw"] = None - if post_bw_recv: - _post_bw_recv_requests() - _backward_pass(output_gradients, tag) - op = Operation.BW - return tag, op - - -def _calc_loss(microbatch_no, microbatch_labels, mul_factor=1.0, eval_mode=False): - """Calculate the loss for a given microbatch number and its corresponding labels - - Arguments: - microbatch_no (int): the microbatch number - microbatch_labels (torch.Tensor): the true labels for the microbatch - mul_factor (float): premultiply loss by this number - """ - # for cross entropy calculation use float - loss = criterion(output_tensors_cache[microbatch_no].float(), microbatch_labels) - if computation_dtype == torch.float16: - output_tensors_cache[microbatch_no] = ( - mul_factor * loss * loss_scale - ) # scale up for mixed precision to - # prevent underflow - else: - output_tensors_cache[microbatch_no] = mul_factor * loss - if eval_mode: - del output_tensors_cache[microbatch_no] - return loss - - -def _backward_pass(output_gradients, microbatch_no): - """do the backward pass of a microbatch and send the input activation gradients - to the previous GPU. - - Arguments: - output gradients (torch.Tensor): the gradient of the loss wrt the output tensor - microbatch_no (int): the microbatch number - """ - output_tensors_cache[microbatch_no].backward(output_gradients) - input_tensor = input_tensors_cache[microbatch_no] - del output_tensors_cache[microbatch_no] - del input_tensors_cache[microbatch_no] - if config.inter_layer_parallel_rank - 1 >= 0: - _send(input_tensor.grad, config.inter_layer_parallel_rank - 1, microbatch_no) - - -def _sync_scale(local_overflow): - assert MPI4PY, "attempting to use inter-layer parallelism without mpi4py installed" - - global loss_scale, no_overflow_iters, max_scale - assert computation_dtype == torch.float16 - overflow_np = np.array(int(local_overflow), "i") - overflow_np_recv = np.array(int(local_overflow), "i") - MPI.COMM_WORLD.Allreduce( - [overflow_np, MPI.INT], [overflow_np_recv, MPI.INT], op=MPI.SUM - ) - if overflow_np_recv > 0: - loss_scale = max(loss_scale / 2.0, min_scale) - if comm_handle.world_rank == 0: - print_status(f"overflow detected - reducing loss scale to {loss_scale}") - no_overflow_iters = 0 - global_overflow = True - else: - no_overflow_iters += 1 - if no_overflow_iters == scaling_window: - loss_scale = min(loss_scale * 2.0, max_scale) - if comm_handle.world_rank == 0: - print_status(f"increasing loss scale to {loss_scale}") - no_overflow_iters = 0 - global_overflow = False - return global_overflow - - -def run_batch( - batch: torch.Tensor, labels: torch.Tensor, eval_mode=False, post_bw_hook=None -) -> int: - """Perform forward and backward pass on a batch. This function invokes - inter-layer-parallelism followed by an all-reduce. - - Arguments: - batch (torch.Tensor): the input batch, for inter-layer-parallel-rank > 0 - this is a proxy tensor with the first dimension equal to the batch size - labels (torch.Tensor): the true labels, for inter-layer-parallel-rank - < G_inter-1, this can be None - eval_mode (bool): set to true if you are doing validation/testing - - Returns: - loss (float): the loss on the batch for inter-layer-parallel-rank - == G_inter - 1, else 0 - """ - batch_loss = 0 - ilp_rank, G_inter, G_data = ( - config.inter_layer_parallel_rank, - config.G_inter, - config.G_data, - ) - num_microbatches_per_network = batch.shape[0] // config.micro_batch_size - - if computation_dtype == torch.float16 and batch.dtype == torch.float32: - batch = batch.half() - - if eval_mode: - model.eval() - else: - model.train() - - if G_inter == 1: - for microbatch_no in range(num_microbatches_per_network): - _forward_pass( - _get_subtensor(batch, microbatch_no), microbatch_no, eval_mode - ) - microbatch_loss = _calc_loss( - microbatch_no, - _get_subtensor(labels, microbatch_no), - 1 / G_data / num_microbatches_per_network, - eval_mode, - ) - batch_loss += microbatch_loss.item() - if not eval_mode: - _backward_pass(None, microbatch_no) - else: - remaining_microbatches = num_microbatches_per_network - num_msgs = remaining_microbatches - if (ilp_rank != 0) and (ilp_rank != G_inter - 1): - num_msgs += remaining_microbatches - forward_msgs = backward_msgs = num_msgs // 2 - elif ilp_rank == 0: - backward_msgs = num_msgs - forward_msgs = 0 - else: - forward_msgs = num_msgs - backward_msgs = 0 - if eval_mode: - num_msgs -= backward_msgs - backward_msgs = 0 - next_microbatch = 0 - if ilp_rank == 0: - for _ in range(G_inter): - if remaining_microbatches == 0: - break - _forward_pass( - _get_subtensor(batch, next_microbatch), next_microbatch, eval_mode - ) - next_microbatch += 1 - remaining_microbatches -= 1 - _post_recv_requests( - post_fw_recv=(forward_msgs > 1), post_bw_recv=(backward_msgs > 1) - ) - while num_msgs: - microbatch_no, op = _recv( - post_fw_recv=(forward_msgs > 1), - post_bw_recv=(backward_msgs > 1), - eval_mode=eval_mode, - ) - num_msgs -= 1 - if op == Operation.FW: - forward_msgs -= 1 - elif op == Operation.BW: - backward_msgs -= 1 - if ilp_rank == 0 and remaining_microbatches: # inject next microbatch - _forward_pass( - _get_subtensor(batch, next_microbatch), next_microbatch, eval_mode - ) - next_microbatch += 1 - remaining_microbatches -= 1 - elif ilp_rank == G_inter - 1: - microbatch_loss = _calc_loss( - microbatch_no, - _get_subtensor(labels, microbatch_no), - 1 / G_data / num_microbatches_per_network, - eval_mode, - ) - batch_loss += microbatch_loss.item() - if not eval_mode: - _backward_pass(None, microbatch_no) - - if eval_mode and ilp_rank == 0: - global transit_tensors - while remaining_microbatches: - while len(transit_tensors) == G_inter: - _clear_transit_tensors() - _forward_pass( - _get_subtensor(batch, next_microbatch), next_microbatch, eval_mode - ) - next_microbatch += 1 - remaining_microbatches -= 1 - - _clear_transit_tensors(clear_all=True) - if post_bw_hook is not None: - assert not eval_mode - post_bw_hook(model) - if not _cpu_offload: - _allreduce_and_descale() - return batch_loss / num_microbatches_per_network - - -def _check_nan(tensor): - """ - check a tensor for overflow - - Arguments: - tensor (torch.Tensor): the tensor to be checked - Return - overflow (bool): true if there is overflow - """ - sum_ = tensor.sum() - return (torch.isinf(sum_) + torch.isnan(sum_)) > 0 - - -def _allreduce_and_descale(): - """ - allreduce and descale the gradients in accoradance with mixed precision - semantics. For fp-16_all_reduce mode, we first all-reduce and then descale - to prevent underflow. Note that it is not possible to check for underflow - so it is absolutely essential to maintain this order. For fp-32 all reduce - mode, we first descale and then all-reduce. After descaling there cannot - be underflow so this order is safe and prevents overflow. - """ - # at this point for mixed precision we will have unscaled fp-16 gradients - # for full precision we will have normal gradients - with torch.no_grad(): - if computation_dtype == torch.float32: - comm_handle.allreduce(model_grads_fp32, async_op=False) - else: - if _fp16_all_reduce: - # first all reduce then descale to prevent underflow - comm_handle.allreduce(model_grads_fp16, async_op=False) - model_grads_fp32.copy_(model_grads_fp16) - model_grads_fp32.div_(loss_scale) - else: - # first descale then allreduce to precent overflow - model_grads_fp32.copy_(model_grads_fp16) - model_grads_fp32.div_(loss_scale) - comm_handle.allreduce(model_grads_fp32, async_op=False) - - model_grads_fp16.zero_() - local_overflow = _check_nan(model_grads_fp32) - global_overflow = _sync_scale(local_overflow) - fp32_optimizer.skip_next_step = global_overflow diff --git a/axonn/inter_layer.py b/axonn/inter_layer.py new file mode 100644 index 0000000..f60e8f1 --- /dev/null +++ b/axonn/inter_layer.py @@ -0,0 +1,456 @@ +from enum import Enum +from dataclasses import dataclass +from axonn import axonn as ax +from mpi4py import MPI +from axonn.intra_layer import ( + sync_gradients_data_parallel, + sync_gradients_depth_parallel, +) +import torch +import numpy as np + + +@dataclass +class LossScaler: + """ + Dataclass for scaling the loss for fp-16 training + """ + + loss_scale: float = 2.0**16 + max_scale: float = 2.0**24 + min_scale: float = 2.0**10 + scaling_window: float = 200 + no_overflow_iters: float = 0 + + +class Operation(Enum): + """ + AxoNNs enum class for the 2 microbatch operations - forward and backward pass + """ + + FW = 0 + BW = 1 + + +class AxoNN_Inter_Layer_Engine: + def __init__(self, model, loss_fn, computation_dtype=torch.float16): + assert ( + ax.is_initialized + ), "Please call ax.init(....) before calling AxoNNPipelineEngine" + self.model = model + self.criterion = loss_fn + + # store references to input activation + self.input_tensors_cache = {} + # store references to output activation + self.output_tensors_cache = {} + # store (future object, tensor reference) for pending isends + self.transit_tensors = [] + # store (future object, tensor reference) for pending irecvs + self.requests = { + "fw": None, + "bw": None, + } + + self.computation_dtype = computation_dtype + self.scaler = LossScaler() + + def _get_subtensor(self, tensor, microbatch_no): + """divide the tensor into equal tensors of micro_batch_size and + retrieve the microbatch_no tensor. Useful when fetching data + corresponding to a microbatch from a batch/labels. + + Arguments: + tensor (torch.Tensor): tensor to be divided + """ + start = microbatch_no * ax.config.micro_batch_size + end = (microbatch_no + 1) * ax.config.micro_batch_size + return tensor[start:end] + + def _forward_pass( + self, input_activation: torch.Tensor, microbatch_no: int, eval_mode: bool + ): + """do the forward pass on an input activation and send the data to a forward GPU + + Arguments: + input_activation (torch.Tensor): input activation from the previous GPU + microbatch_no (int): the microbatch number of the input activation + eval_mode (bool): true if evaluating the model for validation/testing + + """ + with torch.autocast(device_type="cuda", dtype=self.computation_dtype): + if eval_mode: + with torch.no_grad(): + output_activation = self.model(input_activation) + if ax.config.inter_layer_parallel_rank == ax.config.G_inter - 1: + self.output_tensors_cache[microbatch_no] = output_activation + else: + output_activation = self.model(input_activation) + self.input_tensors_cache[microbatch_no] = input_activation + self.output_tensors_cache[microbatch_no] = output_activation + if ax.config.inter_layer_parallel_rank + 1 < ax.config.G_inter: + self._send( + output_activation, + ax.config.inter_layer_parallel_rank + 1, + microbatch_no, + ) + + def _send(self, tensor: torch.Tensor, destination: int, tag: int): + """send a tensor to a particular rank with a particular tag using MPI + + Arguments: + tensor (torch.Tensor): tensor to be sent + destination (int): inter-layer-parallel rank of the destination + tag (int): tag of the message + """ + if (destination < 0) or (destination >= ax.config.G_inter): + return + self._clear_transit_tensors() + tensor = tensor.contiguous().to(self.computation_dtype) + torch.cuda.synchronize() # TODO - replace with stream synchronize. + self.transit_tensors.append( + [ax.comm_handle.send(tensor, destination, tag), tensor] + ) + + def _clear_transit_tensors(self, clear_all=False): + """test pending isends for completion and delete tensors that have been sent + Arguments: + clear_all (bool): if true, return only after all isends have finished + """ + remaining_tensors = [] + for f, tensor in self.transit_tensors: + if clear_all: + f.Wait() + elif not f.Test(): + remaining_tensors.append([f, tensor]) + self.transit_tensors = remaining_tensors + + def _fill_shape(self, shape): + return [ax.config.micro_batch_size if x == -1 else x for x in shape] + + def _post_fw_recv_requests(self): + """ + Post a receive request for a forward pass + """ + if (self.requests["fw"] is None) and ax.config.inter_layer_parallel_rank > 0: + tensor = torch.empty( + size=self._fill_shape(self.model.get_input_shape()), + device="cuda", + dtype=self.computation_dtype, + ) + tensor.requires_grad = True + self.requests["fw"] = [ + tensor, + ax.comm_handle.recv(tensor, ax.config.inter_layer_parallel_rank - 1), + ] + + def _post_bw_recv_requests(self): + """ + Post a receive request for a backward pass + """ + if (self.requests["bw"] is None) and ( + ax.config.inter_layer_parallel_rank < ax.config.G_inter - 1 + ): + tensor = torch.empty( + size=self._fill_shape(self.model.get_output_shape()), + device="cuda", + dtype=self.computation_dtype, + ) + self.requests["bw"] = [ + tensor, + ax.comm_handle.recv(tensor, ax.config.inter_layer_parallel_rank + 1), + ] + + def _post_recv_requests(self, post_fw_recv=True, post_bw_recv=True): + """ + post mpi irecv requests if they haven't been posted. + """ + if post_fw_recv: + self._post_fw_recv_requests() + if post_bw_recv: + self._post_bw_recv_requests() + + def _recv(self, post_fw_recv=True, post_bw_recv=True, eval_mode=False) -> int: + """ + Message driven scheduling of forward and backward passes for pipelining. + + Arguments: + post_fw_recv(bool): Post a new receive request for a forward pass if needed + post_bw_recv(bool): post a new receive request for a backward pass if needed + eval_mode(bool): True if evaluating + Returns: + tag(int): the tag of the received message which is the microbatch number + """ + status = MPI.Status() + if (self.requests["bw"] is None) and (self.requests["fw"] is not None): + self.requests["fw"][1].Wait(status) + tag = status.Get_tag() + input_activation = self.requests["fw"][0] + self.requests["fw"] = None + if post_fw_recv: + self._post_fw_recv_requests() + self._forward_pass(input_activation, tag, eval_mode) + op = Operation.FW + elif (self.requests["fw"] is None) and (self.requests["bw"] is not None): + self.requests["bw"][1].Wait(status) + tag = status.Get_tag() + output_gradients = self.requests["bw"][0] + self.requests["bw"] = None + if post_bw_recv: + self._post_bw_recv_requests() + self._backward_pass(output_gradients, tag) + op = Operation.BW + else: + index = MPI.Request.Waitany( + [self.requests["fw"][1], self.requests["bw"][1]], status + ) + tag = status.Get_tag() + if index == 0: # forward pass + input_activation = self.requests["fw"][0] + self.requests["fw"] = None + if post_fw_recv: + self._post_fw_recv_requests() + self._forward_pass(input_activation, tag, eval_mode) + op = Operation.FW + else: + output_gradients = self.requests["bw"][0] + self.requests["bw"] = None + if post_bw_recv: + self._post_bw_recv_requests() + self._backward_pass(output_gradients, tag) + op = Operation.BW + return tag, op + + def _calc_loss( + self, microbatch_no, microbatch_labels, mul_factor=1.0, eval_mode=False + ): + """Calculate the loss for a given microbatch number and its corresponding labels + + Arguments: + microbatch_no (int): the microbatch number + microbatch_labels (torch.Tensor): the true labels for the microbatch + mul_factor (float): premultiply loss by this number + """ + # for cross entropy calculation use float + loss = self.criterion( + self.output_tensors_cache[microbatch_no].float(), microbatch_labels + ) + if self.computation_dtype == torch.float16: + self.output_tensors_cache[microbatch_no] = ( + mul_factor * loss * self.scaler.loss_scale + ) # scale up for mixed precision to + # prevent underflow + else: + self.output_tensors_cache[microbatch_no] = mul_factor * loss + if eval_mode: + del self.output_tensors_cache[microbatch_no] + return loss + + def _backward_pass(self, output_gradients, microbatch_no): + """do the backward pass of a microbatch and send the input activation gradients + to the previous GPU. + + Arguments: + output gradients (torch.Tensor): the gradient of the loss wrt the + output tensor + microbatch_no (int): the microbatch number + """ + self.output_tensors_cache[microbatch_no].backward(output_gradients) + input_tensor = self.input_tensors_cache[microbatch_no] + del self.output_tensors_cache[microbatch_no] + del self.input_tensors_cache[microbatch_no] + if ax.config.inter_layer_parallel_rank - 1 >= 0: + self._send( + input_tensor.grad, + ax.config.inter_layer_parallel_rank - 1, + microbatch_no, + ) + + def _sync_scale(self, local_overflow): + assert self.computation_dtype == torch.float16 + overflow_np = np.array(int(local_overflow), "i") + overflow_np_recv = np.array(int(local_overflow), "i") + MPI.COMM_WORLD.Allreduce( + [overflow_np, MPI.INT], [overflow_np_recv, MPI.INT], op=MPI.SUM + ) + if overflow_np_recv > 0: + self.scaler.loss_scale = max( + self.scaler.loss_scale / 2.0, self.scaler.min_scale + ) + if ax.comm_handle.world_rank == 0: + print( + f"overflow detected - reducing loss scale" + f"to {self.scaler.loss_scale}" + ) + self.scaler.no_overflow_iters = 0 + global_overflow = True + else: + self.scaler.no_overflow_iters += 1 + if self.scaler.no_overflow_iters == self.scaler.scaling_window: + self.scaler.loss_scale = min( + self.scaler.loss_scale * 2.0, self.scaler.max_scale + ) + if ax.comm_handle.world_rank == 0: + print(f"increasing loss scale to {self.scaler.loss_scale}") + self.scaler.no_overflow_iters = 0 + global_overflow = False + return global_overflow + + def forward_backward_optimizer( + self, + batch: torch.Tensor, + labels: torch.Tensor, + optimizer: torch.optim.Optimizer, + eval_mode=False, + post_bw_hook=None, + ) -> int: + """Perform forward pass, backward pass and optimizer step on a batch. + + Arguments: + batch (torch.Tensor): the input batch + labels (torch.Tensor): the true labels + eval_mode (bool): set to true if you are doing validation/testing + + Returns: + loss (float): the loss on the batch for inter-layer-parallel-rank + == G_inter - 1, else 0 + """ + batch_loss = 0 + ilp_rank, G_inter = ( + ax.config.inter_layer_parallel_rank, + ax.config.G_inter, + ) + num_microbatches_per_gpu = batch.shape[0] // ax.config.micro_batch_size + + if eval_mode: + self.model.eval() + else: + self.model.train() + + if G_inter == 1: + for microbatch_no in range(num_microbatches_per_gpu): + self._forward_pass( + self._get_subtensor(batch, microbatch_no), microbatch_no, eval_mode + ) + microbatch_loss = self._calc_loss( + microbatch_no, + self._get_subtensor(labels, microbatch_no), + 1 / num_microbatches_per_gpu, + eval_mode, + ) + batch_loss += microbatch_loss.item() + if not eval_mode: + self._backward_pass(None, microbatch_no) + else: + remaining_microbatches = num_microbatches_per_gpu + num_msgs = remaining_microbatches + if (ilp_rank != 0) and (ilp_rank != G_inter - 1): + num_msgs += remaining_microbatches + forward_msgs = backward_msgs = num_msgs // 2 + elif ilp_rank == 0: + backward_msgs = num_msgs + forward_msgs = 0 + else: + forward_msgs = num_msgs + backward_msgs = 0 + if eval_mode: + num_msgs -= backward_msgs + backward_msgs = 0 + next_microbatch = 0 + if ilp_rank == 0: + for _ in range(G_inter): + if remaining_microbatches == 0: + break + self._forward_pass( + self._get_subtensor(batch, next_microbatch), + next_microbatch, + eval_mode, + ) + next_microbatch += 1 + remaining_microbatches -= 1 + self._post_recv_requests( + post_fw_recv=(forward_msgs > 1), post_bw_recv=(backward_msgs > 1) + ) + while num_msgs: + microbatch_no, op = self._recv( + post_fw_recv=(forward_msgs > 1), + post_bw_recv=(backward_msgs > 1), + eval_mode=eval_mode, + ) + num_msgs -= 1 + if op == Operation.FW: + forward_msgs -= 1 + elif op == Operation.BW: + backward_msgs -= 1 + if ilp_rank == 0 and remaining_microbatches: # inject next microbatch + self._forward_pass( + self._get_subtensor(batch, next_microbatch), + next_microbatch, + eval_mode, + ) + next_microbatch += 1 + remaining_microbatches -= 1 + elif ilp_rank == G_inter - 1: + microbatch_loss = self._calc_loss( + microbatch_no, + self._get_subtensor(labels, microbatch_no), + 1 / num_microbatches_per_gpu, + eval_mode, + ) + batch_loss += microbatch_loss.item() + if not eval_mode: + self._backward_pass(None, microbatch_no) + + if eval_mode and ilp_rank == 0: + while remaining_microbatches: + while len(self.transit_tensors) == G_inter: + self._clear_transit_tensors() + self._forward_pass( + self._get_subtensor(batch, next_microbatch), + next_microbatch, + eval_mode, + ) + next_microbatch += 1 + remaining_microbatches -= 1 + + self._clear_transit_tensors(clear_all=True) + if post_bw_hook is not None: + assert not eval_mode + post_bw_hook(self.model) + + sync_gradients_depth_parallel(self.model, mean=True) + sync_gradients_data_parallel(self.model, mean=True) + if self.computation_dtype == torch.float16: + global_overflow = self._unscale_gradients() + if not global_overflow: + optimizer.step() + else: + optimizer.step() + return batch_loss / num_microbatches_per_gpu + + def _check_nan(self, tensor): + """ + check a tensor for overflow + + Arguments: + tensor (torch.Tensor): the tensor to be checked + Return + overflow (bool): true if there is overflow + """ + sum_ = tensor.sum() + return (torch.isinf(sum_) + torch.isnan(sum_)) > 0 + + def _unscale_gradients(self): + """ + unscale the gradients and check for overflow across all GPUs + """ + # at this point for mixed precision we will have unscaled fp-16 gradients + # for full precision we will have normal gradients + local_overflow = False + with torch.no_grad(): + for p in self.model.parameters(): + if p.grad is not None: + local_overflow = local_overflow or self._check_nan(p.grad) + p.grad.div_(self.scaler.loss_scale) + global_overflow = self._sync_scale(local_overflow) + return global_overflow diff --git a/axonn/utils.py b/axonn/utils.py new file mode 100644 index 0000000..e517f85 --- /dev/null +++ b/axonn/utils.py @@ -0,0 +1,25 @@ +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from typing import List + + +def _coalesce_and_reassign(tensors: List[torch.Tensor]) -> torch.Tensor: + """Coalesce tensors into a flattened 1D tensor and reassign them to + subtensors in this 1D tensor. + + TODO:- By creating a flat tensor first this doubles the gpu memory. + Make this less memory consuming + + Arguments: + tensors (List[torch.Tensor]): list of tensors to be coalesced + + Returns: + flatenned_tensors (torch.tensor): the flattened tensor. + + """ + flattened_tensor = _flatten_dense_tensors(tensors) + for old_tensor, new_tensor in zip( + tensors, _unflatten_dense_tensors(flattened_tensor, tensors) + ): + old_tensor.data = new_tensor + return flattened_tensor