diff --git a/include/ucoro/awaitable.hpp b/include/ucoro/awaitable.hpp index cd29d39..6a77176 100644 --- a/include/ucoro/awaitable.hpp +++ b/include/ucoro/awaitable.hpp @@ -73,18 +73,33 @@ namespace ucoro template struct local_storage_t { - typedef T value_type; - typedef void local_storage_type_detect_tag; }; inline constexpr local_storage_t local_storage; - template - concept is_a_local_storage_t = std::is_same_v; - ////////////////////////////////////////////////////////////////////////// namespace detail { + + // 用于判定 T 是否是一个 U 的类型 + // 比如 + // is_instance_of_v,std::vector>; // true + // is_instance_of_v,std::list>; // false + template class U> + inline constexpr bool is_instance_of_v = std::false_type{}; + + template class U, class... Vs> + inline constexpr bool is_instance_of_v,U> = std::true_type{}; + + template + struct local_storage_value_type; + + template + struct local_storage_value_type> + { + typedef ValueType value_type; + }; + template concept is_valid_await_suspend_return_value = std::convertible_to> || std::is_void_v || std::is_same_v; @@ -106,11 +121,7 @@ namespace ucoro template concept is_awaitable_v = is_awaiter_v || has_operator_co_await; - // 用于判定 T 是否是一个 awaitable<>::promise_type 的类型, 即: 拥有 local_ 成员。 - template - concept is_awaitable_promise_type_v = requires (T a) { - { a.local_ } -> std::convertible_to>; - }; + } // namespace detail struct debug_coro_promise @@ -266,16 +277,24 @@ namespace ucoro template auto await_transform(A&& awaiter) const { - if constexpr ( is_a_local_storage_t> ) + if constexpr (detail::is_instance_of_v, local_storage_t>) { - return local_storage_awaiter::value_type>{this}; + // 类型 A 是 local_storage_t<> 的一种 + return local_storage_awaiter>::value_type>{this}; + } + else if constexpr (detail::is_instance_of_v, awaitable>) + { + // 类型 A 是 awaitable<> 的一种 + static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); + return std::forward(awaiter); } - else if constexpr ( detail::is_awaitable_v> ) + else if constexpr (detail::is_awaitable_v>) { + // 类型 A 有 三件套 static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); return std::forward(awaiter); } - else if constexpr ( requires (A ) { await_transformer::await_transform; }) + else if constexpr (requires (A) { await_transformer::await_transform; }) { return await_transformer::await_transform(std::move(awaiter)); } @@ -397,7 +416,8 @@ namespace ucoro template auto await_suspend(std::coroutine_handle continuation) { - if constexpr (detail::is_awaitable_promise_type_v) + // PromiseType 是 awaitable::promise_type 的一种 + if constexpr (detail::is_instance_of_v) { auto& calee_promise = this_->current_coro_handle_.promise(); auto& caller_promise = continuation.promise(); diff --git a/tests/test3/test.cpp b/tests/test3/test.cpp index 5a8b77c..65f683f 100644 --- a/tests/test3/test.cpp +++ b/tests/test3/test.cpp @@ -7,7 +7,7 @@ int main(int argc, char **argv) using CallbackAwaiterType0 = ucoro::CallbackAwaiter; using CallbackAwaiterType1 = ucoro::CallbackAwaiter ; - static_assert(ucoro::is_a_local_storage_t>, "not a local_storage_t"); + static_assert(ucoro::detail::is_instance_of_v, ucoro::local_storage_t>, "not a local_storage_t"); static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType0 >, "not a coroutine"); static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType1 >, "not a coroutine");