From 45a78fa878894b9383fe9983b4cedadae46cb4fc Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 1 May 2024 13:19:09 +0000 Subject: [PATCH 1/2] update bitblas --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index a9b770a8..0290a887 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit a9b770a85d2b856424a2b4c71d870e3f1af90396 +Subproject commit 0290a887df4a0f16284e413c26a533f2ee101fb5 From a70131f1bb2734dce71807ae9101e7c39c669228 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 2 May 2024 05:52:31 +0000 Subject: [PATCH 2/2] Merge branch 'main' of https://github.com/microsoft/BitBLAS into main --- .../bitblas/base/roller/policy/tensorcore.py | 3 +- python/bitblas/base/utils.py | 4 + python/bitblas/cache/operator.py | 1 + python/bitblas/ops/general_matmul.py | 105 ++++++++++-------- python/bitblas/utils/target_detector.py | 4 +- 5 files changed, 69 insertions(+), 48 deletions(-) diff --git a/python/bitblas/base/roller/policy/tensorcore.py b/python/bitblas/base/roller/policy/tensorcore.py index eb8aa060..f52a1b80 100644 --- a/python/bitblas/base/roller/policy/tensorcore.py +++ b/python/bitblas/base/roller/policy/tensorcore.py @@ -258,8 +258,7 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]: # allow pad, otherwise, we can not get a valid tile shape return None - if np.prod(space) % warps != 0: - return None + factors = factorize(np.prod(space) // warps) def _score(node, thread): # small is better diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py index 0e51ef57..23a817f7 100644 --- a/python/bitblas/base/utils.py +++ b/python/bitblas/base/utils.py @@ -332,6 +332,10 @@ def fast_tune( policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags) configs = policy.emit_config(topk) + + if len(configs) == 0: + raise ValueError("No valid config generated") + cpresults, best = apply_and_build( func, configs, diff --git a/python/bitblas/cache/operator.py b/python/bitblas/cache/operator.py index 75c67662..9b30a620 100644 --- a/python/bitblas/cache/operator.py +++ b/python/bitblas/cache/operator.py @@ -164,6 +164,7 @@ def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_n def load_global_ops_cache(database_path=BITBLAS_DATABASE_PATH, target=None): if target is None: target = bitblas.auto_detect_nvidia_target() + logger.info(f"Loading operators from database {database_path} for target {target}") global_operator_cache.load_from_database(database_path, target) return global_operator_cache diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index ce8a8aef..ef6b3dfc 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -94,44 +94,24 @@ class MatmulConfig: storage_dtype: str = "int8" # weight transform related flags - fast_decoding: bool = True # enable fast decoding by default - propagate_a: TransformKind = TransformKind.NonTransform - propagate_b: TransformKind = TransformKind.NonTransform - - def __post_init__(self): - # set M to default dynamic range if it is None - if self.M is None: - object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024]) - if self.N is None: - raise ValueError("N should be specified currently.") - if self.K is None: - raise ValueError("K should be specified currently.") - - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - # This is hack to legalize propagate_a and b - # TODO(lei): should be removed in the future when tc+br template is ready. + fast_decoding: Optional[bool] = None # enable fast decoding by default, if not specified, it is enabled by a rule. + propagate_a: Optional[TransformKind] = None # propagate_a is a flag to control the ladder permutation. + propagate_b: Optional[TransformKind] = None # propagate_b is a flag to control the ladder permutation + + + def __legalize_dynamic_symbolic(self, M): + return tuple(self.M) if isinstance(self.M, list) else self.M + + def __legalize_propagate(self, propagate): + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform + if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) + + return propagate + + def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]): MICRO_KERNEL_SIZE = 16 if isinstance( self.M, @@ -148,13 +128,54 @@ def __post_init__(self): else: object.__setattr__(self, "propagate_b", TransformKind.IntraWarpTransform) - if self.zeros_mode is None: + # set a and b value if is not None + if propagate_a is not None: + object.__setattr__(self, "propagate_a", propagate_a) + if propagate_b is not None: + object.__setattr__(self, "propagate_b", propagate_b) + + # TODO(lei): This is a limitation arose by pytorch and llvm + # Should be removed in the future. + if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + object.__setattr__(self, "propagate_b", TransformKind.NonTransform) + + def __initialize_zeros_mode(self, zeros_mode: Optional[str]): + if zeros_mode is None: object.__setattr__(self, "zeros_mode", "original") + def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): if "int" not in self.W_dtype or self.W_dtype == self.A_dtype: object.__setattr__(self, "fast_decoding", False) else: object.__setattr__(self, "fast_decoding", self.fast_decoding) + if fast_decoding is not None: + object.__setattr__(self, "fast_decoding", fast_decoding) + + def __post_init__(self): + # set M to default dynamic range if it is None + if self.M is None: + object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024]) + if self.N is None: + raise ValueError("N should be specified currently.") + if self.K is None: + raise ValueError("K should be specified currently.") + + # set M to tuple if it is list + # otherwise, M is not hashable + object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M)) + + # set propagate_a and propagate_b to default value if it is None + object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a)) + object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b)) + + # This is hack to legalize propagate_a and b + # TODO(lei): should be removed in the future when tc+br template is ready. + self.__initialize_propagate(self.propagate_a, self.propagate_b) + + self.__initialize_zeros_mode(self.zeros_mode) + + self.__initialize_fast_decoding(self.fast_decoding) if self.with_bias is None: object.__setattr__(self, "with_bias", False) @@ -172,11 +193,6 @@ def __post_init__(self): "float16", "int8", "e4m3_float8", "e5m2_float8" ]: object.__setattr__(self, "storage_dtype", self.W_dtype) - # TODO(lei): This is a limitation arose by pytorch and llvm - # Should be removed in the future. - if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: - object.__setattr__(self, "propagate_a", TransformKind.NonTransform) - object.__setattr__(self, "propagate_b", TransformKind.NonTransform) class Matmul(Operator): @@ -217,6 +233,7 @@ def __init__( # to save compilation time if target is None: target = auto_detect_nvidia_target() + logger.info(f"Auto detected target: {target}") assert (config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}" source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype] diff --git a/python/bitblas/utils/target_detector.py b/python/bitblas/utils/target_detector.py index ea731577..927e9f8e 100644 --- a/python/bitblas/utils/target_detector.py +++ b/python/bitblas/utils/target_detector.py @@ -2,11 +2,11 @@ # Licensed under the MIT License. import subprocess -import logging from thefuzz import process from tvm.target import Target from tvm.target.tag import list_tags +import logging logger = logging.getLogger(__name__) @@ -44,6 +44,7 @@ def check_target(best, default): if check_target(best_match, "cuda"): return best_match if score >= MATCH_THRESHOLD else "cuda" else: + logger.info(f"Best match '{best_match}' is not a valid CUDA target, falling back to 'cuda'") return "cuda" @@ -65,5 +66,4 @@ def auto_detect_nvidia_target() -> str: # Get the current GPU model and find the best matching target gpu_model = get_gpu_model_from_nvidia_smi() target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda" - return target