Skip to content

Commit

Permalink
[SplitLogicalObjectFifos] Add support for dma tranposed on the target…
Browse files Browse the repository at this point in the history
… side (nod-ai#850)

Previous PR nod-ai#812 added
support to transpose dma dimensions on the target side for control code
optimization. However, this new dma addressing wasn't supported in
`SplitLogicalObjectFifosForConnectionReuse` pass.

This PR keeps the original logic and most of the original codes while
adding the support for the new dma format.
  • Loading branch information
yzhang93 authored Oct 28, 2024
1 parent 08961d2 commit f16b450
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,17 +269,6 @@ static LogicalResult checkWhetherSplitIsPossible(
fetchL3ToL2DmaCpyNdOp(l2ToL1DmaOps[0]);
if (failed(maybeL3ToL2DmaOp)) return failure();
AMDAIE::DmaCpyNdOp l3ToL2DmaOp = maybeL3ToL2DmaOp.value();
if ((l3ToL2DmaOp.getTargetMixedOffsets().size() !=
l3ToL2DmaOp.getSourceMixedOffsets().size()) ||
(l3ToL2DmaOp.getTargetMixedSizes().size() !=
l3ToL2DmaOp.getSourceMixedSizes().size()) ||
(l3ToL2DmaOp.getTargetMixedStrides().size() !=
l3ToL2DmaOp.getSourceMixedStrides().size())) {
LLVM_DEBUG(llvm::dbgs() << "dimensionality of source and target's "
"offset/size/stride found different for "
<< l3ToL2DmaOp << "\n");
return failure();
}

SmallVector<OpFoldResult, 4> staticL2AsTargetSizes =
l3ToL2DmaOp.getTargetMixedSizes();
Expand Down Expand Up @@ -353,39 +342,81 @@ LogicalResult splitLogicalObjectFifos(
toBeErased.insert(sourceAllocOp);
toBeErased.insert(sourceObjectFifo);

SmallVector<OpFoldResult, 4> staticL2AsTargetOffsets =
SmallVector<OpFoldResult> staticL2AsTargetOffsets =
l3ToL2DmaOp.getTargetMixedOffsets();
SmallVector<OpFoldResult, 4> staticL2AsTargetSizes =
SmallVector<OpFoldResult> staticL2AsTargetSizes =
l3ToL2DmaOp.getTargetMixedSizes();
SmallVector<OpFoldResult, 4> staticL3AsSourceOffsets =
SmallVector<OpFoldResult> staticL3AsSourceOffsets =
l3ToL2DmaOp.getSourceMixedOffsets();
SmallVector<OpFoldResult, 4> staticL3AsSourceSizes =
SmallVector<OpFoldResult> staticL3AsSourceSizes =
l3ToL2DmaOp.getSourceMixedSizes();

LogicalObjectFifoFromMemrefOp l2TargetObjectFifo =
l3ToL2DmaOp.getTargetObjectFifo();
ArrayRef<int64_t> l2TargetShape =
l2TargetObjectFifo.getMemrefType().getShape();
if (l2TargetShape.size() != staticL2AsTargetSizes.size()) {
LLVM_DEBUG(llvm::dbgs() << "L2 target size should be the same");
return failure();
}

// Check if the L3->L2 dma is transposed on the target side.
bool dmaTransposeOnSource = true;
for (auto [s1, s2] : llvm::zip_equal(l2TargetShape, staticL2AsTargetSizes)) {
if (s1 != getConstantIntValue(s2)) {
dmaTransposeOnSource = false;
break;
}
}
if (staticL3AsSourceSizes.size() != staticL2AsTargetSizes.size()) {
dmaTransposeOnSource = false;
}

OpFoldResult zeroVal = getAsIndexOpFoldResult(context, 0);
OpFoldResult oneVal = getAsIndexOpFoldResult(context, 1);
// Update split dimensions' offset/size for L2 as target and L3 as source. We
// can afford to do this here because it's going to be the same for all L3->L2
// splits. Here we are setting offset = 0 and size = 1.
for (size_t dim : splitDimsForL2) {
staticL2AsTargetOffsets[dim] = zeroVal;
staticL2AsTargetSizes[dim] = oneVal;
staticL3AsSourceOffsets[dim] = zeroVal;
staticL3AsSourceSizes[dim] = oneVal;

if (dmaTransposeOnSource) {
// Update split dimensions' offset/size for L2 as target and L3 as source.
// We can afford to do this here because it's going to be the same for all
// L3->L2 splits. Here we are setting offset = 0 and size = 1.
for (size_t dim : splitDimsForL2) {
staticL2AsTargetOffsets[dim] = zeroVal;
staticL2AsTargetSizes[dim] = oneVal;
staticL3AsSourceOffsets[dim] = zeroVal;
staticL3AsSourceSizes[dim] = oneVal;
}
} else {
// The L2 target side has transposed dimensions, while the L3 source side
// data are continuous and don't have `nonSplitDim`. Then the L3 source
// sizes need to be modified to match the new L2 target sizes.
// Hardcoded the transposed dimensions for now.
const SmallVector<size_t> transposeDim = {0, 2, 1, 3};
for (auto &&[splitDim, nonSplitdim] :
llvm::zip_equal(splitDimsForL2, nonSplitDimsForL2)) {
staticL2AsTargetOffsets[transposeDim[splitDim]] = zeroVal;
staticL2AsTargetSizes[transposeDim[splitDim]] = oneVal;
staticL3AsSourceSizes[splitDim] =
staticL2AsTargetSizes[transposeDim[nonSplitdim]];
}
}

// Traverse each L2->L1 DmaCpyNd op and split them.
for (AMDAIE::DmaCpyNdOp l2ToL1DmaOp : l2ToL1DmaOps) {
SmallVector<OpFoldResult, 6> staticL2AsSourceOffsets =
SmallVector<OpFoldResult> staticL2AsSourceOffsets =
l2ToL1DmaOp.getSourceMixedOffsets();
SmallVector<OpFoldResult, 6> staticL2AsSourceSizes =
SmallVector<OpFoldResult> staticL2AsSourceSizes =
l2ToL1DmaOp.getSourceMixedSizes();

// Now we'll create a new L2 buffer based on the new shape inferred earlier
// via `staticL2AsTargetSizes`.
LogicalObjectFifoFromMemrefOp oldL2ObjectFifo =
l2ToL1DmaOp.getSourceObjectFifo();
AMDAIE::LogicalObjectFifoFromMemrefOp source = createNewLogicalObjectFifo(
rewriter, oldL2ObjectFifo, staticL2AsTargetSizes);
// If the dma transpose is on the source(target) side, then the L2
// target(source) side has the sizes in order.
SmallVector<OpFoldResult> newL2Sizes =
dmaTransposeOnSource ? staticL2AsTargetSizes : staticL2AsSourceSizes;
AMDAIE::LogicalObjectFifoFromMemrefOp source =
createNewLogicalObjectFifo(rewriter, oldL2ObjectFifo, newL2Sizes);

// --------------------------------------------
// ---------- L3 -> L2 splitting --------------
Expand All @@ -404,15 +435,19 @@ LogicalResult splitLogicalObjectFifos(
<< splitDim;
}
std::optional<int64_t> constantSize =
getConstantIntValue(staticL2AsTargetSizes[nonSplitdim]);
getConstantIntValue(newL2Sizes[nonSplitdim]);
if (!constantSize) {
return l3ToL2DmaOp->emitOpError()
<< "found a non-constant value for target size at dim "
<< nonSplitdim;
}
int64_t offsetToAdd = constantOffset.value() * constantSize.value();

// If the dma transpose is on the target side, L3 source side data are
// continuous and don't have `nonSplitDim`.
size_t dim = dmaTransposeOnSource ? nonSplitdim : splitDim;
FailureOr<OpFoldResult> newOffset = updateL3SourceOffset(
rewriter, staticL3AsSourceOffsets[nonSplitdim], offsetToAdd, context);
rewriter, staticL3AsSourceOffsets[dim], offsetToAdd, context);
if (failed(newOffset)) {
// TODO: Ideally we should be able to handle even +, -, *, /, etc.
// But handle this later (if at all!) as such cases might not
Expand All @@ -421,8 +456,9 @@ LogicalResult splitLogicalObjectFifos(
<< "Unhandled expression for source offset at dim "
<< nonSplitdim;
}
staticL3AsSourceOffsets[nonSplitdim] = *newOffset;
staticL3AsSourceOffsets[dim] = *newOffset;
}

// Create new L3 -> L2 Dma Op.
rewriter.setInsertionPoint(l3ToL2DmaOp);
rewriter.create<AMDAIE::DmaCpyNdOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,14 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager,
// cause 'aie.dma_bd' error, so for now keep using transpose on source for
// both pack and unpack ops.
// TODO(vivian): explore the other options for conv ops.
AMDAIEConvertToDmaOptions dmaOptions;
dmaOptions.packTransposeOnSource =
(useTilePipeline == TilePassPipeline::ConvDecomposePipeline) ? true
: false;
dmaOptions.unpackTransposeOnSource = true;
passManager.addPass(createAMDAIEConvertToDmaPass(dmaOptions));
{
AMDAIEConvertToDmaOptions dmaOptions;
dmaOptions.packTransposeOnSource =
(useTilePipeline == TilePassPipeline::ConvDecomposePipeline) ? true
: false;
dmaOptions.unpackTransposeOnSource = true;
passManager.addPass(createAMDAIEConvertToDmaPass(dmaOptions));
}

passManager.addPass(createAMDAIENormalizeLoopBoundsPass());
passManager.addPass(createAMDAIEInsertCoresPass());
Expand All @@ -570,6 +572,7 @@ void addAMDAIEObjectFifoLoweringPasses(OpPassManager &passManager,
passManager.addPass(createAMDAIEDistributeCoresAndObjectFifosPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());

passManager.addPass(createAMDAIESplitLogicalObjFifosForConnectionReusePass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());
Expand Down
Loading

0 comments on commit f16b450

Please sign in to comment.