From 8f8b6f1a46643be1786ac1abe4a6b6e59f0d6c38 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Tue, 10 Oct 2023 22:08:00 -0500 Subject: [PATCH 1/9] rebase onto improve_fwd_fa --- .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index f09df4a51b68..cf311fbfa5d8 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -64,6 +64,10 @@ static void addWSNamedAttrs(Operation *op, op->setAttr(attr.getName(), attr.getValue()); } +#ifdef USE_ROCM +constexpr int LDSSize = 65536; +constexpr int kPtrBitWidth = 64; +#endif class TritonLLVMFunctionConversionTarget : public ConversionTarget { public: explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, Target target) @@ -410,6 +414,7 @@ struct ConvertTritonGPUToLLVM decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs); #ifdef USE_ROCM decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs); + reduceCvtOpLDSUsage(mod); #endif decomposeBlockedToDotOperand(mod); decomposeInsertSliceAsyncOp(mod); @@ -710,6 +715,151 @@ struct ConvertTritonGPUToLLVM } }); } + + int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) const { + unsigned inVec = 0; + unsigned outVec = 0; + auto smemShape = getScratchConfigForCvtLayout(cvtOp, inVec, outVec); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto srcType = cvtOp.getOperand().getType().cast(); + auto bytes = + srcType.getElementType().isa() + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcType.getElementTypeBitWidth()) / 8; + + return bytes; + } + + bool isPowerOfTwo(unsigned x) const { return x && (x & (x - 1)) == 0; } + + std::vector> factorizePowerOf2(int n) const { + assert(isPowerOfTwo(n)); + int x = log2(n); + std::vector> pairs; + + for (int i = 0; i <= x / 2; ++i) { + int j = x - i; + pairs.push_back({pow(2, i), pow(2, j)}); + pairs.push_back({pow(2, j), pow(2, i)}); + } + + return pairs; + } + + std::pair + createNewConvertOps(ModuleOp &mod, OpBuilder &builder, + triton::gpu::ConvertLayoutOp &cvtOp, + std::pair warpsPerCta) const { + unsigned warpsPerCtaX = warpsPerCta.first; + unsigned warpsPerCtaY = warpsPerCta.second; + auto srcType = cvtOp.getOperand().getType().cast(); + auto dstType = cvtOp.getType().cast(); + + auto srcMfma = + srcType.getEncoding().dyn_cast(); + auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get( + mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY}, + srcMfma.getIsTransposed()); + + auto newDstType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), dstType.getEncoding()); + auto newSrcType = RankedTensorType::get( + srcType.getShape(), srcType.getElementType(), newMfmaEnc); + + auto tmpCvt = builder.create( + cvtOp.getLoc(), newSrcType, cvtOp.getOperand()); + auto newEpliogueCvt = builder.create( + cvtOp.getLoc(), newDstType, tmpCvt); + + return std::make_pair(tmpCvt, newEpliogueCvt); + } + + // Try to reduce LDS usage of cvt(mfma->blocked) op by changing the shape of + // WarpsPerCta attribute in mfma layout. The implicit LDS usage of + // cvt(mfma->blocked) op depends on the number of warps per CTA that mfma + // layout uses along x dimension and block layout uses across y dimension. + // + // clang-format off + // + // LDS usage of this op is roughly calculated as: + // LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layoput)[1] * sizeof(data_type) + // LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C, + // where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type) + // + // clang-format on + // + // When LDS_USAGE exceeds the size of LDS, try to lower LDS usage by + // decomposing cvt(mfma->blocked) op into 2 conversions: cvt(mfma->mfma_tmp) + // and cvt(mfma_tmp->blocked), where mfma_tmp has WarpsPerCta attribute that + // minimizes uses of LDS for these conversions. + void reduceCvtOpLDSUsage(ModuleOp mod) const { + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + + auto srcType = cvtOp.getOperand().getType().cast(); + auto dstType = cvtOp.getType().cast(); + + auto srcMfma = + srcType.getEncoding().dyn_cast(); + auto dstBlocked = + dstType.getEncoding().dyn_cast(); + + if (!srcMfma || !dstBlocked) { + return; + } + + auto currLDSUsage = getCvtOpLDSUsage(cvtOp); + if (currLDSUsage <= LDSSize) { + return; + } + + unsigned numWarps = + srcMfma.getWarpsPerCTA()[0] * srcMfma.getWarpsPerCTA()[1]; + + triton::gpu::ConvertLayoutOp tmpCvt; + triton::gpu::ConvertLayoutOp newEpliogueCvt; + + // Find all possible shapes of WarpsPerCTA by finding all possible + // factorizations of numWarps. Pick shape for which both conversions in + // decomposition use LDS less than LDSSize and for which sum of LDS usage + // is minimal. If no such shape exists, do not decompose. + unsigned minLDSUsage = 2 * LDSSize; + int minIdx = -1; + auto factorizedNumWarps = factorizePowerOf2(numWarps); + + for (int i = 0; i < factorizedNumWarps.size(); i++) { + auto warpsPerCTAPair = factorizedNumWarps[i]; + std::tie(tmpCvt, newEpliogueCvt) = + createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); + + int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt); + int newCvtLDS = getCvtOpLDSUsage(newEpliogueCvt); + if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) { + int LDSUsage = tmpCvtLDS + newCvtLDS; + if (LDSUsage < minLDSUsage) { + minLDSUsage = LDSUsage; + minIdx = i; + } + } + newEpliogueCvt.erase(); + tmpCvt.erase(); + } + + if (minIdx == -1) { + return; + } + + assert(minIdx >= 0 && minIdx < factorizedNumWarps.size()); + auto warpsPerCTAPair = factorizedNumWarps[minIdx]; + std::tie(tmpCvt, newEpliogueCvt) = + createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); + + cvtOp.replaceAllUsesWith(newEpliogueCvt.getResult()); + cvtOp.erase(); + }); + } + #endif void decomposeBlockedToDotOperand(ModuleOp mod) const { From eff2b833086b7202be59c225da47a9682dfa2889 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Tue, 10 Oct 2023 22:16:52 -0500 Subject: [PATCH 2/9] Fixed a leftover from rebase --- lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index cf311fbfa5d8..4bba88ae4079 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -760,7 +760,7 @@ struct ConvertTritonGPUToLLVM srcType.getEncoding().dyn_cast(); auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get( mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY}, - srcMfma.getIsTransposed()); + srcMfma.getIsTransposed(), srcMfma.getCTALayout()); auto newDstType = RankedTensorType::get( dstType.getShape(), dstType.getElementType(), dstType.getEncoding()); From 2e220550038a50702be5ac6b1cd0cd05bbd48d5b Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 11 Oct 2023 12:07:22 -0500 Subject: [PATCH 3/9] rebase onto improve_fa_fwd --- python/tutorials/06-fused-attention.py | 43 ++++++++++++++------------ 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index b69ff61689a1..ffd6e07edd87 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -80,6 +80,8 @@ def _attn_fwd_inner( @triton.autotune( configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4), @@ -101,7 +103,7 @@ def _attn_fwd_inner( triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=0, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=0, num_warps=4), ], - key=['N_CTX', 'STAGE'], + key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'], verbose = True ) @@ -114,9 +116,9 @@ def _attn_fwd( stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, + BLOCK_DMODEL: tl.constexpr, STAGE: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, pre_load_v: tl.constexpr, ): @@ -747,30 +749,31 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): FLASH_VER = None HAS_FLASH = FLASH_VER is not None -BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +BATCH, N_HEADS, N_CTX= 4, 48, 4096 # vary seq length for fixed head and batch=4 configs = [] for mode in ['fwd', 'bwd']: for causal in [False, True]: if mode == 'bwd' and causal == False: continue - configs.append(triton.testing.Benchmark( - x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 15)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}', - args={ - 'H': N_HEADS, - 'BATCH': BATCH, - 'D_HEAD': D_HEAD, - 'dtype': torch.float16, - 'mode': mode, - 'causal': causal}) - ) + for D_HEAD in [64, 128]: + configs.append(triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 15)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'causal': causal}) + ) @triton.testing.perf_report(configs) From cb381e51ff04e33c857ed8a8aeda412e7f9561e4 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 11 Oct 2023 12:17:08 -0500 Subject: [PATCH 4/9] Reduce tuning space --- python/tutorials/06-fused-attention.py | 28 +++++--------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index ffd6e07edd87..9e24bb77a858 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -80,30 +80,12 @@ def _attn_fwd_inner( @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=0, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=0, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=0, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=0, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=0, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=0, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=0, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=0, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=0, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True ], - key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'], verbose = True + key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'], ) @@ -564,7 +546,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): ) ## restore the grid for bwd kernel - best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage) + best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk) block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) From 379374eb983d620433e7d358c859fda329eb325f Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 11 Oct 2023 12:22:26 -0500 Subject: [PATCH 5/9] Disable bwd with D=128 --- python/tutorials/06-fused-attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 9e24bb77a858..5bada9109d5b 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -739,6 +739,8 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): if mode == 'bwd' and causal == False: continue for D_HEAD in [64, 128]: + if mode == 'bwd' and D_HEAD == 128: + continue configs.append(triton.testing.Benchmark( x_names=['N_CTX'], x_vals=[2**i for i in range(10, 15)], From a1bea586e16e01a6f105fb4c203d04b2765d1948 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 11 Oct 2023 12:32:44 -0500 Subject: [PATCH 6/9] Add test for d=128 --- python/tutorials/06-fused-attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 5bada9109d5b..e2c5b2fe5e12 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -639,6 +639,9 @@ def backward(ctx, do): [(4, 48, 1024, 64), (4, 48, 2048, 64), (4, 48, 4096, 64), + (4, 48, 1024, 128), + (4, 48, 2048, 128), + (4, 48, 4096, 128), #(4, 48, 8192, 64), #(4, 48, 16384, 64) ]) @@ -808,4 +811,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype # only works on post-Ampere GPUs right now -bench_flash_attention.run(save_path='.', print_data=True) +#bench_flash_attention.run(save_path='.', print_data=True) + + +test_op_fwd(4, 48, 4096, ) From 169d6d5551899bcfec90049b5f1138576aaff86b Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 11 Oct 2023 12:50:33 -0500 Subject: [PATCH 7/9] Fix an issue with get_best_config when there is only one config --- python/triton/runtime/autotuner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 4d3c324f537b..15bcf53063a4 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -100,7 +100,7 @@ def get_best_config(self, *args, **kwargs): key_values.append(kwargs[name]) key = tuple(key_values) - return self.cache[key] if key in self.cache else Config({}) + return self.best_config def run(self, *args, **kwargs): From c9b6b3bccc5ef82c38314fe6b92ce0c51dcd2b54 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 11 Oct 2023 14:19:52 -0500 Subject: [PATCH 8/9] Added better configs for d=128 --- python/tutorials/06-fused-attention.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index e2c5b2fe5e12..d0019b1fedcb 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -80,8 +80,8 @@ def _attn_fwd_inner( @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True ], @@ -811,7 +811,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype # only works on post-Ampere GPUs right now -#bench_flash_attention.run(save_path='.', print_data=True) - - -test_op_fwd(4, 48, 4096, ) +bench_flash_attention.run(save_path='.', print_data=True) From 8348b7989fd55931ec8b42826b1621faaaa63286 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Tue, 24 Oct 2023 09:51:50 -0500 Subject: [PATCH 9/9] Fix typos --- .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 4bba88ae4079..180767c296bc 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -769,10 +769,10 @@ struct ConvertTritonGPUToLLVM auto tmpCvt = builder.create( cvtOp.getLoc(), newSrcType, cvtOp.getOperand()); - auto newEpliogueCvt = builder.create( + auto newEpilogueCvt = builder.create( cvtOp.getLoc(), newDstType, tmpCvt); - return std::make_pair(tmpCvt, newEpliogueCvt); + return std::make_pair(tmpCvt, newEpilogueCvt); } // Try to reduce LDS usage of cvt(mfma->blocked) op by changing the shape of @@ -783,7 +783,7 @@ struct ConvertTritonGPUToLLVM // clang-format off // // LDS usage of this op is roughly calculated as: - // LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layoput)[1] * sizeof(data_type) + // LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type) // LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C, // where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type) // @@ -818,7 +818,7 @@ struct ConvertTritonGPUToLLVM srcMfma.getWarpsPerCTA()[0] * srcMfma.getWarpsPerCTA()[1]; triton::gpu::ConvertLayoutOp tmpCvt; - triton::gpu::ConvertLayoutOp newEpliogueCvt; + triton::gpu::ConvertLayoutOp newEpilogueCvt; // Find all possible shapes of WarpsPerCTA by finding all possible // factorizations of numWarps. Pick shape for which both conversions in @@ -830,11 +830,11 @@ struct ConvertTritonGPUToLLVM for (int i = 0; i < factorizedNumWarps.size(); i++) { auto warpsPerCTAPair = factorizedNumWarps[i]; - std::tie(tmpCvt, newEpliogueCvt) = + std::tie(tmpCvt, newEpilogueCvt) = createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt); - int newCvtLDS = getCvtOpLDSUsage(newEpliogueCvt); + int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt); if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) { int LDSUsage = tmpCvtLDS + newCvtLDS; if (LDSUsage < minLDSUsage) { @@ -842,7 +842,7 @@ struct ConvertTritonGPUToLLVM minIdx = i; } } - newEpliogueCvt.erase(); + newEpilogueCvt.erase(); tmpCvt.erase(); } @@ -852,10 +852,10 @@ struct ConvertTritonGPUToLLVM assert(minIdx >= 0 && minIdx < factorizedNumWarps.size()); auto warpsPerCTAPair = factorizedNumWarps[minIdx]; - std::tie(tmpCvt, newEpliogueCvt) = + std::tie(tmpCvt, newEpilogueCvt) = createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); - cvtOp.replaceAllUsesWith(newEpliogueCvt.getResult()); + cvtOp.replaceAllUsesWith(newEpilogueCvt.getResult()); cvtOp.erase(); }); }