diff --git a/include/ucoro/awaitable.hpp b/include/ucoro/awaitable.hpp index 3c4f287..6852057 100644 --- a/include/ucoro/awaitable.hpp +++ b/include/ucoro/awaitable.hpp @@ -36,6 +36,7 @@ namespace std #include #include #include +#include #if defined(DEBUG) || defined(_DEBUG) #if defined(ENABLE_DEBUG_CORO_LEAK) @@ -525,32 +526,112 @@ namespace ucoro template struct CallbackAwaiter : public CallbackAwaiterBase { + CallbackAwaiter(const CallbackAwaiter&) = delete; + CallbackAwaiter& operator = (const CallbackAwaiter&) = delete; public: explicit CallbackAwaiter(CallbackFunction&& callback_function) : callback_function_(std::move(callback_function)) { } + CallbackAwaiter(CallbackAwaiter&&) = default; constexpr bool await_ready() noexcept { return false; } - auto await_suspend(std::coroutine_handle<> handle) + // 用户调用 handle( ret_value ) 就是在这里执行的. + void resume_by_user(std::coroutine_handle<> handle, std::thread::id thread_id, std::shared_ptr handler_resume_once_flag) { + // 这段代码,只有三个可能的执行环境 + + // 1. 在另一个线程里执行 + // 2. 被同一个线程执行,但是是被 executor 重新调度后 + // 3. 在 callback_function_ 里面被立即调用执行,也就是无 executor 环境 + + // 这里和 await_suspend 不在一个线程里执行 + // 说明被多线程 executor 投递执行了 + // 既然是投递执行的,显然 await_suspend 得返回 noop_coroutine + // 因此这里就需要调用 resume 来恢复协程 + if (thread_id != std::this_thread::get_id()) + { + handle.resume(); + } + else + { + // 否则,接下来的代码,还和 await_suspend 在同一个线程里执行 + + // 那么,下面的代码是在 await_suspend 里面被嵌套执行了呢? + // 还是被 单线程 executor 投递执行了呢? + + // 如果是嵌套执行,那么 handler_resume_once_flag_ 一定还是 false + if (*handler_resume_once_flag) + { + // 说明这里的代码是被单线程的 executor 投递并延后执行了 + // 既然是投递执行的,显然 await_suspend 得返回 noop_coroutine + // 因此这里就需要调用 resume 来恢复协程 + handle.resume(); + } + else + { + // 设置为 true, 告诉外层的 await_suspend,一定不要返回 noop_coroutine + // await_suspend 一定要返回 handle 哦!这样可以避免爆栈 + *handler_resume_once_flag = true; + // 设置完毕,这里就不用 resume 了,直接返回即可 + } + } + } + + std::coroutine_handle<> await_suspend(std::coroutine_handle<> handle) + { + auto thread_id = std::this_thread::get_id(); + + auto handler_resume_once_flag = std::make_shared(false); + if constexpr (std::is_void_v) { - callback_function_([]() {}); + callback_function_([this, handle, thread_id, handler_resume_once_flag]() mutable + { + return resume_by_user(handle, thread_id, handler_resume_once_flag); + }); + } + else + { + callback_function_([this, thread_id, handler_resume_once_flag, handle](T t) mutable + { + this->result_ = std::move(t); + return resume_by_user(handle, thread_id, handler_resume_once_flag); + }); + } + + // 如果 resume_by_user 在用户代码直接执行了 + // 则 这里要 return handle, 让协程框架进行切换。这样就不会爆栈了 + // 如果 handler_resume_once_flag == false, 说明用户使用了 executor 对 handler 进行了延后调用 + // 不管是 直接 post 还是作为一个异步操作的回调 + // 那么这里就应该 返回 noop_coroutine,而不是立即进行协程切换 + if (*handler_resume_once_flag) + { + // 这里 handler_resume_once_flag == true + // 说明 resume_by_user 在 callback_function_ 里面就已经被执行了 + // 那么 resume_by_user 会很识趣的不去调用 handle.resume() + // 于是这里就得返回 handle, 让协程框架进行切换。这样就不会爆栈了 + return handle; } else { - callback_function_([this](T t) mutable { this->result_ = std::move(t); }); + // 如果这里是 false,说明 callback_function_ 里的用户代码没有立即调用 resume_by_user() + // 而是进行了投递操作 + // 那么 await_suspend 就必须要返回 noop_coroutine + // 与此同时,也将 handler_resume_once_flag 设置为 true + // 以便 resume_by_user() 能知道,自己其实是被投递执行的. 要在内部调用 handle.resume() + *handler_resume_once_flag = true; + return std::noop_coroutine(); } - return handle; } private: CallbackFunction callback_function_; + bool handler_resume_once_flag_; }; ////////////////////////////////////////////////////////////////////////// diff --git a/tests/test5/test.cpp b/tests/test5/test.cpp index fe2735c..bf2b794 100644 --- a/tests/test5/test.cpp +++ b/tests/test5/test.cpp @@ -7,7 +7,7 @@ boost::asio::io_context main_ioc; ucoro::awaitable coro_compute_int(int value) { - auto ret = co_await executor_awaitable([value](auto handle) { + auto ret = co_await callback_awaitable([value](auto handle) { main_ioc.post([value, handle = std::move(handle)]() mutable { std::this_thread::sleep_for(std::chrono::seconds(0)); std::cout << value << " value\n"; diff --git a/tests/test_asio/test_asio.cpp b/tests/test_asio/test_asio.cpp index 250684f..fd65aa2 100644 --- a/tests/test_asio/test_asio.cpp +++ b/tests/test_asio/test_asio.cpp @@ -19,7 +19,7 @@ boost::asio::awaitable asio_coro_test() ucoro::awaitable coro_compute_int(int value) { - auto ret = co_await executor_awaitable([value](auto handle) { + auto ret = co_await callback_awaitable([value](auto handle) { main_ioc.post([value, handle = std::move(handle)]() mutable { std::this_thread::sleep_for(std::chrono::seconds(0)); std::cout << value << " value\n"; diff --git a/tests/test_executor/test_executor.cpp b/tests/test_executor/test_executor.cpp index e285717..a80f0e9 100644 --- a/tests/test_executor/test_executor.cpp +++ b/tests/test_executor/test_executor.cpp @@ -90,7 +90,7 @@ ucoro::awaitable coro_compute_int(int value) { executor_service* executor = co_await ucoro::local_storage_t(); - auto ret = co_await executor_awaitable([executor, value](auto handle) + auto ret = co_await callback_awaitable([executor, value](auto handle) { executor->enqueue([value, handle = std::move(handle)]() mutable { diff --git a/tests/testlibuv/test.cpp b/tests/testlibuv/test.cpp index ca9fb5f..184bba1 100644 --- a/tests/testlibuv/test.cpp +++ b/tests/testlibuv/test.cpp @@ -5,7 +5,7 @@ ucoro::awaitable async_sleep_with_uv_timer(int ms) { - co_await executor_awaitable([ms](auto continuation) + co_await callback_awaitable([ms](auto continuation) { struct uv_timer_with_data : uv_timer_s { @@ -15,16 +15,14 @@ ucoro::awaitable async_sleep_with_uv_timer(int ms) : continuation_(c){} }; - uv_timer_with_data* timer_handle = new uv_timer_with_data { std::move(continuation) }; + uv_timer_with_data* timer_handle = new uv_timer_with_data { std::forward(continuation) }; uv_timer_init(uv_default_loop(), timer_handle); uv_timer_start(timer_handle, [](uv_timer_t* handle) { uv_timer_stop(handle); - decltype(continuation) continuation_ = std::move(reinterpret_cast(handle)->continuation_); + reinterpret_cast(handle)->continuation_(); delete handle; - - continuation_(); }, ms, false); }); diff --git a/tests/testqt/testqt.cpp b/tests/testqt/testqt.cpp index f30e1ad..bee0625 100644 --- a/tests/testqt/testqt.cpp +++ b/tests/testqt/testqt.cpp @@ -8,7 +8,7 @@ ucoro::awaitable coro_compute_int(int value) { - auto ret = co_await executor_awaitable([value](auto handle) { + auto ret = co_await callback_awaitable([value](auto handle) { QTimer::singleShot(0, [value, handle = std::move(handle)]() mutable { std::cout << value << " value\n"; handle(value * 100);