Skip to content

Commit

Permalink
[CIR][CIRGen] Fix calling a function through a function pointer (llvm…
Browse files Browse the repository at this point in the history
…#467)

CIR codegen always casts the no-proto function pointer to `FuncOp`. But
the function pointer may be result of cir operations (f.e. `cir.load`).
As a result in such cases the function pointer sets to `nullptr`. That
leads to compilation error.
So this PR removes the unecessary cast to 'FuncOp' and resolves the
issue.
  • Loading branch information
YazZz1k authored and lanza committed Apr 17, 2024
1 parent 54d03d2 commit 6a4b4a3
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 12 deletions.
21 changes: 19 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1145,9 +1145,26 @@ RValue CIRGenFunction::buildCall(clang::QualType CalleeType,
if (isa<FunctionNoProtoType>(FnType) || Chain) {
assert(!UnimplementedFeature::chainCalls());
assert(!UnimplementedFeature::addressSpace());
auto CalleeTy = getTypes().GetFunctionType(FnInfo);
// get non-variadic function type
CalleeTy = mlir::cir::FuncType::get(CalleeTy.getInputs(),
CalleeTy.getReturnType(), false);
auto CalleePtrTy =
mlir::cir::PointerType::get(builder.getContext(), CalleeTy);

auto *Fn = Callee.getFunctionPointer();
mlir::Value Addr;
if (auto funcOp = llvm::dyn_cast<mlir::cir::FuncOp>(Fn)) {
Addr = builder.create<mlir::cir::GetGlobalOp>(
getLoc(E->getSourceRange()),
mlir::cir::PointerType::get(builder.getContext(),
funcOp.getFunctionType()),
funcOp.getSymName());
} else {
Addr = Fn->getResult(0);
}

// Set no-proto function as callee.
auto Fn = llvm::dyn_cast<mlir::cir::FuncOp>(Callee.getFunctionPointer());
Fn = builder.createBitcast(Addr, CalleePtrTy).getDefiningOp();
Callee.setFunctionPointer(Fn);
}

Expand Down
12 changes: 6 additions & 6 deletions clang/test/CIR/CodeGen/agg-copy.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ void foo4(A* a1) {
A create() { A a; return a; }

// CHECK: cir.func {{.*@foo5}}
// CHECK: [[TMP0]] = cir.alloca !ty_22A22, cir.ptr <!ty_22A22>,
// CHECK: [[TMP1]] = cir.alloca !ty_22A22, cir.ptr <!ty_22A22>, ["tmp"] {alignment = 4 : i64}
// CHECK: [[TMP2]] = cir.call @create() : () -> !ty_22A22
// CHECK: [[TMP0:%.*]] = cir.alloca !ty_22A22, cir.ptr <!ty_22A22>,
// CHECK: [[TMP1:%.*]] = cir.alloca !ty_22A22, cir.ptr <!ty_22A22>, ["tmp"] {alignment = 4 : i64}
// CHECK: [[TMP2:%.*]] = cir.call @create() : () -> !ty_22A22
// CHECK: cir.store [[TMP2]], [[TMP1]] : !ty_22A22, cir.ptr <!ty_22A22>
// CHECK: cir.copy [[TMP1]] to [[TMP0]] : !cir.ptr<!ty_22A22>
void foo5() {
Expand All @@ -77,9 +77,9 @@ void foo5() {
void foo6(A* a1) {
A a2 = (*a1);
// CHECK: cir.func {{.*@foo6}}
// CHECK: [[TMP0]] = cir.alloca !cir.ptr<!ty_22A22>, cir.ptr <!cir.ptr<!ty_22A22>>, ["a1", init] {alignment = 8 : i64}
// CHECK: [[TMP1]] = cir.alloca !ty_22A22, cir.ptr <!ty_22A22>, ["a2", init] {alignment = 4 : i64}
// CHECK: [[TMP0:%.*]] = cir.alloca !cir.ptr<!ty_22A22>, cir.ptr <!cir.ptr<!ty_22A22>>, ["a1", init] {alignment = 8 : i64}
// CHECK: [[TMP1:%.*]] = cir.alloca !ty_22A22, cir.ptr <!ty_22A22>, ["a2", init] {alignment = 4 : i64}
// CHECK: cir.store %arg0, [[TMP0]] : !cir.ptr<!ty_22A22>, cir.ptr <!cir.ptr<!ty_22A22>>
// CHECK: [[TMP2]] = cir.load deref [[TMP0]] : cir.ptr <!cir.ptr<!ty_22A22>>, !cir.ptr<!ty_22A22>
// CHECK: [[TMP2:%.*]] = cir.load deref [[TMP0]] : cir.ptr <!cir.ptr<!ty_22A22>>, !cir.ptr<!ty_22A22>
// CHECK: cir.copy [[TMP2]] to [[TMP1]] : !cir.ptr<!ty_22A22>
}
11 changes: 11 additions & 0 deletions clang/test/CIR/CodeGen/no-proto-fun-ptr.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,14 @@ void check_noproto_ptr() {

void empty(void) {}

void buz() {
void (*func)();
(*func)();
}

// CHECK: cir.func no_proto @buz()
// CHECK: [[FNPTR_ALLOC:%.*]] = cir.alloca !cir.ptr<!cir.func<!void (...)>>, cir.ptr <!cir.ptr<!cir.func<!void (...)>>>, ["func"] {alignment = 8 : i64}
// CHECK: [[FNPTR:%.*]] = cir.load deref [[FNPTR_ALLOC]] : cir.ptr <!cir.ptr<!cir.func<!void (...)>>>, !cir.ptr<!cir.func<!void (...)>>
// CHECK: [[CAST:%.*]] = cir.cast(bitcast, %1 : !cir.ptr<!cir.func<!void (...)>>), !cir.ptr<!cir.func<!void ()>>
// CHECK: cir.call [[CAST]]() : (!cir.ptr<!cir.func<!void ()>>) -> ()
// CHECK: cir.return
16 changes: 12 additions & 4 deletions clang/test/CIR/CodeGen/no-prototype.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ int test1(int x) {
int noProto2();
int test2(int x) {
return noProto2(x);
// CHECK: %{{.+}} = cir.call @noProto2(%{{[0-9]+}}) : (!s32i) -> !s32i
// CHECK: [[GGO:%.*]] = cir.get_global @noProto2 : cir.ptr <!cir.func<!s32i (!s32i)>>
// CHECK: [[CAST:%.*]] = cir.cast(bitcast, %3 : !cir.ptr<!cir.func<!s32i (!s32i)>>), !cir.ptr<!cir.func<!s32i (!s32i)>>
// CHECK: {{.*}} = cir.call [[CAST]](%{{[0-9]+}}) : (!cir.ptr<!cir.func<!s32i (!s32i)>>, !s32i) -> !s32i
}
int noProto2(int x) { return x; }
// CHECK: cir.func no_proto @noProto2(%arg0: !s32i {{.+}}) -> !s32i
Expand All @@ -49,7 +51,9 @@ int noProto3();
int test3(int x) {
// CHECK: cir.func @test3
return noProto3(x);
// CHECK: %{{.+}} = cir.call @noProto3(%{{[0-9]+}}) : (!s32i) -> !s32i
// CHECK: [[GGO:%.*]] = cir.get_global @noProto3 : cir.ptr <!cir.func<!s32i (...)>>
// CHECK: [[CAST:%.*]] = cir.cast(bitcast, [[GGO]] : !cir.ptr<!cir.func<!s32i (...)>>), !cir.ptr<!cir.func<!s32i (!s32i)>>
// CHECK: {{%.*}} = cir.call [[CAST]](%{{[0-9]+}}) : (!cir.ptr<!cir.func<!s32i (!s32i)>>, !s32i) -> !s32i
}


Expand All @@ -64,14 +68,18 @@ int noProto4() { return 0; }
// cir.func private no_proto @noProto4() -> !s32i
int test4(int x) {
return noProto4(x); // Even if we know the definition, this should compile.
// CHECK: %{{.+}} = cir.call @noProto4(%{{.+}}) : (!s32i) -> !s32i
// CHECK: [[GGO:%.*]] = cir.get_global @noProto4 : cir.ptr <!cir.func<!s32i ()>>
// CHECK: [[CAST:%.*]] = cir.cast(bitcast, [[GGO]] : !cir.ptr<!cir.func<!s32i ()>>), !cir.ptr<!cir.func<!s32i (!s32i)>>
// CHECK: {{%.*}} = cir.call [[CAST]]({{%.*}}) : (!cir.ptr<!cir.func<!s32i (!s32i)>>, !s32i) -> !s32i
}

// No-proto definition followed by an incorrect call due to lack of args.
int noProto5();
int test5(int x) {
return noProto5();
// CHECK: %{{.+}} = cir.call @noProto5() : () -> !s32i
// CHECK: [[GGO:%.*]] = cir.get_global @noProto5 : cir.ptr <!cir.func<!s32i (!s32i)>>
// CHECK: [[CAST:%.*]] = cir.cast(bitcast, [[GGO]] : !cir.ptr<!cir.func<!s32i (!s32i)>>), !cir.ptr<!cir.func<!s32i ()>>
// CHECK: {{%.*}} = cir.call [[CAST]]() : (!cir.ptr<!cir.func<!s32i ()>>) -> !s32i
}
int noProto5(int x) { return x; }
// CHECK: cir.func no_proto @noProto5(%arg0: !s32i {{.+}}) -> !s32i

0 comments on commit 6a4b4a3

Please sign in to comment.