Skip to content

Commit

Permalink
Replaced dyn_cast in lib/Dialect/AIEVec
Browse files Browse the repository at this point in the history
  • Loading branch information
abisca committed Feb 7, 2024
1 parent bbc3717 commit af4be2f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions lib/Dialect/AIEVec/IR/AIEVecOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1171,14 +1171,14 @@ ConcatOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
adaptor.getSources().end());
unsigned totalLength = 0;
for (auto source : srcs) {
VectorType type = llvm::dyn_cast<VectorType>(source.getType());
VectorType type = llvm::cast<VectorType>(source.getType());
assert(type.getRank() == 1 &&
"only rank 1 vectors currently supported by concat");
totalLength += type.getDimSize(0);
}
inferredReturnTypes.push_back(VectorType::get(
{totalLength},
srcs[0].getType().dyn_cast<VectorType>().getElementType()));
srcs[0].getType().cast<VectorType>().getElementType()));
return success();
}

Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/AIEVec/Transforms/AIEVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2185,17 +2185,17 @@ static void fuseMulFMAOpsForInt16(Operation *Op, VectState *state) {
// lhs of current FMAOp should be an upd operation with 512-bit vector width.
// For AIE-ML, we can directly load 512 bits vectors. Thus, we can delete the
// upd operation with index 1.
auto lUpdOp = dyn_cast<aievec::UPDOp>(lhs.getDefiningOp());
auto lUpdOp = cast<aievec::UPDOp>(lhs.getDefiningOp());
if (lUpdOp.getIndex() == 1) {
auto lUpdOp0 = dyn_cast<aievec::UPDOp>(lUpdOp.getVector().getDefiningOp());
auto lUpdOp0 = cast<aievec::UPDOp>(lUpdOp.getVector().getDefiningOp());
lUpdOp->replaceAllUsesWith(lUpdOp0);
lUpdOp->erase();
}

// 2. Deal with the rhs:
// Since vector size of current FMAOp rhs is 256 bits, we need to generate a
// concat op to make the vector size to 512 bits.
auto rUpdOp = dyn_cast<aievec::UPDOp>(curOp->getOperand(1).getDefiningOp());
auto rUpdOp = cast<aievec::UPDOp>(curOp->getOperand(1).getDefiningOp());
state->builder.setInsertionPointAfter(rUpdOp);
AIEVecAttributes rstat = getOperandVecStats(curOp, state, 1);
assert(rstat.vecSizeInBits % 256 == 0);
Expand All @@ -2212,7 +2212,7 @@ static void fuseMulFMAOpsForInt16(Operation *Op, VectState *state) {
Operation *convOp = nullptr;
Operation *mulOrFMAOp = Op->getOperand(2).getDefiningOp();
auto mulOp = dyn_cast<aievec::MulOp>(mulOrFMAOp);
auto fmaOp = dyn_cast<aievec::FMAOp>(mulOrFMAOp);
auto fmaOp = cast<aievec::FMAOp>(mulOrFMAOp);
int32_t zStart;

if (mulOp) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/AIEVec/Transforms/FoldMulAddChainToConvOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ struct LongestConvMACChainAnalysis {
isa<aievec::ExtOp>(opBwdSlices[sliceSz - 1]))) {
convMacRhs = opBwdSlices[sliceSz - 3]->getOperand(0);
convMacBcastIdx =
dyn_cast<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]).getIdx();
cast<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]).getIdx();
return true;
}
}
Expand Down
14 changes: 7 additions & 7 deletions lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
"shiftIndex must be power of 2");

Location loc = srcOp.getLoc();
auto vType = dyn_cast<VectorType>(curValue.getType());
auto vType = cast<VectorType>(curValue.getType());
Type scalarType = vType.getElementType();
Type vecType = curValue.getType();
DstOpTy curOp = nullptr;
Expand Down Expand Up @@ -1341,7 +1341,7 @@ struct LowerVectorCmpOpToAIEVecCmpOp : OpConversionPattern<SrcOpTy> {
if (!aieCmpOp)
return failure();

VectorType resultType = dyn_cast<VectorType>(srcOp.getResult().getType());
VectorType resultType = cast<VectorType>(srcOp.getResult().getType());
// Convert vector i1 type to unsigned interger type by built-in unrealized
// conversion cast op.
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
Expand Down Expand Up @@ -1571,7 +1571,7 @@ struct LowerVectorReductionAddBfloat16Op
Location loc = srcOp.getLoc();
Type accType = getVectorOpDestType(vType, /*AIEML =*/true);
unsigned accWidth =
dyn_cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();
cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();

auto upsOp =
rewriter.create<aievec::UPSOp>(loc, accType, srcOp.getVector());
Expand Down Expand Up @@ -2026,8 +2026,8 @@ struct LowerExtOpPattern : OpConversionPattern<SrcOpTy> {
LogicalResult
matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType srcType = dyn_cast<VectorType>(extOp.getIn().getType());
VectorType dstType = dyn_cast<VectorType>(extOp.getOut().getType());
VectorType srcType = cast<VectorType>(extOp.getIn().getType());
VectorType dstType = cast<VectorType>(extOp.getOut().getType());

auto accType = getVectorOpDestType(srcType, /*AIEML =*/true);
auto upsOp =
Expand Down Expand Up @@ -2057,8 +2057,8 @@ struct LowerTruncOpPattern : OpConversionPattern<SrcOpTy> {
LogicalResult
matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType srcType = dyn_cast<VectorType>(truncOp.getIn().getType());
VectorType dstType = dyn_cast<VectorType>(truncOp.getOut().getType());
VectorType srcType = cast<VectorType>(truncOp.getIn().getType());
VectorType dstType = cast<VectorType>(truncOp.getOut().getType());
Type scalarType = srcType.getElementType();
unsigned elWidth = scalarType.getIntOrFloatBitWidth();

Expand Down

0 comments on commit af4be2f

Please sign in to comment.