diff --git a/extension/pytree/function_ref.h b/extension/pytree/function_ref.h index 0458610c4db..c81add169c8 100644 --- a/extension/pytree/function_ref.h +++ b/extension/pytree/function_ref.h @@ -30,7 +30,9 @@ /// a FunctionRef. // torch::executor: modified from llvm::function_ref -// see https://www.foonathan.net/2017/01/function-ref-implementation/ +// - renamed to FunctionRef +// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses +// - use namespaced internal::remove_cvref_t #pragma once @@ -64,99 +66,47 @@ class FunctionRef; template class FunctionRef { - Ret (*callback_)(const void* memory, Params... params) = nullptr; - union Storage { - void* callable; - Ret (*function)(Params...); - } storage_; + Ret (*callback)(intptr_t callable, Params... params) = nullptr; + intptr_t callable; + + template + static Ret callback_fn(intptr_t callable, Params... params) { + return (*reinterpret_cast(callable))( + std::forward(params)...); + } public: FunctionRef() = default; - explicit FunctionRef(std::nullptr_t) {} - - /** - * Case 1: A callable object passed by lvalue reference. - * Taking rvalue reference is error prone because the object will be always - * be destroyed immediately. - */ - template < - typename Callable, + FunctionRef(std::nullptr_t) {} + + template + FunctionRef( + Callable&& callable, // This is not the copy-constructor. - typename std::enable_if< - !std::is_same, FunctionRef>::value, - int32_t>::type = 0, - // Avoid lvalue reference to non-capturing lambda. - typename std::enable_if< - !std::is_convertible::value, - int32_t>::type = 0, + std::enable_if_t, + FunctionRef>::value>* = nullptr, // Functor must be callable and return a suitable type. - // To make this container type safe, we need to ensure either: - // 1. The return type is void. - // 2. Or the resulting type from calling the callable is convertible to - // the declared return type. - typename std::enable_if< + std::enable_if_t< std::is_void::value || - std::is_convertible< - decltype(std::declval()(std::declval()...)), - Ret>::value, - int32_t>::type = 0> - explicit FunctionRef(Callable& callable) - : callback_([](const void* memory, Params... params) { - auto& storage = *static_cast(memory); - auto& callable = *static_cast(storage.callable); - return static_cast(callable(std::forward(params)...)); - }) { - storage_.callable = &callable; - } - - /** - * Case 2: A plain function pointer. - * Instead of storing an opaque pointer to underlying callable object, - * store a function pointer directly. - * Note that in the future a variant which coerces compatible function - * pointers could be implemented by erasing the storage type. - */ - /* implicit */ FunctionRef(Ret (*ptr)(Params...)) - : callback_([](const void* memory, Params... params) { - auto& storage = *static_cast(memory); - return storage.function(std::forward(params)...); - }) { - storage_.function = ptr; - } - - /** - * Case 3: Implicit conversion from lambda to FunctionRef. - * A common use pattern is like: - * void foo(FunctionRef<...>) {...} - * foo([](...){...}) - * Here constructors for non const lvalue reference or function pointer - * would not work because they do not cover implicit conversion from rvalue - * lambda. - * We need to define a constructor for capturing temporary callables and - * always try to convert the lambda to a function pointer behind the scene. - */ - template < - typename Function, - // This is not the copy-constructor. - typename std::enable_if< - !std::is_same::value, - int32_t>::type = 0, - // Function is convertible to pointer of (Params...) -> Ret. - typename std::enable_if< - std::is_convertible::value, - int32_t>::type = 0> - /* implicit */ FunctionRef(const Function& function) - : FunctionRef(static_cast(function)) {} + std::is_convertible< + decltype(std::declval()(std::declval()...)), + Ret>::value>* = nullptr) + : callback(callback_fn>), + callable(reinterpret_cast(&callable)) {} Ret operator()(Params... params) const { - return callback_(&storage_, std::forward(params)...); + return callback(callable, std::forward(params)...); } explicit operator bool() const { - return callback_; + return callback; } -}; + bool operator==(const FunctionRef& Other) const { + return callable == Other.callable; + } +}; } // namespace pytree } // namespace extension } // namespace executorch diff --git a/extension/pytree/test/function_ref_test.cpp b/extension/pytree/test/function_ref_test.cpp index a3cdbd824bf..cdb2c0538fd 100644 --- a/extension/pytree/test/function_ref_test.cpp +++ b/extension/pytree/test/function_ref_test.cpp @@ -15,21 +15,6 @@ using namespace ::testing; using ::executorch::extension::pytree::FunctionRef; namespace { -class Item { - private: - int32_t val_; - FunctionRef ref_; - - public: - /* implicit */ Item(int32_t val, FunctionRef ref) - : val_(val), ref_(ref) {} - - int32_t get() { - ref_(val_); - return val_; - } -}; - void one(int32_t& i) { i = 1; } @@ -39,8 +24,9 @@ void one(int32_t& i) { TEST(FunctionRefTest, CapturingLambda) { auto one = 1; auto f = [&](int32_t& i) { i = one; }; - Item item(0, FunctionRef{f}); - EXPECT_EQ(item.get(), 1); + int32_t val = 0; + FunctionRef{f}(val); + EXPECT_EQ(val, 1); // ERROR: // Item item1(0, f); // Item item2(0, [&](int32_t& i) { i = 2; }); @@ -58,16 +44,6 @@ TEST(FunctionRefTest, NonCapturingLambda) { FunctionRef ref1(lambda); ref1(val); EXPECT_EQ(val, 1); - - Item item(0, [](int32_t& i) { i = 1; }); - EXPECT_EQ(item.get(), 1); - - auto f = [](int32_t& i) { i = 1; }; - Item item1(0, f); - EXPECT_EQ(item1.get(), 1); - - Item item2(0, std::move(f)); - EXPECT_EQ(item2.get(), 1); } TEST(FunctionRefTest, FunctionPointer) { @@ -76,9 +52,8 @@ TEST(FunctionRefTest, FunctionPointer) { ref(val); EXPECT_EQ(val, 1); - Item item(0, one); - EXPECT_EQ(item.get(), 1); - - Item item1(0, &one); - EXPECT_EQ(item1.get(), 1); + val = 0; + FunctionRef ref2(one); + ref2(val); + EXPECT_EQ(val, 1); }