@@ -240,7 +240,7 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
240
240
if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
241
241
auto dstType = typeConverter.convertType (loadOp.getType ());
242
242
if (!dstType)
243
- return failure ( );
243
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
244
244
rewriter.replaceOpWithNewOp <LLVM::LoadOp>(
245
245
loadOp, dstType, spirv::LoadOpAdaptor (operands).getPtr (), alignment,
246
246
isVolatile, isNonTemporal);
@@ -357,22 +357,23 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
357
357
ConversionPatternRewriter &rewriter) const override {
358
358
auto dstType = typeConverter.convertType (op.getComponentPtr ().getType ());
359
359
if (!dstType)
360
- return failure ( );
360
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
361
361
// To use GEP we need to add a first 0 index to go through the pointer.
362
362
auto indices = llvm::to_vector<4 >(adaptor.getIndices ());
363
363
Type indexType = op.getIndices ().front ().getType ();
364
364
auto llvmIndexType = typeConverter.convertType (indexType);
365
365
if (!llvmIndexType)
366
- return failure ( );
366
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
367
367
Value zero = rewriter.create <LLVM::ConstantOp>(
368
368
op.getLoc (), llvmIndexType, rewriter.getIntegerAttr (indexType, 0 ));
369
369
indices.insert (indices.begin (), zero);
370
- rewriter.replaceOpWithNewOp <LLVM::GEPOp>(
371
- op, dstType,
372
- typeConverter.convertType (
373
- cast<spirv::PointerType>(op.getBasePtr ().getType ())
374
- .getPointeeType ()),
375
- adaptor.getBasePtr (), indices);
370
+
371
+ auto elementType = typeConverter.convertType (
372
+ cast<spirv::PointerType>(op.getBasePtr ().getType ()).getPointeeType ());
373
+ if (!elementType)
374
+ return rewriter.notifyMatchFailure (op, " type conversion failed" );
375
+ rewriter.replaceOpWithNewOp <LLVM::GEPOp>(op, dstType, elementType,
376
+ adaptor.getBasePtr (), indices);
376
377
return success ();
377
378
}
378
379
};
@@ -386,7 +387,7 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
386
387
ConversionPatternRewriter &rewriter) const override {
387
388
auto dstType = typeConverter.convertType (op.getPointer ().getType ());
388
389
if (!dstType)
389
- return failure ( );
390
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
390
391
rewriter.replaceOpWithNewOp <LLVM::AddressOfOp>(op, dstType,
391
392
op.getVariable ());
392
393
return success ();
@@ -404,7 +405,7 @@ class BitFieldInsertPattern
404
405
auto srcType = op.getType ();
405
406
auto dstType = typeConverter.convertType (srcType);
406
407
if (!dstType)
407
- return failure ( );
408
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
408
409
Location loc = op.getLoc ();
409
410
410
411
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -451,7 +452,7 @@ class ConstantScalarAndVectorPattern
451
452
452
453
auto dstType = typeConverter.convertType (srcType);
453
454
if (!dstType)
454
- return failure ( );
455
+ return rewriter. notifyMatchFailure (constOp, " type conversion failed " );
455
456
456
457
// SPIR-V constant can be a signed/unsigned integer, which has to be
457
458
// casted to signless integer when converting to LLVM dialect. Removing the
@@ -492,7 +493,7 @@ class BitFieldSExtractPattern
492
493
auto srcType = op.getType ();
493
494
auto dstType = typeConverter.convertType (srcType);
494
495
if (!dstType)
495
- return failure ( );
496
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
496
497
Location loc = op.getLoc ();
497
498
498
499
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -545,7 +546,7 @@ class BitFieldUExtractPattern
545
546
auto srcType = op.getType ();
546
547
auto dstType = typeConverter.convertType (srcType);
547
548
if (!dstType)
548
- return failure ( );
549
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
549
550
Location loc = op.getLoc ();
550
551
551
552
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
@@ -621,7 +622,7 @@ class CompositeExtractPattern
621
622
ConversionPatternRewriter &rewriter) const override {
622
623
auto dstType = this ->typeConverter .convertType (op.getType ());
623
624
if (!dstType)
624
- return failure ( );
625
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
625
626
626
627
Type containerType = op.getComposite ().getType ();
627
628
if (isa<VectorType>(containerType)) {
@@ -653,7 +654,7 @@ class CompositeInsertPattern
653
654
ConversionPatternRewriter &rewriter) const override {
654
655
auto dstType = this ->typeConverter .convertType (op.getType ());
655
656
if (!dstType)
656
- return failure ( );
657
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
657
658
658
659
Type containerType = op.getComposite ().getType ();
659
660
if (isa<VectorType>(containerType)) {
@@ -680,13 +681,13 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
680
681
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
681
682
682
683
LogicalResult
683
- matchAndRewrite (SPIRVOp operation , typename SPIRVOp::Adaptor adaptor,
684
+ matchAndRewrite (SPIRVOp op , typename SPIRVOp::Adaptor adaptor,
684
685
ConversionPatternRewriter &rewriter) const override {
685
- auto dstType = this ->typeConverter .convertType (operation .getType ());
686
+ auto dstType = this ->typeConverter .convertType (op .getType ());
686
687
if (!dstType)
687
- return failure ( );
688
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
688
689
rewriter.template replaceOpWithNewOp <LLVMOp>(
689
- operation , dstType, adaptor.getOperands (), operation ->getAttrs ());
690
+ op , dstType, adaptor.getOperands (), op ->getAttrs ());
690
691
return success ();
691
692
}
692
693
};
@@ -790,7 +791,7 @@ class GlobalVariablePattern
790
791
auto srcType = cast<spirv::PointerType>(op.getType ());
791
792
auto dstType = typeConverter.convertType (srcType.getPointeeType ());
792
793
if (!dstType)
793
- return failure ( );
794
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
794
795
795
796
// Limit conversion to the current invocation only or `StorageBuffer`
796
797
// required by SPIR-V runner.
@@ -843,23 +844,23 @@ class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
843
844
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
844
845
845
846
LogicalResult
846
- matchAndRewrite (SPIRVOp operation , typename SPIRVOp::Adaptor adaptor,
847
+ matchAndRewrite (SPIRVOp op , typename SPIRVOp::Adaptor adaptor,
847
848
ConversionPatternRewriter &rewriter) const override {
848
849
849
- Type fromType = operation .getOperand ().getType ();
850
- Type toType = operation .getType ();
850
+ Type fromType = op .getOperand ().getType ();
851
+ Type toType = op .getType ();
851
852
852
853
auto dstType = this ->typeConverter .convertType (toType);
853
854
if (!dstType)
854
- return failure ( );
855
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
855
856
856
857
if (getBitWidth (fromType) < getBitWidth (toType)) {
857
- rewriter.template replaceOpWithNewOp <LLVMExtOp>(operation , dstType,
858
+ rewriter.template replaceOpWithNewOp <LLVMExtOp>(op , dstType,
858
859
adaptor.getOperands ());
859
860
return success ();
860
861
}
861
862
if (getBitWidth (fromType) > getBitWidth (toType)) {
862
- rewriter.template replaceOpWithNewOp <LLVMTruncOp>(operation , dstType,
863
+ rewriter.template replaceOpWithNewOp <LLVMTruncOp>(op , dstType,
863
864
adaptor.getOperands ());
864
865
return success ();
865
866
}
@@ -883,6 +884,8 @@ class FunctionCallPattern
883
884
884
885
// Function returns a single result.
885
886
auto dstType = typeConverter.convertType (callOp.getType (0 ));
887
+ if (!dstType)
888
+ return rewriter.notifyMatchFailure (callOp, " type conversion failed" );
886
889
rewriter.replaceOpWithNewOp <LLVM::CallOp>(
887
890
callOp, dstType, adaptor.getOperands (), callOp->getAttrs ());
888
891
return success ();
@@ -896,16 +899,15 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
896
899
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
897
900
898
901
LogicalResult
899
- matchAndRewrite (SPIRVOp operation , typename SPIRVOp::Adaptor adaptor,
902
+ matchAndRewrite (SPIRVOp op , typename SPIRVOp::Adaptor adaptor,
900
903
ConversionPatternRewriter &rewriter) const override {
901
904
902
- auto dstType = this ->typeConverter .convertType (operation .getType ());
905
+ auto dstType = this ->typeConverter .convertType (op .getType ());
903
906
if (!dstType)
904
- return failure ( );
907
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
905
908
906
909
rewriter.template replaceOpWithNewOp <LLVM::FCmpOp>(
907
- operation, dstType, predicate, operation.getOperand1 (),
908
- operation.getOperand2 ());
910
+ op, dstType, predicate, op.getOperand1 (), op.getOperand2 ());
909
911
return success ();
910
912
}
911
913
};
@@ -917,16 +919,15 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
917
919
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
918
920
919
921
LogicalResult
920
- matchAndRewrite (SPIRVOp operation , typename SPIRVOp::Adaptor adaptor,
922
+ matchAndRewrite (SPIRVOp op , typename SPIRVOp::Adaptor adaptor,
921
923
ConversionPatternRewriter &rewriter) const override {
922
924
923
- auto dstType = this ->typeConverter .convertType (operation .getType ());
925
+ auto dstType = this ->typeConverter .convertType (op .getType ());
924
926
if (!dstType)
925
- return failure ( );
927
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
926
928
927
929
rewriter.template replaceOpWithNewOp <LLVM::ICmpOp>(
928
- operation, dstType, predicate, operation.getOperand1 (),
929
- operation.getOperand2 ());
930
+ op, dstType, predicate, op.getOperand1 (), op.getOperand2 ());
930
931
return success ();
931
932
}
932
933
};
@@ -942,7 +943,7 @@ class InverseSqrtPattern
942
943
auto srcType = op.getType ();
943
944
auto dstType = typeConverter.convertType (srcType);
944
945
if (!dstType)
945
- return failure ( );
946
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
946
947
947
948
Location loc = op.getLoc ();
948
949
Value one = createFPConstant (loc, srcType, dstType, rewriter, 1.0 );
@@ -1000,7 +1001,7 @@ class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1000
1001
auto srcType = notOp.getType ();
1001
1002
auto dstType = this ->typeConverter .convertType (srcType);
1002
1003
if (!dstType)
1003
- return failure ( );
1004
+ return rewriter. notifyMatchFailure (notOp, " type conversion failed " );
1004
1005
1005
1006
Location loc = notOp.getLoc ();
1006
1007
IntegerAttr minusOne = minusOneIntegerAttribute (srcType, rewriter);
@@ -1226,18 +1227,18 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1226
1227
using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1227
1228
1228
1229
LogicalResult
1229
- matchAndRewrite (SPIRVOp operation , typename SPIRVOp::Adaptor adaptor,
1230
+ matchAndRewrite (SPIRVOp op , typename SPIRVOp::Adaptor adaptor,
1230
1231
ConversionPatternRewriter &rewriter) const override {
1231
1232
1232
- auto dstType = this ->typeConverter .convertType (operation .getType ());
1233
+ auto dstType = this ->typeConverter .convertType (op .getType ());
1233
1234
if (!dstType)
1234
- return failure ( );
1235
+ return rewriter. notifyMatchFailure (op, " type conversion failed " );
1235
1236
1236
- Type op1Type = operation .getOperand1 ().getType ();
1237
- Type op2Type = operation .getOperand2 ().getType ();
1237
+ Type op1Type = op .getOperand1 ().getType ();
1238
+ Type op2Type = op .getOperand2 ().getType ();
1238
1239
1239
1240
if (op1Type == op2Type) {
1240
- rewriter.template replaceOpWithNewOp <LLVMOp>(operation , dstType,
1241
+ rewriter.template replaceOpWithNewOp <LLVMOp>(op , dstType,
1241
1242
adaptor.getOperands ());
1242
1243
return success ();
1243
1244
}
@@ -1250,7 +1251,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1250
1251
if (!dstTypeWidth || !op2TypeWidth)
1251
1252
return failure ();
1252
1253
1253
- Location loc = operation .getLoc ();
1254
+ Location loc = op .getLoc ();
1254
1255
Value extended;
1255
1256
if (op2TypeWidth < dstTypeWidth) {
1256
1257
if (isUnsignedIntegerOrVector (op2Type)) {
@@ -1268,7 +1269,7 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1268
1269
1269
1270
Value result = rewriter.template create <LLVMOp>(
1270
1271
loc, dstType, adaptor.getOperand1 (), extended);
1271
- rewriter.replaceOp (operation , result);
1272
+ rewriter.replaceOp (op , result);
1272
1273
return success ();
1273
1274
}
1274
1275
};
@@ -1282,7 +1283,7 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1282
1283
ConversionPatternRewriter &rewriter) const override {
1283
1284
auto dstType = typeConverter.convertType (tanOp.getType ());
1284
1285
if (!dstType)
1285
- return failure ( );
1286
+ return rewriter. notifyMatchFailure (tanOp, " type conversion failed " );
1286
1287
1287
1288
Location loc = tanOp.getLoc ();
1288
1289
Value sin = rewriter.create <LLVM::SinOp>(loc, dstType, tanOp.getOperand ());
@@ -1308,7 +1309,7 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1308
1309
auto srcType = tanhOp.getType ();
1309
1310
auto dstType = typeConverter.convertType (srcType);
1310
1311
if (!dstType)
1311
- return failure ( );
1312
+ return rewriter. notifyMatchFailure (tanhOp, " type conversion failed " );
1312
1313
1313
1314
Location loc = tanhOp.getLoc ();
1314
1315
Value two = createFPConstant (loc, srcType, dstType, rewriter, 2.0 );
@@ -1342,17 +1343,23 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1342
1343
1343
1344
auto dstType = typeConverter.convertType (srcType);
1344
1345
if (!dstType)
1345
- return failure ( );
1346
+ return rewriter. notifyMatchFailure (varOp, " type conversion failed " );
1346
1347
1347
1348
Location loc = varOp.getLoc ();
1348
1349
Value size = createI32ConstantOf (loc, rewriter, 1 );
1349
1350
if (!init) {
1350
- rewriter.replaceOpWithNewOp <LLVM::AllocaOp>(
1351
- varOp, dstType, typeConverter.convertType (pointerTo), size);
1351
+ auto elementType = typeConverter.convertType (pointerTo);
1352
+ if (!elementType)
1353
+ return rewriter.notifyMatchFailure (varOp, " type conversion failed" );
1354
+ rewriter.replaceOpWithNewOp <LLVM::AllocaOp>(varOp, dstType, elementType,
1355
+ size);
1352
1356
return success ();
1353
1357
}
1354
- Value allocated = rewriter.create <LLVM::AllocaOp>(
1355
- loc, dstType, typeConverter.convertType (pointerTo), size);
1358
+ auto elementType = typeConverter.convertType (pointerTo);
1359
+ if (!elementType)
1360
+ return rewriter.notifyMatchFailure (varOp, " type conversion failed" );
1361
+ Value allocated =
1362
+ rewriter.create <LLVM::AllocaOp>(loc, dstType, elementType, size);
1356
1363
rewriter.create <LLVM::StoreOp>(loc, adaptor.getInitializer (), allocated);
1357
1364
rewriter.replaceOp (varOp, allocated);
1358
1365
return success ();
@@ -1373,7 +1380,7 @@ class BitcastConversionPattern
1373
1380
ConversionPatternRewriter &rewriter) const override {
1374
1381
auto dstType = typeConverter.convertType (bitcastOp.getType ());
1375
1382
if (!dstType)
1376
- return failure ( );
1383
+ return rewriter. notifyMatchFailure (bitcastOp, " type conversion failed " );
1377
1384
1378
1385
// LLVM's opaque pointers do not require bitcasts.
1379
1386
if (isa<LLVM::LLVMPointerType>(dstType)) {
@@ -1499,6 +1506,8 @@ class VectorShufflePattern
1499
1506
}
1500
1507
1501
1508
auto dstType = typeConverter.convertType (op.getType ());
1509
+ if (!dstType)
1510
+ return rewriter.notifyMatchFailure (op, " type conversion failed" );
1502
1511
auto scalarType = cast<VectorType>(dstType).getElementType ();
1503
1512
auto componentsArray = components.getValue ();
1504
1513
auto *context = rewriter.getContext ();
0 commit comments