Skip to content

Commit

Permalink
[KernelDispatch] Refactor the pass and utils for generic ops
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhang93 committed Dec 17, 2024
1 parent f7cd097 commit fd2d65d
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,9 @@ static SmallVector<int64_t> setInnerPermB(bool isMatmulTransposeB) {

static LogicalResult setRootConfigForPackPeelPipeline(
mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp,
LowerToAIEPassPipeline useLowerToAIEPipeline, bool isMatmulTransposeB,
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols) {
LowerToAIEPassPipeline useLowerToAIEPipeline, AMDAIEDevice targetDevice,
uint32_t numRows, uint32_t numCols) {
bool isTransposeB = isMatmulTransposeB(linalgOp);
bool isObjectFifo =
useLowerToAIEPipeline == LowerToAIEPassPipeline::ObjectFifo;
auto maybePackPeelTiling =
Expand All @@ -359,7 +360,7 @@ static LogicalResult setRootConfigForPackPeelPipeline(
// There is no corresponding unpack for the specified pack operation
// 0 is used when unpack is empty
SmallVector<bool> unpackEmpty = {false};
SmallVector<int64_t> innerPermB = setInnerPermB(isMatmulTransposeB);
SmallVector<int64_t> innerPermB = setInnerPermB(isTransposeB);
SmallVector<SmallVector<int64_t>> innerPerm = {innerPermB};
SmallVector<int64_t> outerPermVec = {0, 1};
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
Expand Down Expand Up @@ -434,8 +435,8 @@ static LogicalResult setRootConfigForPackPeelPipeline(

static LogicalResult setRootConfigForPadPackPipeline(
mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp,
bool isMatmulTransposeB, AMDAIEDevice targetDevice, uint32_t numRows,
uint32_t numCols) {
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols) {
bool isTransposeB = isMatmulTransposeB(linalgOp);
auto maybePadPackTiling = ParameterSetting::create(
linalgOp, /*isPackPeel=*/false, /*isObjectFifo=*/false, targetDevice,
numRows, numCols);
Expand All @@ -454,7 +455,7 @@ static LogicalResult setRootConfigForPadPackPipeline(
// For matmul_transpose_b, transpose B matrix from [N K n k] to [K N n k]
SmallVector<int64_t> transposePackIndices{0, 1, 2};
SmallVector<bool> unpackEmpty{false, false, true};
SmallVector<int64_t> innerPermB = setInnerPermB(isMatmulTransposeB);
SmallVector<int64_t> innerPermB = setInnerPermB(isTransposeB);
SmallVector<SmallVector<int64_t>> innerPerm{{0, 1}, innerPermB, {0, 1}};
SmallVector<SmallVector<int64_t>> outerPerm{{1, 0}, {1, 0}, {1, 0}};

Expand Down Expand Up @@ -629,99 +630,12 @@ static LogicalResult setRootConfigForConvDecomposePipeline(
IREE::Codegen::DispatchLoweringPassPipeline::Custom);
}

//===----------------------------------------------------------------------===//
// Configuration for Matmul-Transpose
//===----------------------------------------------------------------------===//

/// TODO(avarma): This currently is skipping checking for ext* ops.
static bool bodyMatcherForMatmulTranspose(Value yieldVal, Block *body) {
Operation *addOp = yieldVal.getDefiningOp();
if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp)) {
return false;
}
Operation *mulOp = addOp->getOperand(1).getDefiningOp();
if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp)) {
return false;
}
auto lhsBlockArg = dyn_cast<BlockArgument>(mulOp->getOperand(0));
auto rhsBlockArg = dyn_cast<BlockArgument>(mulOp->getOperand(1));
auto outBlockArg = dyn_cast<BlockArgument>(addOp->getOperand(0));
if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2) {
return false;
}
return true;
}

/// `isMatmulTransposeB` is a utility function that aims to indentify whether a
/// linalg.generic op is a matmul with rhs operand transposed.
static bool isMatmulTransposeB(linalg::GenericOp genericOp) {
// Step 1. Test the body of the generic to indeed be what we expect for a
// matmul transpose.
Block *body = genericOp.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
if (!bodyMatcherForMatmulTranspose(yieldVal, body)) {
return false;
}
// Step 2. Check iterator types.
SmallVector<utils::IteratorType> matmulTransposeIteratorTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel,
utils::IteratorType::reduction};
SmallVector<utils::IteratorType> opIteratorTypes =
genericOp.getIteratorTypesArray();
if (matmulTransposeIteratorTypes != opIteratorTypes) {
return false;
}
// Step 3. Test the indexing maps.
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.size() != 3) return false;

AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();

if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
return false;
}

// Extract dimensions for MxK * NxK -> MxN
AffineExpr m = map2.getResult(0);
AffineExpr n = map2.getResult(1);
AffineExpr k = map0.getResult(1);
auto *context = indexingMaps.getContext();
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
return indexingMaps == maps;
}

/// Sets the lowering configuration for a generic op implementing a
/// transposition.
static LogicalResult setTransposeLikeOpRootConfig(
mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp linalgOp,
TilePassPipeline passPipeline, LowerToAIEPassPipeline useLowerToAIEPipeline,
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols) {
if (passPipeline == TilePassPipeline::PackPeelPipeline)
return setRootConfigForPackPeelPipeline(entryPointFn, linalgOp,
useLowerToAIEPipeline, true,
targetDevice, numRows, numCols);
else if (passPipeline == TilePassPipeline::PadPackPipeline)
return setRootConfigForPadPackPipeline(entryPointFn, linalgOp, true,
targetDevice, numRows, numCols);
return linalgOp.emitError(
"Unhandled pass pipeline in setTransposeLikeOpRootConfig.");
}

//===----------------------------------------------------------------------===//
// Root Configurations
//===----------------------------------------------------------------------===//

/// Sets the lowering configuration for dispatch region with root op that
/// is a generic op.
static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
linalg::GenericOp genericOp,
TilePassPipeline passPipeline,
Expand All @@ -730,15 +644,18 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
uint32_t numCols) {
assert(!getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(genericOp) &&
"expected lowering_config is not set");
if (!isMatmul(genericOp) && !isMatmulTransposeB(genericOp))
return genericOp.emitOpError(
"Current pipelines are only set for matmul-like ops.");

if (isMatmulTransposeB(genericOp) &&
succeeded(setTransposeLikeOpRootConfig(
entryPointFn, genericOp, passPipeline, useLowerToAIEPipeline,
targetDevice, numRows, numCols))) {
return success();
}

return failure();
if (passPipeline == TilePassPipeline::PackPeelPipeline)
return setRootConfigForPackPeelPipeline(entryPointFn, genericOp,
useLowerToAIEPipeline, targetDevice,
numRows, numCols);
if (passPipeline == TilePassPipeline::PadPackPipeline)
return setRootConfigForPadPackPipeline(entryPointFn, genericOp,
targetDevice, numRows, numCols);
return genericOp.emitError("Unhandled pass pipeline in setRootConfig.");
}

/// Sets the lowering configuration for dispatch region with root op that
Expand All @@ -752,14 +669,6 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
assert(!getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(contractionOp) &&
"expected lowering_config is not set");
auto linalgOp = cast<linalg::LinalgOp>(contractionOp.getOperation());
if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
if (succeeded(setTransposeLikeOpRootConfig(
entryPointFn, linalgOp, passPipeline, useLowerToAIEPipeline,
targetDevice, numRows, numCols))) {
return success();
}
return failure();
}
unsigned numLoops = linalgOp.getNumLoops();
{
SmallVector<unsigned> dims;
Expand All @@ -776,14 +685,16 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
// approach which will have different tile sizes and pass pipelines
if (passPipeline == TilePassPipeline::PackPeelPipeline)
return setRootConfigForPackPeelPipeline(entryPointFn, linalgOp,
useLowerToAIEPipeline, false,
targetDevice, numRows, numCols);
useLowerToAIEPipeline, targetDevice,
numRows, numCols);
if (passPipeline == TilePassPipeline::PadPackPipeline)
return setRootConfigForPadPackPipeline(entryPointFn, linalgOp, false,
targetDevice, numRows, numCols);
return setRootConfigForPadPackPipeline(entryPointFn, linalgOp, targetDevice,
numRows, numCols);
return linalgOp.emitError("Unhandled pass pipeline in setRootConfig.");
}

/// Sets the lowering configuration for dispatch region with root op that
/// implements the convolution operation interface.
static LogicalResult setConvRootConfig(mlir::FunctionOpInterface entryPointFn,
linalg::ConvolutionOpInterface convOp,
TilePassPipeline passPipeline,
Expand Down
Loading

0 comments on commit fd2d65d

Please sign in to comment.