Skip to content

Commit

Permalink
Graceful shutdown in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
lnikon committed Jan 18, 2025
1 parent 9d5cd21 commit d63a8fe
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 46 deletions.
81 changes: 41 additions & 40 deletions examples/raft/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,6 @@ auto generate_random_timeout() -> int
namespace raft
{

/*node_client_t::node_client_t(id_t nodeId, ip_t nodeIp)*/
/* : m_id{nodeId},*/
/* m_ip{std::move(nodeIp)},*/
/* m_channel(grpc::CreateChannel(m_ip, grpc::InsecureChannelCredentials())),*/
/*m_stub(RaftService::NewStub(m_channel)),*/
/* m_kvStub(TinyKVPPService::NewStub(m_channel))*/
/*{*/
/* assert(m_id > 0);*/
/* assert(!m_ip.empty());*/
/**/
/* if (!m_channel)*/
/* {*/
/* throw std::runtime_error(fmt::format("Failed to establish a gRPC channel for node={} ip={}", m_id, m_ip));*/
/* }*/
/**/
/* if (!m_stub)*/
/* {*/
/* throw std::runtime_error(fmt::format("Failed to create a stub for node={} ip={}", m_id, m_ip));*/
/* }*/
/**/
/* if (!m_kvStub)*/
/* {*/
/* throw std::runtime_error(fmt::format("Failed to create a KV stub for node={} ip={}", m_id, m_ip));*/
/* }*/
/*}*/

node_client_t::node_client_t(node_config_t config, std::unique_ptr<RaftService::StubInterface> pRaftStub)
: m_config{std::move(config)},
m_stub{std::move(pRaftStub)}
Expand Down Expand Up @@ -318,13 +292,15 @@ auto consensus_module_t::RequestVote(grpc::ServerContext *pContext,
// Don't grant vote to the candidate if the nodes term is higher
if (pRequest->term() < m_currentTerm)
{
spdlog::debug("receivedTerm={} is lower than currentTerm={}", pRequest->term(), m_currentTerm);
return grpc::Status::OK;
}

// Grant vote to the candidate if the node hasn't voted yet and
// candidates log is at least as up-to-date as receiver's log
if (m_votedFor == 0 || m_votedFor == pRequest->candidateid())
{
spdlog::debug("votedFor={} candidateid={}", m_votedFor, pRequest->candidateid());
if (pRequest->lastlogterm() > getLastLogTerm() ||
(pRequest->lastlogterm() == getLastLogTerm() && pRequest->lastlogindex() >= getLastLogIndex()))
{
Expand All @@ -333,6 +309,8 @@ auto consensus_module_t::RequestVote(grpc::ServerContext *pContext,
spdlog::error("Node={} is unable to persist votedFor", m_config.m_id, m_votedFor);
}

spdlog::debug("Node={} votedFor={}", m_config.m_id, pRequest->candidateid());

m_leaderHeartbeatReceived.store(true);
pResponse->set_term(m_currentTerm);
pResponse->set_votegranted(1);
Expand Down Expand Up @@ -453,16 +431,21 @@ auto consensus_module_t::init() -> bool

void consensus_module_t::start()
{
absl::WriterMutexLock locker{&m_stateMutex};
m_electionThread = std::jthread(
[this](std::stop_token token)
{
while (!token.stop_requested())
{
if (m_shutdown)
{
break;
}

{
absl::MutexLock locker(&m_stateMutex);
absl::ReaderMutexLock locker{&m_stateMutex};
if (getState() == NodeState::LEADER)
{
/*std::this_thread::sleep_for(std::chrono::milliseconds(generate_random_timeout()));*/
continue;
}
}
Expand All @@ -482,7 +465,10 @@ void consensus_module_t::start()
int64_t timeToWaitMs = generate_random_timeout();
int64_t timeToWaitDeadlineMs = currentTimeMs() + timeToWaitMs;

// Define the condition to wait for leader's heartbeat
// Wake up when
// 1) Thread should be stopped
// 2) Leader sent a heartbeat
// 3) Wait for the heartbeat was too long
auto heartbeatReceivedCondition = [this, &timeToWaitDeadlineMs, &token, currentTimeMs]()
{
return token.stop_requested() || m_leaderHeartbeatReceived.load() ||
Expand Down Expand Up @@ -531,9 +517,12 @@ void consensus_module_t::start()

void consensus_module_t::stop()
{
spdlog::info("before lock shutting down consensus module");
absl::ReaderMutexLock locker{&m_stateMutex};

spdlog::info("Shutting down consensus module");

absl::WriterMutexLock locker{&m_stateMutex};
m_shutdown = true;

if (m_electionThread.joinable())
{
Expand Down Expand Up @@ -650,7 +639,9 @@ void consensus_module_t::startElection()
voteGranted,
response.responderid());

absl::MutexLock locker(&m_stateMutex);
spdlog::debug("Req lock before");
absl::WriterMutexLock locker(&m_stateMutex);
spdlog::debug("Req lock after");
if (responseTerm > m_currentTerm)
{
becomeFollower(responseTerm);
Expand Down Expand Up @@ -679,19 +670,22 @@ void consensus_module_t::becomeFollower(uint32_t newTerm)
{
m_currentTerm = newTerm;
m_state = NodeState::FOLLOWER;
spdlog::debug("Node={} reverted to follower state in term={}", m_config.m_id, m_currentTerm);

if (!updatePersistentState(std::nullopt, 0))
{
spdlog::error("Node={} is unable to persist votedFor={}", m_config.m_id, m_votedFor);
}

spdlog::debug("Follower node={} is joining heartbeat threads");
m_shutdownHeartbeatThreads = true;
for (auto &heartbeatThread : m_heartbeatThreads)
{
heartbeatThread.request_stop();
heartbeatThread.join();
}
m_heartbeatThreads.clear();

spdlog::debug("Server reverted to follower state in term={}", m_currentTerm);
spdlog::debug("Follower node={} is joining heartbeat threads finished");
}

auto consensus_module_t::hasMajority(uint32_t votes) const -> bool
Expand Down Expand Up @@ -736,9 +730,14 @@ void consensus_module_t::sendHeartbeat(node_client_t &client)
int consecutiveFailures = 0;
while (!token.stop_requested())
{
if (m_shutdown || m_shutdownHeartbeatThreads)
{
break;
}

AppendEntriesRequest request;
{
absl::ReaderMutexLock locker(&m_stateMutex);
absl::ReaderMutexLock locker{&m_stateMutex};
if (m_state != NodeState::LEADER)
{
spdlog::debug("Node={} is no longer a leader. Stopping the heartbeat thread");
Expand All @@ -763,11 +762,10 @@ void consensus_module_t::sendHeartbeat(node_client_t &client)
maxRetries);
if (consecutiveFailures >= maxRetries)
{
spdlog::error(
"Stopping heartbeat thread due to too much failed AppendEntries RPC attempts");
return;
}
consecutiveFailures = 0;

continue;
}

consecutiveFailures = 0;
Expand All @@ -783,8 +781,12 @@ void consensus_module_t::sendHeartbeat(node_client_t &client)
response.responderid());

{
absl::WriterMutexLock locker(&m_stateMutex);
if (token.stop_requested())
{
return;
}

absl::WriterMutexLock locker(&m_stateMutex);
if (responseTerm > m_currentTerm)
{
becomeFollower(responseTerm);
Expand All @@ -795,8 +797,6 @@ void consensus_module_t::sendHeartbeat(node_client_t &client)

std::this_thread::sleep_for(heartbeatInterval);
}

spdlog::debug("Stopping heartbeat thread for on the node={} for the client={}", m_config.m_id, client.id());
});
}

Expand Down Expand Up @@ -1034,6 +1034,7 @@ auto consensus_module_t::restorePersistentState() -> bool
}

ifs >> m_commitIndex >> m_votedFor;
m_votedFor = 0;
spdlog::info("Node={} restored commitIndex={} and votedFor={}", m_config.m_id, m_commitIndex, m_votedFor);
}

Expand Down
34 changes: 29 additions & 5 deletions examples/raft/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ using timepoint_t = std::chrono::high_resolution_clock::time_point;
// Valid IDs start from 1
constexpr const id_t gInvalidId = 0;

struct tkvpp_absl_try_unlock
{
tkvpp_absl_try_unlock(absl::Mutex *mu)
: m_mu{mu}
{
}

~tkvpp_absl_try_unlock()
{
if (m_mu)
{
m_mu->Unlock();
}
}

private:
absl::Mutex *m_mu{nullptr};
};

struct node_config_t
{
id_t m_id{gInvalidId};
Expand Down Expand Up @@ -157,6 +176,8 @@ class consensus_module_t : public RaftService::Service,
uint32_t m_votedFor ABSL_GUARDED_BY(m_stateMutex);
std::vector<LogEntry> m_log ABSL_GUARDED_BY(m_stateMutex);

absl::CondVar m_electionCV;

// Volatile state on all servers.
uint32_t m_commitIndex ABSL_GUARDED_BY(m_stateMutex);
uint32_t m_lastApplied ABSL_GUARDED_BY(m_stateMutex);
Expand All @@ -167,16 +188,19 @@ class consensus_module_t : public RaftService::Service,
std::unordered_map<id_t, uint32_t> m_nextIndex ABSL_GUARDED_BY(m_stateMutex);

// Election related fields
absl::Mutex m_timerMutex;
std::atomic<bool> m_leaderHeartbeatReceived{false};
std::jthread m_electionThread;
std::atomic<uint32_t> m_voteCount{0};
absl::Mutex m_timerMutex;
std::atomic<bool> m_leaderHeartbeatReceived{false};
std::jthread m_electionThread ABSL_GUARDED_BY(m_stateMutex);
std::atomic<uint32_t> m_voteCount{0};

// Stores clusterSize - 1 thread to send heartbeat to replicas
std::vector<std::jthread> m_heartbeatThreads ABSL_GUARDED_BY(m_stateMutex);

// Serves incoming RPC's
std::jthread m_serverThread;
std::jthread m_serverThread ABSL_GUARDED_BY(m_stateMutex);

bool m_shutdown{false};
bool m_shutdownHeartbeatThreads{false};

// Temporary in-memory hashtable to store KVs
std::unordered_map<std::string, std::string> m_kv;
Expand Down
4 changes: 3 additions & 1 deletion examples/raft/raft_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ TEST_CASE("ConsensusModule Initialization", "[ConsensusModule]")
.WillRepeatedly(testing::DoAll(testing::SetArgPointee<2>(response), testing::Return(grpc::Status::OK)));

AppendEntriesResponse aeResponse;
aeResponse.set_responderid(1);
aeResponse.set_responderid(2);
aeResponse.set_success(true);
EXPECT_CALL(*mockStub2, AppendEntries)
Expand All @@ -88,6 +87,9 @@ TEST_CASE("ConsensusModule Initialization", "[ConsensusModule]")

raft::consensus_module_t consensusModule{nodeConfig1, std::move(replicas)};
consensusModule.start();

std::this_thread::sleep_for(std::chrono::milliseconds(1000));
consensusModule.stop();
}

TEST_CASE("ConsensusModule Leader Election", "[ConsensusModule]")
Expand Down

0 comments on commit d63a8fe

Please sign in to comment.