Skip to content

Commit

Permalink
Merge pull request #5 from HandH1998/tmp_yych
Browse files Browse the repository at this point in the history
sm90 dispatch change
  • Loading branch information
HandH1998 authored Jan 22, 2025
2 parents ca48f05 + 2812abb commit 0b6a79e
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 63 deletions.
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
requantize_with_max_scale,
convert_to_channelwise,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter

from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -153,6 +154,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight, layer.weight_scale, layer.logical_widths
)
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
# cutlass sgl-kernel only supports per-channel scale
if self.cutlass_fp8_supported:
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)

Expand Down
96 changes: 53 additions & 43 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,25 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
apply_fp8_linear,
convert_to_channelwise,
cutlass_fp8_supported,
per_tensor_dequantize,
requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter, ChannelQuantScaleParameter

from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_utils import (
BlockQuantScaleParameter,
apply_w8a8_block_fp8_linear,
normalize_e4m3fn_to_e4m3fnuz,
cutlass_fp8_supported,
input_to_float8,
apply_fp8_linear,
)
from sglang.srt.utils import (
get_bool_env_var,
Expand All @@ -45,6 +47,7 @@
)

ACTIVATION_SCHEMES = ["static", "dynamic"]
WEIGHT_SCALE_SUPPORTED = ["tensor", "channel", "block"]

logger = logging.getLogger(__name__)

Expand All @@ -58,6 +61,7 @@ def __init__(
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
weight_block_size: List[int] = None,
weight_scale_type: str = "tensor",
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
Expand All @@ -67,9 +71,13 @@ def __init__(
)
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
if weight_scale_type not in WEIGHT_SCALE_SUPPORTED:
raise ValueError(f"Unsupported weight scale type {weight_scale_type}")
self.activation_scheme = activation_scheme
self.weight_scale_type = weight_scale_type
self.ignored_layers = ignored_layers or []
if weight_block_size is not None:
self.weight_scale_type = "block"
if not is_checkpoint_fp8_serialized:
raise ValueError(
f"The block-wise quantization only supports fp8-serialized checkpoint for now."
Expand Down Expand Up @@ -107,11 +115,14 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
# Default weight scale type is per-tensor
weight_scale_type = cls.get_from_keys_or(config, ["weight_scale_type"], "tensor")
return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
weight_scale_type=weight_scale_type,
)

def get_quant_method(
Expand Down Expand Up @@ -250,10 +261,17 @@ def create_weights(
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale_inv", scale)
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
if self.quant_config.weight_scale_type == "tensor":
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
else:
scale = ChannelQuantScaleParameter(
data=torch.empty(output_size_per_partition, 1, dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)

Expand All @@ -263,7 +281,6 @@ def create_weights(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)

scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
Expand All @@ -289,15 +306,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)

# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
)
if self.cutlass_fp8_supported or self.use_marlin:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
qweight, weight_scale = per_token_group_quant_fp8(layer.weight, layer.weight.shape[-1])
weight_scale = weight_scale.t().contiguous()
else:
# per-tensor quantization
qweight, weight_scale = input_to_float8(layer.weight)

# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
Expand All @@ -314,36 +329,31 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)

# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
weight = layer.weight
weight_scale = layer.weight_scale

# Dequant -> Quant with max scale so we can run per tensor.
if self.quant_config.weight_scale_type == "tensor":
if self.cutlass_fp8_supported or self.use_marlin:
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
else:
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale,
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)

# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale,
logical_widths=layer.logical_widths,
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)

weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
logical_widths=layer.logical_widths,
)

# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
Expand Down
111 changes: 109 additions & 2 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def _per_token_group_quant_fp8(
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
Expand All @@ -87,7 +88,7 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
Expand Down Expand Up @@ -134,6 +135,112 @@ def per_token_group_quant_fp8(

return x_q, x_s

@triton.jit
def _static_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
y_s_repeat_ptr,
# Stride of input
y_stride,
# Collums of input
N,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
REPEAT_SCALE: tl.constexpr,
):
"""A Triton-accelerated function to perform quantization using the given scale on a
tensor
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
if REPEAT_SCALE:
y_s_repeat_ptr += g_id

cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N

y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
y_s = tl.load(y_s_ptr).to(tl.float32)
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + cols, y_q, mask=mask)
if REPEAT_SCALE:
tl.store(y_s_repeat_ptr, y_s)


def static_quant_fp8(
x: torch.Tensor,
x_s: torch.Tensor,
repeat_scale: bool = False,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform static quantization using the given scale on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
x_s: The quantization scale.
repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale"
finfo = torch.finfo(dtype)
fp8_max = finfo.max

if is_hip_:
fp8_max = 224.0

fp8_min = -fp8_max

x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if repeat_scale:
x_s_repeat = torch.empty(
(M, 1),
device=x.device,
dtype=torch.float32,
)
else:
x_s_repeat = None

BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_static_quant_fp8[(M,)](
x,
x_q,
x_s,
x_s_repeat,
N,
N,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
REPEAT_SCALE=repeat_scale,
num_warps=num_warps,
num_stages=num_stages,
)
x_s = x_s_repeat if repeat_scale else x_s
return x_q, x_s


@triton.jit
def _w8a8_block_fp8_matmul(
Expand Down
Loading

0 comments on commit 0b6a79e

Please sign in to comment.