-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux Dev transformer RoPE IREE custom kernel bad performance #19822
Comments
The major deterioration is from the attention dispatch. This kernel from 2.29 ms goes up to 97 ms. The baseline dispatch is attention and transpose. // -----// IR Dump Before StripDebugOpsPass (iree-util-strip-debug-ops) //----- //
flow.executable private @forward_bs1$async_dispatch_41 {
flow.executable.export public @forward_bs1$async_dispatch_41_attention_24x4608x128xbf16_generic workgroups() -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @forward_bs1$async_dispatch_41_attention_24x4608x128xbf16_generic(%arg0: !flow.dispatch.tensor<readonly:tensor<24x4608x128xbf16>>, %arg1: !flow.dispatch.tensor<readonly:tensor<24x4608x128xbf16>>, %arg2: !flow.dispatch.tensor<readonly:tensor<24x128x4608xbf16>>, %arg3: !flow.dispatch.tensor<writeonly:tensor<4608x24x128xbf16>>) {
%cst = arith.constant 8.837890e-02 : bf16
%0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xbf16>> -> tensor<24x4608x128xbf16>
%1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xbf16>> -> tensor<24x4608x128xbf16>
%2 = flow.dispatch.tensor.load %arg2, offsets = [0, 0, 0], sizes = [24, 128, 4608], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x128x4608xbf16>> -> tensor<24x128x4608xbf16>
%3 = tensor.empty() : tensor<24x4608x128xbf16>
%4 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]} ins(%0, %1, %2, %cst : tensor<24x4608x128xbf16>, tensor<24x4608x128xbf16>, tensor<24x128x4608xbf16>, bf16) outs(%3 : tensor<24x4608x128xbf16>) {
^bb0(%arg4: f32):
iree_linalg_ext.yield %arg4 : f32
} -> tensor<24x4608x128xbf16>
%5 = tensor.empty() : tensor<4608x24x128xbf16>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<24x4608x128xbf16>) outs(%5 : tensor<4608x24x128xbf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
} -> tensor<4608x24x128xbf16>
flow.dispatch.tensor.store %6, %arg3, offsets = [0, 0, 0], sizes = [4608, 24, 128], strides = [1, 1, 1] : tensor<4608x24x128xbf16> -> !flow.dispatch.tensor<writeonly:tensor<4608x24x128xbf16>>
return
}
}
} In the model variant with the RoPE kernel the RoPE application onto the q and k tensors gets fused into the attention dispatch. // -----// IR Dump Before StripDebugOpsPass (iree-util-strip-debug-ops) //----- //
flow.executable private @forward_bs1$async_dispatch_31 {
flow.executable.export public @forward_bs1$async_dispatch_31_attention_24x1x4608x1x128xbf16_generic workgroups() -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @forward_bs1$async_dispatch_31_attention_24x1x4608x1x128xbf16_generic(%arg0: !flow.dispatch.tensor<readonly:tensor<1x4608x24x128xbf16>>, %arg1: !flow.dispatch.tensor<readonly:tensor<1x128x1x4608xf32>>, %arg2: !flow.dispatch.tensor<readonly:tensor<1x4608x24x128xbf16>>, %arg3: !flow.dispatch.tensor<readonly:tensor<1x128x4608xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<24x128x4608xbf16>>, %arg5: !flow.dispatch.tensor<writeonly:tensor<1x4608x24x128xbf16>>) {
%c128 = arith.constant 128 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 8.837890e-02 : bf16
%0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0, 0], sizes = [1, 4608, 24, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x4608x24x128xbf16>> -> tensor<1x4608x24x128xbf16>
%1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0, 0], sizes = [1, 128, 1, 4608], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x128x1x4608xf32>> -> tensor<1x128x1x4608xf32>
%2 = flow.dispatch.tensor.load %arg2, offsets = [0, 0, 0, 0], sizes = [1, 4608, 24, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x4608x24x128xbf16>> -> tensor<1x4608x24x128xbf16>
%3 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0], sizes = [1, 128, 4608], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x128x4608xf32>> -> tensor<1x128x4608xf32>
%4 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0], sizes = [24, 128, 4608], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x128x4608xbf16>> -> tensor<24x128x4608xbf16>
%5 = tensor.empty() : tensor<1x4608x24x128xbf16>
%6 = tensor.empty() : tensor<24x1x4608x128xbf16>
%7 = tensor.empty() : tensor<24x4608x1x128xbf16>
%8 = tensor.empty() : tensor<24x1x4608x1x128xbf16>
%9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<1x128x1x4608xf32>) outs(%8 : tensor<24x1x4608x1x128xbf16>) {
^bb0(%in: f32, %out: bf16):
%13 = linalg.index 0 : index
%14 = linalg.index 1 : index
%15 = linalg.index 2 : index
%16 = linalg.index 3 : index
%17 = linalg.index 4 : index
%18 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4608)>()[%15, %14]
%19 = arith.muli %16, %c128 overflow<nsw> : index
%20 = arith.addi %17, %19 : index
%21 = arith.divui %20, %c2 : index
%22 = arith.remui %20, %c2 : index
%23 = math.cos %in : f32
%24 = math.sin %in : f32
%25 = arith.muli %21, %c2 : index
%26 = arith.addi %25, %c1 : index
%extracted = tensor.extract %0[%c0, %18, %13, %25] : tensor<1x4608x24x128xbf16>
%27 = arith.extf %extracted : bf16 to f32
%extracted_0 = tensor.extract %0[%c0, %18, %13, %26] : tensor<1x4608x24x128xbf16>
%28 = arith.extf %extracted_0 : bf16 to f32
%29 = arith.cmpi eq, %22, %c0 : index
%30 = arith.mulf %27, %23 : f32
%31 = arith.mulf %28, %24 : f32
%32 = arith.subf %30, %31 : f32
%33 = arith.mulf %28, %23 : f32
%34 = arith.mulf %27, %24 : f32
%35 = arith.addf %33, %34 : f32
%36 = arith.select %29, %32, %35 : f32
%37 = arith.truncf %36 : f32 to bf16
linalg.yield %37 : bf16
} -> tensor<24x1x4608x1x128xbf16>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3 : tensor<1x128x4608xf32>) outs(%7 : tensor<24x4608x1x128xbf16>) {
^bb0(%in: f32, %out: bf16):
%13 = linalg.index 0 : index
%14 = linalg.index 1 : index
%15 = linalg.index 2 : index
%16 = linalg.index 3 : index
%17 = arith.muli %15, %c128 overflow<nsw> : index
%18 = arith.addi %16, %17 : index
%19 = arith.divui %18, %c2 : index
%20 = arith.remui %18, %c2 : index
%21 = math.cos %in : f32
%22 = math.sin %in : f32
%23 = arith.muli %19, %c2 : index
%24 = arith.addi %23, %c1 : index
%extracted = tensor.extract %2[%c0, %14, %13, %23] : tensor<1x4608x24x128xbf16>
%25 = arith.extf %extracted : bf16 to f32
%extracted_0 = tensor.extract %2[%c0, %14, %13, %24] : tensor<1x4608x24x128xbf16>
%26 = arith.extf %extracted_0 : bf16 to f32
%27 = arith.cmpi eq, %20, %c0 : index
%28 = arith.mulf %25, %21 : f32
%29 = arith.mulf %26, %22 : f32
%30 = arith.subf %28, %29 : f32
%31 = arith.mulf %26, %21 : f32
%32 = arith.mulf %25, %22 : f32
%33 = arith.addf %31, %32 : f32
%34 = arith.select %27, %30, %33 : f32
%35 = arith.truncf %34 : f32 to bf16
linalg.yield %35 : bf16
} -> tensor<24x4608x1x128xbf16>
%11 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d6, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>]} ins(%9, %10, %4, %cst : tensor<24x1x4608x1x128xbf16>, tensor<24x4608x1x128xbf16>, tensor<24x128x4608xbf16>, bf16) outs(%6 : tensor<24x1x4608x128xbf16>) {
^bb0(%arg6: f32):
iree_linalg_ext.yield %arg6 : f32
} -> tensor<24x1x4608x128xbf16>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11 : tensor<24x1x4608x128xbf16>) outs(%5 : tensor<1x4608x24x128xbf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
} -> tensor<1x4608x24x128xbf16>
flow.dispatch.tensor.store %12, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 4608, 24, 128], strides = [1, 1, 1, 1] : tensor<1x4608x24x128xbf16> -> !flow.dispatch.tensor<writeonly:tensor<1x4608x24x128xbf16>>
return
}
}
} |
Are you saying you don't want Also, fusing those ops into the dispatch was enabled by #19745 since they would be considered "gather"-like. Maybe the heuristic will have to be more strict. |
That was maybe too aggressive. I should try fusing only Q computation. Give me a sec. |
Fixes iree-org#19822 Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar, with the #19829 patch the mean execution time is 6.5% greater/worse than the baseline without the RoPE kernel. I will look into more detail what is going on. |
Ok cool. So it's an improvement over the previous? |
I had the wrong setup, the above result is meaningless. I will rerun the benchmark correctly. |
I reran the benchmark when fusing only the RoPE application to q. I will explore the trace for other differences as the preparation of the positional embeddings is different. |
I played a bit with changing the kernel such that we don't have to transpose the sequence and head dimensions, but I got worse performance. Around 573 ms (+3.7%). It may be related to not reading q and k in chunks due to the loop interchange. |
cc @manupak maybe this form of attention + rope needs some work? |
an initial skim on this, this gather like behaviour makes all loads to be total number loads in the rope kernel : 338 vs base : 20. generally, what is the expectation here? |
The |
Do you have IR for that ? |
Try with this patch #19829 |
What happened?
In this PR I switched the Flux transformer to use the custom kernel for rotary embeddings and got a significant drop in performance.
The runtime increased from 552 ms to 6131 ms.
I have not investigated yet what is the cause.
Steps to reproduce your issue
Refer to #19751 for more detailed instruction on how to run the benchmark.
Download tracy-profile.zip to get modified MLIR that uses the custom kernel RoPE. There is also a Tracy capture inside of a benchmark with 2 iterations.
What component(s) does this issue relate to?
Compiler
Version information
ebb9615
Additional context
No response
The text was updated successfully, but these errors were encountered: