Skip to content

Commit

Permalink
callback_awaitable 自适应executor模式
Browse files Browse the repository at this point in the history
原先为了避免协程循环导致的爆栈, callback_awaitable 实现成了两种。
一种是在非 executor 环境下调用
另一种是在 executor 环境下调用的 executor_awaitable

非 executor 环境下使用的 callback_awaitable 使用了新的 await_suspend 签名
通过直接返回 coroutine_handle 的方式避免对 .resume() 的直接调用
从而避免了爆栈问题

但是这也导致, callback_awaitable无法在 executor 环境下使用。

现在更新一下 callback_awaitable, 它可以自动判断出来 callback_awaitable 传给你
的 handle 有没有被投递给 executor。如果投递给了 executor 它就 让 await_suspend
返回 noop_coroutine, 等你调用 handle 的时候,它内部再调用对应协程的 resume 来恢
复协程。而如果你没有投递 handle, 而是在 callback_awaitable 传你 handle 的时候立
马调用, 则 await_suspend 就会通过向协程框架返回 协程句柄的方式避免嵌套resume导致的
爆栈。
  • Loading branch information
microcai committed Oct 16, 2024
1 parent b077b7b commit 1294e5c
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 24 deletions.
105 changes: 90 additions & 15 deletions include/ucoro/awaitable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace std
#include <functional>
#include <memory>
#include <type_traits>
#include <thread>

#if defined(DEBUG) || defined(_DEBUG)
#if defined(ENABLE_DEBUG_CORO_LEAK)
Expand Down Expand Up @@ -311,7 +312,7 @@ namespace ucoro
auto await_transform(A&& awaiter) const
{
static_assert(std::is_rvalue_reference_v<decltype(awaiter)>, "co_await must be used on rvalue");
return std::forward<A>(awaiter);
return std::move(awaiter);
}

template<typename A> requires (!detail::is_awaiter_v<std::decay_t<A>>)
Expand Down Expand Up @@ -375,7 +376,7 @@ namespace ucoro

~awaitable()
{
if (current_coro_handle_ && current_coro_handle_.done())
if (current_coro_handle_)
{
current_coro_handle_.destroy();
}
Expand Down Expand Up @@ -432,9 +433,6 @@ namespace ucoro
{
std::rethrow_exception(exception);
}

current_coro_handle_.destroy();
current_coro_handle_ = nullptr;
}
else
{
Expand All @@ -444,9 +442,6 @@ namespace ucoro
std::rethrow_exception(std::get<std::exception_ptr>(ret));
}

current_coro_handle_.destroy();
current_coro_handle_ = nullptr;

return std::get<T>(ret);
}
}
Expand Down Expand Up @@ -525,36 +520,116 @@ namespace ucoro
template<typename T, typename CallbackFunction>
struct CallbackAwaiter : public CallbackAwaiterBase<T>
{
CallbackAwaiter(const CallbackAwaiter&) = delete;
CallbackAwaiter& operator = (const CallbackAwaiter&) = delete;
public:
explicit CallbackAwaiter(CallbackFunction&& callback_function)
: callback_function_(std::move(callback_function))
{
}
CallbackAwaiter(CallbackAwaiter&&) = default;

constexpr bool await_ready() noexcept
{
return false;
}

// 用户调用 handle( ret_value ) 就是在这里执行的.
void handle_process_on_user_call(std::coroutine_handle<> handle, std::thread::id thread_id, std::shared_ptr<bool> handler_resume_once_flag)
{
// 这段代码,只有三个可能的执行环境

// 1. 在另一个线程里执行
// 2. 被同一个线程执行,但是是被 executor 重新调度后
// 3. 在 callback_function_ 里面被立即调用执行,也就是无 executor 环境

// 这里和 await_suspend 不在一个线程里执行
// 说明被多线程 executor 投递执行了
// 既然是投递执行的,显然 await_suspend 得返回 noop_coroutine
// 因此这里就需要调用 resume 来恢复协程
if (thread_id != std::this_thread::get_id())
{
handle.resume();
}
else
{
// 否则,接下来的代码,还和 await_suspend 在同一个线程里执行

// 那么,下面的代码是在 await_suspend 里面被嵌套执行了呢?
// 还是被 单线程 executor 投递执行了呢?

// 如果是嵌套执行,那么 handler_resume_once_flag_ 一定还是 false
if (*handler_resume_once_flag)
{
// 说明这里的代码是被单线程的 executor 投递并延后执行了
// 既然是投递执行的,显然 await_suspend 得返回 noop_coroutine
// 因此这里就需要调用 resume 来恢复协程
handle.resume();
}
else
{
// 设置为 true, 告诉外层的 await_suspend,一定不要返回 noop_coroutine
// await_suspend 一定要返回 handle 哦!这样可以避免爆栈
*handler_resume_once_flag = true;
// 设置完毕,这里就不用 resume 了,直接返回即可
std::terminate();
}
}
}

std::coroutine_handle<> decide_suspend_return(std::coroutine_handle<> handle, std::shared_ptr<bool> handler_resume_once_flag)
{
// 如果 callback_fuction 传入的 handler 在用户代码直接执行了
// 则 这里要 return handle, 让协程框架进行切换。这样就不会爆栈了
// 如果 handler 没有被直接执行,说明用户使用了 executor 对 handler 进行了延后调用
// 不管是 直接 post 还是作为一个异步操作的回调
// 那么这里就应该 返回 noop_coroutine,而不是立即进行协程切换
if (*handler_resume_once_flag)
{
return handle;
}
else
{
// 如果这里是 false,说明 callback_function_ 里的用户代码没有立即调用 handle()
// 而是进行了投递操作
// 那么 await_suspend 就必须要返回 noop_coroutine
// 与此同时,也将 handler_resume_once_flag_ 设置为 true
// 以便 callback_function_ 里的 handle 能知道,自己其实是被投递执行的.
*handler_resume_once_flag = true;
return std::noop_coroutine();
}
}

auto await_suspend(std::coroutine_handle<> handle)
{
auto thread_id = std::this_thread::get_id();

auto handler_resume_once_flag = std::make_shared<bool>(false);

if constexpr (std::is_void_v<T>)
{
callback_function_([]() {});
callback_function_([this, thread_id, handler_resume_once_flag, handle]() mutable
{
return handle_process_on_user_call(handle, thread_id, handler_resume_once_flag);
});
}
else
{
callback_function_([this](T t) mutable { this->result_ = std::move(t); });
callback_function_([this, thread_id, handler_resume_once_flag, handle](T t) mutable
{
this->result_ = std::move(t);
return handle_process_on_user_call(handle, thread_id, handler_resume_once_flag);
});
}
return handle;

return decide_suspend_return(handle, handler_resume_once_flag);
}

private:
CallbackFunction callback_function_;
bool handler_resume_once_flag_;
};

