Skip to content

Commit c19436e

Browse files
authored
[mlir][spirv] Fix a crash of typeConverter with non supported type (#79955)
Fixes a crash in the `convert-to-spirv-llvm` pass caused by unsupported types (e.g. `spirv.matrix` ). This PR fixes it by checking the converted type. Fixes #60017
1 parent ff4636a commit c19436e

File tree

1 file changed

+64
-55
lines changed

1 file changed

+64
-55
lines changed

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

+64-55
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
240240
if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
241241
auto dstType = typeConverter.convertType(loadOp.getType());
242242
if (!dstType)
243-
return failure();
243+
return rewriter.notifyMatchFailure(op, "type conversion failed");
244244
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
245245
loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
246246
isVolatile, isNonTemporal);
@@ -357,22 +357,23 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
357357
ConversionPatternRewriter &rewriter) const override {
358358
auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
359359
if (!dstType)
360-
return failure();
360+
return rewriter.notifyMatchFailure(op, "type conversion failed");
361361
// To use GEP we need to add a first 0 index to go through the pointer.
362362
auto indices = llvm::to_vector<4>(adaptor.getIndices());
363363
Type indexType = op.getIndices().front().getType();
364364
auto llvmIndexType = typeConverter.convertType(indexType);
365365
if (!llvmIndexType)
366-
return failure();
366+
return rewriter.notifyMatchFailure(op, "type conversion failed");
367367
Value zero = rewriter.create<LLVM::ConstantOp>(
368368
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
369369
indices.insert(indices.begin(), zero);
370-
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(
371-
op, dstType,
372-
typeConverter.convertType(
373-
cast<spirv::PointerType>(op.getBasePtr().getType())
374-
.getPointeeType()),
375-
adaptor.getBasePtr(), indices);
370+
371+
auto elementType = typeConverter.convertType(
372+
cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
373+
if (!elementType)
374+
return rewriter.notifyMatchFailure(op, "type conversion failed");
375+
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
376+
adaptor.getBasePtr(), indices);
376377
return success();
377378
}
378379
};
@@ -386,7 +387,7 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
386387
ConversionPatternRewriter &rewriter) const override {
387388
auto dstType = typeConverter.convertType(op.getPointer().getType());
388389
if (!dstType)
389-
return failure();
390+
return rewriter.notifyMatchFailure(op, "type conversion failed");
390391
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
391392
op.getVariable());
392393
return success();
@@ -404,7 +405,7 @@ class BitFieldInsertPattern
404405
auto srcType = op.getType();
405406
auto dstType = typeConverter.convertType(srcType);
406407
if (!dstType)
407-
return failure();
408+
return rewriter.notifyMatchFailure(op, "type conversion failed");
408409
Location loc = op.getLoc();
409410

410411
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -451,7 +452,7 @@ class ConstantScalarAndVectorPattern
451452

452453
auto dstType = typeConverter.convertType(srcType);
453454
if (!dstType)
454-
return failure();
455+
return rewriter.notifyMatchFailure(constOp, "type conversion failed");
455456

456457
// SPIR-V constant can be a signed/unsigned integer, which has to be
457458
// casted to signless integer when converting to LLVM dialect. Removing the
@@ -492,7 +493,7 @@ class BitFieldSExtractPattern
492493
auto srcType = op.getType();
493494
auto dstType = typeConverter.convertType(srcType);
494495
if (!dstType)
495-
return failure();
496+
return rewriter.notifyMatchFailure(op, "type conversion failed");
496497
Location loc = op.getLoc();
497498

498499
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -545,7 +546,7 @@ class BitFieldUExtractPattern
545546
auto srcType = op.getType();
546547
auto dstType = typeConverter.convertType(srcType);
547548
if (!dstType)
548-
return failure();
549+
return rewriter.notifyMatchFailure(op, "type conversion failed");
549550
Location loc = op.getLoc();
550551

