Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sm90 dispatch change #5

Merged
merged 10 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
requantize_with_max_scale,
convert_to_channelwise,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter

from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -153,6 +154,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight, layer.weight_scale, layer.logical_widths
)
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
# cutlass sgl-kernel only supports per-channel scale
if self.cutlass_fp8_supported:
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)

Expand Down
96 changes: 53 additions & 43 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,25 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
apply_fp8_linear,
convert_to_channelwise,
cutlass_fp8_supported,
per_tensor_dequantize,
requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter, ChannelQuantScaleParameter

from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_utils import (
BlockQuantScaleParameter,
apply_w8a8_block_fp8_linear,
normalize_e4m3fn_to_e4m3fnuz,
cutlass_fp8_supported,
input_to_float8,
apply_fp8_linear,
)
from sglang.srt.utils import (
get_bool_env_var,
Expand All @@ -45,6 +47,7 @@
)

ACTIVATION_SCHEMES = ["static", "dynamic"]
WEIGHT_SCALE_SUPPORTED = ["tensor", "channel", "block"]

logger = logging.getLogger(__name__)

Expand All @@ -58,6 +61,7 @@ def __init__(
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
weight_block_size: List[int] = None,
weight_scale_type: str = "tensor",
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
Expand All @@ -67,9 +71,13 @@ def __init__(
)
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
if weight_scale_type not in WEIGHT_SCALE_SUPPORTED:
raise ValueError(f"Unsupported weight scale type {weight_scale_type}")
self.activation_scheme = activation_scheme
self.weight_scale_type = weight_scale_type
self.ignored_layers = ignored_layers or []
if weight_block_size is not None:
self.weight_scale_type = "block"
if not is_checkpoint_fp8_serialized:
raise ValueError(
f"The block-wise quantization only supports fp8-serialized checkpoint for now."
Expand Down Expand Up @@ -107,11 +115,14 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
# Default weight scale type is per-tensor
weight_scale_type = cls.get_from_keys_or(config, ["weight_scale_type"], "tensor")
return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
weight_scale_type=weight_scale_type,
)

def get_quant_method(
Expand Down Expand Up @@ -250,10 +261,17 @@ def create_weights(
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale_inv", scale)
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
if self.quant_config.weight_scale_type == "tensor":
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
else:
scale = ChannelQuantScaleParameter(
data=torch.empty(output_size_per_partition, 1, dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)

Expand All @@ -263,7 +281,6 @@ def create_weights(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)

scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
Expand All @@ -289,15 +306,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)

# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
)
if self.cutlass_fp8_supported or self.use_marlin:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
qweight, weight_scale = per_token_group_quant_fp8(layer.weight, layer.weight.shape[-1])
weight_scale = weight_scale.t().contiguous()
else:
# per-tensor quantization
qweight, weight_scale = input_to_float8(layer.weight)

# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
Expand All @@ -314,36 +329,31 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)

# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
weight = layer.weight
weight_scale = layer.weight_scale

# Dequant -> Quant with max scale so we can run per tensor.
if self.quant_config.weight_scale_type == "tensor":
if self.cutlass_fp8_supported or self.use_marlin:
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
else:
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale,
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)

# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale,
logical_widths=layer.logical_widths,
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)

weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
logical_widths=layer.logical_widths,
)

# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
Expand Down
111 changes: 109 additions & 2 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def _per_token_group_quant_fp8(
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
Expand All @@ -87,7 +88,7 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
dtype: The dype of output tensor.

Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
Expand Down Expand Up @@ -134,6 +135,112 @@ def per_token_group_quant_fp8(

return x_q, x_s

@triton.jit
def _static_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
y_s_repeat_ptr,
# Stride of input
y_stride,
# Collums of input
N,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
REPEAT_SCALE: tl.constexpr,
):
"""A Triton-accelerated function to perform quantization using the given scale on a
tensor

This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
if REPEAT_SCALE:
y_s_repeat_ptr += g_id

cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N

y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
y_s = tl.load(y_s_ptr).to(tl.float32)
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + cols, y_q, mask=mask)
if REPEAT_SCALE:
tl.store(y_s_repeat_ptr, y_s)


def static_quant_fp8(
x: torch.Tensor,
x_s: torch.Tensor,
repeat_scale: bool = False,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform static quantization using the given scale on an input tensor `x`.

It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.

Args:
x: The input tenosr with ndim >= 2.
x_s: The quantization scale.
repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
dtype: The dype of output tensor.

Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale"
finfo = torch.finfo(dtype)
fp8_max = finfo.max

if is_hip_:
fp8_max = 224.0

fp8_min = -fp8_max

x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if repeat_scale:
x_s_repeat = torch.empty(
(M, 1),
device=x.device,
dtype=torch.float32,
)
else:
x_s_repeat = None

BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_static_quant_fp8[(M,)](
x,
x_q,
x_s,
x_s_repeat,
N,
N,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
REPEAT_SCALE=repeat_scale,
num_warps=num_warps,
num_stages=num_stages,
)
x_s = x_s_repeat if repeat_scale else x_s
return x_q, x_s


@triton.jit
def _w8a8_block_fp8_matmul(
Expand Down
Loading
Loading