Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

[EXAMPLE] Add llama finetune #923

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
support manual sharding
ZYHowell committed Apr 30, 2023

Verified

This commit was signed with the committer’s verified signature. The key has expired.
yrodiere Yoann Rodière
commit ad6000ee8101d9c498817569183eeef82fbeb9c1
111 changes: 63 additions & 48 deletions examples/opt_finetune/run_easylm_flax.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# TODO:
# 1. Import Llama Model Definition(done);
# 2. Import Manual partition spec;
# 2. Import Manual partition spec(done);
# 3. Import Fastchat dataset;
# 4. Weight Conversion(done);
# 5. Distributed load/store.
# 6. wandb support

#!/usr/bin/env python
# coding=utf-8
@@ -48,33 +49,31 @@

import alpa
from alpa.model.model_util import DynamicScale, TrainState
from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption
from alpa import ManualShardingOption
import jax
from jax.experimental.pjit import PartitionSpec
import jax.numpy as jnp
import optax
import transformers
from transformers.testing_utils import CaptureLogger
from transformers.utils import get_full_repo_name, send_example_telemetry
import tensorflow as tf
from flax import traverse_util
from optax import tree_map_params
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
AutoConfig,
AutoTokenizer,
FlaxAutoModelForCausalLM,
HfArgumentParser,
is_tensorboard_available,
set_seed,
)

alpa.init(cluster="ray")

from transformers.testing_utils import CaptureLogger
from transformers.utils import get_full_repo_name, send_example_telemetry

tf.config.experimental.set_visible_devices([], 'GPU')

from EasyLM.EasyLM.models.llama.llama_model import (
from EasyLM.models.llama.llama_model import (
LLaMAConfig, FlaxLLaMAForCausalLMModule, FlaxLLaMAForCausalLM
)

@@ -381,6 +380,46 @@ def call(
setattr(FlaxOPTDecoderLayerCollection, "__call__", call)


def llama_manual_sharding(num_layers, state: TrainState):
# TODO: when rebased to jax 0.4.6, use the tree_map_with_path
param_partition = {
'transformer': {
'wte': {'embedding': PartitionSpec("mp", None)},
'ln_f': {'kernel': PartitionSpec(None)},
'h': {
'%d' % (layer): {
'attention': {
# TODO: check whether we need the transpose or not
'wq': {'kernel': PartitionSpec(None, "mp")},
'wk': {'kernel': PartitionSpec(None, "mp")},
'wv': {'kernel': PartitionSpec(None, "mp")},
'wo': {'kernel': PartitionSpec("mp", None)},
},
'feed_forward': {
'w1': {'kernel': PartitionSpec(None, "mp")},
'w2': {'kernel': PartitionSpec("mp", None)},
'w3': {'kernel': PartitionSpec(None, "mp")},
},
'attention_norm': {'kernel': PartitionSpec(None)},
'ffn_norm': {'kernel': PartitionSpec(None)},
}
for layer in range(num_layers)},
},
'lm_head': {'kernel': PartitionSpec(None, "mp")},
}
replicate = lambda x : jax.tree_util.tree_map(lambda _: PartitionSpec(None), x)
opt_state = tree_map_params(state.tx, lambda _, spec: spec, state.opt_state,
param_partition, transform_non_params=lambda _: PartitionSpec(None))
manual_partition = TrainState(step=PartitionSpec(None),
params=param_partition,
master_copy=param_partition,
dynamic_scale=replicate(state.dynamic_scale),
tx=state.tx,
apply_fn=state.apply_fn,
opt_state=opt_state)
return manual_partition


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
@@ -514,22 +553,6 @@ def main():
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
# if model_args.config_name:
# config = AutoConfig.from_pretrained(
# model_args.config_name,
# cache_dir=model_args.cache_dir,
# use_auth_token=True if model_args.use_auth_token else None,
# )
# elif model_args.model_name_or_path:
# config = AutoConfig.from_pretrained(
# model_args.model_name_or_path,
# cache_dir=model_args.cache_dir,
# use_auth_token=True if model_args.use_auth_token else None,
# )
# else:
# config = CONFIG_MAPPING[model_args.model_type]()
# logger.warning("You are instantiating a new config instance from scratch.")
# TODO: merge with the above
config = LLaMAConfig.load_config('test')

if training_args.use_remat:
@@ -556,28 +579,9 @@ def main():
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

# if model_args.model_name_or_path:
# model = FlaxAutoModelForCausalLM.from_pretrained(
# model_args.model_name_or_path,
# config=config,
# seed=training_args.seed,
# dtype=getattr(jnp, model_args.dtype),
# use_auth_token=True if model_args.use_auth_token else None,
# )
# #from transformers import FlaxOPTForCausalLM
# #config.num_hidden_layers = 2
# #model = FlaxOPTForCausalLM(
# # config=config,
# # seed=training_args.seed,
# # dtype=getattr(jnp, model_args.dtype),
# #)
# else:
# model = FlaxAutoModelForCausalLM.from_config(
# config,
# seed=training_args.seed,
# dtype=getattr(jnp, model_args.dtype),
# )
model = FlaxLLaMAForCausalLM(config, (4, 2048))
# TODO(yonghao): don't init weight when loaded somewhere
dummy_input_shape = (4, config.max_sequence_length)
model = FlaxLLaMAForCausalLM(config, dummy_input_shape)

# Preprocessing the datasets.
# First we tokenize all the texts.
@@ -747,6 +751,11 @@ def decay_mask_fn(params):
learning_rate=linear_decay_lr_schedule_fn,
)
else:
# A tmp hack for llama finetune. Remove it either:
# 1) rebase to jax 0.4 and use tree_util's mask with path for partition spec;
# 2) optax fixes the issue of symbolic exec with decay mask fn.
if training_args.weight_decay == 0.0:
decay_mask_fn = None
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(
@@ -769,6 +778,11 @@ def decay_mask_fn(params):
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer,
dynamic_scale=dynamic_scale, use_master_copy=use_master_copy)

# Manual partition spec
state_manual_sharding = llama_manual_sharding(config.num_hidden_layers, state)
ms_option = ManualShardingOption(
("dp", "mp"), in_axis_resources=(state_manual_sharding, PartitionSpec("dp", None)))

def loss_fn(logits, labels):
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
@@ -828,7 +842,8 @@ def eval_step(params, batch):
num_micro_batches=training_args.num_micro_batches,
data_parallel=-1,
operator_parallel=training_args.operator_parallel,
pipeline_parallel=training_args.pipeline_parallel)
pipeline_parallel=training_args.pipeline_parallel,
manual_sharding_option=ms_option)

p_train_step = alpa.parallelize(train_step,
method=method,
4 changes: 2 additions & 2 deletions examples/opt_finetune/run_llama.sh
Original file line number Diff line number Diff line change
@@ -9,8 +9,8 @@ python3 run_easylm_flax.py \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--num_micro_batches 64 \
--operator_parallel 1 \
--pipeline_parallel 4 \
--operator_parallel 2 \
--pipeline_parallel 2 \
--dtype="float16" \
--learning_rate="5e-4" --warmup_steps="2000" \
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.0" \