Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchtune quantization has different model output comparing with document #1701

Open
elfisworking opened this issue Sep 27, 2024 · 1 comment

Comments

@elfisworking
Copy link

elfisworking commented Sep 27, 2024

I'm using torchtune for model quantization with QAT. Currently, I am learning based on https://pytorch.org/torchtune/main/tutorials/qat_finetune.html, but the results of the prepared_model I printed are different from those in the link. Is this normal?

from torchtune.training.quantization import Int8DynActInt4WeightQATQuantizer
from torchtune.models.llama3 import llama3_8b

model = llama3_8b()

# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
quantizer = Int8DynActInt4WeightQATQuantizer()

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# fine-tuning without performing any dtype casting
prepared_model = quantizer.prepare(model)

link show me like this.

>>> print(prepared_model.layers[0].attn)
MultiHeadAttention(
  (q_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
  (v_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
  (output_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
  (pos_embeddings): RotaryPositionalEmbeddings()
)

But i get this:

MultiHeadAttention(
  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
  (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
  (output_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (pos_embeddings): RotaryPositionalEmbeddings()
)

torch 2.4.1
torchtune 0.3.0
torchao 0.5.0
device: nvidia 3090
python: 3.10

@ebsmothers
Copy link
Contributor

Hi @elfisworking thanks for creating the issue! Actually I think our QAT tutorial may be slightly out-of-date. This was written when QAT was done with module swapping (hence why we'd expect Linear -> Int8DynActInt4WeightQATLinear) but now it uses tensor subclasses. If I understand correctly, the fact that you still see Linear instead actually means that you're just on the latest version. cc @andrewor14 to confirm though. If so, we can update our QAT tutorial to reflect this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants