Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit ca13cec

Browse files
committed
use g_idx rather than actorder flag to decide whether to reorder
1 parent ef08596 commit ca13cec

File tree

4 files changed

+34
-30
lines changed

4 files changed

+34
-30
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,7 @@ def _get_scheme_from_parts(
232232
return CompressedTensorsWNA16(
233233
num_bits=weight_quant.num_bits,
234234
strategy=weight_quant.strategy,
235-
group_size=weight_quant.group_size,
236-
actorder=weight_quant.actorder)
235+
group_size=weight_quant.group_size)
237236

238237
# Detect If Activation Quantization.
239238
# TODO @dsikka: clean-up conditions

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
88
CompressedTensorsScheme)
99
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
10-
apply_gptq_marlin_linear, marlin_is_k_full, marlin_make_empty_g_idx,
11-
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
12-
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
10+
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
11+
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
12+
verify_marlin_supported, verify_marlin_supports_shape)
1313
from vllm.model_executor.parameter import (BasevLLMParameter,
1414
ChannelQuantScaleParameter,
1515
GroupQuantScaleParameter,
@@ -31,8 +31,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
3131
def __init__(self,
3232
strategy: str,
3333
num_bits: int,
34-
group_size: Optional[int] = None,
35-
actorder: bool = False):
34+
group_size: Optional[int] = None):
3635

3736
self.pack_factor = 32 // num_bits
3837
self.strategy = strategy
@@ -50,15 +49,6 @@ def __init__(self,
5049

5150
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
5251

53-
if actorder and self.group_size == -1:
54-
# In this case, actorder == True is the same as actorder == False
55-
# (since we have only one group per output channel)
56-
logger.warning(
57-
"Model must be quantized with group_size > 0 in order to use "
58-
"activation ordering")
59-
actorder = False
60-
self.actorder = actorder
61-
6252
# Verify supported on platform.
6353
verify_marlin_supported(quant_type=self.quant_type,
6454
group_size=self.group_size)
@@ -75,7 +65,6 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
7565
**kwargs):
7666

7767
output_size_per_partition = sum(output_partition_sizes)
78-
is_row_parallel = input_size != input_size_per_partition
7968

8069
# If group_size is -1, we are in channelwise case.
8170
channelwise = (self.group_size == -1)
@@ -133,21 +122,21 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
133122
dtype=torch.int64),
134123
weight_loader=weight_loader)
135124

136-
# G_IDX (for activation reordering)
137-
g_idx = BasevLLMParameter(data=torch.empty(input_size_per_partition,
138-
dtype=torch.int32),
125+
# group index (for activation reordering)
126+
weight_g_idx = BasevLLMParameter(data=torch.full((input_size_per_partition, ),
127+
-1,
128+
dtype=torch.int32),
139129
weight_loader=weight_loader)
140130

141131
layer.register_parameter("weight_packed", weight)
142132
layer.register_parameter("weight_scale", weight_scale)
143133
layer.register_parameter("weight_shape", weight_shape)
144-
layer.register_parameter("weight_g_idx", g_idx)
134+
layer.register_parameter("weight_g_idx", weight_g_idx)
145135

146136
layer.input_size_per_partition = input_size_per_partition
147137
layer.output_size_per_partition = output_size_per_partition
148138
layer.input_size = input_size
149139
layer.group_size = group_size
150-
layer.is_k_full = marlin_is_k_full(self.actorder, is_row_parallel)
151140

152141
# Checkpoints are serialized in compressed-tensors format, which is
153142
# different from marlin format. Handle repacking here.
@@ -159,7 +148,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
159148
layer.output_size_per_partition, device)
160149

161150
# Handle sorting for activation reordering if needed.
162-
if self.actorder:
151+
has_g_idx = -1 not in layer.weight_g_idx
152+
if has_g_idx:
163153
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
164154
layer.g_idx_sort_indices = g_idx_sort_indices
165155
replace_tensor(layer, "weight_g_idx", g_idx)
@@ -188,7 +178,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
188178
marlin_scales = marlin_permute_scales(
189179
layer.weight_scale,
190180
size_k=(layer.input_size
191-
if self.actorder else layer.input_size_per_partition),
181+
if has_g_idx else layer.input_size_per_partition),
192182
size_n=layer.output_size_per_partition,
193183
group_size=layer.group_size)
194184
replace_tensor(layer, "weight_scale", marlin_scales)

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
from enum import Enum
3-
from typing import Any, Dict, Iterable, Optional
3+
from typing import Any, Dict, Iterable, Optional, Union
44

55
from pydantic import BaseModel, Field
66
from torch.nn import Module
@@ -40,6 +40,20 @@ class QuantizationStrategy(str, Enum):
4040
TOKEN = "token"
4141

4242

43+
class ActivationOrderingStrategy(str, Enum):
44+
"""
45+
Enum storing strategies for activation ordering
46+
47+
Weight := only reorder weight, not groups (default)
48+
Grouped := reorder groups and weight
49+
Off := do not reorder by activations
50+
"""
51+
52+
WEIGHT = "weight"
53+
GROUP = "group"
54+
OFF = "off"
55+
56+
4357
class QuantizationArgs(BaseModel):
4458
"""
4559
User facing arguments used to define a quantization config
@@ -69,7 +83,8 @@ class QuantizationArgs(BaseModel):
6983
strategy: Optional[QuantizationStrategy] = None
7084
block_structure: Optional[str] = None
7185
dynamic: bool = False
72-
actorder: bool = False
86+
actorder: Union[ActivationOrderingStrategy,
87+
bool] = ActivationOrderingStrategy.OFF
7388
observer: str = Field(
7489
default="minmax",
7590
description=("The class to use to compute the quantization param - "

vllm/model_executor/layers/quantization/utils/marlin_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,16 @@ def marlin_make_workspace(output_size_per_partition: int,
129129
requires_grad=False)
130130

131131

132-
def marlin_is_k_full(actorder: bool, is_row_parallel: bool) -> bool:
133-
return (not actorder) or (actorder and not is_row_parallel)
132+
def marlin_is_k_full(has_g_idx: bool, is_row_parallel: bool) -> bool:
133+
return (not has_g_idx) or (not is_row_parallel)
134134

135135

136-
def marlin_repeat_scales_on_all_ranks(actorder: bool, group_size: int,
136+
def marlin_repeat_scales_on_all_ranks(has_g_idx: bool, group_size: int,
137137
is_row_parallel: bool) -> bool:
138138
# Need to repeat scales on every rank if actorder or
139139
# channelwise and RowParallelLinear
140140
is_channelwise = group_size == -1
141-
return actorder or (is_channelwise and is_row_parallel)
141+
return has_g_idx or (is_channelwise and is_row_parallel)
142142

143143

144144
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:

0 commit comments

Comments
 (0)