Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Channel-wise Output Activation Quantization for Attention QKV Modules + KV-cache channel quantization #1233

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

horheynm
Copy link
Collaborator

@horheynm horheynm commented Mar 7, 2025

Blocked on : neuralmagic/compressed-tensors#270

SUMMARY:
Quantize the output activation of the attention layers for channel wise -> did not have support -> selected wrong dim to quantize.
Quantize the kv-cache for channel wise int8 -> previously only supported tensor-wise.

Attention we need to worry about is the QKV. O/Up/down is not quantized.

Math:
x is the input vector -> tokenized + embedding
weight for QKV is Linear modules
output is the forward call of QKV with x

# x
(Pdb) hidden_states.shape -> torch.Size([1, 1930, 4096]) -> [batch, seq_len, hidden_size]

# weight
(Pdb) self.q_proj.weight.shape -> torch.Size([4096, 4096]) -- [hidden_size, hidden_size]
(Pdb) self.k_proj.weight.shape -> torch.Size([1024, 4096]) -- [num_key_value_heads * head_dim, hidden_size]
(Pdb) self.v_proj.weight.shape -> torch.Size([1024, 4096]) -- [num_key_value_heads * head_dim, hidden_size]

# output
(Pdb) self.q_proj(hidden_states).shape -> torch.Size([1, 1930, 4096]) -> [batch, seq_len, hidden_size]
(Pdb) self.k_proj(hidden_states).shape -> torch.Size([1, 1930, 1024]) -> [batch, seq_len, num_key_value_heads * head_dim]
(Pdb) self.v_proj(hidden_states).shape -> torch.Size([1, 1930, 1024]) -> [batch, seq_len, num_key_value_heads * head_dim]

# key_states, value_states shape
[batch, num_key_value_heads, seq_len, head_dim]

Expected output scales and zp shapes for output activations

q_proj activations -> [4096] -> [hidden_size]
k_proj activations -> [1024] -> [num_key_value_heads * head_dim]
v_proj activations -> [1024] -> [num_key_value_heads * head_dim]

Expected output scales and zp shapes for kv-cache channel

k_proj, v_proj -> [head_dim]

The observer will output the vectors in the same ndim as the given output activation tensor (ie. torch.Size([1, 1930, 1024]), then outputs torch.Size([1, 1, 1024])). Squeeze it to just get torch.Size([1024]), so ndim of 1.

TEST PLAN:

  • Pass tests
  • Pass eval

Copy link

github-actions bot commented Mar 7, 2025

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@horheynm
Copy link
Collaborator Author

horheynm commented Mar 7, 2025

Next todo is to add support for group quantization for output activations

horheynm added 2 commits March 6, 2025 23:36
Signed-off-by: George Ohashi <[email protected]>
@horheynm horheynm added the ready When a PR is ready for review label Mar 7, 2025
@dsikka dsikka marked this pull request as draft March 7, 2025 14:27
Signed-off-by: George Ohashi <[email protected]>
@horheynm horheynm changed the title [Quantization] Channel-wise Output Activation Quantization for Attention QKV Modules [Quantization] Channel-wise Output Activation Quantization for Attention QKV Modules + KV-cache channel quantization Mar 7, 2025
@horheynm
Copy link
Collaborator Author

horheynm commented Mar 7, 2025

Will break down kv-cache logic to a different PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready When a PR is ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants