Skip to content

Commit

Permalink
[BitNet] Disable accelerate for BitNET (#36)
Browse files Browse the repository at this point in the history
* update bitblas

* Merge branch 'main' of https://github.com/microsoft/BitBLAS into main

* make sure weight is contiguous.

* Refactor BitNet evaluation scripts for GPU acceleration
  • Loading branch information
LeiWang1999 authored May 4, 2024
1 parent 9059fc4 commit dfe67fc
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 6 deletions.
4 changes: 1 addition & 3 deletions integration/BitNet/eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,9 @@ def get_runtime(num_repeats=1):
def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
).cuda().half()
with torch.no_grad():
model._post_process_weights()

Expand Down
6 changes: 3 additions & 3 deletions integration/BitNet/eval_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def main(args):
datasets = ['c4', 'wikitext2']
model = BitnetForCausalLM.from_pretrained(
args.hf_path,
device_map='auto',
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
).cuda().half()
with torch.no_grad():
model._post_process_weights()
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda()

Expand Down
133 changes: 133 additions & 0 deletions integration/BitNet/eval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch

import numpy as np
import torch.nn.functional as F

from lm_eval.base import BaseLM
from datasets import load_dataset


def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)

def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
if dataset_name == "wikitext2":
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
testdata = "".join(testdata['text']).split('\n')
elif dataset_name == "c4":
testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text']
else:
raise NotImplementedError

testdata = [item for item in testdata if item != ""]
tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata]

data, doc = [], [tokenizer.bos_token_id]
for sen in tokenized_text:
if len(sen) > seqlen:
continue
if len(doc) + len(sen) > seqlen:
data.append(doc)
doc = [tokenizer.bos_token_id]
doc.extend(sen)
if len(doc) > 1 and len(doc) <= seqlen:
data.append(doc)
return data


class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
super().__init__()

assert isinstance(batch_size, int)

self.model_name = model_name
self.model = model
self.model.eval()

self.tokenizer = tokenizer

self.vocab_size = self.tokenizer.vocab_size

self._batch_size = batch_size

self._max_length = max_length

@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id

@property
def max_length(self):
if self._max_length != -1:
return self._max_length
if hasattr(self.model.config, "n_ctx"):
return self.model.config.n_ctx
elif hasattr(self.model.config, "max_position_embeddings"):
return self.model.config.max_position_embeddings
elif hasattr(self.model.config, "n_positions"):
return self.model.config.n_positions
elif "bloom" in self.model_name:
return 2048
elif "llama" in self.model_name:
return 2048 # TODO: did not check this
elif "mpt" in self.model_name:
return 2048
elif "falcon" in self.model_name:
return 2048
else:
print(self.model.config)
raise NotImplementedError

@property
def max_gen_toks(self):
return 256

@property
def batch_size(self):
return self._batch_size

@property
def device(self):
return "cuda"

def tok_encode(self, string: str, add_special_tokens=True):
return self.tokenizer.encode(string, add_special_tokens=add_special_tokens)

def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)

def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
context, continuation = context.strip(), continuation.strip()
if context == "":
# end of text as context
context_enc = [self.eot_token_id]
else:
context_enc = self.tok_encode(context, add_special_tokens=True)

continuation_enc = self.tok_encode(continuation, add_special_tokens=False)

new_reqs.append(((context, continuation), context_enc, continuation_enc))

return self._loglikelihood_tokens(new_reqs)

def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
out = self.model(inps)[0]
return out

def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
)
2 changes: 2 additions & 0 deletions integration/BitNet/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
lm_eval==0.3.0
flash_attn

0 comments on commit dfe67fc

Please sign in to comment.