Skip to content

Commit

Permalink
chore: Refactor benchmark scripts and fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 9, 2024
1 parent 113e485 commit 46b2d7d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 81 deletions.
105 changes: 48 additions & 57 deletions benchmark/operators/benchmark_matmul_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,26 @@ class BitblasMatmulOpsBenchmarkCompareStategies(BitblasOperatorBenchmarkBase):
OPT_SHAPES = [1, 16, 32, 64, 128, 256, 512, 4096]
CURRENT_COMMIT_ID = get_commit_id()

def __init__(
self
):
def __init__(self):
super().__init__()

def prepare_set_group_4x(self, name: str, N, K) -> List:
assert name in self.config_map, f"Operator {name} not found in config map"
optimize_strategy = self.config_map[name]["optimize_stratety"]
return [
self.generate_op_unit(
self.generate_operator_config(name, [1, 16, 32, 64, 128, 256, 512] if optimize_strategy == OptimizeStrategy.SingleBatchDecodeOnly else [16, 32, 64, 128, 256, 512], N, K)
),
self.generate_operator_config(
name, [1, 16, 32, 64, 128, 256, 512] if optimize_strategy
== OptimizeStrategy.SingleBatchDecodeOnly else [16, 32, 64, 128, 256, 512], N,
K)),
]

def prepare_benchmark_sets(self):
"""Prepare benchmark sets."""
self.add_benchmark_set(
"FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV",
[
*self.prepare_set_group_4x(
"FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 16384, 16384
),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 16384, 16384),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 3200, 3200),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 8640, 3200),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 3200, 8640),
Expand All @@ -85,19 +83,32 @@ def prepare_benchmark_sets(self):
16384,
16384,
),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 3200, 3200),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 8640, 3200),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 3200, 8640),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 5120, 5120),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 13824, 5120),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 5120, 13824),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 6656, 6656),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 17920, 6656),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 6656, 17920),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 1024, 8192),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 8192, 8192),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 28672, 8192),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", 8192, 28672),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
3200, 3200),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
8640, 3200),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
3200, 8640),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
5120, 5120),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
13824, 5120),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
5120, 13824),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
6656, 6656),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
17920, 6656),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
6656, 17920),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
1024, 8192),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
8192, 8192),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
28672, 8192),
*self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching",
8192, 28672),
],
)

Expand Down Expand Up @@ -135,13 +146,9 @@ def serialize_results(self) -> None:
for i, _ in enumerate(results):
config = self.benchmark_sets[name][i][1]
dyn_prof_shape = self.benchmark_sets[name][i][2]
shapes[name].append(
[config.M, config.N, config.K, dyn_prof_shape]
)
shapes[name].append([config.M, config.N, config.K, dyn_prof_shape])

self._save_json(
shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE)
)
self._save_json(shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE))

# Save device info into JSON
self._save_json(
Expand All @@ -162,8 +169,7 @@ def deserialize_from_logs(cls, commit_id: str) -> None:
log_commit_path = path.join(benchmark.log_path, commit_id_path)

benchmark.benchmark_results = cls._load_json(
path.join(log_commit_path, cls.BENCHMARK_RESULTS_FILE)
)
path.join(log_commit_path, cls.BENCHMARK_RESULTS_FILE))

shapes_file = path.join(log_commit_path, cls.BENCHMARK_SHAPES_FILE)

Expand All @@ -176,17 +182,14 @@ def deserialize_from_logs(cls, commit_id: str) -> None:
name,
[
benchmark.generate_op_unit(
benchmark.generate_operator_config(
name, M, N, K
),
benchmark.generate_operator_config(name, M, N, K),
dynamic_profiling_shape=dyn_prof_shape,
)
],
)

benchmark.benchmark_target = cls._load_json(
path.join(log_commit_path, cls.BENCHMARK_DEVICE_FILE)
)["device"]
path.join(log_commit_path, cls.BENCHMARK_DEVICE_FILE))["device"]

return benchmark

Expand All @@ -210,7 +213,7 @@ def report(self):
"Shape (M-N-K / N-K_M)",
"Single Batching Time (ms)",
"Shape (M-N-K / N-K_M)",
"Contigious Batching Time (ms)",
"Contiguous Batching Time (ms)",
"Tune Time (s)",
],
]
Expand Down Expand Up @@ -243,13 +246,11 @@ def legalize_shape(M, N, K, dyn_prof_shape):
origin_name = f"{name}STRATEGY{strategy_name}"
for i, benchmark_set in enumerate(self.benchmark_sets[origin_name]):
op_config = benchmark_set[1]
sub_results = results[i * len(self.OPT_SHAPES) : (i + 1) * len(self.OPT_SHAPES)]
sub_results = results[i * len(self.OPT_SHAPES):(i + 1) * len(self.OPT_SHAPES)]
for i, result in enumerate(sub_results):
latency = result[0]
dyn_prof_shape = {"m":self.OPT_SHAPES[i]}
shape = legalize_shape(
"DYN", op_config.N, op_config.K, dyn_prof_shape
)
dyn_prof_shape = {"m": self.OPT_SHAPES[i]}
shape = legalize_shape("DYN", op_config.N, op_config.K, dyn_prof_shape)
latency_str = "N/A" if latency is None else f"{latency:.3f}"
tmp_data.append([shape, latency_str])
if len(data) == 0:
Expand All @@ -268,9 +269,7 @@ def legalize_shape(M, N, K, dyn_prof_shape):
data[i][3] = f"{head} ({symbol}{speedup * 100 :.3f}%)"
table_data.append([*data[i], "N/A"])

print(
tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")
)
print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid"))

for data in table_data:
print(data)
Expand All @@ -286,9 +285,7 @@ def get_operator_config(self):
def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul:
"""Make an Matmul instance."""
# Disable default tuning when do benchmark
return operator(
config, target=self.benchmark_target, enable_tuning=False
)
return operator(config, target=self.benchmark_target, enable_tuning=False)

def benchmark(self):
"""Run benchmarks on all benchmark sets."""
Expand All @@ -297,14 +294,9 @@ def benchmark(self):
for op, config, _ in benchmark_set:
for opt in self.OPT_SHAPES:
self.benchmark_results[name].extend(
[
self.run_benchmark(op, config, {"m": opt})
]
)
[self.run_benchmark(op, config, {"m": opt})])

def run_compare_strategy(
self, report=True, serialize=True, enable_tuning: bool = False
):
def run_compare_strategy(self, report=True, serialize=True, enable_tuning: bool = False):
"""Run the benchmark process."""

if not path.exists(self.log_path):
Expand All @@ -323,9 +315,7 @@ def run_compare_strategy(


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Bitblas Matmul Operator Benchmark"
)
parser = argparse.ArgumentParser(description="Bitblas Matmul Operator Benchmark")

parser.add_argument(
"--enable_tuning",
Expand All @@ -335,4 +325,5 @@ def run_compare_strategy(

args = parser.parse_args()
enable_tuning = args.enable_tuning
BitblasMatmulOpsBenchmarkCompareStategies().run_compare_strategy(enable_tuning=args.enable_tuning)
BitblasMatmulOpsBenchmarkCompareStategies().run_compare_strategy(
enable_tuning=args.enable_tuning)
41 changes: 20 additions & 21 deletions benchmark/operators/benchmark_ops_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,39 +123,37 @@ def prepare_set_group_llm(self, name: str, N, K) -> List:
),
]

def get_llm_benchmark_sets(self, name:str) -> List:
return [*self.prepare_set_group_llm(name, 3200, 3200),
*self.prepare_set_group_llm(name, 8640, 3200),
*self.prepare_set_group_llm(name, 3200, 8640),
*self.prepare_set_group_llm(name, 5120, 5120),
*self.prepare_set_group_llm(name, 13824, 5120),
*self.prepare_set_group_llm(name, 5120, 13824),
*self.prepare_set_group_llm(name, 6656, 6656),
*self.prepare_set_group_llm(name, 17920, 6656),
*self.prepare_set_group_llm(name, 6656, 17920),
*self.prepare_set_group_llm(name, 1024, 8192),
*self.prepare_set_group_llm(name, 8192, 8192),
*self.prepare_set_group_llm(name, 28672, 8192),
*self.prepare_set_group_llm(name, 8192, 28672)]
def get_llm_benchmark_sets(self, name: str) -> List:
return [
*self.prepare_set_group_llm(name, 3200, 3200),
*self.prepare_set_group_llm(name, 8640, 3200),
*self.prepare_set_group_llm(name, 3200, 8640),
*self.prepare_set_group_llm(name, 5120, 5120),
*self.prepare_set_group_llm(name, 13824, 5120),
*self.prepare_set_group_llm(name, 5120, 13824),
*self.prepare_set_group_llm(name, 6656, 6656),
*self.prepare_set_group_llm(name, 17920, 6656),
*self.prepare_set_group_llm(name, 6656, 17920),
*self.prepare_set_group_llm(name, 1024, 8192),
*self.prepare_set_group_llm(name, 8192, 8192),
*self.prepare_set_group_llm(name, 28672, 8192),
*self.prepare_set_group_llm(name, 8192, 28672)
]

def prepare_benchmark_sets(self):
"""Prepare benchmark sets."""
self.add_benchmark_set(
"FP16xFP16_ACCFP16_NT",
[
*self.prepare_set_group_4x(
"FP16xFP16_ACCFP16_NT", 16384, 16384, 16384
),
*self.prepare_set_group_4x("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384),
*self.get_llm_benchmark_sets("FP16xFP16_ACCFP16_NT"),
],
)

self.add_benchmark_set(
"INT8xINT8_ACCINT32_NT",
[
*self.prepare_set_group_4x(
"INT8xINT8_ACCINT32_NT", 16384, 16384, 16384
),
*self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384),
*self.get_llm_benchmark_sets("INT8xINT8_ACCINT32_NT"),
],
)
Expand Down Expand Up @@ -315,14 +313,15 @@ def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul:
# Disable default tuning when do benchmark
return operator(config, target=self.benchmark_target, enable_tuning=False)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Bitblas Matmul Operator Benchmark")
parser.add_argument(
"--enable_tuning",
action="store_true",
help="Enable hardware-aware tuning",
)

args = parser.parse_args()
enable_tuning = args.enable_tuning
BitblasMatmulOpsBenchmark().run(enable_tuning=args.enable_tuning)
5 changes: 2 additions & 3 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind],
else:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)

if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or
isinstance(self.M, Tuple) or
(self.with_zeros and self.zeros_mode == "quantized")):
if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or
isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized")):
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)
else:
Expand Down

0 comments on commit 46b2d7d

Please sign in to comment.