From 2c6a1e87e6044924c877b7147cdae1c846192c2a Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Thu, 29 Aug 2024 17:30:55 +0800 Subject: [PATCH 1/6] [Benchmark] Fast Decoding Benchmark (#158) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix --- README.md | 16 +- .../benchmark_matmul_fast_decoding.py | 440 ++++++++++++++++++ 2 files changed, 448 insertions(+), 8 deletions(-) create mode 100644 benchmark/operators/benchmark_matmul_fast_decoding.py diff --git a/README.md b/README.md index 43f1d92d..315fa307 100644 --- a/README.md +++ b/README.md @@ -61,14 +61,14 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and | **A_dtype** | **W_dtype** | **Accum_dtype** | **Out_dtype** | **BitBLAS Support** | **Tested Platform** | |:-----------:|:-----------:|:---------------:|:--------------------:|:-------------------:|:----------------------------------------------------:| -| BF16 | BF16 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | FP4_E2M1 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | FP8_E4M3 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | INT8 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | UINT4/INT4 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | UINT2/INT2 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | UINT1 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | NF4 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | BF16 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | FP4_E2M1 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | FP8_E4M3 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | INT8 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | UINT4/INT4 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | UINT2/INT2 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | UINT1 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | NF4 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | | FP16 | FP16 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | FP4_E2M1 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | FP8_E4M3 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | diff --git a/benchmark/operators/benchmark_matmul_fast_decoding.py b/benchmark/operators/benchmark_matmul_fast_decoding.py new file mode 100644 index 00000000..c3c65b3b --- /dev/null +++ b/benchmark/operators/benchmark_matmul_fast_decoding.py @@ -0,0 +1,440 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas.benchmark import BitblasOperatorBenchmarkBase +from bitblas import Matmul, MatmulConfig +from bitblas.utils import get_commit_id +from bitblas import set_log_level +from tabulate import tabulate +from os import path, makedirs +from typing import List +import argparse +from tqdm import tqdm + +set_log_level("DEBUG") + + +class BitblasMatmulOpsBenchmarkCompareStategies(BitblasOperatorBenchmarkBase): + + BENCHMARK_RESULTS_FILE = "benchmark_results.json" + BENCHMARK_SHAPES_FILE = "benchmark_shapes.json" + BENCHMARK_DEVICE_FILE = "benchmark_device.json" + + config_map = { + "FP16xFP16_GEMV": { + "A_dtype": "float16", + "W_dtype": "float16", + "accum_dtype": "float16", + }, + "FP16xUINT4_GEMV_DECODING_NAIVE": { + "A_dtype": "float16", + "W_dtype": "uint4", + "accum_dtype": "float16", + "fast_decoding": False, + }, + "FP16xUINT4_GEMV_DECODING_FAST": { + "A_dtype": "float16", + "W_dtype": "uint4", + "accum_dtype": "float16", + "fast_decoding": True, + }, + "FP16xUINT2_GEMV_DECODING_NAIVE": { + "A_dtype": "float16", + "W_dtype": "uint2", + "accum_dtype": "float16", + "fast_decoding": False, + }, + "FP16xUINT2_GEMV_DECODING_FAST": { + "A_dtype": "float16", + "W_dtype": "uint2", + "accum_dtype": "float16", + "fast_decoding": True, + }, + "INT8xUINT2_GEMV_DECODING_NAIVE": { + "A_dtype": "int8", + "W_dtype": "uint2", + "accum_dtype": "int32", + "out_dtype": "int32", + "fast_decoding": False, + }, + "INT8xUINT2_GEMV_DECODING_FAST": { + "A_dtype": "int8", + "W_dtype": "uint2", + "accum_dtype": "int32", + "out_dtype": "int32", + "fast_decoding": True, + }, + } + + OPT_SHAPES = 1 # our test focuses on GEMV only + + CURRENT_COMMIT_ID = get_commit_id() + + def __init__(self): + super().__init__() + + def prepare_set_group_4x(self, name: str, N, K) -> List: + assert name in self.config_map, f"Operator {name} not found in config map" + return [ + self.generate_op_unit(self.generate_operator_config(name, self.OPT_SHAPES, N, K)), + ] + + def prepare_benchmark_sets(self): + """Prepare benchmark sets.""" + + self.add_benchmark_set( + "FP16xUINT4_GEMV_DECODING_NAIVE", + [ + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 16384, 16384), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 3200, 3200), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 8640, 3200), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 3200, 8640), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 1024, 8192), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 8192, 8192), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 28672, 8192), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 8192, 28672), + ], + ) + + self.add_benchmark_set( + "FP16xUINT4_GEMV_DECODING_FAST", + [ + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 16384, + 16384, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 3200, + 3200, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 8640, + 3200, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 3200, + 8640, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 1024, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 8192, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 28672, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 8192, + 28672, + ), + ], + ) + + self.add_benchmark_set( + "FP16xUINT2_GEMV_DECODING_NAIVE", + [ + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 16384, 16384), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 3200, 3200), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 8640, 3200), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 3200, 8640), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 1024, 8192), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 8192, 8192), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 28672, 8192), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 8192, 28672), + ], + ) + + self.add_benchmark_set( + "FP16xUINT2_GEMV_DECODING_FAST", + [ + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 16384, + 16384, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 3200, + 3200, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 8640, + 3200, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 3200, + 8640, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 1024, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 8192, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 28672, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 8192, + 28672, + ), + ], + ) + + self.add_benchmark_set( + "INT8xUINT2_GEMV_DECODING_NAIVE", + [ + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 16384, 16384), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 3200, 3200), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 8640, 3200), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 3200, 8640), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 1024, 8192), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 8192, 8192), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 28672, 8192), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 8192, 28672), + ], + ) + + self.add_benchmark_set( + "INT8xUINT2_GEMV_DECODING_FAST", + [ + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 16384, + 16384, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 3200, + 3200, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 8640, + 3200, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 3200, + 8640, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 1024, + 8192, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 8192, + 8192, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 28672, + 8192, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 8192, + 28672, + ), + ], + ) + + def generate_operator_config(self, name: str, M, N, K) -> MatmulConfig: + """Generate configuration for the given operator.""" + if name not in self.config_map: + raise ValueError(f"Operator {name} not found in config map") + return self.get_operator_config()( + M=M, + N=N, + K=K, + **self.config_map[name], + ) + + def report(self): + """Generate and print a report of the benchmark results.""" + results4compare = {} + for name, results in self.benchmark_results.items(): + if "DECODING" not in name: + name = f"{name}" + strategy = "" + else: + name, strategy = name.split("DECODING") + results4compare.setdefault(name, {})[strategy] = results + + data = [] + for name, strategy in results4compare.items(): + table_data = [ + ["TAG:", name, "Device:", self.benchmark_target], + [ + "Shape (M-N-K / N-K_M)", + "Native Decoding Time (ms)", + "Shape (M-N-K / N-K_M)", + "Fast Decoding Time (ms)", + "Tune Time (s)", + ], + ] + + def legalize_shape(M, N, K, dyn_prof_shape): + """Generate a string representation of the operator shape. + + Args: + M: The M dimension (can be an int or a tuple). + N: The N dimension (must be an int). + K: The K dimension (must be an int). + dyn_prof_shape: The dynamic profiling shape (dict with "m" key if M is dynamic). + + Returns: + A string representing the shape in either 'M-N-K' or 'N-K_M' format. + """ + if isinstance(M, int): + return f"{M}-{N}-{K}" + elif dyn_prof_shape and "m" in dyn_prof_shape: + return f"{M}-{N}-{K}_{dyn_prof_shape['m']}" + else: + # Calculate the average of tuple M + str_m = "[" + "-".join(str(m) for m in M) + "]" + opt_m = sum(M) / len(M) + return f"{N}-{K}_{str_m}_{opt_m}" + + for strategy_name, results in strategy.items(): + tmp_data = [] + if strategy_name == "": + origin_name = name + else: + origin_name = f"{name}DECODING{strategy_name}" + for i, benchmark_set in enumerate(self.benchmark_sets[origin_name]): + op_config = benchmark_set[1] + if isinstance(self.OPT_SHAPES, int): + sub_results = results[i] + latency = sub_results[0] + dyn_prof_shape = {"m": self.OPT_SHAPES} + shape = legalize_shape("DYN", op_config.N, op_config.K, dyn_prof_shape) + latency_str = "N/A" if latency is None else f"{latency:.3f}" + tmp_data.append([shape, latency_str]) + else: + sub_results = results[i * len(self.OPT_SHAPES):(i + 1) * + len(self.OPT_SHAPES)] + for i, result in enumerate(sub_results): + latency = result[0] + dyn_prof_shape = {"m": self.OPT_SHAPES[i]} + shape = legalize_shape("DYN", op_config.N, op_config.K, dyn_prof_shape) + latency_str = "N/A" if latency is None else f"{latency:.3f}" + tmp_data.append([shape, latency_str]) + if len(data) == 0: + data = tmp_data + else: + for i, item in enumerate(tmp_data): + data[i].extend(item) + + for i, item in enumerate(data): + base = item[1] + head = item[3] + + speedup = float(head) / float(base) - 1 + symbol = "+" if speedup > 0 else "-" + speedup = abs(speedup) + data[i][3] = f"{head} ({symbol}{speedup * 100 :.3f}%)" + table_data.append([*data[i], "N/A"]) + + print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) + + for data in table_data: + print(data) + + def get_operator(self): + """Return the Matmul operator.""" + return Matmul + + def get_operator_config(self): + """Return the Matmul operator configuration.""" + return MatmulConfig + + def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: + """Make an Matmul instance.""" + # Disable default tuning when do benchmark + return operator(config, target=self.benchmark_target, enable_tuning=False) + + def benchmark(self): + """Run benchmarks on all benchmark sets.""" + # Calculate the total number of benchmark runs for the progress bar + total_runs = sum( + (len(benchmark_set) * + (len(self.OPT_SHAPES) if isinstance(self.OPT_SHAPES, list) else self.OPT_SHAPES)) + for benchmark_set in self.benchmark_sets.values()) + + with tqdm(total=total_runs, desc="Total Progress", unit="benchmark") as pbar: + for name, benchmark_set in self.benchmark_sets.items(): + self.benchmark_results[name] = [] + for op, config, _ in benchmark_set: + if isinstance(self.OPT_SHAPES, int): + print(f"Running benchmark for {name} with shape {self.OPT_SHAPES}") + self.benchmark_results[name].extend( + [self.run_benchmark(op, config, {"m": self.OPT_SHAPES})]) + # Update the progress bar after each run + pbar.update(1) + else: + for opt in self.OPT_SHAPES: + print(f"Running benchmark for {name} with shape {opt}") + self.benchmark_results[name].extend( + [self.run_benchmark(op, config, {"m": opt})]) + # Update the progress bar after each run + pbar.update(1) + + def run_compare_strategy(self, report=True, serialize=True, enable_tuning: bool = False): + """Run the benchmark process.""" + + if not path.exists(self.log_path): + makedirs(self.log_path) + + if enable_tuning: + self.enable_tuning() + + self.prepare_benchmark_sets() + self.benchmark() + + if report: + self.report() + + self.cleanup() + + def serialize_results(self) -> None: + """Serialize the benchmark results.""" + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Bitblas Matmul Operator Benchmark") + + parser.add_argument( + "--enable_tuning", + action="store_true", + help="Enable hardware-aware tuning", + ) + + args = parser.parse_args() + enable_tuning = args.enable_tuning + BitblasMatmulOpsBenchmarkCompareStategies().run_compare_strategy( + enable_tuning=args.enable_tuning) From ad1d7aea12a2f2c57e8965b02c74524c2dcae26b Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Thu, 29 Aug 2024 19:32:16 +0800 Subject: [PATCH 2/6] [BUGFix] Disable tensorcore when shape is really small (#159) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes --- bitblas/gpu/matmul_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 16f33664..4a0ef532 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -666,7 +666,7 @@ def check_last_trait(region: List[Range]): block_stmt = sch.get(main_block) - minimal_tensorize_threshold = 16 + minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32 # the batch dimension is not taken into consideration. extent = block_stmt.iter_vars[1].dom.extent if isinstance(extent, From 2c091f8eb29edb4405a0d56d820af7933105565f Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 30 Aug 2024 11:19:45 +0800 Subject: [PATCH 3/6] [BUGFix] Resgiter missing FP8 LDMATRIX Instructions for dynamic shared memory (#162) * Merge branch 'main' of https://github.com/microsoft/BitBLAS into main * remove debug print * Refactor Matmul class for improved readability and maintainability * Refactor Matmul class for improved readability and maintainability * revert set device * lint fix * register fp8 for dynamic --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 3c6317a1..c5d98771 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 3c6317a1ea614b7277ffe0b4ede18b4652afad1c +Subproject commit c5d9877154f67dc0d3651032b15521e09dfda882 From 872d6d71b2c6caee294544b1364132f413be5262 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 30 Aug 2024 11:27:11 +0800 Subject: [PATCH 4/6] [Docs] Update install command from github repo (#163) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands --- README.md | 6 ++++++ docs/Installation.md | 8 ++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 315fa307..70a61357 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,12 @@ The easiest way to install BitBLAS is direcly from the PyPi using pip. To instal pip install bitblas ``` +Alternatively, to install the latest version of BitBLAS from the github repository, you can run the following command: + +```bash +pip install git+https://github.com/microsoft/BitBLAS.git +``` + After installing BitBLAS, you can verify the installation by running: ```bash diff --git a/docs/Installation.md b/docs/Installation.md index f30d2dfb..a50d478e 100644 --- a/docs/Installation.md +++ b/docs/Installation.md @@ -1,7 +1,5 @@ # Installation Guide - - ## Installing with pip **Prerequisites for installation via wheel or PyPI:** @@ -23,6 +21,12 @@ Alternatively, you may choose to install BitBLAS using prebuilt packages availab pip install bitblas-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl ``` +To install the latest version of BitBLAS from the github repository, you can run the following command: + +```bash +pip install git+https://github.com/microsoft/BitBLAS.git +``` + After installing BitBLAS, you can verify the installation by running: ```bash From f284c32be99a52fcf1e93218aa72d42d597b8f25 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 30 Aug 2024 12:30:50 +0800 Subject: [PATCH 5/6] [BugFix] Fix BitBLAS Linear with BFloat16 input (#164) * Merge branch 'main' of https://github.com/microsoft/BitBLAS into main * remove debug print * Refactor Matmul class for improved readability and maintainability * Refactor Matmul class for improved readability and maintainability * revert set device * lint fix * register fp8 for dynamic * Linear Fix --- bitblas/module/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index c1cf316f..b427a3c6 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -265,8 +265,6 @@ def warmup(self, topk=20): self.bitblas_matmul.hardware_aware_finetune(topk=topk) def forward(self, A, output=None): - if A.dtype != torch.float16: - A = A.half() A = self.bitblas_matmul.transform_input(A) stream = torch.cuda.current_stream() @@ -277,7 +275,9 @@ def forward(self, A, output=None): args = [A_void, *self.q_params] if output is None: output = torch.empty( - A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) + A.shape[:-1] + (self.out_features,), + dtype=getattr(torch, self.bitblas_matmul.out_dtype), + device=A.device) args.append(ctypes.c_void_p(output.data_ptr())) if self.bitblas_matmul.dynamic_range is not None: m = reduce(operator.mul, A.shape[:-1], 1) From b1f5e7915553b5cce4720aa373fdf97dcfc60a99 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sat, 31 Aug 2024 14:45:57 +0800 Subject: [PATCH 6/6] update tvm (#165) Co-authored-by: leiwang1999 --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c5d98771..8a1bf865 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c5d9877154f67dc0d3651032b15521e09dfda882 +Subproject commit 8a1bf865ae256ff4ae98eabf714b84c10425f912