Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Ognjen committed Feb 22, 2024
1 parent 1687710 commit 3ae1238
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 98 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ add_definitions( -DROCM_DEFAULT_DIR="${ROCM_DEFAULT_DIR}")
# used conditionally in this file and by lit tests

# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O0 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O0 -g")

# Default build type
if(NOT CMAKE_BUILD_TYPE)
Expand Down
150 changes: 102 additions & 48 deletions lib/Dialect/TritonGPU/Transforms/AMDReorderInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class TritonAMDGPUReorderInstructionsPass
}
}


void processStage(Operation *currDot, Operation *moveBeforeDot,
SmallVector<Operation *> &operations, bool init,
int operandIdx) {
Expand All @@ -222,9 +223,64 @@ class TritonAMDGPUReorderInstructionsPass
return std::distance(value.user_begin(), value.user_end());
}

void scheduleSlicedDot(ModuleOp m, int stages, bool sinkLDSRd, bool sinkLDSWr) {
SmallVector<SmallVector<Operation *>> dotChains;
// Rearrange instructions of dot chain in pipelining manner.
// Note that not only load instruction
// will be hoisted, but all instructions starting from load to cvt(shared,
// dot). Let's say there are two dots:
//
// k0 = load(k0_ptr)
// k0_shared = cvt(k0, blocked, shared)
// k0_dot = cvt(k0_shared, shared, dot)
// dot0 = dot(..., k0_dot, 0)
//
// k1 = load(k1_ptr)
// k1_shared = cvt(k0, blocked, shared)
// k1_dot = cvt(k0_shared, shared, dot)
// dot1 = dot(..., k1_dot, dot1)
//
// process stage will rearrange instructions in following manner:
//
// k0 = load(k0_ptr)
// k1 = load(k1_ptr)
//
// k0_shared = cvt(k0, blocked, shared)
// k1_shared = cvt(k0, blocked, shared)
//
// k1_dot = cvt(k0_shared, shared, dot)
// k0_dot = cvt(k0_shared, shared, dot)
//
// dot0 = dot(..., k0_dot, 0)
// dot1 = dot(..., k1_dot, dot1)
void doPipelining(SmallVector<SmallVector<Operation *>> &dotChains, int pipelineStages) {
for (auto chain : dotChains) {
for (int i = 0; i < (chain.size() - 1) / pipelineStages; i++) {
SmallVector<Operation *> operations;
SmallVector<Operation *> operationsIdx0;
int startStageIdx = i == 0 ? 0 : 1;
for (int j = startStageIdx; j <= pipelineStages; j++) {
processStage(/*currDot*/ chain[i * pipelineStages + j],
/*moveBeforeDot*/ chain[i * pipelineStages],
operationsIdx0, j == startStageIdx, /*operandIdx*/ 0);
processStage(/*currDot*/ chain[i * pipelineStages + j],
/*moveBeforeDot*/ chain[i * pipelineStages], operations,
j == startStageIdx, /*operandIdx*/ 1);
}
}

int startDotIdx = ((chain.size() - 1) / pipelineStages) * pipelineStages;
SmallVector<Operation *> operations;
SmallVector<Operation *> operationsIdx0;
for (int i = 1; i <= (chain.size() - 1) % pipelineStages; i++) {
processStage(chain[startDotIdx + i], chain[startDotIdx], operationsIdx0,
i == 1, 0);
processStage(chain[startDotIdx + i], chain[startDotIdx], operations,
i == 1, 1);
}
}
}

void findDotChains(ModuleOp &m,
SmallVector<SmallVector<Operation *>> &dotChains) {
m.walk([&](tt::DotOp dotOp) {
if (!containsInAnyChain(dotChains, dotOp)) {
SmallVector<Operation *> newChain;
Expand All @@ -248,33 +304,10 @@ class TritonAMDGPUReorderInstructionsPass
}
}
});
}

for (auto chain : dotChains) {
for (int i = 0; i < chain.size() / stages; i++) {
SmallVector<Operation *> operations;
SmallVector<Operation *> operationsIdx0;
for (int j = 0; j < stages; j++) {
processStage(chain[i * stages + j], chain[i], operationsIdx0, j == 0,
0);
processStage(chain[i * stages + j], chain[i], operations, j == 0, 1);
}
}

int startDotIdx = (chain.size() / stages) * stages;
SmallVector<Operation *> operations;
SmallVector<Operation *> operationsIdx0;
for (int i = 0; i < chain.size() % stages; i++) {
processStage(chain[startDotIdx + i], chain[chain.size() / stages],
operationsIdx0, i == 0, 0);
processStage(chain[startDotIdx + i], chain[chain.size() / stages],
operations, i == 0, 1);
}
}

