Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA fwd D=128] Reduce LDS usage in epilogue #340

Merged
merged 9 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ static void addWSNamedAttrs(Operation *op,
op->setAttr(attr.getName(), attr.getValue());
}

#ifdef USE_ROCM
constexpr int LDSSize = 65536;
constexpr int kPtrBitWidth = 64;
#endif
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, Target target)
Expand Down Expand Up @@ -410,6 +414,7 @@ struct ConvertTritonGPUToLLVM
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
#ifdef USE_ROCM
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
reduceCvtOpLDSUsage(mod);
#endif
decomposeBlockedToDotOperand(mod);
decomposeInsertSliceAsyncOp(mod);
Expand Down Expand Up @@ -710,6 +715,151 @@ struct ConvertTritonGPUToLLVM
}
});
}

int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) const {
unsigned inVec = 0;
unsigned outVec = 0;
auto smemShape = getScratchConfigForCvtLayout(cvtOp, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto bytes =
srcType.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcType.getElementTypeBitWidth()) / 8;

return bytes;
}

bool isPowerOfTwo(unsigned x) const { return x && (x & (x - 1)) == 0; }

std::vector<std::pair<int, int>> factorizePowerOf2(int n) const {
assert(isPowerOfTwo(n));
int x = log2(n);
std::vector<std::pair<int, int>> pairs;

for (int i = 0; i <= x / 2; ++i) {
int j = x - i;
pairs.push_back({pow(2, i), pow(2, j)});
pairs.push_back({pow(2, j), pow(2, i)});
}

return pairs;
}

std::pair<triton::gpu::ConvertLayoutOp, triton::gpu::ConvertLayoutOp>
createNewConvertOps(ModuleOp &mod, OpBuilder &builder,
triton::gpu::ConvertLayoutOp &cvtOp,
std::pair<unsigned, unsigned> warpsPerCta) const {
unsigned warpsPerCtaX = warpsPerCta.first;
unsigned warpsPerCtaY = warpsPerCta.second;
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();

auto srcMfma =
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get(
mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY},
srcMfma.getIsTransposed(), srcMfma.getCTALayout());

auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstType.getEncoding());
auto newSrcType = RankedTensorType::get(
srcType.getShape(), srcType.getElementType(), newMfmaEnc);

auto tmpCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newSrcType, cvtOp.getOperand());
auto newEpilogueCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newDstType, tmpCvt);

return std::make_pair(tmpCvt, newEpilogueCvt);
}

// Try to reduce LDS usage of cvt(mfma->blocked) op by changing the shape of
// WarpsPerCta attribute in mfma layout. The implicit LDS usage of
// cvt(mfma->blocked) op depends on the number of warps per CTA that mfma
// layout uses along x dimension and block layout uses across y dimension.
//
// clang-format off
//
// LDS usage of this op is roughly calculated as:
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type)
// LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C,
// where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type)
Comment on lines +787 to +788

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these lines redundant?

//
// clang-format on
//
// When LDS_USAGE exceeds the size of LDS, try to lower LDS usage by
// decomposing cvt(mfma->blocked) op into 2 conversions: cvt(mfma->mfma_tmp)
// and cvt(mfma_tmp->blocked), where mfma_tmp has WarpsPerCta attribute that
// minimizes uses of LDS for these conversions.
void reduceCvtOpLDSUsage(ModuleOp mod) const {
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);

auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();

auto srcMfma =
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
auto dstBlocked =
dstType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();

if (!srcMfma || !dstBlocked) {
return;
}

auto currLDSUsage = getCvtOpLDSUsage(cvtOp);
if (currLDSUsage <= LDSSize) {
return;
}

unsigned numWarps =
srcMfma.getWarpsPerCTA()[0] * srcMfma.getWarpsPerCTA()[1];

triton::gpu::ConvertLayoutOp tmpCvt;
triton::gpu::ConvertLayoutOp newEpilogueCvt;

// Find all possible shapes of WarpsPerCTA by finding all possible
// factorizations of numWarps. Pick shape for which both conversions in
// decomposition use LDS less than LDSSize and for which sum of LDS usage
// is minimal. If no such shape exists, do not decompose.
unsigned minLDSUsage = 2 * LDSSize;
int minIdx = -1;
auto factorizedNumWarps = factorizePowerOf2(numWarps);

for (int i = 0; i < factorizedNumWarps.size(); i++) {
auto warpsPerCTAPair = factorizedNumWarps[i];
std::tie(tmpCvt, newEpilogueCvt) =
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);

int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt);
int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt);
if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) {
int LDSUsage = tmpCvtLDS + newCvtLDS;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does lifetimes of scratch buffers of tmpCvt and newEpilogueCvt overlap?
If not I think it is better to choose maximum from these values instead of sum them.

if (LDSUsage < minLDSUsage) {
minLDSUsage = LDSUsage;
minIdx = i;
}
}
newEpilogueCvt.erase();
tmpCvt.erase();
}

if (minIdx == -1) {
return;
}

assert(minIdx >= 0 && minIdx < factorizedNumWarps.size());
auto warpsPerCTAPair = factorizedNumWarps[minIdx];
std::tie(tmpCvt, newEpilogueCvt) =
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);

cvtOp.replaceAllUsesWith(newEpilogueCvt.getResult());
cvtOp.erase();
});
}

#endif

void decomposeBlockedToDotOperand(ModuleOp mod) const {
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_best_config(self, *args, **kwargs):
key_values.append(kwargs[name])
key = tuple(key_values)

return self.cache[key] if key in self.cache else Config({})
return self.best_config


def run(self, *args, **kwargs):
Expand Down
72 changes: 31 additions & 41 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,28 +80,12 @@ def _attn_fwd_inner(

@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=0, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
],
key=['N_CTX', 'STAGE'],
key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)


Expand All @@ -114,9 +98,9 @@ def _attn_fwd(
stride_oz, stride_oh, stride_om, stride_on,
Z, H,
N_CTX,
BLOCK_DMODEL: tl.constexpr,
STAGE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
pre_load_v: tl.constexpr,
):
Expand Down Expand Up @@ -562,7 +546,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
)

## restore the grid for bwd kernel
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage)
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)

Expand Down Expand Up @@ -655,6 +639,9 @@ def backward(ctx, do):
[(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(4, 48, 1024, 128),
(4, 48, 2048, 128),
(4, 48, 4096, 128),
#(4, 48, 8192, 64),
#(4, 48, 16384, 64)
])
Expand Down Expand Up @@ -747,30 +734,33 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
BATCH, N_HEADS, N_CTX= 4, 48, 4096
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd', 'bwd']:
for causal in [False, True]:
if mode == 'bwd' and causal == False:
continue
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 15)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal})
)
for D_HEAD in [64, 128]:
if mode == 'bwd' and D_HEAD == 128:
continue
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 15)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal})
)


@triton.testing.perf_report(configs)
Expand Down