Skip to content

Commit

Permalink
remove shared mem hack
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Sep 1, 2024
1 parent 9d90c40 commit 9ef14e9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions bitblas/gpu/matmul_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, red
# Apply Swizzling
sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle)
# if not, apply padding to alleviate bank conflict
if not (can_swizzle or is_smooth):
pad_offset = 8 if intrin_info.in_dtype == "float16" else 16
sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset)
# if not (can_swizzle or is_smooth):
# pad_offset = 8 if intrin_info.in_dtype == "float16" else 16
# sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset)
sch.annotate(f_2, "pragma_unroll_explicit", False)
return block_read

Expand Down Expand Up @@ -648,7 +648,7 @@ def inverse_permutation(i, j, ii, jj):
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down
8 changes: 4 additions & 4 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def get_idx():
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down Expand Up @@ -1075,7 +1075,7 @@ def get_idx():
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down Expand Up @@ -1675,7 +1675,7 @@ def get_idx():
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down Expand Up @@ -2194,7 +2194,7 @@ def get_idx():
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down

0 comments on commit 9ef14e9

Please sign in to comment.