Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Dev transformer RoPE IREE custom kernel bad performance #19822

Open
sogartar opened this issue Jan 27, 2025 · 14 comments · May be fixed by #19829
Open

Flux Dev transformer RoPE IREE custom kernel bad performance #19822

sogartar opened this issue Jan 27, 2025 · 14 comments · May be fixed by #19829
Labels
bug 🐞 Something isn't working

Comments

@sogartar
Copy link
Contributor

sogartar commented Jan 27, 2025

What happened?

In this PR I switched the Flux transformer to use the custom kernel for rotary embeddings and got a significant drop in performance.
The runtime increased from 552 ms to 6131 ms.

I have not investigated yet what is the cause.

Steps to reproduce your issue

Refer to #19751 for more detailed instruction on how to run the benchmark.
Download tracy-profile.zip to get modified MLIR that uses the custom kernel RoPE. There is also a Tracy capture inside of a benchmark with 2 iterations.

What component(s) does this issue relate to?

Compiler

Version information

ebb9615

Additional context

No response

@sogartar sogartar added the bug 🐞 Something isn't working label Jan 27, 2025
@sogartar
Copy link
Contributor Author

sogartar commented Jan 27, 2025

The major deterioration is from the attention dispatch. This kernel from 2.29 ms goes up to 97 ms.

The baseline dispatch is attention and transpose.

// -----// IR Dump Before StripDebugOpsPass (iree-util-strip-debug-ops) //----- //
flow.executable private @forward_bs1$async_dispatch_41 {
  flow.executable.export public @forward_bs1$async_dispatch_41_attention_24x4608x128xbf16_generic workgroups() -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
    flow.return %x, %y, %z : index, index, index
  }
  builtin.module {
    func.func @forward_bs1$async_dispatch_41_attention_24x4608x128xbf16_generic(%arg0: !flow.dispatch.tensor<readonly:tensor<24x4608x128xbf16>>, %arg1: !flow.dispatch.tensor<readonly:tensor<24x4608x128xbf16>>, %arg2: !flow.dispatch.tensor<readonly:tensor<24x128x4608xbf16>>, %arg3: !flow.dispatch.tensor<writeonly:tensor<4608x24x128xbf16>>) {
      %cst = arith.constant 8.837890e-02 : bf16
      %0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xbf16>> -> tensor<24x4608x128xbf16>
      %1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xbf16>> -> tensor<24x4608x128xbf16>
      %2 = flow.dispatch.tensor.load %arg2, offsets = [0, 0, 0], sizes = [24, 128, 4608], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x128x4608xbf16>> -> tensor<24x128x4608xbf16>
      %3 = tensor.empty() : tensor<24x4608x128xbf16>
      %4 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]} ins(%0, %1, %2, %cst : tensor<24x4608x128xbf16>, tensor<24x4608x128xbf16>, tensor<24x128x4608xbf16>, bf16) outs(%3 : tensor<24x4608x128xbf16>) {
      ^bb0(%arg4: f32):
        iree_linalg_ext.yield %arg4 : f32
      } -> tensor<24x4608x128xbf16>
      %5 = tensor.empty() : tensor<4608x24x128xbf16>
      %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<24x4608x128xbf16>) outs(%5 : tensor<4608x24x128xbf16>) {
      ^bb0(%in: bf16, %out: bf16):
        linalg.yield %in : bf16
      } -> tensor<4608x24x128xbf16>
      flow.dispatch.tensor.store %6, %arg3, offsets = [0, 0, 0], sizes = [4608, 24, 128], strides = [1, 1, 1] : tensor<4608x24x128xbf16> -> !flow.dispatch.tensor<writeonly:tensor<4608x24x128xbf16>>
      return
    }
  }
}

In the model variant with the RoPE kernel the RoPE application onto the q and k tensors gets fused into the attention dispatch.
My hypothesis is that this fusion throws off the subsequent passes.
In the baseline these are not fused as the incoming positional embeddings get calculated into a different form for all MMDiT layers. The cos/sin application is moved out of the MMDiT layer loop. This causes the input embeddings to be 4x larger. My thinking is that we don't want this because we would like to decrease memory bandwidth in exchange for some recompilation on each iteration. In the custom RoPE kernel variant the cos/sin computation is embedded into the kernel.

// -----// IR Dump Before StripDebugOpsPass (iree-util-strip-debug-ops) //----- //
flow.executable private @forward_bs1$async_dispatch_31 {
  flow.executable.export public @forward_bs1$async_dispatch_31_attention_24x1x4608x1x128xbf16_generic workgroups() -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
    flow.return %x, %y, %z : index, index, index
  }
  builtin.module {
    func.func @forward_bs1$async_dispatch_31_attention_24x1x4608x1x128xbf16_generic(%arg0: !flow.dispatch.tensor<readonly:tensor<1x4608x24x128xbf16>>, %arg1: !flow.dispatch.tensor<readonly:tensor<1x128x1x4608xf32>>, %arg2: !flow.dispatch.tensor<readonly:tensor<1x4608x24x128xbf16>>, %arg3: !flow.dispatch.tensor<readonly:tensor<1x128x4608xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<24x128x4608xbf16>>, %arg5: !flow.dispatch.tensor<writeonly:tensor<1x4608x24x128xbf16>>) {
      %c128 = arith.constant 128 : index
      %c2 = arith.constant 2 : index
      %c1 = arith.constant 1 : index
      %c0 = arith.constant 0 : index
      %cst = arith.constant 8.837890e-02 : bf16
      %0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0, 0], sizes = [1, 4608, 24, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x4608x24x128xbf16>> -> tensor<1x4608x24x128xbf16>
      %1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0, 0], sizes = [1, 128, 1, 4608], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x128x1x4608xf32>> -> tensor<1x128x1x4608xf32>
      %2 = flow.dispatch.tensor.load %arg2, offsets = [0, 0, 0, 0], sizes = [1, 4608, 24, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x4608x24x128xbf16>> -> tensor<1x4608x24x128xbf16>
      %3 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0], sizes = [1, 128, 4608], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x128x4608xf32>> -> tensor<1x128x4608xf32>
      %4 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0], sizes = [24, 128, 4608], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x128x4608xbf16>> -> tensor<24x128x4608xbf16>
      %5 = tensor.empty() : tensor<1x4608x24x128xbf16>
      %6 = tensor.empty() : tensor<24x1x4608x128xbf16>
      %7 = tensor.empty() : tensor<24x4608x1x128xbf16>
      %8 = tensor.empty() : tensor<24x1x4608x1x128xbf16>
      %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<1x128x1x4608xf32>) outs(%8 : tensor<24x1x4608x1x128xbf16>) {
      ^bb0(%in: f32, %out: bf16):
        %13 = linalg.index 0 : index
        %14 = linalg.index 1 : index
        %15 = linalg.index 2 : index
        %16 = linalg.index 3 : index
        %17 = linalg.index 4 : index
        %18 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4608)>()[%15, %14]
        %19 = arith.muli %16, %c128 overflow<nsw> : index
        %20 = arith.addi %17, %19 : index
        %21 = arith.divui %20, %c2 : index
        %22 = arith.remui %20, %c2 : index
        %23 = math.cos %in : f32
        %24 = math.sin %in : f32
        %25 = arith.muli %21, %c2 : index
        %26 = arith.addi %25, %c1 : index
        %extracted = tensor.extract %0[%c0, %18, %13, %25] : tensor<1x4608x24x128xbf16>
        %27 = arith.extf %extracted : bf16 to f32
        %extracted_0 = tensor.extract %0[%c0, %18, %13, %26] : tensor<1x4608x24x128xbf16>
        %28 = arith.extf %extracted_0 : bf16 to f32
        %29 = arith.cmpi eq, %22, %c0 : index
        %30 = arith.mulf %27, %23 : f32
        %31 = arith.mulf %28, %24 : f32
        %32 = arith.subf %30, %31 : f32
        %33 = arith.mulf %28, %23 : f32
        %34 = arith.mulf %27, %24 : f32
        %35 = arith.addf %33, %34 : f32
        %36 = arith.select %29, %32, %35 : f32
        %37 = arith.truncf %36 : f32 to bf16
        linalg.yield %37 : bf16
      } -> tensor<24x1x4608x1x128xbf16>
      %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3 : tensor<1x128x4608xf32>) outs(%7 : tensor<24x4608x1x128xbf16>) {
      ^bb0(%in: f32, %out: bf16):
        %13 = linalg.index 0 : index
        %14 = linalg.index 1 : index
        %15 = linalg.index 2 : index
        %16 = linalg.index 3 : index
        %17 = arith.muli %15, %c128 overflow<nsw> : index
        %18 = arith.addi %16, %17 : index
        %19 = arith.divui %18, %c2 : index
        %20 = arith.remui %18, %c2 : index
        %21 = math.cos %in : f32
        %22 = math.sin %in : f32
        %23 = arith.muli %19, %c2 : index
        %24 = arith.addi %23, %c1 : index
        %extracted = tensor.extract %2[%c0, %14, %13, %23] : tensor<1x4608x24x128xbf16>
        %25 = arith.extf %extracted : bf16 to f32
        %extracted_0 = tensor.extract %2[%c0, %14, %13, %24] : tensor<1x4608x24x128xbf16>
        %26 = arith.extf %extracted_0 : bf16 to f32
        %27 = arith.cmpi eq, %20, %c0 : index
        %28 = arith.mulf %25, %21 : f32
        %29 = arith.mulf %26, %22 : f32
        %30 = arith.subf %28, %29 : f32
        %31 = arith.mulf %26, %21 : f32
        %32 = arith.mulf %25, %22 : f32
        %33 = arith.addf %31, %32 : f32
        %34 = arith.select %27, %30, %33 : f32
        %35 = arith.truncf %34 : f32 to bf16
        linalg.yield %35 : bf16
      } -> tensor<24x4608x1x128xbf16>
      %11 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d6, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>]} ins(%9, %10, %4, %cst : tensor<24x1x4608x1x128xbf16>, tensor<24x4608x1x128xbf16>, tensor<24x128x4608xbf16>, bf16) outs(%6 : tensor<24x1x4608x128xbf16>) {
      ^bb0(%arg6: f32):
        iree_linalg_ext.yield %arg6 : f32
      } -> tensor<24x1x4608x128xbf16>
      %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11 : tensor<24x1x4608x128xbf16>) outs(%5 : tensor<1x4608x24x128xbf16>) {
      ^bb0(%in: bf16, %out: bf16):
        linalg.yield %in : bf16
      } -> tensor<1x4608x24x128xbf16>
      flow.dispatch.tensor.store %12, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 4608, 24, 128], strides = [1, 1, 1, 1] : tensor<1x4608x24x128xbf16> -> !flow.dispatch.tensor<writeonly:tensor<1x4608x24x128xbf16>>
      return
    }
  }
}

