Skip to content

Commit

Permalink
新的 awaitable<T> 模板特化方法
Browse files Browse the repository at this point in the history
awaitable<T> 和 awaitable<void> 行为逻辑略有不同。

最初使用的是特化 awaitable_promise_type<void> 的方式

现在改为不特化 awaitable, 也不特化 awaitable_promise_type,
而是将针对 T 不同的部分,作为 awaitable_promise_value<T>
然后特化 awaitable_promise_value<void> 即可。

这样需要被特化的代码量有所减少。
  • Loading branch information
microcai committed Oct 14, 2024
1 parent 3f107d2 commit f58ddc6
Showing 1 changed file with 70 additions and 93 deletions.
163 changes: 70 additions & 93 deletions include/ucoro/awaitable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,135 +160,117 @@ namespace ucoro
};

//////////////////////////////////////////////////////////////////////////
//

struct awaitable_promise_base
: public debug_coro_promise
// 存储协程 promise 的返回值
template<typename T>
struct awaitable_promise_value
{
auto initial_suspend()
template <typename V>
void return_value(V &&val) noexcept
{
return std::suspend_always{};
value_.template emplace<T>(std::forward<V>(val));
}

template <typename A>
requires(detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
void unhandled_exception() noexcept
{
static_assert(std::is_rvalue_reference_v<decltype(awaiter)>, "co_await must be used on rvalue");
return std::forward<A>(awaiter);
value_.template emplace<std::exception_ptr>(std::current_exception());
}

template <typename A>
requires(!detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
{
return await_transformer<A>::await_transform(std::move(awaiter));
}
std::variant<std::exception_ptr, T> value_{ nullptr };
};

auto await_transform(local_storage_t<void>) noexcept
{
struct result
{
awaitable_promise_base *this_;
//////////////////////////////////////////////////////////////////////////
// 存储协程 promise 的返回值 void 的特化实现
template<>
struct awaitable_promise_value<void>
{
std::exception_ptr exception_{ nullptr };

constexpr bool await_ready() const noexcept
{
return true;
}
constexpr void return_void() noexcept { }

void await_suspend(std::coroutine_handle<void>) noexcept
{
}
void unhandled_exception() noexcept
{
exception_ = std::current_exception();
}

auto await_resume() const noexcept
{
return *this_->local_;
}
};
};

struct awaitable_promise_local_storage
{
std::shared_ptr<std::any> local_;

return result{this};
void set_local(std::any local)
{
local_ = std::make_shared<std::any>(local);
}

template <typename T>
auto await_transform(local_storage_t<T>)
struct local_storage_awaiter
{
struct result
{
awaitable_promise_base *this_;
awaitable_promise_local_storage *this_;

constexpr bool await_ready() const noexcept
{
return true;
}
constexpr bool await_ready() const noexcept { return true; }
void await_suspend(std::coroutine_handle<void>) noexcept {}

void await_suspend(std::coroutine_handle<void>) noexcept
auto await_resume() const noexcept
{
if constexpr (std::is_void_v<T>)
{
return *this_->local_;
}

auto await_resume()
else
{
return std::any_cast<T>(*this_->local_);
}
};
}
};

return result{this};
template <typename T>
auto await_transform(local_storage_t<T>)
{
return local_storage_awaiter<T>{this};
}

std::coroutine_handle<> continuation_;
std::shared_ptr<std::any> local_;
};

//////////////////////////////////////////////////////////////////////////
// 返回 T 的协程 awaitable_promise 实现.

// Promise 类型实现...
template <typename T>
struct awaitable_promise : public awaitable_promise_base
struct awaitable_promise : public awaitable_promise_value<T>, public awaitable_promise_local_storage, public debug_coro_promise
{
~awaitable_promise(){}

using awaitable_promise_local_storage::await_transform;

awaitable<T> get_return_object();

auto final_suspend() noexcept
{
return final_awaitable<T>{};
}

template <typename V>
void return_value(V &&val) noexcept
{
value_.template emplace<T>(std::forward<V>(val));
}

void unhandled_exception() noexcept
{
value_.template emplace<std::exception_ptr>(std::current_exception());
}

std::variant<std::exception_ptr, T> value_{ nullptr };
};

//////////////////////////////////////////////////////////////////////////
// 返回 void 的协程偏特化 awaitable_promise 实现

template <>
struct awaitable_promise<void> : public awaitable_promise_base
{
awaitable<void> get_return_object();

auto final_suspend() noexcept
auto initial_suspend()
{
return final_awaitable<void>{};
return std::suspend_always{};
}

void return_void()
template <typename A>
requires(detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
{
static_assert(std::is_rvalue_reference_v<decltype(awaiter)>, "co_await must be used on rvalue");
return std::forward<A>(awaiter);
}

void unhandled_exception() noexcept
template <typename A>
requires(!detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
{
exception_ = std::current_exception();
return await_transformer<A>::await_transform(std::move(awaiter));
}

std::exception_ptr exception_{ nullptr };
std::coroutine_handle<> continuation_;
};

//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -381,23 +363,24 @@ namespace ucoro
}
}

template <typename PromiseType>
auto await_suspend(std::coroutine_handle<PromiseType> continuation)
template <typename OT>
auto await_suspend(std::coroutine_handle<awaitable_promise<OT>> continuation)
{
current_coro_handle_.promise().continuation_ = continuation;
current_coro_handle_.promise().local_ = continuation.promise().local_;

if constexpr (std::is_base_of_v<awaitable_promise_base, PromiseType>)
{
current_coro_handle_.promise().local_ = continuation.promise().local_;
}
return await_suspend(static_cast<std::coroutine_handle<void>>(continuation));
}

auto await_suspend(std::coroutine_handle<void> continuation)
{
current_coro_handle_.promise().continuation_ = continuation;
return current_coro_handle_;
}

void set_local(std::any local)
{
assert("local has value" && !current_coro_handle_.promise().local_);
current_coro_handle_.promise().local_ = std::make_shared<std::any>(local);
current_coro_handle_.promise().set_local(local);
}

void detach()
Expand Down Expand Up @@ -428,12 +411,6 @@ namespace ucoro
auto result = awaitable<T>{std::coroutine_handle<awaitable_promise<T>>::from_promise(*this)};
return result;
}

awaitable<void> awaitable_promise<void>::get_return_object()
{
auto result = awaitable<void>{std::coroutine_handle<awaitable_promise<void>>::from_promise(*this)};
return result;
}
} // namespace ucoro

//////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit f58ddc6

Please sign in to comment.