-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DispatchCreation] Modify the generated fused op to not use concats. #19980
base: main
Are you sure you want to change the base?
[DispatchCreation] Modify the generated fused op to not use concats. #19980
Conversation
It might be better to just review the new changes by themselves and ignore the diff. The pass is essentially rewritten. |
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
9ebc531
to
bfca1d5
Compare
There's a problem with how ops are grouped. util.func public @test_partial_horizontal_fuse(%arg0: tensor<640x640xf32>, %arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>, %arg3: tensor<640x640xf32>) -> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<640x640xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
%3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
%4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>
} This isn't directly related to the changes you made, I think this problem is on main too. |
Good catch. Let me see if I can fix that. |
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG4]] : | ||
// CHECK: outs(%[[FILL]], %[[FILL]] : | ||
// CHECK: %[[TRUNCF:.+]] = linalg.generic | ||
// CHECK-SAME: ins(%[[FUSED_OP]]#1, %[[FUSED_OP]]#0 : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to horizontally fuse the truncates like this. Instead we should be able to keep them separate and just put them in the same dispatch. That will keep tile + fuse simpler at the workgroup level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have an issue for fixing it at the workgroup level. This functionality is needed for llama ffn layer horizontal fusion. llvm/llvm-project#125915 is meant to fix it, but I havent reviewed it yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, just to clarify. The trunc operator is not fused. It is an elementwise op that using results of both the matmuls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oooh that code looks spooky. I'm not sure how that's supposed to work without changing how we decide whether the fusion is legal.
The trunc operator is not fused. It is an elementwise op that using results of both the matmuls
Exactly, so what is the benefit of putting them all in a single op? I would expect multiple different elementwise ops to work the same way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cause they will be created and consumed in the same dispatch. I agree I need to actually work with it to see how this works.
The change also allows doing horizontal fusion in cases where the LHS operand is the same, but the RHS/Outputs might be transposed. Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
… this yet. Previous implementation of horizontal fusion missed opportunities for horizontal fusion in SD3, but now they do get picked up, but the backend doesnt work on these. Dropping the flag is a no-op for the test since there was no horizontal fusion to start with. Signed-off-by: MaheshRavishankar <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are (already) missing the documentation of the pass. The PR description looks good to me. Can you add such documentation to Passes.td
? I.e., add a description
section to the pass definition.
iree/compiler/src/iree/compiler/DispatchCreation/Passes.td
Lines 68 to 84 in 6ebfcaa
def FuseHorizontalContractionsPass: | |
InterfacePass<"iree-dispatch-creation-fuse-horizontal-contractions", "mlir::FunctionOpInterface"> { | |
let summary = "Fuses horizontal contraction ops without fusions"; | |
let dependentDialects = [ | |
"mlir::arith::ArithDialect", | |
"mlir::tensor::TensorDialect", | |
]; | |
let options = [ | |
Option<"fusionLimit", "fusion-limit", "int", | |
/*default=*/"3", "Maximum number of contractions fused into one"> | |
]; | |
let statistics = [ | |
Statistic<"numFusionGroups", "num-fusion-groups", "Number of fusion groups found">, | |
Statistic<"numSize2FusionGroups", "num-size-2-groups", "Number of fusion groups of size 2">, | |
Statistic<"numSize3FusionGroups", "num-size-3-groups", "Number of fusion groups of size 3"> | |
]; | |
} |
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: MaheshRavishankar <[email protected]>
bfca1d5
to
dcfa537
Compare
@IanWood1 pushed a fix for this issue. |
I merged this into #19847 (I think these changes enable more horizontal fusion in punet) and got a few failing dispatches https://gist.github.com/IanWood1/2ddd601970b9d0197cf01aa91346e7e8. They are smaller sized so I think there going down a different pipeline |
Really. I have been trying this on punet locally. I didnt see any issue there. |
Signed-off-by: MaheshRavishankar <[email protected]>
This is an almost complete rewrite of the pass to fuse contractions horizontally which instead of concatenating operands to map to a GEMM, followed by slices to extract the individual matmul results; the pass now just creates a new operation with the operands being the common LHS, the RHS of each of the gemms, and the output of each of the gemms. The generated op yields the result of each constituent matmul.
This also allows for the RHS/output indexing maps of the gemms to be mismatched, since only the LHS operand and indexing maps need to match. The change also permutes the iteration space of the gemms to ensure that the same indexing maps are used for the LHS across all the fused matmuls.
The rest of the compiler stack has already been fixed up to handle such operations.