v0.5.0
Highlights
We are excited to announce the 0.5 release of torchao! This release adds support for memory efficient inference, float8 training and inference, int8 quantized training, HQQ, automatic mixed-precision quantization through bayesian optimization, sparse marlin, and integrations with HuggingFace, SGLang, and diffusers.
Memory Efficient Inference Support #738
We've added support for Llama 3.1 to the llama benchmarks in TorchAO and added new features and improvements as a proof of concept for memory efficient inference. These additions allow us to to do 130k context length inference with Llama 3.1-8B with only 18.91 GB memory if we combine with kv cache quantization, int4 weight only quantization and linear causal mask.
General savings depend on technique and context length as can be seen in the following graph:
Float8 Training #551
torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.
With torch.compile
on, current results show throughput speedups of up to 1.5x on 128 H100 GPU LLaMa 3 70B pretraining jobs (details)
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(m, module_filter_fn=...)
And for an end-to-minimal training recipe of pretraining with float8, you can check out torchtitan.
Float8 Inference #740 #819
We have introduced two new quantization APIs for Float8 inference:
-
Float8 Weight-Only Quantization: A new quant_api float8_weight_only() has been added to apply float8 weight-only symmetric per-channel quantization to linear layers.
-
Float8 Dynamic Activation and Weight Quantization: A new quant_api float8_dynamic_activation_float8_weight() has been introduced to apply float8 dynamic symmetric quantization to both activations and weights of linear layers. By default PerTensor scaling. We have also added an option to do PerRow scaling of both activations and weights. By computing scales at a finer granularity, it can potentially reduce the overall quantization error and increase performance by reducing dynamic quantization overhead.
Example usage:
import torch
from torchao.quantization import quantize_, float8_weight_only, float8_dynamic_activation_float8_weight, PerRow
# Create a model
model = YourModel()
# Apply float8 weight-only quantization
quantize_(model, float8_weight_only())
# Apply float8 dynamic activation and weight quantization
quantize_(model, float8_dynamic_activation_float8_weight())
# Apply PerRow scaling to weight and activations
quantize_(linear_module, float8_dynamic_activation_float8_weight(granularity=PerRow()))
Notes:
- These new APIs are designed to work with PyTorch 2.5 and later versions.
float8_dynamic_activation_float8_weight
requires CUDA devices with compute capability 8.9 or higher for hardware acceleration.
Int8 quantized training #644 #748
@gau-nernst introduced 2 experimental works on training using INT8.
- INT8 quantized training (#644): weight is quantized to INT8 during the whole duration of training to save memory. Compute remains in high precision. To train the model effectively with only quantized weights, we use stochastic rounding for weight update. Right now, memory saving is not too competitive compared to compiled BF16 baseline.
- INT8 mixed-precision training (#748): weight is kept in the original high precision, but weight and activation are dynamically quantized to INT8 during training to utilize INT8 tensor cores. We observe up to 70% speedup for Llama2 pre-training on 4090, and 20% speedup for Llama3 pre-training on 8x A100 with FSDP2.
from torchao.quantization import quantize_
from torchao.prototype.quantized_training import int8_weight_only_quantized_training, int8_mixed_precision_training
model = YourModel()
# apply INT8 quantized training
quantize_(model, int8_weight_only_quantized_training())
# apply INT8 mixed-precision training
quantize_(model, int8_mixed_precision_training())
For more information and benchmark results, see README and the respective PR (#644 and #748)
HQQ Integration in torchao #605 #786
hqq is added to existing torchao APIs, it gives improvements on model accuracy and leverages the existing efficient kernels in torchao. We enabled hqq for int4_weight_only
API:
quantize_(model, int4_weight_only(group_size, use_hqq=True)
We also added this to the uintx api for accuracy experiments (current uintx kernels are slow):
quantize_(model, uintx_weight_only(torch.uint2, group_size, use_hqq=True)
Automatic Mixed-Precision Quantization through Bayesian Optimization #592, #694
We provided a Bayesian Optimization (BO) tool leveraging Ax to auto search mixed-precision weight-only quantization configuration, i.e., bit width and group size of intN_weight_only(bit_width, group_size)
for each layer. It also includes a sensitivity analysis tool to calculate layer-wise average Hessian trace and average fisher information matrix trace, which is an optional step to customize and improve BO search.
To optimize for model accuracy under a model size constraint (GB):
python --BO_acc_modelsize.py --checkpoint=/tmp/Meta-Llama-3-8B --model_size_constraint=6.0
To optimize for inference throughput under a model perplexity constraint:
python --BO_acc_throughput.py --checkpoint=/tmp/Meta-Llama-3-8B --ppl_constraint=7.5
For more detailed usage, please refer to this README. The mixed-precision quantization searched by this tool reduces 20.1% model size with 2.8% perplexity reduction, and improves 15.1% inference throughput with 3.2% perplexity reduction on the Llama3-8B model compared to int8 uniform quantization.
Sparse Marlin #621, #733
@Diogo-V added sparse-marlin, a W4AFP16 2:4 sparse kernel, support to TorchAO.
On Meta LLama3, we observe a 25% tok/s increase (180 -> 226) compared to our existing int4-wo implementation.
from torchao.quantization.quant_api import quantize_, int4_weight_only
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
---|---|---|---|---|---|
Llama-3-8B | Base (bfloat16) | 95.64 | 1435.54 | 16.43 | 15.01 |
int8dq | 8.61 | 64.75 | 9.24 | 7.52 | |
int8wo | 153.03 | 1150.80 | 10.42 | 7.52 | |
int4wo-64 | 180.80 | 763.33 | 6.88 | 4.22 | |
int4wo-64-sparse-marlin | 226.02 | 689.20 | 5.32 | 3.05 |
HuggingFace Integration
torchao is integrated into huggingface: https://huggingface.co/docs/transformers/main/en/quantization/torchao now you can use int4_weight_only
, int8_weight_only
and int8_dynamic_activation_int8_weight
through TorchAoConfig
in huggingface. Currently available in huggingface main branch only.
SGLang Integration
torchao is also integrated into sglang (sgl-project/sglang#1341) for llama3 model, you can try out with:
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128
Supported configurations are ["int4wo-<group_size>", "int8wo", "int8dq", "fp8wo" (only available in torchao 0.5+)]
diffusers Integration
diffusers-torchao provides end-to-end inference and experimental training recipes to use torchao with diffusers in this repo. We demonstrate 53.88% speedup on Flux.1-Dev* and 27.33% speedup on CogVideoX-5b when comparing compiled quantized models against their standard bf16 counterparts.
BC Breaking
Add layout option to woq int4 api #670
# torchao 0.4.0
from torchao.quantization import quantize_, int4_weight_only
quantize_(my_model, int4_weight_only(inner_k_tiles=8))
# torchao 0.5.0
from torchao.quantization import quantize, int4_weight_only
quantize_(my_model, int4_weight_only(layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)))
Refactor QAT to use tensor subclasses #585
We refactored QAT to use tensor subclasses instead of module swap. This works well with torchtune and FSDP2, but currently lacks support for FSDP1 and DDP. As a fallback for these distribution strategies, please continue to use the old module swap flows.
# torchao 0.4.0: This uses the module swap flow
# torch 0.5.0 + FSDP2: This uses the tensor subclass flow
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
quantizer = Int8DynActInt4WeightQATQuantizer()
model = quantizer.prepare(model)
train(model)
model = quantizer.convert(model)
# torchao 0.5.0 + DDP or FSDP1: This uses the module swap flow
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap
quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap()
model = quantizer.prepare(model)
train(model)
model = quantizer.convert(model)
Deprecations
New Features
- Optimizer CPU offload for single GPU training #584
- Add support for save quantized checkpoint in llama code #553
- Intx quantization tensor subclass #468
- Add superblock to sparse/prototype #660
- Add AffineQuantizedObserver #650
- Add BSR subclass + torch.compile and clean up superblock #680
- Add HQQ support #605
- Add performance profiler #690
- Add experimental INT8 quantized training #644
- Add high-level operator interface #708
- Add sparse marlin 2:4 gemm op #733
- Example for GPTQ-like calibration flow #721
- Llama3.1 and KV cache quantization #738
- Add float8 weight only and weight + dynamic activation #740
- Add Auto-Round support #581
Mixed-Precision Quantization
- Add sensitivity analysis tool for layer-wise FIT and Hessian trace #592
- Bayesian optimization tool for mixed precision quantization #694
Improvements
- Move sam eval from
scripts
totorchao/_models
#591 - QOL improvements to float8 gemm benchmark #596
- Move lowbit universal kernels from torchaccel to torchao #582
- Refactor autoquant to use AQT #609
- Add support for using AffineQuantizedTensor with
weights_only=True
#630 - Move Uintx out of prototype for future extension #635
- Refactor
_quantized_linear
for better extensibility #634 - Update micro benchmarking code for AQT #673
- Refactor superblock code + add final benchmark/eval scripts #691
- Relax QAT dtype assertion #692
- Add option to move param to
device
before quantization #699 - Add gpu benchmarking script #192
- Enable
to(device=device_name)
forUintx
#722 - Make torchao's llama model trainable #728
- Specify output dtype to
torch.float32
in_foreach_norm
#727 - Add semi-structured sparsity to hf eval #576
- Use
torch.uint1
totorch.uint7
for Uintx tensor subclass #672 - Add AdamW to
CPUOffloadOptimizer
default #742 - Make developer experience better for extending AQT #749
- Add back QAT module swap API #762
- Refactor quant_llm to work with affine quantized tensor #772
- Move iOS benchmarking infra code to torchao #766
- Add CPU bandwidth benchmark #773
- Update method names to support intx and floatx changes #775
- Add implementation for torchao::parallel_for backends #774
- Add Llama2-7B finetune benchmarks for low-bit optimizers #746
- Fix Adam4bit support on PyTorch 2.3 and 2.4 and update AdamFp8 torch requirement #755
- Improve compile time + fix PyTorch 2.3 support for 4-bit optim #812
- Allow quantized linear registration in a different file #783
- Add 2bit, 5bit packing routines #797, #798
- Freeze dataclass in nf4, prep for better pt2 support #799
- Format and lint nf4 file and test #800
- Move more utils to TorchAOBaseTensor #784
- Add more information to quantized linear module and added some logs #782
- Add int4 mode to autoquant #804
- Add uintx quant to generate and eval #811
- Move non-NF4 tensor to device prior to quantization on copy #737
Static quantization
- Add float8 static quant support #787
- Update how block_size is calculated with Observers #815
- Add a linear observer class and test #807
Float8
- Update benchmarks to be more useful for smaller shapes #615
- Remove unneeded kernel for scale generation #616
- Filter out microbenchmarking overhead in profiling script #629
- Save torch_logs, and attach them to profiling trace #645
- Add option for gpu time in GEMM benchmarks #666
- Add roofline estimation of GEMM + overhead #668
- Make roofline utils reusable #731
- Use
torch.compiler.is_compiling
#739 - Float8 support in AQT #671
- Add static scaling for float8 training #760
- Make roofline script calculate observed overhead #734
- Make Inference and training code independent #808
- Add rowwise scaling option to float8 dynamic quant #819
Bug fixes
- Fix all-gather in 2D with DTensor (WeightWithDynamicFloat8CastTensor) #590
- Fix FP6-LLM API and add
.to(device)
op #595 - Fix linear_activation_tensor dynamic quant #622
- Fix bug with float8 inference_mode #659
- Quantization kernel bug fixes #717
- Cast
local_scale_tensor
to fp32 for precompute of float8 dynamic scaling #713 - Fix affine quantized tensor to device calls #726
- Small fix for micro benchmark code #711
- Fix LR schedule handling for low-bit optimizers #736
- Fix FPX inductor error #790
- Fixed llama model inference #769
Docs
- Add QAT README #597
- Update serialization.rst to include get_model_size_in_bytes import #604
- Clarify details around unwrap_tensor_subclass in README.md #618, #619
- Spelling fixes #662
- Move developer guide file to a folder #681
- Update docs on how to use AUTOQUANT_CACHE #649
- Update pip install command in README #723
- Fix docstring args names #735
- Update README example with correct import of
sparsify_
#741 - Update main and quantization README #745, #747, #757
- Add README for mixed-precision search tool and code refactor #776
- Add performance section to float8 README.md #794
- Make float8 README.md examples standalone #809
- Add KV cache quantization to READMEs #813
- Update main README.md with more current float8 speedup #816
Not user facing
- Fix float8 inference tests and add export test #613
- Reduce atol/rtol for stable tests #617
- Fix version guard in #620, #679, #684
- Fix BC for QAT location #626
- Enable float8 CI on sm89 #587
- Fix Inductor bench BC change #638, #641
- Add CUDA compute capability compile guard #636
- Remove numpy as bitpack dependency #677
- Add PyTorch 2.4 tests in CI #654
- Remove torchao_nightly package #661
- Update licenses in torchao/experimental #720
- Add lint checks for float8 inference #779
New Contributors
- @sayakpaul made their first contribution in #604
- @metascroy made their first contribution in #582
- @raziel made their first contribution in #618
- @nmacchioni made their first contribution in #641
- @Diogo-V made their first contribution in #670
- @mobicham made their first contribution in #605
- @crcrpar made their first contribution in #703
- @ebsmothers made their first contribution in #737
- @a-r-r-o-w made their first contribution in #741
- @kimishpatel made their first contribution in #766
We were able to close about 70% of tasks for 0.5.0, which will now spill over into upcoming releases. We will post a list for 0.6.0 next, which we aim to release at the end of September 2024. We want to follow a monthly release cadence until further notice.
Full Changelog: v0.4.0...v0.5.0-rc1