Skip to content

v0.5.0

Compare
Choose a tag to compare
@andrewor14 andrewor14 released this 08 Sep 17:18
· 224 commits to main since this release

Highlights

We are excited to announce the 0.5 release of torchao! This release adds support for memory efficient inference, float8 training and inference, int8 quantized training, HQQ, automatic mixed-precision quantization through bayesian optimization, sparse marlin, and integrations with HuggingFace, SGLang, and diffusers.

Memory Efficient Inference Support #738

We've added support for Llama 3.1 to the llama benchmarks in TorchAO and added new features and improvements as a proof of concept for memory efficient inference. These additions allow us to to do 130k context length inference with Llama 3.1-8B with only 18.91 GB memory if we combine with kv cache quantization, int4 weight only quantization and linear causal mask.

General savings depend on technique and context length as can be seen in the following graph:
image

Float8 Training #551

torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.

With torch.compile on, current results show throughput speedups of up to 1.5x on 128 H100 GPU LLaMa 3 70B pretraining jobs (details)

from torchao.float8 import convert_to_float8_training
convert_to_float8_training(m, module_filter_fn=...)

And for an end-to-minimal training recipe of pretraining with float8, you can check out torchtitan.

Float8 Inference #740 #819

We have introduced two new quantization APIs for Float8 inference:

  1. Float8 Weight-Only Quantization: A new quant_api float8_weight_only() has been added to apply float8 weight-only symmetric per-channel quantization to linear layers.

  2. Float8 Dynamic Activation and Weight Quantization: A new quant_api float8_dynamic_activation_float8_weight() has been introduced to apply float8 dynamic symmetric quantization to both activations and weights of linear layers. By default PerTensor scaling. We have also added an option to do PerRow scaling of both activations and weights. By computing scales at a finer granularity, it can potentially reduce the overall quantization error and increase performance by reducing dynamic quantization overhead.

Example usage:

import torch
from torchao.quantization import quantize_, float8_weight_only, float8_dynamic_activation_float8_weight, PerRow

# Create a model
model = YourModel()

# Apply float8 weight-only quantization
quantize_(model, float8_weight_only())

# Apply float8 dynamic activation and weight quantization
quantize_(model, float8_dynamic_activation_float8_weight())

# Apply PerRow scaling to weight and activations
quantize_(linear_module, float8_dynamic_activation_float8_weight(granularity=PerRow())) 

Notes:

  • These new APIs are designed to work with PyTorch 2.5 and later versions.
  • float8_dynamic_activation_float8_weight requires CUDA devices with compute capability 8.9 or higher for hardware acceleration.

Int8 quantized training #644 #748

