diff --git a/include/ucoro/awaitable.hpp b/include/ucoro/awaitable.hpp index 132f411..bb9325b 100644 --- a/include/ucoro/awaitable.hpp +++ b/include/ucoro/awaitable.hpp @@ -86,6 +86,10 @@ namespace ucoro { a.await_suspend(std::coroutine_handle<>{}) } -> is_valid_await_suspend_return_value; { a.await_resume() }; }; + template + concept is_awaitable_promise_type_v = requires (T a){ + { a.local_ } -> std::convertible_to ; + }; } // namespace detail struct debug_coro_promise @@ -160,82 +164,38 @@ namespace ucoro }; ////////////////////////////////////////////////////////////////////////// - // - - struct awaitable_promise_base - : public debug_coro_promise + // 存储协程 promise 的返回值 + template + struct awaitable_promise_value { - auto initial_suspend() - { - return std::suspend_always{}; - } - - template - requires(detail::is_awaiter_v>) - auto await_transform(A&& awaiter) const + template + void return_value(V &&val) noexcept { - static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); - return std::forward(awaiter); + value_.template emplace(std::forward(val)); } - template - requires(!detail::is_awaiter_v>) - auto await_transform(A&& awaiter) const + void unhandled_exception() noexcept { - return await_transformer::await_transform(std::move(awaiter)); + value_.template emplace(std::current_exception()); } - auto await_transform(local_storage_t) noexcept - { - struct result - { - awaitable_promise_base *this_; - - constexpr bool await_ready() const noexcept - { - return true; - } - - void await_suspend(std::coroutine_handle) noexcept - { - } + std::variant value_{ nullptr }; + }; - auto await_resume() const noexcept - { - return *this_->local_; - } - }; + ////////////////////////////////////////////////////////////////////////// + // 存储协程 promise 的返回值 void 的特化实现 + template<> + struct awaitable_promise_value + { + std::exception_ptr exception_{ nullptr }; - return result{this}; - } + constexpr void return_void() noexcept { } - template - auto await_transform(local_storage_t) + void unhandled_exception() noexcept { - struct result - { - awaitable_promise_base *this_; - - constexpr bool await_ready() const noexcept - { - return true; - } - - void await_suspend(std::coroutine_handle) noexcept - { - } - - auto await_resume() - { - return std::any_cast(*this_->local_); - } - }; - - return result{this}; + exception_ = std::current_exception(); } - std::coroutine_handle<> continuation_; - std::shared_ptr local_; }; ////////////////////////////////////////////////////////////////////////// @@ -243,7 +203,7 @@ namespace ucoro // Promise 类型实现... template - struct awaitable_promise : public awaitable_promise_base + struct awaitable_promise : public awaitable_promise_value, public debug_coro_promise { awaitable get_return_object(); @@ -252,43 +212,60 @@ namespace ucoro return final_awaitable{}; } - template - void return_value(V &&val) noexcept + auto initial_suspend() { - value_.template emplace(std::forward(val)); + return std::suspend_always{}; } - void unhandled_exception() noexcept + template + requires (detail::is_awaiter_v>) + auto await_transform(A&& awaiter) const { - value_.template emplace(std::current_exception()); + static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); + return std::forward(awaiter); } - std::variant value_{ nullptr }; - }; - - ////////////////////////////////////////////////////////////////////////// - // 返回 void 的协程偏特化 awaitable_promise 实现 - - template <> - struct awaitable_promise : public awaitable_promise_base - { - awaitable get_return_object(); - - auto final_suspend() noexcept + template + requires (!detail::is_awaiter_v>) + auto await_transform(A&& awaiter) const { - return final_awaitable{}; + return await_transformer::await_transform(std::move(awaiter)); } - void return_void() + void set_local(std::any local) { + local_ = std::make_shared(local); } - void unhandled_exception() noexcept + template + struct local_storage_awaiter { - exception_ = std::current_exception(); + awaitable_promise *this_; + + constexpr bool await_ready() const noexcept { return true; } + void await_suspend(std::coroutine_handle) noexcept {} + + auto await_resume() const noexcept + { + if constexpr (std::is_void_v) + { + return *this_->local_; + } + else + { + return std::any_cast(*this_->local_); + } + } + }; + + template + auto await_transform(local_storage_t) + { + return local_storage_awaiter{this}; } - std::exception_ptr exception_{ nullptr }; + std::coroutine_handle<> continuation_; + std::shared_ptr local_; }; ////////////////////////////////////////////////////////////////////////// @@ -384,20 +361,25 @@ namespace ucoro template auto await_suspend(std::coroutine_handle continuation) { - current_coro_handle_.promise().continuation_ = continuation; - - if constexpr (std::is_base_of_v) + if constexpr (detail::is_awaitable_promise_type_v) { current_coro_handle_.promise().local_ = continuation.promise().local_; } + current_coro_handle_.promise().continuation_ = continuation; + return current_coro_handle_; + } + + auto await_suspend(std::coroutine_handle continuation) + { + current_coro_handle_.promise().continuation_ = continuation; return current_coro_handle_; } void set_local(std::any local) { assert("local has value" && !current_coro_handle_.promise().local_); - current_coro_handle_.promise().local_ = std::make_shared(local); + current_coro_handle_.promise().set_local(local); } void detach() @@ -428,12 +410,6 @@ namespace ucoro auto result = awaitable{std::coroutine_handle>::from_promise(*this)}; return result; } - - awaitable awaitable_promise::get_return_object() - { - auto result = awaitable{std::coroutine_handle>::from_promise(*this)}; - return result; - } } // namespace ucoro //////////////////////////////////////////////////////////////////////////