Skip to content

Commit 15680fe

Browse files
committed
Move fully-featured FunctionRef from extension/pytree to ExecuTorch core
I just re-synced it with LLVM, and it seems harmless. We previously simplified FunctionRef because of an ARM baremetal toolchain issue around placement of a lambda (#555), but the current implementation has no lambda. Motivation: I want to be able to use it in the next PR in thread_parallel_interface.h (the "threadpool active" mode) and threadpool in order to save a ton of size that std::function is currently wasting by virtue of supporting ownership and copying. ghstack-source-id: 8cfb16ed5aae905e2ac65996683f32689c4e5f46 ghstack-comment-id: 2828755099 Pull-Request-resolved: #10441
1 parent 03b0938 commit 15680fe

File tree

9 files changed

+140
-124
lines changed

9 files changed

+140
-124
lines changed

extension/pytree/function_ref.h

+8-106
Original file line numberDiff line numberDiff line change
@@ -6,117 +6,19 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
10-
//
11-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
12-
// See https://llvm.org/LICENSE.txt for license information.
13-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14-
//
15-
//===----------------------------------------------------------------------===//
16-
//
17-
// This file contains some extension to <functional>.
18-
//
19-
// No library is required when using these functions.
20-
//
21-
//===----------------------------------------------------------------------===//
22-
// Extra additions to <functional>
23-
//===----------------------------------------------------------------------===//
24-
25-
/// An efficient, type-erasing, non-owning reference to a callable. This is
26-
/// intended for use as the type of a function parameter that is not used
27-
/// after the function in question returns.
28-
///
29-
/// This class does not own the callable, so it is not in general safe to store
30-
/// a FunctionRef.
31-
32-
// torch::executor: modified from llvm::function_ref
33-
// - renamed to FunctionRef
34-
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
35-
// - use namespaced internal::remove_cvref_t
36-
379
#pragma once
3810

39-
#include <cstdint>
40-
#include <type_traits>
41-
#include <utility>
42-
43-
namespace executorch {
44-
namespace extension {
45-
namespace pytree {
46-
47-
//===----------------------------------------------------------------------===//
48-
// Features from C++20
49-
//===----------------------------------------------------------------------===//
50-
51-
namespace internal {
52-
53-
template <typename T>
54-
struct remove_cvref {
55-
using type =
56-
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
57-
};
58-
59-
template <typename T>
60-
using remove_cvref_t = typename remove_cvref<T>::type;
61-
62-
} // namespace internal
63-
64-
template <typename Fn>
65-
class FunctionRef;
66-
67-
template <typename Ret, typename... Params>
68-
class FunctionRef<Ret(Params...)> {
69-
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
70-
intptr_t callable;
71-
72-
template <typename Callable>
73-
static Ret callback_fn(intptr_t callable, Params... params) {
74-
return (*reinterpret_cast<Callable*>(callable))(
75-
std::forward<Params>(params)...);
76-
}
77-
78-
public:
79-
FunctionRef() = default;
80-
FunctionRef(std::nullptr_t) {}
81-
82-
template <typename Callable>
83-
FunctionRef(
84-
Callable&& callable,
85-
// This is not the copy-constructor.
86-
std::enable_if_t<!std::is_same<
87-
internal::remove_cvref_t<Callable>,
88-
FunctionRef>::value>* = nullptr,
89-
// Functor must be callable and return a suitable type.
90-
std::enable_if_t<
91-
std::is_void<Ret>::value ||
92-
std::is_convertible<
93-
decltype(std::declval<Callable>()(std::declval<Params>()...)),
94-
Ret>::value>* = nullptr)
95-
: callback(callback_fn<std::remove_reference_t<Callable>>),
96-
callable(reinterpret_cast<intptr_t>(&callable)) {}
97-
98-
Ret operator()(Params... params) const {
99-
return callback(callable, std::forward<Params>(params)...);
100-
}
11+
#include <executorch/runtime/core/function_ref.h>
10112

102-
explicit operator bool() const {
103-
return callback;
104-
}
13+
/// This header is DEPRECATED; use executorch/runtime/core/function_ref.h
14+
/// directly instead.
10515

106-
bool operator==(const FunctionRef<Ret(Params...)>& Other) const {
107-
return callable == Other.callable;
108-
}
109-
};
110-
} // namespace pytree
111-
} // namespace extension
112-
} // namespace executorch
16+
namespace executorch::extension::pytree {
17+
using executorch::runtime::FunctionRef;
18+
} // namespace executorch::extension::pytree
11319

114-
namespace torch {
115-
namespace executor {
116-
namespace pytree {
20+
namespace torch::executor::pytree {
11721
// TODO(T197294990): Remove these deprecated aliases once all users have moved
11822
// to the new `::executorch` namespaces.
11923
using ::executorch::extension::pytree::FunctionRef;
120-
} // namespace pytree
121-
} // namespace executor
122-
} // namespace torch
24+
} // namespace torch::executor::pytree

extension/pytree/test/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
1919

2020
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
2121

22-
set(_test_srcs function_ref_test.cpp test_pytree.cpp)
22+
set(_test_srcs test_pytree.cpp)
2323

2424
et_cxx_test(extension_pytree_test SOURCES ${_test_srcs} EXTRA_LIBS)

extension/pytree/test/TARGETS

-6
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@ cpp_unittest(
1010
deps = ["//executorch/extension/pytree:pytree"],
1111
)
1212

13-
cpp_unittest(
14-
name = "function_ref_test",
15-
srcs = ["function_ref_test.cpp"],
16-
deps = ["//executorch/extension/pytree:pytree"],
17-
)
18-
1913
python_unittest(
2014
name = "pybindings_test",
2115
srcs = [

runtime/core/function_ref.h

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
10+
//
11+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://llvm.org/LICENSE.txt for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//===----------------------------------------------------------------------===//
16+
//
17+
// This file contains some extension to <functional>.
18+
//
19+
// No library is required when using these functions.
20+
//
21+
//===----------------------------------------------------------------------===//
22+
// Extra additions to <functional>
23+
//===----------------------------------------------------------------------===//
24+
25+
/// An efficient, type-erasing, non-owning reference to a callable. This is
26+
/// intended for use as the type of a function parameter that is not used
27+
/// after the function in question returns.
28+
///
29+
/// This class does not own the callable, so it is not in general safe to store
30+
/// a FunctionRef.
31+
32+
// torch::executor: modified from llvm::function_ref
33+
// - renamed to FunctionRef
34+
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
35+
// - use namespaced internal::remove_cvref_t
36+
37+
#pragma once
38+
39+
#include <cstdint>
40+
#include <type_traits>
41+
#include <utility>
42+
43+
namespace executorch::runtime {
44+
45+
//===----------------------------------------------------------------------===//
46+
// Features from C++20
47+
//===----------------------------------------------------------------------===//
48+
49+
namespace internal {
50+
51+
template <typename T>
52+
struct remove_cvref {
53+
using type =
54+
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
55+
};
56+
57+
template <typename T>
58+
using remove_cvref_t = typename remove_cvref<T>::type;
59+
60+
} // namespace internal
61+
62+
template <typename Fn>
63+
class FunctionRef;
64+
65+
template <typename Ret, typename... Params>
66+
class FunctionRef<Ret(Params...)> {
67+
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
68+
intptr_t callable;
69+
70+
template <typename Callable>
71+
static Ret callback_fn(intptr_t callable, Params... params) {
72+
return (*reinterpret_cast<Callable*>(callable))(
73+
std::forward<Params>(params)...);
74+
}
75+
76+
public:
77+
FunctionRef() = default;
78+
FunctionRef(std::nullptr_t) {}
79+
80+
template <typename Callable>
81+
FunctionRef(
82+
Callable&& callable,
83+
// This is not the copy-constructor.
84+
std::enable_if_t<!std::is_same<
85+
internal::remove_cvref_t<Callable>,
86+
FunctionRef>::value>* = nullptr,
87+
// Functor must be callable and return a suitable type.
88+
std::enable_if_t<
89+
std::is_void<Ret>::value ||
90+
std::is_convertible<
91+
decltype(std::declval<Callable>()(std::declval<Params>()...)),
92+
Ret>::value>* = nullptr)
93+
: callback(callback_fn<std::remove_reference_t<Callable>>),
94+
callable(reinterpret_cast<intptr_t>(&callable)) {}
95+
96+
Ret operator()(Params... params) const {
97+
return callback(callable, std::forward<Params>(params)...);
98+
}
99+
100+
explicit operator bool() const {
101+
return callback;
102+
}
103+
104+
bool operator==(const FunctionRef<Ret(Params...)>& Other) const {
105+
return callable == Other.callable;
106+
}
107+
};
108+
} // namespace executorch::runtime

runtime/core/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def define_common_targets():
4141
"defines.h",
4242
"error.h",
4343
"freeable_buffer.h",
44+
"function_ref.h",
4445
"result.h",
4546
"span.h",
4647
],

runtime/core/test/CMakeLists.txt

+5-4
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
2020
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
2121

2222
set(_test_srcs
23-
span_test.cpp
23+
array_ref_test.cpp
2424
error_handling_test.cpp
25+
evalue_test.cpp
2526
event_tracer_test.cpp
2627
freeable_buffer_test.cpp
27-
array_ref_test.cpp
28-
memory_allocator_test.cpp
28+
function_ref_test.cpp
2929
hierarchical_allocator_test.cpp
30-
evalue_test.cpp
30+
memory_allocator_test.cpp
31+
span_test.cpp
3132
)
3233

3334
et_cxx_test(runtime_core_test SOURCES ${_test_srcs} EXTRA_LIBS)

extension/pytree/test/function_ref_test.cpp renamed to runtime/core/test/function_ref_test.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/extension/pytree/function_ref.h>
9+
#include <executorch/runtime/core/function_ref.h>
1010

1111
#include <gtest/gtest.h>
1212

1313
using namespace ::testing;
1414

15-
using ::executorch::extension::pytree::FunctionRef;
15+
using ::executorch::runtime::FunctionRef;
1616

1717
namespace {
1818
void one(int32_t& i) {

runtime/core/test/targets.bzl

+10
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@ def define_common_targets():
3333
],
3434
)
3535

36+
runtime.cxx_test(
37+
name = "function_ref_test",
38+
srcs = [
39+
"function_ref_test.cpp",
40+
],
41+
deps = [
42+
"//executorch/runtime/core:core",
43+
],
44+
)
45+
3646
runtime.cxx_test(
3747
name = "event_tracer_test",
3848
srcs = [

test/utils/OSSTestConfig.json

+5-5
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
{
4646
"directory": "extension/pytree/test",
4747
"sources": [
48-
"function_ref_test.cpp",
4948
"test_pytree.cpp"
5049
]
5150
},
@@ -96,14 +95,15 @@
9695
{
9796
"directory": "runtime/core/test",
9897
"sources": [
99-
"span_test.cpp",
98+
"array_ref_test.cpp",
10099
"error_handling_test.cpp",
100+
"evalue_test.cpp",
101101
"event_tracer_test.cpp",
102102
"freeable_buffer_test.cpp",
103-
"array_ref_test.cpp",
104-
"memory_allocator_test.cpp",
103+
"function_ref_test.cpp",
105104
"hierarchical_allocator_test.cpp",
106-
"evalue_test.cpp"
105+
"memory_allocator_test.cpp",
106+
"span_test.cpp"
107107
]
108108
},
109109
{

0 commit comments

Comments
 (0)