Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
use g_idx rather than actorder flag to decide whether to reorder
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Aug 30, 2024
1 parent ef08596 commit ca13cec
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ def _get_scheme_from_parts(
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
group_size=weight_quant.group_size)

# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
Expand All @@ -31,8 +31,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
actorder: bool = False):
group_size: Optional[int] = None):

self.pack_factor = 32 // num_bits
self.strategy = strategy
Expand All @@ -50,15 +49,6 @@ def __init__(self,

self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]

if actorder and self.group_size == -1:
# In this case, actorder == True is the same as actorder == False
# (since we have only one group per output channel)
logger.warning(
"Model must be quantized with group_size > 0 in order to use "
"activation ordering")
actorder = False
self.actorder = actorder

# Verify supported on platform.
verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size)
Expand All @@ -75,7 +65,6 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
**kwargs):

output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition

# If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1)
Expand Down Expand Up @@ -133,21 +122,21 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
dtype=torch.int64),
weight_loader=weight_loader)

# G_IDX (for activation reordering)
g_idx = BasevLLMParameter(data=torch.empty(input_size_per_partition,
dtype=torch.int32),
# group index (for activation reordering)
weight_g_idx = BasevLLMParameter(data=torch.full((input_size_per_partition, ),
-1,
dtype=torch.int32),
weight_loader=weight_loader)

layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("weight_g_idx", g_idx)
layer.register_parameter("weight_g_idx", weight_g_idx)

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.group_size = group_size
layer.is_k_full = marlin_is_k_full(self.actorder, is_row_parallel)

# Checkpoints are serialized in compressed-tensors format, which is
# different from marlin format. Handle repacking here.
Expand All @@ -159,7 +148,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.output_size_per_partition, device)

# Handle sorting for activation reordering if needed.
if self.actorder:
has_g_idx = -1 not in layer.weight_g_idx
if has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
Expand Down Expand Up @@ -188,7 +178,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
marlin_scales = marlin_permute_scales(
layer.weight_scale,
size_k=(layer.input_size
if self.actorder else layer.input_size_per_partition),
if has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from enum import Enum
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, Optional, Union

from pydantic import BaseModel, Field
from torch.nn import Module
Expand Down Expand Up @@ -40,6 +40,20 @@ class QuantizationStrategy(str, Enum):
TOKEN = "token"


class ActivationOrderingStrategy(str, Enum):
"""
Enum storing strategies for activation ordering
Weight := only reorder weight, not groups (default)
Grouped := reorder groups and weight
Off := do not reorder by activations
"""

WEIGHT = "weight"
GROUP = "group"
OFF = "off"


class QuantizationArgs(BaseModel):
"""
User facing arguments used to define a quantization config
Expand Down Expand Up @@ -69,7 +83,8 @@ class QuantizationArgs(BaseModel):
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
actorder: bool = False
actorder: Union[ActivationOrderingStrategy,
bool] = ActivationOrderingStrategy.OFF
observer: str = Field(
default="minmax",
description=("The class to use to compute the quantization param - "
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,16 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad=False)


def marlin_is_k_full(actorder: bool, is_row_parallel: bool) -> bool:
return (not actorder) or (actorder and not is_row_parallel)
def marlin_is_k_full(has_g_idx: bool, is_row_parallel: bool) -> bool:
return (not has_g_idx) or (not is_row_parallel)


def marlin_repeat_scales_on_all_ranks(actorder: bool, group_size: int,
def marlin_repeat_scales_on_all_ranks(has_g_idx: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if actorder or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return actorder or (is_channelwise and is_row_parallel)
return has_g_idx or (is_channelwise and is_row_parallel)


def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
Expand Down

0 comments on commit ca13cec

Please sign in to comment.