Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Fix GEMV Dynamic Scheduling with Splitk #52

Merged
merged 25 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
75d2f3d
improve e4m3 decoding.
May 21, 2024
dd744d0
Merge branch 'main' of https://github.com/microsoft/BitBLAS into main
May 23, 2024
00bfa31
append fp16xint1
May 25, 2024
8cd8b10
Update submodule commit reference
Jun 1, 2024
9122ff7
chore: Update shared memory scope for float32 output dtype
Jun 1, 2024
b508acc
BUGFIX: UINT8/INT8 Decoding
Jun 2, 2024
58d55b7
feat: Add rasterization options for roller module
Jun 5, 2024
e7547ce
Refactor tensorcore_legalization method to optimize tensor core usage
Jun 5, 2024
678a2e1
feat: Add function to collect variables from expression, improve for …
Jun 5, 2024
3088b35
chore: Update typing import in __init__.py
Jun 5, 2024
5d206b3
chore: Refactor CPU execution of operators
Jun 5, 2024
e06ce10
Refactor matmul implementation for splitk layout
Jun 5, 2024
d67cc6d
Refactor matmul implementation for splitk layout
Jun 5, 2024
9e36b6d
Refactor matmul implementation for splitk layout
Jun 5, 2024
e1a0149
chore: Update version to 0.0.1.dev8
Jun 5, 2024
df0ed7a
chore: Enable debug output in bitblas.set_debug_level()
Jun 5, 2024
a0f651a
Refactor Linear module matmul implementation for splitk layout
Jun 5, 2024
88295a7
Refactor matmul implementation for splitk layout
Jun 5, 2024
3366dce
Merge branch 'main' of https://github.com/microsoft/BitBLAS into lei/…
Jun 5, 2024
25b5c63
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 5, 2024
26a9f1b
Bumpt version to v0.0.1.dev9
Jun 5, 2024
251bf08
Merge branch 'main' of https://github.com/microsoft/BitBLAS into lei/…
Jun 5, 2024
e0cf62c
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 6, 2024
2e4e8dd
Bump version to v0.0.1.dev10
Jun 6, 2024
0dec7d8
Merge branch 'main' into lei/splitk
LeiWang1999 Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.1.dev9
0.0.1.dev10
2 changes: 1 addition & 1 deletion python/bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def _init_logger():

_init_logger()

__version__ = "0.0.1.dev9"
__version__ = "0.0.1.dev10"
3 changes: 2 additions & 1 deletion python/bitblas/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,8 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
return None

block_info = block_infos[0]
if len(block_info.iters) not in [2, 3]:
if len(block_info.iters) not in [2, 3, 4]:
# either [SK, B, S, R] = [SK, B, S, R] * [SK, B, R]
# either [B, S, R] = [B, S, R] * [B, R]
# or [S, R] = [S, R] * [R]
return None
Expand Down
13 changes: 13 additions & 0 deletions python/bitblas/gpu/gemv_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def get_vectorize_factor(target_format):
if len(sch.get_loops(block_b)) == 3:
i = sch.get_loops(block_b)[0]
sch.bind(i, "blockIdx.z")
elif len(sch.get_loops(block_b)) == 4:
# splitk case
sk, i = sch.get_loops(block_b)[:2]
sch.bind(sk, "blockIdx.y")
sch.bind(i, "blockIdx.z")

# get target dequantize buffer's idx
def get_idx(weight_decode_info: Dict):
Expand Down Expand Up @@ -274,6 +279,14 @@ def get_vectorize_factor(target_format):
if len(sch.get_loops(block_b)) == 3:
i = sch.get_loops(block_b)[0]
sch.bind(i, "blockIdx.z")
elif len(sch.get_loops(block_b)) == 4:
# splitk case
sk, i = sch.get_loops(block_b)[:2]
sch.bind(sk, "blockIdx.y")
sch.bind(i, "blockIdx.z")
assert len(config.thread) == 2, "SplitK only support 2D thread config"
num_warps = int(num_warps // config.thread[0])


# get target dequantize buffer's idx
def get_idx(weight_decode_info: Dict):
Expand Down
Loading