if (!sinkLDSRd) {
return;
}

void sinkLDSConverts(SmallVector<SmallVector<Operation *>> &dotChains,
bool sinkLDSWr) {
for (auto chain : dotChains) {
for (int i = 0; i < chain.size(); i++) {
Operation *dotOp = chain[i];
Expand All @@ -290,36 +323,57 @@ class TritonAMDGPUReorderInstructionsPass
}
}

void
interleaveLoadsAndLDS(SmallVector<SmallVector<Operation *>> &dotChains) {
for(auto chain: dotChains){
for (int i = 1; i < chain.size(); i++) {
Operation *dotOp = chain[i-1];
Operation *ldsRd = dotOp->getOperand(1).getDefiningOp();
assert(isLDSRead(ldsRd));

Operation *dotOpCurr = chain[i];
Operation *curr = dotOpCurr->getOperand(1).getDefiningOp();
while (!isa<tt::LoadOp>(curr)) {
curr = curr->getOperand(0).getDefiningOp();
}
moveBefore(curr, ldsRd);
}
}
}

void scheduleSlicedDot(ModuleOp m, int stages, bool sinkLDSRd, bool sinkLDSWr,
bool interleaveLoadWithLDSOps) {
SmallVector<SmallVector<Operation *>> dotChains;
int pipelineLoads = stages - 1;

findDotChains(m, dotChains);

if (stages > 1)
doPipelining(dotChains, pipelineLoads);

if (sinkLDSRd)
sinkLDSConverts(dotChains, sinkLDSWr);

// Arrange ops in CK-like fashion.
if (interleaveLoadWithLDSOps && stages == 2) {
interleaveLoadsAndLDS(dotChains);
}
}

void runOnOperation() override {
SmallVector<Operation *> movedOperations;
ModuleOp m = getOperation();

moveQTensorOutOfTheLoop(m);
int stages = 4;
int stages = 2;
bool sinkLDSRd = true;
bool sinkLDSWr = true;
scheduleSlicedDot(m, stages, sinkLDSRd, sinkLDSWr);
bool interleaveLoadWithLDSOps = true;
scheduleSlicedDot(m, stages, sinkLDSRd, sinkLDSWr,
interleaveLoadWithLDSOps);
}
};

std::unique_ptr<Pass> mlir::createTritonAMDGPUReorderInstructionsPass() {
return std::make_unique<TritonAMDGPUReorderInstructionsPass>();
}

// m.walk([&](tt::DotOp dotOp) {
// auto *operandA = dotOp.getOperand(0).getDefiningOp();
// auto convert = dyn_cast<ttg::ConvertLayoutOp>(operandA);
// auto srcTy = convert.getSrc().getType().cast<RankedTensorType>();
// Attribute srcLayout = srcTy.getEncoding();

// if (isa<ttg::MfmaEncodingAttr>(srcLayout)) {
// Operation *currOp = operandA;
// Operation *moveBeforeOp = dotOp;
// while (!isa<ttg::ViewSliceOp>(currOp)) {
// moveBefore(currOp, moveBeforeOp);
// moveBeforeOp = currOp;
// currOp = currOp->getOperand(0).getDefiningOp();
// }
// moveBefore(currOp, moveBeforeOp);
// }
// });
44 changes: 23 additions & 21 deletions python/perf-kernels/06-fused-attention-fwd-transV.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
# AMD E4M3B8
# Note: When picking this f8 data type, scaling is required when using f8
# for the second gemm
TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz')
float8:tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz
# TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz')
# float8:tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz

@triton.jit
def max_fn(x, y):
Expand Down Expand Up @@ -164,7 +164,7 @@ def forward(ctx, q, k, v, sm_scale):
## For fp16, pick BLOCK_M=256, num_warps=8
## For fp8, pick BLOCK_M=128, num_warps=4
## TODO (zhanglx): add tuning infra for FA
BLOCK_M = 128 if TORCH_HAS_FP8E4 and q.dtype == torch.float8_e4m3fnuz else 256
BLOCK_M = 128 #if TORCH_HAS_FP8E4 and q.dtype == torch.float8_e4m3fnuz else 256
BLOCK_N = 128
waves_per_eu = 2
num_warps = BLOCK_M // 32
Expand All @@ -189,6 +189,8 @@ def forward(ctx, q, k, v, sm_scale):
num_warps = num_warps,
num_stages = num_stages,
pre_load_v = pre_load_v,
slice_k_tile = 32,
kpack=2,
)

