Skip to content

Commit

Permalink
[Dev] Refactor the range of INT Format to (-max_int_value - 1, max_in…
Browse files Browse the repository at this point in the history
…t_value) (#15)

* rename transparency

* dependabot fix

* update transparency.

* update plugin

* remove redundant transparency

* dsl benchmark scirpts

* update submodule.

* remove redundant code.

* remove transparency

* fix propagate map issue

* implement in register dequantize config

* optimize target

* fix tag.

* fix some issues on ampere game device

* finetune with data distribution.

* fill matmul benchmarking scripts

* refactor use_async_copy to bool value

* support af format

* format fix

* support propagate input transform for dequantization.

* update requirements

* update requirements.txt

* update af4 related tests.

* clean test

* naive support for dynamic zeros

* move to bitdistiller

* implement lop3 with zeros cpp test

* implement fast decoding with zeros

* update zero generation support.

* Bump transformers from 4.29.2 to 4.36.0

Bumps [transformers](https://github.com/huggingface/transformers) from 4.29.2 to 4.36.0.
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](huggingface/transformers@v4.29.2...v4.36.0)

---
updated-dependencies:
- dependency-name: transformers
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>

* Bump pillow from 9.4.0 to 10.2.0

Bumps [pillow](https://github.com/python-pillow/Pillow) from 9.4.0 to 10.2.0.
- [Release notes](https://github.com/python-pillow/Pillow/releases)
- [Changelog](https://github.com/python-pillow/Pillow/blob/main/CHANGES.rst)
- [Commits](python-pillow/Pillow@9.4.0...10.2.0)

---
updated-dependencies:
- dependency-name: pillow
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>

* Bump tornado from 6.2 to 6.3.3

Bumps [tornado](https://github.com/tornadoweb/tornado) from 6.2 to 6.3.3.
- [Changelog](https://github.com/tornadoweb/tornado/blob/master/docs/releases.rst)
- [Commits](tornadoweb/tornado@v6.2.0...v6.3.3)

---
updated-dependencies:
- dependency-name: tornado
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>

* Bump scipy from 1.5.3 to 1.11.1

Bumps [scipy](https://github.com/scipy/scipy) from 1.5.3 to 1.11.1.
- [Release notes](https://github.com/scipy/scipy/releases)
- [Commits](scipy/scipy@v1.5.3...v1.11.1)

---
updated-dependencies:
- dependency-name: scipy
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>

* Bump jinja2 from 3.1.2 to 3.1.3

Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.2 to 3.1.3.
- [Release notes](https://github.com/pallets/jinja/releases)
- [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst)
- [Commits](pallets/jinja@3.1.2...3.1.3)

---
updated-dependencies:
- dependency-name: jinja2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>

* Bump pygments from 2.2.0 to 2.15.0

Bumps [pygments](https://github.com/pygments/pygments) from 2.2.0 to 2.15.0.
- [Release notes](https://github.com/pygments/pygments/releases)
- [Changelog](https://github.com/pygments/pygments/blob/master/CHANGES)
- [Commits](pygments/pygments@2.2.0...2.15.0)

---
updated-dependencies:
- dependency-name: pygments
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>

* Bump pygments from 2.13.0 to 2.15.0

Bumps [pygments](https://github.com/pygments/pygments) from 2.13.0 to 2.15.0.
- [Release notes](https://github.com/pygments/pygments/releases)
- [Changelog](https://github.com/pygments/pygments/blob/master/CHANGES)
- [Commits](pygments/pygments@2.13.0...2.15.0)

---
updated-dependencies:
- dependency-name: pygments
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>

* update requirements and matmul.

* support fast decode for int8 related items

* improve pass context

* update benchmark related figures.

* update benchmark readme

* reorganize readme

* refactor readme

* update benchmark readme

* refactor quant linear for bisect

* update tvm submodule

* fix blockIdx related

* update bitditiller related.

* update zero type related test

* implement zero types support

* implement zero types support

* fix lop3 permuteta issue.

* fix weight executor bug.

* improve typing

* resolve performance related items

* add implementation for dequantization with dynamic symbolic

* fix ladder transform related issues.

* improve ladder permutation for dequantization

* enhance dynamic symbolic for matmul_impl

* improve support for dynamic symbolic

* update tvm dependency

* implement operator cache.

* refactor print to logging

* append setup.py and remove tvm pythonpath dependency.

* update ignore

* improve installation scripts

* update scaling benchmark of 1bit

* int8xint1 lop3 support.

* replace with to_torch_func

* license related fix

* update contributing.md

* autogptq support.

* refactor docs

* refactor

* refactor docs

* typo fix

* implement disk cache

* refactor codegen to get_source

* support get weight shape.

* Update dependabot.yml

* Update dependabot.yml

* Update dependabot.yml

* Update dependabot.yml

* Update dependabot.yml

* Update requirements.txt

* Update requirements.txt

* Update requirements.txt

* refactor propagate into transform kind

* Update dependabot.yml

* implement scale and zero layout propagation

* typo fix

* refactor codes

* fix performance issue of dequantize propagate

* refactor print

* fix gemv scale bugs

* refactor ops configs

* improve tensor_adapter

* implement trick wrapper for integration

* code refactor

* SUPPORT.md commit

* spell check

* improve for linting

* overal lint improvements

* Add copyright and license information

* improve contributing

* Fix PYTHONPATH export in installation script and update BitBLAS package

* Update benchmark section in README.md

* Update performance benchmarks and integration details

* Fix typo in README.md

* Refactor index map logging in matmul_analysis.py

* Add .ruff_cache to .gitignore

* Add _tir_u32_to_f4_to_f16 function to quantization module

* Update performance benchmark images

* Update benchmark configurations

* Update benchmark information in README.md

* Refactor code for improved performance and readability

* convolution impl support

* Refactor convolution2d_impl.py and test_auto_normalized_tensorcore.py

* Fix code formatting and remove unnecessary code

* Update TensorCore GEMM Performance Comparison

* Update TensorCore GEMM performance comparison on A100 and RTX4090

* Refactor propagate_inputs method in TensorCorePolicy

* Fix BitBLAS import and remove debug print statements

* Add end-to-end integration with Quantize Inference Kernel for AutoGPTQ and vLLM

* Fix import order and handle exception in benchmark scripts

* Update TVM subproject commit

* Update TileDevice class names in bitblas package

* Update imports in roller module

* Update images

* Update images

* Update end2end_llama_13b_vllm.png

* Update trademark and acknowledgement section

* Update benchmark images for consistent GEMM operations

* Add test case for decoding UInt4 to Float16 with scaling and zeros quantized

* Remove benchmarking code for int4 on a specific target

* Update image files and add new functions for quantization and rasterization

* fix rescale and original lop3.

* Add integration example of FasterTransformers with BitBLAS

* Update integration example of FasterTransformer with BitBLAS

* Update requirements-dev.txt and requirements.txt

* Add LLVM download and extraction functionality

* Update FasterTransformer.gif

* Update BitBLAS version and requirements

* Update BitBLAS import paths and add support for installing and developing TVM

* Add GPU intrinsics module for BitBLAS

* Update requirements-dev.txt and requirements.txt

* Refactor import paths in BitBLAS GPU modules

* Update installation guide in Installation.md

* Refactor MatmulConfig class in matmul.py for improved readability and maintainability

* Refactor MatmulConfig class in matmul.py for improved readability and maintainability

* Refactor MatmulConfig class in matmul.py for improved readability and maintainability

* Update installation guide and QuickStart link in README.md

* Update installation guide and QuickStart link in README.md

* Append Default Schedule Fallback

* Refactor requirements-dev.txt and fix newline issue in arch_base.py

* Fix typo in check_mit_license.sh

* imrpove the target detection.

* Improve target detection and fix typos in code

* Fix auto-inline spacing issue in MatmulTensorizationMMAWithDequantizeInfo class

* Improve target detection and fix typos in code

* transform to submit

* Add support for weight_dtype transformation in MatmulWeightOnlyDequantizeConfig

* Update zeros_type to zeros_mode in code blocks

* update README

* update README

* Fix import errors and update paths in code

* Update variable names in test_bitblas_linear.py and __init__.py

* Update imports and add new function in quantization and cache modules

* Update README with support matrix table

* Update support matrix table and benchmark configurations

* Update support matrix table and benchmark configurations

* Update support matrix table and benchmark configurations

* Update support matrix table and benchmark configurations

* Update support matrix table and benchmark configurations

* Update import statements and add new functions in quantization and cache modules

* Fix default dynamic range for M in MatmulConfig

* Update support matrix table with new tested platforms and Out_dtype column

* Refactor code for mixed-precision matrix multiplication and update support matrix table

* Refactor code for mixed-precision matrix multiplication and update support matrix table

* Update MatmulConfig initialization in QuickStart.md

* Update support matrix table with new tested platforms and INT32/FP16/INT8 support

* Refactor code for mixed-precision matrix multiplication and update support matrix table

* Update link to code implementation in QuickStart.md

* Disable tuning for initial bitblas operator creation

* Update linear transformation description in PythonAPI.md

* Update MatmulConfig in PythonAPI.md

* convert af format to nf

* Enable hardware-aware tuning for bitblas operators

* Refactor code for mixed-precision matrix multiplication and update support matrix table

* Update support matrix table with new tested platforms and INT32/FP16/INT8 support

* Update OperatorConfig.md with matrix multiplication configuration details

* code refactor

* Fix capitalization in QuickStart.md

* update ReadME

* Refactor setup.py to remove unnecessary code and improve readability

* refactor infeatures to infeatures

* update README.md

* Fix incorrect data type mapping in general_matmul.py

* update doc

* Refactor variable names in bitblas_linear.py and bitblas_quant_linear.py

* uncomments some case

* Add BITBLAS_DATABASE_PATH constant to OperatorCache and update load_global_ops_cache function

* Refactor variable names in bitblas_linear.py and bitblas_quant_linear.py

* Refactor variable names in bitblas_linear.py and bitblas_quant_linear.py

* Update dependencies in requirements-dev.txt and requirements.txt

* Refactor variable names in bitblas_linear.py and bitblas_quant_linear.py

* Fix BITBLAS_DATABASE_PATH constant assignment in OperatorCache

* Refactor variable names in bitblas_linear.py and bitblas_quant_linear.py

* Refactor variable names in bitblas_linear.py and bitblas_quant_linear.py

* update install

* Refactor variable names in setup.py and build_tvm function

* append linear benchmark scripts

* simple bug fix

* Update BitBLAS installation instructions for Ubuntu 20.04

* Refactor variable names and add output print statements for debugging

* Refactor variable names and update dependencies

* Update BitBLAS installation instructions for Ubuntu 20.04 and add note about Linux support

* Refactor logging handler and set log level in BitBLAS module

* Bump version to 0.0.1

* Implement BitNET LOP3 Test

* Refactor variable names and update dependencies

* Refactor variable names and update dependencies in quantization module

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Lingxiao Ma <[email protected]>
  • Loading branch information
3 people authored Apr 15, 2024
1 parent b952851 commit 16e1f99
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 63 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ BitBLAS achieves exceptional performance across a variety of computational patte
</div>



- TensorCore FP16/INT8 GEMM Performance Vs. Vendor Library on A100 and RTX4090

<div>
Expand Down Expand Up @@ -78,7 +79,6 @@ We are continuously expanding the support matrix. If you have any specific requi

- [Customization](./docs/ExtendOperatorsWithDSL.md): BitBLAS supports implementing customized mixed-precision DNN operations rather than matrix multiplication with the flexible DSL (TIR Script).


## Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
Expand All @@ -90,4 +90,3 @@ This project has adopted the Microsoft Open Source Code of Conduct. For more inf
## Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

4 changes: 2 additions & 2 deletions integration/pytorch/test_bitblas_quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def gen_quant4(k, n, groupsize=-1):
maxq = 2**4 - 1
maxq = 2**4
w = torch.randn((k, n), dtype=torch.half, device="cpu")

original_w = w.clone()
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_quantization_accuracy(m, in_features, out_features, bits, group_size, b

if group_size == -1:
group_size = in_features
zeros = torch.full((in_features // group_size, out_features), 7, dtype=torch.int32)
zeros = torch.full((in_features // group_size, out_features), 8, dtype=torch.int32)

bitblas_zeros = zeros.clone().T
cuda_old_linear = CudaOldQuantLinear(
Expand Down
113 changes: 83 additions & 30 deletions python/bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tvm.script import tir as T
from typing import Dict, Literal
from bitblas.quantization import (
_tir_packed_int_to_int_convert,
_tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert,
_tir_packed_to_unsigned_convert_with_zeros,
Expand All @@ -19,7 +20,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
#pragma unroll
for (int i = 0; i < (N / 2); i++)
Expand Down Expand Up @@ -55,7 +56,7 @@
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
Expand Down Expand Up @@ -97,7 +98,7 @@
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
Expand Down Expand Up @@ -139,7 +140,7 @@
static constexpr uint BOTTOM_MASK = 0x000f000f;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
// Minus 7 to scale the value to signed
static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;
uint const i4s = *reinterpret_cast<uint *>(_i4s);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
Expand Down Expand Up @@ -217,7 +218,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
Expand Down Expand Up @@ -258,7 +259,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
Expand Down Expand Up @@ -300,7 +301,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
Expand Down Expand Up @@ -337,7 +338,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00030003;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400;
int16_t const i2s_i16 = *reinterpret_cast<int16_t *>(_i2s);
// decode 2 elems at one time.
// interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0}
Expand Down Expand Up @@ -366,15 +367,15 @@
"""

decode_i1_to_f16 = """
template <typename T1, typename T2, bool isSigned = false>
__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
template <typename T1, typename T2>
__device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
Expand All @@ -392,26 +393,41 @@
template <typename T1, typename T2>
__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
{
decode_i1b_to_f16<T1, T2, true>(_i1s, B_local_decode, N);
}
uint *h = reinterpret_cast<uint *>(B_local_decode);
template <typename T1, typename T2>
__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8)
{
decode_i1b_to_f16<T1, T2, false>(_i1u, B_local_decode, N);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
}
}
"""

decode_i1_to_f16_scale = """
template <typename T1, typename T2, typename T3, bool isSigned = false>
__device__ void decode_i1b_to_f16_scale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr)
template <typename T1, typename T2, typename T3>
__device__ void decode_i1u_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
Expand All @@ -431,17 +447,41 @@
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
decode_i1b_to_f16_scale<T1, T2, T3, true>(_i1s, B_local_decode, N, scale);
}
template <typename T1, typename T2, typename T3>
__device__ void decode_i1u_to_f16_scale(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8)
{
decode_i1b_to_f16_scale<T1, T2, T3, false>(_i1u, B_local_decode, N, scale);
uint *h = reinterpret_cast<uint *>(B_local_decode);
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
T3 const scale_r = *scale;
uint const packed_scales = __pack_half2(scale_r, scale_r);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0));
}
}
"""

decode_i1_to_f16_scale_zeros_original = """
template <typename T1, typename T2, typename T3, typename T4, bool isSigned = false>
__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr)
Expand All @@ -451,7 +491,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
Expand Down Expand Up @@ -491,7 +531,7 @@
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
Expand Down Expand Up @@ -538,12 +578,14 @@
static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1
static constexpr uint I8s_MAGIC_NUM = 0x00000000;
static constexpr uint MEDIAN_NUM = 0x00000000;
static constexpr uint TRANSFORM_SUBTRACT = 0x01010101;
for (int i = 0; i < N / 4; i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(i8s[i])
: "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsubss4(__vaddss4(i8s[i], i8s[i]), TRANSFORM_SUBTRACT);
}
}
Expand Down Expand Up @@ -709,7 +751,10 @@ def get_fast_decode_intrin(
if with_zeros and zeros_mode == "quantized":
decode_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)
elif source_format == "int":
decode_func = _tir_packed_to_signed_convert(storage_type, storage_nbit)
if source_bit == 1:
decode_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit)
else:
decode_func = _tir_packed_to_signed_convert(storage_type, storage_nbit)
elif source_format == "uint":
decode_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)
else:
Expand Down Expand Up @@ -1379,7 +1424,7 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16),
source_bit=2, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
)

LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_")
Expand All @@ -1389,6 +1434,14 @@ def fast_decode_impl(
source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16),
)

LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i1_to_int8_to_i8_l16_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=1, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
)


LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN,
Expand Down
2 changes: 1 addition & 1 deletion python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
if source_format == "int":
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2**(bit - 1) - 1
maxq = 2**(bit - 1)
# Clamp weight values to be within the quantizable range and adjust
weight = torch.clamp(weight, -maxq, maxq).int() + maxq
else:
Expand Down
46 changes: 32 additions & 14 deletions python/bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bitblas.ops.operator import TransformKind
from bitblas.gpu.matmul_analysis import get_propagate_map
from bitblas.quantization import (
_tir_packed_int_to_int_convert,
_tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert,
_tir_u32_to_f4_to_f16,
Expand Down Expand Up @@ -76,8 +77,13 @@ def decode_func(n, k):
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "int":
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
if bit == 1:
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "fp":
w = _tir_u32_to_f4_to_f16(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
Expand All @@ -91,6 +97,8 @@ def decode_func(n, k):
else:
raise ValueError("Unsupported source_format: {}".format(source_format))



if not with_scaling:
return w

Expand Down Expand Up @@ -236,12 +244,17 @@ def decode_func(n, k):
dtype=in_dtype,
)
elif source_format == "int":
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
if bit == 1:
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "fp":
w = _tir_u32_to_f4_to_f16(
bit,
Expand Down Expand Up @@ -417,12 +430,17 @@ def decode_func(n, k):
dtype=in_dtype,
)
elif source_format == "int":
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
if bit == 1:
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "fp":
w = _tir_u32_to_f4_to_f16(
bit,
Expand Down
Loading

0 comments on commit 16e1f99

Please sign in to comment.