Skip to content

Commit

Permalink
新的 awaitable<T> 模板特化方法
Browse files Browse the repository at this point in the history
awaitable<T> 和 awaitable<void> 行为逻辑略有不同。

最初使用的是特化 awaitable_promise_type<void> 的方式

现在改为不特化 awaitable, 也不特化 awaitable_promise_type,
而是将针对 T 不同的部分,作为 awaitable_promise_value<T>
然后特化 awaitable_promise_value<void> 即可。

这样需要被特化的代码量有所减少。
  • Loading branch information
microcai committed Oct 14, 2024
1 parent 3f107d2 commit 33960d6
Showing 1 changed file with 71 additions and 99 deletions.
170 changes: 71 additions & 99 deletions include/ucoro/awaitable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ namespace ucoro
{ a.await_suspend(std::coroutine_handle<>{}) } -> is_valid_await_suspend_return_value;
{ a.await_resume() };
};
template <typename T>
concept is_awaitable_promise_type_v = requires (T a){
{ a.local_ } -> std::convertible_to<std::any> ;
};
} // namespace detail

struct debug_coro_promise
Expand Down Expand Up @@ -160,90 +164,46 @@ namespace ucoro
};

//////////////////////////////////////////////////////////////////////////
//

struct awaitable_promise_base
: public debug_coro_promise
// 存储协程 promise 的返回值
template<typename T>
struct awaitable_promise_value
{
auto initial_suspend()
{
return std::suspend_always{};
}

template <typename A>
requires(detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
template <typename V>
void return_value(V &&val) noexcept
{
static_assert(std::is_rvalue_reference_v<decltype(awaiter)>, "co_await must be used on rvalue");
return std::forward<A>(awaiter);
value_.template emplace<T>(std::forward<V>(val));
}

template <typename A>
requires(!detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
void unhandled_exception() noexcept
{
return await_transformer<A>::await_transform(std::move(awaiter));
value_.template emplace<std::exception_ptr>(std::current_exception());
}

auto await_transform(local_storage_t<void>) noexcept
{
struct result
{
awaitable_promise_base *this_;

constexpr bool await_ready() const noexcept
{
return true;
}

void await_suspend(std::coroutine_handle<void>) noexcept
{
}
std::variant<std::exception_ptr, T> value_{ nullptr };
};

auto await_resume() const noexcept
{
return *this_->local_;
}
};
//////////////////////////////////////////////////////////////////////////
// 存储协程 promise 的返回值 void 的特化实现
template<>
struct awaitable_promise_value<void>
{
std::exception_ptr exception_{ nullptr };

return result{this};
}
constexpr void return_void() noexcept { }

template <typename T>
auto await_transform(local_storage_t<T>)
void unhandled_exception() noexcept
{
struct result
{
awaitable_promise_base *this_;

constexpr bool await_ready() const noexcept
{
return true;
}

void await_suspend(std::coroutine_handle<void>) noexcept
{
}

auto await_resume()
{
return std::any_cast<T>(*this_->local_);
}
};

return result{this};
exception_ = std::current_exception();
}

std::coroutine_handle<> continuation_;
std::shared_ptr<std::any> local_;
};

//////////////////////////////////////////////////////////////////////////
// 返回 T 的协程 awaitable_promise 实现.

// Promise 类型实现...
template <typename T>
struct awaitable_promise : public awaitable_promise_base
struct awaitable_promise : public awaitable_promise_value<T>, public debug_coro_promise
{
awaitable<T> get_return_object();

Expand All @@ -252,43 +212,60 @@ namespace ucoro
return final_awaitable<T>{};
}

template <typename V>
void return_value(V &&val) noexcept
auto initial_suspend()
{
value_.template emplace<T>(std::forward<V>(val));
return std::suspend_always{};
}

void unhandled_exception() noexcept
template <typename A>
requires (detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
{
value_.template emplace<std::exception_ptr>(std::current_exception());
static_assert(std::is_rvalue_reference_v<decltype(awaiter)>, "co_await must be used on rvalue");
return std::forward<A>(awaiter);
}

std::variant<std::exception_ptr, T> value_{ nullptr };
};

//////////////////////////////////////////////////////////////////////////
// 返回 void 的协程偏特化 awaitable_promise 实现

template <>
struct awaitable_promise<void> : public awaitable_promise_base
{
awaitable<void> get_return_object();

auto final_suspend() noexcept
template <typename A>
requires (!detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
{
return final_awaitable<void>{};
return await_transformer<A>::await_transform(std::move(awaiter));
}

void return_void()
void set_local(std::any local)
{
local_ = std::make_shared<std::any>(local);
}

void unhandled_exception() noexcept
template <typename localtype>
struct local_storage_awaiter
{
exception_ = std::current_exception();
awaitable_promise *this_;

constexpr bool await_ready() const noexcept { return true; }
void await_suspend(std::coroutine_handle<void>) noexcept {}

auto await_resume() const noexcept
{
if constexpr (std::is_void_v<localtype>)
{
return *this_->local_;
}
else
{
return std::any_cast<localtype>(*this_->local_);
}
}
};

template <typename localtype>
auto await_transform(local_storage_t<localtype>)
{
return local_storage_awaiter<localtype>{this};
}

std::exception_ptr exception_{ nullptr };
std::coroutine_handle<> continuation_;
std::shared_ptr<std::any> local_;
};

//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -381,23 +358,24 @@ namespace ucoro
}
}

template <typename PromiseType>
template <typename PromiseType> requires (detail::is_awaitable_promise_type_v<PromiseType>)
auto await_suspend(std::coroutine_handle<PromiseType> continuation)
{
current_coro_handle_.promise().continuation_ = continuation;
current_coro_handle_.promise().local_ = continuation.promise().local_;

if constexpr (std::is_base_of_v<awaitable_promise_base, PromiseType>)
{
current_coro_handle_.promise().local_ = continuation.promise().local_;
}
return await_suspend(static_cast<std::coroutine_handle<void>>(continuation));
}

auto await_suspend(std::coroutine_handle<void> continuation)
{
current_coro_handle_.promise().continuation_ = continuation;
return current_coro_handle_;
}

void set_local(std::any local)
{
assert("local has value" && !current_coro_handle_.promise().local_);
current_coro_handle_.promise().local_ = std::make_shared<std::any>(local);
current_coro_handle_.promise().set_local(local);
}

void detach()
Expand Down Expand Up @@ -428,12 +406,6 @@ namespace ucoro
auto result = awaitable<T>{std::coroutine_handle<awaitable_promise<T>>::from_promise(*this)};
return result;
}

awaitable<void> awaitable_promise<void>::get_return_object()
{
auto result = awaitable<void>{std::coroutine_handle<awaitable_promise<void>>::from_promise(*this)};
return result;
}
} // namespace ucoro

//////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 33960d6

Please sign in to comment.