551552
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -621,7 +622,7 @@ class CompositeExtractPattern
621622
ConversionPatternRewriter &rewriter) const override {
622623
auto dstType = this->typeConverter.convertType(op.getType());
623624
if (!dstType)
624-
return failure();
625+
return rewriter.notifyMatchFailure(op, "type conversion failed");
625626

626627
Type containerType = op.getComposite().getType();
627628
if (isa<VectorType>(containerType)) {
@@ -653,7 +654,7 @@ class CompositeInsertPattern
653654
ConversionPatternRewriter &rewriter) const override {
654655
auto dstType = this->typeConverter.convertType(op.getType());
655656
if (!dstType)
656-
return failure();
657+
return rewriter.notifyMatchFailure(op, "type conversion failed");
657658

658659
Type containerType = op.getComposite().getType();
659660
if (isa<VectorType>(containerType)) {
@@ -680,13 +681,13 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
680681
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
681682

682683
LogicalResult
683-
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
684+
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
684685
ConversionPatternRewriter &rewriter) const override {
685-
auto dstType = this->typeConverter.convertType(operation.getType());
686+
auto dstType = this->typeConverter.convertType(op.getType());
686687
if (!dstType)
687-
return failure();
688+
return rewriter.notifyMatchFailure(op, "type conversion failed");
688689
rewriter.template replaceOpWithNewOp<LLVMOp>(
689-
operation, dstType, adaptor.getOperands(), operation->getAttrs());
690+
op, dstType, adaptor.getOperands(), op->getAttrs());
690691
return success();
691692
}
692693
};
@@ -790,7 +791,7 @@ class GlobalVariablePattern
790791
auto srcType = cast<spirv::PointerType>(op.getType());
791792
auto dstType = typeConverter.convertType(srcType.getPointeeType());
792793
if (!dstType)
793-
return failure();
794+
return rewriter.notifyMatchFailure(op, "type conversion failed");
794795

795796
// Limit conversion to the current invocation only or `StorageBuffer`
796797
// required by SPIR-V runner.
@@ -843,23 +844,23 @@ class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
843844
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
844845

845846
LogicalResult
846-
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
847+
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
847848
ConversionPatternRewriter &rewriter) const override {
848849

849-
Type fromType = operation.getOperand().getType();
850-
Type toType = operation.getType();
850+
Type fromType = op.getOperand().getType();
851+
Type toType = op.getType();
851852

852853
auto dstType = this->typeConverter.convertType(toType);
853854
if (!dstType)
854-
return failure();
855+
return rewriter.notifyMatchFailure(op, "type conversion failed");
855856

856857
if (getBitWidth(fromType) < getBitWidth(toType)) {
857-
rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
858+
rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
858859
adaptor.getOperands());
859860
return success();
860861
}
861862
if (getBitWidth(fromType) > getBitWidth(toType)) {
862-
rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
863+
rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
863864
adaptor.getOperands());
864865
return success();
865866
}
@@ -883,6 +884,8 @@ class FunctionCallPattern
883884

884885
// Function returns a single result.
885886
auto dstType = typeConverter.convertType(callOp.getType(0));
887+
if (!dstType)
888+
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
886889
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
887890
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
888891
return success();
@@ -896,16 +899,15 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
896899
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
897900

898901
LogicalResult
899-
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
902+
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
900903
ConversionPatternRewriter &rewriter) const override {
901904

902-
auto dstType = this->typeConverter.convertType(operation.getType());
905+
auto dstType = this->typeConverter.convertType(op.getType());
903906
if (!dstType)
904-
return failure();
907+
return rewriter.notifyMatchFailure(op, "type conversion failed");
905908

906909
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
907-
operation, dstType, predicate, operation.getOperand1(),
908-
operation.getOperand2());
910+
op, dstType, predicate, op.getOperand1(), op.getOperand2());
909911
return success();
910912
}
911913
};
@@ -917,16 +919,15 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
917919
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
918920

919921
LogicalResult
920-
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
922+
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
921923
ConversionPatternRewriter &rewriter) const override {
922924

923-
auto dstType = this->typeConverter.convertType(operation.getType());
925+
auto dstType = this->typeConverter.convertType(op.getType());
924926
if (!dstType)
925-
return failure();
927+
return rewriter.notifyMatchFailure(op, "type conversion failed");
926928

927929
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
928-
operation, dstType, predicate, operation.getOperand1(),
929-
operation.getOperand2());
930+
op, dstType, predicate, op.getOperand1(), op.getOperand2());
930931
return success();
931932
}
932933
};
@@ -942,7 +943,7 @@ class InverseSqrtPattern
942943
auto srcType = op.getType();
943944
auto dstType = typeConverter.convertType(srcType);
944945
if (!dstType)
945-
return failure();
946+
return rewriter.notifyMatchFailure(op, "type conversion failed");
946947