//////////////////////////////////////////////////////////////////////////

template<typename T, typename CallbackFunction>
struct ExecutorAwaiter : public CallbackAwaiterBase<T>
{
Expand Down Expand Up @@ -594,13 +669,13 @@ namespace ucoro
//////////////////////////////////////////////////////////////////////////

template<typename T, typename callback>
ucoro::CallbackAwaiter<T, callback> callback_awaitable(callback&& cb)
auto callback_awaitable(callback&& cb)
{
return ucoro::CallbackAwaiter<T, callback>{std::forward<callback>(cb)};
}

template<typename T, typename callback>
ucoro::ExecutorAwaiter<T, callback> executor_awaitable(callback&& cb)
auto executor_awaitable(callback&& cb)
{
return ucoro::ExecutorAwaiter<T, callback>{std::forward<callback>(cb)};
}
Expand Down
2 changes: 1 addition & 1 deletion tests/test5/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ boost::asio::io_context main_ioc;

ucoro::awaitable<int> coro_compute_int(int value)
{
auto ret = co_await executor_awaitable<int>([value](auto handle) {
auto ret = co_await callback_awaitable<int>([value](auto handle) {
main_ioc.post([value, handle = std::move(handle)]() mutable {
std::this_thread::sleep_for(std::chrono::seconds(0));
std::cout << value << " value\n";
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asio/test_asio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ boost::asio::awaitable<int> asio_coro_test()

ucoro::awaitable<int> coro_compute_int(int value)
{
auto ret = co_await executor_awaitable<int>([value](auto handle) {
auto ret = co_await callback_awaitable<int>([value](auto handle) {
main_ioc.post([value, handle = std::move(handle)]() mutable {
std::this_thread::sleep_for(std::chrono::seconds(0));
std::cout << value << " value\n";
Expand Down
2 changes: 1 addition & 1 deletion tests/test_executor/test_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ ucoro::awaitable<int> coro_compute_int(int value)
{
executor_service* executor = co_await ucoro::local_storage_t<executor_service*>();

auto ret = co_await executor_awaitable<int>([executor, value](auto handle)
auto ret = co_await callback_awaitable<int>([executor, value](auto handle)
{
executor->enqueue([value, handle = std::move(handle)]() mutable
{
Expand Down
8 changes: 3 additions & 5 deletions tests/testlibuv/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

ucoro::awaitable<void> async_sleep_with_uv_timer(int ms)
{
co_await executor_awaitable<void>([ms](auto continuation)
co_await callback_awaitable<void>([ms](auto continuation)
{
struct uv_timer_with_data : uv_timer_s
{
Expand All @@ -15,16 +15,14 @@ ucoro::awaitable<void> async_sleep_with_uv_timer(int ms)
: continuation_(c){}
};

uv_timer_with_data* timer_handle = new uv_timer_with_data { std::move(continuation) };
uv_timer_with_data* timer_handle = new uv_timer_with_data { std::forward<decltype(continuation)>(continuation) };

uv_timer_init(uv_default_loop(), timer_handle);
uv_timer_start(timer_handle, [](uv_timer_t* handle)
{
uv_timer_stop(handle);
decltype(continuation) continuation_ = std::move(reinterpret_cast<uv_timer_with_data*>(handle)->continuation_);
reinterpret_cast<uv_timer_with_data*>(handle)->continuation_();
delete handle;

continuation_();
}, ms, false);

});
Expand Down
2 changes: 1 addition & 1 deletion tests/testqt/testqt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

ucoro::awaitable<int> coro_compute_int(int value)
{
auto ret = co_await executor_awaitable<int>([value](auto handle) {
auto ret = co_await callback_awaitable<int>([value](auto handle) {
QTimer::singleShot(0, [value, handle = std::move(handle)]() mutable {
std::cout << value << " value\n";
handle(value * 100);
Expand Down

0 comments on commit 1294e5c

Please sign in to comment.