diff --git a/Dockerfile.rocm b/Dockerfile.rocm index ec669ca89c9b8..a143f37ab4f2f 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,9 +1,14 @@ -FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 -ENV WORKSPACE_DIR=/workspace -RUN mkdir -p $WORKSPACE_DIR -WORKDIR $WORKSPACE_DIR -# Limit arch's so composable kernel doesn't take days to finish -ENV PYTORCH_ROCM_ARCH=gfx90a;gfx942 +# default base image +ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" + +FROM $BASE_IMAGE + +RUN echo "Base image is $BASE_IMAGE" + +# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" +# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" + + ARG FA_GFX_ARCHS="gfx90a;gfx942" RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" @@ -22,22 +27,9 @@ ARG BUILD_CUPY="1" # whether to build triton on rocm ARG BUILD_TRITON="1" -# Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y - # Install some basic utilities RUN apt-get update && apt-get install -y \ - curl \ - ca-certificates \ - sudo \ - git \ - bzip2 \ - libx11-6 \ - build-essential \ - wget \ - unzip \ - nvidia-cuda-toolkit \ - tmux \ + sqlite3 libsqlite3-dev libfmt-dev \ && rm -rf /var/lib/apt/lists/* ### Mount Point ### @@ -60,6 +52,8 @@ RUN if [ "$BUILD_FA" = "1" ]; then \ && cd libs \ && git clone https://github.com/ROCm/flash-attention.git \ && cd flash-attention \ + && git checkout ${FA_BRANCH} \ + && git submodule update --init \ && export GPU_ARCHS=${FA_GFX_ARCHS} \ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ @@ -94,10 +88,11 @@ RUN if [ "$BUILD_TRITON" = "1"]; then \ mkdir -p libs \ && cd libs \ && pip uninstall -y triton \ - && git clone https://github.com/ROCmSoftwarePlatform/triton.git + && git clone https://github.com/ROCm/triton.git \ && cd triton/python \ && pip3 install -e . \ - && cd ../..; \ + && cd ../.. \ + && rm -r triton; \ fi COPY ./ /app/vllm @@ -105,16 +100,17 @@ COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install xformers==0.0.23 --no-deps -RUN cd vllm \ - && pip install -r requirements-rocm.txt \ - && pip install typing-extensions==4.8.0 \ - && bash patch_xformers.rocm.sh \ - && cd gradlib && python setup.py develop && cd ../ \ - && python setup.py build && python setup.py develop; exit 0 - -RUN pip install pyarrow Ray pandas==2.0 numpy==1.20.3 +RUN cd /app \ + && cd vllm \ + && pip install -U -r requirements-rocm.txt \ + && if [ "$BUILD_FA" = "1" ]; then \ + bash patch_xformers.rocm.sh; fi \ + && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ + patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch; fi \ + && python3 setup.py install \ + && cd .. -RUN git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData.git \ - && cd rocmProfileData && make; make install +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir ray[all] -WORKDIR /workspace/vllm +CMD ["/bin/bash"] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e940c26f24ed9..0eabd1f66ffc5 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -3,23 +3,17 @@ import time from pathlib import Path from typing import Optional -import pandas as pd + import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams -from torch.profiler import profile, record_function, ProfilerActivity -def list_of_ints(arg): - return list(map(int, arg.split(','))) def main(args: argparse.Namespace): print(args) - print(f'>>>Loading LLM') - if args.report: - results_df = pd.DataFrame(columns=['model', 'batch', 'tp', 'input', 'output', 'latency']) # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM( @@ -36,101 +30,60 @@ def main(args: argparse.Namespace): ray_workers_use_nsight=args.ray_workers_use_nsight, ) - for batch_size in args.batch_size: - for output_len in args.output_len: - for input_len in args.input_len: - print(f'>>>RUNNING {args.model} Batch_size:{batch_size} Input_len:{input_len} Output_len:{output_len}') - sampling_params = SamplingParams( - n=args.n, - temperature=0.0 if args.use_beam_search else 1.0, - top_p=1.0, - use_beam_search=args.use_beam_search, - ignore_eos=True, - max_tokens=output_len, - ) - print(sampling_params) - dummy_prompt_token_ids = [[0] * input_len] * batch_size - dummy_prompts = [] - dummy_prompts.append('DeepSpeed is a machine learning library that deep learning practitioners should use for what purpose') - - def run_to_completion(profile_dir: Optional[str] = None): - if profile_dir: - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - on_trace_ready=torch.profiler.tensorboard_trace_handler( - str(profile_dir))) as p: - llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) - print(p.key_averages()) - elif args.accuracy: - start_time = time.perf_counter() - rsp = llm.generate( - #prompt_token_ids=dummy_prompt_token_ids, - prompts=dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) - end_time = time.perf_counter() - latency = end_time - start_time - print('>>Rsp', rsp[0].outputs) - return latency - else: - start_time = time.perf_counter() - rsp = llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) - end_time = time.perf_counter() - latency = end_time - start_time - print('>>Rsp', rsp[0].outputs) - return latency - - print("Warming up...") - run_to_completion(profile_dir=None) - - if (args.warmup_only): - - print(">>> Warmup only specified, exiting") - continue + sampling_params = SamplingParams( + n=args.n, + temperature=0.0 if args.use_beam_search else 1.0, + top_p=1.0, + use_beam_search=args.use_beam_search, + ignore_eos=True, + max_tokens=args.output_len, + ) + print(sampling_params) + dummy_prompt_token_ids = np.random.randint(10000, + size=(args.batch_size, + args.input_len)) + dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() - if args.profile: - profile_dir = args.profile_result_dir - if not profile_dir: - profile_dir = Path( - "." - ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" - print(f"Profiling (results will be saved to '{profile_dir}')...") - run_to_completion(profile_dir=args.profile_result_dir) - return - if args.rpd: - from rpdTracerControl import rpdTracerControl - rpdTracerControl.setFilename(name = "/workspace/trace.rpd", append=True) - profile_rpd = rpdTracerControl() - profile_rpd.start() - print(f"RPD Profiling'...") - with torch.autograd.profiler.emit_nvtx(): - run_to_completion(profile_dir=None) - profile_rpd.stop() - return + def run_to_completion(profile_dir: Optional[str] = None): + if profile_dir: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_dir))) as p: + llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False) + print(p.key_averages()) + else: + start_time = time.perf_counter() + llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False) + end_time = time.perf_counter() + latency = end_time - start_time + return latency - # Benchmark. - latencies = [] - for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): - latencies.append(run_to_completion(profile_dir=None)) + print("Warming up...") + run_to_completion(profile_dir=None) - if torch.distributed.get_rank() == 0: - #results_df = pd.DataFrame(columns=['model', 'batch', 'tp', 'input', 'output', 'latency']) - latency=np.mean(latencies) - print(f'Avg latency: {latency} seconds') - if args.report: - entry = {'model':[args.model], 'tp':[args.tensor_parallel_size],'batch':[batch_size], 'input':[input_len], 'output':[output_len], 'latency':[latency]} - results_df = pd.concat([results_df, pd.DataFrame(entry)], ignore_index=True) - if torch.distributed.get_rank() == 0 and args.report: - print(results_df) - results_df.to_csv(args.report_file, index=False) + if args.profile: + profile_dir = args.profile_result_dir + if not profile_dir: + profile_dir = Path( + "." + ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" + print(f"Profiling (results will be saved to '{profile_dir}')...") + run_to_completion(profile_dir=profile_dir) + return + # Benchmark. + latencies = [] + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): + latencies.append(run_to_completion(profile_dir=None)) + print(f'Avg latency: {np.mean(latencies)} seconds') if __name__ == '__main__': @@ -144,9 +97,9 @@ def run_to_completion(profile_dir: Optional[str] = None): choices=['awq', 'gptq', 'squeezellm', None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) - parser.add_argument('--input-len', type=list_of_ints, default=32) - parser.add_argument('--output-len', type=list_of_ints, default=128) - parser.add_argument('--batch-size', type=list_of_ints, default=8) + parser.add_argument('--input-len', type=int, default=32) + parser.add_argument('--output-len', type=int, default=128) + parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--n', type=int, default=1, @@ -159,7 +112,6 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') - parser.add_argument( '--dtype', type=str, @@ -172,9 +124,6 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--enforce-eager', action='store_true', help='enforce eager mode and disable CUDA graph') - parser.add_argument('--accuracy', - action='store_true', - help='Run an Actual query through vllm') parser.add_argument( "--kv-cache-dtype", type=str, @@ -216,14 +165,5 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help="If specified, use nsight to profile ray workers", ) - parser.add_argument( - '--rpd', - action='store_true', - help='profile the generation process of a single batch using the rpd tracer') - parser.add_argument('--warmup-only', action='store_true', - help='only run warmup, useful for tuning') - parser.add_argument('--report', action='store_true', - help='turn on dataframe reporting') - parser.add_argument('--report-file', type=str, default=None) args = parser.parse_args() main(args) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 11524cda2041a..7005509094fec 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -629,7 +629,11 @@ template< typename CACHE_T, int BLOCK_SIZE, bool IS_FP8_KV_CACHE, +#ifdef USE_ROCM int NUM_THREADS = 1024> +#else + int NUM_THREADS = 128> +#endif void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, @@ -810,8 +814,13 @@ template< typename CACHE_T, int BLOCK_SIZE, bool IS_FP8_KV_CACHE, +#ifdef USE_ROCM + int NUM_THREADS = 128, + int PARTITION_SIZE = 512> +#else int NUM_THREADS = 1024, int PARTITION_SIZE = 1024> +#endif void paged_attention_v2_launcher( torch::Tensor& out, torch::Tensor& exp_sums, diff --git a/csrc/cache.h b/csrc/cache.h index 82b90eb4ab631..718a5f6cfd7f7 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -24,13 +24,6 @@ void reshape_and_cache( const std::string& kv_cache_dtype, const float kv_scale); -void gather_cached_kv( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping); - // Just for unittest void convert_fp8( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 73f61e92b1a51..24aaa2ff3e263 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -277,167 +277,6 @@ void reshape_and_cache( namespace vllm { -// Grid: (num_blocks, block_size). -template -__global__ void gather_cached_kv_kernel( - scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] - scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size] - const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) { - const int token_idx = blockIdx.x; - const int slot_idx = slot_mapping[token_idx]; - const int block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; - - const int num_tokens = num_heads * head_size; - for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { - const int tgt_key_idx = token_idx * key_stride + i; - const int tgt_value_idx = token_idx * value_stride + i; - - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension - const int x_offset = head_offset % x; - - const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int src_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; - - key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); - value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); - } -} - -template -__global__ void gather_cached_kv_kernel_optimized( - scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size] - scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size] - const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int *__restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) -{ - const int token_idx = blockIdx.x; - const int slot_idx = slot_mapping[token_idx]; - const int block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; - - const int dim = num_heads * head_size; - assert(dim % 4 == 0); // this is true for known use cases - const int unroll_factor = 4; - const int unrolled_dim = dim / unroll_factor; - - for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x) - { - int tgt_key_indices[unroll_factor]; - int tgt_value_indices[unroll_factor]; - int src_key_indices[unroll_factor]; - int src_value_indices[unroll_factor]; - scalar_t keys_to_store[unroll_factor]; - scalar_t values_to_store[unroll_factor]; - - #pragma unroll - for (int j = 0; j < unroll_factor; ++j) - { - int index = i + j * unrolled_dim; - - const int tgt_key_idx = token_idx * key_stride + index; - const int tgt_value_idx = token_idx * value_stride + index; - - const int head_idx = index / head_size; - const int head_offset = index % head_size; - const int x_idx = head_offset / x; - const int x_offset = head_offset % x; - - const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int src_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; - - tgt_key_indices[j] = tgt_key_idx; - tgt_value_indices[j] = tgt_value_idx; - src_key_indices[j] = src_key_idx; - src_value_indices[j] = src_value_idx; - - keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); - values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); - } - - #pragma unroll - for (int j = 0; j < unroll_factor; ++j) - { - key[tgt_key_indices[j]] = keys_to_store[j]; - value[tgt_value_indices[j]] = values_to_store[j]; - } - } -} - -} // namespace vllm - -void gather_cached_kv( - torch::Tensor& key, // [out] [num_tokens, num_heads, head_size] - torch::Tensor& value, // [out] [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping) // [in] [num_tokens] -{ - int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(3); - int x = key_cache.size(4); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key.scalar_type(), - "gather_cached_kv_kernel_optimized", - [&] { - vllm::gather_cached_kv_kernel_optimized<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - slot_mapping.data_ptr(), - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); - }); -} - -namespace vllm { - template __global__ void convert_fp8_kernel( const Tin* __restrict__ src_cache, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index d69fc63c62716..de02afc162113 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -90,10 +90,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); - cache_ops.def( - "gather_cached_kv", - &gather_cached_kv, - "Gather key and value from the cache into contiguous QKV tensors"); cache_ops.def( "convert_fp8", &convert_fp8, diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py deleted file mode 100644 index 5d0b93793c89d..0000000000000 --- a/vllm/model_executor/models/internlm.py +++ /dev/null @@ -1,299 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Any, Dict, List, Optional, Tuple - -import torch -from torch import nn -from transformers import LlamaConfig - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -class InternLMMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class InternLMAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - bias: bool, - rope_theta: float = 10000, - max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - ): - super().__init__() - self.hidden_size = hidden_size - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - self.total_num_heads = num_heads - assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) - self.head_dim = hidden_size // self.total_num_heads - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - bias=bias, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=bias, - linear_method=linear_method, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class InternLMDecoderLayer(nn.Module): - - def __init__( - self, - config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.self_attn = InternLMAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - bias=config.bias, - rope_theta=rope_theta, - max_position_embeddings=max_position_embeddings, - linear_method=linear_method, - rope_scaling=getattr(config, "rope_scaling", None), - ) - self.mlp = InternLMMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -class InternLMModel(nn.Module): - - def __init__( - self, - config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.embed_tokens = VocabParallelEmbedding( - vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - residual, - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class InternLMForCausalLM(nn.Module): - - def __init__( - self, - config, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = InternLMModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight)