diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index f53090cfc..f5a0d1f24 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -333,7 +333,6 @@ def _score(node, thread): # small is better logger.info(info_message) codegen_dict.shared_scope = "shared.dyn" - # Or assume we always use shared memory codegen_dict.shared_scope = "shared.dyn" codegen_dict.complete_config(node) diff --git a/install.sh b/install.sh index b7b389626..db3b36827 100755 --- a/install.sh +++ b/install.sh @@ -46,7 +46,7 @@ fi echo "Download and extraction completed successfully." -LLVM_CONFIG_PATH="${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config" +LLVM_CONFIG_PATH="$(realpath ${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config)" echo "LLVM config path: $LLVM_CONFIG_PATH" # clone and build tvm diff --git a/integration/BitNet/README.md b/integration/BitNet/README.md index 8fa09f764..63cc3e275 100644 --- a/integration/BitNet/README.md +++ b/integration/BitNet/README.md @@ -2,8 +2,42 @@ license: mit --- + This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. +## Latest News + +- 08/09/2024 ✨: We provide a more efficient implementation for bitnet with vLLM, which should use special model checkpoints, to make the ckpt and study how to deploy, please checkout [Make Checkpoints for vLLM](#make-checkpoints-for-vllm). + +## Make Checkpoints for vLLM + +We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension. + +```bash +# move to the integration directory +cd /root/to/BitBLAS/integration/BitNet +# make the checkpoint +./maint/generate_bitnet_model_native_format.sh +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory +``` + +The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization. + +```bash +./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory +``` + +Finnaly, you can use the ckpt in vLLM with: + +```bash +cd vllm_workspace +# inference with the ckpt with fp16 uncompressed metadata +python3 inference_with_native_format.py +# inference with the ckpt with BitBLAS compressed metadata +python3 inference_with_bitblas_format.py +``` + ## BitBLAS Results ### Performance diff --git a/integration/BitNet/eval_correctness.py b/integration/BitNet/eval_correctness.py index cef89313d..4017a6c17 100644 --- a/integration/BitNet/eval_correctness.py +++ b/integration/BitNet/eval_correctness.py @@ -18,9 +18,6 @@ def generate_text(model, tokenizer, prompt, max_length=100): seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - # position_embeddings = model.embed_positions(position_ids) - # cos = position_embeddings[:, :, 0::2].cos() - # sin = position_embeddings[:, :, 1::2].sin() generation_config = GenerationConfig( max_length=max_length, @@ -32,7 +29,6 @@ def generate_text(model, tokenizer, prompt, max_length=100): start_time = time.time() output_ids = model.generate(input_ids, generation_config=generation_config) - # output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin) end_time = time.time() generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) diff --git a/integration/BitNet/create_bitblas_ckpt.py b/integration/BitNet/maint/create_bitblas_ckpt.py similarity index 86% rename from integration/BitNet/create_bitblas_ckpt.py rename to integration/BitNet/maint/create_bitblas_ckpt.py index d443b2e20..0bf603e0d 100644 --- a/integration/BitNet/create_bitblas_ckpt.py +++ b/integration/BitNet/maint/create_bitblas_ckpt.py @@ -1,24 +1,36 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import argparse import torch import bitblas -from modeling_bitnet import BitnetForCausalLM -from tokenization_bitnet import BitnetTokenizer from transformers.utils.hub import cached_file import os from transformers import GenerationConfig import time import json +import sys + +sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + "/../") +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer + filepath = os.path.abspath(__file__) dirpath = os.path.dirname(filepath) torch.set_grad_enabled(False) bitblas.set_log_level("INFO") -model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" -saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") +parser = argparse.ArgumentParser() +parser.add_argument("--model_name_or_path", type=str, default="1bitLLM/bitnet_b1_58-3B") +parser.add_argument("--saved_model_path", type=str, default=None) +args = parser.parse_args() + +model_name_or_path = args.model_name_or_path +saved_model_path = os.path.join( + dirpath, "models", + f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path def generate_text(model, tokenizer, prompt, max_length=100): diff --git a/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh b/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh new file mode 100755 index 000000000..3ace58031 --- /dev/null +++ b/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# retrieve the native model input and saved model directory +MODEL_DIR=$1 +SAVED_MODEL_DIR=$2 + +# check if the model directory exists +if [ ! -d "$MODEL_DIR" ]; then + echo "Model directory does not exist!" + exit 1 +fi + +# if the saved model directory does not exist, create it +# if SAVED_MODEL_DIR is not provided, we do not pass it to the script +if [ -z "$SAVED_MODEL_DIR" ]; then + python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR +else + python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR --saved_model_path $SAVED_MODEL_DIR +fi + +# get the realpath of the saved model directory +SAVED_MODEL_DIR=$(realpath $SAVED_MODEL_DIR) + +# cp files +cp $MODEL_DIR/quantize_config.json $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer.json $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer.model $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer_config.json $SAVED_MODEL_DIR/ + +echo "Model has been converted and save to $SAVED_MODEL_DIR" diff --git a/integration/BitNet/maint/generate_bitnet_model_native_format.sh b/integration/BitNet/maint/generate_bitnet_model_native_format.sh new file mode 100755 index 000000000..c002f6e13 --- /dev/null +++ b/integration/BitNet/maint/generate_bitnet_model_native_format.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# require git lfs +if ! command -v git-lfs &> /dev/null; then + echo "Please install git-lfs first by running 'sudo apt install git-lfs'" + exit 1 +fi + +mkdir -p models + +cd models + +# download the model +git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B ckpt_bitnet_b1_58-3B --depth 1 + +# copy quantized config into the model directory +cp ../maint/quantize_config.json ckpt_bitnet_b1_58-3B + +# get the realpath of the model directory +MODEL_DIR=$(realpath ckpt_bitnet_b1_58-3B) + +cd .. + +echo "Model has been converted and save to $MODEL_DIR" diff --git a/integration/BitNet/maint/quantize_config.json b/integration/BitNet/maint/quantize_config.json new file mode 100644 index 000000000..e2b24123a --- /dev/null +++ b/integration/BitNet/maint/quantize_config.json @@ -0,0 +1,10 @@ +{ + "bits": 2, + "desc_act": false, + "static_groups": false, + "sym": true, + "lm_head": false, + "model_name_or_path": "1bitLLM/bitnet_b1_58-3B", + "quant_method": "bitnet", + "checkpoint_format": "bitnet" +} \ No newline at end of file diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index d9cc25ae7..cb0c0f50b 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -138,6 +138,7 @@ def weight_quant(weight): result = (weight * s).round().clamp(-1, 1) return result.type(torch.int8) + @torch.compile def activation_quant(self, x, num_bits=8): x = x.float() Qn = -(2**(num_bits - 1)) @@ -146,6 +147,13 @@ def activation_quant(self, x, num_bits=8): result = (x * s).round().clamp(Qn, Qp) return result.type(torch.int8) + @torch.compile + def post_quant_process(self, input, si, sw): + out = input / si + out = out / sw + out = out.half() + return out + # for the correctness evaluation. def native_forward(self, input): quant_input = (input + (activation_quant(input, self.input_bits) - input).detach()) @@ -184,9 +192,8 @@ def forward(self, input): Qp = self.Qp si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) # if / (si * sw) it will inf in some cases - out = fp32_out / si - out = out / sw - out = out.half() + out = self.post_quant_process(fp32_out, si, sw) + if self.bias is not None: out += self.bias.view(1, -1).expand_as(out) return out diff --git a/integration/BitNet/vllm_workspace/conftest.py b/integration/BitNet/vllm_workspace/conftest.py new file mode 100644 index 000000000..fd5e162af --- /dev/null +++ b/integration/BitNet/vllm_workspace/conftest.py @@ -0,0 +1,625 @@ +import contextlib +import gc +import os +import sys +from collections import UserList +from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from transformers import ( + AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoTokenizer, + BatchEncoding, +) + +from vllm import LLM, SamplingParams +from vllm.assets.image import ImageAsset +from vllm.config import TokenizerPoolConfig +from vllm.distributed import ( + destroy_distributed_environment, + destroy_model_parallel, +) +from vllm.inputs import TextPrompt +from vllm.logger import init_logger +from vllm.sequence import SampleLogprobs +from vllm.utils import cuda_device_count_stateless, is_cpu + +logger = init_logger(__name__) + +_TEST_DIR = os.path.dirname(__file__) +_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] +_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] + + +def _read_prompts(filename: str) -> List[str]: + with open(filename, "r") as f: + prompts = f.readlines() + return prompts + + +class _ImageAssetPrompts(TypedDict): + stop_sign: str + cherry_blossom: str + + +if sys.version_info < (3, 9): + # UserList cannot be subscripted + class _ImageAssetsBase(UserList): + pass + +else: + + class _ImageAssetsBase(UserList[ImageAsset]): + pass + + +class _ImageAssets(_ImageAssetsBase): + + def __init__(self) -> None: + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) + + def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: + """ + Convenience method to define the prompt for each test image. + + The order of the returned prompts matches the order of the + assets when iterating through this object. + """ + return [prompts["stop_sign"], prompts["cherry_blossom"]] + + +IMAGE_ASSETS = _ImageAssets() +"""Singleton instance of :class:`_ImageAssets`.""" + + +def cleanup(): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + if not is_cpu(): + torch.cuda.empty_cache() + + +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + """Allow subdirectories to skip global cleanup by overriding this fixture. + This can provide a ~10x speedup for non-GPU unit tests since they don't need + to initialize torch. + """ + + if request.node.get_closest_marker("skip_global_cleanup"): + return False + + return True + + +@pytest.fixture(autouse=True) +def cleanup_fixture(should_do_global_cleanup_after_test: bool): + yield + if should_do_global_cleanup_after_test: + cleanup() + + +@pytest.fixture +def example_prompts() -> List[str]: + prompts = [] + for filename in _TEST_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture +def example_long_prompts() -> List[str]: + prompts = [] + for filename in _LONG_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture(scope="session") +def image_assets() -> _ImageAssets: + return IMAGE_ASSETS + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, +} + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) + + +class HfRunner: + + def wrap_device(self, input: _T) -> _T: + if not is_cpu(): + return input.to("cuda") + else: + return input.to("cpu") + + def __init__( + self, + model_name: str, + dtype: str = "half", + *, + model_kwargs: Optional[Dict[str, Any]] = None, + is_embedding_model: bool = False, + is_vision_model: bool = False, + is_sparseml_model: bool = False, + ) -> None: + assert dtype in _STR_DTYPE_TO_TORCH_DTYPE + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + + self.model_name = model_name + + if is_embedding_model: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + + self.model = self.wrap_device( + SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype) + ) + else: + if is_vision_model: + auto_cls = AutoModelForVision2Seq + elif is_sparseml_model: + from sparseml.transformers import SparseAutoModelForCausalLM + + auto_cls = SparseAutoModelForCausalLM + else: + auto_cls = AutoModelForCausalLM + + model_kwargs = model_kwargs if model_kwargs is not None else {} + self.model = self.wrap_device( + auto_cls.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + **model_kwargs, + ) + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + try: + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 + + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + except Exception: + logger.warning( + "Unable to auto-load processor from HuggingFace for " + "model %s. Using tokenizer instead.", + model_name, + ) + self.processor = self.tokenizer + + def generate( + self, + prompts: List[str], + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[List[int]], List[str]]]: + if images: + assert len(prompts) == len(images) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + + output_ids = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + **kwargs, + ) + output_str = self.processor.batch_decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output_ids = output_ids.cpu().tolist() + outputs.append((output_ids, output_str)) + return outputs + + def generate_greedy( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[int], str]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + images=images, + **kwargs, + ) + + return [ + (output_ids[0], output_str[0]) for output_ids, output_str in outputs + ] + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + num_beams=beam_width, + num_return_sequences=beam_width, + ) + for i in range(len(outputs)): + output_ids, output_str = outputs[i] + for j in range(len(output_ids)): + output_ids[j] = [ + x for x in output_ids[j] if x != self.tokenizer.pad_token_id + ] + outputs[i] = (output_ids, output_str) + return outputs + + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[List[torch.Tensor]]: + all_logprobs: List[List[torch.Tensor]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + seq_logprobs: List[torch.Tensor] = [] + for hidden_states in output.hidden_states: + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if self.model.get_output_embeddings().bias is not None: + logits += self.model.get_output_embeddings().bias.unsqueeze( + 0 + ) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + all_logprobs.append(seq_logprobs) + return all_logprobs + + def generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + input_ids = inputs.input_ids + + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + + seq_logprobs: List[torch.Tensor] = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if ( + getattr(self.model.get_output_embeddings(), "bias", None) + is not None + ): + logits += self.model.get_output_embeddings().bias.unsqueeze( + 0 + ) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst: List[Dict[int, float]] = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [ + (output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs + ] + + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + return self.model.encode(prompts) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def hf_runner(): + return HfRunner + + +class VllmRunner: + + def __init__( + self, + model_name: str, + tokenizer_name: Optional[str] = None, + # Use smaller max model length, otherwise bigger model cannot run due + # to kv cache size limit. + max_model_len: int = 1024, + dtype: str = "half", + disable_log_stats: bool = True, + tensor_parallel_size: int = 1, + block_size: int = 16, + enable_chunked_prefill: bool = False, + swap_space: int = 4, + enforce_eager: bool = False, + **kwargs, + ) -> None: + self.model = LLM( + model=model_name, + tokenizer=tokenizer_name, + trust_remote_code=True, + dtype=dtype, + swap_space=swap_space, + enforce_eager=enforce_eager, + disable_log_stats=disable_log_stats, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, + **kwargs, + ) + + def generate( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[List[int]], List[str]]]: + if images is not None: + assert len(prompts) == len(images) + + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = {"image": image} + + req_outputs = self.model.generate( + inputs, sampling_params=sampling_params + ) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for req_output in req_outputs: + prompt_str = req_output.prompt + prompt_ids = req_output.prompt_token_ids + req_sample_output_ids: List[List[int]] = [] + req_sample_output_strs: List[str] = [] + for sample in req_output.outputs: + output_str = sample.text + output_ids = list(sample.token_ids) + req_sample_output_ids.append(prompt_ids + output_ids) + req_sample_output_strs.append(prompt_str + output_str) + outputs.append((req_sample_output_ids, req_sample_output_strs)) + return outputs + + def generate_w_logprobs( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + assert sampling_params.logprobs is not None + + if images is not None: + assert len(prompts) == len(images) + + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = {"image": image} + + req_outputs = self.model.generate( + inputs, sampling_params=sampling_params + ) + outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + for req_output in req_outputs: + for sample in req_output.outputs: + output_str = sample.text + output_ids = sample.token_ids + output_logprobs = sample.logprobs + outputs.append((output_ids, output_str, output_logprobs)) + return outputs + + def generate_greedy( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str]]: + greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + outputs = self.generate(prompts, greedy_params, images=images) + return [ + (output_ids[0], output_str[0]) for output_ids, output_str in outputs + ] + + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs + ) + outputs = self.generate_w_logprobs( + prompts, greedy_logprobs_params, images=images + ) + + return [ + (output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs + ] + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + beam_search_params = SamplingParams( + n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens, + ) + outputs = self.generate(prompts, beam_search_params) + return outputs + + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +def get_tokenizer_pool_config(tokenizer_group_type): + if tokenizer_group_type is None: + return None + if tokenizer_group_type == "ray": + return TokenizerPoolConfig( + pool_size=1, pool_type="ray", extra_config={} + ) + raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog + + +@pytest.fixture(scope="session") +def num_gpus_available(): + """Get number of GPUs without initializing the CUDA context + in current process.""" + + return cuda_device_count_stateless() diff --git a/integration/BitNet/vllm_workspace/inference_with_compress_format.py b/integration/BitNet/vllm_workspace/inference_with_compress_format.py new file mode 100644 index 000000000..55a24543e --- /dev/null +++ b/integration/BitNet/vllm_workspace/inference_with_compress_format.py @@ -0,0 +1,46 @@ +"""Compare the outputs of a GPTQ model to a Marlin model. + +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_marlin.py`. +""" + +from conftest import VllmRunner +import os +import argparse + +# get the path of the current file +current_file_path = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file_path) + +ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas") +parser = argparse.ArgumentParser(description="Inference with BitNet") +parser.add_argument( + "--ckpt_path", + type=str, + default=ckpt_path, + help="Path to the checkpoint", +) + +args = parser.parse_args() + +ckpt_path = args.ckpt_path +with VllmRunner( + ckpt_path, + dtype="half", + quantization="bitblas", + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, +) as bitnet_model: + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], + max_tokens=1024) + print("bitnet inference:") + print(bitbnet_outputs[0][0]) + print(bitbnet_outputs[0][1]) diff --git a/integration/BitNet/vllm_workspace/inference_with_native_format.py b/integration/BitNet/vllm_workspace/inference_with_native_format.py new file mode 100644 index 000000000..4f5f87f6f --- /dev/null +++ b/integration/BitNet/vllm_workspace/inference_with_native_format.py @@ -0,0 +1,47 @@ +"""Compare the outputs of a GPTQ model to a Marlin model. + +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_marlin.py`. +""" + +from conftest import VllmRunner +import os +import argparse + +# get the path of the current file +current_file_path = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file_path) +ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B") + +parser = argparse.ArgumentParser(description="Inference with BitNet") +parser.add_argument( + "--ckpt_path", + type=str, + default=ckpt_path, + help="Path to the checkpoint", +) + +args = parser.parse_args() + +ckpt_path = args.ckpt_path + +with VllmRunner( + ckpt_path, + dtype="half", + quantization="bitnet_bitblas", + gpu_memory_utilization=0.5, + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, +) as bitnet_model: + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128) + print("bitnet inference output:") + print(bitbnet_outputs[0][0]) + print(bitbnet_outputs[0][1]) diff --git a/integration/BitNet/vllm_workspace/utils.py b/integration/BitNet/vllm_workspace/utils.py new file mode 100644 index 000000000..0d5e304d8 --- /dev/null +++ b/integration/BitNet/vllm_workspace/utils.py @@ -0,0 +1,65 @@ +from typing import Dict, List, Tuple + +TokensText = Tuple[List[int], str] + + +def check_outputs_equal(outputs_0_lst: List[TokensText], + outputs_1_lst: List[TokensText], name_0: str, + name_1: str): + """ + Compare the two sequences generated by different models, + which should be equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + for prompt_idx, (outputs_0, + outputs_1) in enumerate(zip(outputs_0_lst, + outputs_1_lst)): + output_ids_0, output_str_0 = outputs_0 + output_ids_1, output_str_1 = outputs_1 + + assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + +TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] + + +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], + outputs_1_lst: List[TokensTextLogprobs], name_0: str, + name_1: str): + """ + Compare the logprobs of two sequences generated by different models, + which should be similar but not necessarily equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + # Loop through responses to each prompt. + for prompt_idx, (outputs_0, + outputs_1) in enumerate(zip(outputs_0_lst, + outputs_1_lst)): + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + + # Loop through generated tokens. + for idx, (output_id_0, + output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): + + # If generated tokens don't match, then + if output_id_0 != output_id_1: + # Each predicted token must be in top N logprobs of the other + assert output_id_0 in logprobs_1[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_id_1 in logprobs_0[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + # Break out since sequences will now diverge. + break