Skip to content

Commit

Permalink
Merge pull request #39 from JY-Ren/new-dev
Browse files Browse the repository at this point in the history
Add ntk, flash-attn2 and support llama2
  • Loading branch information
HuangLK authored Sep 6, 2023
2 parents 1e081f2 + 7ef91e4 commit 8c6f663
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 40 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ See `examples/train_llama_deepspeed.sh`.
```bash
python -m scripts.convert2hf --model_size 7B \
--input_dir ./output/llama-7B-ckpt/global_step1000/ \
--output_dir ./output/llama_hf_7B
--output_dir ./output/llama_hf_7B \
--tokenizer_size 32001
cp /path/to/llama-7b-hf/*.json ./output/llama_hf_7B
cp /path/to/llama-7b-hf/tokenizer.model ./output/llama_hf_7B
```
12 changes: 11 additions & 1 deletion examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transpeeder.models.llama_pipeline_model import get_model
from transpeeder.models.patching import (
replace_llama_attn_with_flash_attn,
refine_rope,
)
from transpeeder.feeder import (
make_prompt_dataloader,
Expand Down Expand Up @@ -51,7 +52,7 @@ class TrainerArguments:

resume_step: int = field(default=-1)
resume_ckpt: str = field(default="llama-7B-init-test-ckpt")

ntk : Optional[bool] = field(default=False)

def read_ds_config(config_path):
config = jload(config_path)
Expand Down Expand Up @@ -79,6 +80,7 @@ def main():
if args.use_flash_attn:
logger.info("⚡⚡⚡ enable flash attention.")
replace_llama_attn_with_flash_attn()
refine_rope()

tokenizer = transformers.AutoTokenizer.from_pretrained(
args.init_ckpt,
Expand All @@ -88,6 +90,14 @@ def main():
)
model_config = transformers.AutoConfig.from_pretrained(args.init_ckpt)

if args.ntk:
rope_scaling = {
"type": "dynamic",
"factor": 2,
}
model_config.rope_scaling = rope_scaling
logger.info(f"Turn on dynamic rope for llama2")

# dataset
dataloader_maker = make_tokenized_dataloader if args.input_format == 'tokenized' else make_prompt_dataloader
train_dataloader = dataloader_maker(tokenizer=tokenizer, data_args=args)
Expand Down
1 change: 1 addition & 0 deletions examples/train_llama_deepspeed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ deepspeed --include localhost:$2 --master_port ${MASTER_PORT} ${WORK_DIR}/train
--pipe_parallel_size 8 \
--model_parallel_size 1 \
--use_flash_attn true \
--ntk true \
--deepspeed_config ${WORK_DIR}/../configs/ds_config_zero1.json
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy
sentencepiece
transformers==4.28.0
transformers >= 4.31.0
deepspeed @ git+https://github.com/HuangLK/DeepSpeed.git@dev
flash_attn
flash_attn >= 2.0
10 changes: 8 additions & 2 deletions scripts/convert2hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def write_json(text, path):
json.dump(text, f)


def write_model(model_path, input_base_path, model_size):
def write_model(model_path, input_base_path, model_size, tokenizer_size):
assert model_size in PARAM_MAP
os.makedirs(model_path, exist_ok=True)

params = PARAM_MAP[model_size]
n_layers = params["n_layers"]

loaded = {}
ORIGINAL_TOKENIZER_SIZE = 32000
ORIGINAL_TOKENIZER_SIZE = tokenizer_size
for pt in Path(input_base_path).iterdir():
# assert tp/mp == 1
sd = torch.load(pt, map_location="cpu")
Expand Down Expand Up @@ -78,11 +78,17 @@ def main():
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--tokenizer_size",
help="Size of tokenizer",
type=int,
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
tokenizer_size=args.tokenizer_size,
)


Expand Down
58 changes: 58 additions & 0 deletions src/transpeeder/models/modeling_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch

class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)

class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len

if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))

t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
105 changes: 71 additions & 34 deletions src/transpeeder/models/patching.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
""" https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py.
""" https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
"""

from typing import List, Optional, Tuple, Dict

import torch
import torch.nn.functional as F
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input

Expand Down Expand Up @@ -36,6 +36,16 @@ def smart_tokenizer_and_embedding_resize(
# input_embeddings[-num_new_tokens:] = input_embeddings_avg
# output_embeddings[-num_new_tokens:] = output_embeddings_avg

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def llama_flash_attn_forward(
self,
Expand All @@ -53,57 +63,79 @@ def llama_flash_attn_forward(
"""
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)


# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py

# transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
qkv = torch.stack([query_states, key_states, value_states], dim=2)
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]

attention_mask = torch.ones((bsz, q_len), device=qkv.device)
key_padding_mask = attention_mask


if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
max_s = q_len
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
device=qkv.device)
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0,
softmax_scale=None, causal=True
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
output = output.view(bsz, q_len, -1)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
qkv = qkv.reshape(bsz, q_len, -1)
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0,
softmax_scale=None, causal=True
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, bsz, q_len),
'b s (h d) -> b s h d', h=nheads)
return self.o_proj(rearrange(output,
'b s h d -> b s (h d)')), None, None
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)

return self.o_proj(output), None, past_key_value


# Disable the transformation of the attention mask in LlamaModel as the flash attention
Expand All @@ -120,3 +152,8 @@ def _prepare_decoder_attention_mask(self,
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_flash_attn_forward

def refine_rope():
from .modeling_llama import LlamaRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding = LlamaDynamicNTKScalingRotaryEmbedding

0 comments on commit 8c6f663

Please sign in to comment.