From 44574def7f34fb61bebd458b74c47fe33acec57d Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Mon, 29 Jan 2024 16:00:01 -0800 Subject: [PATCH] Fixed offloading for PyT version/ Added Attention activation offloading support/ Native FP8 support (#632) * Fixed offloading for PyT version/ Added Attention activation offloading support/ Native FP8 support Signed-off-by: Selvaraj Anandaraj * Removed activation offloading for fused attention Signed-off-by: Selvaraj Anandaraj * Fixed the illegal memory access issue for activation offloading of attention Signed-off-by: Selvaraj Anandaraj * Removed the version guard Signed-off-by: Selvaraj Anandaraj * Pipeline failures fix Signed-off-by: Selvaraj Anandaraj * Fixed lint erros Signed-off-by: Selvaraj Anandaraj * Lint error fix Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj --- transformer_engine/pytorch/attention.py | 24 ++++++++++ transformer_engine/pytorch/cpu_offload.py | 46 +++++++++++++------ .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 3 +- transformer_engine/pytorch/module/linear.py | 2 +- 5 files changed, 59 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9f4fe1f688..92e13108e4 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1752,6 +1752,14 @@ def forward( deterministic=self.deterministic ) else: + + from .cpu_offload import CPUOffloadEnabled + if CPUOffloadEnabled: + tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv] + for tensor in tensor_list: + if tensor is not None: + tensor.activation_offloading = True + with self.attention_dropout_ctx(): fa_optional_forward_kwargs = {} if _flash_attn_2_3_plus: @@ -1938,6 +1946,15 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen) + from .cpu_offload import CPUOffloadEnabled + if CPUOffloadEnabled: + tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv] + qkv_layout = 'sbhd_sbhd_sbhd' + for tensor in tensor_list: + if tensor is not None: + tensor.activation_offloading = True + + ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv) ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q @@ -2818,6 +2835,13 @@ def forward( assert (not context_parallel), \ "Context parallelism is only implemented with Flash Attention and Fused Attention!" + from .cpu_offload import CPUOffloadEnabled + if CPUOffloadEnabled: + warnings.warn( + "Attention activation Offloading is only implemented" + "with Flash Attention and Fused Attention!" + ) + if _NVTE_DEBUG: print("[DotProductAttention]: using unfused DPA") if use_unfused_attention: diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index dcede62ef7..b2635bb9bf 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -184,6 +184,7 @@ def groupid_reset(self): # the tensor back to gpu and deletes the cpu tensor. # These will increment whenever `group_commit()` is invoked self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 self.tensor_tag_to_state = {} def on_group_commit_forward(self): @@ -310,24 +311,35 @@ def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag): def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - assert tensor_tag not in self.tensor_tag_to_state - if (self.current_group < self.num_offload_group - and self.tensor_need_offloading_checker(tensor)): - # first copy the tensor to tensorbuf, so that the original tensor will not be deleted - tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag) - tensor_buf.copy_(tensor) - if hasattr(tensor,"weight_offloading"): - tensor_buf.weight_offloading = True - if hasattr(tensor,"activation_offloading"): - tensor_buf.activation_offloading = True - # Here we just save it, and at commit, bulk_offload_group will handle it - self.tensor_tag_to_state[tensor_tag] = tensor_buf + torch_stray_tensor = isinstance(tensor,(torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor)) + + if not torch_stray_tensor: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + + if (self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(tensor)): + # first copy the tensor to tensorbuf, + # so that the original tensor will not be deleted + tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag) + tensor_buf.copy_(tensor) + if hasattr(tensor,"weight_offloading"): + tensor_buf.weight_offloading = True + if hasattr(tensor,"activation_offloading"): + tensor_buf.activation_offloading = True + # Here we just save it, and at commit, bulk_offload_group will handle it + self.tensor_tag_to_state[tensor_tag] = tensor_buf + else: + self.tensor_tag_to_state[tensor_tag] = tensor else: + tensor_tag = (-1,self.torch_tensor_count) + self.torch_tensor_count += 1 self.tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag def tensor_pop(self, tensor_tag, **kwargs): @@ -350,6 +362,10 @@ def bulk_offload_group(self, group_to_offload): # if offload, return the reference to cpu copy if self.tensor_need_offloading_checker(tensor_on_device): + if hasattr(tensor_on_device,"weight_offloading"): + delattr(tensor_on_device,"weight_offloading") + if hasattr(tensor_on_device,"activation_offloading"): + delattr(tensor_on_device,"activation_offloading") state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) self.tensor_tag_to_state[tensor_tag] = state diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2de860cf73..6836ef6d22 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -242,7 +242,7 @@ def forward( if cpu_offloading: if fuse_wgrad_accumulation: weight.main_grad.weight_offloading = True - if fp8: + if fp8 and weight_t_fp8 is not None: weight_t_fp8.weight_offloading = True ln_weight.weight_offloading = True weight.weight_offloading = True diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d48ee4887d..3a0e5cb559 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -424,8 +424,9 @@ def forward( if fuse_wgrad_accumulation: fc1_weight.main_grad.weight_offloading = True fc2_weight.main_grad.weight_offloading = True - if fp8: + if fp8 and fc1_weight_t_fp8 is not None: fc1_weight_t_fp8.weight_offloading = True + if fp8 and fc2_weight_t_fp8 is not None: fc2_weight_t_fp8.weight_offloading = True ln_weight.weight_offloading = True fc1_weight.weight_offloading = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 68c5bf1a1d..f2c955bfc0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -275,7 +275,7 @@ def forward( if cpu_offloading: if fuse_wgrad_accumulation: weight.main_grad.weight_offloading = True - if fp8: + if fp8 and weight_t_fp8 is not None: weight_t_fp8.weight_offloading = True weight.weight_offloading = True