diff --git a/docs/en/quantization/w8a8.md b/docs/en/quantization/w8a8.md index 1b1726bd5f..5cdb48f764 100644 --- a/docs/en/quantization/w8a8.md +++ b/docs/en/quantization/w8a8.md @@ -1,55 +1,74 @@ # SmoothQuant -LMDeploy provides functions for quantization and inference of large language models using 8-bit integers. +LMDeploy provides functions for quantization and inference of large language models using 8-bit integers(INT8). For GPUs such as Nvidia H100, lmdeploy also supports 8-bit floating point(FP8). -Before starting inference, ensure that lmdeploy and openai/triton are correctly installed. Execute the following commands to install these: +And the following NVIDIA GPUs are available for INT8/FP8 inference respectively: + +- INT8 + - V100(sm70): V100 + - Turing(sm75): 20 series, T4 + - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100 + - Ada Lovelace(sm89): 40 series + - Hopper(sm90): H100 +- FP8 + - Ada Lovelace(sm89): 40 series + - Hopper(sm90): H100 + +First of all, run the following command to install lmdeploy: ```shell -pip install lmdeploy -pip install triton>=2.1.0 +pip install lmdeploy[all] ``` -## 8-bit Weight Model Inference +## 8-bit Weight Quantization -For performing 8-bit weight model inference, you can directly download the pre-quantized 8-bit weight models from LMDeploy's [model zoo](https://huggingface.co/lmdeploy). For instance, the 8-bit Internlm-chat-7B model is available for direct download from the model zoo: +Performing 8-bit weight quantization involves three steps: -```shell -git-lfs install -git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon) -``` +1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing. +2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file. +3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model. -Alternatively, you can manually convert original 16-bit weights into 8-bit by referring to the content under the ["8bit Weight Quantization"](#8bit-weight-quantization) section. Save them in the internlm-chat-7b-w8 directory, using the command below: +lmdeploy provides `lmdeploy lite smooth_quant` command to accomplish all three tasks detailed above. Note that the argument `--quant-dtype` is used to determine if you are doing int8 or fp8 weight quantization. To get more info about usage of the cli, run `lmdeploy lite smooth_quant --help` -```shell -lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8 -``` +Here are two examples: -Afterwards, use the following command to interact with the model via the terminal: +- int8 -```shell -lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch -``` + ```shell + lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8 + ``` -## Launching gradio service +- fp8 -Coming soon... + ```shell + lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8 + ``` -## Inference Speed +## Inference -Coming soon... +Trying the following codes, you can perform the batched offline inference with the quantized model: -## 8bit Weight Quantization +```python +from lmdeploy import pipeline, PytorchEngineConfig -Performing 8bit weight quantization involves three steps: +engine_config = PytorchEngineConfig(tp=1) +pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config) +response = pipe(["Hi, pls intro yourself", "Shanghai is"]) +print(response) +``` -1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing. -2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file. -3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model. +## Service + +LMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup: + +```shell +lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch +``` -The script `lmdeploy/lite/apis/smooth_quant.py` accomplishes all three tasks detailed above. For example, you can obtain the model weights of the quantized Internlm-chat-7B model by running the following command: +The default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`: ```shell -lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8 +lmdeploy serve api_client http://0.0.0.0:23333 ``` -After saving, you can instantiate your quantized model by calling the from_pretrained interface. +You can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can also read the API specification from [here](../llm/api_server.md). diff --git a/docs/zh_cn/quantization/w8a8.md b/docs/zh_cn/quantization/w8a8.md index 302dd538fd..3a63c82f8c 100644 --- a/docs/zh_cn/quantization/w8a8.md +++ b/docs/zh_cn/quantization/w8a8.md @@ -1,56 +1,76 @@ # W8A8 LLM 模型部署 -LMDeploy 提供了使用 8 bit 整数对神经网络模型进行量化和推理的功能。 +LMDeploy 提供了使用 8-bit 整数(INT8)和浮点数(FP8)对神经网络模型进行量化和推理的功能。 -在开始推理前,需要确保已经正确安装了 lmdeploy 和 openai/triton。可以通过以下命令进行安装: +可用于 INT8 和 FP8 推理的 NVIDIA GPU 分别为: + +- INT8 + - V100(sm70): V100 + - Turing(sm75): 20 series, T4 + - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100 + - Ada Lovelace(sm89): 40 series + - Hopper(sm90): H100 +- FP8 + - Ada Lovelace(sm89): 40 series + - Hopper(sm90): H100 + +首先,执行如下命令安装lmdeploy: ```shell -pip install lmdeploy -pip install triton>=2.1.0 +pip install lmdeploy[all] ``` -## 8bit 权重模型推理 +## 8-bit 权重量化 -如果你需要进行 8 bit 权重模型推理,可以直接从 LMDeploy 的 [model zoo](https://huggingface.co/lmdeploy) 下载已经量化好的 8bit 权重模型。以8bit 的 Internlm-chat-7B 模型为例,可以从 model zoo 直接下载: +进行 8-bit 权重量化需要经历以下三步: -```shell -git-lfs install -git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon) -``` +1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。 +2. **模块替换**:使用 `QRMSNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RMSNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。 +3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。 -你也可以参考["8bit 权重量化"](#8bit-权重量化)章节的内容手动将原 16bit 权重量化为 8bit,并保存至 `internlm-chat-7b-w8` 目录下,操作命令如下: +lmdeploy 提供了命令行工具 `lmdeploy lite smooth_quant` 实现了以上三个步骤。并且其中命令行参数 `--quant-dtype` 可以用来控制是进行8-bit整数还是浮点数类型的量化。更多命令行工具使用方式,请执行 `lmdeploy lite smooth_quant --help` 查看。 -```shell -lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8 -``` +以下示例演示了进行 int8 或 fp8 的量化命令。 -然后,执行以下命令,即可在终端与模型对话: +- int8 -```shell -lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch -``` + ```shell + lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8 + ``` -## 启动 gradio 服务 +- fp8 -Coming soon... + ```shell + lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8 + ``` -## 推理速度 +## 模型推理 -Coming soon... +量化后的模型,通过以下几行简单的代码,可以实现离线推理: -## 8bit 权重量化 +```python +from lmdeploy import pipeline, PytorchEngineConfig -进行 8bit 权重量化需要经历以下三步: +engine_config = PytorchEngineConfig(tp=1) +pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config) +response = pipe(["Hi, pls intro yourself", "Shanghai is"]) +print(response) +``` -1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。 -2. **模块替换**:使用 `QRSMNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RSMNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。 -3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。 +关于 pipeline 的详细介绍,请参考[这里](../llm/pipeline.md) -我们在`lmdeploy/lite/api/smooth_quantity.py`脚本中已经实现了以上三个步骤。例如,可以通过以下命令得到量化后的 Internlm-chat-7B 模型的模型权重: +## 推理服务 + +LMDeploy `api_server` 支持把模型一键封装为服务,对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例: ```shell +lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch +``` -lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8 +服务默认端口是23333。在 server 启动后,你可以在终端通过`api_client`与server进行对话: + +```shell +lmdeploy serve api_client http://0.0.0.0:23333 ``` -保存之后,你就可以通过调用from_pretrained接口来实例化你的量化模型。 +还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法。 diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py index 236e022b34..499bace485 100644 --- a/lmdeploy/cli/lite.py +++ b/lmdeploy/cli/lite.py @@ -126,6 +126,7 @@ def add_parser_smooth_quant(): ArgumentHelper.calib_batchsize(parser) ArgumentHelper.calib_search_scale(parser) ArgumentHelper.dtype(parser) + ArgumentHelper.quant_dtype(parser) @staticmethod def auto_awq(args): diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index cf7b6526ec..4edf23d684 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -113,6 +113,16 @@ def dtype(parser, default: str = 'auto'): 'for BF16 models. This option will be ignored if ' 'the model is a quantized model') + @staticmethod + def quant_dtype(parser, default: str = 'int8'): + return parser.add_argument( + '--quant-dtype', + type=str, + default=default, + choices=['int8', 'float8_e4m3fn', 'float8_e5m2', 'fp8'], + help='data type for the quantized model weights and activations.' + 'Note "fp8" is the short version of "float8_e4m3fn"') + @staticmethod def model_format(parser, default: str = None): return parser.add_argument( diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index 188eedbd0e..8d67535bcc 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -24,7 +24,19 @@ def smooth_quant(model: str, batch_size: int = 1, w_bits: int = 8, dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto', - device: str = 'cuda'): + device: str = 'cuda', + quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', + 'float8_e5m2'] = 'int8'): + if quant_dtype == 'fp8': + quant_dtype = 'float8_e4m3fn' + + quant_dtype = getattr(torch, quant_dtype, torch.int8) + if quant_dtype.is_floating_point: + q_dtype_info = torch.finfo(quant_dtype) + else: + q_dtype_info = torch.iinfo(quant_dtype) + + assert q_dtype_info.bits == w_bits model_path = model vl_model, model, tokenizer, work_dir = calibrate(model, calib_dataset, @@ -84,7 +96,7 @@ def smooth_quant(model: str, if skipped_module(name): continue linear.to(device) - q_linear = QLinear.from_float(linear) + q_linear = QLinear.from_float(linear, quant_dtype=quant_dtype) parent_name, _, child_name = name.rpartition('.') parent = model.get_submodule(parent_name) setattr(parent, child_name, q_linear) @@ -94,7 +106,7 @@ def smooth_quant(model: str, if skipped_module(name): continue norm.to(device) - q_norm = QRMSNorm.from_float(norm) + q_norm = QRMSNorm.from_float(norm, quant_dtype=quant_dtype) parent_name, _, child_name = name.rpartition('.') parent = model.get_submodule(parent_name) setattr(parent, child_name, q_norm) @@ -104,8 +116,10 @@ def smooth_quant(model: str, from .auto_awq import save_vl_model save_vl_model(vl_model, model_path, work_dir) else: + quant_dtype_s = str(quant_dtype).split('.')[1] model.config.update( - dict(quantization_config=dict(quant_method='smooth_quant'))) + dict(quantization_config=dict(quant_method='smooth_quant', + quant_dtype=f'{quant_dtype_s}'))) model.save_pretrained(work_dir, max_shard_size='2GB', safe_serialization=False) diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py index ef1e510193..13d9a47ddf 100644 --- a/lmdeploy/pytorch/backends/cuda/qmodules.py +++ b/lmdeploy/pytorch/backends/cuda/qmodules.py @@ -15,34 +15,48 @@ class TritonRMSNormW8A8Impl(RMSNormW8A8Impl): """triton RMS norm w8a8 implementation api.""" - def __init__(self, hidden_size: int, eps: float = 1e-6): + def __init__(self, + hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): super().__init__() self.hidden_size = hidden_size self.eps = eps + self.quant_dtype = quant_dtype def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None): """forward.""" - if residual is not None: - x = x + residual - residual = x - hidden_states_quant, rms_scale = rms_norm_dynamic_quant( - x, weight, self.eps) - x = QTensor(hidden_states_quant, rms_scale) if residual is None: + (x, + rms_scale) = rms_norm_dynamic_quant(x, + weight, + self.eps, + quant_dtype=self.quant_dtype) + x = QTensor(x, rms_scale) return x - return x, residual + else: + (x, rms_scale, + residual) = rms_norm_dynamic_quant(x, + weight, + self.eps, + residual=residual, + quant_dtype=self.quant_dtype) + x = QTensor(x, rms_scale) + return x, residual class TritonRMSNormBuilder(RMSNormW8A8Builder): """triton RMS norm w8a8 implementation builder.""" @staticmethod - def build(hidden_size: int, eps: float = 1e-6): + def build(hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): """build.""" - return TritonRMSNormW8A8Impl(hidden_size, eps) + return TritonRMSNormW8A8Impl(hidden_size, eps, quant_dtype) class TritonLinearW8A8Impl(LinearW8A8Impl): @@ -51,10 +65,12 @@ class TritonLinearW8A8Impl(LinearW8A8Impl): def __init__(self, in_features: int, out_features: int, - out_dtype: torch.dtype = torch.float16): + out_dtype: torch.dtype = torch.float16, + quant_dtype: torch.dtype = torch.int8): self.in_features = in_features self.out_features = out_features self.out_dtype = out_dtype + self.quant_dtype = quant_dtype def forward(self, x, @@ -64,8 +80,8 @@ def forward(self, all_reduce: bool = False): """forward.""" if isinstance(x, torch.Tensor): - x = x.contiguous() - input_quant, input_scale = per_token_quant_int8(x, 1e-7) + input_quant, input_scale = per_token_quant_int8( + x, 1e-7, quant_dtype=self.quant_dtype) else: assert isinstance(x, QTensor) input_quant, input_scale = x.tensor, x.scale @@ -89,6 +105,10 @@ class TritonLinearW8A8Builder(LinearW8A8Builder): def build(in_features: int, out_features: int, bias: bool = True, - dtype: torch.dtype = None): + dtype: torch.dtype = None, + quant_dtype: torch.dtype = torch.int8): """build.""" - return TritonLinearW8A8Impl(in_features, out_features, dtype) + return TritonLinearW8A8Impl(in_features, + out_features, + dtype, + quant_dtype=quant_dtype) diff --git a/lmdeploy/pytorch/backends/qmodules.py b/lmdeploy/pytorch/backends/qmodules.py index a61941b37d..e877a4ca6b 100644 --- a/lmdeploy/pytorch/backends/qmodules.py +++ b/lmdeploy/pytorch/backends/qmodules.py @@ -37,7 +37,9 @@ class RMSNormW8A8Builder(ABC): @staticmethod @abstractmethod - def build(hidden_size: int, eps: float = 1e-6): + def build(hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): """build.""" raise NotImplementedError @@ -71,6 +73,7 @@ class LinearW8A8Builder(ABC): def build(in_features: int, out_features: int, bias: bool = True, - dtype: torch.dtype = None): + dtype: torch.dtype = None, + quant_dtype: torch.dtype = torch.int8): """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index e06e0cf80a..727251e0e8 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -1003,9 +1003,29 @@ def __send_resps(step_outputs: Dict[int, InferOutput]): for out in step_outputs.values(): __send_resp(out) + def __do_prefill(): + # decoding if no waiting + if not self.scheduler.has_waiting(): + return False + + num_running = len(self.scheduler.running) + num_waiting = len(self.scheduler.waiting) + max_batches = self.scheduler_config.max_batches + + # prefill if too much waiting + if num_waiting >= 4: + return True + + # prefill if no enough running + if num_running < max_batches * 0.5: + return True + + # decoding + return False + async def __step(): """step decoding.""" - prefill = self.scheduler.has_waiting() + prefill = __do_prefill() schedule_output = self.scheduler.schedule( is_prefill=prefill, prealloc_size=prefill_interval) # schedule decoding if no valid prefill reqs. diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py index 0d0e10ec83..a8eeb63a5f 100644 --- a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py +++ b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py @@ -14,14 +14,13 @@ tl_round = tl.math.round -def per_channel_quant(x, n_bits, dtype): +def per_channel_quant(x: torch.Tensor, dtype: torch.dtype): """Quantize the input tensor 'x' channel-wise using the given number of bits. Args: x (torch.Tensor): The input tensor to be quantized. Must be a 2-dimensional tensor. - n_bits (int): The number of bits to use for quantization. dtype (torch.dtype): The data type to which the quantized tensor should be converted. @@ -32,31 +31,40 @@ def per_channel_quant(x, n_bits, dtype): assert x.ndim == 2 x = x.to(torch.float32) x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0] - q_max = 2**(n_bits - 1) - 1 - q_min = -2**(n_bits - 1) - scale = x_absmax / (2**(n_bits - 1) - 1) - x_q = torch.round(x / scale).clamp(q_min, q_max).to(dtype) + qtype_info = torch.finfo( + dtype) if dtype.is_floating_point else torch.iinfo(dtype) + q_max = qtype_info.max + q_min = qtype_info.min + scale = x_absmax / q_max + x_q = x / scale + if not dtype.is_floating_point: + x_q = torch.round(x_q) + x_q = x_q.clamp(q_min, q_max).to(dtype) return x_q, scale @triton.autotune( configs=[ triton.Config({ - 'BLOCK_N': 64, + 'BLOCK_M': 128, + 'BLOCK_N': 256, 'BLOCK_K': 128, }, - num_stages=4, - num_warps=4), + num_stages=3, + num_warps=8), triton.Config({ + 'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, }, - num_stages=4, - num_warps=4) + num_stages=3, + num_warps=8) ], key=['N', 'K'], + warmup=5, + rep=20, ) -@triton.jit +@triton.jit(do_not_specialize=['M']) def _linear( A, B, @@ -76,6 +84,7 @@ def _linear( GROUP_SIZE_M: tl.constexpr, rms_scale_ptr, linear_scale_ptr, + ACCUMULATOR_DTYPE: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot product) on input tensors `A` and `B`, and store the result in output @@ -100,12 +109,11 @@ def _linear( offs_k = tl.arange(0, BLOCK_K) a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE) for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0) - accumulator += tl.dot(a, b) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None) + accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk c = accumulator.to(tl.float32) @@ -124,42 +132,31 @@ def _linear( @triton.autotune( configs=[ triton.Config({ - 'BLOCK_N': 64, + 'BLOCK_M': 128, + 'BLOCK_N': 256, 'BLOCK_K': 128, }, - num_stages=4, - num_warps=4), + num_stages=3, + num_warps=8), triton.Config({ + 'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, }, - num_stages=4, - num_warps=4) + num_stages=3, + num_warps=8) ], key=['N', 'K'], + warmup=5, + rep=20, ) -@triton.jit -def _linear_add( - A, - B, - C, - residual_ptr, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - rms_scale_ptr, - linear_scale_ptr, -): +@triton.jit(do_not_specialize=['M']) +def _linear_add(A, B, C, residual_ptr, M, N, K, stride_am, stride_ak, + stride_bk, stride_bn, stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + rms_scale_ptr, linear_scale_ptr, + ACCUMULATOR_DTYPE: tl.constexpr): """Triton-accelerated function used to perform a linear operation (dot product) on input tensors `A` and `B`, with addition of residual. @@ -183,11 +180,11 @@ def _linear_add( a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE) for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0) - accumulator += tl.dot(a, b) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None) + accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk c = accumulator.to(tl.float32) @@ -231,14 +228,11 @@ def matmul_kernel_dynamic_quant(a, assert residual.shape == c_shape assert residual.is_contiguous() c = a.new_empty(c_shape, dtype=output_dtype) - - BLOCK_M = 128 - if M < BLOCK_M: - BLOCK_M = triton.next_power_of_2(M) - BLOCK_M = max(BLOCK_M, 16) + accumulator_dtype = tl.float32 if a.is_floating_point() else tl.int32 def grid(META): - return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, META['BLOCK_N']), ) + return (triton.cdiv(M, META['BLOCK_M']) * + triton.cdiv(N, META['BLOCK_N']), ) kernel_meta = get_kernel_meta(a) if residual is not None: @@ -255,10 +249,10 @@ def grid(META): b.stride(0), c.stride(-2), c.stride(-1), - BLOCK_M=BLOCK_M, GROUP_SIZE_M=8, rms_scale_ptr=rms_scale, linear_scale_ptr=linear_scale, + ACCUMULATOR_DTYPE=accumulator_dtype, **kernel_meta) else: _linear[grid](a, @@ -273,10 +267,10 @@ def grid(META): b.stride(0), c.stride(-2), c.stride(-1), - BLOCK_M=BLOCK_M, GROUP_SIZE_M=8, rms_scale_ptr=rms_scale, linear_scale_ptr=linear_scale, + ACCUMULATOR_DTYPE=accumulator_dtype, **kernel_meta) if bias is not None: c += bias @@ -286,13 +280,16 @@ def grid(META): @triton.jit def _per_token_quant_int8( - y_ptr, - y_q_ptr, - y_s_ptr, - y_stride, - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK: tl.constexpr, + y_ptr, + y_q_ptr, + y_s_ptr, + y_stride: tl.constexpr, + yq_stride: tl.constexpr, + N, # number of columns in X + eps: tl.constexpr, # epsilon to avoid division by zero + BLOCK: tl.constexpr, + Q_MAX: tl.constexpr, + IS_FLOATING_POINT: tl.constexpr, # True for floating point dtype ): """A Triton-accelerated function to perform per-token quantization on a tensor. @@ -302,7 +299,7 @@ def _per_token_quant_int8( # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) y_ptr += row * y_stride - y_q_ptr += row * y_stride + y_q_ptr += row * yq_stride y_s_ptr += row cols = tl.arange(0, BLOCK) # N <= BLOCK @@ -311,21 +308,26 @@ def _per_token_quant_int8( y = tl.load(y_ptr + cols, mask=mask, other=0.).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / 127 - y_q = tl_round(y / y_s).to(tl.int8) + y_s = _absmax / Q_MAX + y_q = y / y_s + if not IS_FLOATING_POINT: + y_q = tl_round(y_q).to(tl.int8) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) -def per_token_quant_int8(x, eps): +def per_token_quant_int8(x, eps, quant_dtype=torch.int8): """Function to perform per-token quantization on an input tensor `x`. It converts the tensor values into signed 8-bit integers and returns the quantized tensor along with the scaling factor used for quantization. """ - - x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) + qdtype_info = torch.finfo( + quant_dtype) if quant_dtype.is_floating_point else torch.iinfo( + quant_dtype) + q_max = qdtype_info.max + x_q = torch.empty_like(x, device=x.device, dtype=quant_dtype) M = x.numel() // x.shape[-1] N = x.shape[-1] x_s = torch.empty(x.shape[:-1] + (1, ), @@ -334,94 +336,184 @@ def per_token_quant_int8(x, eps): BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) + + if x.dim() > 2: + x = x.flatten(0, -2) + assert x.stride(-1) == 1 # enqueue kernel kernel_meta = get_kernel_meta(x) - _per_token_quant_int8[(M, )](x, - x_q, - x_s, - x.stride(-2), - N, - eps, - BLOCK=BLOCK, - num_warps=num_warps, - **kernel_meta) + _per_token_quant_int8[(M, )]( + x, + x_q, + x_s, + y_stride=x.stride(-2), + yq_stride=x_q.stride(-2), + N=N, + eps=eps, + BLOCK=BLOCK, + Q_MAX=q_max, + IS_FLOATING_POINT=quant_dtype.is_floating_point, + num_warps=num_warps, + **kernel_meta) return x_q, x_s @triton.jit -def _rms_norm_fwd_fused_dynamic_symmetric( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - Scale, # pointer to the scales of the output activation - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, +def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr): + """compute rms norm.""" + xf = x.to(tl.float32) + + var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS) + out = xf * tl.math.rsqrt(var + eps) + out = (w * out).to(x.dtype) + return out + + +@triton.jit +def rms_norm_quant_kernel( + input, + weight, + output, + out_scale, + input_row_stride: tl.constexpr, + eps: tl.constexpr, + N_COLS: tl.constexpr, + BLOCK_N: tl.constexpr, + Q_MIN: tl.constexpr, + Q_MAX: tl.constexpr, + IS_FLOATING_POINT: tl.constexpr, ): - """A Triton kernel that calculates Root Mean Square (RMS) normalization - with fused dynamic symmetric quantization.""" - row = tl.program_id(0) - Y += row * stride - X += row * stride + """rms norm kernel.""" + prog_id = tl.program_id(0) + offsets = tl.arange(0, BLOCK_N) - cols = tl.arange(0, BLOCK_SIZE) - mask = cols < N - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) - _var = x * x - var = tl.sum(_var, axis=0) / N - rstd = tl.math.rsqrt(var + eps) + w = tl.load(weight + offsets, mask=offsets < N_COLS) + + x_ptr = input + prog_id * input_row_stride + x = tl.load(x_ptr + offsets, mask=offsets < N_COLS) + out = _compute_rms_norm(x, w, eps, N_COLS) + + scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX + out_s_ptr = out_scale + prog_id + tl.store(out_s_ptr, scale) + out = out / scale + if not IS_FLOATING_POINT: + out = tl_round(out) + out = tl.clamp(out, Q_MIN, Q_MAX) + out_ptr = output + prog_id * input_row_stride + tl.store(out_ptr + offsets, out, mask=offsets < N_COLS) - w = tl.load(W + cols, mask=mask) - x_hat = x * rstd - y = x_hat * w - scale = tl.max(tl.abs(y)).to(tl.float32) / 127 - tl.store(Scale + row, scale) +@triton.jit +def add_rms_norm_quant_kernel( + input, + weight, + residual, + output, + out_scale, + out_residual, + input_row_stride: tl.constexpr, + residual_row_stride: tl.constexpr, + eps: tl.constexpr, + N_COLS: tl.constexpr, + BLOCK_N: tl.constexpr, + Q_MIN: tl.constexpr, + Q_MAX: tl.constexpr, + IS_FLOATING_POINT: tl.constexpr, +): + """rms norm kernel.""" + prog_id = tl.program_id(0) + offsets = tl.arange(0, BLOCK_N) + + w = tl.load(weight + offsets, mask=offsets < N_COLS) - y = tl_round(y / scale) - y = tl.minimum(y, 127) - y = tl.maximum(y, -128) - tl.store(Y + cols, y, mask=mask) + x_ptr = input + prog_id * input_row_stride + x = tl.load(x_ptr + offsets, mask=offsets < N_COLS) + res_ptr = residual + prog_id * residual_row_stride + res = tl.load(res_ptr + offsets, mask=offsets < N_COLS) -def rms_norm_dynamic_quant(x, w, eps): + new_x = x + res + out_res_ptr = out_residual + prog_id * residual_row_stride + tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS) + + out = _compute_rms_norm(new_x, w, eps, N_COLS) + + scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX + out_s_ptr = out_scale + prog_id + tl.store(out_s_ptr, scale) + out = out / scale + if not IS_FLOATING_POINT: + out = tl_round(out) + out = tl.clamp(out, Q_MIN, Q_MAX) + out_ptr = output + prog_id * input_row_stride + tl.store(out_ptr + offsets, out, mask=offsets < N_COLS) + + +def rms_norm_dynamic_quant(x, w, eps, residual=None, quant_dtype=torch.int8): """Performs RMS normalization with dynamic quantization. The function reshapes the input tensor `x`, creates an empty tensor `y` with the same shape as `x`, and calculates RMS normalization on the - reshaped `x` using a Triton kernel `_rms_norm_fwd_fused_dynamic_symmetric`. + reshaped `x` using a Triton kernel `rms_norm_quant_kernel`. """ - - x_arg = x.flatten(0, -2) - y = torch.empty_like(x, dtype=torch.int8) - M, K = x_arg.shape - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(K)) - if K > BLOCK_SIZE: - raise RuntimeError( - "This rms norm doesn't support feature dim >= 64KB.") - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + qdtype_info = torch.finfo( + quant_dtype) if quant_dtype.is_floating_point else torch.iinfo( + quant_dtype) + y = torch.empty_like(x, dtype=quant_dtype) scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32) - kernel_meta = get_kernel_meta(x_arg) - _rms_norm_fwd_fused_dynamic_symmetric[(M, )](x_arg, - y, - w, - scale, - x_arg.stride(0), - K, - eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - **kernel_meta) - return y, scale + + feat_size = w.shape[0] + seq_len = x.numel() // x.size(-1) + input_stride = x.stride(-2) + BLOCK_N = triton.next_power_of_2(feat_size) + grid = (seq_len, ) + + if residual is None: + rms_norm_quant_kernel[grid]( + x, + w, + y, + scale, + input_row_stride=input_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + Q_MIN=qdtype_info.min, + Q_MAX=qdtype_info.max, + IS_FLOATING_POINT=quant_dtype.is_floating_point, + num_warps=4, + num_stages=2) + return y, scale + else: + out_residual = torch.empty_like(x) + res_stride = residual.stride(-2) + add_rms_norm_quant_kernel[grid]( + x, + w, + residual, + y, + scale, + out_residual, + input_row_stride=input_stride, + residual_row_stride=res_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + Q_MIN=qdtype_info.min, + Q_MAX=qdtype_info.max, + IS_FLOATING_POINT=quant_dtype.is_floating_point, + num_warps=4, + num_stages=2) + return y, scale, out_residual def test_rms_and_linear(x, rms_weight, linear_weight, - dtype=torch.float16, + output_dtype=torch.float16, + quant_dtype=torch.int8, eps=1e-5): """Test quantized rms norm and quantized linear layer.""" @@ -434,15 +526,18 @@ def linear_torch(x, b): return F.linear(x, b) linear_weight_quant, linear_scale = per_channel_quant( - linear_weight, 8, torch.int8) + linear_weight, quant_dtype) - rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps) + rms_out, rms_scale = rms_norm_dynamic_quant(x, + rms_weight, + eps, + quant_dtype=quant_dtype) assert rms_out.shape == x.shape and rms_scale.shape[:-1] == x.shape[:-1] linear_out = matmul_kernel_dynamic_quant(rms_out, linear_weight_quant, rms_scale, linear_scale, - output_dtype=dtype) + output_dtype=output_dtype) rms_out_torch = rms_norm_torch(x, rms_weight, eps).half() linear_out_torch = linear_torch(rms_out_torch, linear_weight) @@ -456,17 +551,26 @@ def linear_torch(x, b): linear_out_torch.flatten().to(torch.float32))) -def test_per_token_quant(x, eps): +def test_per_token_quant(x, eps, quant_dtype=torch.int8): """Test per-token quantization.""" - def per_token_quant_int8_torch(x, eps): + def per_token_quant_int8_torch(x, eps, quant_dtype): + qdtype_info = torch.finfo( + quant_dtype) if quant_dtype.is_floating_point else torch.iinfo( + quant_dtype) + _absmax = torch.clamp(x.abs().max(dim=-1, keepdim=True)[0], min=eps) - x_s = _absmax / 127 - x_q = torch.clamp((x / x_s).round(), min=-128, max=127) + x_s = _absmax / qdtype_info.max + x_q = x / x_s + if not quant_dtype.is_floating_point: + x_q = x_q.round() + x_q = torch.clamp(x_q, min=qdtype_info.min, max=qdtype_info.max) return x_q, x_s - x_q, x_s = per_token_quant_int8(x, eps) - x_q_torch, x_s_torch = per_token_quant_int8_torch(x, eps) + x_q, x_s = per_token_quant_int8(x, eps, quant_dtype=quant_dtype) + x_q_torch, x_s_torch = per_token_quant_int8_torch(x, + eps, + quant_dtype=quant_dtype) assert x_q.shape == x_q_torch.shape and x_s.shape == x_s_torch.shape cos = torch.nn.CosineSimilarity(0) print( @@ -479,21 +583,11 @@ def per_token_quant_int8_torch(x, eps): x_s_torch.flatten().to(torch.float32))) -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=['M'], - x_vals=[1, 16, 32, 64, 128, 256] + [512 * i * 2 for i in range(1, 17)], - line_arg='provider', - line_vals=['int8_dynamic_triton_op', 'float_torch'], - line_names=['int8_dynamic_triton_op', 'float_torch'], - styles=[('blue', '-'), ('green', '-'), ('orange', '-'), - ('yellow', '-'), ('yellow', '-')], - ylabel='GB/s', - plot_name='forward', - args={ - 'dtype': torch.float16, - })) -def bench_rms_and_linear(M, dtype, provider, eps=1e-5, device='cuda'): +def bench_rms_and_linear(M: int, + provider: str, + dtype: torch.dtype = torch.float16, + eps: float = 1e-5): + """benchmark rms and linear.""" def rms_norm_torch(x, w, eps): variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) @@ -505,6 +599,7 @@ def linear_torch(x, b): N = 4096 K = 4096 + x_shape = (M, K) rms_w_shape = (x_shape[-1], ) rms_weight = torch.randn(rms_w_shape, @@ -516,14 +611,33 @@ def linear_torch(x, b): dtype=dtype, device='cuda', requires_grad=True) - linear_weight_quant, linear_scale = per_channel_quant( - linear_weight, 8, torch.int8) - alpha = max(x.max().abs(), x.min().abs()) - rms_scale = alpha / 127 + if provider == 'torch_fp16': + rms_out_torch = rms_norm_torch(x, rms_weight, eps).half() - if provider == 'int8_dynamic_triton_op': - rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps) + def y_fwd(): + linear_torch(rms_out_torch, linear_weight) + else: + if provider == 'triton_int8': + quant_dtype = torch.int8 + elif provider == 'triton_fp8_e4m3': + quant_dtype = torch.float8_e4m3fn + elif provider == 'triton_fp8_e5m2': + quant_dtype = torch.float8_e5m2 + + linear_weight_quant, linear_scale = per_channel_quant( + linear_weight, quant_dtype) + + alpha = max(x.max().abs(), x.min().abs()) + if quant_dtype.is_floating_point: + qdtype_info = torch.finfo(quant_dtype) + else: + qdtype_info = torch.iinfo(quant_dtype) + rms_scale = alpha / qdtype_info.max + rms_out, rms_scale = rms_norm_dynamic_quant(x, + rms_weight, + eps, + quant_dtype=quant_dtype) def y_fwd(): @@ -532,21 +646,22 @@ def y_fwd(): rms_scale, linear_scale, output_dtype=dtype) - elif provider == 'float_torch': - rms_out_torch = rms_norm_torch(x, rms_weight, eps).half() - - def y_fwd(): - linear_torch(rms_out_torch, linear_weight) quantiles = [0.5, 0.2, 0.8] ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) - return ms, max_ms, min_ms + + def perf(ms): + return 2 * M * N * K * 1e-12 / (ms * 1e-3) + + return perf(ms), perf(max_ms), perf(min_ms) if __name__ == '__main__': torch.manual_seed(0) + device_map = torch.cuda.get_device_capability() + is_fp8_supported = device_map[0] >= 9 dtype = torch.float16 # test (bs, seq_len, dim) x (dim, out_dim) x = torch.randn((2, 2048, 4096), dtype=dtype, device='cuda') @@ -559,7 +674,16 @@ def y_fwd(): dtype=dtype, device='cuda', requires_grad=True) - test_rms_and_linear(x, rms_weight, linear_weight) + test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8) + if is_fp8_supported: + test_rms_and_linear(x, + rms_weight, + linear_weight, + quant_dtype=torch.float8_e4m3fn) + test_rms_and_linear(x, + rms_weight, + linear_weight, + quant_dtype=torch.float8_e5m2) # test (M, K) x (K, N) x = torch.randn((4, 4096), dtype=dtype, device='cuda') @@ -572,11 +696,45 @@ def y_fwd(): dtype=dtype, device='cuda', requires_grad=True) - test_rms_and_linear(x, rms_weight, linear_weight) + test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8) + if is_fp8_supported: + test_rms_and_linear(x, + rms_weight, + linear_weight, + quant_dtype=torch.float8_e4m3fn) + test_rms_and_linear(x, + rms_weight, + linear_weight, + quant_dtype=torch.float8_e5m2) # test per-token quant x = torch.randn((4, 2048, 4096), dtype=dtype, device='cuda') eps = 1e-7 - test_per_token_quant(x, eps) - - bench_rms_and_linear.run(print_data=True) + test_per_token_quant(x, eps, quant_dtype=torch.int8) + if is_fp8_supported: + test_per_token_quant(x, eps, quant_dtype=torch.float8_e4m3fn) + test_per_token_quant(x, eps, quant_dtype=torch.float8_e5m2) + + # benchmark triton kernels + line_vals = ['triton_int8', 'torch_fp16'] + line_names = ['triton_int8', 'torch_fp16'] + + if is_fp8_supported: + line_vals += ['triton_fp8_e4m3', 'triton_fp8_e5m2'] + line_names += ['triton_fp8_e4m3', 'triton_fp8_e5m2'] + config = triton.testing.Benchmark(x_names=['M'], + x_vals=[1, 16, 32, 64, 128, 256] + + [512 * i * 2 for i in range(1, 5)], + line_arg='provider', + line_vals=line_vals, + line_names=line_names, + styles=[('blue', '-'), ('green', '-'), + ('orange', '-'), ('black', '-'), + ('yellow', '-')], + ylabel='TFLOPS', + plot_name='bench-triton', + args={ + 'dtype': torch.float16, + }) + bench_funch = (triton.testing.perf_report(config))(bench_rms_and_linear) + bench_funch.run(print_data=True) diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py index 583cd19fe9..38d794f1be 100644 --- a/lmdeploy/pytorch/models/baichuan.py +++ b/lmdeploy/pytorch/models/baichuan.py @@ -228,7 +228,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -245,7 +244,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index 73f64d277c..5a83154167 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -265,7 +265,6 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - quantization_config = getattr(config, 'quantization_config', None) self.num_layers = config.num_layers self.post_layer_norm = config.post_layer_norm @@ -280,7 +279,6 @@ def build_layer(layer_number): assert config.rmsnorm self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index c460b8e44f..8010e5cead 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -617,7 +617,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -634,7 +633,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/deepseek.py b/lmdeploy/pytorch/models/deepseek.py index e1a537b7e5..09c0b74fcc 100644 --- a/lmdeploy/pytorch/models/deepseek.py +++ b/lmdeploy/pytorch/models/deepseek.py @@ -313,7 +313,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -330,7 +329,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index 86be85669e..1f24206b16 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -263,7 +263,6 @@ def __init__(self, self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -280,7 +279,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index db246331a1..52f51a3ad1 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -221,7 +221,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, @@ -241,7 +240,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/internlm2_ve.py b/lmdeploy/pytorch/models/internlm2_ve.py index b1a2329597..c10faa5f5d 100644 --- a/lmdeploy/pytorch/models/internlm2_ve.py +++ b/lmdeploy/pytorch/models/internlm2_ve.py @@ -105,7 +105,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, @@ -125,7 +124,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/minicpmv26.py b/lmdeploy/pytorch/models/minicpmv26.py index e551dda841..9e47c56437 100644 --- a/lmdeploy/pytorch/models/minicpmv26.py +++ b/lmdeploy/pytorch/models/minicpmv26.py @@ -227,7 +227,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -247,7 +246,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py index ad27963093..962cdb3d2b 100644 --- a/lmdeploy/pytorch/models/mistral.py +++ b/lmdeploy/pytorch/models/mistral.py @@ -223,7 +223,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -240,7 +239,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index 288fdf3b19..988fee11e5 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -226,7 +226,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -243,7 +242,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/q_modules.py b/lmdeploy/pytorch/models/q_modules.py index 001fab7a60..8379bb18c9 100644 --- a/lmdeploy/pytorch/models/q_modules.py +++ b/lmdeploy/pytorch/models/q_modules.py @@ -34,13 +34,17 @@ class QRMSNorm(nn.Module): """It performs traditional RMS normalization and then quantizes the output to 8-bit integers.""" - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, quant_dtype=torch.int8): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.quant_dtype = quant_dtype @classmethod - def from_float(cls, mod: nn.Module, initialization: bool = True): + def from_float(cls, + mod: nn.Module, + initialization: bool = True, + quant_dtype=torch.int8): """Class method to create a QRMSNorm instance from a floating-point module. @@ -49,7 +53,7 @@ def from_float(cls, mod: nn.Module, initialization: bool = True): """ hidden_size = mod.weight.shape[0] eps = mod.variance_epsilon - q_mod = cls(hidden_size, eps) + q_mod = cls(hidden_size, eps, quant_dtype=quant_dtype) if initialization: q_mod.weight = nn.Parameter(mod.weight.detach()) return q_mod @@ -62,7 +66,10 @@ def forward(self, hidden_states): with its scale factor. """ hidden_states_quant, rms_scale = rms_norm_dynamic_quant( - hidden_states, self.weight, self.variance_epsilon) + hidden_states, + self.weight, + self.variance_epsilon, + quant_dtype=self.quant_dtype) return QTensor(hidden_states_quant, rms_scale) @@ -83,16 +90,18 @@ def __init__(self, out_features: int, bias: bool = True, device=None, - dtype=None) -> None: + dtype=None, + quant_dtype=torch.int8) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features self.out_features = out_features + self.quant_dtype = quant_dtype self.register_buffer( 'weight', torch.empty((out_features, in_features), device=device, - dtype=torch.int8)) + dtype=quant_dtype)) self.register_buffer( 'scale', torch.empty((out_features, 1), device=device, dtype=torch.float32)) @@ -103,7 +112,10 @@ def __init__(self, self.register_parameter('bias', None) @classmethod - def from_float(cls, mod: nn.Module, initialization: bool = True): + def from_float(cls, + mod: nn.Module, + initialization: bool = True, + quant_dtype=torch.int8): """Class method to create a QLinear instance from a floating-point module. @@ -114,11 +126,12 @@ def from_float(cls, mod: nn.Module, initialization: bool = True): mod.out_features, mod.bias is not None, device=mod.weight.device, - dtype=mod.weight.dtype) + dtype=mod.weight.dtype, + quant_dtype=quant_dtype) if initialization: - weight_quant, scale = per_channel_quant(mod.weight.detach(), 8, - torch.int8) + weight_quant, scale = per_channel_quant(mod.weight.detach(), + quant_dtype) q_mod.weight.data = weight_quant q_mod.scale = scale @@ -137,7 +150,8 @@ def forward(self, input): """ if isinstance(input, torch.Tensor): - input_quant, input_scale = per_token_quant_int8(input, 1e-7) + input_quant, input_scale = per_token_quant_int8( + input, 1e-7, quant_dtype=self.quant_dtype) else: assert isinstance(input, QTensor) input_quant, input_scale = input.tensor, input.scale diff --git a/lmdeploy/pytorch/models/qwen.py b/lmdeploy/pytorch/models/qwen.py index bf856461a3..20e184bdf8 100644 --- a/lmdeploy/pytorch/models/qwen.py +++ b/lmdeploy/pytorch/models/qwen.py @@ -229,7 +229,6 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - quantization_config = getattr(config, 'quantization_config', None) self.vocab_size = config.vocab_size self.embed_dim = config.hidden_size self.wte = nn.Embedding(self.vocab_size, @@ -263,7 +262,6 @@ def __init__(self, self.ln_f = RMSNorm(self.embed_dim, eps=config.layer_norm_epsilon, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index 38773c21e1..a26aa22d5a 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -225,7 +225,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -242,7 +241,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py index 61dda5ada3..de990592d5 100644 --- a/lmdeploy/pytorch/models/qwen2_moe.py +++ b/lmdeploy/pytorch/models/qwen2_moe.py @@ -328,7 +328,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -345,7 +344,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index 4e2b1017b5..bfd6e352f1 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -260,7 +260,6 @@ def __init__(self, self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.mrope_section = config.rope_scaling['mrope_section'] - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -277,7 +276,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 486c684a3c..a84d0ec3eb 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -598,17 +598,16 @@ def weight_spliter_lora_b(self, loaded_weight: torch.Tensor): class W8A8Linear(nn.Module): """w8a8 linear.""" - def __init__( - self, - in_features: int, - out_features: int, - bias: bool, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - colwise: bool = True, - is_tp: bool = False, - all_reduce: bool = True, - ): + def __init__(self, + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + colwise: bool = True, + is_tp: bool = False, + all_reduce: bool = True, + quant_dtype: Optional[torch.dtype] = torch.int8): super().__init__() if device is None: device = torch.device('cpu') @@ -618,10 +617,12 @@ def __init__( in_features, out_features = self._get_io_features( in_features, out_features, colwise) impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW8A8) + self.quant_dtype = quant_dtype self.impl = impl_builder.build(in_features, out_features, bias is not None, - dtype=dtype) + dtype=dtype, + quant_dtype=quant_dtype) weight, scale, bias = self.create_weights(in_features, out_features, bias, dtype, device) weight = torch.nn.Parameter(weight, requires_grad=False) @@ -663,7 +664,9 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int, world_size: int): """weight loader for rowwise linear.""" - if loaded_weight.dim() == 2 and param.dtype == torch.int8: + if loaded_weight.dim() == 2 and param.dtype in (torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2): weight = loaded_weight.chunk(world_size, 1)[rank] return default_weight_loader(param, weight) elif loaded_weight.dim() == 2 and loaded_weight.size(1) == 1: @@ -693,7 +696,7 @@ def create_weights(self, in_features: int, out_features: int, bias: bool, dtype: torch.dtype, device: torch.device): """create weights.""" weight = torch.empty((out_features, in_features), - dtype=torch.int8, + dtype=self.quant_dtype, device=device) scale = torch.empty((out_features, 1), dtype=torch.float32, @@ -745,7 +748,8 @@ def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, - out_names: Optional[List[int]] = None): + out_names: Optional[List[int]] = None, + quant_dtype: torch.dtype = torch.int8): self.split_section = all_out_features all_out_features = self._update_all_out_features(all_out_features) self.all_out_features = all_out_features @@ -761,7 +765,8 @@ def __init__(self, dtype, device, colwise=True, - is_tp=is_tp) + is_tp=is_tp, + quant_dtype=quant_dtype) self.weight.weight_loader = self.weight_loader self.scale.weight_loader = self.weight_loader self.weight.weight_spliter = self.weight_spliter @@ -814,7 +819,9 @@ def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, - num_replicate_kv_heads: int = 1): + num_replicate_kv_heads: int = 1, + quant_dtype: torch.dtype = torch.int8): + self.qkv_split_section = self._get_qkv_out_features( num_q_heads, num_kv_heads, head_size, head_size_v, num_replicate_kv_heads) @@ -835,7 +842,8 @@ def __init__(self, dtype=dtype, device=device, is_tp=is_tp, - out_names=out_names) + out_names=out_names, + quant_dtype=quant_dtype) def _update_all_out_features(self, all_out_features: List[int]): """update all out features.""" @@ -1200,6 +1208,10 @@ def build_linear(in_features: int, ) quant_method = quant_config['quant_method'] + quant_dtype = torch.int8 + if 'quant_dtype' in quant_config: + quant_dtype = eval('torch.' + quant_config['quant_dtype']) + if quant_method == 'awq': w_bit = quant_config.get('bits', 4) group_size = quant_config.get('group_size', 128) @@ -1215,16 +1227,15 @@ def build_linear(in_features: int, all_reduce=all_reduce, ) if quant_method == 'smooth_quant': - return W8A8Linear( - in_features, - out_features, - bias=bias, - dtype=dtype, - device=device, - colwise=colwise, - is_tp=is_tp, - all_reduce=all_reduce, - ) + return W8A8Linear(in_features, + out_features, + bias=bias, + dtype=dtype, + device=device, + colwise=colwise, + is_tp=is_tp, + all_reduce=all_reduce, + quant_dtype=quant_dtype) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -1299,6 +1310,10 @@ def build_merged_colwise_linear( ) quant_method = quant_config['quant_method'] + quant_dtype = torch.int8 + if 'quant_dtype' in quant_config: + quant_dtype = eval('torch.' + quant_config['quant_dtype']) + if quant_method == 'awq': w_bit = quant_config.get('bits', 4) group_size = quant_config.get('group_size', 128) @@ -1312,15 +1327,14 @@ def build_merged_colwise_linear( is_tp=is_tp, ) if quant_method == 'smooth_quant': - return MergedW8A8Linear( - in_features=in_features, - all_out_features=all_out_features, - bias=bias, - dtype=dtype, - device=device, - is_tp=is_tp, - out_names=out_names, - ) + return MergedW8A8Linear(in_features=in_features, + all_out_features=all_out_features, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + out_names=out_names, + quant_dtype=quant_dtype) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -1357,6 +1371,10 @@ def build_qkv_proj(in_features: int, num_replicate_kv_heads=num_replicate_kv_heads) quant_method = quant_config['quant_method'] + quant_dtype = torch.int8 + if 'quant_dtype' in quant_config: + quant_dtype = eval('torch.' + quant_config['quant_dtype']) + if quant_method == 'awq': w_bit = quant_config.get('bits', 4) group_size = quant_config.get('group_size', 128) @@ -1381,6 +1399,7 @@ def build_qkv_proj(in_features: int, dtype=dtype, device=device, is_tp=is_tp, - num_replicate_kv_heads=num_replicate_kv_heads) + num_replicate_kv_heads=num_replicate_kv_heads, + quant_dtype=quant_dtype) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py index ef244ff73f..ba565263c3 100644 --- a/lmdeploy/pytorch/nn/norm.py +++ b/lmdeploy/pytorch/nn/norm.py @@ -9,14 +9,15 @@ def _is_w8a8(quant_config: Any): """is w8a8.""" - if quant_config is None: - return False - else: + quant_dtype = None + w8a8_flag = False + if quant_config is not None: quant_method = quant_config['quant_method'] - if quant_method == 'w8a8': - return True - else: - return False + if quant_method == 'smooth_quant': + w8a8_flag = True + quant_dtype = quant_config.get('quant_dtype', 'int8') + quant_dtype = eval(f'torch.{quant_dtype}') + return w8a8_flag, quant_dtype class RMSNorm(nn.Module): @@ -30,13 +31,20 @@ def __init__(self, quant_config: Any = None): super().__init__() backend = get_backend() - if _is_w8a8(quant_config): + + w8a8_flag, quant_dtype = _is_w8a8(quant_config) + if w8a8_flag: builder = backend.get_layer_impl_builder(OpType.RMSNormW8A8) else: builder = backend.get_layer_impl_builder(OpType.RMSNorm) self.register_parameter('weight', self.create_weight(hidden_size, dtype, device)) - self.impl = builder.build(hidden_size, eps) + if w8a8_flag: + self.impl = builder.build(hidden_size, + eps, + quant_dtype=quant_dtype) + else: + self.impl = builder.build(hidden_size, eps) @staticmethod def create_weight(hidden_size: int, diff --git a/tests/pytorch/kernel/test_fused_moe.py b/tests/pytorch/kernel/test_fused_moe.py index 9faa742f1d..cc309eb6a7 100644 --- a/tests/pytorch/kernel/test_fused_moe.py +++ b/tests/pytorch/kernel/test_fused_moe.py @@ -266,7 +266,7 @@ def quant_weight(self, w): per_channel_quant num_experts, num_outs, _ = w.shape w = w.flatten(0, -2) - w_i8, w_scale = per_channel_quant(w, 8, torch.int8) + w_i8, w_scale = per_channel_quant(w, torch.int8) w_i8 = w_i8.view(num_experts, num_outs, -1) w_scale = w_scale.view(num_experts, num_outs, -1) return w_i8, w_scale