diff --git a/bitblas/gpu/matmul_mma.py b/bitblas/gpu/matmul_mma.py index 5d92f99b..3dafd395 100644 --- a/bitblas/gpu/matmul_mma.py +++ b/bitblas/gpu/matmul_mma.py @@ -571,9 +571,9 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, red # Apply Swizzling sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) # if not, apply padding to alleviate bank conflict - if not (can_swizzle or is_smooth): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + # if not (can_swizzle or is_smooth): + # pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + # sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) sch.annotate(f_2, "pragma_unroll_explicit", False) return block_read @@ -648,7 +648,7 @@ def inverse_permutation(i, j, ii, jj): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 6bc0e39b..f6f1e098 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -578,7 +578,7 @@ def get_idx(): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) @@ -1075,7 +1075,7 @@ def get_idx(): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) @@ -1675,7 +1675,7 @@ def get_idx(): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) @@ -2194,7 +2194,7 @@ def get_idx(): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))