Skip to content

Commit

Permalink
[Flow] Enable softmax-like fusion under aggressive fusion. (iree-org#…
Browse files Browse the repository at this point in the history
…17747)

Under aggressive fusion, drop the restriction of consumer iteration
space being same dimensionality as the producer iteration space.
Typically this can lead to large vectors if not handled properly. So
this is guarded under
`--iree-flow-enable-aggressive-fusion` flag.

Fixes nod-ai/SHARK-ModelDev#749

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Jun 28, 2024
1 parent e38cc7f commit 7090f64
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,12 @@ isFusableWithConsumer(OpOperand &fusedOperand,
// Check if the iteration spaces of the producer and consumer are same.
// TODO(#12664): This is unnecessary requirement, but we need a better config
// to tile the consumer with a larger iteration space.
auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
if (producerIterationSpace.size() < consumerIterationSpace.size()) {
return false;
if (!options.aggressiveFusion) {
auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
if (producerIterationSpace.size() < consumerIterationSpace.size()) {
return false;
}
}

// Under aggressive fusion assume that the dispatches are vectorized. In which
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions))" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}))" --split-input-file %s | FileCheck %s

util.func public @pack_elementwise_fusion(%arg0 : tensor<?xf32>,
%arg1 : tensor<?x?xf32>) -> tensor<?x?x8x32xf32> {
Expand Down Expand Up @@ -640,3 +640,109 @@ util.func public @broadcasting_dequant_op(%arg0 : tensor<?x?xi8>,
// CHECK-SAME: ins(%[[GENERIC]],
// CHECK: flow.return %[[MATMUL]]
// CHECK: return %[[RETURN]]

// -----

util.func @softmax_like_fusion(%arg0: tensor<2x4096x640xf16>,
%arg1: tensor<640xf16>, %arg2: tensor<640xf16>) -> tensor<2x4096x640x1xf16> {
%expanded = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
output_shape [2, 4096, 640, 1] : tensor<2x4096x640xf16> into tensor<2x4096x640x1xf16>
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.100000e+01 : f32
%cst_1 = arith.constant 4.000000e+00 : f32
%0 = tensor.empty() : tensor<2x4096x640xf32>
%1 = tensor.empty() : tensor<2x4096x640x1xf16>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<2x4096x640xf16>) outs(%0 : tensor<2x4096x640xf32>) {
^bb0(%in: f16, %out: f32):
%9 = arith.extf %in : f16 to f32
linalg.yield %9 : f32
} -> tensor<2x4096x640xf32>
%3 = tensor.empty() : tensor<2x4096xf32>
%4 = linalg.fill ins(%cst : f32)
outs(%3 : tensor<2x4096xf32>) -> tensor<2x4096xf32>
%5 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%2 : tensor<2x4096x640xf32>) outs(%4 : tensor<2x4096xf32>) {
^bb0(%in: f32, %out: f32):
%9 = arith.addf %in, %out : f32
linalg.yield %9 : f32
} -> tensor<2x4096xf32>
%6 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%5 : tensor<2x4096xf32>) outs(%3 : tensor<2x4096xf32>) {
^bb0(%in: f32, %out: f32):
%9 = arith.divf %in, %cst_0 : f32
linalg.yield %9 : f32
} -> tensor<2x4096xf32>
%7 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%2, %6 : tensor<2x4096x640xf32>, tensor<2x4096xf32>)
outs(%4 : tensor<2x4096xf32>) {
^bb0(%in: f32, %in_4: f32, %out: f32):
%9 = arith.subf %in, %in_4 : f32
%10 = arith.mulf %9, %9 : f32
%11 = arith.addf %10, %out : f32
linalg.yield %11 : f32
} -> tensor<2x4096xf32>
%expanded_2 = tensor.expand_shape %arg1 [[0, 1]] output_shape [640, 1]
: tensor<640xf16> into tensor<640x1xf16>
%expanded_3 = tensor.expand_shape %arg2 [[0, 1]] output_shape [640, 1]
: tensor<640xf16> into tensor<640x1xf16>
%8 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%expanded, %6, %7, %expanded_2, %expanded_3
: tensor<2x4096x640x1xf16>, tensor<2x4096xf32>, tensor<2x4096xf32>,
tensor<640x1xf16>, tensor<640x1xf16>)
outs(%1 : tensor<2x4096x640x1xf16>) {
^bb0(%in: f16, %in_4: f32, %in_5: f32, %in_6: f16, %in_7: f16, %out: f16):
%9 = arith.divf %in_5, %cst_0 : f32
%10 = arith.addf %9, %cst_1 : f32
%11 = math.rsqrt %10 : f32
%12 = arith.extf %in : f16 to f32
%13 = arith.subf %12, %in_4 : f32
%14 = arith.mulf %13, %11 : f32
%15 = arith.extf %in_6 : f16 to f32
%16 = arith.mulf %14, %15 : f32
%17 = arith.extf %in_7 : f16 to f32
%18 = arith.addf %16, %17 : f32
%19 = arith.truncf %18 : f32 to f16
linalg.yield %19 : f16
} -> tensor<2x4096x640x1xf16>
util.return %8 : tensor<2x4096x640x1xf16>
}
// CHECK-LABEL: func public @softmax_like_fusion(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x4096x640xf16>
// CHECK: %[[BITEXTEND:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK: %[[RESULT:.+]] = flow.dispatch.region
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[BITEXTEND]] :
// CHECK: %[[GENERIC2:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[GENERIC1]] :
// CHECK: %[[GENERIC3:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[BITEXTEND]], %[[GENERIC2]] :
// CHECK: %[[GENERIC4:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%{{.+}}, %[[GENERIC2]], %[[GENERIC3]]
// CHECK: flow.return %[[GENERIC4]]
// CHECK: util.return %[[RESULT]]

0 comments on commit 7090f64

Please sign in to comment.