Skip to content

Commit

Permalink
[KernelDispatch] Add matmul RHS outer permutation (#1016)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuyls authored Jan 9, 2025
1 parent f519ca2 commit 53b96d5
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,23 @@ FailureOr<ParameterSetting> ParameterSetting::create(
}
} // namespace

/// Utility to set the packing inner permutation for A/LHS so that is packed as
/// [? ? m k] in case of matmul and [? ? ? m k] in case of batch_matmul.
static SmallVector<int64_t> setInnerPermA(bool isMatmulTransposeA) {
SmallVector<int64_t> innerPerm;
if (isMatmulTransposeA) {
innerPerm = {1, 0};
} else {
innerPerm = {0, 1};
}
return innerPerm;
}

/// Utility to set the packing inner permutation for B/RHS so that is packed as
/// - [? ? k n] in case of matmul
/// - [? ? ? k n] in case of batch_matmul
/// - [? ? n k] in case of matmul_transpose_b
/// - [? ? ? n k] in case of batch_matmul_transpose_b.
static SmallVector<int64_t> setInnerPermB(bool isMatmulTransposeB) {
SmallVector<int64_t> innerPerm;
if (isMatmulTransposeB) {
Expand All @@ -326,14 +343,34 @@ static SmallVector<int64_t> setInnerPermB(bool isMatmulTransposeB) {
return innerPerm;
}

static SmallVector<int64_t> setInnerPermA(bool isMatmulTransposeA) {
SmallVector<int64_t> innerPerm;
/// Utility to set the packing outer permutation for A/LHS so that is packed as
/// [M K ? ?] in case of matmul and [Batch M K ? ?] in case of batch_matmul.
static SmallVector<int64_t> setOuterPermA(bool isMatmulTransposeA,
bool isBatchMatmul) {
SmallVector<int64_t> outerPerm;
if (isMatmulTransposeA) {
innerPerm = {1, 0};
outerPerm = isBatchMatmul ? SmallVector<int64_t>{0, 2, 1}
: SmallVector<int64_t>{1, 0};
} else {
innerPerm = {0, 1};
outerPerm = isBatchMatmul ? SmallVector<int64_t>{0, 1, 2}
: SmallVector<int64_t>{0, 1};
}
return innerPerm;
return outerPerm;
}

/// Utility to set the packing outer permutation for B/RHS so that is packed as
/// [N K ? ?] in case of matmul and [Batch N K ? ?] in case of batch_matmul.
static SmallVector<int64_t> setOuterPermB(bool isMatmulTransposeB,
bool isBatchMatmul) {
SmallVector<int64_t> outerPerm;
if (isMatmulTransposeB) {
outerPerm = isBatchMatmul ? SmallVector<int64_t>{0, 1, 2}
: SmallVector<int64_t>{0, 1};
} else {
outerPerm = isBatchMatmul ? SmallVector<int64_t>{0, 2, 1}
: SmallVector<int64_t>{1, 0};
}
return outerPerm;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -362,7 +399,7 @@ static LogicalResult setRootConfigForPackPeelPipeline(
packedSizesL0.insert(packedSizesL0.begin(), 0);
}

// For matmul, transpose B matrix from [K N n k] to [K N k n]
// For matmul, transpose B matrix from [K N n k] to [N K k n]
// For matmul_transpose_b, we don't have to transpose the B matrix,
// since it is already [N K n k]
SmallVector<int64_t> transposePackIndices = {0, 1};
Expand All @@ -372,11 +409,12 @@ static LogicalResult setRootConfigForPackPeelPipeline(
SmallVector<int64_t> innerPermA = setInnerPermA(isMatmulTransposeA(linalgOp));
SmallVector<int64_t> innerPermB = setInnerPermB(isMatmulTransposeB(linalgOp));
SmallVector<SmallVector<int64_t>> innerPerm = {innerPermA, innerPermB};
SmallVector<int64_t> outerPermVec = {0, 1};
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
outerPermVec.push_back(2);
}
SmallVector<SmallVector<int64_t>> outerPerm = {outerPermVec, outerPermVec};
bool isBatchMatmul = isa<linalg::BatchMatmulOp>(linalgOp);
SmallVector<int64_t> outerPermA =
setOuterPermA(isMatmulTransposeA(linalgOp), isBatchMatmul);
SmallVector<int64_t> outerPermB =
setOuterPermB(isMatmulTransposeB(linalgOp), isBatchMatmul);
SmallVector<SmallVector<int64_t>> outerPerm = {outerPermA, outerPermB};
if (isObjectFifo) {
// Add outer permutation for unpack. NOTE: This currently fails for some
// tests in the AIR pipeline.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-amdaie-lowering-strategy{use-lower-to-aie-pipeline=air use-tile-pipeline=pack-peel})' %s | FileCheck %s --check-prefix=CHECK-PACK-PEEL

// CHECK-PAD-PACK{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [0, 0, 256], [16, 16], [0, 0, 2]]>
// CHECK-PAD-PACK{LITERAL}: #packingConfig = #amdaie.packing_config<packing_config = [{packedSizes = [4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[1, 0], [1, 0], [1, 0]]}]>
// CHECK-PAD-PACK{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[1, 0], [1, 0], [1, 0]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
<storage_buffer>,
Expand All @@ -29,7 +29,7 @@ builtin.module {
// -----

// CHECK-PAD-PACK{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [0, 0, 256], [32, 32], [0, 0, 4]]>
// CHECK-PAD-PACK{LITERAL}: #packingConfig = #amdaie.packing_config<packing_config = [{packedSizes = [4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[1, 0], [1, 0], [1, 0]]}]>
// CHECK-PAD-PACK{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[1, 0], [1, 0], [1, 0]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
<storage_buffer>,
Expand Down Expand Up @@ -190,7 +190,7 @@ builtin.module {
// -----

// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #packingConfig = #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [1, 0]], outerPerm = [[0, 1], [0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
// CHECK-PACK-PEEL{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [128, 128, 128], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [1, 0]], outerPerm = [[0, 1], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 8, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
<storage_buffer>,
Expand All @@ -217,7 +217,7 @@ builtin.module {
// -----

// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #packingConfig = #amdaie.packing_config<packing_config = [{packedSizes = [44, 32, 64], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [1, 0]], outerPerm = [[0, 1], [0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
// CHECK-PACK-PEEL{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [44, 32, 64], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [1, 0]], outerPerm = [[0, 1], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
<storage_buffer>,
Expand Down Expand Up @@ -245,7 +245,7 @@ module {
// CHECK-PAD-PACK{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [0, 0, 256], [32, 32], [0, 0, 4]]>
// CHECK-PAD-PACK{LITERAL}: #packingConfig = #amdaie.packing_config<packing_config = [{packedSizes = [4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [0, 1], [0, 1]], outerPerm = [[1, 0], [1, 0], [1, 0]]}]>
// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #packingConfig = #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [0, 1]], outerPerm = [[0, 1], [0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [0, 1], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
// CHECK-PACK-PEEL{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [0, 1]], outerPerm = [[0, 1], [0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [0, 1], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
<storage_buffer>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Test generic version of matmul.

// CHECK{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1], [0, 1], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
// CHECK{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1], [1, 0], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
module {
func.func @matmul_generic_128x128x256_i32() {
%c0_i32 = arith.constant 0 : i32
Expand Down Expand Up @@ -63,7 +63,7 @@ module {
// Test generic version of matmul_transpose_a.

// CHECK{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[1, 0], [1, 0], [0, 1]], outerPerm = [[0, 1], [0, 1], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[1, 0], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
// CHECK{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[1, 0], [1, 0], [0, 1]], outerPerm = [[1, 0], [1, 0], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[1, 0], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
module {
func.func @matmul_transpose_a_generic_128x128x256_i32() {
%c0_i32 = arith.constant 0 : i32
Expand Down
Loading

0 comments on commit 53b96d5

Please sign in to comment.