Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DispatchCreation] Modify the generated fused op to not use concats. #19980

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

MaheshRavishankar
Copy link
Contributor

This is an almost complete rewrite of the pass to fuse contractions horizontally which instead of concatenating operands to map to a GEMM, followed by slices to extract the individual matmul results; the pass now just creates a new operation with the operands being the common LHS, the RHS of each of the gemms, and the output of each of the gemms. The generated op yields the result of each constituent matmul.
This also allows for the RHS/output indexing maps of the gemms to be mismatched, since only the LHS operand and indexing maps need to match. The change also permutes the iteration space of the gemms to ensure that the same indexing maps are used for the LHS across all the fused matmuls.

The rest of the compiler stack has already been fixed up to handle such operations.

@MaheshRavishankar
Copy link
Contributor Author

It might be better to just review the new changes by themselves and ignore the diff. The pass is essentially rewritten.

@IanWood1
Copy link
Contributor

There's a problem with how ops are grouped. allOps never gets updated with the other ops determined to be fusible with the root op. Also, the candidates to fuse need to be iterated over in dominance order to ensure that, using the example below, %3 gets grouped before %4

util.func public @test_partial_horizontal_fuse(%arg0: tensor<640x640xf32>, %arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>, %arg3: tensor<640x640xf32>) -> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<640x640xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>
}

This isn't directly related to the changes you made, I think this problem is on main too.

@MaheshRavishankar
Copy link
Contributor Author

There's a problem with how ops are grouped. allOps never gets updated with the other ops determined to be fusible with the root op. Also, the candidates to fuse need to be iterated over in dominance order to ensure that, using the example below, %3 gets grouped before %4

util.func public @test_partial_horizontal_fuse(%arg0: tensor<640x640xf32>, %arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>, %arg3: tensor<640x640xf32>) -> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<640x640xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>
}

This isn't directly related to the changes you made, I think this problem is on main too.

Good catch. Let me see if I can fix that.

// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG4]] :
// CHECK: outs(%[[FILL]], %[[FILL]] :
// CHECK: %[[TRUNCF:.+]] = linalg.generic
// CHECK-SAME: ins(%[[FUSED_OP]]#1, %[[FUSED_OP]]#0 :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to horizontally fuse the truncates like this. Instead we should be able to keep them separate and just put them in the same dispatch. That will keep tile + fuse simpler at the workgroup level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an issue for fixing it at the workgroup level. This functionality is needed for llama ffn layer horizontal fusion. llvm/llvm-project#125915 is meant to fix it, but I havent reviewed it yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, just to clarify. The trunc operator is not fused. It is an elementwise op that using results of both the matmuls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oooh that code looks spooky. I'm not sure how that's supposed to work without changing how we decide whether the fusion is legal.

The trunc operator is not fused. It is an elementwise op that using results of both the matmuls

Exactly, so what is the benefit of putting them all in a single op? I would expect multiple different elementwise ops to work the same way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cause they will be created and consumed in the same dispatch. I agree I need to actually work with it to see how this works.

MaheshRavishankar and others added 4 commits February 14, 2025 09:02
The change also allows doing horizontal fusion in cases where the LHS
operand is the same, but the RHS/Outputs might be transposed.

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
… this yet.

Previous implementation of horizontal fusion missed opportunities for
horizontal fusion in SD3, but now they do get picked up, but the
backend doesnt work on these. Dropping the flag is a no-op for the
test since there was no horizontal fusion to start with.

Signed-off-by: MaheshRavishankar <[email protected]>
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are (already) missing the documentation of the pass. The PR description looks good to me. Can you add such documentation to Passes.td? I.e., add a description section to the pass definition.

def FuseHorizontalContractionsPass:
InterfacePass<"iree-dispatch-creation-fuse-horizontal-contractions", "mlir::FunctionOpInterface"> {
let summary = "Fuses horizontal contraction ops without fusions";
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::tensor::TensorDialect",
];
let options = [
Option<"fusionLimit", "fusion-limit", "int",
/*default=*/"3", "Maximum number of contractions fused into one">
];
let statistics = [
Statistic<"numFusionGroups", "num-fusion-groups", "Number of fusion groups found">,
Statistic<"numSize2FusionGroups", "num-size-2-groups", "Number of fusion groups of size 2">,
Statistic<"numSize3FusionGroups", "num-size-3-groups", "Number of fusion groups of size 3">
];
}

@MaheshRavishankar MaheshRavishankar force-pushed the shared/noconcatHorizontalFusionChanges branch from bfca1d5 to dcfa537 Compare February 14, 2025 19:56
@MaheshRavishankar
Copy link
Contributor Author

There's a problem with how ops are grouped. allOps never gets updated with the other ops determined to be fusible with the root op. Also, the candidates to fuse need to be iterated over in dominance order to ensure that, using the example below, %3 gets grouped before %4

util.func public @test_partial_horizontal_fuse(%arg0: tensor<640x640xf32>, %arg1: tensor<640x640xf32>, %arg2: tensor<640x640xf32>, %arg3: tensor<640x640xf32>) -> (tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>) {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<640x640xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %3 = linalg.matmul ins(%arg0, %arg2 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  %4 = linalg.matmul ins(%arg0, %3 : tensor<640x640xf32>, tensor<640x640xf32>) outs(%1 : tensor<640x640xf32>) -> tensor<640x640xf32>
  util.return %2, %3, %4 : tensor<640x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>
}

This isn't directly related to the changes you made, I think this problem is on main too.

@IanWood1 pushed a fix for this issue.

@IanWood1
Copy link
Contributor

IanWood1 commented Feb 14, 2025

I merged this into #19847 (I think these changes enable more horizontal fusion in punet) and got a few failing dispatches https://gist.github.com/IanWood1/2ddd601970b9d0197cf01aa91346e7e8.

They are smaller sized so I think there going down a different pipeline

@MaheshRavishankar
Copy link
Contributor Author

I merged this into #19847 (I think these changes enable more horizontal fusion in punet) and got a few failing dispatches https://gist.github.com/IanWood1/2ddd601970b9d0197cf01aa91346e7e8.

They are smaller sized so I think there going down a different pipeline

Really. I have been trying this on punet locally. I didnt see any issue there.

Signed-off-by: MaheshRavishankar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants