Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Misc ] fbgemm checkpoints #6559

Merged
merged 39 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
2426e29
stash
robertgshaw2-redhat Jul 18, 2024
7665b7b
format
robertgshaw2-redhat Jul 18, 2024
ef27613
tweak arg name
robertgshaw2-redhat Jul 18, 2024
2f96157
fix test
robertgshaw2-redhat Jul 18, 2024
e748554
format
robertgshaw2-redhat Jul 18, 2024
3ef571b
working e2e with our cutlass kernels
robertgshaw2-redhat Jul 19, 2024
ad83666
added fp8 gemm
robertgshaw2-redhat Jul 19, 2024
eb7d48c
remove
robertgshaw2-redhat Jul 19, 2024
90bd839
format
robertgshaw2-redhat Jul 19, 2024
15cc823
Merge branch 'main' into turn-on-fp8-dyn-per-token
robertgshaw2-redhat Jul 19, 2024
d064dd7
stash
robertgshaw2-redhat Jul 19, 2024
6aa37e5
dynamic per token
robertgshaw2-redhat Jul 19, 2024
c9d819a
format
robertgshaw2-redhat Jul 19, 2024
08cbaf7
reenable cutlass
robertgshaw2-redhat Jul 19, 2024
f4cdda1
cleanup comment
robertgshaw2-redhat Jul 19, 2024
2971f4d
format
robertgshaw2-redhat Jul 19, 2024
b601033
added dynamic per token test case
robertgshaw2-redhat Jul 19, 2024
5d8edf9
Merge branch 'turn-on-fp8-dyn-per-token' into fbgemm-checkpoints
robertgshaw2-redhat Jul 19, 2024
8b5d638
added use per token
robertgshaw2-redhat Jul 19, 2024
006ccf0
format
Jul 19, 2024
1884acf
format
Jul 19, 2024
fe14072
Make optional ubs none
Jul 19, 2024
254dcff
format
Jul 19, 2024
919d866
Merge branch 'fp8-dpt-fpgemm' into fbgemm-checkpoints
robertgshaw2-redhat Jul 19, 2024
227a277
hook up end to end with varun's ub quant kernel
robertgshaw2-redhat Jul 19, 2024
951834a
formatted
robertgshaw2-redhat Jul 19, 2024
9aa66d3
updated for nonuniform
robertgshaw2-redhat Jul 19, 2024
458a410
formatting after passing prefix around
robertgshaw2-redhat Jul 19, 2024
278f6d6
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-redhat Jul 19, 2024
3e4aaad
fixed bad merge
robertgshaw2-redhat Jul 19, 2024
de2a764
updated message
robertgshaw2-redhat Jul 20, 2024
268fe94
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-redhat Jul 20, 2024
c88fe34
merged varun's pr
robertgshaw2-redhat Jul 20, 2024
bb02a3f
fixed
robertgshaw2-redhat Jul 20, 2024
1c8f71c
cleanup pr
robertgshaw2-redhat Jul 20, 2024
6970e50
Update config.py
robertgshaw2-redhat Jul 20, 2024
94617f0
fixed config
robertgshaw2-redhat Jul 20, 2024
f9d569c
updated for new ckpt format, turned on ada lovelace, and added test case
robertgshaw2-redhat Jul 20, 2024
ae45615
format
robertgshaw2-redhat Jul 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
if scale_ub is not None:
x_token_max = x_token_max.clamp(max=scale_ub)
scales = (x_token_max / qtype_max)[:, None]
if quant_dtype == torch.float8_e4m3fn:
min_scaling_factor = s_1 / (qtype_max * s_512)
scales = scales.clamp(min=min_scaling_factor)

# Quant
if quant_dtype == torch.int8:
Expand Down
4 changes: 3 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -315,6 +315,8 @@ def scaled_fp8_quant(
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
prefix: str = "",
) -> None:
super().__init__()
if cache_config is not None:
Expand All @@ -56,7 +57,7 @@ def __init__(
self._k_scale = 1.0
self._v_scale = 1.0
quant_method = quant_config.get_quant_method(
self) if quant_config else None
self, prefix=prefix) if quant_config else None
if quant_method is not None:
assert isinstance(quant_method, Fp8KVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _verify_quantization(self) -> None:
f"supported in ROCm.")
if (self.quantization
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
"compressed_tensors")):
"fpgemm_fp8", "compressed_tensors")):
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self)
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None

self.quant_method.create_weights(
Expand Down
26 changes: 16 additions & 10 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()

Expand All @@ -155,7 +156,8 @@ def __init__(
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self)
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)

def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
Expand All @@ -182,9 +184,13 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
prefix: str = ""):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix)

