Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BitNet] Disable accelerate for BitNET #36

Merged
merged 4 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 1 files
+2 −0 src/target/tag.cc
4 changes: 1 addition & 3 deletions integration/BitNet/eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,9 @@ def get_runtime(num_repeats=1):
def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
).cuda().half()
with torch.no_grad():
model._post_process_weights()

Expand Down
6 changes: 3 additions & 3 deletions integration/BitNet/eval_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def main(args):
datasets = ['c4', 'wikitext2']
model = BitnetForCausalLM.from_pretrained(
args.hf_path,
device_map='auto',
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
).cuda().half()
with torch.no_grad():
model._post_process_weights()
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda()

Expand Down
133 changes: 133 additions & 0 deletions integration/BitNet/eval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch

import numpy as np
import torch.nn.functional as F

from lm_eval.base import BaseLM
from datasets import load_dataset


def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)

def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
if dataset_name == "wikitext2":
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
testdata = "".join(testdata['text']).split('\n')
elif dataset_name == "c4":
testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text']
else:
raise NotImplementedError

testdata = [item for item in testdata if item != ""]
tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata]

data, doc = [], [tokenizer.bos_token_id]
for sen in tokenized_text:
if len(sen) > seqlen:
continue
if len(doc) + len(sen) > seqlen:
data.append(doc)
doc = [tokenizer.bos_token_id]
doc.extend(sen)
if len(doc) > 1 and len(doc) <= seqlen:
data.append(doc)
return data


class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
super().__init__()

assert isinstance(batch_size, int)

self.model_name = model_name
self.model = model
self.model.eval()

self.tokenizer = tokenizer

self.vocab_size = self.tokenizer.vocab_size

self._batch_size = batch_size

self._max_length = max_length

@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id

@property
def max_length(self):
if self._max_length != -1:
return self._max_length
if hasattr(self.model.config, "n_ctx"):
return self.model.config.n_ctx
elif hasattr(self.model.config, "max_position_embeddings"):
return self.model.config.max_position_embeddings
elif hasattr(self.model.config, "n_positions"):
return self.model.config.n_positions
elif "bloom" in self.model_name:
return 2048
elif "llama" in self.model_name:
return 2048 # TODO: did not check this
elif "mpt" in self.model_name:
return 2048
elif "falcon" in self.model_name:
return 2048
else:
print(self.model.config)
raise NotImplementedError

@property
def max_gen_toks(self):
return 256

@property
def batch_size(self):
return self._batch_size

@property
def device(self):
return "cuda"

def tok_encode(self, string: str, add_special_tokens=True):
return self.tokenizer.encode(string, add_special_tokens=add_special_tokens)

def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)

def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
context, continuation = context.strip(), continuation.strip()
if context == "":
# end of text as context
context_enc = [self.eot_token_id]
else:
context_enc = self.tok_encode(context, add_special_tokens=True)

continuation_enc = self.tok_encode(continuation, add_special_tokens=False)

new_reqs.append(((context, continuation), context_enc, continuation_enc))

return self._loglikelihood_tokens(new_reqs)

def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call

returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
out = self.model(inps)[0]
return out

def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
)
2 changes: 2 additions & 0 deletions integration/BitNet/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
lm_eval==0.3.0
flash_attn
3 changes: 1 addition & 2 deletions python/bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):
if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]:
# allow pad, otherwise, we can not get a valid tile shape
return None
if np.prod(space) % warps != 0:
return None

factors = factorize(np.prod(space) // warps)

def _score(node, thread): # small is better
Expand Down
4 changes: 4 additions & 0 deletions python/bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ def fast_tune(
policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags)

configs = policy.emit_config(topk)

if len(configs) == 0:
raise ValueError("No valid config generated")

cpresults, best = apply_and_build(
func,
configs,
Expand Down
1 change: 1 addition & 0 deletions python/bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_n
def load_global_ops_cache(database_path=BITBLAS_DATABASE_PATH, target=None):
if target is None:
target = bitblas.auto_detect_nvidia_target()
logger.info(f"Loading operators from database {database_path} for target {target}")
global_operator_cache.load_from_database(database_path, target)
return global_operator_cache

Expand Down
106 changes: 62 additions & 44 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,44 +94,24 @@ class MatmulConfig:
storage_dtype: str = "int8"

# weight transform related flags
fast_decoding: bool = True # enable fast decoding by default
propagate_a: TransformKind = TransformKind.NonTransform
propagate_b: TransformKind = TransformKind.NonTransform

def __post_init__(self):
# set M to default dynamic range if it is None
if self.M is None:
object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024])
if self.N is None:
raise ValueError("N should be specified currently.")
if self.K is None:
raise ValueError("K should be specified currently.")