947948
Location loc = op.getLoc();
948949
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
@@ -1000,7 +1001,7 @@ class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
10001001
auto srcType = notOp.getType();
10011002
auto dstType = this->typeConverter.convertType(srcType);
10021003
if (!dstType)
1003-
return failure();
1004+
return rewriter.notifyMatchFailure(notOp, "type conversion failed");
10041005

10051006
Location loc = notOp.getLoc();
10061007
IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
@@ -1226,18 +1227,18 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
12261227
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
12271228

12281229
LogicalResult
1229-
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
1230+
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
12301231
ConversionPatternRewriter &rewriter) const override {
12311232

1232-
auto dstType = this->typeConverter.convertType(operation.getType());
1233+
auto dstType = this->typeConverter.convertType(op.getType());
12331234
if (!dstType)
1234-
return failure();
1235+
return rewriter.notifyMatchFailure(op, "type conversion failed");
12351236

1236-
Type op1Type = operation.getOperand1().getType();
1237-
Type op2Type = operation.getOperand2().getType();
1237+
Type op1Type = op.getOperand1().getType();
1238+
Type op2Type = op.getOperand2().getType();
12381239

12391240
if (op1Type == op2Type) {
1240-
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1241+
rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
12411242
adaptor.getOperands());
12421243
return success();
12431244
}
@@ -1250,7 +1251,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
12501251
if (!dstTypeWidth || !op2TypeWidth)
12511252
return failure();
12521253

1253-
Location loc = operation.getLoc();
1254+
Location loc = op.getLoc();
12541255
Value extended;
12551256
if (op2TypeWidth < dstTypeWidth) {
12561257
if (isUnsignedIntegerOrVector(op2Type)) {
@@ -1268,7 +1269,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
12681269

12691270
Value result = rewriter.template create<LLVMOp>(
12701271
loc, dstType, adaptor.getOperand1(), extended);
1271-
rewriter.replaceOp(operation, result);
1272+
rewriter.replaceOp(op, result);
12721273
return success();
12731274
}
12741275
};
@@ -1282,7 +1283,7 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
12821283
ConversionPatternRewriter &rewriter) const override {
12831284
auto dstType = typeConverter.convertType(tanOp.getType());
12841285
if (!dstType)
1285-
return failure();
1286+
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
12861287

12871288
Location loc = tanOp.getLoc();
12881289
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
@@ -1308,7 +1309,7 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
13081309
auto srcType = tanhOp.getType();
13091310
auto dstType = typeConverter.convertType(srcType);
13101311
if (!dstType)
1311-
return failure();
1312+
return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
13121313

13131314
Location loc = tanhOp.getLoc();
13141315
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
@@ -1342,17 +1343,23 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
13421343

13431344
auto dstType = typeConverter.convertType(srcType);
13441345
if (!dstType)
1345-
return failure();
1346+
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
13461347

13471348
Location loc = varOp.getLoc();
13481349
Value size = createI32ConstantOf(loc, rewriter, 1);
13491350
if (!init) {
1350-
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(
1351-
varOp, dstType, typeConverter.convertType(pointerTo), size);
1351+
auto elementType = typeConverter.convertType(pointerTo);
1352+
if (!elementType)
1353+
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1354+
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1355+
size);
13521356
return success();
13531357
}
1354-
Value allocated = rewriter.create<LLVM::AllocaOp>(
1355-
loc, dstType, typeConverter.convertType(pointerTo), size);
1358+
auto elementType = typeConverter.convertType(pointerTo);
1359+
if (!elementType)
1360+
return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1361+
Value allocated =
1362+
rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
13561363
rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
13571364
rewriter.replaceOp(varOp, allocated);
13581365
return success();
@@ -1373,7 +1380,7 @@ class BitcastConversionPattern
13731380
ConversionPatternRewriter &rewriter) const override {
13741381
auto dstType = typeConverter.convertType(bitcastOp.getType());
13751382
if (!dstType)
1376-
return failure();
1383+
return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
13771384

13781385
// LLVM's opaque pointers do not require bitcasts.
13791386
if (isa<LLVM::LLVMPointerType>(dstType)) {
@@ -1499,6 +1506,8 @@ class VectorShufflePattern
14991506
}
15001507

15011508
auto dstType = typeConverter.convertType(op.getType());
1509+
if (!dstType)
1510+
return rewriter.notifyMatchFailure(op, "type conversion failed");
15021511
auto scalarType = cast<VectorType>(dstType).getElementType();
15031512
auto componentsArray = components.getValue();
15041513
auto *context = rewriter.getContext();

0 commit comments

Comments
 (0)