Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Improve the Default Config Value and fix a Bug for TensorCore Config with Small shapes #32

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 1 files
+2 −0 src/target/tag.cc
3 changes: 1 addition & 2 deletions python/bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):
if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]:
# allow pad, otherwise, we can not get a valid tile shape
return None
if np.prod(space) % warps != 0:
return None

factors = factorize(np.prod(space) // warps)

def _score(node, thread): # small is better
Expand Down
4 changes: 4 additions & 0 deletions python/bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ def fast_tune(
policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags)

configs = policy.emit_config(topk)

if len(configs) == 0:
raise ValueError("No valid config generated")

cpresults, best = apply_and_build(
func,
configs,
Expand Down
1 change: 1 addition & 0 deletions python/bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_n
def load_global_ops_cache(database_path=BITBLAS_DATABASE_PATH, target=None):
if target is None:
target = bitblas.auto_detect_nvidia_target()
logger.info(f"Loading operators from database {database_path} for target {target}")
global_operator_cache.load_from_database(database_path, target)
return global_operator_cache

Expand Down
105 changes: 61 additions & 44 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,44 +94,24 @@ class MatmulConfig:
storage_dtype: str = "int8"

# weight transform related flags
fast_decoding: bool = True # enable fast decoding by default
propagate_a: TransformKind = TransformKind.NonTransform
propagate_b: TransformKind = TransformKind.NonTransform

def __post_init__(self):
# set M to default dynamic range if it is None
if self.M is None:
object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024])
if self.N is None:
raise ValueError("N should be specified currently.")
if self.K is None:
raise ValueError("K should be specified currently.")

# set M to tuple if it is list
# otherwise, M is not hashable
object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M)
if isinstance(self.propagate_a, bool):
object.__setattr__(
self,
"propagate_a",
(TransformKind.IntraWarpTransform
if self.propagate_a else TransformKind.NonTransform),
)
elif isinstance(self.propagate_a, int):
object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a))

if isinstance(self.propagate_b, bool):
object.__setattr__(
self,
"propagate_b",
(TransformKind.IntraWarpTransform
if self.propagate_b else TransformKind.NonTransform),
)
elif isinstance(self.propagate_b, int):
object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
fast_decoding: Optional[bool] = None # enable fast decoding by default, if not specified, it is enabled by a rule.
propagate_a: Optional[TransformKind] = None # propagate_a is a flag to control the ladder permutation.
propagate_b: Optional[TransformKind] = None # propagate_b is a flag to control the ladder permutation


def __legalize_dynamic_symbolic(self, M):
return tuple(self.M) if isinstance(self.M, list) else self.M

def __legalize_propagate(self, propagate):
if isinstance(propagate, bool):
return (TransformKind.IntraWarpTransform
if propagate else TransformKind.NonTransform)
elif isinstance(propagate, int):
return TransformKind(propagate)

return propagate

def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]):
MICRO_KERNEL_SIZE = 16
if isinstance(
self.M,
Expand All @@ -148,13 +128,54 @@ def __post_init__(self):
else:
object.__setattr__(self, "propagate_b", TransformKind.IntraWarpTransform)

if self.zeros_mode is None:
# set a and b value if is not None
if propagate_a is not None:
object.__setattr__(self, "propagate_a", propagate_a)
if propagate_b is not None:
object.__setattr__(self, "propagate_b", propagate_b)

# TODO(lei): This is a limitation arose by pytorch and llvm
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)

def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
if zeros_mode is None:
object.__setattr__(self, "zeros_mode", "original")

def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):
if "int" not in self.W_dtype or self.W_dtype == self.A_dtype:
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", self.fast_decoding)
if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)

def __post_init__(self):
# set M to default dynamic range if it is None
if self.M is None:
object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024])
if self.N is None:
raise ValueError("N should be specified currently.")
if self.K is None:
raise ValueError("K should be specified currently.")

# set M to tuple if it is list
# otherwise, M is not hashable
object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M))

# set propagate_a and propagate_b to default value if it is None
object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a))
object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
self.__initialize_propagate(self.propagate_a, self.propagate_b)

self.__initialize_zeros_mode(self.zeros_mode)

self.__initialize_fast_decoding(self.fast_decoding)

if self.with_bias is None:
object.__setattr__(self, "with_bias", False)
Expand All @@ -172,11 +193,6 @@ def __post_init__(self):
"float16", "int8", "e4m3_float8", "e5m2_float8"
]:
object.__setattr__(self, "storage_dtype", self.W_dtype)
# TODO(lei): This is a limitation arose by pytorch and llvm
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)


class Matmul(Operator):
Expand Down Expand Up @@ -217,6 +233,7 @@ def __init__(
# to save compilation time
if target is None:
target = auto_detect_nvidia_target()
logger.info(f"Auto detected target: {target}")
assert (config.A_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}"
source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype]
Expand Down
4 changes: 2 additions & 2 deletions python/bitblas/utils/target_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# Licensed under the MIT License.

import subprocess
import logging
from thefuzz import process
from tvm.target import Target
from tvm.target.tag import list_tags

import logging
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -44,6 +44,7 @@ def check_target(best, default):
if check_target(best_match, "cuda"):
return best_match if score >= MATCH_THRESHOLD else "cuda"
else:
logger.info(f"Best match '{best_match}' is not a valid CUDA target, falling back to 'cuda'")
return "cuda"


Expand All @@ -65,5 +66,4 @@ def auto_detect_nvidia_target() -> str:
# Get the current GPU model and find the best matching target
gpu_model = get_gpu_model_from_nvidia_smi()
target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda"

return target
Loading