From b5f21671ef04984bc00770263234dfb94833a274 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 5 Jan 2025 11:02:49 -0500 Subject: [PATCH] MLIR: Enable importing inlineasm calls (#121624) --- .../include/mlir/Target/LLVMIR/ModuleImport.h | 6 +- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 109 ++++++++++-------- .../Target/LLVMIR/Import/import-failure.ll | 9 -- .../test/Target/LLVMIR/Import/instructions.ll | 11 ++ 4 files changed, 79 insertions(+), 56 deletions(-) diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index eea0647895b01b..33c9af7c6335a4 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -319,9 +319,13 @@ class ModuleImport { /// Appends the converted result type and operands of `callInst` to the /// `types` and `operands` arrays. For indirect calls, the method additionally /// inserts the called function at the beginning of the `operands` array. + /// If `allowInlineAsm` is set to false (the default), it will return failure + /// if the called operand is an inline asm which isn't convertible to MLIR as + /// a value. LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst, SmallVectorImpl &types, - SmallVectorImpl &operands); + SmallVectorImpl &operands, + bool allowInlineAsm = false); /// Converts the parameter attributes attached to `func` and adds them to the /// `funcOp`. void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp, diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index b0d5e635248d3f..95fb673fc72e39 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1473,18 +1473,20 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch, return success(); } -LogicalResult -ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst, - SmallVectorImpl &types, - SmallVectorImpl &operands) { +LogicalResult ModuleImport::convertCallTypeAndOperands( + llvm::CallBase *callInst, SmallVectorImpl &types, + SmallVectorImpl &operands, bool allowInlineAsm) { if (!callInst->getType()->isVoidTy()) types.push_back(convertType(callInst->getType())); if (!callInst->getCalledFunction()) { - FailureOr called = convertValue(callInst->getCalledOperand()); - if (failed(called)) - return failure(); - operands.push_back(*called); + if (!allowInlineAsm || + !isa(callInst->getCalledOperand())) { + FailureOr called = convertValue(callInst->getCalledOperand()); + if (failed(called)) + return failure(); + operands.push_back(*called); + } } SmallVector args(callInst->args()); FailureOr> arguments = convertValues(args); @@ -1579,7 +1581,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { SmallVector types; SmallVector operands; - if (failed(convertCallTypeAndOperands(callInst, types, operands))) + if (failed(convertCallTypeAndOperands(callInst, types, operands, + /*allowInlineAsm=*/true))) return failure(); auto funcTy = @@ -1587,45 +1590,59 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { if (!funcTy) return failure(); - CallOp callOp; - - if (llvm::Function *callee = callInst->getCalledFunction()) { - callOp = builder.create( - loc, funcTy, SymbolRefAttr::get(context, callee->getName()), - operands); + if (auto asmI = dyn_cast(callInst->getCalledOperand())) { + auto callOp = builder.create( + loc, funcTy.getReturnType(), operands, + builder.getStringAttr(asmI->getAsmString()), + builder.getStringAttr(asmI->getConstraintString()), + /*has_side_effects=*/true, + /*is_align_stack=*/false, /*asm_dialect=*/nullptr, + /*operand_attrs=*/nullptr); + if (!callInst->getType()->isVoidTy()) + mapValue(inst, callOp.getResult(0)); + else + mapNoResultOp(inst, callOp); } else { - callOp = builder.create(loc, funcTy, operands); + CallOp callOp; + + if (llvm::Function *callee = callInst->getCalledFunction()) { + callOp = builder.create( + loc, funcTy, SymbolRefAttr::get(context, callee->getName()), + operands); + } else { + callOp = builder.create(loc, funcTy, operands); + } + callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv())); + callOp.setTailCallKind( + convertTailCallKindFromLLVM(callInst->getTailCallKind())); + setFastmathFlagsAttr(inst, callOp); + + // Handle function attributes. + if (callInst->hasFnAttr(llvm::Attribute::Convergent)) + callOp.setConvergent(true); + if (callInst->hasFnAttr(llvm::Attribute::NoUnwind)) + callOp.setNoUnwind(true); + if (callInst->hasFnAttr(llvm::Attribute::WillReturn)) + callOp.setWillReturn(true); + + llvm::MemoryEffects memEffects = callInst->getMemoryEffects(); + ModRefInfo othermem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::Other)); + ModRefInfo argMem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem)); + ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem)); + auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, + argMem, inaccessibleMem); + // Only set the attribute when it does not match the default value. + if (!memAttr.isReadWrite()) + callOp.setMemoryEffectsAttr(memAttr); + + if (!callInst->getType()->isVoidTy()) + mapValue(inst, callOp.getResult()); + else + mapNoResultOp(inst, callOp); } - callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv())); - callOp.setTailCallKind( - convertTailCallKindFromLLVM(callInst->getTailCallKind())); - setFastmathFlagsAttr(inst, callOp); - - // Handle function attributes. - if (callInst->hasFnAttr(llvm::Attribute::Convergent)) - callOp.setConvergent(true); - if (callInst->hasFnAttr(llvm::Attribute::NoUnwind)) - callOp.setNoUnwind(true); - if (callInst->hasFnAttr(llvm::Attribute::WillReturn)) - callOp.setWillReturn(true); - - llvm::MemoryEffects memEffects = callInst->getMemoryEffects(); - ModRefInfo othermem = convertModRefInfoFromLLVM( - memEffects.getModRef(llvm::MemoryEffects::Location::Other)); - ModRefInfo argMem = convertModRefInfoFromLLVM( - memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem)); - ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM( - memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem)); - auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem, - inaccessibleMem); - // Only set the attribute when it does not match the default value. - if (!memAttr.isReadWrite()) - callOp.setMemoryEffectsAttr(memAttr); - - if (!callInst->getType()->isVoidTy()) - mapValue(inst, callOp.getResult()); - else - mapNoResultOp(inst, callOp); return success(); } if (inst->getOpcode() == llvm::Instruction::LandingPad) { diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll index 6bde174642d540..b616cb81e0a8a5 100644 --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -12,15 +12,6 @@ bb2: ; // ----- -; CHECK: -; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r" -define i32 @unhandled_value(i32 %arg1) { - %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1) - ret i32 %1 -} - -; // ----- - ; CHECK: ; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported ; CHECK: diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll index fff48bbc486bc1..7377e2584110b5 100644 --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -535,6 +535,17 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) { ; // ----- +; CHECK-LABEL: @inlineasm +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +define i32 @inlineasm(i32 %arg1) { + ; CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects "bswap $0", "=r,r" %[[ARG1]] : (i32) -> i32 + %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1) + ; CHECK: return %[[RES]] + ret i32 %1 +} + +; // ----- + ; CHECK-LABEL: @gep_static_idx ; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]] define void @gep_static_idx(ptr %ptr) {