Skip to content

Commit

Permalink
MLIR: Enable importing inlineasm calls (llvm#121624)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 5, 2025
1 parent f48884d commit b5f2167
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 56 deletions.
6 changes: 5 additions & 1 deletion mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,13 @@ class ModuleImport {
/// Appends the converted result type and operands of `callInst` to the
/// `types` and `operands` arrays. For indirect calls, the method additionally
/// inserts the called function at the beginning of the `operands` array.
/// If `allowInlineAsm` is set to false (the default), it will return failure
/// if the called operand is an inline asm which isn't convertible to MLIR as
/// a value.
LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
SmallVectorImpl<Type> &types,
SmallVectorImpl<Value> &operands);
SmallVectorImpl<Value> &operands,
bool allowInlineAsm = false);
/// Converts the parameter attributes attached to `func` and adds them to the
/// `funcOp`.
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
Expand Down
109 changes: 63 additions & 46 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1473,18 +1473,20 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
return success();
}

LogicalResult
ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
SmallVectorImpl<Type> &types,
SmallVectorImpl<Value> &operands) {
LogicalResult ModuleImport::convertCallTypeAndOperands(
llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
SmallVectorImpl<Value> &operands, bool allowInlineAsm) {
if (!callInst->getType()->isVoidTy())
types.push_back(convertType(callInst->getType()));

if (!callInst->getCalledFunction()) {
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
if (failed(called))
return failure();
operands.push_back(*called);
if (!allowInlineAsm ||
!isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
if (failed(called))
return failure();
operands.push_back(*called);
}
}
SmallVector<llvm::Value *> args(callInst->args());
FailureOr<SmallVector<Value>> arguments = convertValues(args);
Expand Down Expand Up @@ -1579,53 +1581,68 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {

SmallVector<Type> types;
SmallVector<Value> operands;
if (failed(convertCallTypeAndOperands(callInst, types, operands)))
if (failed(convertCallTypeAndOperands(callInst, types, operands,
/*allowInlineAsm=*/true)))
return failure();

auto funcTy =
dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
if (!funcTy)
return failure();

CallOp callOp;

if (llvm::Function *callee = callInst->getCalledFunction()) {
callOp = builder.create<CallOp>(
loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
operands);
if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
auto callOp = builder.create<InlineAsmOp>(
loc, funcTy.getReturnType(), operands,
builder.getStringAttr(asmI->getAsmString()),
builder.getStringAttr(asmI->getConstraintString()),
/*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
/*operand_attrs=*/nullptr);
if (!callInst->getType()->isVoidTy())
mapValue(inst, callOp.getResult(0));
else
mapNoResultOp(inst, callOp);
} else {
callOp = builder.create<CallOp>(loc, funcTy, operands);
CallOp callOp;

if (llvm::Function *callee = callInst->getCalledFunction()) {
callOp = builder.create<CallOp>(
loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
operands);
} else {
callOp = builder.create<CallOp>(loc, funcTy, operands);
}
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
callOp.setTailCallKind(
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
setFastmathFlagsAttr(inst, callOp);

// Handle function attributes.
if (callInst->hasFnAttr(llvm::Attribute::Convergent))
callOp.setConvergent(true);
if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
callOp.setNoUnwind(true);
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
callOp.setWillReturn(true);

llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
ModRefInfo othermem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
ModRefInfo argMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
argMem, inaccessibleMem);
// Only set the attribute when it does not match the default value.
if (!memAttr.isReadWrite())
callOp.setMemoryEffectsAttr(memAttr);

if (!callInst->getType()->isVoidTy())
mapValue(inst, callOp.getResult());
else
mapNoResultOp(inst, callOp);
}
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
callOp.setTailCallKind(
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
setFastmathFlagsAttr(inst, callOp);

// Handle function attributes.
if (callInst->hasFnAttr(llvm::Attribute::Convergent))
callOp.setConvergent(true);
if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
callOp.setNoUnwind(true);
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
callOp.setWillReturn(true);

llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
ModRefInfo othermem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
ModRefInfo argMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem,
inaccessibleMem);
// Only set the attribute when it does not match the default value.
if (!memAttr.isReadWrite())
callOp.setMemoryEffectsAttr(memAttr);

if (!callInst->getType()->isVoidTy())
mapValue(inst, callOp.getResult());
else
mapNoResultOp(inst, callOp);
return success();
}
if (inst->getOpcode() == llvm::Instruction::LandingPad) {
Expand Down
9 changes: 0 additions & 9 deletions mlir/test/Target/LLVMIR/Import/import-failure.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,6 @@ bb2:

; // -----

; CHECK: <unknown>
; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
define i32 @unhandled_value(i32 %arg1) {
%1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
ret i32 %1
}

; // -----

; CHECK: <unknown>
; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported
; CHECK: <unknown>
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Target/LLVMIR/Import/instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,17 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) {

; // -----

; CHECK-LABEL: @inlineasm
; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
define i32 @inlineasm(i32 %arg1) {
; CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects "bswap $0", "=r,r" %[[ARG1]] : (i32) -> i32
%1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
; CHECK: return %[[RES]]
ret i32 %1
}

; // -----

; CHECK-LABEL: @gep_static_idx
; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
define void @gep_static_idx(ptr %ptr) {
Expand Down

0 comments on commit b5f2167

Please sign in to comment.