|
30 | 30 | /// a FunctionRef.
|
31 | 31 |
|
32 | 32 | // torch::executor: modified from llvm::function_ref
|
33 |
| -// see https://www.foonathan.net/2017/01/function-ref-implementation/ |
| 33 | +// - renamed to FunctionRef |
| 34 | +// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses |
| 35 | +// - use namespaced internal::remove_cvref_t |
34 | 36 |
|
35 | 37 | #pragma once
|
36 | 38 |
|
@@ -64,99 +66,47 @@ class FunctionRef;
|
64 | 66 |
|
65 | 67 | template <typename Ret, typename... Params>
|
66 | 68 | class FunctionRef<Ret(Params...)> {
|
67 |
| - Ret (*callback_)(const void* memory, Params... params) = nullptr; |
68 |
| - union Storage { |
69 |
| - void* callable; |
70 |
| - Ret (*function)(Params...); |
71 |
| - } storage_; |
| 69 | + Ret (*callback)(intptr_t callable, Params... params) = nullptr; |
| 70 | + intptr_t callable; |
| 71 | + |
| 72 | + template <typename Callable> |
| 73 | + static Ret callback_fn(intptr_t callable, Params... params) { |
| 74 | + return (*reinterpret_cast<Callable*>(callable))( |
| 75 | + std::forward<Params>(params)...); |
| 76 | + } |
72 | 77 |
|
73 | 78 | public:
|
74 | 79 | FunctionRef() = default;
|
75 |
| - explicit FunctionRef(std::nullptr_t) {} |
76 |
| - |
77 |
| - /** |
78 |
| - * Case 1: A callable object passed by lvalue reference. |
79 |
| - * Taking rvalue reference is error prone because the object will be always |
80 |
| - * be destroyed immediately. |
81 |
| - */ |
82 |
| - template < |
83 |
| - typename Callable, |
| 80 | + FunctionRef(std::nullptr_t) {} |
| 81 | + |
| 82 | + template <typename Callable> |
| 83 | + FunctionRef( |
| 84 | + Callable&& callable, |
84 | 85 | // This is not the copy-constructor.
|
85 |
| - typename std::enable_if< |
86 |
| - !std::is_same<internal::remove_cvref_t<Callable>, FunctionRef>::value, |
87 |
| - int32_t>::type = 0, |
88 |
| - // Avoid lvalue reference to non-capturing lambda. |
89 |
| - typename std::enable_if< |
90 |
| - !std::is_convertible<Callable, Ret (*)(Params...)>::value, |
91 |
| - int32_t>::type = 0, |
| 86 | + std::enable_if_t<!std::is_same< |
| 87 | + internal::remove_cvref_t<Callable>, |
| 88 | + FunctionRef>::value>* = nullptr, |
92 | 89 | // Functor must be callable and return a suitable type.
|
93 |
| - // To make this container type safe, we need to ensure either: |
94 |
| - // 1. The return type is void. |
95 |
| - // 2. Or the resulting type from calling the callable is convertible to |
96 |
| - // the declared return type. |
97 |
| - typename std::enable_if< |
| 90 | + std::enable_if_t< |
98 | 91 | std::is_void<Ret>::value ||
|
99 |
| - std::is_convertible< |
100 |
| - decltype(std::declval<Callable>()(std::declval<Params>()...)), |
101 |
| - Ret>::value, |
102 |
| - int32_t>::type = 0> |
103 |
| - explicit FunctionRef(Callable& callable) |
104 |
| - : callback_([](const void* memory, Params... params) { |
105 |
| - auto& storage = *static_cast<const Storage*>(memory); |
106 |
| - auto& callable = *static_cast<Callable*>(storage.callable); |
107 |
| - return static_cast<Ret>(callable(std::forward<Params>(params)...)); |
108 |
| - }) { |
109 |
| - storage_.callable = &callable; |
110 |
| - } |
111 |
| - |
112 |
| - /** |
113 |
| - * Case 2: A plain function pointer. |
114 |
| - * Instead of storing an opaque pointer to underlying callable object, |
115 |
| - * store a function pointer directly. |
116 |
| - * Note that in the future a variant which coerces compatible function |
117 |
| - * pointers could be implemented by erasing the storage type. |
118 |
| - */ |
119 |
| - /* implicit */ FunctionRef(Ret (*ptr)(Params...)) |
120 |
| - : callback_([](const void* memory, Params... params) { |
121 |
| - auto& storage = *static_cast<const Storage*>(memory); |
122 |
| - return storage.function(std::forward<Params>(params)...); |
123 |
| - }) { |
124 |
| - storage_.function = ptr; |
125 |
| - } |
126 |
| - |
127 |
| - /** |
128 |
| - * Case 3: Implicit conversion from lambda to FunctionRef. |
129 |
| - * A common use pattern is like: |
130 |
| - * void foo(FunctionRef<...>) {...} |
131 |
| - * foo([](...){...}) |
132 |
| - * Here constructors for non const lvalue reference or function pointer |
133 |
| - * would not work because they do not cover implicit conversion from rvalue |
134 |
| - * lambda. |
135 |
| - * We need to define a constructor for capturing temporary callables and |
136 |
| - * always try to convert the lambda to a function pointer behind the scene. |
137 |
| - */ |
138 |
| - template < |
139 |
| - typename Function, |
140 |
| - // This is not the copy-constructor. |
141 |
| - typename std::enable_if< |
142 |
| - !std::is_same<Function, FunctionRef>::value, |
143 |
| - int32_t>::type = 0, |
144 |
| - // Function is convertible to pointer of (Params...) -> Ret. |
145 |
| - typename std::enable_if< |
146 |
| - std::is_convertible<Function, Ret (*)(Params...)>::value, |
147 |
| - int32_t>::type = 0> |
148 |
| - /* implicit */ FunctionRef(const Function& function) |
149 |
| - : FunctionRef(static_cast<Ret (*)(Params...)>(function)) {} |
| 92 | + std::is_convertible< |
| 93 | + decltype(std::declval<Callable>()(std::declval<Params>()...)), |
| 94 | + Ret>::value>* = nullptr) |
| 95 | + : callback(callback_fn<std::remove_reference_t<Callable>>), |
| 96 | + callable(reinterpret_cast<intptr_t>(&callable)) {} |
150 | 97 |
|
151 | 98 | Ret operator()(Params... params) const {
|
152 |
| - return callback_(&storage_, std::forward<Params>(params)...); |
| 99 | + return callback(callable, std::forward<Params>(params)...); |
153 | 100 | }
|
154 | 101 |
|
155 | 102 | explicit operator bool() const {
|
156 |
| - return callback_; |
| 103 | + return callback; |
157 | 104 | }
|
158 |
| -}; |
159 | 105 |
|
| 106 | + bool operator==(const FunctionRef<Ret(Params...)>& Other) const { |
| 107 | + return callable == Other.callable; |
| 108 | + } |
| 109 | +}; |
160 | 110 | } // namespace pytree
|
161 | 111 | } // namespace extension
|
162 | 112 | } // namespace executorch
|
|
0 commit comments