Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fp8 w8a8 for pt backend #2959

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 37 additions & 30 deletions docs/en/quantization/w8a8.md
Original file line number Diff line number Diff line change
@@ -1,55 +1,62 @@
# SmoothQuant

LMDeploy provides functions for quantization and inference of large language models using 8-bit integers.
LMDeploy provides functions for quantization and inference of large language models using 8-bit integers(INT8). For GPUs such as Nvidia H100, lmdeploy also supports 8-bit floating point(FP8).

Before starting inference, ensure that lmdeploy and openai/triton are correctly installed. Execute the following commands to install these:
First of all, run the following command to install lmdeploy:

```shell
pip install lmdeploy
pip install triton>=2.1.0
pip install lmdeploy[all]
```

## 8-bit Weight Model Inference
## 8-bit Weight Quantization

For performing 8-bit weight model inference, you can directly download the pre-quantized 8-bit weight models from LMDeploy's [model zoo](https://huggingface.co/lmdeploy). For instance, the 8-bit Internlm-chat-7B model is available for direct download from the model zoo:
Performing 8-bit weight quantization involves three steps:

```shell
git-lfs install
git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon)
```
1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing.
2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file.
3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model.

Alternatively, you can manually convert original 16-bit weights into 8-bit by referring to the content under the ["8bit Weight Quantization"](#8bit-weight-quantization) section. Save them in the internlm-chat-7b-w8 directory, using the command below:
lmdeploy provides `lmdeploy lite smooth_quant` command to accomplish all three tasks detailed above. Note that the argument `--quant-dtype` is used to determine if you are doing int8 or fp8 weight quantization. To get more info about usage of the cli, run `lmdeploy lite smooth_quant --help`

```shell
lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
```
Here are two examples:

Afterwards, use the following command to interact with the model via the terminal:
- int8

```shell
lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch
```
```shell
lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8-w8a8 --quant-dtype int8
```

## Launching gradio service
- fp8

Coming soon...
```shell
lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8-w8a8 --quant-dtype fp8
```

## Inference Speed
## Inference

Coming soon...
Trying the following codes, you can perform the batched offline inference with the quantized model:

## 8bit Weight Quantization
```python
from lmdeploy import pipeline, PytorchEngineConfig

Performing 8bit weight quantization involves three steps:
engine_config = PytorchEngineConfig(tp=1)
pipe = pipeline("internlm2_5-7b-chat-int8-w8a8", backend_config=engine_config)
response = pipe(["Hi, pls intro yourself", "Shanghai is"])
print(response)
```

1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing.
2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file.
3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model.
## Service

LMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup:

```shell
lmdeploy serve api_server ./internlm2_5-7b-chat-int8-w8a8 --backend pytorch
```

The script `lmdeploy/lite/apis/smooth_quant.py` accomplishes all three tasks detailed above. For example, you can obtain the model weights of the quantized Internlm-chat-7B model by running the following command:
The default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`:

```shell
lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
lmdeploy serve api_client http://0.0.0.0:23333
```

After saving, you can instantiate your quantized model by calling the from_pretrained interface.
You can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can also read the API specification from [here](../llm/api_server.md).
68 changes: 38 additions & 30 deletions docs/zh_cn/quantization/w8a8.md
Original file line number Diff line number Diff line change
@@ -1,56 +1,64 @@
# W8A8 LLM 模型部署

LMDeploy 提供了使用 8 bit 整数对神经网络模型进行量化和推理的功能
LMDeploy 提供了使用 8-bit 整数(INT8)和浮点数(FP8)对神经网络模型进行量化和推理的功能

在开始推理前,需要确保已经正确安装了 lmdeploy 和 openai/triton。可以通过以下命令进行安装
首先,执行如下命令安装lmdeploy

```shell
pip install lmdeploy
pip install triton>=2.1.0
pip install lmdeploy[all]
```

## 8bit 权重模型推理
## 8-bit 权重量化

如果你需要进行 8 bit 权重模型推理,可以直接从 LMDeploy 的 [model zoo](https://huggingface.co/lmdeploy) 下载已经量化好的 8bit 权重模型。以8bit 的 Internlm-chat-7B 模型为例,可以从 model zoo 直接下载
进行 8-bit 权重量化需要经历以下三步

```shell
git-lfs install
git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon)
```
1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。
2. **模块替换**:使用 `QRMSNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RMSNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。
3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。

