From f4e15a5a64f4fc9a9703c4fc316a416d0f74ae1b Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Thu, 4 Jul 2024 13:50:57 +0900 Subject: [PATCH] [Dev] Refactor scripts based on our new directory structure (#69) * chore: Update support matrix in README * Move bitblas package to root * Remove unused code files * Create soft link for tvm * Create soft link for tvm * Update softlink paths for tvm in setup.py * Refactor import statements to use relative paths * fix test linear * Move bitblas package to root * Move bitblas package to root --- bitblas/__init__.py | 14 +- bitblas/base/roller/arch/cpu.py | 2 +- bitblas/base/roller/arch/cuda.py | 2 +- bitblas/base/roller/node.py | 2 +- bitblas/base/roller/policy/default.py | 2 +- bitblas/base/roller/policy/tensorcore.py | 2 +- bitblas/base/transform.py | 2 +- bitblas/base/utils.py | 2 +- bitblas/cache/operator.py | 2 +- bitblas/gpu/intrin/lop3.py | 2 +- bitblas/gpu/rmsnorm.py | 2 +- bitblas/module/__init__.py | 21 +- bitblas/ops/general_matmul.py | 7 +- .../ops/impl/batch_matmul_dequantize_impl.py | 2 +- bitblas/ops/impl/batch_matmul_impl.py | 2 +- bitblas/ops/impl/convolution2d_impl.py | 2 +- bitblas/ops/impl/matmul_dequantize_impl.py | 2 +- .../ops/impl/matmul_dequantize_splitk_impl.py | 2 +- bitblas/ops/impl/matmul_impl.py | 2 +- bitblas/ops/impl/matmul_splitk_impl.py | 2 +- bitblas/ops/matmul.py | 2 +- bitblas/ops/matmul_dequantize.py | 2 +- bitblas/ops/operator.py | 2 +- bitblas/quantization/quantization.py | 2 +- bitblas/testing/__init__.py | 3 +- bitblas/utils/tensor_adapter.py | 2 +- bitblas/wrapper/general.py | 2 +- install.sh | 2 +- maint/scripts/installation.sh | 6 +- setup.py | 5 + testing/python/cache/test_operator_cache.py | 249 ++---------------- testing/python/module/test_bitblas_linear.py | 2 +- .../operators/test_general_matmul_fp8.py | 67 ++--- .../operators/test_general_matmul_ops.py | 154 +++++------ .../test_general_matmul_splitk_ops.py | 2 +- .../operators/test_ladder_permutate_ops.py | 2 +- .../operators/test_lop3_permutate_ops.py | 2 +- .../operators/test_matmul_dequantize_ops.py | 2 +- .../operators/test_param_permutate_ops.py | 2 +- .../transform/test_weight_only_transform.py | 4 +- .../type_conversion/int4b_fp16_convert.py | 4 +- .../test_lop3_type_conversion.py | 2 +- .../python/weight_only/index_map_deduce.py | 4 +- testing/python/weight_only/index_map_fuse.py | 2 +- .../python/weight_only/inverse_index_map.py | 2 +- 45 files changed, 181 insertions(+), 423 deletions(-) diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 172c4cbf1..c2f44ae05 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -5,17 +5,18 @@ # installing tvm install_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm", "python") + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: - os.environ["PYTHONPATH"] = install_tvm_path + ":" + os.environ.get("PYTHONPATH", "") - sys.path.insert(0, install_tvm_path) + os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, install_tvm_path + "/python") develop_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm", "python") + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: - os.environ["PYTHONPATH"] = develop_tvm_path + ":" + os.environ.get("PYTHONPATH", "") - sys.path.insert(0, develop_tvm_path) + os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, develop_tvm_path + "/python") +import tvm as tvm # noqa: E402 from . import gpu # noqa: F401 from .base import ( TileDevice, # noqa: F401 @@ -30,6 +31,7 @@ try_inline_contiguous_spatial, # noqa: F401 ) + from . import testing # noqa: F401 from .utils import auto_detect_nvidia_target # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 diff --git a/bitblas/base/roller/arch/cpu.py b/bitblas/base/roller/arch/cpu.py index 98fb14af5..65592cc7d 100644 --- a/bitblas/base/roller/arch/cpu.py +++ b/bitblas/base/roller/arch/cpu.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from tvm.target import Target from .arch_base import TileDevice diff --git a/bitblas/base/roller/arch/cuda.py b/bitblas/base/roller/arch/cuda.py index 2189947e7..8af1e0c8e 100644 --- a/bitblas/base/roller/arch/cuda.py +++ b/bitblas/base/roller/arch/cuda.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from tvm.target import Target from .arch_base import TileDevice from typing import List, Dict, Union diff --git a/bitblas/base/roller/node.py b/bitblas/base/roller/node.py index 8e20440bb..c9d648019 100644 --- a/bitblas/base/roller/node.py +++ b/bitblas/base/roller/node.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. """PrimFunc Wrapper and Block information Analaysis""" -import tvm +from bitblas import tvm from tvm import tir from tvm.tir import IterVar, PrimFunc from typing import Any, Dict, List, Tuple, Optional diff --git a/bitblas/base/roller/policy/default.py b/bitblas/base/roller/policy/default.py index 81aeba123..730c8336f 100644 --- a/bitblas/base/roller/policy/default.py +++ b/bitblas/base/roller/policy/default.py @@ -7,7 +7,7 @@ from typing import Iterable, Dict, List, Optional import numpy as np -import tvm +from bitblas import tvm from ..arch import TileDevice from ..bestfit import BestFit diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index f4047ef08..ae45b5893 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """Policy for tensorcore schedule""" -import tvm +from bitblas import tvm from typing import Dict, List, Tuple, Optional import numpy as np diff --git a/bitblas/base/transform.py b/bitblas/base/transform.py index 647efa772..ec2cbc1e7 100644 --- a/bitblas/base/transform.py +++ b/bitblas/base/transform.py @@ -9,7 +9,7 @@ import shutil import tempfile import os.path as osp -import tvm +from bitblas import tvm from tvm import tir from tvm import meta_schedule as ms from tvm.ir import IRModule diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 50adc135f..4cd82fa93 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm import os from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind from concurrent.futures import ThreadPoolExecutor, as_completed diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 9b30a6200..0c41ab686 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -8,7 +8,7 @@ import tempfile from hashlib import sha256 import shutil -import tvm +from bitblas import tvm from tvm.contrib.tar import tar import logging diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index b5426cf59..dc4bb587b 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from tvm.tir.function import TensorIntrin from tvm.script import tir as T from typing import Dict, Literal diff --git a/bitblas/gpu/rmsnorm.py b/bitblas/gpu/rmsnorm.py index 6e6d3e247..0d8b37998 100644 --- a/bitblas/gpu/rmsnorm.py +++ b/bitblas/gpu/rmsnorm.py @@ -21,7 +21,7 @@ # pylint: disable=missing-docstring """A RMS norm schedule rule for GPU operators.""" -import tvm +from bitblas import tvm from tvm import tir from tvm.tir import Block, BufferStore from tvm.tir.expr import Cast, BufferLoad, Call diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index f353228a5..a3fbba213 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -231,21 +231,24 @@ def warmup(self, topk=20): def forward(self, A, output=None): if A.dtype != torch.float16: A = A.half() - # can be lifted to post init. - self.init_params() - - if output is None: - output = torch.empty( - A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) - m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1)) A = self.bitblas_matmul.transform_input(A) stream = torch.cuda.current_stream() A_void = ctypes.c_void_p(A.data_ptr()) stream_handle = ctypes.c_void_p(stream.cuda_stream) + # can be lifted to post init. + self.init_params() + args = [A_void, *self.q_params] + if output is None: + output = torch.empty( + A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) + args.append(ctypes.c_void_p(output.data_ptr())) + if self.bitblas_matmul.dynamic_range is not None: + m = reduce(operator.mul, A.shape[:-1], 1) + args.append(m) + args.append(stream_handle) # m is the product of the last n - 1 dimensions of A - self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, - stream_handle) + self.bitblas_matmul.lib.call(*args) return output diff --git a/bitblas/ops/general_matmul.py b/bitblas/ops/general_matmul.py index af2da3f02..97dd7d13f 100644 --- a/bitblas/ops/general_matmul.py +++ b/bitblas/ops/general_matmul.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from tvm.target import Target import operator from functools import reduce @@ -89,7 +89,10 @@ def __legalize_propagate(self, propagate): def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]): MICRO_KERNEL_SIZE = 16 - if (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and + if propagate_b is not None and propagate_b == TransformKind.NonTransform: + # Currently we do not support propagate_a when propagate_b is not transformed. + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + elif (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and (self.K % MICRO_KERNEL_SIZE) == 0): object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform) else: diff --git a/bitblas/ops/impl/batch_matmul_dequantize_impl.py b/bitblas/ops/impl/batch_matmul_dequantize_impl.py index a3ab5ebef..6303f4bf8 100644 --- a/bitblas/ops/impl/batch_matmul_dequantize_impl.py +++ b/bitblas/ops/impl/batch_matmul_dequantize_impl.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # pre-transformed tir expression of matmul -import tvm +from bitblas import tvm from tvm import te, DataType from tvm.tir import IndexMap from bitblas.ops.operator import TransformKind diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 1828ed15d..09b536afa 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # pre-transformed tir expression of matmul -import tvm +from bitblas import tvm from tvm import te from bitblas.ops.operator import TransformKind diff --git a/bitblas/ops/impl/convolution2d_impl.py b/bitblas/ops/impl/convolution2d_impl.py index d77d8f573..c7d21d7c8 100644 --- a/bitblas/ops/impl/convolution2d_impl.py +++ b/bitblas/ops/impl/convolution2d_impl.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # pre-transformed tir expression of matmul -import tvm +from bitblas import tvm from tvm import te, tir diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index d4aa02c84..417eaacfc 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # pre-transformed tir expression of matmul -import tvm +from bitblas import tvm from tvm import te, DataType from tvm.tir import IndexMap from bitblas.ops.operator import TransformKind diff --git a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py index afe241b65..d3ef23187 100644 --- a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # pre-transformed tir expression of matmul -import tvm +from bitblas import tvm from tvm import te from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, diff --git a/bitblas/ops/impl/matmul_impl.py b/bitblas/ops/impl/matmul_impl.py index 69b426354..b093f0d9c 100644 --- a/bitblas/ops/impl/matmul_impl.py +++ b/bitblas/ops/impl/matmul_impl.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # pre-transformed tir expression of matmul -import tvm +from bitblas import tvm from tvm import te from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.ops.operator import TransformKind diff --git a/bitblas/ops/impl/matmul_splitk_impl.py b/bitblas/ops/impl/matmul_splitk_impl.py index c437f64cb..c314fa6ca 100644 --- a/bitblas/ops/impl/matmul_splitk_impl.py +++ b/bitblas/ops/impl/matmul_splitk_impl.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # pre-transformed tir expression of matmul -import tvm +from bitblas import tvm from tvm import te from bitblas.ops.operator import TransformKind diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py index 7783c4972..34014abb9 100644 --- a/bitblas/ops/matmul.py +++ b/bitblas/ops/matmul.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm import numpy as np from tvm.target import Target from bitblas.utils.tensor_adapter import tvm_tensor_to_torch diff --git a/bitblas/ops/matmul_dequantize.py b/bitblas/ops/matmul_dequantize.py index 25c68b121..7381b3f12 100644 --- a/bitblas/ops/matmul_dequantize.py +++ b/bitblas/ops/matmul_dequantize.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from tvm.target import Target from bitblas.base.roller.arch.cuda import CUDA from typing import Any, List, Literal, Optional, Tuple, Union diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 90930d6d3..d1555a674 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from abc import ABC, abstractmethod -import tvm +from bitblas import tvm from tvm import IRModule from tvm.target import Target from tvm.tir import PrimFunc diff --git a/bitblas/quantization/quantization.py b/bitblas/quantization/quantization.py index 71ef224d7..f6fc75b4e 100644 --- a/bitblas/quantization/quantization.py +++ b/bitblas/quantization/quantization.py @@ -21,7 +21,7 @@ # pylint: disable=invalid-name,missing-function-docstring,unused-variable """TIR computation utilities for quantization.""" -import tvm +from bitblas import tvm from tvm import tir diff --git a/bitblas/testing/__init__.py b/bitblas/testing/__init__.py index 24f896bd8..92a43b470 100644 --- a/bitblas/testing/__init__.py +++ b/bitblas/testing/__init__.py @@ -5,7 +5,8 @@ import pytest from bitblas.base import DefaultPolicy, TensorCorePolicy from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags - +from bitblas import tvm # pylint: disable=import-error +from tvm.testing.utils import * # pytest.main() wrapper to allow running single test file def main(): diff --git a/bitblas/utils/tensor_adapter.py b/bitblas/utils/tensor_adapter.py index 55b80d138..d4d052dbb 100644 --- a/bitblas/utils/tensor_adapter.py +++ b/bitblas/utils/tensor_adapter.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from typing import Union from enum import IntEnum import numpy as np diff --git a/bitblas/wrapper/general.py b/bitblas/wrapper/general.py index 58aa8d226..1271329f1 100644 --- a/bitblas/wrapper/general.py +++ b/bitblas/wrapper/general.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from typing import Optional, List, Dict, Union from tvm import IRModule from bitblas import TileDevice diff --git a/install.sh b/install.sh index 584392820..4affa1da6 100755 --- a/install.sh +++ b/install.sh @@ -21,6 +21,6 @@ echo "set(USE_LLVM llvm-config-10)" >> config.cmake && echo "set(USE_CUDA ON)" > cmake .. && make -j && cd ../../.. echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc -echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd)/python:\$PYTHONPATH" >> ~/.bashrc +echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc source ~/.bashrc diff --git a/maint/scripts/installation.sh b/maint/scripts/installation.sh index 8e083326a..4affa1da6 100755 --- a/maint/scripts/installation.sh +++ b/maint/scripts/installation.sh @@ -3,8 +3,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# install torch -pip install torch==2.1.0 +# install requirements +pip install -r requirements.txt # install llvm apt-get install llvm-10 @@ -21,6 +21,6 @@ echo "set(USE_LLVM llvm-config-10)" >> config.cmake && echo "set(USE_CUDA ON)" > cmake .. && make -j && cd ../../.. echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc -echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd)/python:\$PYTHONPATH" >> ~/.bashrc +echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc source ~/.bashrc diff --git a/setup.py b/setup.py index b011f006d..6a88fc7bf 100644 --- a/setup.py +++ b/setup.py @@ -209,6 +209,8 @@ def run(self): build_tvm(llvm_path) # Continue with the standard installation process install.run(self) + # Create softlink for bitblas + create_softlink(tvm_path="../3rdparty/tvm/python/tvm", bitblas_path="bitblas/tvm") class BitBLASBuilPydCommand(build_py): @@ -222,6 +224,9 @@ def run(self): _, llvm_path = setup_llvm_for_tvm() # Build TVM build_tvm(llvm_path) + # Create softlink for bitblas + create_softlink(tvm_path="../3rdparty/tvm/python/tvm", bitblas_path="bitblas/tvm") + # Copy the built TVM to the package directory TVM_PREBUILD_ITEMS = [ "3rdparty/tvm/build/libtvm_runtime.so", diff --git a/testing/python/cache/test_operator_cache.py b/testing/python/cache/test_operator_cache.py index fcb863f9a..51a155f7f 100644 --- a/testing/python/cache/test_operator_cache.py +++ b/testing/python/cache/test_operator_cache.py @@ -4,15 +4,11 @@ import os import torch import bitblas -from bitblas.ops.matmul import Matmul, MatmulConfig -from bitblas.ops.matmul_dequantize import ( - MatmulWeightOnlyDequantize, - MatmulWeightOnlyDequantizeConfig, -) +from bitblas import Matmul, MatmulConfig from bitblas.cache import global_operator_cache target = bitblas.utils.auto_detect_nvidia_target() - +bitblas.set_log_level("DEBUG") def get_codegen_result(ops, target): code = ops.get_source(target=target) @@ -23,10 +19,10 @@ def get_codegen_result(ops, target): @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout,enable_tuning", [ - (1, 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), + (1, 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", False), # dynamic shape - ([1], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), - ([1, 32], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", True), + ([1], 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", False), + ([1, 32], 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", True), ], ) def test_config_hashable( @@ -47,7 +43,7 @@ def test_config_hashable( M=M, N=N, K=K, - in_dtype=in_dtype, + A_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, with_bias=with_bias, @@ -75,10 +71,10 @@ def test_config_hashable( @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout,enable_tuning", [ - (1, 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), + (1, 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", False), # dynamic shape - ([1], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), - ([1, 32], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", True), + ([1], 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", False), + ([1, 32], 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", True), ], ) def test_global_cache_inquery( @@ -99,7 +95,7 @@ def test_global_cache_inquery( M=M, N=N, K=K, - in_dtype=in_dtype, + A_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, with_bias=with_bias, @@ -128,10 +124,10 @@ def test_global_cache_inquery( @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout,enable_tuning", [ - (1, 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), + (1, 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", False), # dynamic shape - ([1], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), - ([1, 32], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", True), + ([1], 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", False), + ([1, 32], 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", True), ], ) def test_global_cache_inquery_torch_forward( @@ -152,7 +148,7 @@ def test_global_cache_inquery_torch_forward( M=M, N=N, K=K, - in_dtype=in_dtype, + A_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, with_bias=with_bias, @@ -197,16 +193,15 @@ def test_global_cache_inquery_torch_forward( permuted_inputs.append(matmul.weight_transform(inputs[1].cpu()).cuda()) else: permuted_inputs.append(inputs[1]) - permuted_inputs.append(inputs[2]) - matmul(*permuted_inputs) - torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) + bitblas_out = matmul(permuted_inputs[0], permuted_inputs[1]) + torch.testing.assert_close(bitblas_out, ref_result, rtol=1e-2, atol=1e-2) @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout,enable_tuning", [ - (1, 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), - ([1, 32], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), + (1, 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", False), + ([1, 32], 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", False), ], ) def test_global_cache_save_to_database( @@ -227,7 +222,7 @@ def test_global_cache_save_to_database( M=M, N=N, K=K, - in_dtype=in_dtype, + A_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, with_bias=with_bias, @@ -249,7 +244,7 @@ def test_global_cache_save_to_database( print(hash_error) assert success - database_path = "debug/test_database" + database_path = "/tmp/.tmp_bitblas_cache.db" global_operator_cache.save_into_database(database_path, target=target) assert os.path.exists(database_path) global_operator_cache.clear() @@ -280,208 +275,8 @@ def test_global_cache_save_to_database( permuted_inputs.append(matmul.weight_transform(inputs[1].cpu()).cuda()) else: permuted_inputs.append(inputs[1]) - permuted_inputs.append(inputs[2]) - matmul(*permuted_inputs) - torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", - [ - ( - 1, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "uint", - False, - False, - -1, - False, - False, - False, - False, - "nt", - ), - ( - 1, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - False, - False, - -1, - False, - False, - False, - False, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - False, - False, - -1, - False, - False, - False, - False, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - False, - False, - -1, - False, - False, - False, - True, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - False, - False, - -1, - False, - False, - True, - True, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - True, - False, - -1, - False, - False, - True, - True, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - True, - False, - 128, - False, - False, - True, - True, - "nt", - ), - ], -) -def test_matmul_dequantize_save_into_database( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - propagate_a, - propagate_b, - layout, -): - - matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - ) - matmul = MatmulWeightOnlyDequantize( - config=matmul_config, - target=target, - ) - matmul.hardware_aware_finetune(topk=20) - database_path = "debug/test_database" - success = False - - try: - global_operator_cache.add(matmul.config, matmul) - success = True - except Exception as hash_error: - print(hash_error) - assert success - global_operator_cache.save_into_database(database_path, target=target) - assert os.path.exists(database_path) - global_operator_cache.clear() - assert global_operator_cache.size() == 0 - global_operator_cache.load_from_database(database_path, target=target) - assert global_operator_cache.size() > 0 + bitblas_output = matmul(*permuted_inputs) + torch.testing.assert_close(bitblas_output, ref_result, rtol=1e-2, atol=1e-2) # fmt: on diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index edab105ab..eeaf90475 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -9,7 +9,7 @@ import pytest torch.manual_seed(0) - +bitblas.set_log_level("DEBUG") @pytest.mark.parametrize( "m, in_features, out_features, bias", diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 603a57248..0514b9209 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -1,6 +1,4 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import pytest +import torch import bitblas from bitblas import MatmulConfig, Matmul import logging @@ -9,23 +7,7 @@ set_log_level(logging.DEBUG) -# TODO(lei): should add requirements for cuda and sm version -@pytest.mark.parametrize( - "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", - [ - (1, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, - None, None), - (1024, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, - None, None, None), - (1, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, - None, None), - (1024, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, - None, None, None), - ], -) -def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, - group_size, with_scaling, with_zeros, zeros_mode): - import torch +def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): torch.random.manual_seed(0) matmul_config = MatmulConfig( @@ -49,7 +31,6 @@ def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, weight_shape = (N, K) if layout == "nt" else (K, N) def map_torch_type(intype): - typemap = { 'e4m3_float8': torch.float8_e4m3fn, 'e5m2_float8': torch.float8_e5m2, @@ -63,8 +44,8 @@ def map_torch_type(intype): numpytype_b = map_torch_type(W_dtype) numpytype_c = map_torch_type(out_dtype) - torch_a = torch.rand(M * K).uniform_(-5, 5).reshape(input_shape).type(numpytype_a).cuda() - torch_b = torch.rand(N * K).uniform_(-5, 5).reshape(weight_shape).type(numpytype_b).cuda() + torch_a = torch.rand(M * K).uniform_(-1, 1).reshape(input_shape).type(numpytype_a).cuda() + torch_b = torch.rand(N * K).uniform_(-1, 1).reshape(weight_shape).type(numpytype_b).cuda() ref_out = torch.matmul(torch_a.to(torch.float32), torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul( torch_a.to(torch.float32), torch_b.to(torch.float32)) @@ -75,25 +56,14 @@ def map_torch_type(intype): bitblas_out = matmul(torch_a, new_torch_b) print("bitblas_out", bitblas_out) +@bitblas.testing.requires_cuda_compute_version(8, 9) +def test_matmul_torch_forward(): + matmul_torch_forward(1, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, None) + matmul_torch_forward(1024, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, None) + matmul_torch_forward(1, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, None) + matmul_torch_forward(1024, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, None) -# TODO(lei): should add requirements for cuda and sm version -@pytest.mark.parametrize( - "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", - [ - (1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, - None, None), - (1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, - None, None), - (1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, None, - None), - (1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, - None, None), - ], -) -def test_matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, - layout, with_bias, group_size, with_scaling, - with_zeros, zeros_mode): - import torch +def matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): torch.random.manual_seed(0) matmul_config = MatmulConfig( @@ -119,7 +89,6 @@ def test_matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum weight_shape = (N, K) if layout == "nt" else (K, N) def map_torch_type(intype): - typemap = { 'e4m3_float8': torch.float8_e4m3fn, 'e5m2_float8': torch.float8_e5m2, @@ -144,9 +113,8 @@ def map_torch_type(intype): group_size = -1 if group_size == -1: group_size = K - scale_tensor = torch.rand(N * K // group_size).uniform_(-4, 4).reshape( + scale_tensor = torch.rand(N * K // group_size).uniform_(-1, 1).reshape( [N, K // group_size]).type(torch.float16).cuda() - # scale_tensor = torch.ones([N, K // group_size]).type(torch.float16).cuda() rescale_b = torch.zeros_like(torch_b).type(torch.float16) for i in range(K): rescale_b[:, i] = torch_b.to(torch.float16)[:, i] * scale_tensor[:, i // group_size] @@ -168,9 +136,12 @@ def map_torch_type(intype): torch.testing.assert_close(ref_out, bitblas_out, rtol=1e-1, atol=1e-1) +@bitblas.testing.requires_cuda_compute_version(8, 9) +def test_matmul_torch_forward_weight_dequantize(): + matmul_torch_forward_weight_dequantize(1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, None, None) + matmul_torch_forward_weight_dequantize(1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, None, None) + matmul_torch_forward_weight_dequantize(1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, None, None) + matmul_torch_forward_weight_dequantize(1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, None, None) -# fmt: on if __name__ == "__main__": - # bitblas.testing.main() - test_matmul_torch_forward_weight_dequantize(1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, - None, None) + bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index bb0b06719..6baa4d434 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pytest import bitblas from bitblas import MatmulConfig, Matmul import logging @@ -13,34 +12,8 @@ def get_codegen_result(ops): code = ops.get_source() return code - # fmt: off -@pytest.mark.parametrize( - "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", - [ - (1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), - (768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), - (1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), - (768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, - None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, - "original"), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, - None), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, - None), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, - None), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, - "original"), - ], -) -def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, +def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): matmul_config = MatmulConfig( @@ -61,33 +34,22 @@ def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtyp matmul = Matmul(config=matmul_config, enable_tuning=False) assert get_codegen_result(matmul) +def test_matmul_codegen_default(): + matmul_codegen_default(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None), + matmul_codegen_default(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None), + matmul_codegen_default(1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), + matmul_codegen_default(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), + matmul_codegen_default(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, "original"), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, None), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), + matmul_codegen_default(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, "original"), -@pytest.mark.parametrize( - "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", - [ - (1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), - (768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), - (1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), - (768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, - None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, - "original"), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, - None), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, - None), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, - None), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, - "original"), - ], -) -def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + +def matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): matmul_config = MatmulConfig( @@ -110,28 +72,30 @@ def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo assert get_codegen_result(matmul) -@pytest.mark.parametrize( - "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", - [ - (1, 1024, 1024, "float16", "int4", "float16", "float16", "nt", None, None, None, None, - None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, - None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), - (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, - "original"), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, - None), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, - None), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, +def test_matmul_finetune(): + matmul_finetune(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None), - (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, - "original"), - ], -) -def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + matmul_finetune(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + matmul_finetune(1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), + matmul_finetune(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), + matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None), + matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), + matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), + matmul_finetune(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original"), + matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None), + matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, + None), + matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, + None), + matmul_finetune(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original"), + + +def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): import torch torch.random.manual_seed(0) @@ -213,25 +177,33 @@ def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, if with_bias: permuted_inputs.append(bias) permuted_inputs.append(inputs[2]) - matmul(*permuted_inputs) - print(permuted_inputs[-1]) - print(ref_result) + matmul(*permuted_inputs[:2], output=permuted_inputs[-1]) if zeros_mode == "rescale": torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e-0) else: torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e-1) -@pytest.mark.parametrize( - "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,with_bias", - [ - (1, 768, 768, "float16", "uint4", "float16", "float16", False), - (1, 768, 768, "float16", "int4", "float16", "float16", False), - (768, 768, 768, "float16", "uint4", "float16", "float16", False), - (768, 768, 768, "float16", "int4", "float16", "float16", False), - ], -) -def test_matmul_transform_weight( +def test_matmul_torch_forward(): + matmul_torch_forward(1, 1024, 1024, "float16", "int4", "float16", "float16", "nt", None, None, None, None, + None) + matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None) + matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), + matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), + matmul_torch_forward(1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original") + matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None) + matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, + None) + matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, + None) + matmul_torch_forward(768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original") + + +def matmul_transform_weight( M, N, K, @@ -281,6 +253,12 @@ def test_matmul_transform_weight( torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0) +def test_matmul_transform_weight(): + matmul_transform_weight(1, 768, 768, "float16", "uint4", "float16", "float16", False) + matmul_transform_weight(1, 768, 768, "float16", "int4", "float16", "float16", False) + matmul_transform_weight(768, 768, 768, "float16", "uint4", "float16", "float16", False) + matmul_transform_weight(768, 768, 768, "float16", "int4", "float16", "float16", False) + # fmt: on if __name__ == "__main__": bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py index 12fbbcabe..307308bf1 100644 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -143,7 +143,7 @@ def map_torch_type(intype): torch_a.to(torch.float32), torch_b.to(torch.float32)) ref_out = ref_out.to(torch.float16) bitblas_out = torch.empty_like(ref_out) - matmul.forward(torch_a, torch_b, output=bitblas_out) + matmul.forward(torch_a, torch_b) print("torch_ref_out", ref_out) print("bitblas_out", bitblas_out) diff --git a/testing/python/operators/test_ladder_permutate_ops.py b/testing/python/operators/test_ladder_permutate_ops.py index ed8586e09..b302edcc7 100644 --- a/testing/python/operators/test_ladder_permutate_ops.py +++ b/testing/python/operators/test_ladder_permutate_ops.py @@ -3,7 +3,7 @@ import pytest import bitblas from bitblas.ops.ladder_permutate import LadderPermutate, LadderPermutateConfig -import tvm +from bitblas import tvm target = tvm.target.Target("llvm") diff --git a/testing/python/operators/test_lop3_permutate_ops.py b/testing/python/operators/test_lop3_permutate_ops.py index 55dde1174..0f4965a5f 100644 --- a/testing/python/operators/test_lop3_permutate_ops.py +++ b/testing/python/operators/test_lop3_permutate_ops.py @@ -4,7 +4,7 @@ import bitblas from bitblas.ops.lop3_permutate import LOP3Permutate, LOP3PermutateConfig -import tvm +from bitblas import tvm target = tvm.target.Target("llvm") # fmt: off diff --git a/testing/python/operators/test_matmul_dequantize_ops.py b/testing/python/operators/test_matmul_dequantize_ops.py index 12fc8364e..c727dc7e8 100644 --- a/testing/python/operators/test_matmul_dequantize_ops.py +++ b/testing/python/operators/test_matmul_dequantize_ops.py @@ -7,7 +7,7 @@ MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig, ) -import tvm +from bitblas import tvm import logging from bitblas import set_log_level diff --git a/testing/python/operators/test_param_permutate_ops.py b/testing/python/operators/test_param_permutate_ops.py index 9149c15e9..eb85e34cb 100644 --- a/testing/python/operators/test_param_permutate_ops.py +++ b/testing/python/operators/test_param_permutate_ops.py @@ -4,7 +4,7 @@ import bitblas from bitblas.ops.param_permutate import ParamPermutate, ParamPermutateConfig -import tvm +from bitblas import tvm target = tvm.target.Target("llvm") diff --git a/testing/python/transform/test_weight_only_transform.py b/testing/python/transform/test_weight_only_transform.py index 4f4860bc9..11d7ce6d7 100644 --- a/testing/python/transform/test_weight_only_transform.py +++ b/testing/python/transform/test_weight_only_transform.py @@ -4,8 +4,8 @@ from tvm.script import tir as T from tvm.script import relax as R -import tvm -import tvm.testing +from bitblas import tvm +from bitblas import tvm.testing from tvm import relax from tvm.script import ir as I, relax as R, tir as T from tvm import tir diff --git a/testing/python/type_conversion/int4b_fp16_convert.py b/testing/python/type_conversion/int4b_fp16_convert.py index 92f5f46a8..43474d062 100644 --- a/testing/python/type_conversion/int4b_fp16_convert.py +++ b/testing/python/type_conversion/int4b_fp16_convert.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm import torch import numpy as np -import tvm.testing +from bitblas import tvm.testing from tvm.script import tir as T import os from tvm import te diff --git a/testing/python/type_conversion/test_lop3_type_conversion.py b/testing/python/type_conversion/test_lop3_type_conversion.py index e434c8a95..d5853a108 100644 --- a/testing/python/type_conversion/test_lop3_type_conversion.py +++ b/testing/python/type_conversion/test_lop3_type_conversion.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from tvm.script import tir as T import bitblas from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy diff --git a/testing/python/weight_only/index_map_deduce.py b/testing/python/weight_only/index_map_deduce.py index 892cea202..9f75267e6 100644 --- a/testing/python/weight_only/index_map_deduce.py +++ b/testing/python/weight_only/index_map_deduce.py @@ -3,8 +3,8 @@ import numpy as np import pytest -import tvm -import tvm.testing +from bitblas import tvm +from bitblas import tvm.testing from tvm.ir import assert_structural_equal from tvm.runtime import const from tvm.tir import IndexMap, IntImm, floordiv, floormod diff --git a/testing/python/weight_only/index_map_fuse.py b/testing/python/weight_only/index_map_fuse.py index 6660903ca..328f8184d 100644 --- a/testing/python/weight_only/index_map_fuse.py +++ b/testing/python/weight_only/index_map_fuse.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from tvm.script import tir as T from tvm.tir import IndexMap from tvm.tir.tensor_intrin.cuda import ( diff --git a/testing/python/weight_only/inverse_index_map.py b/testing/python/weight_only/inverse_index_map.py index bfc178297..a27e72b28 100644 --- a/testing/python/weight_only/inverse_index_map.py +++ b/testing/python/weight_only/inverse_index_map.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm +from bitblas import tvm from tvm.script import tir as T from tvm.tir import IndexMap from tvm.tir.tensor_intrin.cuda import (