Skip to content

Commit

Permalink
[LoweringStrategy] Use a more general method to fetch input dims and …
Browse files Browse the repository at this point in the history
…sizes
  • Loading branch information
yzhang93 committed Feb 8, 2025
1 parent 2751586 commit bba961e
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,52 @@ FailureOr<std::array<uint32_t, 3>> getPackedSize(linalg::LinalgOp linalgOp,
return instructionSize;
}

struct InputDimsAndSizes {
SmallVector<unsigned, 2> mDims;
SmallVector<unsigned, 2> nDims;
SmallVector<unsigned, 2> kDims;
SmallVector<int64_t, 2> mSizes;
SmallVector<int64_t, 2> nSizes;
SmallVector<int64_t, 2> kSizes;
};

FailureOr<InputDimsAndSizes> getInputDimsAndSizes(linalg::LinalgOp linalgOp) {
FailureOr<linalg::ContractionDimensions> maybeContractionDims =
linalg::inferContractionDims(linalgOp);
if (failed(maybeContractionDims)) {
return linalgOp.emitOpError("failed to infer the contraction dimensions.");
}

linalg::ContractionDimensions contractionDims = *maybeContractionDims;
SmallVector<unsigned, 2> mDims = contractionDims.m;
SmallVector<unsigned, 2> nDims = contractionDims.n;
SmallVector<unsigned, 2> kDims = contractionDims.k;
if (mDims.empty() || nDims.empty() || kDims.empty()) {
return linalgOp.emitOpError("failed to fetch m/n/k dims.");
}

SmallVector<int64_t> shapes = linalgOp.getStaticLoopRanges();
if (mDims.size() + nDims.size() + kDims.size() > shapes.size()) {
return linalgOp.emitOpError(
"the total of m/n/k dims is larger than the number of loops.");
}

auto getSizesAt = [&shapes](const SmallVector<unsigned, 2> &idx) {
SmallVector<int64_t, 2> sizes;
for (auto i : idx) sizes.push_back(shapes[i]);
return sizes;
};

InputDimsAndSizes inputDimsAndSizes;
inputDimsAndSizes.mDims = mDims;
inputDimsAndSizes.nDims = nDims;
inputDimsAndSizes.kDims = kDims;
inputDimsAndSizes.mSizes = getSizesAt(mDims);
inputDimsAndSizes.nSizes = getSizesAt(nDims);
inputDimsAndSizes.kSizes = getSizesAt(kDims);
return inputDimsAndSizes;
}

// Container class for the tiling at level 0 (the AIE shared memory) and level 1
// (the AIE core) in the M-, N-, and K-dimensions of a matmul operation, using
// the pad-pack approach to tiling a matmul. Also contains the packing sizes for
Expand Down Expand Up @@ -156,25 +202,24 @@ FailureOr<ParameterSetting> ParameterSetting::create(
auto initType =
llvm::cast<ShapedType>(linalgOp.getDpsInitOperand(0)->get().getType());
unsigned nBitsInit = initType.getElementTypeBitWidth();
ArrayRef<int64_t> initShape = initType.getShape();

auto lhsType =
llvm::cast<ShapedType>(linalgOp.getDpsInputOperand(0)->get().getType());
unsigned nBitsLhs = lhsType.getElementTypeBitWidth();
ArrayRef<int64_t> lhsShape = lhsType.getShape();

auto rhsType =
llvm::cast<ShapedType>(linalgOp.getDpsInputOperand(1)->get().getType());
unsigned nBitsRhs = rhsType.getElementTypeBitWidth();

// Shape of the full matmul operation.
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
initShape = initShape.drop_front();
lhsShape = lhsShape.drop_front();
}
const uint64_t M = initShape[0];
const uint64_t N = initShape[1];
const uint64_t K = lhsShape[1];
auto getTotalSize = [](const SmallVector<int64_t, 2> &sizes) {
return std::accumulate(sizes.begin(), sizes.end(), 1,
std::multiplies<int64_t>());
};

// Get the shape (M, N, K) of the full Matmul operation.
auto maybeInputDimsAndSizes = getInputDimsAndSizes(linalgOp);
if (failed(maybeInputDimsAndSizes)) return failure();
int64_t M = getTotalSize(maybeInputDimsAndSizes.value().mSizes);
int64_t N = getTotalSize(maybeInputDimsAndSizes.value().nSizes);
int64_t K = getTotalSize(maybeInputDimsAndSizes.value().kSizes);

