From a174985b91db023181465a3ec9ad161a68787c57 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Mon, 12 Feb 2024 14:55:44 -0800 Subject: [PATCH] Support GEMM-GELU fusion with split AG overlap (#661) * Support GEMM-GELU fusion with split AG overlap Signed-off-by: Jaemin Choi * Fix linter complaints Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jaemin Choi * Avoid code duplication Signed-off-by: Jaemin Choi * Fix issue with modifying tuple Signed-off-by: Jaemin Choi * Disable GEMM-GELU fusion when split AG overlap is not enabled Signed-off-by: Jaemin Choi * Add ub_split_ag parameter to LayerNormMLP unit test Signed-off-by: Jaemin Choi * Move knob into LayerNormMLP, auto-disable fusion when split AG overlap is not enabled Signed-off-by: Jaemin Choi * Revert changes to test_layernorm_mlp_accuracy Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jaemin Choi --------- Signed-off-by: Jaemin Choi Signed-off-by: Jaemin Choi Co-authored-by: Jaemin Choi Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../pytorch/csrc/comm_gemm_overlap.h | 20 +++++++++- .../pytorch/module/layernorm_mlp.py | 40 +++++++++++++++---- 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index ef4614272b..8d30bf4a8c 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -793,10 +793,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Get communication and GEMM output chunk sizes const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const int output_chunk_bytes = (n_chunk * m) * HALF_BYTES; + const bool do_gelu = pre_gelu_out.numel() > 0; + const int output_chunk_bytes = (do_gelu + ? (n_chunk * m) * D.element_size() + : (n_chunk * m) * HALF_BYTES); + const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; // Get output and workspace data pointers char *output_ptr = reinterpret_cast(D.data_ptr()); + char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.data_ptr()); char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); int workspace_size_chunk = workspaceSize / _stream_compute.size(); @@ -809,7 +814,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - assert(pre_gelu_out.numel() == 0); if (_aggregate2) { // Catch up the default torch stream CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); @@ -848,6 +852,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options()); torch::Tensor output_chunk = torch::from_blob( output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options()); + if (do_gelu) { + pre_gelu_out = torch::from_blob( + pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), + {n_chunk * 2, m}, + pre_gelu_out.options()); + } torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}, workspace.options()); @@ -901,6 +911,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // GEMM torch::Tensor output_chunk = torch::from_blob( output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options()); + if (do_gelu) { + pre_gelu_out = torch::from_blob( + pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), + {n_chunk, m}, + pre_gelu_out.options()); + } torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}, workspace.options()); diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a5c9652f0d..5d1d001be4 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -118,6 +118,7 @@ def forward( ub_atomic_gemm_rs: bool, ub_split_ag: bool, ub_atomic_gemm_ag: bool, + gemm_gelu_fusion: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -261,7 +262,9 @@ def forward( ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo - fc1_out, _ = tex.fp8_gemm( + + # Perform FP8 GEMM + fp8_gemm_args = [ fc1_weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, @@ -272,6 +275,8 @@ def forward( fp8_dtype_forward, activation_dtype, get_workspace(), + ] + fp8_gemm_kwargs = dict( bias=fc1_bias, use_bias=use_fc1_bias, use_split_accumulator=_2X_ACC_FPROP, @@ -279,15 +284,31 @@ def forward( ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) + if gemm_gelu_fusion: + fp8_gemm_args[8] = torch.uint8 # out_dtype + fp8_gemm_kwargs.update( + dict( + gelu=True, + out_index=tex.FP8FwdTensors.GEMM2_INPUT, + fp8_meta_tensor=fp8_meta["scaling_fwd"], + D_dtype=fp8_dtype_forward, + ) + ) + fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs) if not is_grad_enabled: clear_tensor_data(ln_out_total) - gelu_out = activation_func( - fc1_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - ) + # Perform activation + if gemm_gelu_fusion: + gelu_out, fc1_out = fp8_gemm_out + else: + fc1_out, _ = fp8_gemm_out + gelu_out = activation_func( + fc1_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM2_INPUT, + fp8_dtype_forward, + ) if not is_grad_enabled: clear_tensor_data(fc1_out) @@ -1033,6 +1054,7 @@ def backward( None, None, None, + None, ) @@ -1175,6 +1197,9 @@ def __init__( self.ub_split_ag = ub_split_ag self.ub_atomic_gemm_rs = ub_atomic_gemm_rs self.ub_atomic_gemm_ag = ub_atomic_gemm_ag + # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap + self.gemm_gelu_fusion = (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and + self.activation == 'gelu' and self.ub_split_ag) if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions or ub_bulk_dgrad @@ -1438,6 +1463,7 @@ def forward( self.ub_atomic_gemm_rs, self.ub_split_ag, self.ub_atomic_gemm_ag, + self.gemm_gelu_fusion, ) out = fwd_fn(*args)