# All the linear layer supports quant method.
assert self.quant_method is not None
Expand Down Expand Up @@ -258,9 +264,9 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
prefix: Optional[str] = None):
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
quant_config, prefix)

self.gather_output = gather_output

Expand Down Expand Up @@ -370,7 +376,7 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
prefix: str = ""):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
Expand Down Expand Up @@ -514,7 +520,7 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
prefix: str = ""):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
Expand Down Expand Up @@ -707,9 +713,9 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
quant_config, prefix)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
Expand All @@ -24,6 +25,7 @@
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AQLMLinearMethod"]:
if isinstance(layer, LinearBase):
return AQLMLinearMethod(self)
return None
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQLinearMethod"]:
if isinstance(layer, LinearBase):
return AWQLinearMethod(self)
return None
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ def get_from_keys_or(config: Dict[str, Any], keys: List[str],
return default

@abstractmethod
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.

Args:
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
The quantize method. None if the given layer doesn't support quant
method.
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules)

def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
if isinstance(layer, LinearBase):
return BitsAndBytesLinearMethod(self)
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ def get_min_capability(cls) -> int:
def get_name(self) -> str:
return "compressed_tensors"

# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def get_quant_method(
self, layer: torch.nn.Module
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["CompressedTensorsLinearMethod"]:
if isinstance(layer, LinearBase):
return CompressedTensorsLinearMethod(self)
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/layers/quantization/deepspeedfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ def get_config_filenames() -> List[str]:
"quantize_config.json",
]

def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
if isinstance(layer, LinearBase):
return DeepSpeedFPLinearMethod(self)
return None
Expand Down
160 changes: 160 additions & 0 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)

# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}


class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""

def __init__(self, ignore_list: List[str]):
self.ignore_list = ignore_list

@classmethod
def get_name(cls) -> str:
return "fbgemm_fp8"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.float16]

@classmethod
def get_min_capability(cls) -> int:
return 90

@classmethod
def get_config_filenames(cls) -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
return cls(ignore_list=ignore_list)

def _is_layer_skipped(self, prefix: str) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
]
mgoin marked this conversation as resolved.
Show resolved Hide resolved

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in self.ignore_list

if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in self.ignore_list

assert is_skipped is not None
return is_skipped

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self._is_layer_skipped(prefix):
return UnquantizedLinearMethod()
return FBGEMMFp8LinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class FBGEMMFp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)

layer.logical_widths = output_partition_sizes

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype

# WEIGHT
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
**extra_weight_attrs,
})

# WEIGHT SCALE
weight_scale = create_per_channel_scale_param(output_partition_sizes,
**extra_weight_attrs)
layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE UPPER BOUND
input_scale_ub = torch.nn.Parameter(torch.zeros((1),
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale_ub", input_scale_ub)
set_weight_attrs(input_scale_ub, {
"ignore_warning": True,
**extra_weight_attrs
})

def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

return apply_fp8_linear(input=x,
Copy link

@vlasenkoalexey vlasenkoalexey Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that apply_fp8_linear would work here. Using fbgemm it would look like:

        xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
            x, num_tokens, self.activation_scale_ub
        )
        y = torch.ops.fbgemm.f8f8bf16_rowwise(
            xq, layer.weight, x_scale, layer.weight_scale, use_fast_accum=True
        )

Particularly input_scale=None is most likely wrong, here is a reference implementation for quantize_fp8_per_row

def fp8_row_quantize_ref(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    # Quantize an input tensor and return the fp8 tensor and its inverse scale.
    x_row_max = torch.max(torch.abs(x), dim=1).values
    max_scaling_factor = E4M3_MAX_POS * 512.0  # Match kernel logics
    scale = torch.Tensor(E4M3_MAX_POS / x_row_max).clamp(max=max_scaling_factor)
    xq = (x * scale.unsqueeze(1)).to(fp8_e4m3)
    return xq, scale.reciprocal().to(torch.float32)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, we are updating the quant kernel right now to use the activation_scale_ub

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in combo with https://github.com/vllm-project/vllm/pull/6547/files

which enables per token scales

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides both torch._scaled_mm and ops.scaled_fp8_quant expect scale to be scalar

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#6547 extends ops.scaled_fp8_quant to accepts per token scales (vector of scales)

Copy link
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch._scaled_mm will not be used.

cutlass_scaled_mm accepts per channel weights and per token activations

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pulled in #6547 to this PR so you can see

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adds #6593 adds the ub

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this PR currently now has the

  • scale_ub
  • uses dynamic per token activation scales

weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
mgoin marked this conversation as resolved.
Show resolved Hide resolved
input_scale_ub=layer.input_scale_ub,
bias=bias,
cutlass_fp8_supported=True,
use_per_token_if_dynamic=True)
Loading
Loading