Skip to content

Commit

Permalink
callback_awaitable 自适应executor模式
Browse files Browse the repository at this point in the history
原先为了避免协程循环导致的爆栈, callback_awaitable 实现成了两种。
一种是在非 executor 环境下调用
另一种是在 executor 环境下调用的 executor_awaitable

非 executor 环境下使用的 callback_awaitable 使用了新的 await_suspend 签名
通过直接返回 coroutine_handle 的方式避免对 .resume() 的直接调用
从而避免了爆栈问题

但是这也导致, callback_awaitable无法在 executor 环境下使用。

现在更新一下 callback_awaitable, 它可以自动判断出来 callback_awaitable 传给你
的 handle 有没有被投递给 executor。如果投递给了 executor 它就 让 await_suspend
返回 noop_coroutine, 等你调用 handle 的时候,它内部再调用对应协程的 resume 来恢
复协程。而如果你没有投递 handle, 而是在 callback_awaitable 传你 handle 的时候立
马调用, 则 await_suspend 就会通过向协程框架返回 协程句柄的方式避免嵌套resume导致的
爆栈。

下面简述原理:

callback_awaitable 调用 用户的回调函数的时候,会传入一个 handle
用户通过调用这个 handle 实现 恢复协程。让 co_await callback_awaitable 得以返回。

实现原理就是检测 handle 被调用的时候,是否是在 callback_awaitable 回调用户的上下文里。
也就是说,如果调用栈是

callback_awaitable::await_suspend -> user_lambda -> handle

那么,在 handle 的处理代码里,就标记一下,而不调用 coro_handle 的 resume

于是,等 user_lambda返回的时候,callback_awaitable::await_suspend 的代码通过检查
标记,就可以知道 handle 是不是被直接调用了。如果是,就 返回coro_handle,
否则返回 noop_coroutine.

如果 handle 的处理代码发现自己的调用栈不是 callback_awaitable::await_suspend 过
来的,则不做这个标记。

检查的方式如下:

1. 如果它发现自己运行的线程甚至不是 callback_awaitable::await_suspend
所运行的线程,则必然不在 callback_awaitable::await_suspend 的上下文里。

2. 通过检查一个共享的变量判断自己是否在 callback_awaitable::await_suspend 里面。

为啥 1. 要单独提出来呢? 因为 方法 2. 里有个隐含的条件,就是 handle 的处理代码,
和 callback_awaitable::await_suspend 调用 user_lambda 后的后续代码,是串行执行的。
如果只依赖 2. 这个方法,则可能判断出错。
  • Loading branch information
microcai committed Oct 16, 2024
1 parent b077b7b commit e18a6e6
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 13 deletions.
89 changes: 85 additions & 4 deletions include/ucoro/awaitable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace std
#include <functional>
#include <memory>
#include <type_traits>
#include <thread>

#if defined(DEBUG) || defined(_DEBUG)
#if defined(ENABLE_DEBUG_CORO_LEAK)
Expand Down Expand Up @@ -525,32 +526,112 @@ namespace ucoro
template<typename T, typename CallbackFunction>
struct CallbackAwaiter : public CallbackAwaiterBase<T>
{
CallbackAwaiter(const CallbackAwaiter&) = delete;
CallbackAwaiter& operator = (const CallbackAwaiter&) = delete;
public:
explicit CallbackAwaiter(CallbackFunction&& callback_function)
: callback_function_(std::move(callback_function))
{
}
CallbackAwaiter(CallbackAwaiter&&) = default;

constexpr bool await_ready() noexcept
{
return false;
}

auto await_suspend(std::coroutine_handle<> handle)
// 用户调用 handle( ret_value ) 就是在这里执行的.
void resume_by_user(std::coroutine_handle<> handle, std::thread::id thread_id, std::shared_ptr<bool> handler_resume_once_flag)
{
// 这段代码,只有三个可能的执行环境

// 1. 在另一个线程里执行
// 2. 被同一个线程执行,但是是被 executor 重新调度后
// 3. 在 callback_function_ 里面被立即调用执行,也就是无 executor 环境

// 这里和 await_suspend 不在一个线程里执行
// 说明被多线程 executor 投递执行了
// 既然是投递执行的,显然 await_suspend 得返回 noop_coroutine
// 因此这里就需要调用 resume 来恢复协程
if (thread_id != std::this_thread::get_id())
{
handle.resume();
}
else
{
// 否则,接下来的代码,还和 await_suspend 在同一个线程里执行

// 那么,下面的代码是在 await_suspend 里面被嵌套执行了呢?
// 还是被 单线程 executor 投递执行了呢?

// 如果是嵌套执行,那么 handler_resume_once_flag_ 一定还是 false
if (*handler_resume_once_flag)
{
// 说明这里的代码是被单线程的 executor 投递并延后执行了
// 既然是投递执行的,显然 await_suspend 得返回 noop_coroutine
// 因此这里就需要调用 resume 来恢复协程
handle.resume();
}
else
{
// 设置为 true, 告诉外层的 await_suspend,一定不要返回 noop_coroutine
// await_suspend 一定要返回 handle 哦!这样可以避免爆栈
*handler_resume_once_flag = true;
// 设置完毕,这里就不用 resume 了,直接返回即可
}
}
}

std::coroutine_handle<> await_suspend(std::coroutine_handle<> handle)
{
auto thread_id = std::this_thread::get_id();

auto handler_resume_once_flag = std::make_shared<bool>(false);

if constexpr (std::is_void_v<T>)
{
callback_function_([]() {});
callback_function_([this, handle, thread_id, handler_resume_once_flag]() mutable
{
return resume_by_user(handle, thread_id, handler_resume_once_flag);
});
}
else
{
callback_function_([this, thread_id, handler_resume_once_flag, handle](T t) mutable
{
this->result_ = std::move(t);
return resume_by_user(handle, thread_id, handler_resume_once_flag);
});
}

// 如果 resume_by_user 在用户代码直接执行了
// 则 这里要 return handle, 让协程框架进行切换。这样就不会爆栈了
// 如果 handler_resume_once_flag == false, 说明用户使用了 executor 对 handler 进行了延后调用
// 不管是 直接 post 还是作为一个异步操作的回调
// 那么这里就应该 返回 noop_coroutine,而不是立即进行协程切换
if (*handler_resume_once_flag)
{
// 这里 handler_resume_once_flag == true
// 说明 resume_by_user 在 callback_function_ 里面就已经被执行了
// 那么 resume_by_user 会很识趣的不去调用 handle.resume()
// 于是这里就得返回 handle, 让协程框架进行切换。这样就不会爆栈了
return handle;
}
else
{
callback_function_([this](T t) mutable { this->result_ = std::move(t); });
// 如果这里是 false,说明 callback_function_ 里的用户代码没有立即调用 resume_by_user()
// 而是进行了投递操作
// 那么 await_suspend 就必须要返回 noop_coroutine
// 与此同时,也将 handler_resume_once_flag 设置为 true
// 以便 resume_by_user() 能知道,自己其实是被投递执行的. 要在内部调用 handle.resume()
*handler_resume_once_flag = true;
return std::noop_coroutine();
}
return handle;
}

private:
CallbackFunction callback_function_;
bool handler_resume_once_flag_;
};