@gau-nernst introduced 2 experimental works on training using INT8.

  • INT8 quantized training (#644): weight is quantized to INT8 during the whole duration of training to save memory. Compute remains in high precision. To train the model effectively with only quantized weights, we use stochastic rounding for weight update. Right now, memory saving is not too competitive compared to compiled BF16 baseline.
  • INT8 mixed-precision training (#748): weight is kept in the original high precision, but weight and activation are dynamically quantized to INT8 during training to utilize INT8 tensor cores. We observe up to 70% speedup for Llama2 pre-training on 4090, and 20% speedup for Llama3 pre-training on 8x A100 with FSDP2.
from torchao.quantization import quantize_
from torchao.prototype.quantized_training import int8_weight_only_quantized_training, int8_mixed_precision_training

model = YourModel()

# apply INT8 quantized training
quantize_(model, int8_weight_only_quantized_training())

# apply INT8 mixed-precision training
quantize_(model, int8_mixed_precision_training())

For more information and benchmark results, see README and the respective PR (#644 and #748)

HQQ Integration in torchao #605 #786

hqq is added to existing torchao APIs, it gives improvements on model accuracy and leverages the existing efficient kernels in torchao. We enabled hqq for int4_weight_only API:

quantize_(model, int4_weight_only(group_size, use_hqq=True)

We also added this to the uintx api for accuracy experiments (current uintx kernels are slow):

quantize_(model, uintx_weight_only(torch.uint2, group_size, use_hqq=True)

Automatic Mixed-Precision Quantization through Bayesian Optimization #592, #694

We provided a Bayesian Optimization (BO) tool leveraging Ax to auto search mixed-precision weight-only quantization configuration, i.e., bit width and group size of intN_weight_only(bit_width, group_size) for each layer. It also includes a sensitivity analysis tool to calculate layer-wise average Hessian trace and average fisher information matrix trace, which is an optional step to customize and improve BO search.

To optimize for model accuracy under a model size constraint (GB):

python --BO_acc_modelsize.py --checkpoint=/tmp/Meta-Llama-3-8B --model_size_constraint=6.0

To optimize for inference throughput under a model perplexity constraint:

python --BO_acc_throughput.py --checkpoint=/tmp/Meta-Llama-3-8B --ppl_constraint=7.5

For more detailed usage, please refer to this README. The mixed-precision quantization searched by this tool reduces 20.1% model size with 2.8% perplexity reduction, and improves 15.1% inference throughput with 3.2% perplexity reduction on the Llama3-8B model compared to int8 uniform quantization.

Sparse Marlin #621, #733

@Diogo-V added sparse-marlin, a W4AFP16 2:4 sparse kernel, support to TorchAO.
On Meta LLama3, we observe a 25% tok/s increase (180 -> 226) compared to our existing int4-wo implementation.

from torchao.quantization.quant_api import quantize_, int4_weight_only
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
Model Technique Tokens/Second Memory Bandwidth (GB/s) Peak Memory (GB) Model Size (GB)
Llama-3-8B Base (bfloat16) 95.64 1435.54 16.43 15.01
int8dq 8.61 64.75 9.24 7.52
int8wo 153.03 1150.80 10.42 7.52
int4wo-64 180.80 763.33 6.88 4.22
int4wo-64-sparse-marlin 226.02 689.20 5.32 3.05

HuggingFace Integration

torchao is integrated into huggingface: https://huggingface.co/docs/transformers/main/en/quantization/torchao now you can use int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight through TorchAoConfig in huggingface. Currently available in huggingface main branch only.

SGLang Integration

torchao is also integrated into sglang (sgl-project/sglang#1341) for llama3 model, you can try out with:

python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128

Supported configurations are ["int4wo-<group_size>", "int8wo", "int8dq", "fp8wo" (only available in torchao 0.5+)]

diffusers Integration

diffusers-torchao provides end-to-end inference and experimental training recipes to use torchao with diffusers in this repo. We demonstrate 53.88% speedup on Flux.1-Dev* and 27.33% speedup on CogVideoX-5b when comparing compiled quantized models against their standard bf16 counterparts.

BC Breaking

Add layout option to woq int4 api #670

# torchao 0.4.0
from torchao.quantization import quantize_, int4_weight_only
quantize_(my_model, int4_weight_only(inner_k_tiles=8))

# torchao 0.5.0
from torchao.quantization import quantize, int4_weight_only
quantize_(my_model, int4_weight_only(layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)))

Refactor QAT to use tensor subclasses #585

We refactored QAT to use tensor subclasses instead of module swap. This works well with torchtune and FSDP2, but currently lacks support for FSDP1 and DDP. As a fallback for these distribution strategies, please continue to use the old module swap flows.

# torchao 0.4.0: This uses the module swap flow
# torch 0.5.0 + FSDP2: This uses the tensor subclass flow
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
quantizer = Int8DynActInt4WeightQATQuantizer()
model = quantizer.prepare(model)
train(model)
model = quantizer.convert(model)

# torchao 0.5.0 + DDP or FSDP1: This uses the module swap flow
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap
quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap()
model = quantizer.prepare(model)
train(model)
model = quantizer.convert(model)

Deprecations

New Features

  • Optimizer CPU offload for single GPU training #584
  • Add support for save quantized checkpoint in llama code #553
  • Intx quantization tensor subclass #468
  • Add superblock to sparse/prototype #660
  • Add AffineQuantizedObserver #650
  • Add BSR subclass + torch.compile and clean up superblock #680
  • Add HQQ support #605
  • Add performance profiler #690
  • Add experimental INT8 quantized training #644
  • Add high-level operator interface #708
  • Add sparse marlin 2:4 gemm op #733
  • Example for GPTQ-like calibration flow #721
  • Llama3.1 and KV cache quantization #738
  • Add float8 weight only and weight + dynamic activation #740
  • Add Auto-Round support #581

Mixed-Precision Quantization

  • Add sensitivity analysis tool for layer-wise FIT and Hessian trace #592
  • Bayesian optimization tool for mixed precision quantization #694

Improvements

  • Move sam eval from scripts to torchao/_models #591
  • QOL improvements to float8 gemm benchmark #596
  • Move lowbit universal kernels from torchaccel to torchao #582
  • Refactor autoquant to use AQT #609
  • Add support for using AffineQuantizedTensor with weights_only=True #630
  • Move Uintx out of prototype for future extension #635
  • Refactor _quantized_linear for better extensibility #634
  • Update micro benchmarking code for AQT #673
  • Refactor superblock code + add final benchmark/eval scripts #691
  • Relax QAT dtype assertion #692
  • Add option to move param to device before quantization #699
  • Add gpu benchmarking script #192
  • Enable to(device=device_name) for Uintx #722
  • Make torchao's llama model trainable #728
  • Specify output dtype to torch.float32 in _foreach_norm #727
  • Add semi-structured sparsity to hf eval #576
  • Use torch.uint1 to torch.uint7 for Uintx tensor subclass #672
  • Add AdamW to CPUOffloadOptimizer default #742
  • Make developer experience better for extending AQT #749
  • Add back QAT module swap API #762
  • Refactor quant_llm to work with affine quantized tensor #772
  • Move iOS benchmarking infra code to torchao #766
  • Add CPU bandwidth benchmark #773
  • Update method names to support intx and floatx changes #775
  • Add implementation for torchao::parallel_for backends #774
  • Add Llama2-7B finetune benchmarks for low-bit optimizers #746
  • Fix Adam4bit support on PyTorch 2.3 and 2.4 and update AdamFp8 torch requirement #755
  • Improve compile time + fix PyTorch 2.3 support for 4-bit optim #812
  • Allow quantized linear registration in a different file #783
  • Add 2bit, 5bit packing routines #797, #798
  • Freeze dataclass in nf4, prep for better pt2 support #799
  • Format and lint nf4 file and test #800
  • Move more utils to TorchAOBaseTensor #784
  • Add more information to quantized linear module and added some logs #782
  • Add int4 mode to autoquant #804
  • Add uintx quant to generate and eval #811
  • Move non-NF4 tensor to device prior to quantization on copy #737

Static quantization

  • Add float8 static quant support #787
  • Update how block_size is calculated with Observers #815
  • Add a linear observer class and test #807

Float8

  • Update benchmarks to be more useful for smaller shapes #615
  • Remove unneeded kernel for scale generation #616
  • Filter out microbenchmarking overhead in profiling script #629
  • Save torch_logs, and attach them to profiling trace #645
  • Add option for gpu time in GEMM benchmarks #666
  • Add roofline estimation of GEMM + overhead #668
  • Make roofline utils reusable #731
  • Use torch.compiler.is_compiling #739
  • Float8 support in AQT #671
  • Add static scaling for float8 training #760
  • Make roofline script calculate observed overhead #734
  • Make Inference and training code independent #808
  • Add rowwise scaling option to float8 dynamic quant #819

Bug fixes

  • Fix all-gather in 2D with DTensor (WeightWithDynamicFloat8CastTensor) #590
  • Fix FP6-LLM API and add .to(device) op #595
  • Fix linear_activation_tensor dynamic quant #622
  • Fix bug with float8 inference_mode #659
  • Quantization kernel bug fixes #717
  • Cast local_scale_tensor to fp32 for precompute of float8 dynamic scaling #713
  • Fix affine quantized tensor to device calls #726
  • Small fix for micro benchmark code #711
  • Fix LR schedule handling for low-bit optimizers #736
  • Fix FPX inductor error #790
  • Fixed llama model inference #769

Docs

  • Add QAT README #597
  • Update serialization.rst to include get_model_size_in_bytes import #604
  • Clarify details around unwrap_tensor_subclass in README.md #618, #619
  • Spelling fixes #662
  • Move developer guide file to a folder #681
  • Update docs on how to use AUTOQUANT_CACHE #649
  • Update pip install command in README #723
  • Fix docstring args names #735
  • Update README example with correct import of sparsify_ #741
  • Update main and quantization README #745, #747, #757
  • Add README for mixed-precision search tool and code refactor #776
  • Add performance section to float8 README.md #794
  • Make float8 README.md examples standalone #809
  • Add KV cache quantization to READMEs #813
  • Update main README.md with more current float8 speedup #816

Not user facing

  • Fix float8 inference tests and add export test #613
  • Reduce atol/rtol for stable tests #617
  • Fix version guard in #620, #679, #684
  • Fix BC for QAT location #626
  • Enable float8 CI on sm89 #587
  • Fix Inductor bench BC change #638, #641
  • Add CUDA compute capability compile guard #636
  • Remove numpy as bitpack dependency #677
  • Add PyTorch 2.4 tests in CI #654
  • Remove torchao_nightly package #661
  • Update licenses in torchao/experimental #720
  • Add lint checks for float8 inference #779

New Contributors

We were able to close about 70% of tasks for 0.5.0, which will now spill over into upcoming releases. We will post a list for 0.6.0 next, which we aim to release at the end of September 2024. We want to follow a monthly release cadence until further notice.

Full Changelog: v0.4.0...v0.5.0-rc1