Skip to content

Re-sync extension/pytree/function_ref.h with LLVM #10440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 31 additions & 81 deletions extension/pytree/function_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
/// a FunctionRef.

// torch::executor: modified from llvm::function_ref
// see https://www.foonathan.net/2017/01/function-ref-implementation/
// - renamed to FunctionRef
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
// - use namespaced internal::remove_cvref_t

#pragma once

Expand Down Expand Up @@ -64,99 +66,47 @@ class FunctionRef;

template <typename Ret, typename... Params>
class FunctionRef<Ret(Params...)> {
Ret (*callback_)(const void* memory, Params... params) = nullptr;
union Storage {
void* callable;
Ret (*function)(Params...);
} storage_;
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
intptr_t callable;

template <typename Callable>
static Ret callback_fn(intptr_t callable, Params... params) {
return (*reinterpret_cast<Callable*>(callable))(
std::forward<Params>(params)...);
}

public:
FunctionRef() = default;
explicit FunctionRef(std::nullptr_t) {}

/**
* Case 1: A callable object passed by lvalue reference.
* Taking rvalue reference is error prone because the object will be always
* be destroyed immediately.
*/
template <
typename Callable,
FunctionRef(std::nullptr_t) {}

template <typename Callable>
FunctionRef(
Callable&& callable,
// This is not the copy-constructor.
typename std::enable_if<
!std::is_same<internal::remove_cvref_t<Callable>, FunctionRef>::value,
int32_t>::type = 0,
// Avoid lvalue reference to non-capturing lambda.
typename std::enable_if<
!std::is_convertible<Callable, Ret (*)(Params...)>::value,
int32_t>::type = 0,
std::enable_if_t<!std::is_same<
internal::remove_cvref_t<Callable>,
FunctionRef>::value>* = nullptr,
// Functor must be callable and return a suitable type.
// To make this container type safe, we need to ensure either:
// 1. The return type is void.
// 2. Or the resulting type from calling the callable is convertible to
// the declared return type.
typename std::enable_if<
std::enable_if_t<
std::is_void<Ret>::value ||
std::is_convertible<
decltype(std::declval<Callable>()(std::declval<Params>()...)),
Ret>::value,
int32_t>::type = 0>
explicit FunctionRef(Callable& callable)
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
auto& callable = *static_cast<Callable*>(storage.callable);
return static_cast<Ret>(callable(std::forward<Params>(params)...));
}) {
storage_.callable = &callable;
}

/**
* Case 2: A plain function pointer.
* Instead of storing an opaque pointer to underlying callable object,
* store a function pointer directly.
* Note that in the future a variant which coerces compatible function
* pointers could be implemented by erasing the storage type.
*/
/* implicit */ FunctionRef(Ret (*ptr)(Params...))
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
return storage.function(std::forward<Params>(params)...);
}) {
storage_.function = ptr;
}

/**
* Case 3: Implicit conversion from lambda to FunctionRef.
* A common use pattern is like:
* void foo(FunctionRef<...>) {...}
* foo([](...){...})
* Here constructors for non const lvalue reference or function pointer
* would not work because they do not cover implicit conversion from rvalue
* lambda.
* We need to define a constructor for capturing temporary callables and
* always try to convert the lambda to a function pointer behind the scene.
*/
template <
typename Function,
// This is not the copy-constructor.
typename std::enable_if<
!std::is_same<Function, FunctionRef>::value,
int32_t>::type = 0,
// Function is convertible to pointer of (Params...) -> Ret.
typename std::enable_if<
std::is_convertible<Function, Ret (*)(Params...)>::value,
int32_t>::type = 0>
/* implicit */ FunctionRef(const Function& function)
: FunctionRef(static_cast<Ret (*)(Params...)>(function)) {}
std::is_convertible<
decltype(std::declval<Callable>()(std::declval<Params>()...)),
Ret>::value>* = nullptr)
: callback(callback_fn<std::remove_reference_t<Callable>>),
callable(reinterpret_cast<intptr_t>(&callable)) {}

Ret operator()(Params... params) const {
return callback_(&storage_, std::forward<Params>(params)...);
return callback(callable, std::forward<Params>(params)...);
}

explicit operator bool() const {
return callback_;
return callback;
}
};

bool operator==(const FunctionRef<Ret(Params...)>& Other) const {
return callable == Other.callable;
}
};
} // namespace pytree
} // namespace extension
} // namespace executorch
Expand Down
39 changes: 7 additions & 32 deletions extension/pytree/test/function_ref_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,6 @@ using namespace ::testing;
using ::executorch::extension::pytree::FunctionRef;

namespace {
class Item {
private:
int32_t val_;
FunctionRef<void(int32_t&)> ref_;

public:
/* implicit */ Item(int32_t val, FunctionRef<void(int32_t&)> ref)
: val_(val), ref_(ref) {}

int32_t get() {
ref_(val_);
return val_;
}
};

void one(int32_t& i) {
i = 1;
}
Expand All @@ -39,8 +24,9 @@ void one(int32_t& i) {
TEST(FunctionRefTest, CapturingLambda) {
auto one = 1;
auto f = [&](int32_t& i) { i = one; };
Item item(0, FunctionRef<void(int32_t&)>{f});
EXPECT_EQ(item.get(), 1);
int32_t val = 0;
FunctionRef<void(int32_t&)>{f}(val);
EXPECT_EQ(val, 1);
// ERROR:
// Item item1(0, f);
// Item item2(0, [&](int32_t& i) { i = 2; });
Expand All @@ -58,16 +44,6 @@ TEST(FunctionRefTest, NonCapturingLambda) {
FunctionRef<void(int32_t&)> ref1(lambda);
ref1(val);
EXPECT_EQ(val, 1);

Item item(0, [](int32_t& i) { i = 1; });
EXPECT_EQ(item.get(), 1);

auto f = [](int32_t& i) { i = 1; };
Item item1(0, f);
EXPECT_EQ(item1.get(), 1);

Item item2(0, std::move(f));
EXPECT_EQ(item2.get(), 1);
}

TEST(FunctionRefTest, FunctionPointer) {
Expand All @@ -76,9 +52,8 @@ TEST(FunctionRefTest, FunctionPointer) {
ref(val);
EXPECT_EQ(val, 1);

Item item(0, one);
EXPECT_EQ(item.get(), 1);

Item item1(0, &one);
EXPECT_EQ(item1.get(), 1);
val = 0;
FunctionRef<void(int32_t&)> ref2(one);
ref2(val);
EXPECT_EQ(val, 1);
}
Loading