From df4cfee9a5fb34609f33e969a72028d3e76c4e75 Mon Sep 17 00:00:00 2001 From: Alexey Spiridonov Date: Wed, 30 Oct 2024 01:03:33 -0700 Subject: [PATCH] `coro/TaskWrapper.h`, a helper for wrapping `Task` / `TaskPromise` Summary: Here are few possible reasons to want to wrap `Task`: - `SafeTask` (upcoming in D61995850) -- quacks exactly like `Task` but includes some additional compile-time enforcement that makes it harder to make dangling-reference bugs. - `TaskDfatalIfNotAwaited` -- quacks exactly like `Task`, but used for those cases where you really want to make sure the task is awaited. It would be possible to open up `Task` / `TaskPromise` for inheritance, but this comes with a risk of object slicing bugs, some quite subtle. Moreover, while for the simplest wrappers the "is-a-Task" implied by relationship can be okay, one can imagine more featureful wrappers where it is absolutely not appropriate. The composition approach is more explicit, and thus safer for the implementer of the wrapper. The downside of composition is that we have to transparently forward the entire -- and rather large -- API surface of `TaskPromise` / `Task`. This would be far too fragile if it were copy-pasted in each individual wrapper. By collecting the implementation in a central `TaskWapper.h`, the cost of adapting to `Task.h` API changes should stay low. Reviewed By: ispeters, skrueger Differential Revision: D62994547 fbshipit-source-id: 5d7e11a386950b9e07d34b28cf7f986f7eb94e66 --- folly/coro/BUCK | 6 + folly/coro/Task.h | 45 ++++-- folly/coro/TaskWrapper.h | 214 ++++++++++++++++++++++++++++ folly/coro/test/BUCK | 11 ++ folly/coro/test/TaskWrapperTest.cpp | 179 +++++++++++++++++++++++ 5 files changed, 443 insertions(+), 12 deletions(-) create mode 100644 folly/coro/TaskWrapper.h create mode 100644 folly/coro/test/TaskWrapperTest.cpp diff --git a/folly/coro/BUCK b/folly/coro/BUCK index 8944ca4f58e..dc70aff440a 100644 --- a/folly/coro/BUCK +++ b/folly/coro/BUCK @@ -543,6 +543,12 @@ cpp_library( ], ) +cpp_library( + name = "task_wrapper", + headers = ["TaskWrapper.h"], + exported_deps = [":task"], +) + cpp_library( name = "timed_wait", headers = ["TimedWait.h"], diff --git a/folly/coro/Task.h b/folly/coro/Task.h index 4ffda52dbca..d328d2f7a7d 100644 --- a/folly/coro/Task.h +++ b/folly/coro/Task.h @@ -62,7 +62,17 @@ class TaskWithExecutor; namespace detail { +class TaskPromiseBase; + +class TaskPromisePrivate { + private: + friend TaskPromiseBase; + TaskPromisePrivate() = default; +}; + class TaskPromiseBase { + static TaskPromisePrivate privateTag() { return TaskPromisePrivate{}; } + class FinalAwaiter { public: bool await_ready() noexcept { return false; } @@ -77,23 +87,26 @@ class TaskPromiseBase { // // This is a bit untidy, and hopefully something we can replace with // a virtual wrapper over coroutine_handle that handles the pop for us. - if (promise.scopeExit_) { - promise.scopeExit_.promise().setContext( - promise.continuation_, - &promise.asyncFrame_, - promise.executor_.get_alias(), - promise.result_.hasException() ? promise.result_.exception() - : exception_wrapper{}); - return promise.scopeExit_; + if (promise.scopeExitRef(privateTag())) { + promise.scopeExitRef(privateTag()) + .promise() + .setContext( + promise.continuationRef(privateTag()), + &promise.getAsyncFrame(), + promise.executorRef(privateTag()).get_alias(), + promise.result().hasException() ? promise.result().exception() + : exception_wrapper{}); + return promise.scopeExitRef(privateTag()); } - folly::popAsyncStackFrameCallee(promise.asyncFrame_); - if (promise.result_.hasException()) { + folly::popAsyncStackFrameCallee(promise.getAsyncFrame()); + if (promise.result().hasException()) { auto [handle, frame] = - promise.continuation_.getErrorHandle(promise.result_.exception()); + promise.continuationRef(privateTag()) + .getErrorHandle(promise.result().exception()); return handle.getHandle(); } - return promise.continuation_.getHandle(); + return promise.continuationRef(privateTag()).getHandle(); } [[noreturn]] void await_resume() noexcept { folly::assume_unreachable(); } @@ -167,6 +180,14 @@ class TaskPromiseBase { return executor_; } + // These getters exist so that `FinalAwaiter` can interact with wrapped + // `TaskPromise`s, and not just `TaskPromiseBase` descendants. We use a + // private tag to let `TaskWrapper` call them without becoming a `friend`. + auto& scopeExitRef(TaskPromisePrivate) { return scopeExit_; } + auto& continuationRef(TaskPromisePrivate) { return continuation_; } + // Unlike `getExecutor()`, does not copy an atomic. + auto& executorRef(TaskPromisePrivate) { return executor_; } + private: template friend class folly::coro::TaskWithExecutor; diff --git a/folly/coro/TaskWrapper.h b/folly/coro/TaskWrapper.h new file mode 100644 index 00000000000..b8ac36bdd28 --- /dev/null +++ b/folly/coro/TaskWrapper.h @@ -0,0 +1,214 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +/// The header provides base classes for wrapping `folly::coro::Task` with +/// custom functionality. These work by composition, which avoids the +/// pitfalls of inheritance -- your custom wrapper will not be "is-a-Task", +/// and will not implicitly "object slice" to a `Task`. +/// +/// Keep in mind that some destructive APIs, like `.semi()`, effectively +/// unwrap the `Task`. If this is important for your use-case, you may need +/// to add features (e.g. `TaskWithExecutorWrapper`, on-unwrap callbacks). +/// +/// The point of this header is to uniformly forward the large API surface +/// of `TaskPromise` & `Task`, leaving just the "new logic" in each wrapper. +/// As `Task.h` evolves, a central `TaskWrapper.h` is easier to maintain. +/// +/// You'll derive from `TaskWrapperPromise` -- which must reference a +/// derived class of `TaskWrapperCrtp` that is your new user-facing coro. +/// +/// To discourage inheritance and object-slicing bugs, mark your derived +/// wrappers `final` -- they can be wrapped recursively. +/// +/// Read `TaskWrapperTest.cpp` for examples of a minimal & recursive wrapper. +/// +/// Future: Once this has a benchmark, see if `FOLLY_ALWAYS_INLINE` makes +/// any difference on the wrapped functions (it shouldn't). + +namespace folly::coro { + +namespace detail { +template +class TaskPromiseWrapperBase; +} // namespace detail + +template +class TaskWrapperCrtp; + +namespace detail { + +template +using task_wrapper_underlying_semiawaitable_t = + typename Wrapper::TaskWrapperUnderlyingSemiAwaitable; + +template +inline constexpr bool is_task_or_wrapper_v = + (!std::is_same_v && // Does not wrap Task + (std::is_same_v> || // Wraps Task + is_task_or_wrapper_v< + detected_t, + T>)); + +template +using task_wrapper_underlying_promise_t = + typename Wrapper::TaskWrapperUnderlyingPromise; + +template +inline constexpr bool is_task_promise_or_wrapper_v = + (!std::is_same_v && // Does not wrap TaskPromise + (std::is_same_v> || // Wraps TaskPromise + is_task_promise_or_wrapper_v< + detected_t, + T>)); + +template +class TaskPromiseWrapperBase { + protected: + static_assert( + is_task_or_wrapper_v, + "SemiAwaitable must be a sequence of wrappers ending in Task"); + static_assert( + is_task_promise_or_wrapper_v, + "Promise must be a sequence of wrappers ending in TaskPromise"); + + Promise promise_; + + TaskPromiseWrapperBase() noexcept = default; + ~TaskPromiseWrapperBase() = default; + + public: + using TaskWrapperUnderlyingPromise = Promise; + + WrappedSemiAwaitable get_return_object() noexcept { + return WrappedSemiAwaitable{promise_.get_return_object()}; + } + + static void* operator new(std::size_t size) { + return ::folly_coro_async_malloc(size); + } + static void operator delete(void* ptr, std::size_t size) { + ::folly_coro_async_free(ptr, size); + } + + auto initial_suspend() noexcept { return promise_.initial_suspend(); } + auto final_suspend() noexcept { return promise_.final_suspend(); } + + auto await_transform(auto&& what) { + return promise_.await_transform(std::forward(what)); + } + + auto yield_value(auto&& v) + requires requires { promise_.yield_value(std::forward(v)); } + { + return promise_.yield_value(std::forward(v)); + } + + void unhandled_exception() noexcept { promise_.unhandled_exception(); } + + // These getters are all interposed for `TaskPromiseBase::FinalAwaiter` + decltype(auto) result() { return promise_.result(); } + decltype(auto) getAsyncFrame() { return promise_.getAsyncFrame(); } + auto& scopeExitRef(TaskPromisePrivate tag) { + return promise_.scopeExitRef(tag); + } + auto& continuationRef(TaskPromisePrivate tag) { + return promise_.continuationRef(tag); + } + auto& executorRef(TaskPromisePrivate tag) { + return promise_.executorRef(tag); + } +}; + +template +class TaskPromiseWrapper + : public TaskPromiseWrapperBase { + protected: + TaskPromiseWrapper() noexcept = default; + ~TaskPromiseWrapper() = default; + + public: + template // see `returnImplicitCtor` test + auto return_value(U&& value) { + return this->promise_.return_value(std::forward(value)); + } +}; + +template +class TaskPromiseWrapper + : public TaskPromiseWrapperBase { + protected: + TaskPromiseWrapper() noexcept = default; + ~TaskPromiseWrapper() = default; + + public: + void return_void() noexcept { this->promise_.return_void(); } +}; + +} // namespace detail + +template +class TaskWrapperCrtp { + private: + static_assert( + detail::is_task_or_wrapper_v, + "TaskWrapperCrtp must wrap a sequence of wrappers ending in Task"); + + SemiAwaitable task_; + + protected: + template + friend class ::folly::coro::detail::TaskPromiseWrapperBase; + + explicit TaskWrapperCrtp(SemiAwaitable t) : task_(std::move(t)) {} + + SemiAwaitable unwrap() && { return std::move(task_); } + + public: + using TaskWrapperUnderlyingSemiAwaitable = SemiAwaitable; + + // NB: In the future, this might ALSO produce a wrapped object. + FOLLY_NODISCARD + TaskWithExecutor scheduleOn(Executor::KeepAlive<> executor) && noexcept { + return std::move(task_).scheduleOn(std::move(executor)); + } + + FOLLY_NOINLINE auto semi() && { return std::move(task_).semi(); } + + friend Derived co_withCancellation( + folly::CancellationToken cancelToken, Derived&& tw) noexcept { + return Derived{ + co_withCancellation(std::move(cancelToken), std::move(tw.task_))}; + } + + friend auto co_viaIfAsync( + folly::Executor::KeepAlive<> executor, Derived&& tw) noexcept { + return co_viaIfAsync(std::move(executor), std::move(tw.task_)); + } + // At least in Clang 15, the `static_assert` isn't enough to get a usable + // error message (it is instantiated too late), but the deprecation + // warning does show up. + [[deprecated( + "Error: Use `co_await std::move(lvalue)`, not `co_await lvalue`.")]] + friend Derived co_viaIfAsync(folly::Executor::KeepAlive<>, const Derived&) { + static_assert("Use `co_await std::move(lvalue)`, not `co_await lvalue`."); + } +}; + +} // namespace folly::coro diff --git a/folly/coro/test/BUCK b/folly/coro/test/BUCK index 0398acb9d2f..0a33cd8327f 100644 --- a/folly/coro/test/BUCK +++ b/folly/coro/test/BUCK @@ -346,6 +346,17 @@ cpp_benchmark( ], ) +cpp_unittest( + name = "task_wrapper_test", + srcs = ["TaskWrapperTest.cpp"], + deps = [ + "//folly/coro:gtest_helpers", + "//folly/coro:task_wrapper", + "//folly/coro:timeout", + "//folly/fibers:semaphore", + ], +) + cpp_unittest( name = "RustAdaptorsTest", srcs = [ diff --git a/folly/coro/test/TaskWrapperTest.cpp b/folly/coro/test/TaskWrapperTest.cpp new file mode 100644 index 00000000000..e27e027ba38 --- /dev/null +++ b/folly/coro/test/TaskWrapperTest.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +using namespace std::literals::chrono_literals; + +namespace folly::coro { + +template +struct TinyTask; + +namespace detail { +template +class TinyTaskPromise final + : public TaskPromiseWrapper, TaskPromise> {}; +} // namespace detail + +template +struct TinyTask final : public TaskWrapperCrtp, T, Task> { + using promise_type = detail::TinyTaskPromise; + using TaskWrapperCrtp, T, Task>::TaskWrapperCrtp; +}; + +CO_TEST(TaskWrapper, trivial) { + EXPECT_EQ( + 1337, co_await [](int x) -> TinyTask { co_return 1300 + x; }(37)); +} + +namespace { +TinyTask intFunc(auto x) { + co_return *x; +} +} // namespace + +CO_TEST(TaskWrapper, returnsNonVoid) { + auto x = std::make_unique(17); + auto lambdaTmpl = [](auto x) -> TinyTask { co_return x; }; + EXPECT_EQ(20, co_await intFunc(std::move(x)) + co_await lambdaTmpl(3)); +} + +namespace { +TinyTask voidFunc(auto x, int* ran) { + EXPECT_EQ(17, *x); + ++*ran; + co_return; +} +} // namespace + +CO_TEST(TaskWrapper, returnsVoidLambda) { + int ran = 0; + auto lambdaTmpl = [&](auto x) -> TinyTask { + EXPECT_EQ(3, x); + ++ran; + co_return; + }; + co_await lambdaTmpl(3); + EXPECT_EQ(1, ran); +} + +CO_TEST(TaskWrapper, returnsVoidFn) { + int ran = 0; + auto x = std::make_unique(17); + co_await voidFunc(std::move(x), &ran); + EXPECT_EQ(1, ran); +} + +CO_TEST(TaskWrapper, awaitsTask) { + EXPECT_EQ( + 1337, co_await []() -> TinyTask { + co_return 1300 + co_await ([]() -> Task { co_return 37; }()); + }()); +} + +CO_TEST(TaskWrapper, cancellation) { + bool ran = false; + EXPECT_THROW( + co_await timeout( + [&]() -> TinyTask { + ran = true; + folly::fibers::Semaphore stuck{0}; // a cancellable baton + co_await stuck.co_wait(); + }(), + 200ms), + folly::FutureTimeout); + EXPECT_TRUE(ran); +} + +namespace { +struct MyError : std::exception {}; +} // namespace + +CO_TEST(TaskWrapper, throws) { + EXPECT_THROW( + co_await []() -> TinyTask { co_yield co_error(MyError{}); }(), + MyError); +} + +CO_TEST(TaskWrapper, co_awaitTry) { + auto res = co_await co_awaitTry( + []() -> TinyTask { co_yield co_error(MyError{}); }()); + EXPECT_TRUE(res.hasException()); +} + +CO_TEST(TaskWrapper, returnImplicitCtor) { + auto t = []() -> TinyTask> { co_return {3, 4}; }; + EXPECT_EQ(std::pair(3, 4), co_await t()); +} + +template +struct RecursiveWrapTask; + +namespace detail { +template +class RecursiveWrapTaskPromise final + : public TaskPromiseWrapper< + T, + RecursiveWrapTask, + InnerPromise> {}; +} // namespace detail + +template +struct RecursiveWrapTask final + : public TaskWrapperCrtp< + RecursiveWrapTask, + T, + InnerSemiAwaitable> { + using promise_type = + detail::RecursiveWrapTaskPromise; + using TaskWrapperCrtp< + RecursiveWrapTask, + T, + InnerSemiAwaitable>::TaskWrapperCrtp; + using TaskWrapperCrtp< + RecursiveWrapTask, + T, + InnerSemiAwaitable>::unwrap; +}; + +template +using TwoWrapTask = + RecursiveWrapTask, detail::TinyTaskPromise>; +template +using TwoWrapTaskPromise = detail:: + RecursiveWrapTaskPromise, detail::TinyTaskPromise>; + +template +using ThreeWrapTask = + RecursiveWrapTask, TwoWrapTaskPromise>; +template +using ThreeWrapTaskPromise = + detail::RecursiveWrapTaskPromise, TwoWrapTaskPromise>; + +CO_TEST(TaskWrapper, recursiveUnwrap) { + auto t = []() -> ThreeWrapTask { co_return 3; }; + EXPECT_EQ(3, co_await t()); + static_assert(std::is_same_v>); + EXPECT_EQ(3, co_await t().unwrap()); + static_assert(std::is_same_v>); + EXPECT_EQ(3, co_await t().unwrap().unwrap()); +} + +} // namespace folly::coro