Skip to content

Commit

Permalink
fix BitNet integration for vLLM (#137)
Browse files Browse the repository at this point in the history
* fix BitNet integration for vLLM

* update ckpt name of BitNet integration for vLLM

* format code
  • Loading branch information
xysmlx authored Aug 9, 2024
1 parent d52f93d commit 22b5262
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 42 deletions.
6 changes: 3 additions & 3 deletions integration/BitNet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ We provide two scripts to make the checkpoints for vLLM. The first script is `ge
cd /root/to/BitBLAS/integration/BitNet
# make the checkpoint
./maint/generate_bitnet_model_native_format.sh
# the output ckpy will be saved in the `./models/bitnet_b1_58-3B` directory
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory
```

The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization.

```bash
./maint/generate_bitnet_model_bitblas_format.sh ./models/bitnet_3B_1.58bit ./models/bitnet_3B_1.58bit_bitblas
# the output ckpy will be saved in the `./models/bitnet_b1_58-3B_bitblas` directory
./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory
```

Finnaly, you can use the ckpt in vLLM with:
Expand Down
14 changes: 10 additions & 4 deletions integration/BitNet/maint/create_bitblas_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,33 @@
import argparse
import torch
import bitblas
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
from transformers.utils.hub import cached_file
import os
from transformers import GenerationConfig
import time
import json

import sys

sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + "/../")
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer

filepath = os.path.abspath(__file__)
dirpath = os.path.dirname(filepath)

torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")

parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="BitBLASModel/open_llama_3b_1.58bits")
parser.add_argument("--model_name_or_path", type=str, default="1bitLLM/bitnet_b1_58-3B")
parser.add_argument("--saved_model_path", type=str, default=None)
args = parser.parse_args()

model_name_or_path = args.model_name_or_path
saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path
saved_model_path = os.path.join(
dirpath, "models",
f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path


def generate_text(model, tokenizer, prompt, max_length=100):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,10 @@ fi
# get the realpath of the saved model directory
SAVED_MODEL_DIR=$(realpath $SAVED_MODEL_DIR)

# cp files
cp $MODEL_DIR/quantize_config.json $SAVED_MODEL_DIR/
cp $MODEL_DIR/tokenizer.json $SAVED_MODEL_DIR/
cp $MODEL_DIR/tokenizer.model $SAVED_MODEL_DIR/
cp $MODEL_DIR/tokenizer_config.json $SAVED_MODEL_DIR/

echo "Model has been converted and save to $SAVED_MODEL_DIR"
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ mkdir -p models
cd models

# download the model
git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B bitnet_3B_1.58bits --depth 1
git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B ckpt_bitnet_b1_58-3B --depth 1

# copy quantized config into the model directory
cp ../maint/quant_config.json bitnet_3B_1.58bits
cp ../maint/quantize_config.json ckpt_bitnet_b1_58-3B

# get the realpath of the model directory
MODEL_DIR=$(realpath bitnet_3B_1.58bits)
MODEL_DIR=$(realpath ckpt_bitnet_b1_58-3B)

cd ..

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
current_file_path = os.path.realpath(__file__)
current_dir = os.path.dirname(current_file_path)

ckpt_path = os.path.join(current_dir, "../models/bitnet_3b_1.58bits_bitblas")
ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas")
parser = argparse.ArgumentParser(description="Inference with BitNet")
parser.add_argument(
"--ckpt_path",
Expand All @@ -32,14 +32,13 @@

ckpt_path = args.ckpt_path
with VllmRunner(
ckpt_path,
dtype="half",
quantization="bitblas",
enforce_eager=True,
ckpt_path,
dtype="half",
quantization="bitblas",
enforce_eager=True, # set False to enable cuda graph
) as bitnet_model:
bitbnet_outputs = bitnet_model.generate_greedy(
["Hi, tell me about microsoft?"], max_tokens=1024
)
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"],
max_tokens=1024)
print("bitnet inference:")
print(bitbnet_outputs[0][0])
print(bitbnet_outputs[0][1])
30 changes: 6 additions & 24 deletions integration/BitNet/vllm_workspace/inference_with_native_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
import os
import argparse


# get the path of the current file
current_file_path = os.path.realpath(__file__)
current_dir = os.path.dirname(current_file_path)
ckpt_path = os.path.join(current_dir, "../models/bitnet_3b_1.58bits")
ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas")

parser = argparse.ArgumentParser(description="Inference with BitNet")
parser.add_argument(
Expand All @@ -34,29 +33,12 @@
ckpt_path = args.ckpt_path

with VllmRunner(
ckpt_path,
dtype="half",
quantization="bitnet",
gpu_memory_utilization=0.5,
ckpt_path,
dtype="half",
quantization="bitnet",
gpu_memory_utilization=0.5,
) as bitnet_model:
bitbnet_outputs = bitnet_model.generate_greedy(
["Hi, tell me about microsoft?"], max_tokens=128
)
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128)
print("bitnet inference output:")
print(bitbnet_outputs[0][0])
print(bitbnet_outputs[0][1])

# with VllmRunner(
# "BitBLASModel/open_llama_3b_1.58bits_bitblas",
# dtype="half",
# quantization="bitblas",
# enforce_eager=True,
# ) as bitnet_model:
# torch.cuda.profiler.start()
# bitbnet_outputs = bitnet_model.generate_greedy(
# ["Hi, tell me about microsoft?"], max_tokens=1024
# )
# torch.cuda.profiler.stop()
# print("bitnet:")
# print(bitbnet_outputs[0][0])
# print(bitbnet_outputs[0][1])

0 comments on commit 22b5262

Please sign in to comment.