From 1e081f2c9a9b46b239b3adbdd8567a375ac7b94f Mon Sep 17 00:00:00 2001 From: "lk.huang" Date: Sun, 3 Sep 2023 21:44:18 +0800 Subject: [PATCH] remove unused code --- examples/train.py | 2 +- scripts/convert2ckpt.py | 4 +-- .../models/llama_pipeline_model.py | 29 +------------------ src/transpeeder/models/patching.py | 6 ++-- 4 files changed, 7 insertions(+), 34 deletions(-) diff --git a/examples/train.py b/examples/train.py index ed90f3a..1d533f2 100644 --- a/examples/train.py +++ b/examples/train.py @@ -84,7 +84,7 @@ def main(): args.init_ckpt, model_max_length=args.max_seq_len, padding_side="right", - use_fast=True, + use_fast=False, ) model_config = transformers.AutoConfig.from_pretrained(args.init_ckpt) diff --git a/scripts/convert2ckpt.py b/scripts/convert2ckpt.py index 28bc5d9..22989c0 100644 --- a/scripts/convert2ckpt.py +++ b/scripts/convert2ckpt.py @@ -5,10 +5,10 @@ import torch import transformers -from models.patching import ( +from transpeeder.models.patching import ( smart_tokenizer_and_embedding_resize, ) -from feeder import ( +from transpeeder.feeder import ( DEFAULT_BOS_TOKEN, DEFAULT_PAD_TOKEN, DEFAULT_EOS_TOKEN, diff --git a/src/transpeeder/models/llama_pipeline_model.py b/src/transpeeder/models/llama_pipeline_model.py index 9dc1a5f..cd163bc 100644 --- a/src/transpeeder/models/llama_pipeline_model.py +++ b/src/transpeeder/models/llama_pipeline_model.py @@ -12,20 +12,12 @@ def forward(self, args): inputs_embeds = super().forward(input_ids) return (inputs_embeds, position_ids, attention_mask) -def _wrap_embed_layer(layer: torch.nn.Module): - layer.__class__ = EmbeddingPipe - return layer - class ParallelTransformerLayerPipe(LlamaDecoderLayer): - def __init__(self, config: LlamaConfig, activation_checkpointing=False): + def __init__(self, config: LlamaConfig): super().__init__(config) - self.activation_checkpointing = activation_checkpointing def forward(self, args): - if self.activation_checkpointing: - return self._ckpt_forward(args) - hidden_states, position_ids, mask = args attention_mask = torch.where(mask == True, float("-inf"), 0).long() @@ -36,25 +28,6 @@ def forward(self, args): ) return (outputs[0], position_ids, mask) - def _ckpt_forward(self, args): - hidden_states, position_ids, mask = args - attention_mask = torch.where(mask == True, float("-inf"), 0).long() - - def create_custom_forward(module): - def custom_forward(*inputs): - return LlamaDecoderLayer.forward(module, *inputs) - return custom_forward - - # deepspeed checkpoint auto use outputs[0] if len(outputs) == 1 - outputs = deepspeed.checkpointing.checkpoint( - create_custom_forward(self), - hidden_states, - attention_mask, - position_ids, - ) - - return (outputs, position_ids, mask) - class LayerNormPipe(LlamaRMSNorm): def forward(self, args): diff --git a/src/transpeeder/models/patching.py b/src/transpeeder/models/patching.py index b62ea02..5f7e2f6 100644 --- a/src/transpeeder/models/patching.py +++ b/src/transpeeder/models/patching.py @@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from einops import rearrange -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.bert_padding import unpad_input, pad_input @@ -85,7 +85,7 @@ def llama_flash_attn_forward( max_s = q_len cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) - output = flash_attn_unpadded_qkvpacked_func( + output = flash_attn_varlen_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) @@ -95,7 +95,7 @@ def llama_flash_attn_forward( x = rearrange(qkv, 'b s three h d -> b s (three h d)') x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) - output_unpad = flash_attn_unpadded_qkvpacked_func( + output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True )