Skip to content

Commit

Permalink
[FP8] Support FP8 MatrixCore Code gen and related test (#29)
Browse files Browse the repository at this point in the history
* Add Str Parse library to requirements.txt and requirements-dev.txt

* Support quantized zero types for uint2.

* Support FP8 Codegen

* Add support for e4m3_float8 and e5m2_float8 types in CUDA wrapper

* Support FP8
  • Loading branch information
LeiWang1999 authored Apr 29, 2024
1 parent d536dde commit c01a3a7
Show file tree
Hide file tree
Showing 16 changed files with 194 additions and 39 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 933 files
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ Some of the key features of BitBLAS include:
- BitBLAS first implemented $W_{INT2}A_{INT8}$ GEMV/GEMM in [BitNet-b1.58](https://arxiv.org/abs/2402.17764) with 8x/2x speedup over cuBLAS $W_{FP16}A_{FP16}$ on A100, please checkout [op_benchmark_a100_int2_scaling](https://github.com/microsoft/BitBLAS/blob/main/images/figures/op_benchmark_a100_int2_scaling.png) for detailed benchmark results. Please checkout [BitNet-b1.58 integration](https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet) for the integration with the 3rdparty reproduced BitNet-b1.58 model.
- Support customizing mixed-precision DNN operations for your specific scenarios via the flexible DSL (TIR Script).

## Latest News

- 2024.04.19: BitBLAS is now open source! We are excited to announce that BitBLAS, a high-performance library for mixed-precision DNN model deployment, is now available to the public.
- 2024.04.30: BitBLAS now support

## Integration Example of FasterTransformer with BitBLAS
![FasterTransformer Integration](images/gif/FasterTransformer.gif)

## Benchmark Summary


## Integration Example of FasterTransformer with BitBLAS
![FasterTransformer Integration](images/gif/FasterTransformer.gif)

Expand Down Expand Up @@ -63,6 +74,8 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and
| INT8 | UINT4/INT4 | INT32 | FP32/INT32/FP16/INT8 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| INT8 | UINT2/INT2 | INT32 | FP32/INT32/FP16/INT8 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| INT8 | UINT1 | INT32 | FP32/INT32/FP16/INT8 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP8_E4M3 | FP8_E4M3 | FP32 | FP32/FP16 | **** | RTX 4090(SM_89) |
| FP8_E5M2 | FP8_E5M2 | FP32 | FP32/FP16 | **** | RTX 4090(SM_89) |

We are continuously expanding the support matrix. If you have any specific requirements, please feel free to open an issue or PR.

Expand Down
17 changes: 14 additions & 3 deletions python/bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,27 @@ def var_wrapper(v):
else:
raise ValueError("Not supported type: ", type(func))

def map_numpy_type(intype):
typemap = {
'e4m3_float8': 'float8_e4m3fn',
'e5m2_float8': 'float8_e5m2',
}
if intype in typemap:
return typemap[intype]
else:
return intype

numpy_dtype = map_numpy_type(arg.dtype)
if distribution == "uniform":
profile_tensors.append(
tvm.nd.array(
np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(arg.dtype),
np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(numpy_dtype),
device=device,
))
elif distribution == "onefill":
profile_tensors.append(
tvm.nd.array(
np.ones([var_wrapper(i) for i in arg.shape]).astype(arg.dtype),
np.ones([var_wrapper(i) for i in arg.shape]).astype(numpy_dtype),
device=device,
))
else:
Expand Down Expand Up @@ -245,7 +256,7 @@ def tvm_callback_cuda_postproc(code, _):
try:
latency = cpresult.profile()
except Exception as e_mesg:
logger.debug("Evaluation with config failed: ", e_mesg)
logger.debug(f"Evaluation with config failed {e_mesg}")
continue
logger.info("Evaluation with config {}".format(config))
logger.info("Time cost of this config: {:.3f} ms".format(latency))
Expand Down
7 changes: 3 additions & 4 deletions python/bitblas/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):


def get_bytes(dtype: Union[DataType, str]) -> int:
num = re.findall(r"\d+", dtype)
if len(num) != 1:
raise ValueError(f"Cannot get bytes from {dtype}")
return int(num[0]) // 8
if isinstance(dtype, str):
dtype = DataType(dtype)
return int(dtype.bits) // 8


def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]:
Expand Down
21 changes: 15 additions & 6 deletions python/bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def get_tensorized_func_and_tags(
allow_gemv: bool = False,
) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]:
from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel
get_wmma_intrin_group,)
get_mma_intrin_group,)
"""
transform function to matmul if necessary (e.g. transform conv2d with im2col)
"""
Expand Down Expand Up @@ -607,14 +607,18 @@ def check_last_trait(region: List[Range]):