你也可以参考["8bit 权重量化"](#8bit-权重量化)章节的内容手动将原 16bit 权重量化为 8bit,并保存至 `internlm-chat-7b-w8` 目录下,操作命令如下:
lmdeploy 提供了命令行工具 `lmdeploy lite smooth_quant` 实现了以上三个步骤。并且其中命令行参数 `--quant-dtype` 可以用来控制是进行8-bit整数还是浮点数类型的量化。更多命令行工具使用方式,请执行 `lmdeploy lite smooth_quant --help` 查看。

```shell
lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
```
以下示例演示了进行 int8 或 fp8 的量化命令。

然后,执行以下命令,即可在终端与模型对话:
- int8

```shell
lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch
```
```shell
lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8-w8a8 --quant-dtype int8
```

## 启动 gradio 服务
- fp8

Coming soon...
```shell
lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8-w8a8 --quant-dtype fp8
```

## 推理速度
## 模型推理

Coming soon...
量化后的模型,通过以下几行简单的代码,可以实现离线推理:

## 8bit 权重量化
```python
from lmdeploy import pipeline, PytorchEngineConfig

进行 8bit 权重量化需要经历以下三步:
engine_config = PytorchEngineConfig(tp=1)
pipe = pipeline("internlm2_5-7b-chat-int8-w8a8", backend_config=engine_config)
response = pipe(["Hi, pls intro yourself", "Shanghai is"])
print(response)
```

1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。
2. **模块替换**:使用 `QRSMNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RSMNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。
3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。
关于 pipeline 的详细介绍,请参考[这里](../llm/pipeline.md)

## 推理服务

我们在`lmdeploy/lite/api/smooth_quantity.py`脚本中已经实现了以上三个步骤。例如,可以通过以下命令得到量化后的 Internlm-chat-7B 模型的模型权重
LMDeploy `api_server` 支持把模型一键封装为服务,对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例

```shell
lmdeploy serve api_server ./internlm2_5-7b-chat-int8-w8a8 --backend pytorch
```

服务默认端口是23333。在 server 启动后,你可以在终端通过`api_client`与server进行对话:

lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
```shell
lmdeploy serve api_client http://0.0.0.0:23333
```

保存之后,你就可以通过调用from_pretrained接口来实例化你的量化模型
还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法
1 change: 1 addition & 0 deletions lmdeploy/cli/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def add_parser_smooth_quant():
ArgumentHelper.calib_batchsize(parser)
ArgumentHelper.calib_search_scale(parser)
ArgumentHelper.dtype(parser)
ArgumentHelper.quant_dtype(parser)

@staticmethod
def auto_awq(args):
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ def dtype(parser, default: str = 'auto'):
'for BF16 models. This option will be ignored if '
'the model is a quantized model')

@staticmethod
def quant_dtype(parser, default: str = 'int8'):
return parser.add_argument(
'--quant-dtype',
type=str,
default=default,
choices=['int8', 'float8_e4m3fn', 'float8_e5m2', 'fp8'],
help='data type for the quantized model weights and activations.'
'Note "fp8" is the short version of "float8_e4m3fn"')

@staticmethod
def model_format(parser, default: str = None):
return parser.add_argument(
Expand Down
22 changes: 18 additions & 4 deletions lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,19 @@ def smooth_quant(model: str,
batch_size: int = 1,
w_bits: int = 8,
dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
device: str = 'cuda'):
device: str = 'cuda',
quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn',
'float8_e5m2'] = 'int8'):
if quant_dtype == 'fp8':
quant_dtype = 'float8_e4m3fn'

quant_dtype = getattr(torch, quant_dtype, torch.int8)
if quant_dtype.is_floating_point:
q_dtype_info = torch.finfo(quant_dtype)
else:
q_dtype_info = torch.iinfo(quant_dtype)

assert q_dtype_info.bits == w_bits
model_path = model
vl_model, model, tokenizer, work_dir = calibrate(model,
calib_dataset,
Expand Down Expand Up @@ -84,7 +96,7 @@ def smooth_quant(model: str,
if skipped_module(name):
continue
linear.to(device)
q_linear = QLinear.from_float(linear)
q_linear = QLinear.from_float(linear, quant_dtype=quant_dtype)
parent_name, _, child_name = name.rpartition('.')
parent = model.get_submodule(parent_name)
setattr(parent, child_name, q_linear)
Expand All @@ -94,7 +106,7 @@ def smooth_quant(model: str,
if skipped_module(name):
continue
norm.to(device)
q_norm = QRMSNorm.from_float(norm)
q_norm = QRMSNorm.from_float(norm, quant_dtype=quant_dtype)
parent_name, _, child_name = name.rpartition('.')
parent = model.get_submodule(parent_name)
setattr(parent, child_name, q_norm)
Expand All @@ -104,8 +116,10 @@ def smooth_quant(model: str,
from .auto_awq import save_vl_model
save_vl_model(vl_model, model_path, work_dir)
else:
quant_dtype_s = str(quant_dtype).split('.')[1]
model.config.update(
dict(quantization_config=dict(quant_method='smooth_quant')))
dict(quantization_config=dict(quant_method='smooth_quant',
quant_dtype=f'{quant_dtype_s}')))
model.save_pretrained(work_dir,
max_shard_size='2GB',
safe_serialization=False)
Expand Down
51 changes: 36 additions & 15 deletions lmdeploy/pytorch/backends/cuda/qmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,60 @@
class TritonRMSNormW8A8Impl(RMSNormW8A8Impl):
"""triton RMS norm w8a8 implementation api."""

def __init__(self, hidden_size: int, eps: float = 1e-6):
def __init__(self,
hidden_size: int,
eps: float = 1e-6,
quant_dtype: torch.dtype = torch.int8):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.quant_dtype = quant_dtype

def forward(self,
x: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor = None):
"""forward."""
if residual is not None:
x = x + residual
residual = x
hidden_states_quant, rms_scale = rms_norm_dynamic_quant(
x, weight, self.eps)
x = QTensor(hidden_states_quant, rms_scale)
if residual is None:
(x,
rms_scale) = rms_norm_dynamic_quant(x,
weight,
self.eps,
quant_dtype=self.quant_dtype)
x = QTensor(x, rms_scale)
return x
return x, residual
else:
(x, rms_scale,
residual) = rms_norm_dynamic_quant(x,
weight,
self.eps,
residual=residual,
quant_dtype=self.quant_dtype)
x = QTensor(x, rms_scale)
return x, residual


class TritonRMSNormBuilder(RMSNormW8A8Builder):
"""triton RMS norm w8a8 implementation builder."""

@staticmethod
def build(hidden_size: int, eps: float = 1e-6):
def build(hidden_size: int,
eps: float = 1e-6,
quant_dtype: torch.dtype = torch.int8):
"""build."""
return TritonRMSNormW8A8Impl(hidden_size, eps)
return TritonRMSNormW8A8Impl(hidden_size, eps, quant_dtype)


class TritonLinearW8A8Impl(LinearW8A8Impl):
"""triton linear w8a8 implementation."""

def __init__(self, in_features: int, out_features: int):
def __init__(self,
in_features: int,
out_features: int,
quant_dtype: torch.dtype = torch.int8):
self.in_features = in_features
self.out_features = out_features
self.quant_dtype = quant_dtype

def forward(self,
x,
Expand All @@ -60,8 +78,8 @@ def forward(self,
all_reduce: bool = False):
"""forward."""
if isinstance(x, torch.Tensor):
x = x.contiguous()
input_quant, input_scale = per_token_quant_int8(x, 1e-7)
input_quant, input_scale = per_token_quant_int8(
x, 1e-7, quant_dtype=self.quant_dtype)
else:
assert isinstance(x, QTensor)
input_quant, input_scale = x.tensor, x.scale
Expand All @@ -85,6 +103,9 @@ class TritonLinearW8A8Builder(LinearW8A8Builder):
def build(in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None):
dtype: torch.dtype = None,
quant_dtype: torch.dtype = torch.int8):
"""build."""
return TritonLinearW8A8Impl(in_features, out_features)
return TritonLinearW8A8Impl(in_features,
out_features,
quant_dtype=quant_dtype)
7 changes: 5 additions & 2 deletions lmdeploy/pytorch/backends/qmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class RMSNormW8A8Builder(ABC):

@staticmethod
@abstractmethod
def build(hidden_size: int, eps: float = 1e-6):
def build(hidden_size: int,
eps: float = 1e-6,
quant_dtype: torch.dtype = torch.int8):
"""build."""
raise NotImplementedError

Expand Down Expand Up @@ -71,6 +73,7 @@ class LinearW8A8Builder(ABC):
def build(in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None):
dtype: torch.dtype = None,
quant_dtype: torch.dtype = torch.int8):
"""build."""
raise NotImplementedError
Loading
Loading