diff --git a/include/ucoro/awaitable.hpp b/include/ucoro/awaitable.hpp index 132f411..acd9e9d 100644 --- a/include/ucoro/awaitable.hpp +++ b/include/ucoro/awaitable.hpp @@ -160,82 +160,76 @@ namespace ucoro }; ////////////////////////////////////////////////////////////////////////// - // - - struct awaitable_promise_base - : public debug_coro_promise + // 存储协程 promise 的返回值 + template + struct awaitable_promise_value { - auto initial_suspend() + template + void return_value(V &&val) noexcept { - return std::suspend_always{}; + value_.template emplace(std::forward(val)); } - template - requires(detail::is_awaiter_v>) - auto await_transform(A&& awaiter) const + void unhandled_exception() noexcept { - static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); - return std::forward(awaiter); + value_.template emplace(std::current_exception()); } - template - requires(!detail::is_awaiter_v>) - auto await_transform(A&& awaiter) const - { - return await_transformer::await_transform(std::move(awaiter)); - } + std::variant value_{ nullptr }; + }; - auto await_transform(local_storage_t) noexcept - { - struct result - { - awaitable_promise_base *this_; + ////////////////////////////////////////////////////////////////////////// + // 存储协程 promise 的返回值 void 的特化实现 + template<> + struct awaitable_promise_value + { + std::exception_ptr exception_{ nullptr }; - constexpr bool await_ready() const noexcept - { - return true; - } + constexpr void return_void() noexcept { } - void await_suspend(std::coroutine_handle) noexcept - { - } + void unhandled_exception() noexcept + { + exception_ = std::current_exception(); + } - auto await_resume() const noexcept - { - return *this_->local_; - } - }; + }; - return result{this}; + struct awaitable_promise_local_storage + { + std::shared_ptr local_; + + void set_local(std::any local) + { + local_ = std::make_shared(local); } template - auto await_transform(local_storage_t) + struct local_storage_awaiter { - struct result - { - awaitable_promise_base *this_; + awaitable_promise_local_storage *this_; - constexpr bool await_ready() const noexcept - { - return true; - } + constexpr bool await_ready() const noexcept { return true; } + void await_suspend(std::coroutine_handle) noexcept {} - void await_suspend(std::coroutine_handle) noexcept + auto await_resume() const noexcept + { + if constexpr (std::is_void_v) { + return *this_->local_; } - - auto await_resume() + else { return std::any_cast(*this_->local_); } - }; + } + }; - return result{this}; + template + auto await_transform(local_storage_t) + { + return local_storage_awaiter{this}; } - std::coroutine_handle<> continuation_; - std::shared_ptr local_; }; ////////////////////////////////////////////////////////////////////////// @@ -243,8 +237,10 @@ namespace ucoro // Promise 类型实现... template - struct awaitable_promise : public awaitable_promise_base + struct awaitable_promise : public awaitable_promise_value, public awaitable_promise_local_storage, public debug_coro_promise { + using awaitable_promise_local_storage::await_transform; + awaitable get_return_object(); auto final_suspend() noexcept @@ -252,43 +248,27 @@ namespace ucoro return final_awaitable{}; } - template - void return_value(V &&val) noexcept - { - value_.template emplace(std::forward(val)); - } - - void unhandled_exception() noexcept - { - value_.template emplace(std::current_exception()); - } - - std::variant value_{ nullptr }; - }; - - ////////////////////////////////////////////////////////////////////////// - // 返回 void 的协程偏特化 awaitable_promise 实现 - - template <> - struct awaitable_promise : public awaitable_promise_base - { - awaitable get_return_object(); - - auto final_suspend() noexcept + auto initial_suspend() { - return final_awaitable{}; + return std::suspend_always{}; } - void return_void() + template + requires(detail::is_awaiter_v>) + auto await_transform(A&& awaiter) const { + static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); + return std::forward(awaiter); } - void unhandled_exception() noexcept + template + requires(!detail::is_awaiter_v>) + auto await_transform(A&& awaiter) const { - exception_ = std::current_exception(); + return await_transformer::await_transform(std::move(awaiter)); } - std::exception_ptr exception_{ nullptr }; + std::coroutine_handle<> continuation_; }; ////////////////////////////////////////////////////////////////////////// @@ -381,23 +361,24 @@ namespace ucoro } } - template - auto await_suspend(std::coroutine_handle continuation) + template + auto await_suspend(std::coroutine_handle> continuation) { - current_coro_handle_.promise().continuation_ = continuation; + current_coro_handle_.promise().local_ = continuation.promise().local_; - if constexpr (std::is_base_of_v) - { - current_coro_handle_.promise().local_ = continuation.promise().local_; - } + return await_suspend(static_cast>(continuation)); + } + auto await_suspend(std::coroutine_handle continuation) + { + current_coro_handle_.promise().continuation_ = continuation; return current_coro_handle_; } void set_local(std::any local) { assert("local has value" && !current_coro_handle_.promise().local_); - current_coro_handle_.promise().local_ = std::make_shared(local); + current_coro_handle_.promise().set_local(local); } void detach() @@ -428,12 +409,6 @@ namespace ucoro auto result = awaitable{std::coroutine_handle>::from_promise(*this)}; return result; } - - awaitable awaitable_promise::get_return_object() - { - auto result = awaitable{std::coroutine_handle>::from_promise(*this)}; - return result; - } } // namespace ucoro //////////////////////////////////////////////////////////////////////////