You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
May be a noob question: Is this a bug or does int4 require the weights to be in bfloat16?
Traceback (most recent call last):
File "/home/agunapal/torch_ao/vit_ao.py", line 16, in <module>
quantize_(model, int4_weight_only())
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 463, in quantize_
_replace_with_custom_fn_if_matches_filter(
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 203, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
[Previous line repeated 2 more times]
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 199, in _replace_with_custom_fn_if_matches_filter
model = replacement_fn(model)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 393, in insert_subclass
lin.weight = torch.nn.Parameter(constructor(lin.weight, **kwargs), requires_grad=requires_grad)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 553, in apply_int4_weight_only_quant
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 286, in from_hp_to_intx
layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/dtypes/affine_quantized_tensor.py", line 1033, in from_plain
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/utils.py", line 319, in pack_tinygemm_scales_and_zeros
guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size())
File "/home/agunapal/anaconda3/envs/torchao/lib/python3.10/site-packages/torchao/quantization/utils.py", line 128, in guard_dtype_size
raise ValueError(f"Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.")
ValueError: Expected Tensor argument scales to have dtype torch.bfloat16, but got torch.float32 instead.
Code for repro:
import torch
import torchao
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchao.utils import benchmark_model
from torchao.quantization import int4_weight_only, quantize_
torch.set_float32_matmul_precision('high')
dtype = torch.float32
device = "cuda"
N = 1
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
quantize_(model, int4_weight_only())
model = torch.compile(model, mode='max-autotune').to(device).to(dtype)
method = "int8 quantize followed by compile"
input = (torch.randn(N, 3, 224, 224).to(device).to(dtype),)
with torch.no_grad():
# warmup
benchmark_model(model, 20, input)
# benchmark
result.append((method, N, benchmark_model(model, 100, input)))
for (method, N, elapsed_time) in result:
print(f"batch_size={N} : elapsed time {elapsed_time:.3f} ms : {method} ")
The text was updated successfully, but these errors were encountered:
yeah int4_weight_only quant requires bfloat16 right now I think since that's the only dtype support for the tinygemm kernel (int4_weight_only is actually corresponding to just int4 tinygemm kernel, it's not a general int4 weight only quant)
Getting this error with
int4
quantization.May be a noob question: Is this a bug or does
int4
require the weights to be inbfloat16
?Code for repro:
The text was updated successfully, but these errors were encountered: