Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Complete benchmark op sets of ci #100

Merged
merged 98 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
d8884e6
Refactor BatchMatMulEmitter and BatchMatMulSelector for improved read…
LeiWang1999 Jul 5, 2024
fc84173
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
02f64de
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
397eee6
disable failure email for ci
LeiWang1999 Jul 5, 2024
20f6ad1
remove email notifications.
LeiWang1999 Jul 6, 2024
b93c394
move relax pass from testing to mlc_llm
LeiWang1999 Jul 6, 2024
ba6a6df
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
257693a
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
9bb7f49
Lint Fix
LeiWang1999 Jul 6, 2024
39e7614
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
93eb5a5
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
aa66a90
bug fix in test
LeiWang1999 Jul 6, 2024
ae14a53
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 6, 2024
79b08e4
lint fix.
LeiWang1999 Jul 6, 2024
86fd036
test cuda i4 kernel
LeiWang1999 Jul 7, 2024
6b73a21
Refactor copyright notice in i4matmul.hpp
LeiWang1999 Jul 7, 2024
0ba90c1
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 7, 2024
086d208
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 7, 2024
47a3abd
refactor test as version below python 3.9 cannot handle int32 overflow.
LeiWang1999 Jul 8, 2024
024b247
format lint for test
LeiWang1999 Jul 8, 2024
bfedeaa
Refactor test_int4b_fp16_convert.py for improved readability and main…
LeiWang1999 Jul 8, 2024
e672a23
remove unused design file
LeiWang1999 Jul 8, 2024
21e5430
move tile device from package to base
LeiWang1999 Jul 8, 2024
fd11940
dummy impl for codegen
LeiWang1999 Jul 8, 2024
9ccfa85
Refactor file structure for ladder_permutate module
LeiWang1999 Jul 8, 2024
7c7d73e
Refactor backend class and fix typos in comments
LeiWang1999 Jul 8, 2024
47d5fc5
Deep refactor Lib related code.
LeiWang1999 Jul 8, 2024
53dd0dd
remove ci pull.
LeiWang1999 Jul 10, 2024
d58ac43
LintFix
LeiWang1999 Jul 10, 2024
37cb07c
refactor builder for whl build
LeiWang1999 Jul 10, 2024
f5b9999
Refactor TIRWrapper.wrap() method to include an assertion for the opt…
LeiWang1999 Jul 11, 2024
fb78244
Refactor lib_generator to set library and source paths
LeiWang1999 Jul 11, 2024
706e227
lint fix
LeiWang1999 Jul 11, 2024
63f5515
BitNet vllm integration
LeiWang1999 Jul 16, 2024
de91c0d
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 16, 2024
b9655fd
chore: update codespell to version 2.3.0
LeiWang1999 Jul 16, 2024
fff385f
Lintfix
LeiWang1999 Jul 16, 2024
72a98e7
Bump version to 0.0.1.dev13
LeiWang1999 Jul 18, 2024
5646ab5
lint fix
LeiWang1999 Jul 18, 2024
b965863
disable fast decoding [u]int4xint8 by default.
LeiWang1999 Jul 21, 2024
1198fc7
optimize from dict design in Hint
LeiWang1999 Jul 21, 2024
014213c
Implement SplitK
LeiWang1999 Jul 21, 2024
e0ca752
bitnet benchmark generation.
LeiWang1999 Jul 21, 2024
81b9cf0
Add benchmark script for BitNet integration
LeiWang1999 Jul 21, 2024
02edc0b
AtomicAdd Support
LeiWang1999 Jul 21, 2024
1a70c2d
LintFix
LeiWang1999 Jul 21, 2024
28d851c
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 21, 2024
c447a95
ci fix when 3rdparty tvm is initialized.
LeiWang1999 Jul 21, 2024
79a001b
bug fix for setup
LeiWang1999 Jul 21, 2024
31813b2
fix a bug in block reduce
LeiWang1999 Jul 21, 2024
78b6a3d
typo fix
LeiWang1999 Jul 21, 2024
9c55218
BUG Fix for block reduce.
LeiWang1999 Jul 22, 2024
1aa8868
Lint fix
LeiWang1999 Jul 22, 2024
22f70bf
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 22, 2024
5f082a5
Refactor block reduce schedule template
LeiWang1999 Jul 22, 2024
b4fb31e
transform branch from bitblas to bitblas_tl
LeiWang1999 Jul 22, 2024
35eaa00
Fix subproject commit reference in 3rdparty/tvm
LeiWang1999 Jul 22, 2024
254dd74
chore: update submodule branch from bitblas to bitblas_tl
LeiWang1999 Jul 22, 2024
31a44aa
force update config.cmake
LeiWang1999 Jul 22, 2024
427800e
Bug fix
LeiWang1999 Jul 22, 2024
96db111
Fix subproject commit reference in 3rdparty/cutlass
LeiWang1999 Jul 22, 2024
38b251a
chore: Add submodule for cutlass library
LeiWang1999 Jul 22, 2024
87d1c5a
update tl cutlass path
LeiWang1999 Jul 22, 2024
6200b1e
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 22, 2024
0ffe0b5
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 22, 2024
8e08e77
format fix
LeiWang1999 Jul 22, 2024
df05a64
Copy CUTLASS to the package directory
LeiWang1999 Jul 22, 2024
4f529c5
Refactor setup.py to include additional TVM header files
LeiWang1999 Jul 22, 2024
d02bbc7
lint fix
LeiWang1999 Jul 23, 2024
cffe3fd
bug fix
LeiWang1999 Jul 23, 2024
a8bed74
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 23, 2024
d4eb5fd
Implement Matmul Benchmark Design
LeiWang1999 Jul 23, 2024
4c6c2c1
chore: Update BitBLAS Matmul benchmark script
LeiWang1999 Jul 23, 2024
0acaca1
lint fix
LeiWang1999 Jul 23, 2024
54d2227
Refactor BitBLASMatmulOpsBenchmark for improved readability and maint…
LeiWang1999 Jul 23, 2024
c2edefb
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
e0bc723
lint fix
LeiWang1999 Jul 23, 2024
a4e68d1
Benchmark bot test
LeiWang1999 Jul 23, 2024
df7e9aa
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 23, 2024
1c03365
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
4f319fc
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
a8833d4
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
803f6c6
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
df4572b
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
45ded45
int8 test case
LeiWang1999 Jul 23, 2024
4229676
Refactor compare_benchmark.py to handle missing benchmark results gra…
LeiWang1999 Jul 23, 2024
b883290
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 23, 2024
476ffee
ci fix
LeiWang1999 Jul 23, 2024
9bd34ff
disable ci for test benchmark
LeiWang1999 Jul 23, 2024
e86f4b2
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 23, 2024
75f3dd9
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
79e04aa
remove cli installation
LeiWang1999 Jul 23, 2024
cdd3345
chore: Create virtual environment and install dependencies for benchmark
LeiWang1999 Jul 23, 2024
f099938
Merge branch 'main' into dev
LeiWang1999 Jul 23, 2024
f211ad4
chore: Update benchmark workflow to include comparison step
LeiWang1999 Jul 23, 2024
ddde02a
Lint fix
LeiWang1999 Jul 24, 2024
8045ce9
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 24, 2024
21aee89
Merge branch 'dev' of https://github.com/LeiWang1999/MSBitBLAS into dev
LeiWang1999 Jul 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
types: [created]

jobs:
benchmark:
benchmark_base:
if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark')
runs-on: self-hosted

Expand All @@ -17,7 +17,13 @@ jobs:

- name: Get base branch commit ID
id: get_base_commit
run: echo "BASE_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV
run: echo "BASE_COMMIT_ID=$(git rev-parse HEAD)" > base_commit_id.txt

- name: Upload base commit ID
uses: actions/upload-artifact@v3
with:
name: base-commit-id
path: base_commit_id.txt

- name: Set up Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -51,7 +57,18 @@ jobs:

- name: Get PR branch commit ID
id: get_pr_commit
run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV
run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" > pr_commit_id.txt

- name: Upload PR commit ID
uses: actions/upload-artifact@v3
with:
name: pr-commit-id
path: pr_commit_id.txt

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'

- name: Create virtual environment
run: python -m venv bitblas_benchmark
Expand All @@ -73,17 +90,49 @@ jobs:
cd benchmark/operators
python ./benchmark_ops_matmul.py

benchmark_compare:
if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark')
needs: [benchmark_base, benchmark_head]
runs-on: self-hosted

steps:
- name: Download commit IDs
uses: actions/download-artifact@v3
with:
name: base-commit-id
path: .

- name: Download PR commit ID
uses: actions/download-artifact@v3
with:
name: pr-commit-id
path: .

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'

- name: Create virtual environment
run: python -m venv bitblas_benchmark

- name: Activate virtual environment and install dependencies
run: |
source bitblas_benchmark/bin/activate
python -m pip install --upgrade pip
if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi

- name: Compare benchmark results
run: |
source bitblas_benchmark/bin/activate
cd benchmark/operators
python ./compare_benchmark.py --base ${{ env.BASE_COMMIT_ID }} --head ${{ env.PR_COMMIT_ID }} 2>&1 | tee compare_results.txt
python ./compare_benchmark.py --base $(cat base_commit_id.txt) --head $(cat pr_commit_id.txt) 2>&1 | tee compare_results.txt

- name: Authenticate GitHub CLI
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
gh auth login --with-token <<< $GITHUB_TOKEN
echo "${{ secrets.GITHUB_TOKEN }}" | gh auth login --with-token

- name: Post benchmark results
run: |
Expand Down
107 changes: 95 additions & 12 deletions benchmark/operators/benchmark_ops_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,109 @@ class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase):
"accum_dtype": "int32",
"out_dtype": "int8",
},
"FP16xINT4_ACCINT32_NT": {
"FP16xUINT4_ACCFP16_NT": {
"A_dtype": "float16",
"W_dtype": "int4",
"W_dtype": "uint4",
"accum_dtype": "float16",
},
"FP16xUINT2_ACCFP16_NT": {
"A_dtype": "float16",
"W_dtype": "uint2",
"accum_dtype": "float16",
},
"INT8xUINT2_ACCINT32_NT": {
"A_dtype": "int8",
"W_dtype": "uint2",
"accum_dtype": "int32",
"out_dtype": "int8",
},
}

CURRENT_COMMIT_ID = get_commit_id()

def prepare_set_group_4x(self, name: str, M, N, K) -> List:
return [
self.generate_op_unit(self.generate_operator_config(name, 1, N, K)),
self.generate_op_unit(self.generate_operator_config(name, M, N, K)),
self.generate_op_unit(
self.generate_operator_config(name, [1, M], N, K),
dynamic_profiling_shape={"m": 1},
),
self.generate_op_unit(
self.generate_operator_config(name, [1, M], N, K),
dynamic_profiling_shape={"m": M},
),
]

def prepare_set_group_llm(self, name: str, N, K) -> List:
return [
self.generate_op_unit(self.generate_operator_config(name, 1, N, K)),
self.generate_op_unit(self.generate_operator_config(name, 16, N, K)),
self.generate_op_unit(self.generate_operator_config(name, 32, N, K)),
self.generate_op_unit(self.generate_operator_config(name, 64, N, K)),
self.generate_op_unit(self.generate_operator_config(name, 128, N, K)),
self.generate_op_unit(self.generate_operator_config(name, 2048, N, K)),
self.generate_op_unit(
self.generate_operator_config(name, [1, 16], N, K),
dynamic_profiling_shape={"m": 1},
),
self.generate_op_unit(
self.generate_operator_config(name, [1, 32], N, K),
dynamic_profiling_shape={"m": 32},
),
self.generate_op_unit(
self.generate_operator_config(name, [1, 64], N, K),
dynamic_profiling_shape={"m": 64},
),
self.generate_op_unit(
self.generate_operator_config(name, [1, 128], N, K),
dynamic_profiling_shape={"m": 128},
),
self.generate_op_unit(
self.generate_operator_config(name, [1, 2048], N, K),
dynamic_profiling_shape={"m": 2048},
),
]

def prepare_benchmark_sets(self):
"""Prepare benchmark sets."""
self.add_benchmark_set(
"FP16xFP16_ACCFP16_NT",
[
self.generate_op_unit(
self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384),),
self.generate_op_unit(
self.generate_operator_config("FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384),
dynamic_profiling_shape={"M": 1024},
),
*self.prepare_set_group_4x("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 3200),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8640, 3200),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 8640),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 5120),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 13824, 5120),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 13824),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 6656),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 17920, 6656),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 17920),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 1024, 8192),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 8192),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 28672, 8192),
*self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 28672),
],
)

self.add_benchmark_set(
"INT8xINT8_ACCINT32_NT",
[
*self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 3200),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8640, 3200),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 8640),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 5120),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 13824, 5120),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 13824),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 6656),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 17920, 6656),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 17920),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 1024, 8192),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 8192),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 28672, 8192),
*self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 28672),
],
)

Expand Down Expand Up @@ -168,15 +251,15 @@ def legalize_shape(M, N, K, dyn_prof_shape):
M: The M dimension (can be an int or a tuple).
N: The N dimension (must be an int).
K: The K dimension (must be an int).
dyn_prof_shape: The dynamic profiling shape (dict with 'M' key if M is dynamic).
dyn_prof_shape: The dynamic profiling shape (dict with "m" key if M is dynamic).

Returns:
A string representing the shape in either 'M-N-K' or 'N-K_M' format.
"""
if isinstance(M, int):
return f"{M}-{N}-{K}"
elif dyn_prof_shape and "M" in dyn_prof_shape:
return f"{N}-{K}_{dyn_prof_shape['M']}"
elif dyn_prof_shape and "m" in dyn_prof_shape:
return f"{N}-{K}_{dyn_prof_shape['m']}"
else:
# Calculate the average of tuple M
opt_m = sum(M) / len(M)
Expand All @@ -195,7 +278,7 @@ def legalize_shape(M, N, K, dyn_prof_shape):
f"{(2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12):.3f}"
if latency else "N/A")
latency_str = "N/A" if latency is None else f"{latency:.3f}"
tuning_time_str = ("N/A" if tuning_time is None else f"{tuning_time:.3f}")
tuning_time_str = "N/A" if tuning_time is None else f"{tuning_time:.3f}"

table_data.append([shape, latency_str, throughput, tuning_time_str])

Expand Down
37 changes: 29 additions & 8 deletions benchmark/operators/compare_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,42 @@ def legalize_shape(M, N, K, dyn_prof_shape):
sum(op_config.M) /
len(op_config.M) if isinstance(op_config.M, Tuple) else op_config.M)

base_latency = base.benchmark_results[name][i][0]
try:
base_latency = base.benchmark_results[name][i][0]
except IndexError:
print(f"Operator {name} not found in benchmark sets")
base_latency = None

if latency is not None:
throughput = (2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12)
base_throughput = (2 * benchmark_M * op_config.N * op_config.K /
(base_latency * 1e-3) / 1e12)
throughput = f"{throughput:.3f}{get_suffix(base_throughput, throughput)}"
if base_latency is not None:
base_throughput = (2 * benchmark_M * op_config.N * op_config.K /
(base_latency * 1e-3) / 1e12)
throughput = f"{throughput:.3f}{get_suffix(base_throughput, throughput)}"
else:
throughput = f"{throughput:.3f}"
else:
throughput = "N/A"

if base_latency is not None:
latency_str = f"{latency:.3f}{get_suffix(base_latency, latency)}"
if latency is not None:
if base_latency is not None:
latency_str = f"{latency:.3f}{get_suffix(base_latency, latency)}"
else:
latency_str = f"{latency:.3f}"
else:
latency_str = "N/A"

base_tuning_time = base.benchmark_results[name][i][1]
try:
base_tuning_time = base.benchmark_results[name][i][1]
except IndexError:
print(f"Operator {name} not found in benchmark sets")
base_tuning_time = None

if tuning_time is not None:
tuning_time_str = f"{tuning_time:.3f}{get_suffix(base_tuning_time, tuning_time)}"
if base_tuning_time is not None:
tuning_time_str = f"{tuning_time:.3f}{get_suffix(base_tuning_time, tuning_time)}"
else:
tuning_time_str = f"{tuning_time:.3f}"
else:
tuning_time_str = "N/A"

Expand All @@ -95,6 +114,8 @@ def legalize_shape(M, N, K, dyn_prof_shape):
)
args = parser.parse_args()

print(f"Comparing base commit {args.base} with head commit {args.head}")

base_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.base)

head_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.head)
Expand Down
14 changes: 7 additions & 7 deletions bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ def __init__(self, config, sch, mod: Module):
self.mod = mod
self.code = mod.imported_modules[0].get_source() if mod else None
self.latency = 1e9
self.profile_tensors = []
self.time_evaluator = None

def profile(self):
profile_tensors = self.profile_tensors
return self.time_evaluator(*profile_tensors).mean * 1e3
def profile(self, data_distribution="uniform"):
func = self.sch.mod["main"]
device = self.config.arch.device
profile_tensors = get_dummy_input_arrays(func, device, distribution=data_distribution)
latency = self.time_evaluator(*profile_tensors).mean * 1e3
return latency


def _apply_config(
Expand Down Expand Up @@ -172,7 +174,6 @@ def apply_and_build_parallel(func,
data_distribution="uniform") -> CompileResult:
cpresults = []

profile_tensors = get_dummy_input_arrays(func, arch.device, distribution=data_distribution)
max_workers = min(len(configs), os.cpu_count(), max_workers)

# apply config in thread parallel
Expand Down Expand Up @@ -242,7 +243,6 @@ def tvm_callback_cuda_postproc(code, _):
cpresult = CompileResult(config, sch, rt_mod)
timer_cuda_mod = rt_mod.time_evaluator(
rt_mod.entry_name, arch.device, number=num_repeats)
cpresult.profile_tensors = profile_tensors
cpresult.time_evaluator = timer_cuda_mod
cpresult.code = code
cpresults.append(cpresult)
Expand All @@ -256,7 +256,7 @@ def tvm_callback_cuda_postproc(code, _):
for cpresult in cpresults:
config = cpresult.config
try:
latency = cpresult.profile()
latency = cpresult.profile(data_distribution=data_distribution)
except Exception as e_mesg:
logger.debug(f"Evaluation with config failed {e_mesg}")
continue
Expand Down
2 changes: 2 additions & 0 deletions bitblas/benchmark/operator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def run_benchmark(

latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape)

op_inst.cleanup()

return latency, tuning_time

@abstractmethod
Expand Down
8 changes: 8 additions & 0 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,11 @@ def dispatch_tir(self,
def _alloc_workspace(self):
return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda()

def _free_workspace(self):
# release the workspace if it is None
if self.workspace is not None:
self.workspace = None

def _assign_ladder_permutate_a(self, target: Target, enable_tuning: bool):
ladder_permutate_a = None
if self.propagate_a:
Expand Down Expand Up @@ -534,6 +539,9 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.forward(*args, **kwds)

def cleanup(self):
self._free_workspace()

@property
def M(self):
return self.config.M
Expand Down
1 change: 0 additions & 1 deletion bitblas/ops/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def var_warpper(v, m):
[var_warpper(i, m) for i in arg.shape]).astype(arg.dtype),
device=device,
))
self.profile_tensors = profile_tensors
latency = self.time_evaluator(*profile_tensors).mean * 1e3
benchmark_latencies.append({"m": m, "latency": latency})
# ms
Expand Down
Loading
Loading