@IanWood1
Copy link
Contributor

Are you saying you don't want %9 and %10 fused into the same dispatch with the attention op? My understanding is that fusing the rope computation decreases memory bandwidth but increases re-computation.

Also, fusing those ops into the dispatch was enabled by #19745 since they would be considered "gather"-like. Maybe the heuristic will have to be more strict.

cc @MaheshRavishankar

@MaheshRavishankar
Copy link
Contributor

That was maybe too aggressive. I should try fusing only Q computation. Give me a sec.

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this issue Jan 28, 2025
@MaheshRavishankar
Copy link
Contributor

@sogartar can you try with #19829

@sogartar
Copy link
Contributor Author

@MaheshRavishankar, with the #19829 patch the mean execution time is 6.5% greater/worse than the baseline without the RoPE kernel. I will look into more detail what is going on.

@MaheshRavishankar
Copy link
Contributor

Ok cool. So it's an improvement over the previous?

@sogartar
Copy link
Contributor Author

I had the wrong setup, the above result is meaningless. I will rerun the benchmark correctly.

@sogartar
Copy link
Contributor Author

I reran the benchmark when fusing only the RoPE application to q.
The result is 574 ms mean real time. Which is 4.0% percent higher than the baseline.
I confirmed that the fusion is correct and the fused attention kernel time for the MMDiT double block increased from 2.29 ms to 3.46 ms.
The attention for the MMDiT single block increased from a baseline of 3.05 ms to 3.53 ms.

I will explore the trace for other differences as the preparation of the positional embeddings is different.

@sogartar
Copy link
Contributor Author

I played a bit with changing the kernel such that we don't have to transpose the sequence and head dimensions, but I got worse performance. Around 573 ms (+3.7%). It may be related to not reading q and k in chunks due to the loop interchange.
One other model-side change that maybe can yield some improvements is to change the other layers such the sequence and head dimensions are transposed. They should be treated the same expect in attention is it is the only operation that should distinguish them. I am not able to confirm as this modification will require a lot of changes and maybe somewhere we will hit a problem.

@MaheshRavishankar
Copy link
Contributor

cc @manupak maybe this form of attention + rope needs some work?

@manupak
Copy link
Contributor

manupak commented Jan 31, 2025

@MaheshRavishankar,

an initial skim on this, this gather like behaviour makes all loads to be global_load_dwordx4 -> global_load_dword + global_load_ushort

total number loads in the rope kernel : 338 vs base : 20.
number of mfma stays the same (so its not unrolling related).

generally, what is the expectation here?
Im asking this to get the intuition why this kind of fusion is being pursue'd

@MaheshRavishankar
Copy link
Contributor

The q rope computation fusion overall makes sense right? It should be profitable to fuse this. The k rope computation does not make sense. I have a patch mentioned #19822 (comment) where it drops the k rope computation fusion. But q rope computation is about 7% slower. We need to see if that difference can be bridged.

@manupak
Copy link
Contributor

manupak commented Jan 31, 2025

Do you have IR for that ?
Something that I can compile and see what is going on -- because I see both for now.,

@MaheshRavishankar
Copy link
Contributor

Try with this patch #19829

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants