diff --git a/3rdparty/tvm b/3rdparty/tvm index a9b770a85..0290a887d 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit a9b770a85d2b856424a2b4c71d870e3f1af90396 +Subproject commit 0290a887df4a0f16284e413c26a533f2ee101fb5 diff --git a/integration/BitNet/eval_correctness.py b/integration/BitNet/eval_correctness.py index 74a8267ae..578715da4 100644 --- a/integration/BitNet/eval_correctness.py +++ b/integration/BitNet/eval_correctness.py @@ -39,11 +39,9 @@ def get_runtime(num_repeats=1): def main(): model = BitnetForCausalLM.from_pretrained( '1bitLLM/bitnet_b1_58-3B', - device_map='auto', - low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, - ).half() + ).cuda().half() with torch.no_grad(): model._post_process_weights() diff --git a/integration/BitNet/eval_ppl.py b/integration/BitNet/eval_ppl.py index 0b096513b..8f6e7d347 100644 --- a/integration/BitNet/eval_ppl.py +++ b/integration/BitNet/eval_ppl.py @@ -35,11 +35,11 @@ def main(args): datasets = ['c4', 'wikitext2'] model = BitnetForCausalLM.from_pretrained( args.hf_path, - device_map='auto', - low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, - ).half() + ).cuda().half() + with torch.no_grad(): + model._post_process_weights() tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda() diff --git a/integration/BitNet/eval_utils.py b/integration/BitNet/eval_utils.py new file mode 100644 index 000000000..a7a57dd8a --- /dev/null +++ b/integration/BitNet/eval_utils.py @@ -0,0 +1,133 @@ +import torch + +import numpy as np +import torch.nn.functional as F + +from lm_eval.base import BaseLM +from datasets import load_dataset + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + +def get_test_dataset(dataset_name, tokenizer, seqlen=2048): + if dataset_name == "wikitext2": + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + testdata = "".join(testdata['text']).split('\n') + elif dataset_name == "c4": + testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text'] + else: + raise NotImplementedError + + testdata = [item for item in testdata if item != ""] + tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata] + + data, doc = [], [tokenizer.bos_token_id] + for sen in tokenized_text: + if len(sen) > seqlen: + continue + if len(doc) + len(sen) > seqlen: + data.append(doc) + doc = [tokenizer.bos_token_id] + doc.extend(sen) + if len(doc) > 1 and len(doc) <= seqlen: + data.append(doc) + return data + + +class LMEvalAdaptor(BaseLM): + def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): + super().__init__() + + assert isinstance(batch_size, int) + + self.model_name = model_name + self.model = model + self.model.eval() + + self.tokenizer = tokenizer + + self.vocab_size = self.tokenizer.vocab_size + + self._batch_size = batch_size + + self._max_length = max_length + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length != -1: + return self._max_length + if hasattr(self.model.config, "n_ctx"): + return self.model.config.n_ctx + elif hasattr(self.model.config, "max_position_embeddings"): + return self.model.config.max_position_embeddings + elif hasattr(self.model.config, "n_positions"): + return self.model.config.n_positions + elif "bloom" in self.model_name: + return 2048 + elif "llama" in self.model_name: + return 2048 # TODO: did not check this + elif "mpt" in self.model_name: + return 2048 + elif "falcon" in self.model_name: + return 2048 + else: + print(self.model.config) + raise NotImplementedError + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + return self._batch_size + + @property + def device(self): + return "cuda" + + def tok_encode(self, string: str, add_special_tokens=True): + return self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests): + new_reqs = [] + for context, continuation in requests: + context, continuation = context.strip(), continuation.strip() + if context == "": + # end of text as context + context_enc = [self.eot_token_id] + else: + context_enc = self.tok_encode(context, add_special_tokens=True) + + continuation_enc = self.tok_encode(continuation, add_special_tokens=False) + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs) + + def _model_call(self, inps): + """ + inps: a torch tensor of shape [batch, sequence] + the size of sequence may vary from call to call + + returns: a torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model + """ + with torch.no_grad(): + out = self.model(inps)[0] + return out + + def _model_generate(self, context, max_length, eos_token_id): + return self.model.generate( + context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False + ) \ No newline at end of file diff --git a/integration/BitNet/requirements.txt b/integration/BitNet/requirements.txt new file mode 100644 index 000000000..7d4b14956 --- /dev/null +++ b/integration/BitNet/requirements.txt @@ -0,0 +1,2 @@ +lm_eval==0.3.0 +flash_attn diff --git a/python/bitblas/base/roller/policy/tensorcore.py b/python/bitblas/base/roller/policy/tensorcore.py index eb8aa0600..f52a1b80b 100644 --- a/python/bitblas/base/roller/policy/tensorcore.py +++ b/python/bitblas/base/roller/policy/tensorcore.py @@ -258,8 +258,7 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]: # allow pad, otherwise, we can not get a valid tile shape return None - if np.prod(space) % warps != 0: - return None + factors = factorize(np.prod(space) // warps) def _score(node, thread): # small is better diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py index 0e51ef57b..23a817f78 100644 --- a/python/bitblas/base/utils.py +++ b/python/bitblas/base/utils.py @@ -332,6 +332,10 @@ def fast_tune( policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags) configs = policy.emit_config(topk) + + if len(configs) == 0: + raise ValueError("No valid config generated") + cpresults, best = apply_and_build( func, configs, diff --git a/python/bitblas/cache/operator.py b/python/bitblas/cache/operator.py index 75c67662d..9b30a6200 100644 --- a/python/bitblas/cache/operator.py +++ b/python/bitblas/cache/operator.py @@ -164,6 +164,7 @@ def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_n def load_global_ops_cache(database_path=BITBLAS_DATABASE_PATH, target=None): if target is None: target = bitblas.auto_detect_nvidia_target() + logger.info(f"Loading operators from database {database_path} for target {target}") global_operator_cache.load_from_database(database_path, target) return global_operator_cache diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index ce8a8aef4..4c518f55d 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -94,44 +94,24 @@ class MatmulConfig: storage_dtype: str = "int8" # weight transform related flags - fast_decoding: bool = True # enable fast decoding by default - propagate_a: TransformKind = TransformKind.NonTransform - propagate_b: TransformKind = TransformKind.NonTransform - - def __post_init__(self): - # set M to default dynamic range if it is None - if self.M is None: - object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024]) - if self.N is None: - raise ValueError("N should be specified currently.") - if self.K is None: - raise ValueError("K should be specified currently.") - - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - # This is hack to legalize propagate_a and b - # TODO(lei): should be removed in the future when tc+br template is ready. + fast_decoding: Optional[bool] = None # enable fast decoding by default, if not specified, it is enabled by a rule. + propagate_a: Optional[TransformKind] = None # propagate_a is a flag to control the ladder permutation. + propagate_b: Optional[TransformKind] = None # propagate_b is a flag to control the ladder permutation + + + def __legalize_dynamic_symbolic(self, M): + return tuple(self.M) if isinstance(self.M, list) else self.M + + def __legalize_propagate(self, propagate): + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform + if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) + + return propagate + + def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]): MICRO_KERNEL_SIZE = 16 if isinstance( self.M, @@ -148,13 +128,54 @@ def __post_init__(self): else: object.__setattr__(self, "propagate_b", TransformKind.IntraWarpTransform) - if self.zeros_mode is None: + # set a and b value if is not None + if propagate_a is not None: + object.__setattr__(self, "propagate_a", propagate_a) + if propagate_b is not None: + object.__setattr__(self, "propagate_b", propagate_b) + + # TODO(lei): This is a limitation arose by pytorch and llvm + # Should be removed in the future. + if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + object.__setattr__(self, "propagate_b", TransformKind.NonTransform) + + def __initialize_zeros_mode(self, zeros_mode: Optional[str]): + if zeros_mode is None: object.__setattr__(self, "zeros_mode", "original") + def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): if "int" not in self.W_dtype or self.W_dtype == self.A_dtype: object.__setattr__(self, "fast_decoding", False) else: object.__setattr__(self, "fast_decoding", self.fast_decoding) + if fast_decoding is not None: + object.__setattr__(self, "fast_decoding", fast_decoding) + + def __post_init__(self): + # set M to default dynamic range if it is None + if self.M is None: + object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024]) + if self.N is None: + raise ValueError("N should be specified currently.") + if self.K is None: + raise ValueError("K should be specified currently.") + + # set M to tuple if it is list + # otherwise, M is not hashable + object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M)) + + # set propagate_a and propagate_b to default value if it is None + object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a)) + object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b)) + + # This is hack to legalize propagate_a and b + # TODO(lei): should be removed in the future when tc+br template is ready. + self.__initialize_propagate(self.propagate_a, self.propagate_b) + + self.__initialize_zeros_mode(self.zeros_mode) + + self.__initialize_fast_decoding(self.fast_decoding) if self.with_bias is None: object.__setattr__(self, "with_bias", False) @@ -172,11 +193,6 @@ def __post_init__(self): "float16", "int8", "e4m3_float8", "e5m2_float8" ]: object.__setattr__(self, "storage_dtype", self.W_dtype) - # TODO(lei): This is a limitation arose by pytorch and llvm - # Should be removed in the future. - if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: - object.__setattr__(self, "propagate_a", TransformKind.NonTransform) - object.__setattr__(self, "propagate_b", TransformKind.NonTransform) class Matmul(Operator): @@ -217,6 +233,7 @@ def __init__( # to save compilation time if target is None: target = auto_detect_nvidia_target() + logger.info(f"Auto detected target: {target}") assert (config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}" source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype] @@ -400,6 +417,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): Returns: A list containing the transformed weight tensor and optionally the scale, zeros, and bias. """ + weight = weight.contiguous() if self.W_dtype == self.A_dtype: if self.weight_transform is not None: return self.weight_transform(weight.cpu()).cuda().contiguous() diff --git a/python/bitblas/utils/target_detector.py b/python/bitblas/utils/target_detector.py index ea7315771..927e9f8e8 100644 --- a/python/bitblas/utils/target_detector.py +++ b/python/bitblas/utils/target_detector.py @@ -2,11 +2,11 @@ # Licensed under the MIT License. import subprocess -import logging from thefuzz import process from tvm.target import Target from tvm.target.tag import list_tags +import logging logger = logging.getLogger(__name__) @@ -44,6 +44,7 @@ def check_target(best, default): if check_target(best_match, "cuda"): return best_match if score >= MATCH_THRESHOLD else "cuda" else: + logger.info(f"Best match '{best_match}' is not a valid CUDA target, falling back to 'cuda'") return "cuda" @@ -65,5 +66,4 @@ def auto_detect_nvidia_target() -> str: # Get the current GPU model and find the best matching target gpu_model = get_gpu_model_from_nvidia_smi() target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda" - return target