diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py index 5b2be2c65..4d9f48e62 100644 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ b/src/llmcompressor/modifiers/quantization/cache.py @@ -90,6 +90,7 @@ def update( self.k_observers.append(k_observer) self.v_observers.append(v_observer) + # batch x heads x seq_len x head_dim q_key_states = self._quantize( key_states.contiguous(), KVCacheScaleType.KEY, layer_idx ) @@ -150,7 +151,15 @@ def _quantize(self, tensor, kv_type, layer_idx): scales = self.v_scales zps = self.v_zps - scale, zp = observer(tensor) + # note: key, value states are in the shape: + # [batch, num_key_value_heads, seq_len, head_dim] + + base_name = None # tensor-wise quantization, shape of [1] + if self.quantization_args.strategy == "channel": + # target last dim to quantize, shape of [head_dim] + base_name = "kv_cache" + + scale, zp = observer(tensor, base_name=base_name) if len(scales) <= layer_idx: scales.append(scale) zps.append(zp) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index bcb4b7433..37cca0ac8 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -81,7 +81,9 @@ def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor] raise ValueError("Must provide a value to observe if not using weight observer") observer = getattr(module, f"{base_name}_observer") - updated_scale, updated_zero_point = observer(value, g_idx=g_idx) + updated_scale, updated_zero_point = observer( + value, g_idx=g_idx, base_name=base_name + ) # update scale and zero point update_parameter_data(module, updated_scale, f"{base_name}_scale") diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index e70125908..9bc030b52 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -31,7 +31,10 @@ def __init__(self, quantization_args: QuantizationArgs): @torch.no_grad() def forward( - self, observed: Tensor, g_idx: Optional[Tensor] = None + self, + observed: Tensor, + g_idx: Optional[Tensor] = None, + base_name: Optional[str] = None, ) -> Tuple[FloatTensor, IntTensor]: """ maps directly to get_qparams @@ -40,8 +43,9 @@ def forward( :param g_idx: optional mapping from column index to group index :return: tuple of scale and zero point based on last observed value """ + # breakpoint() self.record_observed_tokens(observed) - return self.get_qparams(observed=observed, g_idx=g_idx) + return self.get_qparams(observed=observed, g_idx=g_idx, base_name=base_name) def calculate_qparams( self, @@ -66,6 +70,7 @@ def get_qparams( self, observed: Optional[Tensor] = None, g_idx: Optional[Tensor] = None, + base_name: Optional[str] = None, ) -> Tuple[FloatTensor, IntTensor]: """ Convenience function to wrap overwritten calculate_qparams @@ -123,8 +128,25 @@ def get_qparams( self._zero_point[:, group_index] = zero_point.squeeze(1) elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - # assume observed is transposed, because its the output, hence use dim 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) + if base_name in ("output", "kv_cache"): + # the last dimension is the hidden dimension + # shape of [1,1, num_key_value_heads * head_dim] + scale, zero_point = self.get_qparams_along_dim( + observed, observed.ndim - 1 + ) + self._scale = ( + scale.squeeze() + ) # shape of [num_key_value_heads * head_dim] + self._zero_point = ( + zero_point.squeeze() + ) # shape of [num_key_value_heads * head_dim] + else: + # weight or input + # assume observed is transposed, + # because its the output, hence use dim 0 + self._scale, self._zero_point = self.get_qparams_along_dim( + observed, 0 + ) elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: # use dim 1, assume the obsersed.shape = [batch, token, hidden]