Skip to content

Commit

Permalink
coro/TaskWrapper.h, a helper for wrapping Task / TaskPromise
Browse files Browse the repository at this point in the history
Summary:
Here are few possible reasons to want to wrap `Task`:
 - `SafeTask` (upcoming in D61995850) -- quacks exactly like `Task` but includes some additional compile-time enforcement that makes it harder to make dangling-reference bugs.
 - `TaskDfatalIfNotAwaited` -- quacks exactly like `Task`, but used for those cases where you really want to make sure the task is awaited.

It would be possible to open up `Task` / `TaskPromise` for inheritance, but this comes with a risk of object slicing bugs, some quite subtle. Moreover, while for the simplest wrappers the "is-a-Task" implied by relationship can be okay, one can imagine more featureful wrappers where it is absolutely not appropriate. The composition approach is more explicit, and thus safer for the implementer of the wrapper.

The downside of composition is that we have to transparently forward the entire -- and rather large -- API surface of `TaskPromise` / `Task`. This would be far too fragile if it were copy-pasted in each individual wrapper. By collecting the implementation in a central `TaskWapper.h`, the cost of adapting to `Task.h` API changes should stay low.

Reviewed By: ispeters, skrueger

Differential Revision: D62994547

fbshipit-source-id: 5d7e11a386950b9e07d34b28cf7f986f7eb94e66
  • Loading branch information
Alexey Spiridonov authored and facebook-github-bot committed Oct 30, 2024
1 parent 1de3e85 commit df4cfee
Show file tree
Hide file tree
Showing 5 changed files with 443 additions and 12 deletions.
6 changes: 6 additions & 0 deletions folly/coro/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,12 @@ cpp_library(
],
)

cpp_library(
name = "task_wrapper",
headers = ["TaskWrapper.h"],
exported_deps = [":task"],
)

