Skip to content

Commit 6feb623

Browse files
authored
Re-sync extension/pytree/function_ref.h with LLVM (#10440)
We had an old version, and that version didn't work quite right when handed a const ref to a callable. Just resync and note divergences from the original in the comment. Our old version also had an `explicit` constructor, which is *not* desirable because it complicates callsites. It was also a divergence from both LLVM and [the upcoming C++26 std::function_ref](https://en.cppreference.com/w/cpp/utility/functional/function_ref/function_ref). The test made invalid use of a FunctionRef -- you can't keep one around past the lifetime of the function object it refers to, but the test held onto a temporary capturing lambda in an object. I fixed the test.
1 parent b415af0 commit 6feb623

File tree

2 files changed

+38
-113
lines changed

2 files changed

+38
-113
lines changed

extension/pytree/function_ref.h

+31-81
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
/// a FunctionRef.
3131

3232
// torch::executor: modified from llvm::function_ref
33-
// see https://www.foonathan.net/2017/01/function-ref-implementation/
33+
// - renamed to FunctionRef
34+
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
35+
// - use namespaced internal::remove_cvref_t
3436

3537
#pragma once
3638

@@ -64,99 +66,47 @@ class FunctionRef;
6466

6567
template <typename Ret, typename... Params>
6668
class FunctionRef<Ret(Params...)> {
67-
Ret (*callback_)(const void* memory, Params... params) = nullptr;
68-
union Storage {
69-
void* callable;
70-
Ret (*function)(Params...);
71-
} storage_;
69+
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
70+
intptr_t callable;
71+
72+
template <typename Callable>
73+
static Ret callback_fn(intptr_t callable, Params... params) {
74+
return (*reinterpret_cast<Callable*>(callable))(
75+
std::forward<Params>(params)...);
76+
}
7277

7378
public:
7479
FunctionRef() = default;
75-
explicit FunctionRef(std::nullptr_t) {}
76-
77-
/**
78-
* Case 1: A callable object passed by lvalue reference.
79-
* Taking rvalue reference is error prone because the object will be always
80-
* be destroyed immediately.
81-
*/
82-
template <
83-
typename Callable,
80+
FunctionRef(std::nullptr_t) {}
81+
82+
template <typename Callable>
83+
FunctionRef(
84+
Callable&& callable,
8485
// This is not the copy-constructor.
85-
typename std::enable_if<
86-
!std::is_same<internal::remove_cvref_t<Callable>, FunctionRef>::value,
87-
int32_t>::type = 0,
88-
// Avoid lvalue reference to non-capturing lambda.
89-
typename std::enable_if<
90-
!std::is_convertible<Callable, Ret (*)(Params...)>::value,
91-
int32_t>::type = 0,
86+
std::enable_if_t<!std::is_same<
87+
internal::remove_cvref_t<Callable>,
88+
FunctionRef>::value>* = nullptr,
9289
// Functor must be callable and return a suitable type.
93-
// To make this container type safe, we need to ensure either:
94-
// 1. The return type is void.
95-
// 2. Or the resulting type from calling the callable is convertible to
96-
// the declared return type.
97-
typename std::enable_if<
90+
std::enable_if_t<
9891
std::is_void<Ret>::value ||
99-
std::is_convertible<
100-
decltype(std::declval<Callable>()(std::declval<Params>()...)),
101-
Ret>::value,
102-
int32_t>::type = 0>
103-
explicit FunctionRef(Callable& callable)
104-
: callback_([](const void* memory, Params... params) {
105-
auto& storage = *static_cast<const Storage*>(memory);
106-
auto& callable = *static_cast<Callable*>(storage.callable);
107-
return static_cast<Ret>(callable(std::forward<Params>(params)...));
108-
}) {
109-
storage_.callable = &callable;
110-
}
111-
112-
/**
113-
* Case 2: A plain function pointer.
114-
* Instead of storing an opaque pointer to underlying callable object,
115-
* store a function pointer directly.
116-
* Note that in the future a variant which coerces compatible function
117-
* pointers could be implemented by erasing the storage type.
118-
*/
119-
/* implicit */ FunctionRef(Ret (*ptr)(Params...))
120-
: callback_([](const void* memory, Params... params) {
121-
auto& storage = *static_cast<const Storage*>(memory);
122-
return storage.function(std::forward<Params>(params)...);
123-
}) {
124-
storage_.function = ptr;
125-
}
126-
127-
/**
128-
* Case 3: Implicit conversion from lambda to FunctionRef.
129-
* A common use pattern is like:
130-
* void foo(FunctionRef<...>) {...}
131-
* foo([](...){...})
132-
* Here constructors for non const lvalue reference or function pointer
133-
* would not work because they do not cover implicit conversion from rvalue
134-
* lambda.
135-
* We need to define a constructor for capturing temporary callables and
136-
* always try to convert the lambda to a function pointer behind the scene.
137-
*/
138-
template <
139-
typename Function,
140-
// This is not the copy-constructor.
141-
typename std::enable_if<
142-
!std::is_same<Function, FunctionRef>::value,
143-
int32_t>::type = 0,
144-
// Function is convertible to pointer of (Params...) -> Ret.
145-
typename std::enable_if<
146-
std::is_convertible<Function, Ret (*)(Params...)>::value,
147-
int32_t>::type = 0>
148-
/* implicit */ FunctionRef(const Function& function)
149-
: FunctionRef(static_cast<Ret (*)(Params...)>(function)) {}
92+
std::is_convertible<
93+
decltype(std::declval<Callable>()(std::declval<Params>()...)),
94+
Ret>::value>* = nullptr)
95+
: callback(callback_fn<std::remove_reference_t<Callable>>),
96+
callable(reinterpret_cast<intptr_t>(&callable)) {}
15097

15198
Ret operator()(Params... params) const {
152-
return callback_(&storage_, std::forward<Params>(params)...);
99+
return callback(callable, std::forward<Params>(params)...);
153100
}
154101

155102
explicit operator bool() const {
156-
return callback_;
103+
return callback;
157104
}
158-
};
159105

106+
bool operator==(const FunctionRef<Ret(Params...)>& Other) const {
107+
return callable == Other.callable;
108+
}
109+
};
160110
} // namespace pytree
161111
} // namespace extension
162112
} // namespace executorch