// If we are conservative with ensuring that tiles A, B, and C fit at the
// different memory levels, we should choose the scale factor based
Expand Down Expand Up @@ -389,26 +434,35 @@ static SmallVector<int64_t> setOuterPermB(bool isMatmulTransposeB,

static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp,
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols) {
// Scale the L1 K with a factor of 2 compared with the outer dimenions M and N
// to increase the L1 memory usage.
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols,
uint32_t numLoops) {
// Scale the L1 K with a factor of 2 compared with the outer dimensions M and
// N to increase the L1 memory usage.
auto maybePackPeelTiling = ParameterSetting::create(
linalgOp, /*isPackPeel=*/true, /*isObjectFifo=*/true, targetDevice,
numRows, numCols, /*kPackScaleL1=*/2);
if (failed(maybePackPeelTiling)) return failure();
auto packPeelTiling = maybePackPeelTiling.value();

// Get M, N, K dimension indices from the input indexing map.
auto maybeInputDimsAndSizes = getInputDimsAndSizes(linalgOp);
if (failed(maybeInputDimsAndSizes)) return failure();
SmallVector<unsigned, 2> mDims = maybeInputDimsAndSizes.value().mDims;
SmallVector<unsigned, 2> nDims = maybeInputDimsAndSizes.value().nDims;
SmallVector<unsigned, 2> kDims = maybeInputDimsAndSizes.value().kDims;

AMDAIEDeviceModel deviceModel = getDeviceModel(targetDevice);

// ------------------------------------------------------
// --------------- Set packing config -------------------
// ------------------------------------------------------
MLIRContext *context = entryPointFn.getContext();

SmallVector<int64_t> packedSizesL0 = packPeelTiling.getPackSizeL0();
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
packedSizesL0.insert(packedSizesL0.begin(), 0);
}
// Pack level => 1.
SmallVector<int64_t> packedSizesL0(numLoops, 0);
packedSizesL0[mDims.back()] = packPeelTiling.m0Pack;
packedSizesL0[nDims.back()] = packPeelTiling.n0Pack;
packedSizesL0[kDims.back()] = packPeelTiling.k0Pack;

// For matmul, transpose B matrix from [K N n k] to [N K k n]
// For matmul_transpose_b, we don't have to transpose the B matrix,
Expand Down Expand Up @@ -440,17 +494,11 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
outerPerm);

// Pack level => 2.
// packed size for [M, N, K, m, n, k]
SmallVector<int64_t> packedSizesL1 = {0,
0,
0,
packPeelTiling.m1Pack,
packPeelTiling.n1Pack,
packPeelTiling.k1Pack};

if (isa<linalg::BatchMatmulOp>(linalgOp)) {
packedSizesL1.insert(packedSizesL1.begin(), 0);
}
// The number of loops have increased by 3 due to the first level pack.
SmallVector<int64_t> packedSizesL1(numLoops + 3, 0);
packedSizesL1[mDims.back() + 3] = packPeelTiling.m1Pack;
packedSizesL1[nDims.back() + 3] = packPeelTiling.n1Pack;
packedSizesL1[kDims.back() + 3] = packPeelTiling.k1Pack;

// Transpose A matrix from [M K m k m0 k0] to [M K k m m0 k0]
// Transpose C matrix from [M N m n m0 n0] to [M N n m m0 n0]
Expand Down Expand Up @@ -492,18 +540,24 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
bool fitsInL2 = (l2SizeA + l2SizeB + l2SizeInit) <
(deviceModel.getMemTileSizeInBytes() * numCols);
int64_t scaleL0 = !isBatchMatmul && fitsInL2 ? 2 : 1;
SmallVector<int64_t> tileSizeLevel0 = {packPeelTiling.M0 * scaleL0,
packPeelTiling.N0 * scaleL0};
SmallVector<int64_t> tileSizeLevel1 = {numRows, numCols, 0};
SmallVector<int64_t> tileSizeLevel2 = {0, 0, 1};
SmallVector<int64_t> tileSizeLevel3 = {1, 1, 0, 0, 0, 0};

SmallVector<int64_t> tileSizeLevel0(numLoops, 0);
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
tileSizeLevel0.insert(tileSizeLevel0.begin(), 1);
tileSizeLevel1.insert(tileSizeLevel1.begin(), 0);
tileSizeLevel2.insert(tileSizeLevel2.begin(), 0);
tileSizeLevel3.insert(tileSizeLevel3.begin(), 0);
tileSizeLevel0[0] = 1;
}
tileSizeLevel0[mDims[0]] = packPeelTiling.M0 * scaleL0;
tileSizeLevel0[nDims[0]] = packPeelTiling.N0 * scaleL0;

SmallVector<int64_t> tileSizeLevel1(numLoops, 0);
tileSizeLevel1[mDims[0]] = numRows;
tileSizeLevel1[nDims[0]] = numCols;

SmallVector<int64_t> tileSizeLevel2(numLoops, 0);
tileSizeLevel2[kDims[0]] = 1;

SmallVector<int64_t> tileSizeLevel3(numLoops, 0);
tileSizeLevel3[mDims[0]] = 1;
tileSizeLevel3[nDims[0]] = 1;

TileSizesListType tileSizes = {tileSizeLevel0, tileSizeLevel1, tileSizeLevel2,
tileSizeLevel3};
Expand All @@ -518,7 +572,7 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
static LogicalResult setRootConfigForPackPeelPipeline(
mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp,
LowerToAIEPassPipeline useLowerToAIEPipeline, AMDAIEDevice targetDevice,
uint32_t numRows, uint32_t numCols) {
uint32_t numRows, uint32_t numCols, uint32_t numLoops) {
bool isObjectFifo =
useLowerToAIEPipeline == LowerToAIEPassPipeline::ObjectFifo;
auto maybePackPeelTiling =
Expand All @@ -527,15 +581,23 @@ static LogicalResult setRootConfigForPackPeelPipeline(
if (failed(maybePackPeelTiling)) return failure();
auto packPeelTiling = maybePackPeelTiling.value();

// Get M, N, K dimension indices from the input indexing map.
auto maybeInputDimsAndSizes = getInputDimsAndSizes(linalgOp);
if (failed(maybeInputDimsAndSizes)) return failure();
SmallVector<unsigned, 2> mDims = maybeInputDimsAndSizes.value().mDims;
SmallVector<unsigned, 2> nDims = maybeInputDimsAndSizes.value().nDims;
SmallVector<unsigned, 2> kDims = maybeInputDimsAndSizes.value().kDims;

// ------------------------------------------------------
// --------------- Set packing config -------------------
// ------------------------------------------------------
MLIRContext *context = entryPointFn.getContext();

SmallVector<int64_t> packedSizesL0 = packPeelTiling.getPackSizeL0();
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
packedSizesL0.insert(packedSizesL0.begin(), 0);
}
// Pack level => 1.
SmallVector<int64_t> packedSizesL0(numLoops, 0);
packedSizesL0[mDims.back()] = packPeelTiling.m0Pack;
packedSizesL0[nDims.back()] = packPeelTiling.n0Pack;
packedSizesL0[kDims.back()] = packPeelTiling.k0Pack;

// For matmul, transpose B matrix from [K N n k] to [N K k n]
// For matmul_transpose_b, we don't have to transpose the B matrix,
Expand Down Expand Up @@ -571,17 +633,11 @@ static LogicalResult setRootConfigForPackPeelPipeline(
outerPerm);

// Pack level => 2.
// packed size for [M, N, K, m, n, k]
SmallVector<int64_t> packedSizesL1 = {0,
0,
0,
packPeelTiling.m1Pack,
packPeelTiling.n1Pack,
packPeelTiling.k1Pack};

if (isa<linalg::BatchMatmulOp>(linalgOp)) {
packedSizesL1.insert(packedSizesL1.begin(), 0);
}
// The number of loops have increased by 3 due to the first level pack.
SmallVector<int64_t> packedSizesL1(numLoops + 3, 0);
packedSizesL1[mDims.back() + 3] = packPeelTiling.m1Pack;
packedSizesL1[nDims.back() + 3] = packPeelTiling.n1Pack;
packedSizesL1[kDims.back() + 3] = packPeelTiling.k1Pack;

// Transpose A matrix from [M K m k m0 k0] to [M K k m m0 k0]
// Transpose C matrix from [M N m n m0 n0] to [M N n m m0 n0]
Expand Down Expand Up @@ -611,15 +667,19 @@ static LogicalResult setRootConfigForPackPeelPipeline(
// ------------------------------------------------------
// -------------- Set lowering config -------------------
// ------------------------------------------------------
SmallVector<int64_t> tileSizeLevel0 = {packPeelTiling.M0, packPeelTiling.N0};
SmallVector<int64_t> tileSizeLevel1 = {0, 0, packPeelTiling.K0};
SmallVector<int64_t> tileSizeLevel2 = {1, 1, 0, 0, 0, 0};

SmallVector<int64_t> tileSizeLevel0(numLoops, 0);
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
tileSizeLevel0.insert(tileSizeLevel0.begin(), 1);
tileSizeLevel1.insert(tileSizeLevel1.begin(), 0);
tileSizeLevel2.insert(tileSizeLevel2.begin(), 0);
tileSizeLevel0[0] = 1;
}
tileSizeLevel0[mDims[0]] = packPeelTiling.M0;
tileSizeLevel0[nDims[0]] = packPeelTiling.N0;

SmallVector<int64_t> tileSizeLevel1(numLoops, 0);
tileSizeLevel1[kDims[0]] = 1;

SmallVector<int64_t> tileSizeLevel2(numLoops, 0);
tileSizeLevel2[mDims[0]] = 1;
tileSizeLevel2[nDims[0]] = 1;

TileSizesListType tileSizes = {tileSizeLevel0, tileSizeLevel1,
tileSizeLevel2};
Expand Down Expand Up @@ -842,6 +902,8 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
uint32_t numCols) {
assert(!getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(genericOp) &&
"expected lowering_config is not set");
unsigned numLoops = genericOp.getNumLoops();
assert(numLoops <= 7 && "expected input number of loops no more than 7");
if (!isMatmul(genericOp) && !isMatmulTransposeA(genericOp) &&
!isMatmulTransposeB(genericOp))
return genericOp.emitOpError(
Expand All @@ -850,11 +912,11 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
if (passPipeline == TilePassPipeline::PackPeelPipeline) {
return setRootConfigForPackPeelPipeline(entryPointFn, genericOp,
useLowerToAIEPipeline, targetDevice,
numRows, numCols);
numRows, numCols, numLoops);
}
if (passPipeline == TilePassPipeline::PackPeel4LevelTilingPipeline) {
return setRootConfigForPackPeel4LevelTilingPipeline(
entryPointFn, genericOp, targetDevice, numRows, numCols);
entryPointFn, genericOp, targetDevice, numRows, numCols, numLoops);
}
if (passPipeline == TilePassPipeline::PadPackPipeline) {
return setRootConfigForPadPackPipeline(entryPointFn, genericOp,
Expand All @@ -875,27 +937,19 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
"expected lowering_config is not set");
auto linalgOp = cast<linalg::LinalgOp>(contractionOp.getOperation());
unsigned numLoops = linalgOp.getNumLoops();
{
SmallVector<unsigned> dims;
linalgOp.getReductionDims(dims);
if (dims.size() != 1 || dims[0] != numLoops - 1) {
return linalgOp.emitOpError(
"is expected to have exactly one reduction dim, ")
<< "and that it is the innermost dim (" << numLoops - 1 << ").";
}
}
assert(numLoops <= 7 && "expected input number of loops no more than 7");

// TODO (nmeshram) : This needs to be moved in a separate more generalized
// logic. Also, need a flag to experiment between pad based and pack based
// approach which will have different tile sizes and pass pipelines
if (passPipeline == TilePassPipeline::PackPeelPipeline) {
return setRootConfigForPackPeelPipeline(entryPointFn, linalgOp,
useLowerToAIEPipeline, targetDevice,
numRows, numCols);
numRows, numCols, numLoops);
}
if (passPipeline == TilePassPipeline::PackPeel4LevelTilingPipeline) {
return setRootConfigForPackPeel4LevelTilingPipeline(
entryPointFn, linalgOp, targetDevice, numRows, numCols);
entryPointFn, linalgOp, targetDevice, numRows, numCols, numLoops);
}
if (passPipeline == TilePassPipeline::PadPackPipeline) {
return setRootConfigForPadPackPipeline(entryPointFn, linalgOp, targetDevice,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ builtin.module {

// -----

// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0], [0, 0, 1], [1, 1, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [128, 128, 128], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [1, 0]], outerPerm = [[0, 1], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 8, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
Expand All @@ -216,7 +216,7 @@ builtin.module {

// -----

// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[44, 128, 0], [0, 0, 1], [1, 1, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [44, 32, 64], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [1, 0]], outerPerm = [[0, 1], [1, 0]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
Expand Down Expand Up @@ -244,7 +244,7 @@ module {

// CHECK-PAD-PACK{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [0, 0, 256], [32, 32], [0, 0, 4]]>
// CHECK-PAD-PACK{LITERAL}: #packingConfig = #amdaie.packing_config<packing_config = [{packedSizes = [4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [0, 1], [0, 1]], outerPerm = [[1, 0], [1, 0], [1, 0]]}]>
// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #config = #iree_codegen.lowering_config<tile_sizes = [[128, 128, 0], [0, 0, 1], [1, 1, 0]]>
// CHECK-PACK-PEEL{LITERAL}: #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [0, 1], unpackEmpty = [false, false], innerPerm = [[0, 1], [0, 1]], outerPerm = [[0, 1], [0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [0, 1], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
<storage_buffer>,
Expand Down
Loading

0 comments on commit bba961e

Please sign in to comment.