From 1185ad9ca3fec9977ca1c2501b5d4c42c1b86ff3 Mon Sep 17 00:00:00 2001 From: Zipeng Xie <53039617+xiezipeng-ML@users.noreply.github.com> Date: Mon, 15 Jan 2024 20:55:54 +0800 Subject: [PATCH] Adpter finetune (#528) * support adapter finetune * refine * reformat --- projects/Llama/adapter/adapter_config.py | 63 ++ projects/Llama/adapter/adapter_model.py | 730 +++++++++++++++++++++++ projects/Llama/adapter/adapter_sft.py | 97 +++ projects/Llama/adapter/train_net.py | 115 ++++ projects/Llama/configs/llama_config.py | 3 +- projects/Llama/configs/llama_sft.py | 29 +- projects/Llama/dataset.py | 33 +- projects/Llama/llama.py | 23 +- projects/Llama/pipeline.py | 4 +- projects/Llama/readme.md | 5 +- projects/Llama/utils/eval_adapter.py | 2 +- projects/Llama/utils/llama_loader.py | 3 - projects/Llama/utils/prepare_alpaca.py | 57 +- 13 files changed, 1073 insertions(+), 91 deletions(-) create mode 100644 projects/Llama/adapter/adapter_config.py create mode 100644 projects/Llama/adapter/adapter_model.py create mode 100644 projects/Llama/adapter/adapter_sft.py create mode 100644 projects/Llama/adapter/train_net.py diff --git a/projects/Llama/adapter/adapter_config.py b/projects/Llama/adapter/adapter_config.py new file mode 100644 index 000000000..7381e64af --- /dev/null +++ b/projects/Llama/adapter/adapter_config.py @@ -0,0 +1,63 @@ +from omegaconf import DictConfig, OmegaConf + +from configs.common.train import train # noqa +from libai.config import LazyCall +from projects.Llama.adapter.adapter_model import LlamaForCausalLM +from projects.Llama.tokenizer import LlamaTokenizer + +cfg = dict( + # Model + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=11008, + max_position_embeddings=4096, + num_attention_heads=32, + hidden_layers=32, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + tie_word_embeddings=False, + vocab_size=32000, + use_scaled_init_for_output_weights=False, + scale_mask_softmax_fusion=False, + amp_enabled=True, + # Inference + is_encoder_decoder=False, + max_length=256, + min_length=0, + do_sample=False, + early_stopping=False, + num_beams=1, + num_beam_groups=1, + diversity_penalty=0.0, + temperature=0.9, + top_k=50, + top_p=0.6, + typical_p=1.0, + repetition_penalty=1.0, + length_penalty=1.0, + no_repeat_ngram_size=0, + encoder_no_repeat_ngram_size=0, + num_return_sequences=1, + chunk_size_feed_forward=0, + output_scores=False, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + pad_token_id=0, + # adapter + adapter_len=10, + adapter_layer=30, + # train + pretrained_model_path="meta-llama/Llama-2-7b-hf/", +) + +cfg = DictConfig(cfg) + +model = LazyCall(LlamaForCausalLM)(cfg=cfg) +tokenization = OmegaConf.create() +tokenization.make_vocab_size_divisible_by = 1 +tokenization.tokenizer = LazyCall(LlamaTokenizer)( + pretrained_model_path="Llama-2-7b-hf/tokenizer.model" +) diff --git a/projects/Llama/adapter/adapter_model.py b/projects/Llama/adapter/adapter_model.py new file mode 100644 index 000000000..78a72b4fe --- /dev/null +++ b/projects/Llama/adapter/adapter_model.py @@ -0,0 +1,730 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +import oneflow as flow +import oneflow.nn.functional as F +from oneflow import nn + +from libai.config import configurable +from libai.inference.generator.generation_utils import Generator +from libai.layers import Embedding, Linear, RMSLayerNorm, VocabEmbedding +from libai.layers.attention import AttnMaskType +from libai.models.utils import init_method_normal, scaled_init_method_normal +from libai.utils import distributed as dist + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return flow.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + + def forward(self, x, seq_len=None, cos_cached=None, sin_cached=None): + if seq_len > self.max_position_embeddings: + raise ValueError( + f"The maximum supported length is {self.max_position_embeddings}, " + f"and the current length is{seq_len}." + ) + + return ( + cos_cached[:seq_len].to_global(placement=x.placement), + sin_cached[:seq_len].to_global(placement=x.placement), + ) + + +class MLP(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + *, + layer_idx=0, + ): + super().__init__() + + if output_layer_init_method is None: + output_layer_init_method = init_method + + self.gate_proj = Linear( + hidden_size, + intermediate_size, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.up_proj = Linear( + hidden_size, + intermediate_size, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.down_proj = Linear( + intermediate_size, + hidden_size, + bias=False, + parallel="row", + init_method=output_layer_init_method, + layer_idx=layer_idx, + ) + + self.activation_func = nn.SiLU() + + def forward(self, hidden_states): + gate_out = self.activation_func(self.gate_proj(hidden_states)) + up_out = self.up_proj(hidden_states) + output = self.down_proj(gate_out * up_out) + return output + + +class MultiheadAttention(nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + max_position_embeddings, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + scale_mask_softmax_fusion=False, + attn_mask_type=AttnMaskType.padding, + *, + layer_idx=0, + ): + super().__init__() + self.hidden_size = hidden_size + if output_layer_init_method is None: + output_layer_init_method = init_method + + self.num_heads = num_attention_heads + self.head_size = hidden_size // num_attention_heads + self.attn_mask_type = attn_mask_type + + self.norm_factor = 1.0 / math.sqrt(float(self.head_size)) + + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + + self.query_key_value = Linear( + self.hidden_size, + self.hidden_size * 3, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.o_proj = Linear( + self.hidden_size, + self.hidden_size, + bias=False, + parallel="row", + init_method=output_layer_init_method, + layer_idx=layer_idx, + ) + + self.coeff = None + + rotary_dim = self.head_size + self.rotary_embed = RotaryEmbedding( + dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + ) + + self.gate = flow.nn.Parameter( + flow.zeros( + 1, + self.num_heads, + 1, + 1, + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + ) + + def forward( + self, + hidden_states: flow.Tensor, + encoder_states: flow.Tensor = None, + attention_mask: flow.Tensor = None, + position_ids=None, + past_key_value: Tuple[flow.Tensor, flow.Tensor] = None, + cos_cached: flow.Tensor = None, + sin_cached: flow.Tensor = None, + use_cache: bool = False, + adapter=None, + ): + if encoder_states is not None: + encoder_states = encoder_states.to_global(placement=hidden_states.placement) + + if attention_mask is not None: + attention_mask = attention_mask.to_global(placement=hidden_states.placement) + + if adapter is not None: + adapter = adapter.to_global(placement=hidden_states.placement) + + bsz, tgt_len = hidden_states.size()[:2] + + query_key_value = self.query_key_value(hidden_states) + query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size) + query_key_value = query_key_value.permute( + 0, 2, 1, 3 + ) # [bsz, num_heads, src_len, 3 * head_size] + query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1) + + kv_seq_len = key.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_embed( + value, seq_len=kv_seq_len, cos_cached=cos_cached, sin_cached=sin_cached + ) + query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) + + # [1, adapter_len, 4096] + if adapter is not None: + adapter_len = adapter.shape[1] + adapter_qkv = self.query_key_value(adapter) + adapter_qkv = adapter_qkv.view(1, -1, self.num_heads, 3 * self.head_size) + adapter_qkv = adapter_qkv.permute(0, 2, 1, 3) # [1, num_heads, src_len, 3 * head_size] + _, adapter_key, adapter_value = flow.chunk(adapter_qkv, chunks=3, dim=-1) + adapter_key = adapter_key.repeat(bsz, 1, 1, 1) + adapter_value = adapter_value.repeat(bsz, 1, 1, 1) + key = flow.cat([adapter_key, key], dim=2) + value = flow.cat([adapter_value, value], dim=2) + extra_mask = flow.zeros( + bsz, + 1, + tgt_len, + adapter_len, + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=attention_mask.placement, + ) + attention_mask = flow.cat([extra_mask, attention_mask], dim=-1) + + if past_key_value is not None: + past_key, past_value = past_key_value + key = flow.cat((past_key.type_as(key), key), dim=2) + value = flow.cat((past_value.type_as(value), value), dim=2) + + # query, key, value: [S(0), S(1)], shape: [bsz, num_heads, seq_length, head_size] + if use_cache: + past_key_value = (key, value) + + # [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)] + attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor) + attention_weights = attention_scores + attention_mask + + if adapter is not None: + attention_weights = flow.cat( + [ + self.gate.tanh().half() + * F.softmax(attention_weights[:, :, :, :adapter_len].float(), dim=-1).to( + query.dtype + ), + F.softmax(attention_weights[:, :, :, adapter_len:].float(), dim=-1).to( + query.dtype + ), + ], + dim=-1, + ) + else: + attention_weights = flow.softmax(attention_weights, dim=-1) + # Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)] + context = flow.matmul(attention_weights, value) + + # Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size] + context = context.transpose(1, 2) + output = self.o_proj(context.flatten(2)) + + if use_cache: + output = (output, past_key_value) + + return output + + +class CasualMask(nn.Module): + def __init__(self, max_positions=1024, dtype=flow.float16, *, layer_idx=0): + super().__init__() + self.dtype = dtype + self.mask = flow.full( + (max_positions, max_positions), + flow.finfo(dtype).min, + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + mask_cond = flow.arange( + self.mask.size(-1), + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + self.mask.masked_fill_(mask_cond < (mask_cond + 1).view(self.mask.size(-1), 1), 0) + self.mask = self.mask.to(dtype) + + def forward(self, input_ids, past_length=0, attention_mask=None, input_dtype=None): + bsz, tgt_len = input_ids.size() + casual_mask = self.mask[:tgt_len, :tgt_len] + if past_length > 0: + # in case past_key_values are used, we need to add a prefix ones mask to casual mask + casual_mask = flow.cat( + [flow.ones(tgt_len, past_length, dtype=self.dtype), casual_mask], dim=-1 + ) + casual_mask = ( + casual_mask.unsqueeze(0).unsqueeze(1).expand(bsz, 1, tgt_len, tgt_len + past_length) + ) + casual_mask = casual_mask.to_global(sbp=input_ids.sbp) + if attention_mask is not None: + bsz, src_len = attention_mask.size() + attention_mask = ( + attention_mask[:, None, None, :] + .expand(bsz, 1, tgt_len, src_len) + .to(casual_mask.dtype) + ) + attention_mask = attention_mask.to_global(placement=casual_mask.placement) + casual_mask = casual_mask + attention_mask + if input_dtype is not None: + casual_mask = casual_mask.to(input_dtype) + return casual_mask + + +class LlamaDecoderLayer(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + num_attention_heads, + is_decoder=False, + rms_norm_eps=1e-5, + max_position_embeddings=None, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + scale_mask_softmax_fusion=False, + attn_mask_type=AttnMaskType.padding, + *, + layer_idx=0, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.rms_norm_eps = rms_norm_eps + self.max_position_embeddings = max_position_embeddings + self.attn_mask_type = attn_mask_type + + self.layer_idx = layer_idx + self.is_decoder = is_decoder + + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + + self.init_method = init_method + if output_layer_init_method is None: + output_layer_init_method = init_method + self.output_layer_init_method = output_layer_init_method + + self.input_layernorm = RMSLayerNorm( + self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx + ) + + self.self_attn = self.build_attention() + self.post_attention_layernorm = RMSLayerNorm( + self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx + ) + + self.mlp = MLP( + self.hidden_size, + self.intermediate_size, + self.init_method, + output_layer_init_method=self.output_layer_init_method, + layer_idx=self.layer_idx, + ) + + def forward( + self, + hidden_states, + attention_mask=None, + past_key_value=None, + cos_cached=None, + sin_cached=None, + use_cache=False, + adapter=None, + ): + hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx)) + + # hidden_states shape: (batch_size, seq_length, hidden_size) + if attention_mask is not None: + attention_mask = attention_mask.to_global( + placement=dist.get_layer_placement(self.layer_idx) + ) + + if past_key_value is not None: + if self.is_decoder: + assert len(past_key_value) == 4 + self_attn_past_key_value = past_key_value[:2] + else: + self_attn_past_key_value = past_key_value + else: + self_attn_past_key_value = None + + layernorm_output = self.input_layernorm(hidden_states) + attention_output = self.self_attn( + layernorm_output, + attention_mask=attention_mask, + past_key_value=self_attn_past_key_value, + cos_cached=cos_cached, + sin_cached=sin_cached, + use_cache=use_cache, + adapter=adapter, + ) + + if use_cache: + attention_output, presents = attention_output + + hidden_states = hidden_states + attention_output + + layernorm_output = self.post_attention_layernorm(hidden_states) + + mlp_output = self.mlp(layernorm_output) + + output = hidden_states + mlp_output + + if use_cache: + output = (output, presents) + return output + + def build_attention(self): + return MultiheadAttention( + self.hidden_size, + self.num_attention_heads, + self.max_position_embeddings, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + scale_mask_softmax_fusion=self.scale_mask_softmax_fusion, + attn_mask_type=self.attn_mask_type, + layer_idx=self.layer_idx, + ) + + +class LlamaModel(nn.Module): + def __init__( + self, + hidden_layers, + vocab_size, + hidden_size, + intermediate_size, + num_attention_heads, + max_position_embeddings=1024, + rms_norm_eps=1e-5, + initializer_range=0.02, + use_scaled_init_for_output_weights=True, + scale_mask_softmax_fusion=False, + amp_enabled=False, + cfg=None, + ): + super().__init__() + self.cfg = cfg + init_method = init_method_normal(sigma=initializer_range) + if use_scaled_init_for_output_weights: + output_layer_init_method = scaled_init_method_normal(initializer_range, hidden_layers) + else: + output_layer_init_method = init_method + + self.embed_tokens = VocabEmbedding( + vocab_size, hidden_size, init_method=init_method, amp_enabled=amp_enabled + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + hidden_size, + intermediate_size, + num_attention_heads, + rms_norm_eps=rms_norm_eps, + max_position_embeddings=max_position_embeddings, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, + attn_mask_type=AttnMaskType.causal, + layer_idx=i, + ) + for i in range(hidden_layers) + ] + ) + self.norm = RMSLayerNorm(hidden_size, eps=rms_norm_eps, layer_idx=-1) + + self.adapter_query = Embedding( + cfg.adapter_len * cfg.adapter_layer, hidden_size, amp_enabled=amp_enabled + ) + + self._set_cos_sin_cache( + rotary_dim=hidden_size // num_attention_heads, + seq_len=max_position_embeddings, + dtype=flow.float32, + layer_idx=0, + ) + + def _set_cos_sin_cache(self, rotary_dim, seq_len, base=10000, dtype=None, layer_idx=0): + position = flow.arange( + 0, + rotary_dim, + 2, + dtype=dtype, + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=dist.get_layer_placement(layer_idx), + ) + inv_freq = 1.0 / (base ** (position / rotary_dim)) + + t = flow.arange( + seq_len, + dtype=inv_freq.dtype, + sbp=inv_freq.sbp, + placement=inv_freq.placement, + ) + + freqs = flow.einsum("i,j->ij", t, inv_freq) + emb = flow.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype)) + self.register_buffer("sin_cached", emb.sin().to(dtype)) + + def forward( + self, + input_ids, + attention_mask=None, + past_key_values=None, + use_cache=False, + set_cache=None, + ): + with flow.no_grad(): + if use_cache: + presents = [] + input_ids = input_ids.to_global(placement=dist.get_layer_placement(0)) + hidden_states = self.embed_tokens(input_ids) + + for layer, past_key_value in zip( + self.layers[: -self.cfg.adapter_layer], past_key_values[: -self.cfg.adapter_layer] + ): + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + cos_cached=self.cos_cached, + sin_cached=self.sin_cached, + use_cache=False, + adapter=None, + ) + if use_cache: + hidden_states, present = hidden_states + presents.append(present) + + adapter_index = 0 + # [num_adapter_layer, 1, adapter_len, 4096] + adapter = self.adapter_query.weight.reshape(-1, self.cfg.adapter_len, 4096).unsqueeze(1) + for layer, past_key_value in zip( + self.layers[-self.cfg.adapter_layer :], past_key_values[-self.cfg.adapter_layer :] + ): + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + cos_cached=self.cos_cached, + sin_cached=self.sin_cached, + use_cache=False, + adapter=adapter[adapter_index], # [1, adapter_len, 4096] + ) + adapter_index += 1 + if use_cache: + hidden_states, present = hidden_states + presents.append(present) + + hidden_states = self.norm(hidden_states) + + if use_cache: + set_cache(presents) + + return hidden_states + + +class CrossEntropyLoss(nn.Module): + def forward(self, logits: flow.Tensor, target: flow.Tensor): + assert logits.ndim == 3 + assert target.ndim == 2 + assert logits.shape[0:2] == target.shape + + target = target.to_global(placement=logits.placement) + target = target * (target >= 0) + + lm_loss = flow._C.cross_entropy( + logits.view(-1, logits.shape[-1]), target.view(-1), ignore_index=0 + ) + return lm_loss + + +class SFTLoss(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lm_loss = CrossEntropyLoss() + + def forward(self, logits, lm_labels): + lm_loss = self.lm_loss(logits, lm_labels) + lm_loss = lm_loss.mean() + return {"lm_loss": lm_loss} + + +class LlamaForCausalLM(nn.Module, Generator): + @configurable + def __init__( + self, + hidden_layers, + vocab_size, + hidden_size, + intermediate_size, + num_attention_heads, + max_position_embeddings=1024, + rms_norm_eps=1e-5, + initializer_range=0.02, + use_scaled_init_for_output_weights=True, + scale_mask_softmax_fusion=False, + amp_enabled=False, + cfg=None, + ): + super().__init__() + self.cfg = cfg + self.model = LlamaModel( + hidden_layers=hidden_layers, + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=rms_norm_eps, + initializer_range=initializer_range, + use_scaled_init_for_output_weights=use_scaled_init_for_output_weights, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, + amp_enabled=amp_enabled, + cfg=cfg, + ) + self.casual_mask = CasualMask(max_position_embeddings, layer_idx=0) + self.lm_head = Linear(hidden_size, vocab_size, bias=False, layer_idx=-1) + self.loss_func = SFTLoss() + + self.past_key_values = [None] * hidden_layers + self.past_length = 0 + + def forward(self, input_ids, attention_mask=None, labels=None, use_cache=False): + input_ids = input_ids.to_global(placement=dist.get_layer_placement(0)) + attention_mask = ( + attention_mask.to_global(placement=dist.get_layer_placement(0)) + if attention_mask is not None + else attention_mask + ) + labels = ( + labels.to_global(placement=dist.get_layer_placement(0)) + if labels is not None + else labels + ) + + if use_cache and self.past_key_values[0] is not None: + self.past_length = self.past_key_values[0][0].size(-2) + else: + self.past_length = 0 + + mask = self.casual_mask( + input_ids, + past_length=self.past_length, + attention_mask=attention_mask, + input_dtype=self.lm_head.weight.dtype, + ) + + output = self.model( + input_ids, + attention_mask=mask, + past_key_values=self.past_key_values, + use_cache=use_cache, + set_cache=self.set_cache, + ) + + logits = self.lm_head(output) + + if labels is not None: + lm_loss = self.loss_func(logits, labels) + return lm_loss + else: + return {"logits": logits} + + def set_cache(self, past_key_values): + self.past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2] + + if past_key_values is None: + past_key_values = [None] * self.cfg.hidden_layers + + assert len(past_key_values) == self.cfg.hidden_layers, ( + f"past_key_values's length {len(past_key_values)} doesn't match " + f"num_layers:' {self.cfg.hidden_layers}" + ) + + def prepare_inputs_for_generation(self, input_ids: flow.Tensor, **kwargs): + if "attention_mask" in kwargs: + attention_mask = kwargs.pop("attention_mask").float() + attention_mask = attention_mask - 1 + attention_mask.masked_fill_(attention_mask == -1, flow.finfo(flow.float32).min) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @classmethod + def from_config(cls, cfg): + return { + "hidden_layers": cfg.hidden_layers, + "vocab_size": cfg.vocab_size, + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "num_attention_heads": cfg.num_attention_heads, + "max_position_embeddings": cfg.max_position_embeddings, + "rms_norm_eps": cfg.rms_norm_eps, + "initializer_range": cfg.initializer_range, + "use_scaled_init_for_output_weights": cfg.use_scaled_init_for_output_weights, + "scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion, + "amp_enabled": cfg.amp_enabled, + "cfg": cfg, + } + + @staticmethod + def set_activation_checkpoint(model): + for module_block in model.modules(): + # Old API in OneFlow 0.8 + if hasattr(module_block, "origin"): + if isinstance(module_block.origin, LlamaDecoderLayer): + module_block.config.activation_checkpointing = True + else: + if isinstance(module_block.to(nn.Module), LlamaDecoderLayer): + module_block.to(nn.graph.GraphModule).activation_checkpointing = True diff --git a/projects/Llama/adapter/adapter_sft.py b/projects/Llama/adapter/adapter_sft.py new file mode 100644 index 000000000..e95e012bb --- /dev/null +++ b/projects/Llama/adapter/adapter_sft.py @@ -0,0 +1,97 @@ +import os + +from omegaconf import OmegaConf + +from configs.common.models.graph import graph +from configs.common.optim import optim +from configs.common.train import train +from libai.config import LazyCall +from libai.data.build import build_nlp_test_loader, build_nlp_train_loader +from libai.evaluation import PPLEvaluator +from libai.scheduler import WarmupExponentialLR +from projects.Llama.adapter.adapter_config import cfg +from projects.Llama.adapter.adapter_model import LlamaForCausalLM +from projects.Llama.dataset import AlpacaDataset +from projects.Llama.tokenizer import LlamaTokenizer + +# Hyperparameters +weight_decay = 0.1 +learning_rate = 2e-5 +max_input_length = 512 +dataset_path = "alpaca_data" +pretrained_model_path = "meta-llama/Llama-2-7b-hf" + +# graph & optim +graph["enabled"] = False +optim.update( + dict( + lr=learning_rate, + weight_decay=weight_decay, + ) +) + +# tokenize +tokenization = OmegaConf.create() +tokenization.make_vocab_size_divisible_by = 1 +tokenization.tokenizer = LazyCall(LlamaTokenizer)( + pretrained_model_path=os.path.join(pretrained_model_path, "tokenizer.model") +) + +# model +cfg.use_cache = False +model = LazyCall(LlamaForCausalLM)(cfg=cfg) + +# datasets +dataloader = OmegaConf.create() +dataloader.train = LazyCall(build_nlp_train_loader)( + dataset=[ + LazyCall(AlpacaDataset)( + path=os.path.join(dataset_path, "train"), tokenizer=tokenization.tokenizer + ) + ], +) +dataloader.test = [ + LazyCall(build_nlp_test_loader)( + dataset=LazyCall(AlpacaDataset)( + path=os.path.join(dataset_path, "test"), tokenizer=tokenization.tokenizer + ), + ), +] + + +train.update( + dict( + output_dir="./sft_result", + train_micro_batch_size=8, + test_micro_batch_size=1, + train_epoch=3, + train_iter=1, + log_period=10, + warmup_ratio=2 / 5, + num_accumulation_steps=8, + rdma_enabled=False, + amp=dict(enabled=True), + activation_checkpoint=dict(enabled=True), + checkpointer=dict( + period=5000, + max_to_keep=20, + ), + dist=dict( + data_parallel_size=1, + tensor_parallel_size=1, + pipeline_parallel_size=8, + pipeline_num_layers=cfg.hidden_layers, + ), + evaluation=dict( + enabled=True, + evaluator=LazyCall(PPLEvaluator)(), + eval_period=1000, + eval_iter=100, + ), + scheduler=LazyCall(WarmupExponentialLR)( + warmup_factor=0.0, + gamma=1.0, + warmup_method="linear", + ), + ) +) diff --git a/projects/Llama/adapter/train_net.py b/projects/Llama/adapter/train_net.py new file mode 100644 index 000000000..327b95e6f --- /dev/null +++ b/projects/Llama/adapter/train_net.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import random +import sys + +import numpy as np +import oneflow as flow + +import libai.utils.distributed as dist +from libai.config import LazyConfig, default_argument_parser, try_get_key +from libai.engine import DefaultTrainer, default_setup +from libai.utils.checkpoint import Checkpointer +from projects.Llama.utils.llama_loader import LlamaLoaderHuggerFace + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + +logger = logging.getLogger("libai." + __name__) + + +def build_model(cfg): + model_loader = LlamaLoaderHuggerFace( + cfg, + cfg.cfg, + cfg.cfg.pretrained_model_path, + ) + model = model_loader.load() + + for name, param in model.named_parameters(): + if "adapter" not in name: + param.requires_grad = False + else: + param.requires_grad = True + param.data = param.data.float() + + for name, param in model.model.layers[-cfg.cfg.adapter_layer :].named_parameters(): + if "gate" in name or "adapter" in name: + param.data = param.data.float() + param.requires_grad = True + + return model + + +class LlamaTrainer(DefaultTrainer): + @classmethod + def build_model(cls, cfg): + assert try_get_key(cfg, "model") is not None, "cfg must contain `model` namespace" + # Set model fp16 option because of embedding layer `white_identity` manual + # insert for amp training if provided. + if try_get_key(cfg.model, "cfg.amp_enabled") is not None: + cfg.model.cfg.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled + # In case some model define without cfg keyword. + elif try_get_key(cfg.model, "amp_enabled") is not None: + cfg.model.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled + model = build_model(cfg.model) + logger = logging.getLogger(__name__) + logger.info("Model:\n{}".format(model)) + model._apply(dist.convert_to_distributed_default_setting) + return model + + +def main(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + seed_for_rank = cfg.train.seed + flow.env.get_rank() + flow.manual_seed(seed_for_rank) + flow.cuda.manual_seed(seed_for_rank) + np.random.seed(seed_for_rank) + random.seed(seed_for_rank) + + if args.fast_dev_run: + cfg.train.train_epoch = 0 + cfg.train.train_iter = 20 + cfg.train.evaluation.eval_period = 10 + cfg.train.log_period = 1 + + if args.eval_only: + tokenizer = None + if try_get_key(cfg, "tokenization") is not None: + tokenizer = DefaultTrainer.build_tokenizer(cfg) + model = DefaultTrainer.build_model(cfg) + Checkpointer(model, save_dir=cfg.train.output_dir).resume_or_load( + cfg.train.load_weight, resume=args.resume + ) + if try_get_key(cfg, "graph.enabled", default=False): + model = DefaultTrainer.build_graph(cfg, model, is_train=False) + test_loader = DefaultTrainer.build_test_loader(cfg, tokenizer) + if len(test_loader) == 0: + logger.info("No dataset in dataloader.test, please set dataset for dataloader.test") + _ = DefaultTrainer.test(cfg, test_loader, model) + return + + trainer = LlamaTrainer(cfg) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + main(args) diff --git a/projects/Llama/configs/llama_config.py b/projects/Llama/configs/llama_config.py index 715b96dd3..58b86ecd6 100644 --- a/projects/Llama/configs/llama_config.py +++ b/projects/Llama/configs/llama_config.py @@ -15,7 +15,6 @@ max_position_embeddings=4096, num_attention_heads=32, hidden_layers=32, - num_key_value_heads=32, pretraining_tp=1, rms_norm_eps=1e-05, rope_scaling=None, @@ -58,5 +57,5 @@ tokenization = OmegaConf.create() tokenization.make_vocab_size_divisible_by = 1 tokenization.tokenizer = LazyCall(LlamaTokenizer)( - pretrained_model_path="Llama-2-7b-hf/tokenizer.model" + pretrained_model_path="meta-llama/Llama-2-7b-hf/tokenizer.model" ) diff --git a/projects/Llama/configs/llama_sft.py b/projects/Llama/configs/llama_sft.py index 322c88c1a..e767d84d7 100644 --- a/projects/Llama/configs/llama_sft.py +++ b/projects/Llama/configs/llama_sft.py @@ -18,13 +18,12 @@ # Hyperparameters weight_decay = 0.1 -learning_rate = 2e-5 -max_input_length = 1350 +learning_rate = 5e-5 dataset_path = "alpaca_data" pretrained_model_path = "meta-llama/Llama-2-7b-hf" # graph & optim -graph["enabled"] = True +graph["enabled"] = False optim.update( dict( lr=learning_rate, @@ -47,18 +46,14 @@ dataloader.train = LazyCall(build_nlp_train_loader)( dataset=[ LazyCall(AlpacaDataset)( - path=os.path.join(dataset_path, "train"), - tokenizer=tokenization.tokenizer, - max_len=max_input_length, + path=os.path.join(dataset_path, "train"), tokenizer=tokenization.tokenizer ) ], ) dataloader.test = [ LazyCall(build_nlp_test_loader)( dataset=LazyCall(AlpacaDataset)( - path=os.path.join(dataset_path, "test"), - tokenizer=tokenization.tokenizer, - max_len=max_input_length, + path=os.path.join(dataset_path, "test"), tokenizer=tokenization.tokenizer ), ), ] @@ -67,30 +62,30 @@ train.update( dict( output_dir="./sft_result", - train_micro_batch_size=2, + train_micro_batch_size=4, test_micro_batch_size=1, - train_epoch=5, + train_epoch=3, train_iter=1, log_period=10, - warmup_ratio=2 / 5, + warmup_ratio=1 / 3, num_accumulation_steps=8, - rdma_enabled=True, + rdma_enabled=False, amp=dict(enabled=True), activation_checkpoint=dict(enabled=True), checkpointer=dict( - period=100, + period=5000, max_to_keep=20, ), dist=dict( - data_parallel_size=2, + data_parallel_size=1, tensor_parallel_size=1, - pipeline_parallel_size=4, + pipeline_parallel_size=8, pipeline_num_layers=cfg.hidden_layers, ), evaluation=dict( enabled=True, evaluator=LazyCall(PPLEvaluator)(), - eval_period=100, + eval_period=1000, eval_iter=1e5, ), scheduler=LazyCall(WarmupExponentialLR)( diff --git a/projects/Llama/dataset.py b/projects/Llama/dataset.py index b500998b1..d78efe9fe 100644 --- a/projects/Llama/dataset.py +++ b/projects/Llama/dataset.py @@ -1,46 +1,19 @@ -# coding=utf-8 -# Copyright 2021 The OneFlow Authors. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random - import oneflow as flow from oneflow.utils.data import Dataset from libai.data.structures import DistTensorData, Instance -def pad_right(data, pad_id=0, max_len=1350): - n = max_len - data.shape[0] - return flow.cat((data, flow.full((n,), pad_id, dtype=data.dtype))) - - class AlpacaDataset(Dataset): - def __init__(self, path, tokenizer, max_len=1350): + def __init__(self, path, tokenizer): self.data = flow.load(path) - random.shuffle(self.data) self.tokenizer = tokenizer - self.max_len = max_len def __len__(self): return len(self.data) def __getitem__(self, index): - input_ids = pad_right(self.data[index]["input_ids"], pad_id=0, max_len=self.max_len) - labels = pad_right(self.data[index]["labels"], pad_id=-1, max_len=self.max_len) - return Instance( - input_ids=DistTensorData(input_ids), - labels=DistTensorData(labels), + input_ids=DistTensorData(self.data[index]["input_ids"]), + labels=DistTensorData(self.data[index]["labels"]), ) diff --git a/projects/Llama/llama.py b/projects/Llama/llama.py index 1736253bb..ea1b73541 100644 --- a/projects/Llama/llama.py +++ b/projects/Llama/llama.py @@ -22,7 +22,7 @@ from libai.config import configurable from libai.inference.generator.generation_utils import Generator -from libai.layers import Linear, ParallelCrossEntropyLoss, RMSLayerNorm, VocabEmbedding +from libai.layers import Linear, RMSLayerNorm, VocabEmbedding from libai.layers.attention import AttnMaskType from libai.models.utils import init_method_normal, scaled_init_method_normal from libai.utils import distributed as dist @@ -394,7 +394,6 @@ def __init__( hidden_size, intermediate_size, num_attention_heads, - num_key_value_heads, max_position_embeddings=1024, rms_norm_eps=1e-5, initializer_range=0.02, @@ -495,10 +494,25 @@ def forward( return hidden_states +class CrossEntropyLoss(nn.Module): + def forward(self, logits: flow.Tensor, target: flow.Tensor): + assert logits.ndim == 3 + assert target.ndim == 2 + assert logits.shape[0:2] == target.shape + + target = target.to_global(placement=logits.placement) + target = target * (target >= 0) + + lm_loss = flow._C.cross_entropy( + logits.view(-1, logits.shape[-1]), target.view(-1), ignore_index=0 + ) + return lm_loss + + class SFTLoss(nn.Module): def __init__(self) -> None: super().__init__() - self.lm_loss = ParallelCrossEntropyLoss() + self.lm_loss = CrossEntropyLoss() def forward(self, logits, lm_labels): lm_loss = self.lm_loss(logits, lm_labels) @@ -515,7 +529,6 @@ def __init__( hidden_size, intermediate_size, num_attention_heads, - num_key_value_heads, max_position_embeddings=1024, rms_norm_eps=1e-5, initializer_range=0.02, @@ -532,7 +545,6 @@ def __init__( hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, max_position_embeddings=max_position_embeddings, rms_norm_eps=rms_norm_eps, initializer_range=initializer_range, @@ -614,7 +626,6 @@ def from_config(cls, cfg): "hidden_size": cfg.hidden_size, "intermediate_size": cfg.intermediate_size, "num_attention_heads": cfg.num_attention_heads, - "num_key_value_heads": cfg.num_key_value_heads, "max_position_embeddings": cfg.max_position_embeddings, "rms_norm_eps": cfg.rms_norm_eps, "initializer_range": cfg.initializer_range, diff --git a/projects/Llama/pipeline.py b/projects/Llama/pipeline.py index 0b936da67..bea4a2f56 100644 --- a/projects/Llama/pipeline.py +++ b/projects/Llama/pipeline.py @@ -114,7 +114,9 @@ def postprocess(self, model_output_dict, **kwargs) -> dict: mode="libai", ) - text = ["a dog is flying on the sky", "Wikipedia is a free online", "what is beam search?"] + text = [ + "Give three tips for staying healthy.", + ] output = pipeline(inputs=text) if dist.is_main_process(): print(output) diff --git a/projects/Llama/readme.md b/projects/Llama/readme.md index a7ab82577..9adb3d925 100644 --- a/projects/Llama/readme.md +++ b/projects/Llama/readme.md @@ -24,8 +24,11 @@ python projects/Llama/utils/prepare_alpaca.py ### 3. Run the following code to start SFT ```bash -# cd /path/to/libai +# full finetune bash tools/train.sh projects/Llama/train_net.py projects/Llama/configs/llama_sft.py 8 + +# adapter finetune +bash tools/train.sh projects/Llama/adapter/train_net.py projects/Llama/adapter/adapter_sft.py 8 ``` ## Evaluate diff --git a/projects/Llama/utils/eval_adapter.py b/projects/Llama/utils/eval_adapter.py index 717107653..954c47df4 100644 --- a/projects/Llama/utils/eval_adapter.py +++ b/projects/Llama/utils/eval_adapter.py @@ -150,7 +150,7 @@ def run_eval_harness( parallel_config = DictConfig( dict( data_parallel_size=1, - tensor_parallel_size=1, + tensor_parallel_size=8, pipeline_parallel_size=1, pipeline_num_layers=32, device_type="cuda", diff --git a/projects/Llama/utils/llama_loader.py b/projects/Llama/utils/llama_loader.py index 59b46343a..20b9ba258 100644 --- a/projects/Llama/utils/llama_loader.py +++ b/projects/Llama/utils/llama_loader.py @@ -43,8 +43,6 @@ def _convert_state_dict(self, flow_state_dict, cfg): # Get configs num_attention_heads = cfg.get("num_attention_heads") - num_key_value_heads = cfg.get("num_key_value_heads") - assert num_attention_heads == num_key_value_heads hidden_size = cfg.get("hidden_size") head_size = int(hidden_size // num_attention_heads) @@ -83,7 +81,6 @@ def _load_config_from_json(self, config_file): self._update_cfg("hidden_layers", cfg_dict["num_hidden_layers"]) self._update_cfg("hidden_size", cfg_dict["hidden_size"]) self._update_cfg("num_attention_heads", cfg_dict["num_attention_heads"]) - self._update_cfg("num_key_value_heads", cfg_dict["num_key_value_heads"]) self._update_cfg("max_position_embeddings", cfg_dict["max_position_embeddings"]) self._update_cfg("intermediate_size", cfg_dict["intermediate_size"]) self._update_cfg("rms_norm_eps", cfg_dict["rms_norm_eps"]) diff --git a/projects/Llama/utils/prepare_alpaca.py b/projects/Llama/utils/prepare_alpaca.py index ce2fc92b1..c21f505fb 100644 --- a/projects/Llama/utils/prepare_alpaca.py +++ b/projects/Llama/utils/prepare_alpaca.py @@ -1,4 +1,5 @@ """Implementation derived from https://github.com/tloen/alpaca-lora""" +import copy import json import math import os @@ -18,18 +19,17 @@ def prepare( - destination_path: Path = Path("/alpaca_data"), - checkpoint_dir: Path = Path("/Llama-2-7b-hf"), + destination_path: Path = Path("alpaca_data"), + checkpoint_dir: Path = Path("meta-llama/Llama-2-7b-hf"), test_split_fraction: float = 0.03865, # to get exactly 2000 test samples, seed: int = 42, mask_inputs: bool = False, # as in alpaca-lora data_file_name: str = "alpaca_data_cleaned_archive.json", data_file_url: str = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json", # noqa ignore_index: int = -1, - max_seq_length: Optional[int] = None, + max_seq_length: Optional[int] = 512, ) -> None: """Prepare the Alpaca dataset for instruction tuning. - The output is a training and test dataset saved as `train.pt` and `test.pt`, which stores the preprocessed and tokenized prompts and labels. """ @@ -67,8 +67,6 @@ def prepare( example=sample, tokenizer=tokenizer, max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, ) for sample in tqdm(train_set) ] @@ -80,8 +78,6 @@ def prepare( example=sample, tokenizer=tokenizer, max_length=max_seq_length, - mask_inputs=mask_inputs, - ignore_index=ignore_index, ) for sample in tqdm(test_set) ] @@ -99,46 +95,47 @@ def download_if_missing(file_path: Path, file_url: str) -> None: f.write(requests.get(file_url).text) -def prepare_sample( - example: dict, tokenizer, max_length: int, mask_inputs: bool, ignore_index: int -) -> dict: +def prepare_sample(example: dict, tokenizer, max_length: int) -> dict: """Processes a single sample. - Each sample in the dataset consists of: - instruction: A string describing the task - input: A string holding a special input value for the instruction. This only applies to some samples, and in others this is empty. - output: The response string - This function processes this data to produce a prompt text and a label for supervised training. The prompt text is formed as a single message including both the instruction and the input. The label/target is the same message but with the response attached. - Finally, both the prompt and the label get tokenized. If desired, all tokens in the label that correspond to the original input prompt get masked out (default). """ full_prompt = generate_prompt(example) full_prompt_and_response = full_prompt + example["output"] - encoded_full_prompt = tokenizer.tokenize( - full_prompt, max_length=max_length, device="cpu" - ).squeeze(0) - encoded_full_prompt_and_response = tokenizer.tokenize( - full_prompt_and_response, add_eos=True, max_length=max_length, device="cpu" - ).squeeze(0) - - # The labels are the full prompt with response, but with the prompt masked out - labels = encoded_full_prompt_and_response.clone() - encoded_full_prompt_and_response = encoded_full_prompt_and_response[:-1] + prompt = tokenizer.tokenize(full_prompt, add_bos=True, add_eos=False, device="cpu")[0] + example = tokenizer.tokenize( + full_prompt_and_response, add_bos=True, add_eos=True, device="cpu" + )[0] + + padding = max_length - example.shape[0] + if padding > 0: + example = flow.cat((example, flow.zeros(padding, dtype=flow.long) - 1)) + elif padding < 0: + example = example[:max_length] + labels = copy.deepcopy(example) + labels[: len(prompt)] = -1 + example_mask = example.ge(0) + label_mask = labels.ge(0) + example[~example_mask] = 0 + labels[~label_mask] = -1 + example = example[:-1] labels = labels[1:] - if mask_inputs: - labels[: len(encoded_full_prompt)] = ignore_index - + example_mask = flow.where( + example_mask, flow.tensor(0, dtype=flow.float), flow.tensor(-float("inf")) + ) + example_mask = example_mask[:-1] return { - **example, - "input_ids": encoded_full_prompt_and_response, - "input_ids_no_response": encoded_full_prompt, + "input_ids": example, "labels": labels, }