Skip to content

Commit

Permalink
[Dev] Support Numeric Precision BFloat16 as activation type (#148)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test

* lint fix.

* test cuda i4 kernel

* Refactor copyright notice in i4matmul.hpp

* Refactor BitBLASLinear test module for improved readability and maintainability

* refactor test as version below python 3.9 cannot handle int32 overflow.

* format lint for test

* Refactor test_int4b_fp16_convert.py for improved readability and maintainability

* remove unused design file

* move tile device from package to base

* dummy impl for codegen

* Refactor file structure for ladder_permutate module

* Refactor backend class and fix typos in comments

* Deep refactor Lib related code.

* remove ci pull.

* LintFix

* refactor builder for whl build

* Refactor TIRWrapper.wrap() method to include an assertion for the optimized module

* Refactor lib_generator to set library and source paths

* lint fix

* BitNet vllm integration

* chore: update codespell to version 2.3.0

* Lintfix

* Bump version to 0.0.1.dev13

* lint fix

* disable fast decoding [u]int4xint8 by default.

* optimize from dict design in Hint

* Implement SplitK

* bitnet benchmark generation.

* Add benchmark script for BitNet integration

* AtomicAdd Support

* LintFix

* ci fix when 3rdparty tvm is initialized.

* bug fix for setup

* fix a bug in block reduce

* typo fix

* BUG Fix for block reduce.

* Lint fix

* Refactor block reduce schedule template

* transform branch from bitblas to bitblas_tl

* Fix subproject commit reference in 3rdparty/tvm

* chore: update submodule branch from bitblas to bitblas_tl

* force update config.cmake

* Bug fix

* Fix subproject commit reference in 3rdparty/cutlass

* chore: Add submodule for cutlass library

* update tl cutlass path

* Refactor BitBLASLinear test module for improved readability and maintainability

* format fix

* Copy CUTLASS to the package directory

* Refactor setup.py to include additional TVM header files

* lint fix

* bug fix

* Refactor BitBLASLinear test module for improved readability and maintainability

* Implement Matmul Benchmark Design

* chore: Update BitBLAS Matmul benchmark script

* lint fix

* Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* lint fix

* Benchmark bot test

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* int8 test case

* Refactor compare_benchmark.py to handle missing benchmark results gracefully

* ci fix

* disable ci for test benchmark

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* remove cli installation

* chore: Create virtual environment and install dependencies for benchmark

* chore: Update benchmark workflow to include comparison step

* Lint fix

* upodate tvm cmmit

* Imporve lower warp memory pass

* Bug fix

* Enhance to support warp schedule.

* Enhance LOP3 Instructions

* Enhance LOP3 Instructions

* add test for stage3 propagate

* implement propagate func

* Stage3 Ladder Permutate integration

* get_ladder_stage3_propagate

* comments benchmark scirpts as the setting is too big

* ci fix for benchmark

* lint fix

* chore: Update benchmark workflow to trigger on pull request comments

* Add LDMatrix Transform 3

* Support GPTQ Test

* Fuse BlockReduce Schedule

* Support mma propagate 3

* Support MMA Propagate Stage 3

* Lint Fix

* Merge block reduce for dequantze config.

* fix codeql

* chore: Update submodule reference to latest commit

* chore: Disable common subexpression elimination in TIR passes

* Lint Fix

* 4bit related lop3 updates.

* lint fix

* gptq test fix

* Fix for test

* lint fix

* lint fix

* typofix

* QuantCompress Test

* chore: Refactor quant_compress_impl.py for readability and maintainability

* Enhance docs to update latest works.

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* removed legacy operator

* Refactor weight executors in Matmul class for improved readability and maintainability

* LintFix

* Fix GPTQ Repack with the latest weight transform

* lint fix

* bug fix for rescale dequantize

* test fix

* typo fix

* lint fix

* Set default weight propagate kind into LDMatrixTransform

* lint fix

* bug fix

* bug fix for test

* set default to stage3

* revert change

* lint fix

* case fix

* bug fix

* fix for legalize

* bug fix

* chore: Clear global operator cache before running tests

* revert optimize_stratety into SingleBatchDecodeOnly

* typofix

* update benchmark scripts

* chore: Refactor benchmark scripts and fix typos

* fix for testing

* lint fix

* fix import.

* typo

* operator benchmark

* optimize

* always with shared.dyn

* optimize cache.

* dsl fix

* tqdm

* chore: Add serialize_results method to benchmark_matmul_strategies.py

* fix performance issue for dynamic async copy

* chore: Refactor benchmark_matmul_strategies.py for improved performance and code readability

* bug fix

* update readme

* disable block reduce for int8

* bugfix for bitnet

* annotatte todo.

* lint fix

* regist fast_decode for int8xint4

* Refactor CUDA code to use sm architecture instead of compute architecture

* compress qkv and gate up for bitnet

* improve elementwise schedule

* Refactor BitNet model checkpoint generation scripts

* cross thread reduce for tl

* fix scale only lop3 tensorize instructions.

* bug fix for scale only case

* fix scale for warp memory dequantize

* lint fix

* bug fox

* format

* fix repack from gptqv2

* chore: Enable large files for Hugging Face models

* bump version to dev14

* BF 16 Update

* lint fix

* chore: Update BitBLAS benchmark scripts and fix typos

* chore: Add gptqmodel to test requirements

* remove gptqmodel dep for test

* chore: Remove gptqmodel dependency for test
  • Loading branch information
LeiWang1999 authored Aug 23, 2024
1 parent ef28a5d commit 673290b
Show file tree
Hide file tree
Showing 19 changed files with 410 additions and 46 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and

| **A_dtype** | **W_dtype** | **Accum_dtype** | **Out_dtype** | **BitBLAS Support** | **Tested Platform** |
|:-----------:|:-----------:|:---------------:|:--------------------:|:-------------------:|:----------------------------------------------------:|
| BF16 | BF16 | FP32/FP16 | FP16 | **** | A100(SM_80)/A6000(SM_86) |
| BF16 | FP4_E2M1 | FP32/FP16 | FP16 | **** | A100(SM_80)/A6000(SM_86) |
| BF16 | FP8_E4M3 | FP32/FP16 | FP16 | **** | A100(SM_80)/A6000(SM_86) |
| BF16 | INT8 | FP32/FP16 | FP16 | **** | A100(SM_80)/A6000(SM_86) |
| BF16 | UINT4/INT4 | FP32/FP16 | FP16 | **** | A100(SM_80)/A6000(SM_86) |
| BF16 | UINT2/INT2 | FP32/FP16 | FP16 | **** | A100(SM_80)/A6000(SM_86) |
| BF16 | UINT1 | FP32/FP16 | FP16 | **** | A100(SM_80)/A6000(SM_86) |
| BF16 | NF4 | FP32/FP16 | FP16 | **** | A100(SM_80)/A6000(SM_86) |
| FP16 | FP16 | FP32/FP16 | FP16 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | FP4_E2M1 | FP32/FP16 | FP16 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | FP8_E4M3 | FP32/FP16 | FP16 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.1.dev13
0.0.1.dev14
2 changes: 1 addition & 1 deletion bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@ def new_func(*args, **kwargs):
return decorator


__version__ = "0.0.1.dev13"
__version__ = "0.0.1.dev14"
26 changes: 16 additions & 10 deletions bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import itertools
from tvm.ir.supply import GlobalVarSupply
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils.tensor_adapter import (
np_float2np_bf16,)
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -149,17 +151,21 @@ def map_numpy_type(intype):

numpy_dtype = map_numpy_type(arg.dtype)
if distribution == "uniform":
profile_tensors.append(
tvm.nd.array(
np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(numpy_dtype),
device=device,
))
data_np = np.random.rand(*[var_wrapper(i) for i in arg.shape])
if arg.dtype == "bfloat16":
profile_tensors.append(
tvm.nd.empty(data_np.shape, device=device, dtype=arg.dtype).copyfrom(
np_float2np_bf16(data_np.astype(np.float32))))
else:
profile_tensors.append(tvm.nd.array(data_np.astype(numpy_dtype), device=device))
elif distribution == "onefill":
profile_tensors.append(
tvm.nd.array(
np.ones([var_wrapper(i) for i in arg.shape]).astype(numpy_dtype),
device=device,
))
data_np = np.ones(*[var_wrapper(i) for i in arg.shape])
if arg.dtype == "bfloat16":
profile_tensors.append(
tvm.nd.empty(data_np.shape, device=device,
dtype=arg.dtype).copyfrom(np_float2np_bf16(data_np)))
else:
profile_tensors.append(tvm.nd.array(data_np.astype(numpy_dtype), device=device))
else:
raise ValueError("Not supported distribution: ", distribution)
return profile_tensors
Expand Down
2 changes: 1 addition & 1 deletion bitblas/builder/wrapper/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TIRCUDASourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half",
"bfloat16": "__nv_bfloat162",
"bfloat16": "__nv_bfloat16",
"e4m3_float8": "__nv_fp8_e4m3",
"e5m2_float8": "__nv_fp8_e5m2",
"float64": "double",
Expand Down
6 changes: 4 additions & 2 deletions bitblas/gpu/gemv_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def check_weight_decode_info(weight_decode_info):
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
conditions.append(
weight_decode_info["target_format"] in ["float16", "bfloat16", "int8"])
return all(conditions)

if not check_weight_decode_info(weight_decode_info):
Expand Down Expand Up @@ -223,7 +224,8 @@ def check_weight_decode_info(weight_decode_info):
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
conditions.append(
weight_decode_info["target_format"] in ["float16", "bfloat16", "int8"])
return all(conditions)

if not check_weight_decode_info(weight_decode_info):
Expand Down
3 changes: 2 additions & 1 deletion bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,7 +1626,8 @@ def get_lop3_intrin_group(
Dict[str, str]
A dictionary mapping the names of the intrinsics to their corresponding implementations.
"""
assert out_dtype in ["float16", "int8"]
assert out_dtype in ["float16",
"int8"], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8'.")

dtype_mapping = {"float16": "f16", "int8": "i8", "int32": "i32"}
target_dtype = dtype_mapping[out_dtype]
Expand Down
19 changes: 11 additions & 8 deletions bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def check_last_trait(region: List[Range]):
# When the func is a dequantize like ops, we should consider the M
require_block_reduce = False
# And we only support float16 for now
if hasattr(func.attrs, "dequantize_info") and in_dtype == "float16":
if (hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]):
for arg in func.params:
inp_shape = func.buffer_map[arg].shape
M = inp_shape[0]
Expand Down Expand Up @@ -690,12 +690,14 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
)

assert dtype in [
"bfloat16",
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
], "Only support float16, int8, e4m3_float8, e5m2_float8"
if dtype == "float16":
], "Only support bfloat16, float16, int8, e4m3_float8, e5m2_float8"
# TODO(lei): actually should analyze based on bits instead of dtype
if dtype in ["bfloat16", "float16"]:
ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout
ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
Expand Down Expand Up @@ -723,7 +725,7 @@ def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j):
local_id = kernel_j % 16
return ldmatrix_layout(thread_id, local_id)

if dtype == "float16":
if dtype in ["bfloat16", "float16"]:
ldmatrix_index_map = (
ldmatrix_trans_permutation_16x16_32x8_16x16
if trans else ldmatrix_permutation_16x16_32x8_16x16)
Expand All @@ -732,7 +734,7 @@ def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j):

ldmatrix_index_map = IndexMap.from_func(ldmatrix_index_map, index_dtype=index_dtype)
# TODO(lei): index_dtype should be analyzed from the schedule
row, col = [16, 16] if dtype == "float16" else [16, 32]
row, col = [16, 16] if dtype in ["bfloat16", "float16"] else [16, 32]
inversed_index_map = ldmatrix_index_map.inverse([row, col])
return ldmatrix_index_map, inversed_index_map

Expand All @@ -753,12 +755,13 @@ def shared_32x16_to_mma_32x16_layout(i, j):
return thread_id, local_id

assert dtype in [
"bfloat16",
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
], "Only support float16, int8, e4m3_float8, e5m2_float8"
if dtype == "float16":
if dtype in ["bfloat16", "float16"]:
stage3_layout = shared_32x8_to_mma_32x8_layout
elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
stage3_layout = shared_32x16_to_mma_32x16_layout
Expand All @@ -782,14 +785,14 @@ def ladder_stage3_permutation_16x32_32x16_32x16_16x32(kernel_i, kernel_j):
new_kernel_j = (new_thread_id * 16 + new_local_id) % 32
return new_kernel_i, new_kernel_j

if dtype == "float16":
if dtype in ["bfloat16", "float16"]:
stage3_index_map = ladder_stage3_permutation_16x16_32x8_32x8_16x16
else:
stage3_index_map = ladder_stage3_permutation_16x32_32x16_32x16_16x32

stage3_index_map = IndexMap.from_func(stage3_index_map, index_dtype=index_dtype)
# TODO(lei): index_dtype should be analyzed from the schedule
row, col = [16, 16] if dtype == "float16" else [16, 32]
row, col = [16, 16] if dtype in ["bfloat16", "float16"] else [16, 32]
inversed_index_map = stage3_index_map.inverse([row, col])
return stage3_index_map, inversed_index_map

Expand Down
3 changes: 2 additions & 1 deletion bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,8 @@ def check_weight_decode_info(weight_decode_info):
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
conditions.append(
weight_decode_info["target_format"] in ["bfloat16", "float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info"
Expand Down
44 changes: 44 additions & 0 deletions bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ def unpack_qzeros(qzeros, bits):
return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1)


# For gptqv2 from gptqmodel
def unpack_qzeros_v2(qzeros, bits):
qzeros = qzeros.view(torch.int32)
elems_per_int32 = 32 // bits
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * elems_per_int32),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for col in range(unpacked_zeros.shape[1]):
i = col % elems_per_int32
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i))

# Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303
# NOTE: It appears that casting after the `unpacked_zeros + 1` is important.
return torch.bitwise_and(unpacked_zeros, 2**bits - 1)


def unpack_qweight(qweight, bits):
qweight = qweight.view(torch.int8)
elems_per_int8 = 8 // bits
Expand Down Expand Up @@ -318,6 +337,31 @@ def repack_from_gptq(self, gptq_module):
if self.bias is not None:
self.bias = gptq_module.bias.data.to(torch.float16).contiguous()

def repack_from_gptq_v2(self, gptq_module):
# qweight in gptq old quant linear stored with (out_features, in_features), should be transposed.
qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE)
intweight = unpack_qweight(qweight, self.bits).contiguous()
if self.bitblas_matmul.weight_transform is not None:
qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda()
self.qweight = qweight
# scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed.
scales = gptq_module.scales.T.contiguous().view(self.torch_dtype)
self.scales = scales
# qzeros should be dequantized to int zeros.
intzeros = unpack_qzeros_v2(gptq_module.qzeros, self.bits).T.contiguous()
if self.bitblas_matmul.config.zeros_mode == "original":
self.zeros = intzeros.to(torch.float16).contiguous()
elif self.bitblas_matmul.config.zeros_mode == "rescale":
self.zeros[:, :] = intzeros.to(torch.float16)[:, :] * self.scales[:, :]
elif self.bitblas_matmul.config.zeros_mode == "quantized":
self.zeros = (
torch.Tensor(general_compress(intzeros.T.contiguous().cpu().numpy(), self.bits)).to(
self.qweight.device).to(self.zeros.dtype).contiguous())
else:
raise ValueError(f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}")
if self.bias is not None:
self.bias = gptq_module.bias.data.to(torch.float16).contiguous()

@property
def consistent(self):
return self.is_consitent
Expand Down
13 changes: 12 additions & 1 deletion bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
("float64", "float64"),
("float32", "float32"),
("float16", "float16"),
("bfloat16", "bfloat16"),
("int8", "int8"),
("e4m3_float8", "e4m3_float8"),
("e4m3_float8", "e5m2_float8"),
Expand Down Expand Up @@ -140,7 +141,7 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind],

# TODO(lei): This is a limitation arose by pytorch and llvm
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
if self.A_dtype in ["e4m3_float8", "e5m2_float8", "bfloat16"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)

Expand All @@ -159,6 +160,9 @@ def is_not_fast_decoding_supported():
# if the w_dtype is int4/uint4 and the a_dtype is int8
# we do not require fast decoding
conditions.append(self.W_dtype in ["int4", "uint4"] and self.A_dtype in ["int8"])
# do not support bfloat16 currently
# TODO(lei): should implement to improve the performance
conditions.append(self.A_dtype == "bfloat16")
return any(conditions)

if fast_decoding is not None:
Expand Down Expand Up @@ -214,6 +218,7 @@ def __post_init__(self):

if self.A_dtype == self.W_dtype and self.W_dtype in [
"float16",
"bfloat16",
"int8",
"e4m3_float8",
"e5m2_float8",
Expand All @@ -228,6 +233,7 @@ class Matmul(Operator):
"float64": ("fp", 64),
"float32": ("fp", 32),
"float16": ("fp", 16),
"bfloat16": ("bf", 16),
"int32": ("int", 32),
"uint32": ("uint", 32),
"int16": ("int", 16),
Expand Down Expand Up @@ -260,8 +266,13 @@ def __init__(
if target is None:
target = auto_detect_nvidia_target()
logger.info(f"Auto detected target: {target}")

assert (config.A_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}"

assert (config.W_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported weight dtype {config.W_dtype}"

source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype]

self.source_format = source_format
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/ladder_permutate/ladder_permutate_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def select_implementation(
M: int,
N: int,
datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16",
datatype: Literal["float16", "bfloat16", "int8", "e4m3_float8", "e5m2_float8"] = "float16",
dequantize_bits: int = -1,
storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16",
propagate_kind: Literal["A", "B"] = "B",
Expand Down
25 changes: 20 additions & 5 deletions bitblas/utils/tensor_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]):
else:
raise RuntimeError("Not supported type: ", type(tensor))


def lazy_tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]):
# It additionally needs the ctypes type as torch type
def as_tensor(address, shape, elems_inbytes, torch_type):
arr = (ctypes.c_int8 * elems_inbytes).from_address(
address)
arr = (ctypes.c_int8 * elems_inbytes).from_address(address)
return torch.frombuffer(arr, dtype=torch_type).view(*shape)

if isinstance(tensor, tvm.nd.NDArray):
Expand All @@ -110,21 +110,36 @@ def as_tensor(address, shape, elems_inbytes, torch_type):
else:
raise RuntimeError("Not supported type: ", type(tensor))


def lazy_torch_to_tvm_tensor(tensor):
# It additionally needs the ctypes type as torch type
def as_tensor(address, shape, elems_inbytes, numpy_type):
arr = (ctypes.c_int8 * elems_inbytes).from_address(
address)
arr = (ctypes.c_int8 * elems_inbytes).from_address(address)
return np.frombuffer(arr, dtype=numpy_type).reshape(shape)

if isinstance(tensor, torch.Tensor):
data_ptr = tensor.data_ptr()
shape = tensor.shape
torch_dtype = tensor.dtype
numpy_dtype = str(torch_dtype).replace("torch.", "")
num_elems_inbytes = prod(shape) * tensor.itemsize
num_elems_inbytes = prod(shape) * tensor.itemsize
np_tensor = as_tensor(data_ptr, shape, num_elems_inbytes, numpy_dtype)
tvm_tensor = tvm.nd.array(np_tensor)
return tvm_tensor
else:
raise RuntimeError("Not supported type: ", type(tensor))


def np_float2np_bf16(arr):
"""Convert a numpy array of float to a numpy array
of bf16 in uint16"""
orig = arr.view("<u4")
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
return np.right_shift(orig + bias, 16).astype("uint16")


def np_bf162np_float(arr):
"""Convert a numpy array of bf16 (uint16) to a numpy array
of float"""
u32 = np.left_shift(arr.astype("uint32"), 16)
return u32.view("<f4")
2 changes: 1 addition & 1 deletion bitblas/wrapper/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_TYPE_MAP = {
"float32": "float",
"float16": "half",
"bfloat16": "__nv_bfloat162",
"bfloat16": "__nv_bfloat16",
"e4m3_float8": "__nv_fp8_e4m3",
"e5m2_float8": "__nv_fp8_e5m2",
"float64": "double",
Expand Down
2 changes: 2 additions & 0 deletions integration/BitNet/maint/upload_models.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ git commit -m "Initial commit"

git remote add origin $REMOTE_DIR

huggingface-cli lfs-enable-largefiles .

git fetch origin

git push -f --set-upstream origin main
Loading

0 comments on commit 673290b

Please sign in to comment.