From 4575f8ce6d8daf50153245e20a75e1d01739d01f Mon Sep 17 00:00:00 2001 From: microcai Date: Mon, 14 Oct 2024 20:57:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E7=9A=84=20awaitable=20=E6=A8=A1?= =?UTF-8?q?=E6=9D=BF=E7=89=B9=E5=8C=96=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit awaitable 和 awaitable 行为逻辑略有不同。 最初使用的是特化 awaitable_promise_type 的方式 现在改为不特化 awaitable, 也不特化 awaitable_promise_type, 而是将针对 T 不同的部分,作为 awaitable_promise_value 然后特化 awaitable_promise_value 即可。 这样需要被特化的代码量有所减少。 --- include/ucoro/awaitable.hpp | 164 +++++++++++++++--------------------- 1 file changed, 68 insertions(+), 96 deletions(-) diff --git a/include/ucoro/awaitable.hpp b/include/ucoro/awaitable.hpp index 132f411..b19dab9 100644 --- a/include/ucoro/awaitable.hpp +++ b/include/ucoro/awaitable.hpp @@ -86,6 +86,12 @@ namespace ucoro { a.await_suspend(std::coroutine_handle<>{}) } -> is_valid_await_suspend_return_value; { a.await_resume() }; }; + + // 用于判定 T 是否是一个 awaitable<>::promise_type 的类型, 即: 拥有 local_ 成员。 + template + concept is_awaitable_promise_type_v = requires (T a){ + { a.local_ } -> std::convertible_to> ; + }; } // namespace detail struct debug_coro_promise @@ -160,82 +166,38 @@ namespace ucoro }; ////////////////////////////////////////////////////////////////////////// - // - - struct awaitable_promise_base - : public debug_coro_promise + // 存储协程 promise 的返回值 + template + struct awaitable_promise_value { - auto initial_suspend() - { - return std::suspend_always{}; - } - - template - requires(detail::is_awaiter_v>) - auto await_transform(A&& awaiter) const + template + void return_value(V &&val) noexcept { - static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); - return std::forward(awaiter); + value_.template emplace(std::forward(val)); } - template - requires(!detail::is_awaiter_v>) - auto await_transform(A&& awaiter) const + void unhandled_exception() noexcept { - return await_transformer::await_transform(std::move(awaiter)); + value_.template emplace(std::current_exception()); } - auto await_transform(local_storage_t) noexcept - { - struct result - { - awaitable_promise_base *this_; - - constexpr bool await_ready() const noexcept - { - return true; - } - - void await_suspend(std::coroutine_handle) noexcept - { - } + std::variant value_{ nullptr }; + }; - auto await_resume() const noexcept - { - return *this_->local_; - } - }; + ////////////////////////////////////////////////////////////////////////// + // 存储协程 promise 的返回值 void 的特化实现 + template<> + struct awaitable_promise_value + { + std::exception_ptr exception_{ nullptr }; - return result{this}; - } + constexpr void return_void() noexcept { } - template - auto await_transform(local_storage_t) + void unhandled_exception() noexcept { - struct result - { - awaitable_promise_base *this_; - - constexpr bool await_ready() const noexcept - { - return true; - } - - void await_suspend(std::coroutine_handle) noexcept - { - } - - auto await_resume() - { - return std::any_cast(*this_->local_); - } - }; - - return result{this}; + exception_ = std::current_exception(); } - std::coroutine_handle<> continuation_; - std::shared_ptr local_; }; ////////////////////////////////////////////////////////////////////////// @@ -243,7 +205,7 @@ namespace ucoro // Promise 类型实现... template - struct awaitable_promise : public awaitable_promise_base + struct awaitable_promise : public awaitable_promise_value, public debug_coro_promise { awaitable get_return_object(); @@ -252,43 +214,60 @@ namespace ucoro return final_awaitable{}; } - template - void return_value(V &&val) noexcept + auto initial_suspend() { - value_.template emplace(std::forward(val)); + return std::suspend_always{}; } - void unhandled_exception() noexcept + template + requires (detail::is_awaiter_v>) + auto await_transform(A&& awaiter) const { - value_.template emplace(std::current_exception()); + static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); + return std::forward(awaiter); } - std::variant value_{ nullptr }; - }; - - ////////////////////////////////////////////////////////////////////////// - // 返回 void 的协程偏特化 awaitable_promise 实现 - - template <> - struct awaitable_promise : public awaitable_promise_base - { - awaitable get_return_object(); - - auto final_suspend() noexcept + template + requires (!detail::is_awaiter_v>) + auto await_transform(A&& awaiter) const { - return final_awaitable{}; + return await_transformer::await_transform(std::move(awaiter)); } - void return_void() + void set_local(std::any local) { + local_ = std::make_shared(local); } - void unhandled_exception() noexcept + template + struct local_storage_awaiter { - exception_ = std::current_exception(); + awaitable_promise *this_; + + constexpr bool await_ready() const noexcept { return true; } + void await_suspend(std::coroutine_handle) noexcept {} + + auto await_resume() const noexcept + { + if constexpr (std::is_void_v) + { + return *this_->local_; + } + else + { + return std::any_cast(*this_->local_); + } + } + }; + + template + auto await_transform(local_storage_t) + { + return local_storage_awaiter{this}; } - std::exception_ptr exception_{ nullptr }; + std::coroutine_handle<> continuation_; + std::shared_ptr local_; }; ////////////////////////////////////////////////////////////////////////// @@ -384,20 +363,19 @@ namespace ucoro template auto await_suspend(std::coroutine_handle continuation) { - current_coro_handle_.promise().continuation_ = continuation; - - if constexpr (std::is_base_of_v) + if constexpr (detail::is_awaitable_promise_type_v) { current_coro_handle_.promise().local_ = continuation.promise().local_; } + current_coro_handle_.promise().continuation_ = continuation; return current_coro_handle_; } void set_local(std::any local) { assert("local has value" && !current_coro_handle_.promise().local_); - current_coro_handle_.promise().local_ = std::make_shared(local); + current_coro_handle_.promise().set_local(local); } void detach() @@ -428,12 +406,6 @@ namespace ucoro auto result = awaitable{std::coroutine_handle>::from_promise(*this)}; return result; } - - awaitable awaitable_promise::get_return_object() - { - auto result = awaitable{std::coroutine_handle>::from_promise(*this)}; - return result; - } } // namespace ucoro //////////////////////////////////////////////////////////////////////////