Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting quantized weights from Quark by default. #47

Merged
merged 9 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,23 @@ The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head

## Fp8 Quantization

To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder.
To use fp8 quantization, first step is to quantize your model to fp8 format.

Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.
By default, rocm-vllm accepts the quantized weights generated by Quark quantizer. To do this, install quark and run the command:

```
python3 quantize_quark.py --model_dir [llama2 checkpoint folder] \
--output_dir output_dir \
--quant_scheme w_fp8_a_fp8_o_fp8 \
--num_calib_data 128 \
--export_safetensors \
--no_weight_matrix_merge
```
For more details, please refer to Quark's documentation.

To use ammo, please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer), and set `VLLM_FP8_USE_AMMO=1`.

Both quantizers generate a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.

## Gemm Tuning for Fp8

Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/fp8/amd/gemm_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <cstdint>
#include <cstdio>

#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContextLight.h>
Expand Down
151 changes: 106 additions & 45 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import os
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -441,57 +442,117 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

def load_quantized_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]):
params_dict = dict(self.named_parameters())
#with open("/projects/a.txt", "r") as f:
# j = json.load(f)
# for k, v in j.items():
# params_dict[k].data.copy_(v)
quant_shards = [
("mlp.gate_up_proj", "mlp.fc", 0), # fc is gate_proj
("mlp.gate_up_proj", "mlp.gate", 1), # gate is up_proj
]
quant_map = [
("mlp.down_proj", "mlp.proj"),
("self_attn.o_proj", "attention.dense"),
("self_attn.qkv_proj", "attention.qkv"),
]
for name, loaded_weight in weights:
#print(name)
name = name.replace('transformer', 'model')
name = name.replace('kv_cache_scaling_factor',
'qkv.output_scaling_factor')
loaded_weight = loaded_weight.to("cuda")
if loaded_weight.dtype == torch.int8:
loaded_weight[loaded_weight == -128] = 0
assert loaded_weight.is_contiguous
loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz)
for (param_name, weight_name, shard_id) in quant_shards:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:

def load_ammo():
params_dict = dict(self.named_parameters())
quant_shards = [
("mlp.gate_up_proj", "mlp.fc", 0), # fc is gate_proj
("mlp.gate_up_proj", "mlp.gate", 1), # gate is up_proj
]
quant_map = [
("mlp.down_proj", "mlp.proj"),
("self_attn.o_proj", "attention.dense"),
("self_attn.qkv_proj", "attention.qkv"),
]
for name, loaded_weight in weights:
name = name.replace('transformer', 'model')
name = name.replace('kv_cache_scaling_factor',
'qkv.output_scaling_factor')
loaded_weight = loaded_weight.to("cuda")
if loaded_weight.dtype == torch.int8:
loaded_weight[loaded_weight == -128] = 0
assert loaded_weight.is_contiguous
loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz)
for (param_name, weight_name, shard_id) in quant_shards:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
for (param_name, weight_name) in quant_map:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
if ("activation_scaling_factor" in name
or "weights_scaling_factor" in name
or "output_scaling_factor" in name):
param.data.copy_(loaded_weight)
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
break

def load_quark():
params_dict = dict(self.named_parameters())
quant_shards = [
("mlp.gate_up_proj", "mlp.gate_proj", 0), # fc is gate_proj
("mlp.gate_up_proj", "mlp.up_proj", 1), # gate is up_proj
]
quant_map = [
("mlp.down_proj", "mlp.down_proj"),
("self_attn.o_proj", "self_attn.o_proj"),
("self_attn.qkv_proj", "self_attn.qkv"),
]
scaling_factor_map = [
("activation_scaling_factor", "input_quant_scale"),
("weights_scaling_factor", "weight_quant_scale"),
("output_scaling_factor", "output_quant_scale"),
]
for name, loaded_weight in weights:
if "zero_point" in name:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
for (param_name, weight_name) in quant_map:
if len(loaded_weight.shape) == 0:
loaded_weight = torch.Tensor([loaded_weight])
# replace the name for scaling factor
for (scale_name, weight_name) in scaling_factor_map:
if weight_name not in name:
continue
name = name.replace(weight_name, scale_name)
if loaded_weight.dtype == torch.int8:
loaded_weight[loaded_weight == -128] = 0
assert loaded_weight.is_contiguous
loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz)

for (param_name, weight_name, shard_id) in quant_shards:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if ("activation_scaling_factor" in name
or "weights_scaling_factor" in name
or "output_scaling_factor" in name):
param.data.copy_(loaded_weight)
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
for (param_name, weight_name) in quant_map:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
if ("activation_scaling_factor" in name
or "weights_scaling_factor" in name
or "output_scaling_factor" in name):
param.data.copy_(loaded_weight)
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
break

load_func = load_ammo if os.getenv(
"VLLM_FP8_USE_AMMO") == "1" else load_quark
load_func()

# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
Expand Down
Loading