From 2f6aeaead124d6fba094cc4a4b90e1da1fd7f0da Mon Sep 17 00:00:00 2001 From: Vincent-syr <583636762@qq.com> Date: Wed, 24 Jan 2024 14:23:00 +0800 Subject: [PATCH 1/2] [fix] fix requirements.txt --- model_zoo/mixtral/huggingface/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/model_zoo/mixtral/huggingface/requirements.txt b/model_zoo/mixtral/huggingface/requirements.txt index 1a0b606..2e839fd 100644 --- a/model_zoo/mixtral/huggingface/requirements.txt +++ b/model_zoo/mixtral/huggingface/requirements.txt @@ -5,3 +5,4 @@ sentencepiece==0.1.99 torch==2.0.1 torchaudio==2.0.2 torchvision==0.15.2 +safetensors==0.3.2 \ No newline at end of file From fb707a0745e9ac4db09fe43002e441b1e0726eff Mon Sep 17 00:00:00 2001 From: Vincent-syr <583636762@qq.com> Date: Mon, 1 Apr 2024 22:34:53 +0800 Subject: [PATCH 2/2] [feature] add llama pipeline parallel --- model_zoo/ModelParallel.py | 72 +++ model_zoo/llama_pp/README.md | 86 +++ model_zoo/llama_pp/huggingface/Demo.py | 174 +++++++ model_zoo/llama_pp/huggingface/Export.py | 43 ++ model_zoo/llama_pp/modeling/Loader.py | 117 +++++ .../modeling/dynamic_batching/Model_pp.py | 293 +++++++++++ .../modeling/dynamic_batching/Pipeline_pp.py | 491 ++++++++++++++++++ 7 files changed, 1276 insertions(+) create mode 100644 model_zoo/llama_pp/README.md create mode 100644 model_zoo/llama_pp/huggingface/Demo.py create mode 100644 model_zoo/llama_pp/huggingface/Export.py create mode 100644 model_zoo/llama_pp/modeling/Loader.py create mode 100644 model_zoo/llama_pp/modeling/dynamic_batching/Model_pp.py create mode 100644 model_zoo/llama_pp/modeling/dynamic_batching/Pipeline_pp.py diff --git a/model_zoo/ModelParallel.py b/model_zoo/ModelParallel.py index f37f999..83c6a58 100644 --- a/model_zoo/ModelParallel.py +++ b/model_zoo/ModelParallel.py @@ -24,6 +24,78 @@ def setup(use_cpu: bool = True) -> Tuple[int, int]: torch.manual_seed(1) return local_rank, world_size +class DistMapping(object): + ''' + A node with 8 GPUs, tp_size = 4, pp_size = 2 + + 2 tp groups: + + - [0, 1, 2, 3] + - [4, 5, 6, 7] + + 4 pp groups: + + - [0, 4] + - [1, 5] + - [2, 6] + - [3, 7] + ''' + def __init__(self, + world_size=1, + rank=0, + tp_size=1, + pp_size=1): + self.tp_size = tp_size + self.pp_size = pp_size + self.world_size = world_size + self.rank = rank + + if pp_size * tp_size != world_size: + raise ValueError("world_size must equal to pp_size * tp_size") + self.pp_groups = [] + self.tp_groups = [] + + # init pp group + for i in range(tp_size): + ranks = range(i, world_size, tp_size) + self.pp_groups.append(list(ranks)) + + # init tp group + for i in range(pp_size): + ranks = range(i * tp_size, (i + 1) * tp_size) + self.tp_groups.append(list(ranks)) + + self.pp_rank = self.rank // self.tp_size + self.tp_rank = self.rank % self.tp_size + + self.tp_group = self.tp_groups[self.pp_rank] + self.pp_group = self.pp_groups[self.tp_rank] + + self.tp_proc_group = None + self.pp_proc_group = None + + print(f"rank: {self.rank}, tp_groups: {self.tp_groups}, pp_groups: {self.pp_groups}") + + def is_last_pp_rank(self): + return self.pp_rank == self.pp_size - 1 + + def is_first_pp_rank(self): + return self.pp_rank == 0 + + def has_pp(self): + return self.pp_size > 1 + + def prev_pp_rank(self): + p = self.rank - self.tp_size + if p < 0: + p = p + self.world_size + return p + + def next_pp_rank(self): + p = self.rank + self.tp_size + if p >= self.world_size: + p = p - self.world_size + return p class ParallelEmbedding(torch.nn.Module): def __init__( diff --git a/model_zoo/llama_pp/README.md b/model_zoo/llama_pp/README.md new file mode 100644 index 0000000..c22eb19 --- /dev/null +++ b/model_zoo/llama_pp/README.md @@ -0,0 +1,86 @@ +# Usage example + +## Download model from Hugging Face + +Download the model file from the [Hugging Face](https://huggingface.co/models). + +## Convert model params + +Due to the inconsistency with the implementation of Hugging Face's RotaryPositionEmbedding function, we need to convert the weight parameters. + +``` +cd ppl.pmx/model_zoo/llama/huggingface +python ConvertWeightToPmx.py --input_dir --output_dir +``` + +you can find pmx model file in`` after the conversion. + +## Spliting model +Here we support split model with tensor parallel and pipelien parallel in runtime, so we don't need extra split model script. + +## Testing Model + +The `Demo.py` script provides functionality to test the model for correctness before exporting. + +```bash +OMP_NUM_THREADS=1 python Demo.py --nproc_per_node $pp_size \ +--ckpt_dir \ +--tokenizer_path /tokenizer.model \ +--temperature 0 \ +--top_p 0.95 \ +--batch 4 \ +--seqlen_scale_up 1 \ +--unaligned_batch 0 \ +--max_gen_len 16 \ +--friendly_gqa 0 \ +--fused_qkv 1 \ +--fused_kvcache 0 \ +--fused_ffn_glu 1 \ +--auto_causal 1 \ +--quantized_cache 1 \ +--cache_layout 3 \ +--cache_mode 0 \ +--dynamic_batching 1 \ +--pp_size $pp_size +``` + +- `OMP_NUM_THREADS`: This parameter determines the number of OpenMP threads. It is set to 1 to prevent excessive CPU core usage. Each PyTorch process opens an OpenMP thread pool, and setting it to 1 avoids occupying too many CPU cores. +- `--nproc_per_node`: Specifies the number of model slices per node. + +## Exporting Model + +To export a model, you will use the `Export.py` script provided. Here's an example command for exporting a 13B model with 1 GPU: + +```bash +OMP_NUM_THREADS=1 torchrun --nproc_per_node $pp_size Export.py \ +--ckpt_dir \ +--export_path \ +--friendly_gqa 1 \ +--fused_qkv 1 \ +--fused_kvcache 1 \ +--fused_ffn_glu 1 \ +--auto_causal 1 \ +--quantized_cache 1 \ +--cache_layout 3 \ +--cache_mode 0 \ +--dynamic_batching 1 \ +--pp_size $pp_size +``` + +Make sure to replace `$pp_size` with the actual number of GPUs you want to use. + +## Generating Test Data + +This script demonstrates how to generate test data for steps 0, 1, and 255 using the specified command. + +```bash +OMP_NUM_THREADS=1 torchrun --nproc_per_node $num_gpu Demo.py --ckpt_dir --tokenizer_path /tokenizer.model --fused_qkv 1 --fused_kvcache 1 --auto_causal 1 --quantized_cache 1 --dynamic_batching 1 --seqlen_scale_up 1 --max_gen_len 256 --dump_steps 0,1,255 --dump_tensor_path --batch 1 +``` + +- `seqlen_scale_up`: Scale factor for input byte size (sequence length scaled up by 8). +- `max_gen_len`: Specifies the maximum generated output length in bytes. +- `dump_steps`: Steps at which to dump the test data. +- `dump_tensor_path`: Path to store the dumped test data. +- `batch`: Specifies the batch size for data processing. + +Make sure to replace `` , `` and ``with the actual directory paths in your environment. diff --git a/model_zoo/llama_pp/huggingface/Demo.py b/model_zoo/llama_pp/huggingface/Demo.py new file mode 100644 index 0000000..7cec795 --- /dev/null +++ b/model_zoo/llama_pp/huggingface/Demo.py @@ -0,0 +1,174 @@ +import sys +import os +import json + +from pathlib import Path +from typing import List + +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../..") + +import llama_pp.modeling.Loader as Loader +from llama.huggingface.Tokenizer import Tokenizer +from ModelParams import ModelParams + +import torch.multiprocessing as mp +import argparse + +import signal +import traceback +import os +import time + +def ParseCommandLineArgs(): + parser = argparse.ArgumentParser() + # basic command + parser.add_argument( + "--nnodes", + type=int, + default=1, + help="Number of nodes, or the range of nodes in form :.", + ) + parser.add_argument( + "--nproc_per_node", + type=int, + default=1, + help="Number of workers per node;" + ) + + parser.add_argument( + "--master_port", + type=str, + default="29500" + ) + + parser.add_argument( + "--local_addr", + type=str, + default="localhost" + ) + + # llm param + parser.add_argument("--ckpt_dir", type=str) + parser.add_argument("--tokenizer_path", type=str) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top_p", type=float, default=0.95) + parser.add_argument("--batch", type=int, default=4) + + parser.add_argument("--seqlen_scale_up", type=int, default=1) + parser.add_argument("--unaligned_batch", type=int, default=False) + parser.add_argument("--max_gen_len", type=int, default=256) + parser.add_argument("--friendly_gqa", type=int, default=False) + parser.add_argument("--fused_qkv", type=int, default=False) + + parser.add_argument("--fused_kvcache", type=int, default=True) + parser.add_argument("--fused_ffn_glu", type=int, default=True) + parser.add_argument("--auto_causal", type=int, default=True) + parser.add_argument("--quantized_cache", type=int, default=True) + parser.add_argument("--cache_layout", type=int, default=0) + + parser.add_argument("--cache_mode", type=int, default=0) + parser.add_argument("--dynamic_batching", type=int, default=True) + parser.add_argument("--pp_size", type=int, default=1) + # parser.add_argument("--tp_size", type=int, default=1) + parser.add_argument("--dump_tensor_path", type=str, default=None) + parser.add_argument("--dump_steps", type=str, default=None) + + args = parser.parse_args() + if args.dump_steps: + args.dump_steps = [int(s) for s in args.dump_steps.split(",")] + + args.unaligned_batch = bool(args.unaligned_batch) + args.friendly_gqa = bool(args.friendly_gqa) + args.fused_qkv = bool(args.fused_qkv) + args.fused_kvcache = bool(args.fused_kvcache) + args.fused_ffn_glu = bool(args.fused_ffn_glu) + args.auto_causal = bool(args.auto_causal) + args.quantized_cache = bool(args.quantized_cache) + args.dynamic_batching = bool(args.dynamic_batching) + + args.world_size = args.nproc_per_node * args.nnodes + return args + + +def set_dist_env_var(rank: int, world_size: int, local_addr: str, master_port: str): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = local_addr + os.environ["MASTER_PORT"] = master_port + +def main(rank: int, args: argparse.Namespace, queue: mp.Queue, global_start=None): + set_dist_env_var(rank, args.world_size, args.local_addr, args.master_port) + + tokenizer = Tokenizer(model_path=args.tokenizer_path) + + with open(Path(args.ckpt_dir) / "pmx_params.json", "r") as f: + params = json.loads(f.read()) + params: ModelParams = ModelParams(**params) + + generator = Loader.load( + args.ckpt_dir, params, args.friendly_gqa, + args.fused_qkv, args.fused_kvcache, args.fused_ffn_glu, + args.auto_causal, args.quantized_cache, args.cache_layout, + args.cache_mode, args.dynamic_batching, + False, False, False, False, + 0, pp_size=args.pp_size, + dump_tensor_path=args.dump_tensor_path, dump_steps=args.dump_steps + ) + + if args.unaligned_batch: + test_prompt = [ # For these prompts, the expected answer is the natural continuation of the prompt + "I believe the meaning of life is", + "Simply put, the theory of relativity states that ", + """A brief message congratulating the team on the launch: + + Hi everyone, + + I just """, + # Few shot prompt (providing a few examples before asking model to complete more); + """Translate English to French: + + sea otter => loutre de mer + peppermint => menthe poivrée + plush girafe => girafe peluche + cheese =>""", + ] + test_prompt = [tokenizer.encode(t, bos=True, eos=False) for t in test_prompt] + + prompt_tokens = test_prompt.copy() + for _ in range((args.batch - 1) // len(test_prompt)): + prompt_tokens.extend(test_prompt) + else: + test_prompt = "I believe the meaning of life is" + test_prompt = tokenizer.encode(test_prompt, bos=True, eos=False) + + _scale_up_prompt = [] + for _ in range(args.seqlen_scale_up): + _scale_up_prompt.extend(test_prompt) + test_prompt = _scale_up_prompt + + prompt_tokens = [test_prompt for _ in range(args.batch)] + + print(f"prepared {len(prompt_tokens)} prompts") + results = generator.generate( + prompt_tokens[:args.batch], tokenizer.get_eos_id(), tokenizer.get_pad_id(), + max_gen_len=args.max_gen_len, temperature=args.temperature, top_p=args.top_p, top_k=0, + queue=queue, global_start=global_start + ) + if generator.model.dist_mapping.is_last_pp_rank(): + for result in results: + print(result) + print(tokenizer.decode(result)) + print("\n==================================\n") + + +if __name__ == "__main__": + args = ParseCommandLineArgs() + print(args) + mp.set_start_method('spawn') + queue = mp.Queue() + global_start = time.time() + # tid_dict = mp.Manager().dict() + # lock = mp.Lock() + + mp.spawn(main, nprocs=args.world_size, args=(args, queue, global_start), join=True) \ No newline at end of file diff --git a/model_zoo/llama_pp/huggingface/Export.py b/model_zoo/llama_pp/huggingface/Export.py new file mode 100644 index 0000000..ae3a732 --- /dev/null +++ b/model_zoo/llama_pp/huggingface/Export.py @@ -0,0 +1,43 @@ +import fire +import sys +import os +import json + +from pathlib import Path + +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../..") + +import llama_pp.modeling.Loader as Loader +from ModelParams import ModelParams + +def main( + ckpt_dir: str, + export_path: str, + friendly_gqa: bool = False, # done gqa by repeating key and value by key_value_cache op + fused_qkv: bool = True, # fuse qkv linear + fused_kvcache: bool = True, # fuse key_value_cache and multi_head_attention + fused_ffn_glu: bool = True, # fuse feed forward gate linear unit + auto_causal: bool = True, # causal mask is auto done by attention op, no need to pass additional mask to the model + quantized_cache: bool = True, # 8bit kv cache quantization + cache_layout: int = 3, # change kv cache layout for hardware performance friendly + cache_mode: int = 0, # change kv cache indexing mode for memory management friendly, only affected when dynamic_batching == True + dynamic_batching: bool = True, # use dynamic batching scheduling + pp_size: int = 1, +): + with open(Path(ckpt_dir) / "pmx_params.json", "r") as f: + params = json.loads(f.read()) + params: ModelParams = ModelParams(**params) + + generator = Loader.load( + ckpt_dir, params, friendly_gqa, + fused_qkv, fused_kvcache, fused_ffn_glu, + auto_causal, quantized_cache, cache_layout, + cache_mode, dynamic_batching, + False, False, False, True, + 0, pp_size=pp_size + ) + + generator.export(export_path) + +if __name__ == "__main__": + fire.Fire(main) diff --git a/model_zoo/llama_pp/modeling/Loader.py b/model_zoo/llama_pp/modeling/Loader.py new file mode 100644 index 0000000..965e390 --- /dev/null +++ b/model_zoo/llama_pp/modeling/Loader.py @@ -0,0 +1,117 @@ +import os +import sys +import torch +import time + +from pathlib import Path +from typing import List + +import torch.distributed as dist + +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../../..") + +from ModelParams import ModelParams +from ModelUtils import __TextGenerator__ +import ModelParallel + + +def load( + ckpt_dir: str, + model_params: ModelParams, + friendly_gqa: bool, # done gqa by repeating key and value by key_value_cache op + fused_qkv: bool, # fuse qkv linear + fused_kvcache: bool, # fuse key_value_cache and multi_head_attention + fused_ffn_glu: bool, # fuse feed forward gate linear unit + auto_causal: bool, # causal mask is auto done by attention op, no need to pass additional mask to the model + quantized_cache: bool, # 8bit kv cache quantization + cache_layout: int, # change kv cache layout for hardware performance friendly + cache_mode: int, # change kv cache indexing mode for memory management friendly, only affected when dynamic_batching == True + dynamic_batching: bool, # use dynamic batching scheduling + attn_wqkv_bias_term: bool, + attn_wo_bias_term: bool, + ffn_linear_bias_term: bool, + load_to_cpu: bool, + rotary_dim: int = 0, + pp_size: int = 1, + dump_tensor_path: str = None, + dump_steps: List[int] = [] +) -> __TextGenerator__: + start_time = time.time() + + if dynamic_batching: + from llama_pp.modeling.dynamic_batching.Model_pp import TensorDumper, Transformer + from llama_pp.modeling.dynamic_batching.Pipeline_pp import LLaMA + if cache_layout != 3: + print("Info: we suggest using cache_layout 3 for cuda inference performance") + else: + raise ValueError("we only support dynamic_batching == True") + + assert model_params.num_layers % pp_size == 0, \ + f"num_layers {model_params.num_layers} must be a multiple of pipeline parallelism size {pp_size}" + + local_rank, world_size = ModelParallel.setup(load_to_cpu) + + tp_size = world_size // pp_size + + assert tp_size == 1, \ + f"tensor parallelism size [{tp_size}] must equal to 1" + + dist_mapping = ModelParallel.DistMapping(world_size=world_size, rank=local_rank, tp_size=tp_size, pp_size=pp_size) + + dist_mapping.tp_proc_group = dist.new_group(ranks=dist_mapping.tp_group) + dist_mapping.pp_proc_group = dist.new_group(ranks=dist_mapping.pp_group) + + if dist_mapping.tp_rank > 0: + sys.stdout = open(os.devnull, "w") + + model_params.dynamic_batching = bool(dynamic_batching) + model_params.auto_causal = bool(auto_causal) + model_params.cache_layout = cache_layout + model_params.cache_mode = cache_mode + if quantized_cache: + model_params.cache_quant_bit = 8 + model_params.cache_quant_group = 8 + else: + model_params.cache_quant_bit = 0 + model_params.cache_quant_group = 0 + + if load_to_cpu: + torch.set_default_tensor_type(torch.HalfTensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) + + model = Transformer(model_params, + friendly_gqa, + fused_qkv, + fused_kvcache, + fused_ffn_glu, + attn_wqkv_bias_term, + attn_wo_bias_term, + ffn_linear_bias_term, + rotary_dim=rotary_dim, + dist_mapping=dist_mapping) + + torch.set_default_tensor_type(torch.FloatTensor) + + + print("Loading") + + ckpt_path = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(ckpt_path) == 1, f"Expect only one checkpoint file in {ckpt_dir}" + checkpoint = torch.load(ckpt_path[0], map_location="cpu") + model.load_state_dict(checkpoint) + del checkpoint + # exit(0) + + generator = LLaMA(model) + + if dump_tensor_path is not None: + dump_path = os.path.join(dump_tensor_path, "rank_{}".format(local_rank)) + if not os.path.exists(dump_path): + os.makedirs(dump_path) + TensorDumper.dir = dump_path + TensorDumper.enable_dump = True + TensorDumper.dump_steps = dump_steps + + print(f"Loaded in {time.time() - start_time:.2f} seconds") + return generator diff --git a/model_zoo/llama_pp/modeling/dynamic_batching/Model_pp.py b/model_zoo/llama_pp/modeling/dynamic_batching/Model_pp.py new file mode 100644 index 0000000..c391f41 --- /dev/null +++ b/model_zoo/llama_pp/modeling/dynamic_batching/Model_pp.py @@ -0,0 +1,293 @@ +import sys +import os + +import torch +from torch import nn +import torch.distributed as dist + +from typing import Mapping, Any, Optional + +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../../..") + +import torch_function as PMX +from ModelParams import ModelParams +import ModelUtils +from ModelParallel import ColumnParallelLinear, RowParallelLinear, ParallelEmbedding, DistMapping +from llama.modeling.dynamic_batching.Model import RMSNorm, Attention, FeedForward + +TensorDumper = ModelUtils.__TensorDumper__() + +class TransformerBlockPP(nn.Module): + def __init__(self, layer_id: int, + args: ModelParams, + friendly_gqa: bool, + fused_qkv: bool, + fused_kvcache: bool, + fused_ffn_glu: bool, + attn_wqkv_bias_term: bool, + attn_wo_bias_term: bool, + ffn_linear_bias_term: bool, + rotary_dim: int, + tp_proc_group: dist.ProcessGroup): + super().__init__() + self.attention = Attention(args, + layer_id, + friendly_gqa, + fused_qkv, + fused_kvcache, + attn_wqkv_bias_term, + attn_wo_bias_term, + rotary_dim=rotary_dim, + proc_group=tp_proc_group) + self.feed_forward = FeedForward(args, + layer_id, + fused_ffn_glu, + ffn_linear_bias_term, + proc_group=tp_proc_group) + + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.hidden_dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.hidden_dim, eps=args.norm_eps) + + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor], + seqstarts: torch.Tensor, kvstarts: torch.Tensor, cachestarts: torch.Tensor, + decoding_batches: torch.Tensor, start_pos: torch.Tensor, + max_seqlen: torch.Tensor, max_kvlen: torch.Tensor, + kv_cache: torch.Tensor, kv_sacle: torch.Tensor = None): + residual = x + norm = self.attention_norm(x) + TensorDumper.dump(norm, "layer{}_attention_norm_out".format(self.layer_id)) + attn = self.attention.forward(norm, attn_mask, seqstarts, kvstarts, + cachestarts, decoding_batches, + start_pos, max_seqlen, max_kvlen, + kv_cache, kv_sacle) + hidden_states = residual + attn + residual = hidden_states + + norm = self.ffn_norm(hidden_states) + TensorDumper.dump(norm, "layer{}_ffn_norm_out".format(self.layer_id)) + ffn = self.feed_forward.forward(norm) + + hidden_states = residual + ffn + TensorDumper.dump(hidden_states, "layer{}_block_output".format(self.layer_id)) + return hidden_states + +class Transformer(nn.Module): + def __init__(self, params: ModelParams, + friendly_gqa: bool, + fused_qkv: bool, + fused_kvcache: bool, + fused_ffn_glu: bool, + attn_wqkv_bias_term: bool, + attn_wo_bias_term: bool, + ffn_linear_bias_term: bool, + rotary_dim: int, + dist_mapping: DistMapping): + super().__init__() + self.params = params + self.dist_mapping = dist_mapping + self.vocab_size = params.vocab_size + + self.num_layers = params.num_layers if not self.dist_mapping.has_pp() else (params.num_layers) // self.dist_mapping.pp_size + self.layer_range = list(range(dist_mapping.pp_rank * self.num_layers, (dist_mapping.pp_rank + 1) * self.num_layers, 1)) + + self.tp_proc_group = dist_mapping.tp_proc_group + self.fused_qkv = fused_qkv + self.fused_kvcache = fused_kvcache + self.fused_ffn_glu = fused_ffn_glu + + tp_size = 1 if self.tp_proc_group is None else self.tp_proc_group.size() + num_kv_heads = params.num_heads if params.num_kv_heads is None else params.num_kv_heads + self.hidden_dim = params.hidden_dim + self.num_local_heads = params.num_heads // tp_size + self.num_local_kv_heads = num_kv_heads // tp_size + self.head_dim = params.hidden_dim // params.num_heads + self.local_q_dim = self.num_local_heads * self.head_dim + self.local_kv_dim = self.num_local_kv_heads * self.head_dim + self.local_imm_dim = params.intermediate_dim // tp_size + + if self.dist_mapping.is_first_pp_rank(): + self.tok_embeddings = ParallelEmbedding(self.tp_proc_group, params.vocab_size, params.hidden_dim) + self.layers = torch.nn.ModuleList() + + # for layer_id in self.layer_range: + for layer_id in range(self.num_layers): + self.layers.append(TransformerBlockPP( + layer_id, params, + friendly_gqa, + fused_qkv, + fused_kvcache, + fused_ffn_glu, + attn_wqkv_bias_term, + attn_wo_bias_term, + ffn_linear_bias_term, + rotary_dim, + self.tp_proc_group)) + + if self.dist_mapping.is_last_pp_rank(): + self.norm = RMSNorm(params.hidden_dim, eps=params.norm_eps) + self.output = ColumnParallelLinear(self.tp_proc_group, params.hidden_dim, params.vocab_size, bias_term=False) + + @torch.inference_mode() + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor], + seqstarts: torch.Tensor, kvstarts: torch.Tensor, + cachestarts: torch.Tensor, decoding_batches: torch.Tensor, + start_pos: torch.Tensor, max_seqlen: torch.Tensor, max_kvlen: torch.Tensor, + kv_cache: torch.Tensor, kv_scale: torch.Tensor = None): + # if is first rank, x is tokens, else hidden states + # print(f"rank: {self.dist_mapping.rank} forward") + TensorDumper.dump(x, "model_input") + if self.dist_mapping.is_first_pp_rank(): + h = self.tok_embeddings(x) + else: + h = x + _kv_scale = kv_scale + if attn_mask is not None: + TensorDumper.dump(attn_mask, "attn_mask") + if self.fused_kvcache and attn_mask is not None: + if kv_scale is None: # mount an empty scale for friendly exporting + _kv_scale = torch.empty(0, dtype=h.dtype) + TensorDumper.dump(seqstarts, "seqstarts") + TensorDumper.dump(kvstarts, "kvstarts") + TensorDumper.dump(cachestarts, "cachestarts") + TensorDumper.dump(decoding_batches, "decoding_batches") + TensorDumper.dump(start_pos, "start_pos") + TensorDumper.dump(max_seqlen, "max_seqlen") + TensorDumper.dump(max_kvlen, "max_kvlen") + TensorDumper.dump(kv_cache, "kv_cache") + if kv_scale is not None: + TensorDumper.dump(kv_scale, "kv_scale") + for i, layer in enumerate(self.layers): + # for layer in self.layers: + h = layer(h, attn_mask, seqstarts, kvstarts, cachestarts, + decoding_batches, start_pos, max_seqlen, max_kvlen, + kv_cache, _kv_scale) + # print(f"rank: {self.dist_mapping.rank} after layer forward") + + if self.dist_mapping.is_last_pp_rank(): + h = self.norm(h) + TensorDumper.dump(h, "last_rms_norm") + gathered_h = torch.index_select(h, 0, seqstarts[1:] - 1) + TensorDumper.dump(gathered_h, "gathered_h") + output = self.output(gathered_h) # only compute last logits + TensorDumper.dump(output, "logits_before_cast") + output = output.float() + TensorDumper.dump(output, "logits") + # print(f"rank: {self.dist_mapping.rank} after output") + else: + output = h + return output + + @torch.no_grad() + def load_state_dict(self, state_dict: Mapping[str, Any]): + loaded_params = set() + model_params = {key: value for key, value in self.named_parameters()} + + pp_state_dict = filter_pp_state_dict(state_dict, self.layer_range, self.dist_mapping) + + tp_rank, tp_size = self.dist_mapping.tp_rank, self.dist_mapping.tp_size + + for key, value in pp_state_dict.items(): + try: + if 'attention.wq.weight' in key: + value = value.reshape( + self.num_local_heads * tp_size, self.head_dim, self.hidden_dim).split( + [self.num_local_heads] * tp_size, dim=0)[tp_rank].reshape(-1, self.hidden_dim) + if 'attention.wk.weight' in key or 'attention.wv.weight' in key: + value = value.reshape( + self.num_local_kv_heads * tp_size, self.head_dim, self.hidden_dim).split( + [self.num_local_kv_heads] * tp_size, dim=0)[tp_rank].reshape(-1, self.hidden_dim) + if 'attention.wo.weight' in key: + value = value.split([self.hidden_dim // tp_size] * tp_size, dim=1)[tp_rank] + if 'feed_forward.w1.weight' in key or 'feed_forward.w3.weight' in key: + value = value.split([self.local_imm_dim] * tp_size, dim=0)[tp_rank] + if 'feed_forward.w2.weight' in key: + value = value.split([self.local_imm_dim]*tp_size, dim=1)[tp_rank] + if 'tok_embeddings.weight' in key: + value = value.split([self.hidden_dim // tp_size] * tp_size, dim=1)[tp_rank] + if 'output.weight' in key: + value = value.split([self.vocab_size // tp_size] * tp_size, dim=0)[tp_rank] + # split ColParaelleLinear bias + if 'attention.wq.bias' in key or 'attention.wk.bias' in key or 'attention.wv.bias' in key or \ + 'attention.w1.bias' in key or 'attention.w3.bias' in key or \ + 'tok_embeddings.bias' in key or 'output.bias' in key: + bias_dim = value.shape[0] + value = value.split([bias_dim // tp_size] * tp_size)[tp_rank] + + module_name, param_name = key.rsplit(".", 1) + if key in model_params: + self.get_submodule(module_name)._parameters[param_name][:] = value + loaded_params.add(key) + print(f'Loaded: {key} -> {key}[{value.shape}]') + + if self.fused_qkv: + if 'attention.wq' in key: + loaded_params.add(key) + module_name = module_name.replace('wq', 'wqkv') + self.get_submodule(module_name)._parameters[param_name][ + :self.local_q_dim] = value + replaced_key = module_name + '.' + param_name + print(f'Loaded: {key} -> {replaced_key}[{value.shape}]') + elif 'attention.wk' in key: + loaded_params.add(key) + module_name = module_name.replace('wk', 'wqkv') + self.get_submodule(module_name)._parameters[param_name][ + self.local_q_dim:self.local_q_dim + self.local_kv_dim] = value + replaced_key = module_name + '.' + param_name + print(f'Loaded: {key} -> {replaced_key}[{value.shape}]') + elif 'attention.wv' in key: + loaded_params.add(key) + module_name = module_name.replace('wv', 'wqkv') + self.get_submodule(module_name)._parameters[param_name][ + self.local_q_dim + self.local_kv_dim: + self.local_q_dim + self.local_kv_dim * 2] = value + replaced_key = module_name + '.' + param_name + print(f'Loaded: {key} -> {replaced_key}[{value.shape}]') + if self.fused_ffn_glu: + if 'feed_forward.w1' in key: + loaded_params.add(key) + module_name = module_name.replace('w1', 'wu') + self.get_submodule(module_name)._parameters[param_name][ + :self.local_imm_dim] = value + replaced_key = module_name + '.' + param_name + print(f'Loaded: {key} -> {replaced_key}[{value.shape}]') + if 'feed_forward.w3' in key: + loaded_params.add(key) + module_name = module_name.replace('w3', 'wu') + self.get_submodule(module_name)._parameters[param_name][ + self.local_imm_dim:] = value + replaced_key = module_name + '.' + param_name + print(f'Loaded: {key} -> {replaced_key}[{value.shape}]') + except AttributeError as e: + raise Exception(f'Failed to inject model weight {key}, can not find corresponding layer.') + + for key in pp_state_dict: + if key not in loaded_params: + print(f'{key} is not loaded.') + + +def filter_pp_state_dict(state_dict: Mapping[str, Any], layer_range: list, dist_mapping: DistMapping): + if dist_mapping.pp_size == 1: + return state_dict + + pp_state_dict = {} + + for key, value in state_dict.items(): + if "tok_embeddings" in key: + if dist_mapping.is_first_pp_rank(): + pp_state_dict.update({key: value}) + continue + if "norm.weight" == key or "norm.bias" == key or "output" in key: + if dist_mapping.is_last_pp_rank(): + pp_state_dict.update({key: value}) + continue + prefix, layer_idx, param = key.split(".", 2) + if int(layer_idx) not in layer_range: + continue + + remapped_layer_idx = str(int(layer_idx) - layer_range[0]) + pp_key = ".".join([prefix, remapped_layer_idx, param]) + + pp_state_dict.update({pp_key: value}) + return pp_state_dict \ No newline at end of file diff --git a/model_zoo/llama_pp/modeling/dynamic_batching/Pipeline_pp.py b/model_zoo/llama_pp/modeling/dynamic_batching/Pipeline_pp.py new file mode 100644 index 0000000..b39bcae --- /dev/null +++ b/model_zoo/llama_pp/modeling/dynamic_batching/Pipeline_pp.py @@ -0,0 +1,491 @@ +from typing import List +import sys +import os +import torch +import os +import json +import torch.multiprocessing as mp +import torch.distributed as dist + +from .Model_pp import Transformer, TensorDumper + +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../../..") + +from ModelUtils import __Tokenizer__, __TextGenerator__ +from ModelParallel import DistMapping + +import time +from torch.cuda import synchronize + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +def recv_tensor_info(dist_mapping: DistMapping, hidden_dim: int, step: int): + # print(f"step [{step}] start recv") + tmp_tensor = torch.empty((2), dtype=torch.int64).cuda() + dist.recv(tmp_tensor, src=dist_mapping.prev_pp_rank()) + _batches, _decoding_batches = tmp_tensor.chunk(2, dim=0) + tmp_tensor2 = torch.empty((_batches.item() * 4), dtype=torch.int64).cuda() + dist.recv(tmp_tensor2, src=dist_mapping.prev_pp_rank()) + _seqlens, _start_pos, _cache_starts, _tid_list = tmp_tensor2.chunk(4, dim=0) + + total_seqlen = torch.sum(_seqlens) + _hs = torch.empty((total_seqlen, hidden_dim), dtype=torch.half).cuda() + dist.recv(_hs, src=dist_mapping.prev_pp_rank()) + torch.cuda.synchronize() + # print("recv end") + return _decoding_batches, _seqlens, _start_pos, _cache_starts, _tid_list.tolist(), _hs + +def send_tensor_info(dist_mapping: DistMapping, decoding_batches: int, _seqlens: torch.tensor, + _start_pos: torch.tensor, _cache_starts: torch.tensor, _tid_list: torch.tensor, _hidden_states: torch.tensor, step: int): + if dist_mapping.is_last_pp_rank(): + return + # print(f"step [{step}] start send") + batches = _seqlens.shape[0] + tmp_tensor = torch.tensor([batches, decoding_batches], dtype=torch.int64).cuda() + dist.send(tensor=tmp_tensor, dst=dist_mapping.next_pp_rank()) + tmp_tensor2 = torch.cat([_seqlens, _start_pos, _cache_starts, _tid_list], dim=0) + dist.send(tensor=tmp_tensor2, dst=dist_mapping.next_pp_rank()) + dist.send(tensor=_hidden_states, dst=dist_mapping.next_pp_rank()) + # print("end send") + +def post_process(queue: mp.Queue, logits: torch.tensor, temperature: float, top_p: float, start_pos, tid_list, current_steps, max_gen_len, output_ids: torch.tensor): + if temperature > 0: + probs = torch.softmax(logits / temperature, dim=-1) + next_tokens = sample_top_p(probs, top_p) + else: + next_tokens = torch.argmax(logits, dim=-1) + next_tokens = next_tokens.reshape(-1) + output_ids[tid_list, start_pos] = next_tokens # send + + # push to queue and early finish check + for idx, tid, next_token in zip(range(len(tid_list)), tid_list, next_tokens): + # early finish check + if current_steps[idx] >= max_gen_len: + continue + # print(f"queue.put(({tid}, [{next_token}])") + queue.put((tid, [next_token])) + return output_ids + + +def remove_finished_task(tid, tid_list, start_pos, seqlens, cachestarts, current_steps): + idx = tid_list.index(tid) + tid_list.pop(idx) + seqlens.pop(idx) + start_pos.pop(idx) + cachestarts.pop(idx) + current_steps.pop(idx) + +def update_decode_input(seqlens: List[int], start_pos: List[int]): + for idx in range(len(seqlens)): + start_pos[idx] += seqlens[idx] + seqlens[idx] = 1 + +class Profiler(): + def __init__(self): + self.recv_start = [] + self.recv_end = [] + self.forward_start = [] + self.forward_end = [] + self.send_start = [] + self.send_end = [] + +class LLaMA(__TextGenerator__): + def __init__(self, model: Transformer): + self.model = model + self.profiler = Profiler() + + def generate( + self, + prompts_ids: List[List[int]], + eos_id: int, + pad_id: int, + max_gen_len: int, + temperature: float, + top_k: int, + top_p: float, + queue: mp.Queue, + global_start = None + ) -> List[List[int]]: + + global_step = 0 + + total_prompt_len = 0 + for i, p in enumerate(prompts_ids): + total_prompt_len = total_prompt_len + len(p) + + total_cache_len = total_prompt_len + len(prompts_ids) * max_gen_len + head_dim = self.model.params.hidden_dim // self.model.params.num_heads + num_local_kv_heads = self.model.params.num_kv_heads // self.model.dist_mapping.tp_size + hidden_dim = self.model.params.hidden_dim + local_num_layer = self.model.params.num_layers + + dist_mapping = self.model.dist_mapping + + if self.model.params.cache_layout == 0: + cache_prefix_shape = (total_cache_len, local_num_layer, 2, num_local_kv_heads) + elif self.model.params.cache_layout == 1: + cache_prefix_shape = (local_num_layer, total_cache_len, 2, num_local_kv_heads) + elif self.model.params.cache_layout == 2: + cache_prefix_shape = (local_num_layer, 2, total_cache_len, num_local_kv_heads) + elif self.model.params.cache_layout == 3: + cache_prefix_shape = (local_num_layer, 2, num_local_kv_heads, total_cache_len) + else: + raise Exception("unsupported cache_layout: {}".format(self.model.params.cache_layout)) + + if self.model.params.cache_quant_bit == 8: + scale_head_dim = head_dim // self.model.params.cache_quant_group + kv_cache = torch.zeros(cache_prefix_shape + (head_dim,), dtype=torch.int8).cuda() + kv_scale = torch.zeros(cache_prefix_shape + (scale_head_dim,), dtype=torch.float16).cuda() + else: + kv_cache = torch.zeros(cache_prefix_shape + (head_dim,), dtype=torch.float16).cuda() + kv_scale = None + + max_prompt_len = max([len(t) for t in prompts_ids]) + output_ids = torch.full((len(prompts_ids), max_prompt_len + max_gen_len), pad_id).cuda().long() + for k, t in enumerate(prompts_ids): + output_ids[k, : len(t)] = torch.tensor(t).long() + + pp_rank = self.model.dist_mapping.pp_rank + + def myprint(string: str): + print(f"step[{global_step}]-rank[{pp_rank}]: {string}") + + + # send in queue + tid = 0 + if pp_rank == 0: + # token_ids = [] + for prompt_id in prompts_ids: + queue.put((tid, prompt_id)) # 0 means new req, 1 means decoding req + # print(f" [x] Sent {tid}: {prompt_id}") + tid += 1 + + TensorDumper.step = 0 + + if pp_rank == 0: + allocated_cache_len = 0 + seqlens = [] + start_pos = [] + cachestarts = [] + + tid_list = [] + + max_seqlen = 0 + current_steps = [] + total_tid_cnt = len(prompts_ids) + + while True: + # rank 0,从msgq里取数据 + # print(f"---------- step: {global_step}-------pp_rank: {pp_rank} ---------------- ") + if dist_mapping.is_first_pp_rank(): + tokens_ids = [] + current_batches = 0 + decoding_batches = torch.tensor([0]) + prev_tid_set = set(tid_list) + tid_set = set() + + self.profiler.recv_start.append(time.time() - global_start) + while True: + item = queue.get() # 期望是block wait + if item is None: + print("item is none") + break + tid, prompt_id = item + # print(f"queue.get(({tid}, [{prompt_id}])") + current_batches += 1 + tid_set.add(tid) + + if tid not in prev_tid_set: # new req + tokens_ids.extend(prompt_id) + l = len(prompt_id) + start_pos.append(0) + cachestarts.append(allocated_cache_len) + seqlens.append(l) + tid_list.append(tid) + allocated_cache_len += l + max_gen_len + current_steps.append(0) + else: # decode + tokens_ids.extend(prompt_id) + decoding_batches += 1 + + if queue.qsize() == 0: + break + self.profiler.recv_end.append(time.time() - global_start) + + # early 检测与处理处理 + tid_finished = [] + + for tid in prev_tid_set: + if tid not in tid_set: + tid_finished.append(tid) + remove_finished_task(tid, tid_list, start_pos, seqlens, cachestarts, current_steps) + + if len(tid_finished) > 0: + print("early finish tid list: ", tid_finished) + + _tokens_ids = torch.tensor(tokens_ids, dtype=torch.int64).cuda() + if self.model.params.cache_mode == 0: + _cachestarts = torch.tensor(cachestarts, dtype=torch.int64).cuda() + else: + raise Exception("unsupported cache_mode: {}".format(self.model.params.cache_mode)) + + _start_pos = torch.tensor(start_pos, dtype=torch.int64).cuda() + _seqlens = torch.tensor(seqlens, dtype=torch.int64).cuda() + model_input0 = _tokens_ids + else: # other pp_rank + prev_tid_list = tid_list + prev_tid_set = set(prev_tid_list) + + self.profiler.recv_start.append(time.time() - global_start) + + decoding_batches, _seqlens, _start_pos, _cachestarts, tid_list, _hs = recv_tensor_info(dist_mapping, hidden_dim, global_step) + + self.profiler.recv_end.append(time.time() - global_start) + + tid_set = set(tid_list) + + all_tid_set = prev_tid_set | tid_set + new_tid_set = all_tid_set - prev_tid_set + finished_tid_set = all_tid_set - tid_set + + # update current_steps + for tid in finished_tid_set: # delete finish tid + idx = prev_tid_list.index(tid) + current_steps.pop(idx) + current_steps.extend([0 for _ in range(len(new_tid_set))]) # add new tid + + current_batches = _seqlens.shape[0] + model_input0 = _hs + + _kvlens = _start_pos + _seqlens + max_seqlen = _seqlens.max().cpu().unsqueeze(0) + max_kvlen = _kvlens.max().cpu().unsqueeze(0) + _seqstarts = torch.zeros(current_batches + 1, dtype=torch.int64).cuda() + _seqstarts[1:] = _seqlens + _seqstarts = _seqstarts.cumsum(0) + _kvstarts = torch.zeros(current_batches + 1, dtype=torch.int64).cuda() + _kvstarts[1:] = _kvlens + _kvstarts = _kvstarts.cumsum(0) + + attn_mask = torch.empty(0, dtype=torch.float16).cuda() + if self.model.params.auto_causal == False and decoding_batches < current_batches: + padded_last_dim = (_kvstarts[-1] + 15) // 16 * 16 + attn_mask = torch.zeros((_seqstarts[-1], padded_last_dim), dtype=torch.float16).cuda() + for b in range(decoding_batches, current_batches): + seqbeg = _seqstarts[b] + seqend = _seqstarts[b+1] + kvbeg = _kvstarts[b] + kvend = _kvstarts[b+1] + attn_mask[seqbeg:seqend, kvbeg:kvend] = \ + torch.triu(torch.full_like(attn_mask[seqbeg:seqend, kvbeg:kvend], float("-inf")), diagonal=1) + + self.profiler.forward_start.append(time.time() - global_start) + synchronize() + model_output = self.model.forward(model_input0, attn_mask, _seqstarts, _kvstarts, + _cachestarts, decoding_batches, + _start_pos, max_seqlen, max_kvlen, + kv_cache, kv_scale) + synchronize() + self.profiler.forward_end.append(time.time() - global_start) + + + current_steps = [i + 1 for i in current_steps] + global_step += 1 + # update input 参数 + if pp_rank == 0: + update_decode_input(seqlens, start_pos) + + self.profiler.send_start.append(time.time() - global_start) + if dist_mapping.is_last_pp_rank(): + next_token_idx = (_start_pos + _seqlens).tolist() + post_process(queue, model_output, temperature, top_p, next_token_idx, tid_list, current_steps, max_gen_len, output_ids) + else: + send_tensor_info(dist_mapping, decoding_batches, _seqlens, _start_pos, _cachestarts, torch.tensor(tid_list).cuda(), model_output, global_step) + + self.profiler.send_end.append(time.time() - global_start) + + if max(current_steps) >= max_gen_len: + # myprint("break") + break + + # print(f"rank [{pp_rank}]: recv start | recv end | forward start | forward end | send start | send end ") + # for i in range(max_gen_len): + # print(f"rank[{pp_rank}] step[{i}]: {round(self.profiler.recv_start[i] * 1000)} | {round(self.profiler.recv_end[i] * 1000)} | {round(self.profiler.forward_start[i] * 1000)} | {round(self.profiler.forward_end[i] * 1000)} | {round(self.profiler.send_start[i] * 1000)} | {round(self.profiler.send_end[i] * 1000)} ") + + return output_ids.tolist() + + + # python export method + def export( + self, + export_path: str + ): + bsz = 4 + total_len = 16 + + total_cache_len = bsz * total_len + head_dim = self.model.params.hidden_dim // self.model.params.num_heads + num_local_kv_heads = self.model.params.num_kv_heads // self.model.dist_mapping.tp_size + num_layers = self.model.params.num_layers + + if self.model.params.cache_layout == 0: + cache_prefix_shape = (total_cache_len, num_layers, 2, num_local_kv_heads) + max_tokenlen_idx = 0 + elif self.model.params.cache_layout == 1: + cache_prefix_shape = (num_layers, total_cache_len, 2, num_local_kv_heads) + max_tokenlen_idx = 1 + elif self.model.params.cache_layout == 2: + cache_prefix_shape = (num_layers, 2, total_cache_len, num_local_kv_heads) + max_tokenlen_idx = 2 + elif self.model.params.cache_layout == 3: + cache_prefix_shape = (num_layers, 2, num_local_kv_heads, total_cache_len) + max_tokenlen_idx = 3 + else: + raise Exception("unsupported cache_layout: {}".format(self.model.params.cache_layout)) + + if self.model.params.cache_quant_bit == 8: + scale_head_dim = head_dim // self.model.params.cache_quant_group + kv_cache = torch.zeros(cache_prefix_shape + (head_dim,), dtype=torch.int8) + kv_scale = torch.zeros(cache_prefix_shape + (scale_head_dim,), dtype=torch.float16) + else: + kv_cache = torch.zeros(cache_prefix_shape + (head_dim,), dtype=torch.float16) + kv_scale = None + + seqlen = total_len // 2 + token_ids = torch.ones(bsz * seqlen, dtype=torch.int64) + hidden_states = torch.ones(bsz * seqlen, self.model.params.hidden_dim) + start_pos = torch.zeros(bsz, dtype=torch.int64) + seqstarts = torch.arange(0, seqlen * (bsz + 1), seqlen, dtype=torch.int64) + kvstarts = torch.arange(0, seqlen * (bsz + 1), seqlen, dtype=torch.int64) + decoding_batches = torch.tensor([0], dtype=torch.int64) + max_seqlen = torch.tensor([seqlen]) + attn_mask = torch.empty(0, dtype=torch.float16) + + if self.model.params.cache_mode == 0: + cachestarts = torch.arange(0, total_len * bsz, total_len, dtype=torch.int64) + cachestarts_dim_name = 'batch' + elif self.model.params.cache_mode == 1: + cachestarts = torch.arange(0, total_len * bsz, dtype=torch.int64) + cachestarts_dim_name = 'total_kvlen' + else: + raise Exception("unsupported cache_mode: {}".format(self.model.params.cache_mode)) + + # set input + if self.model.dist_mapping.is_first_pp_rank(): + input_names = ["token_ids", "attn_mask", "seqstarts", "kvstarts", + "cachestarts", "decoding_batches", + "start_pos", "max_seqlen", "max_kvlen", + "kv_cache", "kv_scale"] + input_tensors = [token_ids, attn_mask, + seqstarts, kvstarts, + cachestarts, decoding_batches, + start_pos, max_seqlen, max_seqlen, + kv_cache, kv_scale + ] + dynamic_axes = { + 'token_ids': { + 0:'total_seqlen' + }, + 'seqstarts': { + 0:'batch + 1' + }, + 'kvstarts': { + 0:'batch + 1' + }, + 'cachestarts': { + 0:cachestarts_dim_name + }, + 'start_pos': { + 0:'batch' + }, + 'kv_cache': { + max_tokenlen_idx: 'max_tokenlen' + }, + 'kv_scale': { + max_tokenlen_idx: 'max_tokenlen' + }, + } + else: + input_names = ["hidden_states", "attn_mask", "seqstarts", "kvstarts", + "cachestarts", "decoding_batches", + "start_pos", "max_seqlen", "max_kvlen", + "kv_cache", "kv_scale"] + input_tensors = [hidden_states, attn_mask, + seqstarts, kvstarts, + cachestarts, decoding_batches, + start_pos, max_seqlen, max_seqlen, + kv_cache, kv_scale + ] + + dynamic_axes = { + 'hidden_states': { + 0:'total_seqlen' + }, + 'seqstarts': { + 0:'batch + 1' + }, + 'kvstarts': { + 0:'batch + 1' + }, + 'cachestarts': { + 0:cachestarts_dim_name + }, + 'start_pos': { + 0:'batch' + }, + 'kv_cache': { + max_tokenlen_idx: 'max_tokenlen' + }, + 'kv_scale': { + max_tokenlen_idx: 'max_tokenlen' + }, + } + + # set output + if self.model.dist_mapping.is_last_pp_rank(): + output_names = ["logits"] + dynamic_axes.update({ + "logits": { + 0: 'batch', + 1: 'vocab_size'} + }) + + else: + output_names = ["hidden_states"] + dynamic_axes.update({ + 'hidden_states': { + 0: 'total_seqlen', + 1: 'hidden_dim'} + }) + + if self.model.params.cache_quant_bit == 0: + dynamic_axes.pop('kv_scale') + input_names.pop() + input_tensors.pop() + + local_rank = self.model.dist_mapping.rank + model_path = os.path.join(export_path, "model_slice_{}".format(local_rank)) + if not os.path.exists(model_path): + os.makedirs(model_path) + + torch.onnx.export( + self.model.cpu(), + tuple(input_tensors), + os.path.join(model_path, "model.onnx"), + input_names=input_names, + output_names=output_names, + do_constant_folding=True, + opset_version=11, + dynamic_axes=dynamic_axes) + + if local_rank == 0: + with open(os.path.join(export_path, "params.json"), "w") as f: + json.dump(self.model.params.__dict__, f) \ No newline at end of file