-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] TorchAO Quantizer #10009
[core] TorchAO Quantizer #10009
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
looks good to me overall, I think we also want to think about how we can integrate autoquant API: https://github.com/pytorch/ao/tree/main/torchao/quantization#autoquantization that works on the full model instead of individual linear modules |
if device_map is not None: | ||
raise NotImplementedError( | ||
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul I'm not sure how this impacts BnB quantizer. I assume it was disabled for BnB for some reason I'm not aware of. It works with TorchAO as expected though so if you need this to have some kind of guard for torchao-specific, I'll add it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @SunMarc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works with TorchAO as expected though
How did you test that?
I assume it was disabled for BnB for some reason I'm not aware of.
That is because we merge the sharded checkpoints when using bnb
and using custom device_map
s needs this codepath:
accelerate.load_checkpoint_and_dispatch( |
This is not hit when loading quantized checkpoints at least for bitsandbytes
. This will be tackled in: #10013
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How did you test that?
There is a test for this in tests/quantization/torchao/test_torchao.py
called test_offload
that can be used to verify that cpu/disk offloading works with torchao
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but this is about custom user-provided device_map
s. What am I missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check of if device_map is not None
was added in the BnB Quantizer PR. I assume it was added because device_map is not supported in BnB. But it works perfectly fine with TorchAO (as the test checks), so I removed the change in order to do the initial testing of the TorchAO quantizer quickly.
I would like to know if there should be an error raised if BnB quantizer is the method used. Something like:
if quantization method is BnB and device_map is not None:
raise Error
Does that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that works for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this. I love my aspects of this PR. My favorite being how we're supporting many torchao
quant configs!
Apart from the comments I left in-line, I have the following additional comments:
- Consider testing for model memory footprint as well.
- Consider including integration tests for Flux, Cog, etc. At least Flux should be covered.
- Consider adding a note on serialization in the docs.
- Consider testing for
modules_to_not_convert
. - Consider adding a test for training. Example:
diffusers/tests/quantization/bnb/test_4bit.py
Line 365 in 069186f
def test_training(self):
LMK if anything is unclear.
|
||
## Usage | ||
|
||
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. | |
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] or even load a pre-quantized model. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. |
if device_map is not None: | ||
raise NotImplementedError( | ||
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works with TorchAO as expected though
How did you test that?
I assume it was disabled for BnB for some reason I'm not aware of.
That is because we merge the sharded checkpoints when using bnb
and using custom device_map
s needs this codepath:
accelerate.load_checkpoint_and_dispatch( |
This is not hit when loading quantized checkpoints at least for bitsandbytes
. This will be tackled in: #10013
"hf-internal-testing/tiny-flux-pipe", | ||
subfolder="transformer", | ||
quantization_config=quantization_config, | ||
device_map=device_map_offload, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We assign a hf_device_map
attribute to the model too, so we should also check if the quantized_hf_device_map
matches the expected one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w just checking if this is remaining to be added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the test_offload
test with a check to verify hf_device_map
is same as device_map
. I don't see any quantized_hf_device_map
when grepping the codebase or searching on github, so not sure what you are referring too. Could you help with this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did another pass and answered some questions.
Very important would be to have a test suite for torchao
+ torch.compile()
(at least for some quant types) as that is a massive USP of torchao
.
module, tensor_name = get_module_from_name(model, param_name) | ||
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check it to do a further inspection on the shape as well as to handle any parameter creation. But I guess it's fine with torchao
because of the reasons you mentioned. @SunMarc WDYT?
@property | ||
def is_trainable(self): | ||
# TODO(aryan): needs testing | ||
return self.quantization_config.quant_type.startswith("int8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp8 training is orthogonal here if we are talking about peft I think. But I feel we should be able to fine-tune a fp8 quantized model as well (with float8_weight_only, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight) I feel, I haven't tried this though, did you see any errors when you try it?
I think this should be checked in the torchao
CI given the popularity of training quantized models? If you give us a heads up about that, we'd be more than happy to configure this here accordingly.
@BenjaminBossan could you comment on the support of torchao
<> peft
a bit here?
for param in module.parameters(): | ||
if param.__class__.__name__ == "AffineQuantizedTensor": | ||
data, scale, zero_point = param.layout_tensor.get_plain() | ||
quantized_param_memory += data.numel() + data.element_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh typo here... should be multiplied
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also have this util in torchao btw: https://github.com/pytorch/ao/blob/abff563ba515576fc48cd4ac0feb923dd65dc267/torchao/utils.py#L245
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh interesting, missed it! Will try it out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really well!
I think it would also make sense to run the existing and important integration tests before merging to make sure there's no obvious bugs.
"hf-internal-testing/tiny-flux-pipe", | ||
subfolder="transformer", | ||
quantization_config=quantization_config, | ||
device_map=device_map_offload, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w just checking if this is remaining to be added?
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) | ||
|
||
@staticmethod | ||
def _get_memory_footprint(module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this not work?
diffusers/src/diffusers/models/modeling_utils.py
Line 1277 in 18f9b99
def get_memory_footprint(self, return_buffers=True): |
If not, we should consider these changes in modeling_utils.py
, IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, this does not return the correct size of model weights when quantization is applied. We can consider the change in modeling_utils.py
in a separate PR to account for the AQT tensors, since this is just present in the tests for the moment.
The TorchAO utility provided by Jerry here is probably better to use that what I have here.
@staticmethod | ||
def _get_memory_footprint(module): | ||
quantized_param_memory = 0.0 | ||
unquantized_param_memory = 0.0 | ||
|
||
for param in module.parameters(): | ||
if param.__class__.__name__ == "AffineQuantizedTensor": | ||
data, scale, zero_point = param.layout_tensor.get_plain() | ||
quantized_param_memory += data.numel() + data.element_size() | ||
quantized_param_memory += scale.numel() + scale.element_size() | ||
quantized_param_memory += zero_point.numel() + zero_point.element_size() | ||
else: | ||
unquantized_param_memory += param.data.numel() * param.data.element_size() | ||
|
||
total_memory = quantized_param_memory + unquantized_param_memory | ||
return total_memory, quantized_param_memory, unquantized_param_memory | ||
|
||
def test_memory_footprint(self): | ||
r""" | ||
A simple test to check if the model conversion has been done correctly by checking on the | ||
memory footprint of the converted model and the class type of the linear layers of the converted models | ||
""" | ||
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"] | ||
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"] | ||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"] | ||
transformer_bf16 = self.get_dummy_components(None)["transformer"] | ||
|
||
total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo) | ||
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint( | ||
transformer_int4wo_gs32 | ||
) | ||
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo) | ||
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16) | ||
|
||
self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16) | ||
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points | ||
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32) | ||
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32 | ||
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32) | ||
# int8 quantizes more layers compare to int4 with default group size | ||
self.assertTrue(quantized_int8wo < quantized_int4wo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah both are separate memory footprint gives us the ballpark around how much we need to load. Memory usage will tell us the actual memory needed for execution. Both could be considered to be included here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh thanks for the great work!
PR looks very good to me and I think we can merge this very soon. The only concern I have is the API to support all the shorthand, IMO we should not, but I'm open to different opinions! :)
) | ||
|
||
@classmethod | ||
def _get_torchao_quant_type_to_method(cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh, I think it will be too much to support API with all the shorthand, creating and maintaining all the related docs, and keeping the list up to date.
Is this something currently supported by TorchAO, or in their plan to? If they support it, I think it will be ok/managable for us to maintain a parallel mapping. otherwise, I think it is unnecessary/not meaningful for us to come up with new APIs for the external libraries we integrate.
If we want to create a set of "shorthand standards" that we can use at diffusers across all the different quantization methods/libraries we support (e.g. something we can use for both bnb, torchAO etc), it might be meaningful, but I think it will be better if we do that after we have a few more libraries in :)
also cc @SunMarc here, because for quantisation we would like to keep the API roughly consistent between transformer and diffusers
Overall, IMO , I think we should just only accept passing method name as it is, for shorthand, it is ok to support a very very small and most commonly used list.
But I'm open to different opinions! so let me know cc @DN6 too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, maybe we can provide some common shorthand, these have been repeated in many libraries: https://github.com/pytorch/ao/blob/8a805d08898e5c961fb9b4f6ab61ffd5d5bdbca5/torchao/_models/llama/generate.py#L702
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking forward to hear thoughts from others! I find the shorthands for uintx
/fpx
, and suffixes of wo
and dq
, rather convenient. I believe they are commonly used too. No hard preferences, okay with whatever we decide
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think the number of terms here will be a bit difficult to maintain. Perhaps we support just the Shorthands mentioned here? Does the community have any preference?
https://github.com/huggingface/diffusers/pull/10009/files#r1873188178
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @DN6! I've removed the documentation shorthands (anything of the form {dtype}_aXwY
after our discussion in DM.
I think we should definitely have the fully qualified function names, wo
and dq
suffixes. Updated the documentation accordingly.
For Jerry's suggestion, the hqq, marlin, sparsify, autoquant, intx (prototype), spinquant can be tackled in a separate PR after trying it out. Let's keep this one to the just the ones that we have here already. Will check the generation quality with these soon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding, really super!
"""This is a config class for torchao quantization/sparsity techniques. | ||
|
||
Args: | ||
quant_type (`str`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is quite heavy for a docstring
Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
) | ||
|
||
@classmethod | ||
def _get_torchao_quant_type_to_method(cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think the number of terms here will be a bit difficult to maintain. Perhaps we support just the Shorthands mentioned here? Does the community have any preference?
https://github.com/huggingface/diffusers/pull/10009/files#r1873188178
@SunMarc can you do a final review if you haven't? |
* torchao quantizer --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Steven Liu <[email protected]>
|
||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more. | ||
|
||
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed it seems Pytorch 2.5+ is required because in
torch.uint1, |
torch.uint1
(and others) which are not available in earlier torch versions. However, diffusers
seem to require torch>=1.4
(ref), so this seem inconsistent. Am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TorchAO will not be imported or usable unless the pytorch version of 2.5 or above is available. Some Diffusers models can run with the 1.4 version as well, which is why that's the minimum required version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm running into the same issue with the torch.unit1
import. It seems the TorchAO import is not guarded according to the backtrace. The following backtrace stems from this import line:
from diffusers import StableDiffusionXLPipeline
Here is the trace, and the pip list:
Traceback (most recent call last):
File "/github/home/.local/lib/python3.10/site-packages/diffusers/utils/import_utils.py", line 920, in _get_module
return importlib.import_module("." + module_name, self.__name__)
File "/usr/local/lib/python3.10/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 883, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/github/home/.local/lib/python3.10/site-packages/diffusers/loaders/single_file.py", line 24, in <module>
from .single_file_utils import (
File "/github/home/.local/lib/python3.10/site-packages/diffusers/loaders/single_file_utils.py", line 28, in <module>
from ..models.modeling_utils import load_state_dict
File "/github/home/.local/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 35, in <module>
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/__init__.py", line 15, in <module>
from .auto import DiffusersAutoQuantizer
File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/auto.py", line 31, in <module>
from .torchao import TorchAoHfQuantizer
File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/torchao/__init__.py", line 15, in <module>
from .torchao_quantizer import TorchAoHfQuantizer
File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/torchao/torchao_quantizer.py", line 45, in <module>
torch.uint1,
File "/github/home/.local/lib/python3.10/site-packages/torch/__init__.py", line 1938, in __getattr__
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'uint1'
And the pip list:
pip list -v
Package Version Editable project location Location Installer
------------------------ ----------- ---------------------------- ------------------------------------------------ ---------
certifi 2024.12.14 /github/home/.local/lib/python3.10/site-packages pip
charset-normalizer 3.4.0 /github/home/.local/lib/python3.10/site-packages pip
colorama 0.4.6 /github/home/.local/lib/python3.10/site-packages pip
coloredlogs 15.0.1 /github/home/.local/lib/python3.10/site-packages pip
colorlog 6.9.0 /github/home/.local/lib/python3.10/site-packages pip
coverage 7.6.9 /github/home/.local/lib/python3.10/site-packages pip
diffusers 0.32.0 /github/home/.local/lib/python3.10/site-packages pip
exceptiongroup 1.2.2 /github/home/.local/lib/python3.10/site-packages pip
execnet 2.1.1 /github/home/.local/lib/python3.10/site-packages pip
filelock 3.16.1 /github/home/.local/lib/python3.10/site-packages pip
flatbuffers 24.12.23 /github/home/.local/lib/python3.10/site-packages pip
fsspec 2024.12.0 /github/home/.local/lib/python3.10/site-packages pip
huggingface-hub 0.27.0 /github/home/.local/lib/python3.10/site-packages pip
humanfriendly 10.0 /github/home/.local/lib/python3.10/site-packages pip
idna 3.10 /github/home/.local/lib/python3.10/site-packages pip
importlib_metadata [8](/runs/952286/job/1969259#step:10:9).5.0 /github/home/.local/lib/python3.10/site-packages pip
iniconfig 2.0.0 /github/home/.local/lib/python3.10/site-packages pip
Jinja2 3.1.5 /github/home/.local/lib/python3.10/site-packages pip
markdown-it-py 3.0.0 /github/home/.local/lib/python3.10/site-packages pip
MarkupSafe 3.0.2 /github/home/.local/lib/python3.10/site-packages pip
mdurl 0.1.2 /github/home/.local/lib/python3.10/site-packages pip
mpmath 1.3.0 /github/home/.local/lib/python3.10/site-packages pip
networkx 3.4.2 /github/home/.local/lib/python3.10/site-packages pip
numpy 1.26.4 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cublas-cu12 12.1.3.1 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cuda-cupti-cu12 12.1.105 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cuda-nvrtc-cu12 12.1.105 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cuda-runtime-cu12 12.1.105 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cudnn-cu12 8.[9](/runs/952286/job/1969259#step:10:10).2.26 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cufft-cu12 11.0.2.54 /github/home/.local/lib/python3.[10](/runs/952286/job/1969259#step:10:11)/site-packages pip
nvidia-curand-cu12 10.3.2.106 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cusolver-cu12 [11](/runs/952286/job/1969259#step:10:12).4.5.107 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cusparse-cu[12](/runs/952286/job/1969259#step:10:13) 12.1.0.106 /github/home/.local/lib/python3.10/site-packages pip
nvidia-nccl-cu12 2.19.3 /github/home/.local/lib/python3.10/site-packages pip
nvidia-nvjitlink-cu12 12.6.85 /github/home/.local/lib/python3.10/site-packages pip
nvidia-nvtx-cu12 12.1.105 /github/home/.local/lib/python3.10/site-packages pip
onnx 1.17.0 /github/home/.local/lib/python3.10/site-packages pip
onnx2torch 1.5.15 /github/home/.local/lib/python3.10/site-packages pip
onnxruntime 1.20.1 /github/home/.local/lib/python3.10/site-packages pip
onnxsim 0.4.36 /github/home/.local/lib/python3.10/site-packages pip
packaging 24.2 /github/home/.local/lib/python3.10/site-packages pip
pandas 2.2.3 /github/home/.local/lib/python3.10/site-packages pip
pillow 11.0.0 /github/home/.local/lib/python3.10/site-packages pip
pip 22.0.4 /usr/local/lib/python3.10/site-packages pip
pluggy 1.5.0 /github/home/.local/lib/python3.10/site-packages pip
protobuf 5.29.2 /github/home/.local/lib/python3.10/site-packages pip
Pygments 2.18.0 /github/home/.local/lib/python3.10/site-packages pip
pytest 8.3.4 /github/home/.local/lib/python3.10/site-packages pip
pytest-xdist 3.6.1 /github/home/.local/lib/python3.10/site-packages pip
python-dateutil 2.9.0.post0 /github/home/.local/lib/python3.10/site-packages pip
python-dotenv 1.0.1 /github/home/.local/lib/python3.10/site-packages pip
python-json-logger 3.2.1 /github/home/.local/lib/python3.10/site-packages pip
pytz 2024.2 /github/home/.local/lib/python3.10/site-packages pip
PyYAML 6.0.2 /github/home/.local/lib/python3.10/site-packages pip
regex 2024.11.6 /github/home/.local/lib/python3.10/site-packages pip
requests 2.32.3 /github/home/.local/lib/python3.10/site-packages pip
rich 13.9.4 /github/home/.local/lib/python3.10/site-packages pip
safetensors 0.4.5 /github/home/.local/lib/python3.10/site-packages pip
scipy 1.[14](/runs/952286/job/1969259#step:10:15).1 /github/home/.local/lib/python3.10/site-packages pip
sentencepiece 0.2.0 /github/home/.local/lib/python3.10/site-packages pip
setuptools 58.1.0 /usr/local/lib/python3.10/site-packages pip
six 1.17.0 /github/home/.local/lib/python3.10/site-packages pip
sympy 1.13.3 /github/home/.local/lib/python3.10/site-packages pip
tabulate 0.9.0 /github/home/.local/lib/python3.10/site-packages pip
tokenizers 0.[15](/runs/952286/job/1969259#step:10:16).2 /github/home/.local/lib/python3.10/site-packages pip
tomli 2.2.1 /github/home/.local/lib/python3.10/site-packages pip
torch 2.2.2 /github/home/.local/lib/python3.10/site-packages pip
torchvision 0.[17](/runs/952286/job/1969259#step:10:18).2 /github/home/.local/lib/python3.10/site-packages pip
tqdm 4.67.1 /github/home/.local/lib/python3.10/site-packages pip
transformers 4.38.2 /github/home/.local/lib/python3.10/site-packages pip
triton 2.2.0 /github/home/.local/lib/python3.10/site-packages pip
typing_extensions 4.12.2 /github/home/.local/lib/python3.10/site-packages pip
tzdata [20](/runs/952286/job/1969259#step:10:21)24.2 /github/home/.local/lib/python3.10/site-packages pip
urllib3 2.3.0 /github/home/.local/lib/python3.10/site-packages pip
wheel 0.37.1 /usr/local/lib/python3.10/site-packages pip
zipp 3.[21](/runs/952286/job/1969259#step:10:22).0 /github/home/.local/lib/python3.10/site-packages pip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for reporting @fjeremic! We were able to replicate for torch <= 2.2. It seems to not cause the import errors for >= 2.3. We will be doing a patch release soon to fix this behaviour. Sorry for the inconvenience!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for providing a quick fix!
For completeness, I was running into the import error with torch 2.2.2
when importing AutoencoderKL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BeckerFelix @fjeremic The patch release is out! Hope it fixes any problems you were facing in torch < 2.3
What does this PR do?
Adds support for the TorchAO Quantizer.
Quantization formats
TorchAO supports a wide variety of quantizations. For a normal user, it can get quite overwhelming to understand all the different parameters and what they mean. In order to simplify this a bit, I've used some custom commonly used names that are easier to remember/use while also supporting full configurability of the original arguments.
The naming conventions used are:
int8_weight_only
,float8_weight_only
, etc. You can pass the arguments supported by each method (as described in torchao docs) it through quantization kwargs.wo
(weight-only) anddq
(weight + activation quantization) suffixes{dtype}_a{x}w{y}
: shorthand notations for convenience reasons and because thea{x}w{y}
notation is used extensively in the torchao docs_tensor
and per row is suffixed with_row
. per axis and per group granularity is also supported but they involve additional parameters in their constructors so power-users are free to play with that if they like, but the shorthands provided here are just for tensor/row.Broadly,
int4
,int8
,uintx
,fp8
andfpx
quantizations are supported, with dynamic activation quants where applicable, otherwise weight-only. Group sizes can be specified by power-users via the full function names and we don't have special names to handle those.Benchmarks
The following code is used for benchmarking:
Code
You can launch it with something like:
Here are the time/memory results from a single H100:
Flux Table
Flux visual results
CogVideoX table
CogVideoX visual results
Unfortunately, the prompt I used does not produce a very good initial video. Should have verified this in the beginning... 🫠
For some reason, GitHub does not render the videos from HF despite trying a few things. So, I'm not embedding it here. The results can be found here: https://huggingface.co/datasets/a-r-r-o-w/randoms/tree/main/cogvideox_benchmark_results
The minimal code for using the quantizer would be:
TODO
modules_to_not_convert
.torch.compile()
<>torchao
@DN6 @sayakpaul @yiyixuxu