Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama2 #524

Merged
merged 22 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions libai/inference/generator/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def _prepare_attention_mask_for_generation(
pad_token_id: Optional[int],
eos_token_id: Optional[int],
):
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [flow.int64, flow.long]
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [
flow.int64,
flow.long,
]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
Expand Down Expand Up @@ -502,7 +505,7 @@ def greedy_search(
next_tokens = next_tokens.to_global(placement=input_ids.placement)
unfinished_sequences = unfinished_sequences.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
placement=input_ids.placement,
)

if eos_token_id is not None:
Expand Down Expand Up @@ -987,7 +990,9 @@ def generate(

# 8. Prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
max_length=max_length,
max_time=max_time,
stopping_criteria=stopping_criteria,
)

# 9. Go into different generation modes
Expand Down
37 changes: 24 additions & 13 deletions libai/models/utils/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def _convert_tensor(self, tensor):
Returns:
flow.Tensor: The target tensor.
"""
tensor = tensor.float()
return flow.Tensor(tensor.detach().cpu().numpy())

def _convert_tensors(self, torch_state_dict):
Expand Down Expand Up @@ -465,8 +464,15 @@ def _load_torch_state_dict(self, state_dict_file):
raise ImportError("Load torch state dict need torch.")

# load pytorch_model.bin
state_dict = torch.load(state_dict_file, map_location="cpu")
return state_dict
if isinstance(state_dict_file, str):
return torch.load(state_dict_file, map_location="cpu")

if isinstance(state_dict_file, list):
merged_state_dict = {}
for file in state_dict_file:
state_dict = torch.load(file, map_location="cpu")
merged_state_dict.update(state_dict)
return merged_state_dict

def _update_cfg(self, keys_libai, value_target):
"""Update the libai_cfg according to target_cfg.
Expand All @@ -491,11 +497,12 @@ def _update_cfg_log(self):
f"changed libai model cfg {temp_key} : "
f"{self.origin_libai_cfg[key]} -> {self.libai_cfg[key]} "
)
logger.warning(
"The following model configurations has been modified according "
"to `config.json` or kwargs: \n"
f"{self.changed_keys} \n"
)
if len(self.changed_keys) > 0:
logger.warning(
"The following model configurations has been modified according "
"to `config.json` or kwargs: \n"
f"{self.changed_keys} \n"
)

if dist.get_pipeline_parallel_size() > 1:
logger.warning(
Expand Down Expand Up @@ -528,11 +535,15 @@ def load(self):
if dist.is_main_process():
if os.path.isdir(self.pretrained_model_path):
# state_dict file pytorch
if os.path.isfile(os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)):
model_file = os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)
else:
model_files = [
os.path.join(self.pretrained_model_path, file)
for file in os.listdir(self.pretrained_model_path)
if file.endswith(".bin")
]

if len(model_files) == 0:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME_PT} found"
f"Error: no file named endswith '.bin' found"
f"in directory {self.pretrained_model_path}."
)

Expand All @@ -554,7 +565,7 @@ def load(self):
raise EnvironmentError(f"{self.pretrained_model_path} is not a directory.")

logger.info("loading torch model...")
torch_state_dict = self._load_torch_state_dict(model_file)
torch_state_dict = self._load_torch_state_dict(model_files)
torch_state_dict = self._fix_key(torch_state_dict)
logger.info("transfering torch model into oneflow model...")
flow_state_dict = self._convert_tensors(torch_state_dict)
Expand Down
2 changes: 1 addition & 1 deletion libai/tokenizer/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def build_tokenizer(cfg):
"""Initialize tokenizer."""
tokenizer = instantiate(cfg.tokenizer)

if cfg.append_eod and tokenizer.eod_token is None:
if cfg.get("append_eod", None) and tokenizer.eod_token is None:
if tokenizer.eos_token is not None:
tokenizer.eod_token = tokenizer.eos_token
else:
Expand Down
62 changes: 62 additions & 0 deletions projects/Llama/configs/llama_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from omegaconf import DictConfig, OmegaConf

from libai.config import LazyCall
from projects.Llama.llama import LlamaForCausalLM
from projects.Llama.tokenizer import LlamaTokenizer
from configs.common.train import train


cfg = dict(
# Model
hidden_act="silu",
hidden_size=4096,
initializer_range=0.02,
intermediate_size=11008,
max_position_embeddings=4096,
num_attention_heads=32,
hidden_layers=32,
num_key_value_heads=32,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
tie_word_embeddings=False,
vocab_size=32000,
use_scaled_init_for_output_weights=False,
scale_mask_softmax_fusion=False,
amp_enabled=True,
# Inference
is_encoder_decoder=False,
max_length=256,
min_length=0,
do_sample=False,
early_stopping=False,
num_beams=1,
num_beam_groups=1,
diversity_penalty=0.0,
temperature=0.9,
top_k=50,
top_p=0.6,
typical_p=1.0,
repetition_penalty=1.0,
length_penalty=1.0,
no_repeat_ngram_size=0,
encoder_no_repeat_ngram_size=0,
num_return_sequences=1,
chunk_size_feed_forward=0,
output_scores=False,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
pad_token_id=0,
# train
pretrained_model_path="meta-llama/Llama-2-7b-hf",
)

cfg = DictConfig(cfg)

model = LazyCall(LlamaForCausalLM)(cfg=cfg)
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(LlamaTokenizer)(
pretrained_model_path="Llama-2-7b-hf/tokenizer.model"
)
102 changes: 102 additions & 0 deletions projects/Llama/configs/llama_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
from omegaconf import OmegaConf

from libai.config import LazyCall
from libai.evaluation import PPLEvaluator
from libai.scheduler import WarmupExponentialLR
from libai.data.build import build_nlp_test_loader, build_nlp_train_loader

from configs.common.train import train
from configs.common.models.graph import graph
from configs.common.optim import optim

from projects.Llama.configs.llama_config import cfg
from projects.Llama.dataset import AlpacaDataset
from projects.Llama.tokenizer import LlamaTokenizer
from projects.Llama.llama import LlamaForCausalLM


# Hyperparameters
weight_decay = 0.1
learning_rate = 2e-5
max_input_length = 1350
dataset_path = "alpaca_data"
pretrained_model_path = "meta-llama/Llama-2-7b-hf"

# graph & optim
graph["enabled"] = True
optim.update(
dict(
lr=learning_rate,
weight_decay=weight_decay,
)
)

# tokenize
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(LlamaTokenizer)(
pretrained_model_path=os.path.join(pretrained_model_path, "tokenizer.model")
)

# model
model = LazyCall(LlamaForCausalLM)(cfg=cfg)

# datasets
dataloader = OmegaConf.create()
dataloader.train = LazyCall(build_nlp_train_loader)(
dataset=[
LazyCall(AlpacaDataset)(
path=os.path.join(dataset_path, "train"),
tokenizer=tokenization.tokenizer,
max_len=max_input_length,
)
],
)
dataloader.test = [
LazyCall(build_nlp_test_loader)(
dataset=LazyCall(AlpacaDataset)(
path=os.path.join(dataset_path, "test"),
tokenizer=tokenization.tokenizer,
max_len=max_input_length,
),
),
]


train.update(
dict(
output_dir="./sft_result",
train_micro_batch_size=2,
test_micro_batch_size=1,
train_epoch=5,
train_iter=1,
log_period=10,
warmup_ratio=2 / 5,
num_accumulation_steps=8,
rdma_enabled=True,
amp=dict(enabled=True),
activation_checkpoint=dict(enabled=True),
checkpointer=dict(
period=100,
max_to_keep=20,
),
dist=dict(
data_parallel_size=2,
tensor_parallel_size=1,
pipeline_parallel_size=4,
pipeline_num_layers=cfg.hidden_layers,
),
evaluation=dict(
enabled=True,
evaluator=LazyCall(PPLEvaluator)(),
eval_period=100,
eval_iter=1e5,
),
scheduler=LazyCall(WarmupExponentialLR)(
warmup_factor=0.0,
gamma=1.0,
warmup_method="linear",
),
)
)
46 changes: 46 additions & 0 deletions projects/Llama/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

import oneflow as flow
from oneflow.utils.data import Dataset

from libai.data.structures import DistTensorData, Instance


def pad_right(data, pad_id=0, max_len=1350):
n = max_len - data.shape[0]
return flow.cat((data, flow.full((n,), pad_id, dtype=data.dtype)))


class AlpacaDataset(Dataset):
def __init__(self, path, tokenizer, max_len=1350):
self.data = flow.load(path)
random.shuffle(self.data)
self.tokenizer = tokenizer
self.max_len = max_len

def __len__(self):
return len(self.data)

def __getitem__(self, index):
input_ids = pad_right(self.data[index]["input_ids"], pad_id=0, max_len=self.max_len)
labels = pad_right(self.data[index]["labels"], pad_id=-1, max_len=self.max_len)

return Instance(
input_ids=DistTensorData(input_ids),
labels=DistTensorData(labels),
)
Loading
Loading