extension/pytree/test/function_ref_test.cpp

+7-32
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,6 @@ using namespace ::testing;
1515
using ::executorch::extension::pytree::FunctionRef;
1616

1717
namespace {
18-
class Item {
19-
private:
20-
int32_t val_;
21-
FunctionRef<void(int32_t&)> ref_;
22-
23-
public:
24-
/* implicit */ Item(int32_t val, FunctionRef<void(int32_t&)> ref)
25-
: val_(val), ref_(ref) {}
26-
27-
int32_t get() {
28-
ref_(val_);
29-
return val_;
30-
}
31-
};
32-
3318
void one(int32_t& i) {
3419
i = 1;
3520
}
@@ -39,8 +24,9 @@ void one(int32_t& i) {
3924
TEST(FunctionRefTest, CapturingLambda) {
4025
auto one = 1;
4126
auto f = [&](int32_t& i) { i = one; };
42-
Item item(0, FunctionRef<void(int32_t&)>{f});
43-
EXPECT_EQ(item.get(), 1);
27+
int32_t val = 0;
28+
FunctionRef<void(int32_t&)>{f}(val);
29+
EXPECT_EQ(val, 1);
4430
// ERROR:
4531
// Item item1(0, f);
4632
// Item item2(0, [&](int32_t& i) { i = 2; });
@@ -58,16 +44,6 @@ TEST(FunctionRefTest, NonCapturingLambda) {
5844
FunctionRef<void(int32_t&)> ref1(lambda);
5945
ref1(val);
6046
EXPECT_EQ(val, 1);
61-
62-
Item item(0, [](int32_t& i) { i = 1; });
63-
EXPECT_EQ(item.get(), 1);
64-
65-
auto f = [](int32_t& i) { i = 1; };
66-
Item item1(0, f);
67-
EXPECT_EQ(item1.get(), 1);
68-
69-
Item item2(0, std::move(f));
70-
EXPECT_EQ(item2.get(), 1);
7147
}
7248

7349
TEST(FunctionRefTest, FunctionPointer) {
@@ -76,9 +52,8 @@ TEST(FunctionRefTest, FunctionPointer) {
7652
ref(val);
7753
EXPECT_EQ(val, 1);
7854

79-
Item item(0, one);
80-
EXPECT_EQ(item.get(), 1);
81-
82-
Item item1(0, &one);
83-
EXPECT_EQ(item1.get(), 1);
55+
val = 0;
56+
FunctionRef<void(int32_t&)> ref2(one);
57+
ref2(val);
58+
EXPECT_EQ(val, 1);
8459
}

0 commit comments

Comments
 (0)