# set M to tuple if it is list
# otherwise, M is not hashable
object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M)
if isinstance(self.propagate_a, bool):
object.__setattr__(
self,
"propagate_a",
(TransformKind.IntraWarpTransform
if self.propagate_a else TransformKind.NonTransform),
)
elif isinstance(self.propagate_a, int):
object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a))

if isinstance(self.propagate_b, bool):
object.__setattr__(
self,
"propagate_b",
(TransformKind.IntraWarpTransform
if self.propagate_b else TransformKind.NonTransform),
)
elif isinstance(self.propagate_b, int):
object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
fast_decoding: Optional[bool] = None # enable fast decoding by default, if not specified, it is enabled by a rule.
propagate_a: Optional[TransformKind] = None # propagate_a is a flag to control the ladder permutation.
propagate_b: Optional[TransformKind] = None # propagate_b is a flag to control the ladder permutation


def __legalize_dynamic_symbolic(self, M):
return tuple(self.M) if isinstance(self.M, list) else self.M

def __legalize_propagate(self, propagate):
if isinstance(propagate, bool):
return (TransformKind.IntraWarpTransform
if propagate else TransformKind.NonTransform)
elif isinstance(propagate, int):
return TransformKind(propagate)

return propagate

def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]):
MICRO_KERNEL_SIZE = 16
if isinstance(
self.M,
Expand All @@ -148,13 +128,54 @@ def __post_init__(self):
else:
object.__setattr__(self, "propagate_b", TransformKind.IntraWarpTransform)

if self.zeros_mode is None:
# set a and b value if is not None
if propagate_a is not None:
object.__setattr__(self, "propagate_a", propagate_a)
if propagate_b is not None:
object.__setattr__(self, "propagate_b", propagate_b)

# TODO(lei): This is a limitation arose by pytorch and llvm
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)

def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
if zeros_mode is None:
object.__setattr__(self, "zeros_mode", "original")

def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):
if "int" not in self.W_dtype or self.W_dtype == self.A_dtype:
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", self.fast_decoding)
if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)

def __post_init__(self):
# set M to default dynamic range if it is None
if self.M is None:
object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024])
if self.N is None:
raise ValueError("N should be specified currently.")
if self.K is None:
raise ValueError("K should be specified currently.")

# set M to tuple if it is list
# otherwise, M is not hashable
object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M))

# set propagate_a and propagate_b to default value if it is None
object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a))
object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
self.__initialize_propagate(self.propagate_a, self.propagate_b)

self.__initialize_zeros_mode(self.zeros_mode)

self.__initialize_fast_decoding(self.fast_decoding)

if self.with_bias is None:
object.__setattr__(self, "with_bias", False)
Expand All @@ -172,11 +193,6 @@ def __post_init__(self):
"float16", "int8", "e4m3_float8", "e5m2_float8"
]:
object.__setattr__(self, "storage_dtype", self.W_dtype)
# TODO(lei): This is a limitation arose by pytorch and llvm
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)


class Matmul(Operator):
Expand Down Expand Up @@ -217,6 +233,7 @@ def __init__(
# to save compilation time
if target is None:
target = auto_detect_nvidia_target()
logger.info(f"Auto detected target: {target}")
assert (config.A_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}"
source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype]
Expand Down Expand Up @@ -400,6 +417,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
Returns:
A list containing the transformed weight tensor and optionally the scale, zeros, and bias.
"""
weight = weight.contiguous()
if self.W_dtype == self.A_dtype:
if self.weight_transform is not None:
return self.weight_transform(weight.cpu()).cuda().contiguous()
Expand Down
4 changes: 2 additions & 2 deletions python/bitblas/utils/target_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# Licensed under the MIT License.

import subprocess
import logging
from thefuzz import process
from tvm.target import Target
from tvm.target.tag import list_tags

import logging
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -44,6 +44,7 @@ def check_target(best, default):
if check_target(best_match, "cuda"):
return best_match if score >= MATCH_THRESHOLD else "cuda"
else:
logger.info(f"Best match '{best_match}' is not a valid CUDA target, falling back to 'cuda'")
return "cuda"


Expand All @@ -65,5 +66,4 @@ def auto_detect_nvidia_target() -> str:
# Get the current GPU model and find the best matching target
gpu_model = get_gpu_model_from_nvidia_smi()
target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda"

return target
Loading