Skip to content

Commit

Permalink
[Dispatch] Bubble extract_slice through all parallel generics (#20161)
Browse files Browse the repository at this point in the history
Fixes llama fp8 perf regression introduced by
#20106. The PR stopped the
linalg.generic from getting hoisted. This was causing a broadcast to get
fused and `tensor<1x1x131072x131072xi1>` to be recomputed on each
prefill call.

---------

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Mar 5, 2025
1 parent ec128bf commit a09be42
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ struct BubbleUpExtract : OpRewritePattern<tensor::ExtractSliceOp> {
"expected generic op to have all projected permutation maps");
}

if (genericOp.hasIndexSemantics()) {
return rewriter.notifyMatchFailure(
genericOp, "pattern doesn't support index semantics");
}

Value replacement;
linalg::GenericOp swappedOp;
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,26 @@ func.func @bubble_up_extract_slice_single_use(%arg0: tensor<131072xi64>, %arg1:
// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: return %[[GENERIC]]

// -----

func.func @bubble_extract_broadcast(%arg0: tensor<1x1x131072xi64>, %arg2: index) -> tensor<?x?xi1> {
%empty = tensor.empty() : tensor<1x1x131072x131072xi1>
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0: tensor<1x1x131072xi64>) outs(%empty: tensor<1x1x131072x131072xi1>) {
^bb0(%in: i64, %out: i1):
%899 = linalg.index 3 : index
%900 = arith.index_cast %899 : index to i64
%901 = arith.cmpi sge, %900, %in : i64
linalg.yield %901 : i1
} -> tensor<1x1x131072x131072xi1>
%extracted_slice = tensor.extract_slice %0[0, 0, 0, 0] [1, 1, %arg2, %arg2] [1, 1, 1, 1] : tensor<1x1x131072x131072xi1> to tensor<?x?xi1>
return %extracted_slice : tensor<?x?xi1>
}
// CHECK-LABEL: func @bubble_extract_broadcast
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x131072xi64>
// CHECK-SAME: %[[ARG2:.+]]: index
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK-SAME: tensor<1x1x131072xi64> to tensor<?xi64>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXTRACT]] : tensor<?xi64>)
// CHECK: return %[[GENERIC]]

0 comments on commit a09be42

Please sign in to comment.