diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 17462acd273f..865e43778fcf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -323,22 +323,8 @@ std::optional findConstValue(Value val) { return intAttr.getInt(); } -bool fastPathAvailable(const SharedMemoryObject &smemObj, - const SharedEncodingAttr &srcEncoding, - const MfmaEncodingAttr &dstEncoding) { - if (srcEncoding.getMaxPhase() > 1) - return false; - auto stride0 = findConstValue(smemObj.strides[0]); - auto stride1 = findConstValue(smemObj.strides[1]); - auto offset0 = findConstValue(smemObj.offsets[0]); - auto offset1 = findConstValue(smemObj.offsets[1]); - bool allValuesDefined = stride0.has_value() && stride1.has_value() && - offset0.has_value() && offset1.has_value(); - if (!allValuesDefined) - return false; - if (offset0.value() != 0 || offset1.value() != 0) - return false; - return true; +bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) { + return srcEncoding.getMaxPhase() > 1; } // Computes offsets for operand B or transposed operand A @@ -469,10 +455,10 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, SmallVector loadedValues; SmallVector offsets; Value smemBase; - bool isFastPath = fastPathAvailable(smemObj, sharedLayout, mfmaLayout); - if (!isKMajor(order, opIdx) && isFastPath) { - // fast path handles tensors that are not k-major, in which case swizzling - // is disabled and offsets computation can be simplified + bool isFastPath = !isKMajor(order, opIdx) && !hasSwizzleEnabled(sharedLayout); + if (isFastPath) { + // fast path handles tensors that are not k-major and have swizzling + // disabled, in which case offsets computation can be simplified // TODO (zhanglx): later when we enable vector access to LDS for non k-major // tensors, we'll refactor the scope of fast and normal path Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); @@ -499,8 +485,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); } else { // normal path - // Normal path handles tensors that are k-major, in which case swizzling - // is enabled and it requires a 2-step method to compute the offsets. + // Normal path handles tensors that fall into either of the following three + // cases: + // 1. k-major + swizzling is enabled <-- this should be the most + // performant case + // 2. k-major + swizzling is disabled <-- for testing purpose only + // 3. non k-major + swizzling is enabled <-- for testing purpose only + // + // In this path, it requires a 2-step method to compute the offsets. if (opIdx == 0) { offsets = computeOffsetsAType( rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK, diff --git a/python/perf-kernels/03-matrix-multiplication-all-types.py b/python/perf-kernels/03-matrix-multiplication-all-types.py index ef3adf8fbc1d..498e0ef2b3bf 100644 --- a/python/perf-kernels/03-matrix-multiplication-all-types.py +++ b/python/perf-kernels/03-matrix-multiplication-all-types.py @@ -228,8 +228,8 @@ def get_x_vals(): # Only test k-major tensors because # 1. This is the most preformant config and the current focus # 2. Other case does not work with num_stages=0 (TODO (zhanglx)) - for col_a in [False] - for col_b in [True]] + for col_a in [True, False] + for col_b in [True, False]] ) def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype): a, a_fp16 = gen_input(M, K, in_dtype, col_a, 1, device='cuda') diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py index f0b7eddfa618..3b712ee5bdb0 100644 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -460,10 +460,8 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_ K, N = b.shape # 1D launch kernel where each block gets its own program. - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'] - ) + grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k + matmul_kernel[grid]( a, b, c, M, N, K, @@ -503,7 +501,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, config, v size_str = '' if verbose: size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}' - if torch.allclose(triton_output.to(torch.float16), torch_output, atol=1e-1, rtol=rtol): + if torch.allclose(triton_output.to(torch.float16), torch_output, atol=1e-3, rtol=rtol): print(f'{size_str} Correct✅') else: print(f'{size_str} Incorrect❌')