//////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion tests/test5/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ boost::asio::io_context main_ioc;

ucoro::awaitable<int> coro_compute_int(int value)
{
auto ret = co_await executor_awaitable<int>([value](auto handle) {
auto ret = co_await callback_awaitable<int>([value](auto handle) {
main_ioc.post([value, handle = std::move(handle)]() mutable {
std::this_thread::sleep_for(std::chrono::seconds(0));
std::cout << value << " value\n";
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asio/test_asio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ boost::asio::awaitable<int> asio_coro_test()

ucoro::awaitable<int> coro_compute_int(int value)
{
auto ret = co_await executor_awaitable<int>([value](auto handle) {
auto ret = co_await callback_awaitable<int>([value](auto handle) {
main_ioc.post([value, handle = std::move(handle)]() mutable {
std::this_thread::sleep_for(std::chrono::seconds(0));
std::cout << value << " value\n";
Expand Down
2 changes: 1 addition & 1 deletion tests/test_executor/test_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ ucoro::awaitable<int> coro_compute_int(int value)
{
executor_service* executor = co_await ucoro::local_storage_t<executor_service*>();

auto ret = co_await executor_awaitable<int>([executor, value](auto handle)
auto ret = co_await callback_awaitable<int>([executor, value](auto handle)
{
executor->enqueue([value, handle = std::move(handle)]() mutable
{
Expand Down
8 changes: 3 additions & 5 deletions tests/testlibuv/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

ucoro::awaitable<void> async_sleep_with_uv_timer(int ms)
{
co_await executor_awaitable<void>([ms](auto continuation)
co_await callback_awaitable<void>([ms](auto continuation)
{
struct uv_timer_with_data : uv_timer_s
{
Expand All @@ -15,16 +15,14 @@ ucoro::awaitable<void> async_sleep_with_uv_timer(int ms)
: continuation_(c){}
};

uv_timer_with_data* timer_handle = new uv_timer_with_data { std::move(continuation) };
uv_timer_with_data* timer_handle = new uv_timer_with_data { std::forward<decltype(continuation)>(continuation) };

uv_timer_init(uv_default_loop(), timer_handle);
uv_timer_start(timer_handle, [](uv_timer_t* handle)
{
uv_timer_stop(handle);
decltype(continuation) continuation_ = std::move(reinterpret_cast<uv_timer_with_data*>(handle)->continuation_);
reinterpret_cast<uv_timer_with_data*>(handle)->continuation_();
delete handle;

continuation_();
}, ms, false);

});
Expand Down
2 changes: 1 addition & 1 deletion tests/testqt/testqt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

ucoro::awaitable<int> coro_compute_int(int value)
{
auto ret = co_await executor_awaitable<int>([value](auto handle) {
auto ret = co_await callback_awaitable<int>([value](auto handle) {
QTimer::singleShot(0, [value, handle = std::move(handle)]() mutable {
std::cout << value << " value\n";
handle(value * 100);
Expand Down

0 comments on commit e18a6e6

Please sign in to comment.