diff --git a/include/ucoro/awaitable.hpp b/include/ucoro/awaitable.hpp index 3c4f287..e58ecdf 100644 --- a/include/ucoro/awaitable.hpp +++ b/include/ucoro/awaitable.hpp @@ -36,6 +36,7 @@ namespace std #include #include #include +#include #if defined(DEBUG) || defined(_DEBUG) #if defined(ENABLE_DEBUG_CORO_LEAK) @@ -311,7 +312,7 @@ namespace ucoro auto await_transform(A&& awaiter) const { static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); - return std::forward(awaiter); + return std::move(awaiter); } template requires (!detail::is_awaiter_v>) @@ -375,7 +376,7 @@ namespace ucoro ~awaitable() { - if (current_coro_handle_ && current_coro_handle_.done()) + if (current_coro_handle_) { current_coro_handle_.destroy(); } @@ -432,9 +433,6 @@ namespace ucoro { std::rethrow_exception(exception); } - - current_coro_handle_.destroy(); - current_coro_handle_ = nullptr; } else { @@ -444,9 +442,6 @@ namespace ucoro std::rethrow_exception(std::get(ret)); } - current_coro_handle_.destroy(); - current_coro_handle_ = nullptr; - return std::get(ret); } } @@ -525,36 +520,116 @@ namespace ucoro template struct CallbackAwaiter : public CallbackAwaiterBase { + CallbackAwaiter(const CallbackAwaiter&) = delete; + CallbackAwaiter& operator = (const CallbackAwaiter&) = delete; public: explicit CallbackAwaiter(CallbackFunction&& callback_function) : callback_function_(std::move(callback_function)) { } + CallbackAwaiter(CallbackAwaiter&&) = default; constexpr bool await_ready() noexcept { return false; } + // 用户调用 handle( ret_value ) 就是在这里执行的. + void handle_process_on_user_call(std::coroutine_handle<> handle, std::thread::id thread_id, std::shared_ptr handler_resume_once_flag) + { + // 这段代码,只有三个可能的执行环境 + + // 1. 在另一个线程里执行 + // 2. 被同一个线程执行,但是是被 executor 重新调度后 + // 3. 在 callback_function_ 里面被立即调用执行,也就是无 executor 环境 + + // 这里和 await_suspend 不在一个线程里执行 + // 说明被多线程 executor 投递执行了 + // 既然是投递执行的,显然 await_suspend 得返回 noop_coroutine + // 因此这里就需要调用 resume 来恢复协程 + if (thread_id != std::this_thread::get_id()) + { + handle.resume(); + } + else + { + // 否则,接下来的代码,还和 await_suspend 在同一个线程里执行 + + // 那么,下面的代码是在 await_suspend 里面被嵌套执行了呢? + // 还是被 单线程 executor 投递执行了呢? + + // 如果是嵌套执行,那么 handler_resume_once_flag_ 一定还是 false + if (*handler_resume_once_flag) + { + // 说明这里的代码是被单线程的 executor 投递并延后执行了 + // 既然是投递执行的,显然 await_suspend 得返回 noop_coroutine + // 因此这里就需要调用 resume 来恢复协程 + handle.resume(); + } + else + { + // 设置为 true, 告诉外层的 await_suspend,一定不要返回 noop_coroutine + // await_suspend 一定要返回 handle 哦!这样可以避免爆栈 + *handler_resume_once_flag = true; + // 设置完毕,这里就不用 resume 了,直接返回即可 + std::terminate(); + } + } + } + + std::coroutine_handle<> decide_suspend_return(std::coroutine_handle<> handle, std::shared_ptr handler_resume_once_flag) + { + // 如果 callback_fuction 传入的 handler 在用户代码直接执行了 + // 则 这里要 return handle, 让协程框架进行切换。这样就不会爆栈了 + // 如果 handler 没有被直接执行,说明用户使用了 executor 对 handler 进行了延后调用 + // 不管是 直接 post 还是作为一个异步操作的回调 + // 那么这里就应该 返回 noop_coroutine,而不是立即进行协程切换 + if (*handler_resume_once_flag) + { + return handle; + } + else + { + // 如果这里是 false,说明 callback_function_ 里的用户代码没有立即调用 handle() + // 而是进行了投递操作 + // 那么 await_suspend 就必须要返回 noop_coroutine + // 与此同时,也将 handler_resume_once_flag_ 设置为 true + // 以便 callback_function_ 里的 handle 能知道,自己其实是被投递执行的. + *handler_resume_once_flag = true; + return std::noop_coroutine(); + } + } + auto await_suspend(std::coroutine_handle<> handle) { + auto thread_id = std::this_thread::get_id(); + + auto handler_resume_once_flag = std::make_shared(false); + if constexpr (std::is_void_v) { - callback_function_([]() {}); + callback_function_([this, thread_id, handler_resume_once_flag, handle]() mutable + { + return handle_process_on_user_call(handle, thread_id, handler_resume_once_flag); + }); } else { - callback_function_([this](T t) mutable { this->result_ = std::move(t); }); + callback_function_([this, thread_id, handler_resume_once_flag, handle](T t) mutable + { + this->result_ = std::move(t); + return handle_process_on_user_call(handle, thread_id, handler_resume_once_flag); + }); } - return handle; + + return decide_suspend_return(handle, handler_resume_once_flag); } private: CallbackFunction callback_function_; + bool handler_resume_once_flag_; }; - ////////////////////////////////////////////////////////////////////////// - template struct ExecutorAwaiter : public CallbackAwaiterBase { @@ -594,13 +669,13 @@ namespace ucoro ////////////////////////////////////////////////////////////////////////// template -ucoro::CallbackAwaiter callback_awaitable(callback&& cb) +auto callback_awaitable(callback&& cb) { return ucoro::CallbackAwaiter{std::forward(cb)}; } template -ucoro::ExecutorAwaiter executor_awaitable(callback&& cb) +auto executor_awaitable(callback&& cb) { return ucoro::ExecutorAwaiter{std::forward(cb)}; } diff --git a/tests/test5/test.cpp b/tests/test5/test.cpp index fe2735c..bf2b794 100644 --- a/tests/test5/test.cpp +++ b/tests/test5/test.cpp @@ -7,7 +7,7 @@ boost::asio::io_context main_ioc; ucoro::awaitable coro_compute_int(int value) { - auto ret = co_await executor_awaitable([value](auto handle) { + auto ret = co_await callback_awaitable([value](auto handle) { main_ioc.post([value, handle = std::move(handle)]() mutable { std::this_thread::sleep_for(std::chrono::seconds(0)); std::cout << value << " value\n"; diff --git a/tests/test_asio/test_asio.cpp b/tests/test_asio/test_asio.cpp index 250684f..fd65aa2 100644 --- a/tests/test_asio/test_asio.cpp +++ b/tests/test_asio/test_asio.cpp @@ -19,7 +19,7 @@ boost::asio::awaitable asio_coro_test() ucoro::awaitable coro_compute_int(int value) { - auto ret = co_await executor_awaitable([value](auto handle) { + auto ret = co_await callback_awaitable([value](auto handle) { main_ioc.post([value, handle = std::move(handle)]() mutable { std::this_thread::sleep_for(std::chrono::seconds(0)); std::cout << value << " value\n"; diff --git a/tests/test_executor/test_executor.cpp b/tests/test_executor/test_executor.cpp index e285717..a80f0e9 100644 --- a/tests/test_executor/test_executor.cpp +++ b/tests/test_executor/test_executor.cpp @@ -90,7 +90,7 @@ ucoro::awaitable coro_compute_int(int value) { executor_service* executor = co_await ucoro::local_storage_t(); - auto ret = co_await executor_awaitable([executor, value](auto handle) + auto ret = co_await callback_awaitable([executor, value](auto handle) { executor->enqueue([value, handle = std::move(handle)]() mutable { diff --git a/tests/testlibuv/test.cpp b/tests/testlibuv/test.cpp index ca9fb5f..184bba1 100644 --- a/tests/testlibuv/test.cpp +++ b/tests/testlibuv/test.cpp @@ -5,7 +5,7 @@ ucoro::awaitable async_sleep_with_uv_timer(int ms) { - co_await executor_awaitable([ms](auto continuation) + co_await callback_awaitable([ms](auto continuation) { struct uv_timer_with_data : uv_timer_s { @@ -15,16 +15,14 @@ ucoro::awaitable async_sleep_with_uv_timer(int ms) : continuation_(c){} }; - uv_timer_with_data* timer_handle = new uv_timer_with_data { std::move(continuation) }; + uv_timer_with_data* timer_handle = new uv_timer_with_data { std::forward(continuation) }; uv_timer_init(uv_default_loop(), timer_handle); uv_timer_start(timer_handle, [](uv_timer_t* handle) { uv_timer_stop(handle); - decltype(continuation) continuation_ = std::move(reinterpret_cast(handle)->continuation_); + reinterpret_cast(handle)->continuation_(); delete handle; - - continuation_(); }, ms, false); }); diff --git a/tests/testqt/testqt.cpp b/tests/testqt/testqt.cpp index f30e1ad..bee0625 100644 --- a/tests/testqt/testqt.cpp +++ b/tests/testqt/testqt.cpp @@ -8,7 +8,7 @@ ucoro::awaitable coro_compute_int(int value) { - auto ret = co_await executor_awaitable([value](auto handle) { + auto ret = co_await callback_awaitable([value](auto handle) { QTimer::singleShot(0, [value, handle = std::move(handle)]() mutable { std::cout << value << " value\n"; handle(value * 100);