From 46b2d7d809ce19f7a74d77d25d7acd1f3ababad3 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 9 Aug 2024 07:31:59 +0000 Subject: [PATCH] chore: Refactor benchmark scripts and fix typos --- .../operators/benchmark_matmul_strategies.py | 105 ++++++++---------- benchmark/operators/benchmark_ops_matmul.py | 41 ++++--- bitblas/ops/general_matmul/__init__.py | 5 +- 3 files changed, 70 insertions(+), 81 deletions(-) diff --git a/benchmark/operators/benchmark_matmul_strategies.py b/benchmark/operators/benchmark_matmul_strategies.py index aa9b4392f..71ee758b3 100644 --- a/benchmark/operators/benchmark_matmul_strategies.py +++ b/benchmark/operators/benchmark_matmul_strategies.py @@ -39,9 +39,7 @@ class BitblasMatmulOpsBenchmarkCompareStategies(BitblasOperatorBenchmarkBase): OPT_SHAPES = [1, 16, 32, 64, 128, 256, 512, 4096] CURRENT_COMMIT_ID = get_commit_id() - def __init__( - self - ): + def __init__(self): super().__init__() def prepare_set_group_4x(self, name: str, N, K) -> List: @@ -49,8 +47,10 @@ def prepare_set_group_4x(self, name: str, N, K) -> List: optimize_strategy = self.config_map[name]["optimize_stratety"] return [ self.generate_op_unit( - self.generate_operator_config(name, [1, 16, 32, 64, 128, 256, 512] if optimize_strategy == OptimizeStrategy.SingleBatchDecodeOnly else [16, 32, 64, 128, 256, 512], N, K) - ), + self.generate_operator_config( + name, [1, 16, 32, 64, 128, 256, 512] if optimize_strategy + == OptimizeStrategy.SingleBatchDecodeOnly else [16, 32, 64, 128, 256, 512], N, + K)), ] def prepare_benchmark_sets(self): @@ -58,9 +58,7 @@ def prepare_benchmark_sets(self): self.add_benchmark_set( "FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", [ - *self.prepare_set_group_4x( - "FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 16384, 16384 - ), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 16384, 16384), *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 3200, 3200), *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 8640, 3200), *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 3200, 8640), @@ -85,19 +83,32 @@ def prepare_benchmark_sets(self): 16384, 16384, ), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 3200, 3200), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 8640, 3200), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 3200, 8640), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 5120, 5120), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 13824, 5120), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 5120, 13824), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 6656, 6656), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 17920, 6656), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 6656, 17920), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 1024, 8192), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 8192, 8192), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 28672, 8192), - *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 8192, 28672), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 3200, 3200), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 8640, 3200), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 3200, 8640), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 5120, 5120), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 13824, 5120), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 5120, 13824), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 6656, 6656), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 17920, 6656), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 6656, 17920), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 1024, 8192), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 8192, 8192), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 28672, 8192), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 8192, 28672), ], ) @@ -135,13 +146,9 @@ def serialize_results(self) -> None: for i, _ in enumerate(results): config = self.benchmark_sets[name][i][1] dyn_prof_shape = self.benchmark_sets[name][i][2] - shapes[name].append( - [config.M, config.N, config.K, dyn_prof_shape] - ) + shapes[name].append([config.M, config.N, config.K, dyn_prof_shape]) - self._save_json( - shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE) - ) + self._save_json(shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE)) # Save device info into JSON self._save_json( @@ -162,8 +169,7 @@ def deserialize_from_logs(cls, commit_id: str) -> None: log_commit_path = path.join(benchmark.log_path, commit_id_path) benchmark.benchmark_results = cls._load_json( - path.join(log_commit_path, cls.BENCHMARK_RESULTS_FILE) - ) + path.join(log_commit_path, cls.BENCHMARK_RESULTS_FILE)) shapes_file = path.join(log_commit_path, cls.BENCHMARK_SHAPES_FILE) @@ -176,17 +182,14 @@ def deserialize_from_logs(cls, commit_id: str) -> None: name, [ benchmark.generate_op_unit( - benchmark.generate_operator_config( - name, M, N, K - ), + benchmark.generate_operator_config(name, M, N, K), dynamic_profiling_shape=dyn_prof_shape, ) ], ) benchmark.benchmark_target = cls._load_json( - path.join(log_commit_path, cls.BENCHMARK_DEVICE_FILE) - )["device"] + path.join(log_commit_path, cls.BENCHMARK_DEVICE_FILE))["device"] return benchmark @@ -210,7 +213,7 @@ def report(self): "Shape (M-N-K / N-K_M)", "Single Batching Time (ms)", "Shape (M-N-K / N-K_M)", - "Contigious Batching Time (ms)", + "Contiguous Batching Time (ms)", "Tune Time (s)", ], ] @@ -243,13 +246,11 @@ def legalize_shape(M, N, K, dyn_prof_shape): origin_name = f"{name}STRATEGY{strategy_name}" for i, benchmark_set in enumerate(self.benchmark_sets[origin_name]): op_config = benchmark_set[1] - sub_results = results[i * len(self.OPT_SHAPES) : (i + 1) * len(self.OPT_SHAPES)] + sub_results = results[i * len(self.OPT_SHAPES):(i + 1) * len(self.OPT_SHAPES)] for i, result in enumerate(sub_results): latency = result[0] - dyn_prof_shape = {"m":self.OPT_SHAPES[i]} - shape = legalize_shape( - "DYN", op_config.N, op_config.K, dyn_prof_shape - ) + dyn_prof_shape = {"m": self.OPT_SHAPES[i]} + shape = legalize_shape("DYN", op_config.N, op_config.K, dyn_prof_shape) latency_str = "N/A" if latency is None else f"{latency:.3f}" tmp_data.append([shape, latency_str]) if len(data) == 0: @@ -268,9 +269,7 @@ def legalize_shape(M, N, K, dyn_prof_shape): data[i][3] = f"{head} ({symbol}{speedup * 100 :.3f}%)" table_data.append([*data[i], "N/A"]) - print( - tabulate(table_data, headers="firstrow", tablefmt="fancy_grid") - ) + print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) for data in table_data: print(data) @@ -286,9 +285,7 @@ def get_operator_config(self): def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: """Make an Matmul instance.""" # Disable default tuning when do benchmark - return operator( - config, target=self.benchmark_target, enable_tuning=False - ) + return operator(config, target=self.benchmark_target, enable_tuning=False) def benchmark(self): """Run benchmarks on all benchmark sets.""" @@ -297,14 +294,9 @@ def benchmark(self): for op, config, _ in benchmark_set: for opt in self.OPT_SHAPES: self.benchmark_results[name].extend( - [ - self.run_benchmark(op, config, {"m": opt}) - ] - ) + [self.run_benchmark(op, config, {"m": opt})]) - def run_compare_strategy( - self, report=True, serialize=True, enable_tuning: bool = False - ): + def run_compare_strategy(self, report=True, serialize=True, enable_tuning: bool = False): """Run the benchmark process.""" if not path.exists(self.log_path): @@ -323,9 +315,7 @@ def run_compare_strategy( if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Bitblas Matmul Operator Benchmark" - ) + parser = argparse.ArgumentParser(description="Bitblas Matmul Operator Benchmark") parser.add_argument( "--enable_tuning", @@ -335,4 +325,5 @@ def run_compare_strategy( args = parser.parse_args() enable_tuning = args.enable_tuning - BitblasMatmulOpsBenchmarkCompareStategies().run_compare_strategy(enable_tuning=args.enable_tuning) + BitblasMatmulOpsBenchmarkCompareStategies().run_compare_strategy( + enable_tuning=args.enable_tuning) diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index 113076a0e..282e26997 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -123,29 +123,29 @@ def prepare_set_group_llm(self, name: str, N, K) -> List: ), ] - def get_llm_benchmark_sets(self, name:str) -> List: - return [*self.prepare_set_group_llm(name, 3200, 3200), - *self.prepare_set_group_llm(name, 8640, 3200), - *self.prepare_set_group_llm(name, 3200, 8640), - *self.prepare_set_group_llm(name, 5120, 5120), - *self.prepare_set_group_llm(name, 13824, 5120), - *self.prepare_set_group_llm(name, 5120, 13824), - *self.prepare_set_group_llm(name, 6656, 6656), - *self.prepare_set_group_llm(name, 17920, 6656), - *self.prepare_set_group_llm(name, 6656, 17920), - *self.prepare_set_group_llm(name, 1024, 8192), - *self.prepare_set_group_llm(name, 8192, 8192), - *self.prepare_set_group_llm(name, 28672, 8192), - *self.prepare_set_group_llm(name, 8192, 28672)] + def get_llm_benchmark_sets(self, name: str) -> List: + return [ + *self.prepare_set_group_llm(name, 3200, 3200), + *self.prepare_set_group_llm(name, 8640, 3200), + *self.prepare_set_group_llm(name, 3200, 8640), + *self.prepare_set_group_llm(name, 5120, 5120), + *self.prepare_set_group_llm(name, 13824, 5120), + *self.prepare_set_group_llm(name, 5120, 13824), + *self.prepare_set_group_llm(name, 6656, 6656), + *self.prepare_set_group_llm(name, 17920, 6656), + *self.prepare_set_group_llm(name, 6656, 17920), + *self.prepare_set_group_llm(name, 1024, 8192), + *self.prepare_set_group_llm(name, 8192, 8192), + *self.prepare_set_group_llm(name, 28672, 8192), + *self.prepare_set_group_llm(name, 8192, 28672) + ] def prepare_benchmark_sets(self): """Prepare benchmark sets.""" self.add_benchmark_set( "FP16xFP16_ACCFP16_NT", [ - *self.prepare_set_group_4x( - "FP16xFP16_ACCFP16_NT", 16384, 16384, 16384 - ), + *self.prepare_set_group_4x("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384), *self.get_llm_benchmark_sets("FP16xFP16_ACCFP16_NT"), ], ) @@ -153,9 +153,7 @@ def prepare_benchmark_sets(self): self.add_benchmark_set( "INT8xINT8_ACCINT32_NT", [ - *self.prepare_set_group_4x( - "INT8xINT8_ACCINT32_NT", 16384, 16384, 16384 - ), + *self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384), *self.get_llm_benchmark_sets("INT8xINT8_ACCINT32_NT"), ], ) @@ -315,6 +313,7 @@ def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: # Disable default tuning when do benchmark return operator(config, target=self.benchmark_target, enable_tuning=False) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Bitblas Matmul Operator Benchmark") parser.add_argument( @@ -322,7 +321,7 @@ def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: action="store_true", help="Enable hardware-aware tuning", ) - + args = parser.parse_args() enable_tuning = args.enable_tuning BitblasMatmulOpsBenchmark().run(enable_tuning=args.enable_tuning) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index c46353ccf..16908dd41 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -116,9 +116,8 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind], else: object.__setattr__(self, "propagate_a", TransformKind.NonTransform) - if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or - isinstance(self.M, Tuple) or - (self.with_zeros and self.zeros_mode == "quantized")): + if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or + isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized")): object.__setattr__(self, "propagate_a", TransformKind.NonTransform) object.__setattr__(self, "propagate_b", TransformKind.NonTransform) else: