Skip to content

Commit

Permalink
Dispatch net query only after locked check of close_flag_.
Browse files Browse the repository at this point in the history
  • Loading branch information
levlam committed May 15, 2024
1 parent 29cd56c commit 12c1689
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
48 changes: 35 additions & 13 deletions td/telegram/net/NetQueryDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,18 @@ void NetQueryDispatcher::complete_net_query(NetQueryPtr net_query) {
}
}

void NetQueryDispatcher::dispatch(NetQueryPtr net_query) {
// net_query->debug("dispatch");
bool NetQueryDispatcher::check_stop_flag(NetQueryPtr &net_query) const {
if (stop_flag_.load(std::memory_order_relaxed)) {
net_query->set_error(Global::request_aborted_error());
return complete_net_query(std::move(net_query));
complete_net_query(std::move(net_query));
return true;
}
return false;
}

void NetQueryDispatcher::dispatch(NetQueryPtr net_query) {
if (check_stop_flag(net_query)) {
return;
}
if (G()->get_option_boolean("test_flood_wait")) {
net_query->set_error(Status::Error(429, "Too Many Requests: retry after 10"));
Expand All @@ -55,13 +62,17 @@ void NetQueryDispatcher::dispatch(NetQueryPtr net_query) {
// net_query->set_error(Status::Error(420, "FLOOD_WAIT_10"));
// }
}
if (net_query->tl_constructor() == telegram_api::account_getPassword::ID && false) {
if (false && net_query->tl_constructor() == telegram_api::account_getPassword::ID) {
net_query->set_error(Status::Error(429, "Too Many Requests: retry after 10"));
return complete_net_query(std::move(net_query));
}

if (!net_query->in_sequence_dispatcher() && !net_query->get_chain_ids().empty()) {
net_query->debug("sent to main sequence dispatcher");
std::lock_guard<std::mutex> guard(mutex_);
if (check_stop_flag(net_query)) {
return;
}
send_closure_later(sequence_dispatcher_, &MultiSequenceDispatcher::send, std::move(net_query));
return;
}
Expand All @@ -76,6 +87,10 @@ void NetQueryDispatcher::dispatch(NetQueryPtr net_query) {
(code == 420 && !begins_with(net_query->error().message(), "STORY_SEND_FLOOD_") &&
!begins_with(net_query->error().message(), "PREMIUM_SUB_ACTIVE_UNTIL_"))) {
net_query->debug("sent to NetQueryDelayer");
std::lock_guard<std::mutex> guard(mutex_);
if (check_stop_flag(net_query)) {
return;
}
return send_closure_later(delayer_, &NetQueryDelayer::delay, std::move(net_query));
}
}
Expand Down Expand Up @@ -104,6 +119,10 @@ void NetQueryDispatcher::dispatch(NetQueryPtr net_query) {

auto dc_pos = static_cast<size_t>(dest_dc_id.get_raw_id() - 1);
CHECK(dc_pos < dcs_.size());
std::lock_guard<std::mutex> guard(mutex_);
if (check_stop_flag(net_query)) {
return;
}
switch (net_query->type()) {
case NetQuery::Type::Common:
net_query->debug(PSTRING() << "sent to main session multi proxy " << dest_dc_id);
Expand Down Expand Up @@ -148,7 +167,7 @@ Status NetQueryDispatcher::wait_dc_init(DcId dc_id, bool force) {
}

if (should_init) {
std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
std::lock_guard<std::mutex> guard(mutex_);
if (stop_flag_.load(std::memory_order_relaxed) || need_destroy_auth_key_) {
return Status::Error("Closing");
}
Expand Down Expand Up @@ -210,8 +229,7 @@ void NetQueryDispatcher::dispatch_with_callback(NetQueryPtr net_query, ActorShar
}

void NetQueryDispatcher::stop() {
std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
td_guard_.reset();
std::lock_guard<std::mutex> guard(mutex_);
stop_flag_ = true;
delayer_.reset();
for (auto &dc : dcs_) {
Expand All @@ -223,10 +241,11 @@ void NetQueryDispatcher::stop() {
public_rsa_key_watchdog_.reset();
dc_auth_manager_.reset();
sequence_dispatcher_.reset();
td_guard_.reset();
}

void NetQueryDispatcher::update_session_count() {
std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
std::lock_guard<std::mutex> guard(mutex_);
int32 session_count = get_session_count();
bool use_pfs = get_use_pfs();
for (int32 i = 1; i < DcId::MAX_RAW_DC_ID; i++) {
Expand All @@ -247,7 +266,7 @@ void NetQueryDispatcher::destroy_auth_keys(Promise<> promise) {
}
}

std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
std::lock_guard<std::mutex> guard(mutex_);
LOG(INFO) << "Destroy auth keys";
need_destroy_auth_key_ = true;
for (int32 i = 1; i < DcId::MAX_RAW_DC_ID; i++) {
Expand All @@ -259,7 +278,7 @@ void NetQueryDispatcher::destroy_auth_keys(Promise<> promise) {
}

void NetQueryDispatcher::update_use_pfs() {
std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
std::lock_guard<std::mutex> guard(mutex_);
bool use_pfs = get_use_pfs();
for (int32 i = 1; i < DcId::MAX_RAW_DC_ID; i++) {
if (is_dc_inited(i)) {
Expand All @@ -272,7 +291,7 @@ void NetQueryDispatcher::update_use_pfs() {
}

void NetQueryDispatcher::update_mtproto_header() {
std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
std::lock_guard<std::mutex> guard(mutex_);
for (int32 i = 1; i < DcId::MAX_RAW_DC_ID; i++) {
if (is_dc_inited(i)) {
send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_mtproto_header);
Expand Down Expand Up @@ -343,8 +362,7 @@ void NetQueryDispatcher::set_main_dc_id(int32 new_main_dc_id) {
return;
}

// Very rare event; mutex is ok.
std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
std::lock_guard<std::mutex> guard(mutex_);
if (new_main_dc_id == main_dc_id_) {
return;
}
Expand All @@ -363,6 +381,10 @@ void NetQueryDispatcher::set_main_dc_id(int32 new_main_dc_id) {
}

void NetQueryDispatcher::check_authorization_is_ok() {
std::lock_guard<std::mutex> guard(mutex_);
if (stop_flag_.load(std::memory_order_relaxed)) {
return;
}
send_closure(dc_auth_manager_, &DcAuthManager::check_authorization_is_ok);
}

Expand Down
3 changes: 2 additions & 1 deletion td/telegram/net/NetQueryDispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class NetQueryDispatcher {
std::atomic<int32> main_dc_id_{1};
#endif
ActorOwn<PublicRsaKeyWatchdog> public_rsa_key_watchdog_;
std::mutex main_dc_id_mutex_;
std::mutex mutex_;
std::shared_ptr<Guard> td_guard_;

Status wait_dc_init(DcId dc_id, bool force);
Expand All @@ -90,6 +90,7 @@ class NetQueryDispatcher {
static bool get_use_pfs();

static void complete_net_query(NetQueryPtr net_query);
bool check_stop_flag(NetQueryPtr &net_query) const;

void try_fix_migrate(NetQueryPtr &net_query);
};
Expand Down

0 comments on commit 12c1689

Please sign in to comment.