Skip to content

v0.4.0

Compare
Choose a tag to compare
@jcaip jcaip released this 07 Aug 16:48
· 9 commits to llama3benchmark1tmp since this release

v0.4.0

Highlights

We are excited to announce the 0.4 release of torchao! This release adds support for KV cache quantization, quantization aware training (QAT), low bit optimizer support, composing quantization and sparsity, and more!

KV cache quantization (#532)

We've added support for KV cache quantization, showing a peak memory reduction from 19.7 -> 19.2 GB on Llama3-8B at an 8192 context length. We plan to investigate Llama3.1 next.

Quantization-Aware Training (QAT) (#383, #555)

We now support two QAT schemes for linear layers: Int8 per token dynamic activations + int4 per group weights, and int4 per group weights (using the efficient tinygemm int4 kernel after training). Users can access this feature by transforming their models before and after training using the appropriate quantizer, for example:

from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer

# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)

# Convert fake quantize to actual quantize operations
model = qat_quantizer.convert(model)

Initial evaluation results indicate that QAT in torchao can recover up to 96% of quantized accuracy degradation on hellaswag and up to 68% of quantized perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the README and this blog post.

Composing quantization and sparsity (#457, #473)

We've added support for composing int8 dynamic quantization with 2:4 sparsity, using the quantize_ API. We also added SAM benchmarks that show a 7% speedup over standalone sparsity / int8 dynamic quantization here.

from torchao.quantization import quantize_, int8_dynamic_activation_int8_semi_sparse_weight
quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight())

Community Contributions

low-bit optimizer support (#478, #463, #482, #484, #538)

@gau-nernst added implementations for 4-bit, 8-bit, and FP8 Adam with FSDP2/FSDP support. Our API is a drop-in replacement for torch.optim.Adam and can be used as follows:

from torchao.prototype.low_bit_optim import Adam8bit, Adam4bit, AdamFp8
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8


model = ...
optim = Adam8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions

For more information about low bit optimizer support please refer to our README.

Improvements to 4-bit quantization (#517, #552, #544, #479 )

@bdhirsh @jeromeku @yanbing-j @manuelcandales @larryliu0820 added torch.compile support for NF4 Tensor, custom CUDA int4 tinygemm unpacking ops, and several bugfixes to torchao

BC breaking

  • quantize has been renamed to quantize_ #467
# for torchao 0.4
from torchao.quantization import quantize_, int8_weight_only
quantize_(model, int8_weight_only())

# for torchao 0.3
from torchao.quantization import quantize, int8_weight_only
quantize(model, int8_weight_only())
  • apply_sparse_semi_structured has been deprecated in favor of sparsify_ which matches the quantize_ API #473
# for torchao 0.4
from torchao.sparsity import _sparsify, semi_sparse_weight
sparsify_(model, semi_sparse_weight())

# for torchao 0.3
from torchao.sparsity import apply_sparse_semi_structured
apply_sparse_semi_structured(model)

Deprecations

New Features

  • Added kv_cache quantization #532
  • Migrated float8_experimental to torchao.float8, enabling float8 training support #551 #529
  • Added FP5 E2M2 #399
  • Added 4-bit, 8-bit, and FP8 ADAM support #478 #463 #482
  • Added FSDP2 support for low-bit optimizers #484
  • [prototype] mixed-precision quantization and eval framework #531
  • Added int4 weight-only QAT support #555, #383
  • Added custom CUDA tinygemm unpacking ops #415

Improvements

  • Composing quantization and sparsity now uses the unified AQT Layout #498
  • Added default inductor config settings #423
  • Better dtype and device handling for Int8DynActInt4WeightQuantizer and Int4WeightOnlyQuantizer #475 #479
  • Enable model.to for int4/int8 weight only quantized models #486 #522
  • Added more logging to TensorCoreTiledAQTLayout #520
  • Added general fake_quantize_affine op with mask support #492 #500
  • QAT now uses the shared fake_quantize_affine primitive #527
  • Improve FSDP support for low-bit optimizers #538
  • Custom op and inductor decomp registration now uses a decorator #434
  • Updated torch version to no longer require unwrap_tensor_subclass #595

Bug fixes

  • Fixed import for TORCH_VERSION_AFTER_* #433
  • Fixed crash when PYTORCH_VERSION is not defined #455
  • Added torch.compile support for NF4Tensor #544
  • Added fbcode check to fix torchtune in Genie #480
  • Fixed int4pack_mm error #517
  • Fixed cuda device check #536
  • Weight shuffling now runs on CPU for int4 quantization due to a MPS memory issue #552
  • Scale and input now are the same dtype for int8 weight only quantization #534
  • Fixed FP6-LLM API #595

Performance

  • Added segment-anything-fast benchmarks for composed quantization + sparsity #457
  • Updated low-bit Adam benchmark #481

Docs

  • Updated README.md #583 #438 #445 #460
  • Updated installation instructions #447 #459
  • Added more docs for int4_weight_only API #469
  • Added developer guide notebook #588
  • Added optimized model serialization/deserialization doc #524 #525
  • Added new float8 feature tracker #557
  • Added static quantization tutorial for calibration-based techniques #487

Devs

  • Fix numpy version in CI #537
  • trymerge now uploads merge records to s3 #448
  • Updated python version to 3.9 #488
  • torchao no long depends on torch #449
  • benchmark_model now accepts args and kwargs and supports cpu and mps backends #586 #406
  • Add git version suffix to package name #547
  • Added validations to torchao #453 #454
  • Parallel test support with pytest-xdist #518
  • Quantizer now uses logging instead of print #472

Not user facing

  • Refactored _replace_linear_8da4w #451
  • Remove unused code from AQT implementation #476 #440 #441 #471
  • Improved error message for lm_eval script #444
  • Updated HF_TOKEN env variable #427
  • Fixed typo in Quant-LLM in #450
  • Add a test for map_location="cpu" in #497
  • Removed sparse test collection warning #489
  • Refactored layout implementation #491
  • Refactored LinearActQuantizedTensor #542

New Contributors

Full Changelog: v0.3.1-rc1...v0.4.0-rc1

We were able to close about 60% of tasks for 0.4.0, which will now spill over into upcoming releases. We will post a list for 0.5.0 next, which we aim to release at the end of August 2024. We want to follow a monthly release cadence until further notice.