From 5d14d3161e10bb64e4f09c7e4a98a2859b792817 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 5 Aug 2024 21:29:26 +0800 Subject: [PATCH 1/7] [Dev] Refactor the weight transformation to support upcoming stage3 transform (#130) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * bug fix in test * lint fix. * test cuda i4 kernel * Refactor copyright notice in i4matmul.hpp * Refactor BitBLASLinear test module for improved readability and maintainability * refactor test as version below python 3.9 cannot handle int32 overflow. * format lint for test * Refactor test_int4b_fp16_convert.py for improved readability and maintainability * remove unused design file * move tile device from package to base * dummy impl for codegen * Refactor file structure for ladder_permutate module * Refactor backend class and fix typos in comments * Deep refactor Lib related code. * remove ci pull. * LintFix * refactor builder for whl build * Refactor TIRWrapper.wrap() method to include an assertion for the optimized module * Refactor lib_generator to set library and source paths * lint fix * BitNet vllm integration * chore: update codespell to version 2.3.0 * Lintfix * Bump version to 0.0.1.dev13 * lint fix * disable fast decoding [u]int4xint8 by default. * optimize from dict design in Hint * Implement SplitK * bitnet benchmark generation. * Add benchmark script for BitNet integration * AtomicAdd Support * LintFix * ci fix when 3rdparty tvm is initialized. * bug fix for setup * fix a bug in block reduce * typo fix * BUG Fix for block reduce. * Lint fix * Refactor block reduce schedule template * transform branch from bitblas to bitblas_tl * Fix subproject commit reference in 3rdparty/tvm * chore: update submodule branch from bitblas to bitblas_tl * force update config.cmake * Bug fix * Fix subproject commit reference in 3rdparty/cutlass * chore: Add submodule for cutlass library * update tl cutlass path * Refactor BitBLASLinear test module for improved readability and maintainability * format fix * Copy CUTLASS to the package directory * Refactor setup.py to include additional TVM header files * lint fix * bug fix * Refactor BitBLASLinear test module for improved readability and maintainability * Implement Matmul Benchmark Design * chore: Update BitBLAS Matmul benchmark script * lint fix * Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * lint fix * Benchmark bot test * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * int8 test case * Refactor compare_benchmark.py to handle missing benchmark results gracefully * ci fix * disable ci for test benchmark * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * remove cli installation * chore: Create virtual environment and install dependencies for benchmark * chore: Update benchmark workflow to include comparison step * Lint fix * upodate tvm cmmit * Imporve lower warp memory pass * Bug fix * Enhance to support warp schedule. * Enhance LOP3 Instructions * Enhance LOP3 Instructions * add test for stage3 propagate * implement propagate func * Stage3 Ladder Permutate integration * get_ladder_stage3_propagate * comments benchmark scirpts as the setting is too big * ci fix for benchmark * lint fix * chore: Update benchmark workflow to trigger on pull request comments * Add LDMatrix Transform 3 * Support GPTQ Test * Fuse BlockReduce Schedule * Support mma propagate 3 * Support MMA Propagate Stage 3 * Lint Fix * Merge block reduce for dequantze config. * fix codeql * chore: Update submodule reference to latest commit * chore: Disable common subexpression elimination in TIR passes * Lint Fix * 4bit related lop3 updates. * lint fix * gptq test fix * Fix for test * lint fix * lint fix * typofix * QuantCompress Test * chore: Refactor quant_compress_impl.py for readability and maintainability * Enhance docs to update latest works. * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * removed legacy operator * Refactor weight executors in Matmul class for improved readability and maintainability * LintFix * Fix GPTQ Repack with the latest weight transform * lint fix * bug fix for rescale dequantize * test fix * typo fix --- bitblas/__init__.py | 25 +- bitblas/gpu/matmul_mma_dequantize.py | 9 +- bitblas/module/__init__.py | 21 +- bitblas/ops/__init__.py | 3 +- bitblas/ops/general_matmul/__init__.py | 49 ++- bitblas/ops/ladder_permutate/__init__.py | 18 + bitblas/ops/lop3_permutate/__init__.py | 18 +- bitblas/ops/matmul.py | 276 --------------- bitblas/ops/matmul_dequantize.py | 320 ------------------ bitblas/ops/operator.py | 2 - integration/pytorch/bitblas_quant_linear.py | 2 +- testing/python/module/test_bitblas_linear.py | 8 +- .../operators/test_general_matmul_ops.py | 8 +- 13 files changed, 122 insertions(+), 637 deletions(-) delete mode 100644 bitblas/ops/matmul.py delete mode 100644 bitblas/ops/matmul_dequantize.py diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 91e88133..a1bc95f3 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -39,9 +39,10 @@ from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 -from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401 from .module import Linear # noqa: F401 +import warnings +import functools import logging from tqdm import tqdm @@ -89,4 +90,26 @@ def _init_logger(): _init_logger() + +def deprecated(reason): + """ + This is a decorator which can be used to mark functions as deprecated. + It will result in a warning being emitted when the function is used. + """ + + def decorator(func): + + @functools.wraps(func) + def new_func(*args, **kwargs): + warnings.warn( + f"Call to deprecated function {func.__name__} ({reason}).", + category=DeprecationWarning, + stacklevel=2) + return func(*args, **kwargs) + + return new_func + + return decorator + + __version__ = "0.0.1.dev13" diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 4575fa36..2033b8f7 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -2264,10 +2264,11 @@ def get_idx(): lop3_intrin_info["compute"], ) # Assume the grouped K is the last dim of the scaling - grouped_k = sch.get(bf).reads[1].buffer.shape[-1] - # TODO(lei): This is a hack to get the loop extent - loop_extent = 8 if out_dtype == "float16" else 16 - sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k) + if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: + grouped_k = sch.get(bf).reads[1].buffer.shape[-1] + # TODO(lei): This is a hack to get the loop extent + loop_extent = 8 if out_dtype == "float16" else 16 + sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k) import_source.append(lop3_intrin_info["c_source"]) def tensorize_init_store_compute(): diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index a3fbba21..f148097d 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -40,6 +40,24 @@ def unpack_qzeros(qzeros, bits): return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1) +def unpack_qweight(qweight, bits): + qweight = qweight.view(torch.int8) + elems_per_int8 = 8 // bits + unpacked_weight = torch.zeros( + (qweight.shape[0], qweight.shape[1] * elems_per_int8), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + for col in range(unpacked_weight.shape[1]): + i = col % elems_per_int8 + unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> (bits * i)) + + # Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303 + # NOTE: It appears that casting after the `unpacked_zeros + 1` is important. + return torch.bitwise_and(unpacked_weight, 2**bits - 1) + + class Linear(nn.Module): opt_M = [1, 16, 32, 64, 128, 256, 512] STORAGE_DTYPE = "int8" # assume int8 storage @@ -279,8 +297,9 @@ def load_and_transform_weight( def repack_from_gptq(self, gptq_module): # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) + intweight = unpack_qweight(qweight, self.bits).contiguous() if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() + qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda() self.qweight = qweight # scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed. scales = gptq_module.scales.T.contiguous().view(self.torch_dtype) diff --git a/bitblas/ops/__init__.py b/bitblas/ops/__init__.py index 4fa45647..a132a83b 100644 --- a/bitblas/ops/__init__.py +++ b/bitblas/ops/__init__.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .operator import Operator, OperatorConfig # noqa: F401 -from .matmul import Matmul, MatmulConfig # noqa: F401 -from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401 +from .general_matmul import Matmul, MatmulConfig # noqa: F401 from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401 from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig # noqa: F401 from .quant_compress import QuantCompress, QuantCompressConfig # noqa: F401 diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 184da0b0..7d99d962 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -13,6 +13,7 @@ from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass from ..ladder_permutate import LadderPermutate, LadderPermutateConfig +from ..quant_compress import QuantCompress, QuantCompressConfig from ..lop3_permutate import LOP3Permutate, LOP3PermutateConfig import logging import torch @@ -292,6 +293,7 @@ def dispatch_tir(self, # create permutate_opertors self.ladder_permutate_a = self._assign_ladder_permutate_a(target, enable_tuning) self.ladder_permutate_b = self._assign_ladder_permutate_b(target, enable_tuning) + self.weight_compress = self._assign_weight_compress(target, enable_tuning) self.lop3_permutate = self._assign_lop3_permutate(target, enable_tuning) # create cpu weight executors self.input_executors = self._create_input_executors() @@ -338,11 +340,14 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): del enable_tuning if self.propagate_b: + # weight transform should be done in the unpacked level + # otherwise the bit trick should be applied and that is + # too complex to be implemented in the ladder permutation. ladder_permutate_config = LadderPermutateConfig( M=self.N, N=self.K, datatype=self.A_dtype, - dequantize_bits=self.bit, + dequantize_bits=-1, storage_dtype=self.storage_dtype, propagate_kind="B", transpose_matrix=self.layout == "nt", @@ -354,6 +359,25 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): ) return None + def _assign_weight_compress(self, target: Target, enable_tuning: bool): + # unused variables + del target + del enable_tuning + + require_compress: bool = self.bit in [1, 2, 4] + if require_compress: + quant_compress_config = QuantCompressConfig( + M=self.N, + N=self.K, + input_dtype=self.storage_dtype, + storage_dtype=self.storage_dtype, + dequantize_bits=self.bit) + return QuantCompress( + config=quant_compress_config, + target=tvm.target.Target("llvm"), + ) + return None + def _assign_lop3_permutate(self, target: Target, enable_tuning: bool): # unused variables del target @@ -381,10 +405,12 @@ def _create_input_executors(self): def _create_weight_executors(self): weight_executors = OPExecutorCPU() - if self.fast_decoding: - weight_executors.append(self.lop3_permutate) if self.propagate_b is not TransformKind.NonTransform: weight_executors.append(self.ladder_permutate_b) + if self.weight_compress is not None: + weight_executors.append(self.weight_compress) + if self.fast_decoding: + weight_executors.append(self.lop3_permutate) return weight_executors def _select_implementation(self): @@ -452,10 +478,6 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): return self.weight_transform(weight.cpu()).cuda().contiguous() return weight - from bitblas.quantization import general_compress - import torch - import numpy as np - source_format, bit = self.source_format, self.bit # Process integer source format @@ -464,20 +486,13 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): assert not self.with_zeros, "zeros should be False for int source format" maxq = 2**(bit - 1) # Clamp weight values to be within the quantizable range and adjust - weight = torch.clamp(weight, -maxq, maxq).int() + maxq + weight = torch.clamp(weight, -maxq, maxq).char() + maxq elif source_format in ["fp_e5m2", "fp_e4m3"]: weight = weight.view(torch.int8) - weight = weight.int() else: # For non-integer formats, simply convert weights to integers - weight = weight.int() - - np_storage_dtype = getattr(np, self.storage_dtype) - - weight = general_compress( - weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) - - weight = torch.from_numpy(weight).cuda().contiguous() + # And assume weight is in the range of [-128, 127] for int8 + weight = weight.char() # Apply an optional weight transformation if specified if self.weight_transform is not None: diff --git a/bitblas/ops/ladder_permutate/__init__.py b/bitblas/ops/ladder_permutate/__init__.py index 6644705c..d09ee6da 100644 --- a/bitblas/ops/ladder_permutate/__init__.py +++ b/bitblas/ops/ladder_permutate/__init__.py @@ -5,6 +5,7 @@ from ..operator import Operator from .ladder_permutate_impl import select_implementation from dataclasses import dataclass +import torch @dataclass(frozen=True) @@ -57,6 +58,23 @@ def _select_implementation(self): target_instruction=self.target_instruction, ) + def forward(self, inp, out=None): + if out is None: + out_shape, out_dtype = self.retrieve_output_shape() + out = torch.zeros(out_shape, dtype=out_dtype).to(inp.device) + self.torch_func(inp, out) + return out + + def retrieve_output_shape(self): + """ + Retrieve the output shape of the operator + """ + func = self.prim_func + param = func.params[-1] + assert param in func.buffer_map, f"param {param} not in buffer_map" + arg = func.buffer_map[param] + return [int(i) for i in arg.shape], getattr(torch, arg.dtype) + @property def M(self): return self.config.M diff --git a/bitblas/ops/lop3_permutate/__init__.py b/bitblas/ops/lop3_permutate/__init__.py index 10c452b3..19c4b0ee 100644 --- a/bitblas/ops/lop3_permutate/__init__.py +++ b/bitblas/ops/lop3_permutate/__init__.py @@ -42,11 +42,23 @@ def _select_implementation(self): dequantize_bits=self.dequantize_bits, ) - def forward(self, weight, res): + def forward(self, inp, out=None): + out_shape = inp.shape + out_dtype = inp.dtype + if out is None: + # lop3 transform does not change the shape of the input tensor + out = torch.zeros(out_shape, dtype=out_dtype) # reinterpret the input tensor to int32 format - args = [arg.view(torch.int32) for arg in [weight, res]] + shape_2dim = self.retrieve_2d_weight_shape() + args = [arg.view(inp.dtype).view(shape_2dim).view(torch.int32) for arg in [inp, out]] self.torch_func(*args) - return args[-1].view(weight.dtype) + return args[-1].view(out_dtype).view(out_shape) + + def retrieve_2d_weight_shape(self): + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + elems_per_byte = storage_nbit // self.dequantize_bits + weight_shape = (self.M, self.N // elems_per_byte) + return weight_shape @property def M(self): diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py deleted file mode 100644 index e515a264..00000000 --- a/bitblas/ops/matmul.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -import numpy as np -from tvm.target import Target -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from typing import List, Union, Optional, Any, Tuple -from .operator import Operator, TransformKind -from .impl.matmul_impl import select_implementation -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -import logging - -logger = logging.getLogger(__name__) - - -class TransformExecutorCPU: - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) - - -@dataclass(frozen=True) -class MatmulConfig: - M: Union[int, Tuple[int]] - N: int - K: int - in_dtype: str = "float16" - out_dtype: str = "float16" - accum_dtype: str = "float16" - with_bias: bool = False - # layout of matrix A and B - # "nn": C[i, j] = A[i, k] * B[k, j] - # "nt": C[i, j] = A[i, k] * B[j, k] - layout: str = "nt" - # weight transformation kind of matrix A - propagate_a: TransformKind = TransformKind.NonTransform - # weight transformation kind of matrix B - propagate_b: TransformKind = TransformKind.NonTransform - - def __post_init__(self): - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - -class Matmul(Operator): - - def __init__( - self, - config: MatmulConfig, - name: str = "matmul", - target: Union[str, Target] = "cuda", - enable_tuning: bool = False, - from_database: bool = False, - ): - super().__init__(name, config, target) - target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") - - if isinstance(self.M, Tuple): - self.dynamic_range = {"m": self.M} - self.update_func(self.prim_func.with_attrs({"opt_shapes": self.dynamic_range})) - else: - self.dynamic_range = None - - if not from_database: - self._build_default_module(target) - - if self.propagate_a: - assert (self.propagate_a is - TransformKind.NonTransform), "Currently only support NonTransform for input" - ladder_permutate_config = LadderPermutateConfig( - M=self.M, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="A", - transpose_matrix=False, - transform_kind=self.propagate_a, - ) - self.ladder_permutate_a = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_a = None - - if self.propagate_b: - ladder_permutate_config = LadderPermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="B", - transpose_matrix=(self.layout == "nt"), - transform_kind=self.propagate_b, - ) - self.ladder_permutate_b = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_b = None - - input_executors = TransformExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_a) - - self.input_executors = input_executors - - weight_executors = TransformExecutorCPU() - if self.ladder_permutate_b is not None: - weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - with_bias=self.with_bias, - layout=self.layout, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def post_process(self, code: str) -> str: - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - def _profile_latency_with_dynamic_range(self) -> List: - func = self.prim_func_mod["main"] - device = self.arch.device - - def var_warpper(v, m): - if isinstance(v, tvm.tir.Var): - assert "opt_shapes" in func.attrs - assert v.name in func.attrs["opt_shapes"] - return m - elif isinstance(v, tvm.tir.IntImm): - return v.value - else: - raise RuntimeError("Not supported type: ", type(v)) - - benchmark_latencies = [] - for m in self.dynamic_range["m"]: - profile_tensors = [] - for param in func.params: - if param not in func.buffer_map: - # in case of dynamic symbolic may in params - continue - arg = func.buffer_map[param] - profile_tensors.append( - tvm.nd.array( - np.random.uniform(0, 1, - [var_warpper(i, m) for i in arg.shape]).astype(arg.dtype), - device=device, - )) - latency = self.time_evaluator(*profile_tensors).mean * 1e3 - benchmark_latencies.append({"m": m, "latency": latency}) - # ms - return benchmark_latencies - - def forward(self, *args) -> Any: - if self.lib is None: - self._forward_from_torch_func(*args) - dynamic_symbolic = [] - if self.dynamic_range is not None: - # assume we only have one dynamic range - m = args[0].shape[0] - dynamic_symbolic.append(m) - self._forward_from_prebuild_lib(*args, *dynamic_symbolic) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def K(self): - return self.config.K - - @property - def in_dtype(self): - return self.config.in_dtype - - @property - def out_dtype(self): - return self.config.out_dtype - - @property - def accum_dtype(self): - return self.config.accum_dtype - - @property - def layout(self): - return self.config.layout - - @property - def with_bias(self): - return self.config.with_bias - - @property - def propagate_a(self): - return self.config.propagate_a - - @property - def propagate_b(self): - return self.config.propagate_b - - @property - def input_transform(self): - return self.input_executors if self.input_executors.size else None - - @property - def weight_transform(self): - return self.weight_executors if self.weight_executors.size else None - - -__all__ = ["Matmul", "MatmulConfig"] diff --git a/bitblas/ops/matmul_dequantize.py b/bitblas/ops/matmul_dequantize.py deleted file mode 100644 index 6971547b..00000000 --- a/bitblas/ops/matmul_dequantize.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -from tvm.target import Target -from bitblas.base.arch.cuda import CUDA -from typing import Any, List, Literal, Optional, Tuple, Union -from .operator import Operator, TransformKind -from .impl.matmul_dequantize_impl import select_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig -import logging - -logger = logging.getLogger(__name__) - - -class OPExecutorCPU: - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) - - -@dataclass(frozen=True) -class MatmulWeightOnlyDequantizeConfig: - M: Union[int, Tuple[int]] - N: int - K: int - in_dtype: str = "float16" - out_dtype: str = "float16" - accum_dtype: str = "float16" - bit: int = 4 - storage_dtype: str = "int8" - # documents for source_format: - # the format of the source data, which can be "int", "uint", "fp", "nf" - # "int": dequantize_weight = (target)((int)(quantize_weight - fixed_zero_point)) * scale - # where the fixed_zero_point is 2^(bit - 1) - 1 - # "uint": dequantize_weight = (target)((uint)(quantize_weight - zero_point)) * scale - # where the zero_point is manually set by zeros tensor - # "fp": dequantize_weight = (quantize_weight - zero_point) * scale - # "nf": dequantize_weight = (lut[quantize_weight] - zero_point) * scale - source_format: Literal["int", "uint", "fp", "nf"] = "int" - with_scaling: bool = False - with_zeros: bool = False - group_size: int = -1 - fast_decoding: bool = False - with_bias: bool = False - propagate_a: TransformKind = TransformKind.NonTransform - propagate_b: TransformKind = TransformKind.NonTransform - layout: str = "nt" - # documents for zeros_mode: - # original: target = (dequantize_weight - zero_point) * scale - # rescale: target = dequantize_weight * scale - zero_point - # quantized: target = (dequantize_weight - dequantize_zeros) * scale - # The auto-gptq framework prefer "quantized" and "original" for alignment with cuda. - zeros_mode: Literal["original", "rescale", "quantized"] = "original" - - def __post_init__(self): - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - -class MatmulWeightOnlyDequantize(Operator): - - def __init__( - self, - config: MatmulWeightOnlyDequantizeConfig, - name: str = "matmul_weight_only_dequantize", - target: Target = "cuda", - enable_tuning: bool = False, - from_database: bool = False, - ): - super().__init__(name, config, target) - - target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") - - self.arch = CUDA(target) - - if isinstance(self.M, Tuple): - self.dynamic_range = {"m": self.M} - self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( - {"opt_shapes": self.dynamic_range}) - else: - self.dynamic_range = None - - if not from_database: - self._build_default_module(target) - - if self.propagate_a: - ladder_permutate_config = LadderPermutateConfig( - M=self.M, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="A", - transpose_matrix=False, - transform_kind=self.propagate_a, - ) - self.ladder_permutate_a = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_a = None - - if self.propagate_b: - ladder_permutate_config = LadderPermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - propagate_kind="B", - transpose_matrix=self.layout == "nt", - transform_kind=self.propagate_b, - ) - self.ladder_permutate_b = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_b = None - - if self.fast_decoding: - lop3_permutate_config = LOP3PermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - ) - self.lop3_permutate = LOP3Permutate( - config=lop3_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.lop3_permutate = None - - input_executors = OPExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_a) - self.input_executors = input_executors - - weight_executors = OPExecutorCPU() - if self.lop3_permutate is not None: - weight_executors.append(self.lop3_permutate) - - if self.ladder_permutate_b is not None: - weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - bit=self.bit, - storage_dtype=self.storage_dtype, - source_format=self.source_format, - with_scaling=self.with_scaling, - with_zeros=self.with_zeros, - group_size=self.group_size, - fast_decoding=self.fast_decoding, - with_bias=self.with_bias, - layout=self.layout, - zeros_mode=self.zeros_mode, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def post_process(self, code: str) -> str: - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - def retrieve_weight_shape(self): - return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] - - def forward(self, *args) -> Any: - if self.lib is None: - self._forward_from_torch_func(*args) - dynamic_symbolic = [] - if self.dynamic_range is not None: - # assume we only have one dynamic range - m = args[0].shape[0] - dynamic_symbolic.append(m) - self._forward_from_prebuild_lib(*args, *dynamic_symbolic) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def K(self): - return self.config.K - - @property - def in_dtype(self): - return self.config.in_dtype - - @property - def out_dtype(self): - return self.config.out_dtype - - @property - def accum_dtype(self): - return self.config.accum_dtype - - @property - def bit(self): - return self.config.bit - - @property - def storage_dtype(self): - return self.config.storage_dtype - - @property - def source_format(self): - return self.config.source_format - - @property - def with_scaling(self): - return self.config.with_scaling - - @property - def with_zeros(self): - return self.config.with_zeros - - @property - def group_size(self): - return self.config.group_size - - @property - def fast_decoding(self): - return self.config.fast_decoding - - @property - def with_bias(self): - return self.config.with_bias - - @property - def propagate_a(self): - return self.config.propagate_a - - @property - def propagate_b(self): - return self.config.propagate_b - - @property - def layout(self): - return self.config.layout - - @property - def zeros_mode(self): - return self.config.zeros_mode - - @property - def input_transform(self): - return self.input_executors if self.input_executors.size else None - - @property - def weight_transform(self): - return self.weight_executors if self.weight_executors.size else None diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 29d38430..f6fa4cca 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -13,7 +13,6 @@ from bitblas.base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy from bitblas.base.arch import get_arch -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch from bitblas.builder.wrapper import TIRWrapper from bitblas.builder.lib_generator import LibraryGenerator from dataclasses import dataclass @@ -371,7 +370,6 @@ def is_none(self): def forward(self, weight): inputs = [weight] for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) inputs = [op.forward(*inputs)] return inputs[-1] diff --git a/integration/pytorch/bitblas_quant_linear.py b/integration/pytorch/bitblas_quant_linear.py index c0cdac61..6e6610c1 100644 --- a/integration/pytorch/bitblas_quant_linear.py +++ b/integration/pytorch/bitblas_quant_linear.py @@ -182,7 +182,7 @@ def pack(self, linear, scales, zeros=None): (w[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to(torch.int)[:, None]) intweight = torch.cat(intweight, dim=1) intweight = intweight.contiguous() - intweight = intweight.cpu().numpy().astype(np.int8) + intweight = intweight.cpu().to(torch.int8) # quantize to 4bit qw_np = general_compress(intweight, source_bits=self.bits, storage_dtype=np.int8) # do interleave for fast type conversion diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index f329a146..3adacaa8 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -98,7 +98,7 @@ def correctness_weight_only_dequantize( inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) intweight = inputs[1] - intweight = intweight.cpu().numpy().astype(np.int8) + intweight = intweight.cpu().to(torch.int8) if source_format == "int": intweight = intweight + maxq if with_zeros: @@ -109,15 +109,13 @@ def correctness_weight_only_dequantize( ref_result = ref_result + bias_tensor with torch.no_grad(): - qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) - qw_torch = torch.from_numpy(qw_np).cuda() permuted_inputs = [] permuted_inputs.append(inputs[0]) if linear_bitblas.bitblas_matmul.weight_transform is not None: permuted_inputs.append( - linear_bitblas.bitblas_matmul.weight_transform(qw_torch.cpu()).cuda()) + linear_bitblas.bitblas_matmul.weight_transform(intweight.cpu()).cuda()) else: - permuted_inputs.append(qw_torch) + permuted_inputs.append(inputs[1]) linear_bitblas.qweight.data = permuted_inputs[-1].clone() if with_scaling: if group_size == -1: diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index 62808e2a..354914d2 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -155,7 +155,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) intweight = inputs[1] - intweight = intweight.cpu().numpy().astype(np.int8) + intweight = intweight.cpu().to(torch.int8) if source_format == "int": intweight = intweight + maxq if with_zeros: @@ -165,14 +165,12 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) if with_bias: ref_result = ref_result + bias - qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) - qw_torch = torch.from_numpy(qw_np).cuda() permuted_inputs = [] permuted_inputs.append(inputs[0]) if matmul.weight_transform is not None: - permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda()) + permuted_inputs.append(matmul.weight_transform(intweight.cpu()).cuda()) else: - permuted_inputs.append(qw_torch) + permuted_inputs.append(intweight) if with_scaling: if group_size == -1: group_size = K From 2e60d2bf98d0a777cbdd0ebcc9d5032746608123 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 5 Aug 2024 22:17:30 +0800 Subject: [PATCH 2/7] [Dev] Bring Block Reduction into our seach space and policy (#132) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * bug fix in test * lint fix. * test cuda i4 kernel * Refactor copyright notice in i4matmul.hpp * Refactor BitBLASLinear test module for improved readability and maintainability * refactor test as version below python 3.9 cannot handle int32 overflow. * format lint for test * Refactor test_int4b_fp16_convert.py for improved readability and maintainability * remove unused design file * move tile device from package to base * dummy impl for codegen * Refactor file structure for ladder_permutate module * Refactor backend class and fix typos in comments * Deep refactor Lib related code. * remove ci pull. * LintFix * refactor builder for whl build * Refactor TIRWrapper.wrap() method to include an assertion for the optimized module * Refactor lib_generator to set library and source paths * lint fix * BitNet vllm integration * chore: update codespell to version 2.3.0 * Lintfix * Bump version to 0.0.1.dev13 * lint fix * disable fast decoding [u]int4xint8 by default. * optimize from dict design in Hint * Implement SplitK * bitnet benchmark generation. * Add benchmark script for BitNet integration * AtomicAdd Support * LintFix * ci fix when 3rdparty tvm is initialized. * bug fix for setup * fix a bug in block reduce * typo fix * BUG Fix for block reduce. * Lint fix * Refactor block reduce schedule template * transform branch from bitblas to bitblas_tl * Fix subproject commit reference in 3rdparty/tvm * chore: update submodule branch from bitblas to bitblas_tl * force update config.cmake * Bug fix * Fix subproject commit reference in 3rdparty/cutlass * chore: Add submodule for cutlass library * update tl cutlass path * Refactor BitBLASLinear test module for improved readability and maintainability * format fix * Copy CUTLASS to the package directory * Refactor setup.py to include additional TVM header files * lint fix * bug fix * Refactor BitBLASLinear test module for improved readability and maintainability * Implement Matmul Benchmark Design * chore: Update BitBLAS Matmul benchmark script * lint fix * Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * lint fix * Benchmark bot test * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * int8 test case * Refactor compare_benchmark.py to handle missing benchmark results gracefully * ci fix * disable ci for test benchmark * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * remove cli installation * chore: Create virtual environment and install dependencies for benchmark * chore: Update benchmark workflow to include comparison step * Lint fix * upodate tvm cmmit * Imporve lower warp memory pass * Bug fix * Enhance to support warp schedule. * Enhance LOP3 Instructions * Enhance LOP3 Instructions * add test for stage3 propagate * implement propagate func * Stage3 Ladder Permutate integration * get_ladder_stage3_propagate * comments benchmark scirpts as the setting is too big * ci fix for benchmark * lint fix * chore: Update benchmark workflow to trigger on pull request comments * Add LDMatrix Transform 3 * Support GPTQ Test * Fuse BlockReduce Schedule * Support mma propagate 3 * Support MMA Propagate Stage 3 * Lint Fix * Merge block reduce for dequantze config. * fix codeql * chore: Update submodule reference to latest commit * chore: Disable common subexpression elimination in TIR passes * Lint Fix * 4bit related lop3 updates. * lint fix * gptq test fix * Fix for test * lint fix * lint fix * typofix * QuantCompress Test * chore: Refactor quant_compress_impl.py for readability and maintainability * Enhance docs to update latest works. * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * removed legacy operator * Refactor weight executors in Matmul class for improved readability and maintainability * LintFix * Fix GPTQ Repack with the latest weight transform * lint fix * bug fix for rescale dequantize * test fix * typo fix * lint fix --- 3rdparty/tvm | 2 +- bitblas/base/roller/policy/tensorcore.py | 158 ++++++++++-------- bitblas/gpu/matmul_analysis.py | 6 +- .../tirscript/matmul_dequantize_impl.py | 6 +- 4 files changed, 93 insertions(+), 79 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c441882e..6daecacc 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c441882e2372deeb33d0eaefd62a133d482ac669 +Subproject commit 6daecacc73c8c8fdea1b9732891e1d4a5ebbf818 diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index 9e6fff9e..468498fb 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -117,83 +117,92 @@ def _check_small_tile(td: TileDict): return True return False - if not _check_small_tile(td): - return None + if _check_small_tile(td): + + smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) + rstep_map = td.rstep_map.copy() + + def _optimize(node, rstep): + all_steps = self.get_node_reduce_step_candidates(node) + # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] + for k in all_steps: + all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) + if any([v == [] for v in all_steps.values()]): + return rstep + + def _shared_memory_usage(td: TileDict): + return node.footprint(td.output_tile, new_rstep_map, + td.tensor_strides_map[node]) + + def _score(rstep_id): + rstep = { + k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis + } + score = 0 + shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) + input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) + for i, input_buffer in enumerate(input_buffers): + score += coalesced_factor(shape[i], input_buffer.shape) + return score + + def _enlarge(rstep_id): + candidates = [] + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + if len(candidates) == 0: + return None + return max(candidates, key=lambda x: x[1])[0] + + cur_rstep_id = { + k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis + } + new_rstep_map = rstep_map.copy() + while True: + new_rstep_id = _enlarge(cur_rstep_id) + if new_rstep_id is None: + break + new_rstep_map = { + k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] + for k in node.raxis + } + old_rstep_map = td.rstep_map + td.rstep_map = new_rstep_map + smem_usage, _ = _shared_memory_usage(td) + td.rstep_map = old_rstep_map + if smem_usage > smem_limit: + break + else: + cur_rstep_id = new_rstep_id + rstep = { + k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis + } + return rstep - smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) - rstep_map = td.rstep_map.copy() + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _optimize(node, rstep_map) + rstep_map = rstep - def _optimize(node, rstep): - all_steps = self.get_node_reduce_step_candidates(node) - # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] - for k in all_steps: - all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) - if any([v == [] for v in all_steps.values()]): - return rstep + td.rstep_map = rstep_map + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) - def _shared_memory_usage(td: TileDict): - return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node]) + if self.block_reduction_depth is not None: - def _score(rstep_id): - rstep = { - k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis - } - score = 0 - shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) - input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) - for i, input_buffer in enumerate(input_buffers): - score += coalesced_factor(shape[i], input_buffer.shape) - return score - - def _enlarge(rstep_id): - candidates = [] - for ax in rstep_id: - if rstep_id[ax] + 1 == len(all_steps[ax]): - continue - r = rstep_id.copy() - r[ax] += 1 - candidates.append((r, _score(r))) - if len(candidates) == 0: - return None - return max(candidates, key=lambda x: x[1])[0] - - cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis - } - new_rstep_map = rstep_map.copy() - while True: - new_rstep_id = _enlarge(cur_rstep_id) - if new_rstep_id is None: - break - new_rstep_map = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis - } - old_rstep_map = td.rstep_map - td.rstep_map = new_rstep_map - smem_usage, _ = _shared_memory_usage(td) - td.rstep_map = old_rstep_map - if smem_usage > smem_limit: - break - else: - cur_rstep_id = new_rstep_id - rstep = { - k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis - } - return rstep + def _expand_with_tags(rstep): + new_rstep = {k: v * self.block_reduction_depth for k, v in rstep.items()} + return new_rstep + + rstep_map = td.rstep_map.copy() + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _expand_with_tags(rstep_map) + rstep_map = rstep + td.rstep_map = rstep_map - for node in self.ordered_nodes: - if len(node.raxis) > 0: - rstep = _optimize(node, rstep_map) - rstep_map = rstep - - # if is_block_reduction: - # # If block reduction, we should constrain the max value is 64 - # # Otherwise it will introduce an issue of cuda invalid args. - # MAX_REDUCE_K = 64 - # for k in rstep_map: - # rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) - td.rstep_map = rstep_map - td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) return def get_node_reduce_step_candidates(self, node): @@ -318,12 +327,15 @@ def _score(node, thread): # small is better # smem capacity # TODO: This is a dummy mul which avoid reusing some shared memory. # Should be removed in the future. - if td.smem_cost > (self.arch.smem_cap * 1.3): + if td.smem_cost > (self.arch.smem_cap): info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ " use dynamic shared memory." logger.info(info_message) codegen_dict.shared_scope = "shared.dyn" + # Or assume we always use shared memory + # codegen_dict.shared_scope = "shared.dyn" + codegen_dict.complete_config(node) codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size) codegen_dict.arch = self.arch diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 210c560a..1d0889fa 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -622,14 +622,16 @@ def check_last_trait(region: List[Range]): # Analysis Block Reduction Optimization # Currently, we only support block reduction depth 2 for small M # When the func is a dequantize like ops, we should consider the M + require_block_reduce = False if hasattr(func.attrs, "dequantize_info"): for arg in func.params: inp_shape = func.buffer_map[arg].shape M = inp_shape[0] if isinstance(M, tir.IntImm) and M <= 128: - tags["block_reduction_depth"] = 2 + require_block_reduce = True break - + if require_block_reduce and check_sm_version(target.arch) == 80: + tags["block_reduction_depth"] = 2 return tags (main_block,) = reduction_blocks diff --git a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py index 17d22dcf..a86f6469 100644 --- a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py +++ b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py @@ -515,7 +515,7 @@ def matmul_nt_dequantize_b_propagate_b( fast_decoding=False, with_bias=False, zeros_mode="original", - transform_kind: Union[int, TransformKind] = TransformKind.NonTransform, + transform_kind: Union[int, TransformKind] = TransformKind.IntraWarpTransform, ): if isinstance(transform_kind, int): transform_kind = TransformKind(transform_kind) @@ -699,8 +699,8 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b( fast_decoding=False, with_bias=False, zeros_mode="original", - transform_kind_input: Union[int, TransformKind] = TransformKind.NonTransform, - transform_kind_weight: Union[int, TransformKind] = TransformKind.NonTransform, + transform_kind_input: Union[int, TransformKind] = TransformKind.IntraWarpTransform, + transform_kind_weight: Union[int, TransformKind] = TransformKind.IntraWarpTransform, ): if isinstance(transform_kind_input, int): transform_kind_input = TransformKind(transform_kind_input) From c6cc01e1f92cb63ed6e2ebb9472d0da14e42fcc9 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:02:38 +0800 Subject: [PATCH 3/7] Fix retrieve head commit in benchmark (#134) --- .github/workflows/benchmark.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 235b8686..013345f6 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -62,6 +62,7 @@ jobs: uses: actions/checkout@v2 with: fetch-depth: 0 + ref: ${{ github.event.pull_request.head.ref }} - name: Get PR branch commit ID id: get_pr_commit From 7c6bccf9c2b64d9ff0b709de2d49da994a926ba4 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:05:38 +0800 Subject: [PATCH 4/7] [Integration] Upload tutorial for making a bitnet ckpt for vLLM (#135) * fix install with absolute path * efficient inference with torch compile * update vllm ckpt tutorial for bitnet --- install.sh | 2 +- integration/BitNet/README.md | 33 + integration/BitNet/eval_correctness.py | 4 - .../BitNet/{ => maint}/create_bitblas_ckpt.py | 10 +- .../generate_bitnet_model_bitblas_format.sh | 27 + .../generate_bitnet_model_native_format.sh | 27 + integration/BitNet/maint/quant_config.json | 10 + integration/BitNet/utils_quant.py | 13 +- integration/BitNet/vllm_workspace/conftest.py | 625 ++++++++++++++++++ .../inference_with_compress_format.py | 45 ++ .../inference_with_native_format.py | 62 ++ integration/BitNet/vllm_workspace/utils.py | 65 ++ 12 files changed, 913 insertions(+), 10 deletions(-) rename integration/BitNet/{ => maint}/create_bitblas_ckpt.py (90%) create mode 100755 integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh create mode 100755 integration/BitNet/maint/generate_bitnet_model_native_format.sh create mode 100644 integration/BitNet/maint/quant_config.json create mode 100644 integration/BitNet/vllm_workspace/conftest.py create mode 100644 integration/BitNet/vllm_workspace/inference_with_compress_format.py create mode 100644 integration/BitNet/vllm_workspace/inference_with_native_format.py create mode 100644 integration/BitNet/vllm_workspace/utils.py diff --git a/install.sh b/install.sh index b7b38962..db3b3682 100755 --- a/install.sh +++ b/install.sh @@ -46,7 +46,7 @@ fi echo "Download and extraction completed successfully." -LLVM_CONFIG_PATH="${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config" +LLVM_CONFIG_PATH="$(realpath ${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config)" echo "LLVM config path: $LLVM_CONFIG_PATH" # clone and build tvm diff --git a/integration/BitNet/README.md b/integration/BitNet/README.md index 8fa09f76..f1e82625 100644 --- a/integration/BitNet/README.md +++ b/integration/BitNet/README.md @@ -2,8 +2,41 @@ license: mit --- +## Latest News + +- 08/09/2024 ✨: We provide a more efficient implementation for bitnet with vLLM, which should use special model checkpoints, to make the ckpt, please reach []. + This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. +## Make Checkpoints for vLLM + +We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension. + +```bash +# move to the integration directory +cd /root/to/BitBLAS/integration/BitNet +# make the checkpoint +./maint/generate_bitnet_model_native_format.sh +# the output ckpy will be saved in the `./models/bitnet_b1_58-3B` directory +``` + +The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization. + +```bash +./maint/generate_bitnet_model_bitblas_format.sh ./models/bitnet_3B_1.58bit ./models/bitnet_3B_1.58bit_bitblas +# the output ckpy will be saved in the `./models/bitnet_b1_58-3B_bitblas` directory +``` + +Finnaly, you can use the ckpt in vLLM with: + +```bash +cd vllm_workspace +# inference with the ckpt with fp16 uncompressed metadata +python3 inference_with_native_format.py +# inference with the ckpt with BitBLAS compressed metadata +python3 inference_with_bitblas_format.py +``` + ## BitBLAS Results ### Performance diff --git a/integration/BitNet/eval_correctness.py b/integration/BitNet/eval_correctness.py index cef89313..4017a6c1 100644 --- a/integration/BitNet/eval_correctness.py +++ b/integration/BitNet/eval_correctness.py @@ -18,9 +18,6 @@ def generate_text(model, tokenizer, prompt, max_length=100): seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - # position_embeddings = model.embed_positions(position_ids) - # cos = position_embeddings[:, :, 0::2].cos() - # sin = position_embeddings[:, :, 1::2].sin() generation_config = GenerationConfig( max_length=max_length, @@ -32,7 +29,6 @@ def generate_text(model, tokenizer, prompt, max_length=100): start_time = time.time() output_ids = model.generate(input_ids, generation_config=generation_config) - # output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin) end_time = time.time() generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) diff --git a/integration/BitNet/create_bitblas_ckpt.py b/integration/BitNet/maint/create_bitblas_ckpt.py similarity index 90% rename from integration/BitNet/create_bitblas_ckpt.py rename to integration/BitNet/maint/create_bitblas_ckpt.py index d443b2e2..d71f5958 100644 --- a/integration/BitNet/create_bitblas_ckpt.py +++ b/integration/BitNet/maint/create_bitblas_ckpt.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import argparse import torch import bitblas from modeling_bitnet import BitnetForCausalLM @@ -17,8 +18,13 @@ torch.set_grad_enabled(False) bitblas.set_log_level("INFO") -model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" -saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") +parser = argparse.ArgumentParser() +parser.add_argument("--model_name_or_path", type=str, default="BitBLASModel/open_llama_3b_1.58bits") +parser.add_argument("--saved_model_path", type=str, default=None) +args = parser.parse_args() + +model_name_or_path = args.model_name_or_path +saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path def generate_text(model, tokenizer, prompt, max_length=100): diff --git a/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh b/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh new file mode 100755 index 00000000..aea62db9 --- /dev/null +++ b/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# retrieve the native model input and saved model directory +MODEL_DIR=$1 +SAVED_MODEL_DIR=$2 + +# check if the model directory exists +if [ ! -d "$MODEL_DIR" ]; then + echo "Model directory does not exist!" + exit 1 +fi + +# if the saved model directory does not exist, create it +# if SAVED_MODEL_DIR is not provided, we do not pass it to the script +if [ -z "$SAVED_MODEL_DIR" ]; then + python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR +else + python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR --saved_model_path $SAVED_MODEL_DIR +fi + +# get the realpath of the saved model directory +SAVED_MODEL_DIR=$(realpath $SAVED_MODEL_DIR) + +echo "Model has been converted and save to $SAVED_MODEL_DIR" diff --git a/integration/BitNet/maint/generate_bitnet_model_native_format.sh b/integration/BitNet/maint/generate_bitnet_model_native_format.sh new file mode 100755 index 00000000..75bac8a7 --- /dev/null +++ b/integration/BitNet/maint/generate_bitnet_model_native_format.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# require git lfs +if ! command -v git-lfs &> /dev/null; then + echo "Please install git-lfs first by running 'sudo apt install git-lfs'" + exit 1 +fi + +mkdir -p models + +cd models + +# download the model +git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B bitnet_3B_1.58bits --depth 1 + +# copy quantized config into the model directory +cp ../maint/quant_config.json bitnet_3B_1.58bits + +# get the realpath of the model directory +MODEL_DIR=$(realpath bitnet_3B_1.58bits) + +cd .. + +echo "Model has been converted and save to $MODEL_DIR" diff --git a/integration/BitNet/maint/quant_config.json b/integration/BitNet/maint/quant_config.json new file mode 100644 index 00000000..e2b24123 --- /dev/null +++ b/integration/BitNet/maint/quant_config.json @@ -0,0 +1,10 @@ +{ + "bits": 2, + "desc_act": false, + "static_groups": false, + "sym": true, + "lm_head": false, + "model_name_or_path": "1bitLLM/bitnet_b1_58-3B", + "quant_method": "bitnet", + "checkpoint_format": "bitnet" +} \ No newline at end of file diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index d9cc25ae..cb0c0f50 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -138,6 +138,7 @@ def weight_quant(weight): result = (weight * s).round().clamp(-1, 1) return result.type(torch.int8) + @torch.compile def activation_quant(self, x, num_bits=8): x = x.float() Qn = -(2**(num_bits - 1)) @@ -146,6 +147,13 @@ def activation_quant(self, x, num_bits=8): result = (x * s).round().clamp(Qn, Qp) return result.type(torch.int8) + @torch.compile + def post_quant_process(self, input, si, sw): + out = input / si + out = out / sw + out = out.half() + return out + # for the correctness evaluation. def native_forward(self, input): quant_input = (input + (activation_quant(input, self.input_bits) - input).detach()) @@ -184,9 +192,8 @@ def forward(self, input): Qp = self.Qp si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) # if / (si * sw) it will inf in some cases - out = fp32_out / si - out = out / sw - out = out.half() + out = self.post_quant_process(fp32_out, si, sw) + if self.bias is not None: out += self.bias.view(1, -1).expand_as(out) return out diff --git a/integration/BitNet/vllm_workspace/conftest.py b/integration/BitNet/vllm_workspace/conftest.py new file mode 100644 index 00000000..fd5e162a --- /dev/null +++ b/integration/BitNet/vllm_workspace/conftest.py @@ -0,0 +1,625 @@ +import contextlib +import gc +import os +import sys +from collections import UserList +from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from transformers import ( + AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoTokenizer, + BatchEncoding, +) + +from vllm import LLM, SamplingParams +from vllm.assets.image import ImageAsset +from vllm.config import TokenizerPoolConfig +from vllm.distributed import ( + destroy_distributed_environment, + destroy_model_parallel, +) +from vllm.inputs import TextPrompt +from vllm.logger import init_logger +from vllm.sequence import SampleLogprobs +from vllm.utils import cuda_device_count_stateless, is_cpu + +logger = init_logger(__name__) + +_TEST_DIR = os.path.dirname(__file__) +_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] +_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] + + +def _read_prompts(filename: str) -> List[str]: + with open(filename, "r") as f: + prompts = f.readlines() + return prompts + + +class _ImageAssetPrompts(TypedDict): + stop_sign: str + cherry_blossom: str + + +if sys.version_info < (3, 9): + # UserList cannot be subscripted + class _ImageAssetsBase(UserList): + pass + +else: + + class _ImageAssetsBase(UserList[ImageAsset]): + pass + + +class _ImageAssets(_ImageAssetsBase): + + def __init__(self) -> None: + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) + + def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: + """ + Convenience method to define the prompt for each test image. + + The order of the returned prompts matches the order of the + assets when iterating through this object. + """ + return [prompts["stop_sign"], prompts["cherry_blossom"]] + + +IMAGE_ASSETS = _ImageAssets() +"""Singleton instance of :class:`_ImageAssets`.""" + + +def cleanup(): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + if not is_cpu(): + torch.cuda.empty_cache() + + +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + """Allow subdirectories to skip global cleanup by overriding this fixture. + This can provide a ~10x speedup for non-GPU unit tests since they don't need + to initialize torch. + """ + + if request.node.get_closest_marker("skip_global_cleanup"): + return False + + return True + + +@pytest.fixture(autouse=True) +def cleanup_fixture(should_do_global_cleanup_after_test: bool): + yield + if should_do_global_cleanup_after_test: + cleanup() + + +@pytest.fixture +def example_prompts() -> List[str]: + prompts = [] + for filename in _TEST_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture +def example_long_prompts() -> List[str]: + prompts = [] + for filename in _LONG_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture(scope="session") +def image_assets() -> _ImageAssets: + return IMAGE_ASSETS + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, +} + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) + + +class HfRunner: + + def wrap_device(self, input: _T) -> _T: + if not is_cpu(): + return input.to("cuda") + else: + return input.to("cpu") + + def __init__( + self, + model_name: str, + dtype: str = "half", + *, + model_kwargs: Optional[Dict[str, Any]] = None, + is_embedding_model: bool = False, + is_vision_model: bool = False, + is_sparseml_model: bool = False, + ) -> None: + assert dtype in _STR_DTYPE_TO_TORCH_DTYPE + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + + self.model_name = model_name + + if is_embedding_model: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + + self.model = self.wrap_device( + SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype) + ) + else: + if is_vision_model: + auto_cls = AutoModelForVision2Seq + elif is_sparseml_model: + from sparseml.transformers import SparseAutoModelForCausalLM + + auto_cls = SparseAutoModelForCausalLM + else: + auto_cls = AutoModelForCausalLM + + model_kwargs = model_kwargs if model_kwargs is not None else {} + self.model = self.wrap_device( + auto_cls.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + **model_kwargs, + ) + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + try: + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 + + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + except Exception: + logger.warning( + "Unable to auto-load processor from HuggingFace for " + "model %s. Using tokenizer instead.", + model_name, + ) + self.processor = self.tokenizer + + def generate( + self, + prompts: List[str], + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[List[int]], List[str]]]: + if images: + assert len(prompts) == len(images) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + + output_ids = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + **kwargs, + ) + output_str = self.processor.batch_decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output_ids = output_ids.cpu().tolist() + outputs.append((output_ids, output_str)) + return outputs + + def generate_greedy( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[int], str]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + images=images, + **kwargs, + ) + + return [ + (output_ids[0], output_str[0]) for output_ids, output_str in outputs + ] + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + num_beams=beam_width, + num_return_sequences=beam_width, + ) + for i in range(len(outputs)): + output_ids, output_str = outputs[i] + for j in range(len(output_ids)): + output_ids[j] = [ + x for x in output_ids[j] if x != self.tokenizer.pad_token_id + ] + outputs[i] = (output_ids, output_str) + return outputs + + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[List[torch.Tensor]]: + all_logprobs: List[List[torch.Tensor]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + seq_logprobs: List[torch.Tensor] = [] + for hidden_states in output.hidden_states: + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if self.model.get_output_embeddings().bias is not None: + logits += self.model.get_output_embeddings().bias.unsqueeze( + 0 + ) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + all_logprobs.append(seq_logprobs) + return all_logprobs + + def generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + input_ids = inputs.input_ids + + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + + seq_logprobs: List[torch.Tensor] = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if ( + getattr(self.model.get_output_embeddings(), "bias", None) + is not None + ): + logits += self.model.get_output_embeddings().bias.unsqueeze( + 0 + ) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst: List[Dict[int, float]] = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [ + (output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs + ] + + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + return self.model.encode(prompts) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def hf_runner(): + return HfRunner + + +class VllmRunner: + + def __init__( + self, + model_name: str, + tokenizer_name: Optional[str] = None, + # Use smaller max model length, otherwise bigger model cannot run due + # to kv cache size limit. + max_model_len: int = 1024, + dtype: str = "half", + disable_log_stats: bool = True, + tensor_parallel_size: int = 1, + block_size: int = 16, + enable_chunked_prefill: bool = False, + swap_space: int = 4, + enforce_eager: bool = False, + **kwargs, + ) -> None: + self.model = LLM( + model=model_name, + tokenizer=tokenizer_name, + trust_remote_code=True, + dtype=dtype, + swap_space=swap_space, + enforce_eager=enforce_eager, + disable_log_stats=disable_log_stats, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, + **kwargs, + ) + + def generate( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[List[int]], List[str]]]: + if images is not None: + assert len(prompts) == len(images) + + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = {"image": image} + + req_outputs = self.model.generate( + inputs, sampling_params=sampling_params + ) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for req_output in req_outputs: + prompt_str = req_output.prompt + prompt_ids = req_output.prompt_token_ids + req_sample_output_ids: List[List[int]] = [] + req_sample_output_strs: List[str] = [] + for sample in req_output.outputs: + output_str = sample.text + output_ids = list(sample.token_ids) + req_sample_output_ids.append(prompt_ids + output_ids) + req_sample_output_strs.append(prompt_str + output_str) + outputs.append((req_sample_output_ids, req_sample_output_strs)) + return outputs + + def generate_w_logprobs( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + assert sampling_params.logprobs is not None + + if images is not None: + assert len(prompts) == len(images) + + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = {"image": image} + + req_outputs = self.model.generate( + inputs, sampling_params=sampling_params + ) + outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + for req_output in req_outputs: + for sample in req_output.outputs: + output_str = sample.text + output_ids = sample.token_ids + output_logprobs = sample.logprobs + outputs.append((output_ids, output_str, output_logprobs)) + return outputs + + def generate_greedy( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str]]: + greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + outputs = self.generate(prompts, greedy_params, images=images) + return [ + (output_ids[0], output_str[0]) for output_ids, output_str in outputs + ] + + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs + ) + outputs = self.generate_w_logprobs( + prompts, greedy_logprobs_params, images=images + ) + + return [ + (output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs + ] + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + beam_search_params = SamplingParams( + n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens, + ) + outputs = self.generate(prompts, beam_search_params) + return outputs + + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +def get_tokenizer_pool_config(tokenizer_group_type): + if tokenizer_group_type is None: + return None + if tokenizer_group_type == "ray": + return TokenizerPoolConfig( + pool_size=1, pool_type="ray", extra_config={} + ) + raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog + + +@pytest.fixture(scope="session") +def num_gpus_available(): + """Get number of GPUs without initializing the CUDA context + in current process.""" + + return cuda_device_count_stateless() diff --git a/integration/BitNet/vllm_workspace/inference_with_compress_format.py b/integration/BitNet/vllm_workspace/inference_with_compress_format.py new file mode 100644 index 00000000..45426d65 --- /dev/null +++ b/integration/BitNet/vllm_workspace/inference_with_compress_format.py @@ -0,0 +1,45 @@ +"""Compare the outputs of a GPTQ model to a Marlin model. + +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_marlin.py`. +""" + +from conftest import VllmRunner +import os +import argparse + +# get the path of the current file +current_file_path = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file_path) + +ckpt_path = os.path.join(current_dir, "../models/bitnet_3b_1.58bits_bitblas") +parser = argparse.ArgumentParser(description="Inference with BitNet") +parser.add_argument( + "--ckpt_path", + type=str, + default=ckpt_path, + help="Path to the checkpoint", +) + +args = parser.parse_args() + +ckpt_path = args.ckpt_path +with VllmRunner( + ckpt_path, + dtype="half", + quantization="bitblas", + enforce_eager=True, +) as bitnet_model: + bitbnet_outputs = bitnet_model.generate_greedy( + ["Hi, tell me about microsoft?"], max_tokens=1024 + ) + print("bitnet inference:") + print(bitbnet_outputs[0][0]) + print(bitbnet_outputs[0][1]) diff --git a/integration/BitNet/vllm_workspace/inference_with_native_format.py b/integration/BitNet/vllm_workspace/inference_with_native_format.py new file mode 100644 index 00000000..07aefeec --- /dev/null +++ b/integration/BitNet/vllm_workspace/inference_with_native_format.py @@ -0,0 +1,62 @@ +"""Compare the outputs of a GPTQ model to a Marlin model. + +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_marlin.py`. +""" + +from conftest import VllmRunner +import os +import argparse + + +# get the path of the current file +current_file_path = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file_path) +ckpt_path = os.path.join(current_dir, "../models/bitnet_3b_1.58bits") + +parser = argparse.ArgumentParser(description="Inference with BitNet") +parser.add_argument( + "--ckpt_path", + type=str, + default=ckpt_path, + help="Path to the checkpoint", +) + +args = parser.parse_args() + +ckpt_path = args.ckpt_path + +with VllmRunner( + ckpt_path, + dtype="half", + quantization="bitnet", + gpu_memory_utilization=0.5, +) as bitnet_model: + bitbnet_outputs = bitnet_model.generate_greedy( + ["Hi, tell me about microsoft?"], max_tokens=128 + ) + print("bitnet inference output:") + print(bitbnet_outputs[0][0]) + print(bitbnet_outputs[0][1]) + +# with VllmRunner( +# "BitBLASModel/open_llama_3b_1.58bits_bitblas", +# dtype="half", +# quantization="bitblas", +# enforce_eager=True, +# ) as bitnet_model: +# torch.cuda.profiler.start() +# bitbnet_outputs = bitnet_model.generate_greedy( +# ["Hi, tell me about microsoft?"], max_tokens=1024 +# ) +# torch.cuda.profiler.stop() +# print("bitnet:") +# print(bitbnet_outputs[0][0]) +# print(bitbnet_outputs[0][1]) diff --git a/integration/BitNet/vllm_workspace/utils.py b/integration/BitNet/vllm_workspace/utils.py new file mode 100644 index 00000000..0d5e304d --- /dev/null +++ b/integration/BitNet/vllm_workspace/utils.py @@ -0,0 +1,65 @@ +from typing import Dict, List, Tuple + +TokensText = Tuple[List[int], str] + + +def check_outputs_equal(outputs_0_lst: List[TokensText], + outputs_1_lst: List[TokensText], name_0: str, + name_1: str): + """ + Compare the two sequences generated by different models, + which should be equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + for prompt_idx, (outputs_0, + outputs_1) in enumerate(zip(outputs_0_lst, + outputs_1_lst)): + output_ids_0, output_str_0 = outputs_0 + output_ids_1, output_str_1 = outputs_1 + + assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + +TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] + + +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], + outputs_1_lst: List[TokensTextLogprobs], name_0: str, + name_1: str): + """ + Compare the logprobs of two sequences generated by different models, + which should be similar but not necessarily equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + # Loop through responses to each prompt. + for prompt_idx, (outputs_0, + outputs_1) in enumerate(zip(outputs_0_lst, + outputs_1_lst)): + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + + # Loop through generated tokens. + for idx, (output_id_0, + output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): + + # If generated tokens don't match, then + if output_id_0 != output_id_1: + # Each predicted token must be in top N logprobs of the other + assert output_id_0 in logprobs_1[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_id_1 in logprobs_0[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + # Break out since sequences will now diverge. + break From d52f93dbf2b09c2e7265f19657cc82d165f68127 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:09:22 +0800 Subject: [PATCH 5/7] [Typo] Fix missing links in the bitnet integration's docs (#136) * fix install with absolute path * efficient inference with torch compile * update vllm ckpt tutorial for bitnet * ReadME Fix. --- integration/BitNet/README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/integration/BitNet/README.md b/integration/BitNet/README.md index f1e82625..78d8a7eb 100644 --- a/integration/BitNet/README.md +++ b/integration/BitNet/README.md @@ -2,12 +2,13 @@ license: mit --- -## Latest News - -- 08/09/2024 ✨: We provide a more efficient implementation for bitnet with vLLM, which should use special model checkpoints, to make the ckpt, please reach []. This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. +## Latest News + +- 08/09/2024 ✨: We provide a more efficient implementation for bitnet with vLLM, which should use special model checkpoints, to make the ckpt and study how to deploy, please checkout [Make Checkpoints for vLLM](#make-checkpoints-for-vllm). + ## Make Checkpoints for vLLM We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension. From 22b5262dc136b8b70400162c3c278cf6ec313e87 Mon Sep 17 00:00:00 2001 From: Lingxiao Ma Date: Fri, 9 Aug 2024 20:47:13 +0800 Subject: [PATCH 6/7] fix BitNet integration for vLLM (#137) * fix BitNet integration for vLLM * update ckpt name of BitNet integration for vLLM * format code --- integration/BitNet/README.md | 6 ++-- .../BitNet/maint/create_bitblas_ckpt.py | 14 ++++++--- .../generate_bitnet_model_bitblas_format.sh | 6 ++++ .../generate_bitnet_model_native_format.sh | 6 ++-- ...quant_config.json => quantize_config.json} | 0 .../inference_with_compress_format.py | 15 +++++----- .../inference_with_native_format.py | 30 ++++--------------- 7 files changed, 35 insertions(+), 42 deletions(-) rename integration/BitNet/maint/{quant_config.json => quantize_config.json} (100%) diff --git a/integration/BitNet/README.md b/integration/BitNet/README.md index 78d8a7eb..63cc3e27 100644 --- a/integration/BitNet/README.md +++ b/integration/BitNet/README.md @@ -18,14 +18,14 @@ We provide two scripts to make the checkpoints for vLLM. The first script is `ge cd /root/to/BitBLAS/integration/BitNet # make the checkpoint ./maint/generate_bitnet_model_native_format.sh -# the output ckpy will be saved in the `./models/bitnet_b1_58-3B` directory +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory ``` The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization. ```bash -./maint/generate_bitnet_model_bitblas_format.sh ./models/bitnet_3B_1.58bit ./models/bitnet_3B_1.58bit_bitblas -# the output ckpy will be saved in the `./models/bitnet_b1_58-3B_bitblas` directory +./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory ``` Finnaly, you can use the ckpt in vLLM with: diff --git a/integration/BitNet/maint/create_bitblas_ckpt.py b/integration/BitNet/maint/create_bitblas_ckpt.py index d71f5958..0bf603e0 100644 --- a/integration/BitNet/maint/create_bitblas_ckpt.py +++ b/integration/BitNet/maint/create_bitblas_ckpt.py @@ -4,14 +4,18 @@ import argparse import torch import bitblas -from modeling_bitnet import BitnetForCausalLM -from tokenization_bitnet import BitnetTokenizer from transformers.utils.hub import cached_file import os from transformers import GenerationConfig import time import json +import sys + +sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + "/../") +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer + filepath = os.path.abspath(__file__) dirpath = os.path.dirname(filepath) @@ -19,12 +23,14 @@ bitblas.set_log_level("INFO") parser = argparse.ArgumentParser() -parser.add_argument("--model_name_or_path", type=str, default="BitBLASModel/open_llama_3b_1.58bits") +parser.add_argument("--model_name_or_path", type=str, default="1bitLLM/bitnet_b1_58-3B") parser.add_argument("--saved_model_path", type=str, default=None) args = parser.parse_args() model_name_or_path = args.model_name_or_path -saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +saved_model_path = os.path.join( + dirpath, "models", + f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path def generate_text(model, tokenizer, prompt, max_length=100): diff --git a/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh b/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh index aea62db9..3ace5803 100755 --- a/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh +++ b/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh @@ -24,4 +24,10 @@ fi # get the realpath of the saved model directory SAVED_MODEL_DIR=$(realpath $SAVED_MODEL_DIR) +# cp files +cp $MODEL_DIR/quantize_config.json $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer.json $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer.model $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer_config.json $SAVED_MODEL_DIR/ + echo "Model has been converted and save to $SAVED_MODEL_DIR" diff --git a/integration/BitNet/maint/generate_bitnet_model_native_format.sh b/integration/BitNet/maint/generate_bitnet_model_native_format.sh index 75bac8a7..c002f6e1 100755 --- a/integration/BitNet/maint/generate_bitnet_model_native_format.sh +++ b/integration/BitNet/maint/generate_bitnet_model_native_format.sh @@ -14,13 +14,13 @@ mkdir -p models cd models # download the model -git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B bitnet_3B_1.58bits --depth 1 +git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B ckpt_bitnet_b1_58-3B --depth 1 # copy quantized config into the model directory -cp ../maint/quant_config.json bitnet_3B_1.58bits +cp ../maint/quantize_config.json ckpt_bitnet_b1_58-3B # get the realpath of the model directory -MODEL_DIR=$(realpath bitnet_3B_1.58bits) +MODEL_DIR=$(realpath ckpt_bitnet_b1_58-3B) cd .. diff --git a/integration/BitNet/maint/quant_config.json b/integration/BitNet/maint/quantize_config.json similarity index 100% rename from integration/BitNet/maint/quant_config.json rename to integration/BitNet/maint/quantize_config.json diff --git a/integration/BitNet/vllm_workspace/inference_with_compress_format.py b/integration/BitNet/vllm_workspace/inference_with_compress_format.py index 45426d65..9e60fa97 100644 --- a/integration/BitNet/vllm_workspace/inference_with_compress_format.py +++ b/integration/BitNet/vllm_workspace/inference_with_compress_format.py @@ -19,7 +19,7 @@ current_file_path = os.path.realpath(__file__) current_dir = os.path.dirname(current_file_path) -ckpt_path = os.path.join(current_dir, "../models/bitnet_3b_1.58bits_bitblas") +ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas") parser = argparse.ArgumentParser(description="Inference with BitNet") parser.add_argument( "--ckpt_path", @@ -32,14 +32,13 @@ ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitblas", - enforce_eager=True, + ckpt_path, + dtype="half", + quantization="bitblas", + enforce_eager=True, # set False to enable cuda graph ) as bitnet_model: - bitbnet_outputs = bitnet_model.generate_greedy( - ["Hi, tell me about microsoft?"], max_tokens=1024 - ) + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], + max_tokens=1024) print("bitnet inference:") print(bitbnet_outputs[0][0]) print(bitbnet_outputs[0][1]) diff --git a/integration/BitNet/vllm_workspace/inference_with_native_format.py b/integration/BitNet/vllm_workspace/inference_with_native_format.py index 07aefeec..579c5e17 100644 --- a/integration/BitNet/vllm_workspace/inference_with_native_format.py +++ b/integration/BitNet/vllm_workspace/inference_with_native_format.py @@ -15,11 +15,10 @@ import os import argparse - # get the path of the current file current_file_path = os.path.realpath(__file__) current_dir = os.path.dirname(current_file_path) -ckpt_path = os.path.join(current_dir, "../models/bitnet_3b_1.58bits") +ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas") parser = argparse.ArgumentParser(description="Inference with BitNet") parser.add_argument( @@ -34,29 +33,12 @@ ckpt_path = args.ckpt_path with VllmRunner( - ckpt_path, - dtype="half", - quantization="bitnet", - gpu_memory_utilization=0.5, + ckpt_path, + dtype="half", + quantization="bitnet", + gpu_memory_utilization=0.5, ) as bitnet_model: - bitbnet_outputs = bitnet_model.generate_greedy( - ["Hi, tell me about microsoft?"], max_tokens=128 - ) + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128) print("bitnet inference output:") print(bitbnet_outputs[0][0]) print(bitbnet_outputs[0][1]) - -# with VllmRunner( -# "BitBLASModel/open_llama_3b_1.58bits_bitblas", -# dtype="half", -# quantization="bitblas", -# enforce_eager=True, -# ) as bitnet_model: -# torch.cuda.profiler.start() -# bitbnet_outputs = bitnet_model.generate_greedy( -# ["Hi, tell me about microsoft?"], max_tokens=1024 -# ) -# torch.cuda.profiler.stop() -# print("bitnet:") -# print(bitbnet_outputs[0][0]) -# print(bitbnet_outputs[0][1]) From 0e1e3663d3096ad4bee7ee0c119d0662c69b0b3e Mon Sep 17 00:00:00 2001 From: Lingxiao Ma Date: Fri, 9 Aug 2024 22:10:34 +0800 Subject: [PATCH 7/7] fix BitNet integration for vLLM (#139) * fix BitNet integration for vLLM * update ckpt name of BitNet integration for vLLM * format code * fix BitNet integration for vLLM native version --- .../vllm_workspace/inference_with_compress_format.py | 4 +++- .../BitNet/vllm_workspace/inference_with_native_format.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/integration/BitNet/vllm_workspace/inference_with_compress_format.py b/integration/BitNet/vllm_workspace/inference_with_compress_format.py index 9e60fa97..55a24543 100644 --- a/integration/BitNet/vllm_workspace/inference_with_compress_format.py +++ b/integration/BitNet/vllm_workspace/inference_with_compress_format.py @@ -35,7 +35,9 @@ ckpt_path, dtype="half", quantization="bitblas", - enforce_eager=True, # set False to enable cuda graph + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=1024) diff --git a/integration/BitNet/vllm_workspace/inference_with_native_format.py b/integration/BitNet/vllm_workspace/inference_with_native_format.py index 579c5e17..4f5f87f6 100644 --- a/integration/BitNet/vllm_workspace/inference_with_native_format.py +++ b/integration/BitNet/vllm_workspace/inference_with_native_format.py @@ -18,7 +18,7 @@ # get the path of the current file current_file_path = os.path.realpath(__file__) current_dir = os.path.dirname(current_file_path) -ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas") +ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B") parser = argparse.ArgumentParser(description="Inference with BitNet") parser.add_argument( @@ -35,8 +35,11 @@ with VllmRunner( ckpt_path, dtype="half", - quantization="bitnet", + quantization="bitnet_bitblas", gpu_memory_utilization=0.5, + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, ) as bitnet_model: bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128) print("bitnet inference output:")