Skip to content

Commit

Permalink
typo fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Jul 21, 2024
1 parent 31813b2 commit 78b6a3d
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,13 +1991,13 @@ def get_param_indices(
block_idx = sch.fuse(i1, j1)
thread_idy = i2
thread_idz = j2

sch.bind(batch, "blockIdx.z")
sch.bind(block_idx, "blockIdx.x")
sch.bind(block_idy, "blockIdx.y")
thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz)
sch.bind(thread_idy, "threadIdx.y")

def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741
if not enable:
return
Expand Down Expand Up @@ -2063,7 +2063,6 @@ def decode_fetch_to_shared(block, idx):
block_shared = sch.cache_read(block, idx, shared_scope)
sch.compute_at(block_shared, k0, preserve_unit_loops=True)


# TODO(lei): the factor should be analyzed more deeper.
decode_factor = get_coalesced_veclen(sch.get(block_shared))
_, B_shared_vi, _ = sch.split(
Expand Down Expand Up @@ -2166,7 +2165,7 @@ def get_idx():
_ = decode_fetch_to_shared(block_outer, 1)

# Put the thread binding after the shared memory prefetch
# Otherwise there's a axis mssing bug behind tvm
# Otherwise there's a axis missing bug behind tvm
sch.bind(kr, "threadIdx.z")
# create read cache to load matrix from shared memory to wmma fragments
A_mat = sch.cache_read(block_outer, 0, "warp")
Expand Down

0 comments on commit 78b6a3d

Please sign in to comment.