Skip to content

Commit

Permalink
remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
lk.huang committed Sep 3, 2023
1 parent 294d7e8 commit 1e081f2
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def main():
args.init_ckpt,
model_max_length=args.max_seq_len,
padding_side="right",
use_fast=True,
use_fast=False,
)
model_config = transformers.AutoConfig.from_pretrained(args.init_ckpt)

Expand Down
4 changes: 2 additions & 2 deletions scripts/convert2ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch
import transformers

from models.patching import (
from transpeeder.models.patching import (
smart_tokenizer_and_embedding_resize,
)
from feeder import (
from transpeeder.feeder import (
DEFAULT_BOS_TOKEN,
DEFAULT_PAD_TOKEN,
DEFAULT_EOS_TOKEN,
Expand Down
29 changes: 1 addition & 28 deletions src/transpeeder/models/llama_pipeline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,12 @@ def forward(self, args):
inputs_embeds = super().forward(input_ids)
return (inputs_embeds, position_ids, attention_mask)

def _wrap_embed_layer(layer: torch.nn.Module):
layer.__class__ = EmbeddingPipe
return layer


class ParallelTransformerLayerPipe(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, activation_checkpointing=False):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.activation_checkpointing = activation_checkpointing

def forward(self, args):
if self.activation_checkpointing:
return self._ckpt_forward(args)

hidden_states, position_ids, mask = args
attention_mask = torch.where(mask == True, float("-inf"), 0).long()

Expand All @@ -36,25 +28,6 @@ def forward(self, args):
)
return (outputs[0], position_ids, mask)

def _ckpt_forward(self, args):
hidden_states, position_ids, mask = args
attention_mask = torch.where(mask == True, float("-inf"), 0).long()

def create_custom_forward(module):
def custom_forward(*inputs):
return LlamaDecoderLayer.forward(module, *inputs)
return custom_forward

# deepspeed checkpoint auto use outputs[0] if len(outputs) == 1
outputs = deepspeed.checkpointing.checkpoint(
create_custom_forward(self),
hidden_states,
attention_mask,
position_ids,
)

return (outputs, position_ids, mask)


class LayerNormPipe(LlamaRMSNorm):
def forward(self, args):
Expand Down
6 changes: 3 additions & 3 deletions src/transpeeder/models/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input


Expand Down Expand Up @@ -85,7 +85,7 @@ def llama_flash_attn_forward(
max_s = q_len
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0,
softmax_scale=None, causal=True
)
Expand All @@ -95,7 +95,7 @@ def llama_flash_attn_forward(
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_attn_unpadded_qkvpacked_func(
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0,
softmax_scale=None, causal=True
)
Expand Down

0 comments on commit 1e081f2

Please sign in to comment.