block_stmt = sch.get(main_block)
if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70:
# TODO(lei): we should consider the dtype of the input a and b
# instead of assuming both a and b share the same dtype.
# As the tensorcore may supports e4m3_float8 * e5m2_float8
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
try:
_ = get_wmma_intrin_group(
in_dtype=in_dtype,
_ = get_mma_intrin_group(
a_dtype=in_dtype,
b_dtype=in_dtype,
out_dtype=out_dtype,
)
except Exception:
logger.debug("Cannot find the corresponding wmma intrin group")
logger.debug("Cannot find the corresponding mma intrin group")
return func, None

# reindex and transform functions
Expand Down Expand Up @@ -651,11 +655,16 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b,
)

assert dtype in ["float16", "int8"], "Only support float16 for now"
assert dtype in [
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
], "Only support float16, int8, e4m3_float8, e5m2_float8"
if dtype == "float16":
ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout
ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
elif dtype == "int8":
elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
# int8 mma only support 32x16 to 16x32 layout
if matrix_name == "A" and trans is False:
ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a
Expand Down
6 changes: 4 additions & 2 deletions python/bitblas/gpu/matmul_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def store_output(block_outer, write_buffer_idx):
intrin_group = get_mma_intrin_group(
load_scope="shared.dyn",
store_scope="shared.dyn",
in_dtype=str(dtype_a),
a_dtype=str(dtype_a),
b_dtype=str(dtype_b),
out_dtype=str(dtype_c),
trans_a=is_transpose_a,
trans_b=is_transpose_b,
Expand Down Expand Up @@ -396,7 +397,8 @@ def check_has_dynamic(func: tir.PrimFunc):
intrin_group = get_mma_intrin_group(
load_scope=shared_scope,
store_scope=shared_scope if cache_write_required else "global",
in_dtype=intrin_info.in_dtype,
a_dtype=intrin_info.in_dtype,
b_dtype=intrin_info.in_dtype,
out_dtype=intrin_info.out_dtype,
trans_a=intrin_info.trans_a,
trans_b=intrin_info.trans_b,
Expand Down
9 changes: 6 additions & 3 deletions python/bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def check_weight_decode_info(weight_decode_info):
intrin_group = get_mma_intrin_group(
load_scope=shared_scope,
store_scope=shared_scope if cache_write_required else "global",
in_dtype=intrin_info.in_dtype,
a_dtype=intrin_info.in_dtype,
b_dtype=intrin_info.in_dtype,
out_dtype=intrin_info.out_dtype,
trans_a=intrin_info.trans_a,
trans_b=intrin_info.trans_b,
Expand Down Expand Up @@ -654,7 +655,8 @@ def check_weight_decode_info(weight_decode_info):
intrin_group = get_mma_intrin_group(
load_scope=shared_scope,
store_scope=shared_scope if cache_write_required else "global",
in_dtype=intrin_info.in_dtype,
a_dtype=intrin_info.in_dtype,
b_dtype=intrin_info.in_dtype,
out_dtype=intrin_info.out_dtype,
trans_a=intrin_info.trans_a,
trans_b=intrin_info.trans_b,
Expand Down Expand Up @@ -1143,7 +1145,8 @@ def check_weight_decode_info(weight_decode_info):
intrin_group = get_mma_intrin_group(
load_scope=shared_scope,
store_scope=shared_scope if cache_write_required else "global",
in_dtype=intrin_info.in_dtype,
a_dtype=intrin_info.in_dtype,
b_dtype=intrin_info.in_dtype,
out_dtype=intrin_info.out_dtype,
trans_a=intrin_info.trans_a,
trans_b=intrin_info.trans_b,
Expand Down
36 changes: 32 additions & 4 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@

WORKSPACE_SIZE = 1024 * 1024 * 256

# TODO(lei): This should be improved into a general
# Method to get the consistent compute patterns.
NATIVE_COMPUTE_PATTERNS = [
# A_dtype, W_dtype
("float64", "float64"),
("float32", "float32"),
("float16", "float16"),
("int8", "int8"),
("e4m3_float8", "e4m3_float8"),
("e4m3_float8", "e5m2_float8"),
("e5m2_float8", "e4m3_float8"),
("e5m2_float8", "e5m2_float8"),
]


def is_native_compute(A_dtype, W_dtype) -> bool:
return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS


class OPExecutorCPU:

Expand Down Expand Up @@ -150,8 +168,15 @@ def __post_init__(self):
if self.with_zeros is None:
object.__setattr__(self, "with_zeros", False)

if self.A_dtype == self.W_dtype and self.W_dtype in ["float16", "int8"]:
if self.A_dtype == self.W_dtype and self.W_dtype in [
"float16", "int8", "e4m3_float8", "e5m2_float8"
]:
object.__setattr__(self, "storage_dtype", self.W_dtype)
# TODO(lei): This is a limitation arose by pytorch
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)


class Matmul(Operator):
Expand All @@ -176,6 +201,8 @@ class Matmul(Operator):
"nf4": ("nf", 4),
"fp8_e5m2": ("fp", 8),
"fp4_e2m1": ("fp", 4),
"e4m3_float8": ("fp", 8), # "e4m3_float8" is a trick for "float8_e4m3fn"
"e5m2_float8": ("fp", 8),
}

def __init__(
Expand Down Expand Up @@ -316,7 +343,7 @@ def _build_default_module(self, target: Target):
self._build_runtime_module(target)

def _select_implementation(self):
if self.A_dtype == self.W_dtype:
if is_native_compute(self.A_dtype, self.W_dtype):
return consistent_implementation(
M=self.M,
N=self.N,
Expand Down Expand Up @@ -446,8 +473,9 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args.append(bias)
args.append(output)

m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)
if self.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)

if self.lib is None:
self._forward_from_torch_func(*args)
Expand Down
4 changes: 2 additions & 2 deletions python/bitblas/ops/impl/ladder_permutate_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def select_implementation(
M: int,
N: int,
datatype: Literal["float16", "int8"] = "float16",
datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16",
dequantize_bits: int = -1,
storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16",
propagate_kind: Literal["A", "B"] = "B",
Expand All @@ -23,7 +23,7 @@ def select_implementation(
# This is trick to get the basic tile size for the current datatype
# as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8
l = r = 16 # noqa: E741
if datatype == "int8":
if datatype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741
intra_index_map, _ = get_propagate_map(
transpose_matrix, dtype=datatype, matrix_name=propagate_kind)
Expand Down
6 changes: 2 additions & 4 deletions python/bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def decode_func(n, k):
else:
raise ValueError("Unsupported source_format: {}".format(source_format))



if not with_scaling:
return w

Expand Down Expand Up @@ -187,7 +185,7 @@ def matmul_nt_dequantize_b_propagate_b(
M = tvm.te.var("m")

l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741

_, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B")
Expand Down Expand Up @@ -358,7 +356,7 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b(
M = tvm.te.var("m")

l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741
_, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A")
A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype)
Expand Down
6 changes: 3 additions & 3 deletions python/bitblas/ops/impl/matmul_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def matmul_nt_propagate_a(
if not isinstance(M, int):
M = tvm.te.var("m")
l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741

_, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A")
Expand Down Expand Up @@ -171,7 +171,7 @@ def matmul_nt_propagate_b(
if not isinstance(M, int):
M = tvm.te.var("m")
l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741

_, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B")
Expand Down Expand Up @@ -232,7 +232,7 @@ def matmul_nt_propagate_a_propagate_b(
if not isinstance(M, int):
M = tvm.te.var("m")
l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741

A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype)
Expand Down
2 changes: 1 addition & 1 deletion python/bitblas/ops/impl/param_permutate_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def select_implementation(
# This is trick to get the basic tile size for the current datatype
# as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8
l = r = 16 # noqa: E741
if datatype == "int8":
if datatype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741
if group_size == -1:
group_size = N
Expand Down
14 changes: 13 additions & 1 deletion python/bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,27 @@ def var_warpper(v):
else:
raise RuntimeError("Not supported type: ", type(v))

def map_numpy_type(intype):
typemap = {
'e4m3_float8': 'float8_e4m3fn',
'e5m2_float8': 'float8_e5m2',
}
if intype in typemap:
return typemap[intype]
else:
return intype

profile_tensors = []
for param in func.params:
if param not in func.buffer_map:
# in case of dynamic symbolic may in params
continue
arg = func.buffer_map[param]
numpy_dtype = map_numpy_type(arg.dtype)
profile_tensors.append(
tvm.nd.array(
np.random.uniform(0, 1, [var_warpper(i) for i in arg.shape]).astype(arg.dtype),
np.random.uniform(0, 1,
[var_warpper(i) for i in arg.shape]).astype(numpy_dtype),
device=device,
))
self.profile_tensors = profile_tensors
Expand Down
3 changes: 2 additions & 1 deletion python/bitblas/relax/transform/weight_only_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def transform_matmul(self, g_var: GlobalVar, func: tir.PrimFunc, intrin_info):
intrin_group = get_mma_intrin_group(
load_scope="shared",
store_scope="shared",
in_dtype=intrin_info["in_dtype"],
a_dtype=intrin_info["in_dtype"],
b_dtype=intrin_info["in_dtype"],
out_dtype=intrin_info["out_dtype"],
trans_a=False,
trans_b=intrin_info["trans_b"],
Expand Down
Loading

0 comments on commit c01a3a7

Please sign in to comment.