7
7
from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
8
8
CompressedTensorsScheme )
9
9
from vllm .model_executor .layers .quantization .utils .marlin_utils import (
10
- apply_gptq_marlin_linear , marlin_is_k_full , marlin_make_empty_g_idx ,
11
- marlin_make_workspace , marlin_permute_scales , marlin_sort_g_idx ,
12
- replace_tensor , verify_marlin_supported , verify_marlin_supports_shape )
10
+ apply_gptq_marlin_linear , marlin_make_empty_g_idx , marlin_make_workspace ,
11
+ marlin_permute_scales , marlin_sort_g_idx , replace_tensor ,
12
+ verify_marlin_supported , verify_marlin_supports_shape )
13
13
from vllm .model_executor .parameter import (BasevLLMParameter ,
14
14
ChannelQuantScaleParameter ,
15
15
GroupQuantScaleParameter ,
@@ -31,8 +31,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
31
31
def __init__ (self ,
32
32
strategy : str ,
33
33
num_bits : int ,
34
- group_size : Optional [int ] = None ,
35
- actorder : bool = False ):
34
+ group_size : Optional [int ] = None ):
36
35
37
36
self .pack_factor = 32 // num_bits
38
37
self .strategy = strategy
@@ -50,15 +49,6 @@ def __init__(self,
50
49
51
50
self .quant_type = WNA16_SUPPORTED_TYPES_MAP [num_bits ]
52
51
53
- if actorder and self .group_size == - 1 :
54
- # In this case, actorder == True is the same as actorder == False
55
- # (since we have only one group per output channel)
56
- logger .warning (
57
- "Model must be quantized with group_size > 0 in order to use "
58
- "activation ordering" )
59
- actorder = False
60
- self .actorder = actorder
61
-
62
52
# Verify supported on platform.
63
53
verify_marlin_supported (quant_type = self .quant_type ,
64
54
group_size = self .group_size )
@@ -75,7 +65,6 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
75
65
** kwargs ):
76
66
77
67
output_size_per_partition = sum (output_partition_sizes )
78
- is_row_parallel = input_size != input_size_per_partition
79
68
80
69
# If group_size is -1, we are in channelwise case.
81
70
channelwise = (self .group_size == - 1 )
@@ -133,21 +122,21 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
133
122
dtype = torch .int64 ),
134
123
weight_loader = weight_loader )
135
124
136
- # G_IDX (for activation reordering)
137
- g_idx = BasevLLMParameter (data = torch .empty (input_size_per_partition ,
138
- dtype = torch .int32 ),
125
+ # group index (for activation reordering)
126
+ weight_g_idx = BasevLLMParameter (data = torch .full ((input_size_per_partition , ),
127
+ - 1 ,
128
+ dtype = torch .int32 ),
139
129
weight_loader = weight_loader )
140
130
141
131
layer .register_parameter ("weight_packed" , weight )
142
132
layer .register_parameter ("weight_scale" , weight_scale )
143
133
layer .register_parameter ("weight_shape" , weight_shape )
144
- layer .register_parameter ("weight_g_idx" , g_idx )
134
+ layer .register_parameter ("weight_g_idx" , weight_g_idx )
145
135
146
136
layer .input_size_per_partition = input_size_per_partition
147
137
layer .output_size_per_partition = output_size_per_partition
148
138
layer .input_size = input_size
149
139
layer .group_size = group_size
150
- layer .is_k_full = marlin_is_k_full (self .actorder , is_row_parallel )
151
140
152
141
# Checkpoints are serialized in compressed-tensors format, which is
153
142
# different from marlin format. Handle repacking here.
@@ -159,7 +148,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
159
148
layer .output_size_per_partition , device )
160
149
161
150
# Handle sorting for activation reordering if needed.
162
- if self .actorder :
151
+ has_g_idx = - 1 not in layer .weight_g_idx
152
+ if has_g_idx :
163
153
g_idx , g_idx_sort_indices = marlin_sort_g_idx (layer .weight_g_idx )
164
154
layer .g_idx_sort_indices = g_idx_sort_indices
165
155
replace_tensor (layer , "weight_g_idx" , g_idx )
@@ -188,7 +178,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
188
178
marlin_scales = marlin_permute_scales (
189
179
layer .weight_scale ,
190
180
size_k = (layer .input_size
191
- if self . actorder else layer .input_size_per_partition ),
181
+ if has_g_idx else layer .input_size_per_partition ),
192
182
size_n = layer .output_size_per_partition ,
193
183
group_size = layer .group_size )
194
184
replace_tensor (layer , "weight_scale" , marlin_scales )
0 commit comments