cpp_library(
name = "timed_wait",
headers = ["TimedWait.h"],
Expand Down
45 changes: 33 additions & 12 deletions folly/coro/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,17 @@ class TaskWithExecutor;

namespace detail {

class TaskPromiseBase;

class TaskPromisePrivate {
private:
friend TaskPromiseBase;
TaskPromisePrivate() = default;
};

class TaskPromiseBase {
static TaskPromisePrivate privateTag() { return TaskPromisePrivate{}; }

class FinalAwaiter {
public:
bool await_ready() noexcept { return false; }
Expand All @@ -77,23 +87,26 @@ class TaskPromiseBase {
//
// This is a bit untidy, and hopefully something we can replace with
// a virtual wrapper over coroutine_handle that handles the pop for us.
if (promise.scopeExit_) {
promise.scopeExit_.promise().setContext(
promise.continuation_,
&promise.asyncFrame_,
promise.executor_.get_alias(),
promise.result_.hasException() ? promise.result_.exception()
: exception_wrapper{});
return promise.scopeExit_;
if (promise.scopeExitRef(privateTag())) {
promise.scopeExitRef(privateTag())
.promise()
.setContext(
promise.continuationRef(privateTag()),
&promise.getAsyncFrame(),
promise.executorRef(privateTag()).get_alias(),
promise.result().hasException() ? promise.result().exception()
: exception_wrapper{});
return promise.scopeExitRef(privateTag());
}

folly::popAsyncStackFrameCallee(promise.asyncFrame_);
if (promise.result_.hasException()) {
folly::popAsyncStackFrameCallee(promise.getAsyncFrame());
if (promise.result().hasException()) {
auto [handle, frame] =
promise.continuation_.getErrorHandle(promise.result_.exception());
promise.continuationRef(privateTag())
.getErrorHandle(promise.result().exception());
return handle.getHandle();
}
return promise.continuation_.getHandle();
return promise.continuationRef(privateTag()).getHandle();
}

[[noreturn]] void await_resume() noexcept { folly::assume_unreachable(); }
Expand Down Expand Up @@ -167,6 +180,14 @@ class TaskPromiseBase {
return executor_;
}

// These getters exist so that `FinalAwaiter` can interact with wrapped
// `TaskPromise`s, and not just `TaskPromiseBase` descendants. We use a
// private tag to let `TaskWrapper` call them without becoming a `friend`.
auto& scopeExitRef(TaskPromisePrivate) { return scopeExit_; }
auto& continuationRef(TaskPromisePrivate) { return continuation_; }
// Unlike `getExecutor()`, does not copy an atomic.
auto& executorRef(TaskPromisePrivate) { return executor_; }

private:
template <typename>
friend class folly::coro::TaskWithExecutor;
Expand Down
214 changes: 214 additions & 0 deletions folly/coro/TaskWrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <folly/coro/Task.h>

/// The header provides base classes for wrapping `folly::coro::Task` with
/// custom functionality. These work by composition, which avoids the
/// pitfalls of inheritance -- your custom wrapper will not be "is-a-Task",
/// and will not implicitly "object slice" to a `Task`.
///
/// Keep in mind that some destructive APIs, like `.semi()`, effectively
/// unwrap the `Task`. If this is important for your use-case, you may need
/// to add features (e.g. `TaskWithExecutorWrapper`, on-unwrap callbacks).
///
/// The point of this header is to uniformly forward the large API surface
/// of `TaskPromise` & `Task`, leaving just the "new logic" in each wrapper.
/// As `Task.h` evolves, a central `TaskWrapper.h` is easier to maintain.
///
/// You'll derive from `TaskWrapperPromise` -- which must reference a
/// derived class of `TaskWrapperCrtp` that is your new user-facing coro.
///
/// To discourage inheritance and object-slicing bugs, mark your derived
/// wrappers `final` -- they can be wrapped recursively.
///
/// Read `TaskWrapperTest.cpp` for examples of a minimal & recursive wrapper.
///
/// Future: Once this has a benchmark, see if `FOLLY_ALWAYS_INLINE` makes
/// any difference on the wrapped functions (it shouldn't).

namespace folly::coro {

namespace detail {
template <typename, typename, typename>
class TaskPromiseWrapperBase;
} // namespace detail

template <typename, typename, typename>
class TaskWrapperCrtp;

namespace detail {

template <typename Wrapper>
using task_wrapper_underlying_semiawaitable_t =
typename Wrapper::TaskWrapperUnderlyingSemiAwaitable;

template <typename SemiAwaitable, typename T>
inline constexpr bool is_task_or_wrapper_v =
(!std::is_same_v<nonesuch, SemiAwaitable> && // Does not wrap Task
(std::is_same_v<SemiAwaitable, Task<T>> || // Wraps Task
is_task_or_wrapper_v<
detected_t<task_wrapper_underlying_semiawaitable_t, SemiAwaitable>,
T>));

template <typename Wrapper>
using task_wrapper_underlying_promise_t =
typename Wrapper::TaskWrapperUnderlyingPromise;

template <typename Promise, typename T>
inline constexpr bool is_task_promise_or_wrapper_v =
(!std::is_same_v<nonesuch, Promise> && // Does not wrap TaskPromise
(std::is_same_v<Promise, TaskPromise<T>> || // Wraps TaskPromise
is_task_promise_or_wrapper_v<
detected_t<task_wrapper_underlying_promise_t, Promise>,
T>));

template <typename T, typename WrappedSemiAwaitable, typename Promise>
class TaskPromiseWrapperBase {
protected:
static_assert(
is_task_or_wrapper_v<WrappedSemiAwaitable, T>,
"SemiAwaitable must be a sequence of wrappers ending in Task<T>");
static_assert(
is_task_promise_or_wrapper_v<Promise, T>,
"Promise must be a sequence of wrappers ending in TaskPromise<T>");

Promise promise_;

TaskPromiseWrapperBase() noexcept = default;
~TaskPromiseWrapperBase() = default;

public:
using TaskWrapperUnderlyingPromise = Promise;

WrappedSemiAwaitable get_return_object() noexcept {
return WrappedSemiAwaitable{promise_.get_return_object()};
}

static void* operator new(std::size_t size) {
return ::folly_coro_async_malloc(size);
}
static void operator delete(void* ptr, std::size_t size) {
::folly_coro_async_free(ptr, size);
}

auto initial_suspend() noexcept { return promise_.initial_suspend(); }
auto final_suspend() noexcept { return promise_.final_suspend(); }

auto await_transform(auto&& what) {
return promise_.await_transform(std::forward<decltype(what)>(what));
}

auto yield_value(auto&& v)
requires requires { promise_.yield_value(std::forward<decltype(v)>(v)); }
{
return promise_.yield_value(std::forward<decltype(v)>(v));
}

void unhandled_exception() noexcept { promise_.unhandled_exception(); }

// These getters are all interposed for `TaskPromiseBase::FinalAwaiter`
decltype(auto) result() { return promise_.result(); }
decltype(auto) getAsyncFrame() { return promise_.getAsyncFrame(); }
auto& scopeExitRef(TaskPromisePrivate tag) {
return promise_.scopeExitRef(tag);
}
auto& continuationRef(TaskPromisePrivate tag) {
return promise_.continuationRef(tag);
}
auto& executorRef(TaskPromisePrivate tag) {
return promise_.executorRef(tag);
}
};

template <typename T, typename SemiAwaitable, typename Promise>
class TaskPromiseWrapper
: public TaskPromiseWrapperBase<T, SemiAwaitable, Promise> {
protected:
TaskPromiseWrapper() noexcept = default;
~TaskPromiseWrapper() = default;

public:
template <typename U = T> // see `returnImplicitCtor` test
auto return_value(U&& value) {
return this->promise_.return_value(std::forward<U>(value));
}
};

template <typename SemiAwaitable, typename Promise>
class TaskPromiseWrapper<void, SemiAwaitable, Promise>
: public TaskPromiseWrapperBase<void, SemiAwaitable, Promise> {
protected:
TaskPromiseWrapper() noexcept = default;
~TaskPromiseWrapper() = default;

public:
void return_void() noexcept { this->promise_.return_void(); }
};

} // namespace detail

template <typename Derived, typename T, typename SemiAwaitable>
class TaskWrapperCrtp {
private:
static_assert(
detail::is_task_or_wrapper_v<SemiAwaitable, T>,
"TaskWrapperCrtp must wrap a sequence of wrappers ending in Task<T>");

SemiAwaitable task_;

protected:
template <typename, typename, typename>
friend class ::folly::coro::detail::TaskPromiseWrapperBase;

explicit TaskWrapperCrtp(SemiAwaitable t) : task_(std::move(t)) {}

SemiAwaitable unwrap() && { return std::move(task_); }

public:
using TaskWrapperUnderlyingSemiAwaitable = SemiAwaitable;

// NB: In the future, this might ALSO produce a wrapped object.
FOLLY_NODISCARD
TaskWithExecutor<T> scheduleOn(Executor::KeepAlive<> executor) && noexcept {
return std::move(task_).scheduleOn(std::move(executor));
}

FOLLY_NOINLINE auto semi() && { return std::move(task_).semi(); }

friend Derived co_withCancellation(
folly::CancellationToken cancelToken, Derived&& tw) noexcept {
return Derived{
co_withCancellation(std::move(cancelToken), std::move(tw.task_))};
}

friend auto co_viaIfAsync(
folly::Executor::KeepAlive<> executor, Derived&& tw) noexcept {
return co_viaIfAsync(std::move(executor), std::move(tw.task_));
}
// At least in Clang 15, the `static_assert` isn't enough to get a usable
// error message (it is instantiated too late), but the deprecation
// warning does show up.
[[deprecated(
"Error: Use `co_await std::move(lvalue)`, not `co_await lvalue`.")]]
friend Derived co_viaIfAsync(folly::Executor::KeepAlive<>, const Derived&) {
static_assert("Use `co_await std::move(lvalue)`, not `co_await lvalue`.");
}
};

} // namespace folly::coro
11 changes: 11 additions & 0 deletions folly/coro/test/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,17 @@ cpp_benchmark(
],
)

cpp_unittest(
name = "task_wrapper_test",
srcs = ["TaskWrapperTest.cpp"],
deps = [
"//folly/coro:gtest_helpers",
"//folly/coro:task_wrapper",
"//folly/coro:timeout",
"//folly/fibers:semaphore",
],
)

cpp_unittest(
name = "RustAdaptorsTest",
srcs = [
Expand Down
Loading

0 comments on commit df4cfee

Please sign in to comment.