Skip to content

Commit

Permalink
Fix a bug in fastPath condition
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 committed Jan 25, 2024
1 parent 6141b10 commit 36c6a58
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -323,22 +323,8 @@ std::optional<int> findConstValue(Value val) {
return intAttr.getInt();
}

bool fastPathAvailable(const SharedMemoryObject &smemObj,
const SharedEncodingAttr &srcEncoding,
const MfmaEncodingAttr &dstEncoding) {
if (srcEncoding.getMaxPhase() > 1)
return false;
auto stride0 = findConstValue(smemObj.strides[0]);
auto stride1 = findConstValue(smemObj.strides[1]);
auto offset0 = findConstValue(smemObj.offsets[0]);
auto offset1 = findConstValue(smemObj.offsets[1]);
bool allValuesDefined = stride0.has_value() && stride1.has_value() &&
offset0.has_value() && offset1.has_value();
if (!allValuesDefined)
return false;
if (offset0.value() != 0 || offset1.value() != 0)
return false;
return true;
bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) {
return srcEncoding.getMaxPhase() > 1;
}

// Computes offsets for operand B or transposed operand A
Expand Down Expand Up @@ -469,10 +455,10 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
SmallVector<Value> loadedValues;
SmallVector<Value> offsets;
Value smemBase;
bool isFastPath = fastPathAvailable(smemObj, sharedLayout, mfmaLayout);
if (!isKMajor(order, opIdx) && isFastPath) {
// fast path handles tensors that are not k-major, in which case swizzling
// is disabled and offsets computation can be simplified
bool isFastPath = !isKMajor(order, opIdx) && !hasSwizzleEnabled(sharedLayout);
if (isFastPath) {
// fast path handles tensors that are not k-major and have swizzling
// disabled, in which case offsets computation can be simplified
// TODO (zhanglx): later when we enable vector access to LDS for non k-major
// tensors, we'll refactor the scope of fast and normal path
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
Expand All @@ -499,8 +485,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
}
smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
} else { // normal path
// Normal path handles tensors that are k-major, in which case swizzling
// is enabled and it requires a 2-step method to compute the offsets.
// Normal path handles tensors that fall into either of the following three
// cases:
// 1. k-major + swizzling is enabled <-- this should be the most
// performant case
// 2. k-major + swizzling is disabled <-- for testing purpose only
// 3. non k-major + swizzling is enabled <-- for testing purpose only
//
// In this path, it requires a 2-step method to compute the offsets.
if (opIdx == 0) {
offsets = computeOffsetsAType(
rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK,
Expand Down
4 changes: 2 additions & 2 deletions python/perf-kernels/03-matrix-multiplication-all-types.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ def get_x_vals():
# Only test k-major tensors because
# 1. This is the most preformant config and the current focus
# 2. Other case does not work with num_stages=0 (TODO (zhanglx))
for col_a in [False]
for col_b in [True]]
for col_a in [True, False]
for col_b in [True, False]]
)
def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype):
a, a_fp16 = gen_input(M, K, in_dtype, col_a, 1, device='cuda')
Expand Down
8 changes: 3 additions & 5 deletions scripts/amd/gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,8 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_
K, N = b.shape
# 1D launch kernel where each block gets its own program.

grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
META['SPLIT_K']
)
grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k

matmul_kernel[grid](
a, b, c,
M, N, K,
Expand Down Expand Up @@ -503,7 +501,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, config, v
size_str = ''
if verbose:
size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}'
if torch.allclose(triton_output.to(torch.float16), torch_output, atol=1e-1, rtol=rtol):
if torch.allclose(triton_output.to(torch.float16), torch_output, atol=1e-3, rtol=rtol):
print(f'{size_str} Correct✅')
else:
print(f'{size_str} Incorrect❌')
Expand Down

0 comments on commit 36c6a58

Please sign in to comment.