return o
Expand All @@ -198,19 +200,19 @@ def forward(ctx, q, k, v, sm_scale):

name_to_torch_types = {
'fp16': torch.float16,
'bf16': torch.bfloat16,
'fp8': float8
# 'bf16': torch.bfloat16,
# 'fp8': float8
}

@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype',
[ (*shape, dtype)
for shape in [(4, 48, 1024, 128),
(4, 48, 2048, 128),
(4, 48, 4096, 128)]
for dtype in ['fp16', 'bf16', 'fp8']])
for dtype in ['fp16']])
def test_op_fwd(Z, H, N_CTX, D_HEAD, dtype):
torch.manual_seed(20)
init_dtype = torch.float16 if dtype == 'fp8' else name_to_torch_types[dtype]
init_dtype = torch.float16 # if dtype == 'fp8' else name_to_torch_types[dtype]
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda")
.normal_(mean=0., std=0.5)
Expand Down Expand Up @@ -239,8 +241,8 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, dtype):
dout = torch.randn_like(q, dtype=torch.float16)
tri_out = attention(q, k, v, sm_scale)
# compare
atol = 1.4e-1 if dtype == 'fp8' else 1e-2
rtol = 1e-2 if dtype == 'fp8' else 3e-3
atol = 1e-2
rtol = 3e-3
torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol)


Expand All @@ -258,21 +260,21 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, dtype):

# vary seq length for fixed head and batch=4
configs = []
for dtype in ['fp16', 'bf16', 'fp8']:
for dtype in ['fp16']:
for D_HEAD in [128]:
for causal in [False]:
configs.append(triton.testing.Benchmark(
x_names=['BATCH', 'H','N_CTX'],
x_vals=[(16, 16, 1024),
(8, 16, 2048),
(4, 16, 4096),
(2, 16, 8192),
(1, 16, 16384),
(4, 48, 1024),
(4, 48, 2048),
x_vals=[#(16, 16, 1024),
# (8, 16, 2048),
# (4, 16, 4096),
# (2, 16, 8192),
# (1, 16, 16384),
# (4, 48, 1024),
# (4, 48, 2048),
(4, 48, 4096),
(4, 48, 8192),
(4, 48, 16384),
# (4, 48, 8192),
# (4, 48, 16384),
],
line_arg='provider',
line_vals=['triton'],
Expand All @@ -289,8 +291,8 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, dtype):

@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, provider, dtype, device="cuda"):
if dtype == 'fp8' and not TORCH_HAS_FP8E4:
sys.exit("fp8 is not available")
# if dtype == 'fp8' and not TORCH_HAS_FP8E4:
# sys.exit("fp8 is not available")
warmup = 25
rep = 100
init_dtype = torch.float16 if dtype != 'bf16' else torch.bfloat16
Expand Down
54 changes: 27 additions & 27 deletions python/perf-kernels/06-fused-attention-transV.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,23 @@ def _attn_fwd_inner(

@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=8),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
],
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)
Expand Down Expand Up @@ -762,21 +762,21 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd']:
for dtype in ["fp16", "bf16"]:
for D_HEAD in [128, 64]:
for causal in [False, True]:
for dtype in ["fp16"]:
for D_HEAD in [128]:
for causal in [False]:
configs.append(triton.testing.Benchmark(
x_names=['BATCH', 'H','N_CTX'],
x_vals=[(16, 16, 1024),
(8, 16, 2048),
(4, 16, 4096),
(2, 16, 8192),
(1, 16, 16384),
(4, 48, 1024),
(4, 48, 2048),
x_vals=[#(16, 16, 1024),
# (8, 16, 2048),
# (4, 16, 4096),
# (2, 16, 8192),
# (1, 16, 16384),
# (4, 48, 1024),
# (4, 48, 2048),
(4, 48, 4096),
(4, 48, 8192),
(4, 48, 16384),
# (4, 48, 8192),
# (4, 48, 16384),
],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
Expand Down

0 comments on commit 3ae1238

Please sign in to comment.