From d63a8fe6db9cef5e21706f12498b40a3caf04953 Mon Sep 17 00:00:00 2001
From: lnikon <bejanyan.vahag@protonmail.com>
Date: Sat, 18 Jan 2025 13:43:05 +0400
Subject: [PATCH] Graceful shutdown in progress

---
 examples/raft/raft.cpp      | 81 +++++++++++++++++++------------------
 examples/raft/raft.h        | 34 +++++++++++++---
 examples/raft/raft_test.cpp |  4 +-
 3 files changed, 73 insertions(+), 46 deletions(-)

diff --git a/examples/raft/raft.cpp b/examples/raft/raft.cpp
index b61d8df..1e0fde7 100644
--- a/examples/raft/raft.cpp
+++ b/examples/raft/raft.cpp
@@ -45,32 +45,6 @@ auto generate_random_timeout() -> int
 namespace raft
 {
 
-/*node_client_t::node_client_t(id_t nodeId, ip_t nodeIp)*/
-/*    : m_id{nodeId},*/
-/*      m_ip{std::move(nodeIp)},*/
-/*      m_channel(grpc::CreateChannel(m_ip, grpc::InsecureChannelCredentials())),*/
-/*m_stub(RaftService::NewStub(m_channel)),*/
-/*      m_kvStub(TinyKVPPService::NewStub(m_channel))*/
-/*{*/
-/*    assert(m_id > 0);*/
-/*    assert(!m_ip.empty());*/
-/**/
-/*    if (!m_channel)*/
-/*    {*/
-/*        throw std::runtime_error(fmt::format("Failed to establish a gRPC channel for node={} ip={}", m_id, m_ip));*/
-/*    }*/
-/**/
-/*    if (!m_stub)*/
-/*    {*/
-/*        throw std::runtime_error(fmt::format("Failed to create a stub for node={} ip={}", m_id, m_ip));*/
-/*    }*/
-/**/
-/*    if (!m_kvStub)*/
-/*    {*/
-/*        throw std::runtime_error(fmt::format("Failed to create a KV stub for node={} ip={}", m_id, m_ip));*/
-/*    }*/
-/*}*/
-
 node_client_t::node_client_t(node_config_t config, std::unique_ptr<RaftService::StubInterface> pRaftStub)
     : m_config{std::move(config)},
       m_stub{std::move(pRaftStub)}
@@ -318,6 +292,7 @@ auto consensus_module_t::RequestVote(grpc::ServerContext      *pContext,
     // Don't grant vote to the candidate if the nodes term is higher
     if (pRequest->term() < m_currentTerm)
     {
+        spdlog::debug("receivedTerm={} is lower than currentTerm={}", pRequest->term(), m_currentTerm);
         return grpc::Status::OK;
     }
 
@@ -325,6 +300,7 @@ auto consensus_module_t::RequestVote(grpc::ServerContext      *pContext,
     // candidates log is at least as up-to-date as receiver's log
     if (m_votedFor == 0 || m_votedFor == pRequest->candidateid())
     {
+        spdlog::debug("votedFor={} candidateid={}", m_votedFor, pRequest->candidateid());
         if (pRequest->lastlogterm() > getLastLogTerm() ||
             (pRequest->lastlogterm() == getLastLogTerm() && pRequest->lastlogindex() >= getLastLogIndex()))
         {
@@ -333,6 +309,8 @@ auto consensus_module_t::RequestVote(grpc::ServerContext      *pContext,
                 spdlog::error("Node={} is unable to persist votedFor", m_config.m_id, m_votedFor);
             }
 
+            spdlog::debug("Node={} votedFor={}", m_config.m_id, pRequest->candidateid());
+
             m_leaderHeartbeatReceived.store(true);
             pResponse->set_term(m_currentTerm);
             pResponse->set_votegranted(1);
@@ -453,16 +431,21 @@ auto consensus_module_t::init() -> bool
 
 void consensus_module_t::start()
 {
+    absl::WriterMutexLock locker{&m_stateMutex};
     m_electionThread = std::jthread(
         [this](std::stop_token token)
         {
             while (!token.stop_requested())
             {
+                if (m_shutdown)
+                {
+                    break;
+                }
+
                 {
-                    absl::MutexLock locker(&m_stateMutex);
+                    absl::ReaderMutexLock locker{&m_stateMutex};
                     if (getState() == NodeState::LEADER)
                     {
-                        /*std::this_thread::sleep_for(std::chrono::milliseconds(generate_random_timeout()));*/
                         continue;
                     }
                 }
@@ -482,7 +465,10 @@ void consensus_module_t::start()
                     int64_t timeToWaitMs = generate_random_timeout();
                     int64_t timeToWaitDeadlineMs = currentTimeMs() + timeToWaitMs;
 
-                    // Define the condition to wait for leader's heartbeat
+                    // Wake up when
+                    // 1) Thread should be stopped
+                    // 2) Leader sent a heartbeat
+                    // 3) Wait for the heartbeat was too long
                     auto heartbeatReceivedCondition = [this, &timeToWaitDeadlineMs, &token, currentTimeMs]()
                     {
                         return token.stop_requested() || m_leaderHeartbeatReceived.load() ||
@@ -531,9 +517,12 @@ void consensus_module_t::start()
 
 void consensus_module_t::stop()
 {
+    spdlog::info("before lock shutting down consensus module");
+    absl::ReaderMutexLock locker{&m_stateMutex};
+
     spdlog::info("Shutting down consensus module");
 
-    absl::WriterMutexLock locker{&m_stateMutex};
+    m_shutdown = true;
 
     if (m_electionThread.joinable())
     {
@@ -650,7 +639,9 @@ void consensus_module_t::startElection()
                               voteGranted,
                               response.responderid());
 
-                absl::MutexLock locker(&m_stateMutex);
+                spdlog::debug("Req lock before");
+                absl::WriterMutexLock locker(&m_stateMutex);
+                spdlog::debug("Req lock after");
                 if (responseTerm > m_currentTerm)
                 {
                     becomeFollower(responseTerm);
@@ -679,19 +670,22 @@ void consensus_module_t::becomeFollower(uint32_t newTerm)
 {
     m_currentTerm = newTerm;
     m_state = NodeState::FOLLOWER;
+    spdlog::debug("Node={} reverted to follower state in term={}", m_config.m_id, m_currentTerm);
+
     if (!updatePersistentState(std::nullopt, 0))
     {
         spdlog::error("Node={} is unable to persist votedFor={}", m_config.m_id, m_votedFor);
     }
 
+    spdlog::debug("Follower node={} is joining heartbeat threads");
+    m_shutdownHeartbeatThreads = true;
     for (auto &heartbeatThread : m_heartbeatThreads)
     {
         heartbeatThread.request_stop();
         heartbeatThread.join();
     }
     m_heartbeatThreads.clear();
-
-    spdlog::debug("Server reverted to follower state in term={}", m_currentTerm);
+    spdlog::debug("Follower node={} is joining heartbeat threads finished");
 }
 
 auto consensus_module_t::hasMajority(uint32_t votes) const -> bool
@@ -736,9 +730,14 @@ void consensus_module_t::sendHeartbeat(node_client_t &client)
             int consecutiveFailures = 0;
             while (!token.stop_requested())
             {
+                if (m_shutdown || m_shutdownHeartbeatThreads)
+                {
+                    break;
+                }
+
                 AppendEntriesRequest request;
                 {
-                    absl::ReaderMutexLock locker(&m_stateMutex);
+                    absl::ReaderMutexLock locker{&m_stateMutex};
                     if (m_state != NodeState::LEADER)
                     {
                         spdlog::debug("Node={} is no longer a leader. Stopping the heartbeat thread");
@@ -763,11 +762,10 @@ void consensus_module_t::sendHeartbeat(node_client_t &client)
                                       maxRetries);
                         if (consecutiveFailures >= maxRetries)
                         {
+                            spdlog::error(
+                                "Stopping heartbeat thread due to too much failed AppendEntries RPC attempts");
                             return;
                         }
-                        consecutiveFailures = 0;
-
-                        continue;
                     }
 
                     consecutiveFailures = 0;
@@ -783,8 +781,12 @@ void consensus_module_t::sendHeartbeat(node_client_t &client)
                                   response.responderid());
 
                     {
-                        absl::WriterMutexLock locker(&m_stateMutex);
+                        if (token.stop_requested())
+                        {
+                            return;
+                        }
 
+                        absl::WriterMutexLock locker(&m_stateMutex);
                         if (responseTerm > m_currentTerm)
                         {
                             becomeFollower(responseTerm);
@@ -795,8 +797,6 @@ void consensus_module_t::sendHeartbeat(node_client_t &client)
 
                 std::this_thread::sleep_for(heartbeatInterval);
             }
-
-            spdlog::debug("Stopping heartbeat thread for on the node={} for the client={}", m_config.m_id, client.id());
         });
 }
 
@@ -1034,6 +1034,7 @@ auto consensus_module_t::restorePersistentState() -> bool
         }
 
         ifs >> m_commitIndex >> m_votedFor;
+        m_votedFor = 0;
         spdlog::info("Node={} restored commitIndex={} and votedFor={}", m_config.m_id, m_commitIndex, m_votedFor);
     }
 
diff --git a/examples/raft/raft.h b/examples/raft/raft.h
index 0bca10c..eb6a291 100644
--- a/examples/raft/raft.h
+++ b/examples/raft/raft.h
@@ -37,6 +37,25 @@ using timepoint_t = std::chrono::high_resolution_clock::time_point;
 // Valid IDs start from 1
 constexpr const id_t gInvalidId = 0;
 
+struct tkvpp_absl_try_unlock
+{
+    tkvpp_absl_try_unlock(absl::Mutex *mu)
+        : m_mu{mu}
+    {
+    }
+
+    ~tkvpp_absl_try_unlock()
+    {
+        if (m_mu)
+        {
+            m_mu->Unlock();
+        }
+    }
+
+  private:
+    absl::Mutex *m_mu{nullptr};
+};
+
 struct node_config_t
 {
     id_t m_id{gInvalidId};
@@ -157,6 +176,8 @@ class consensus_module_t : public RaftService::Service,
     uint32_t m_votedFor         ABSL_GUARDED_BY(m_stateMutex);
     std::vector<LogEntry> m_log ABSL_GUARDED_BY(m_stateMutex);
 
+    absl::CondVar m_electionCV;
+
     // Volatile state on all servers.
     uint32_t m_commitIndex ABSL_GUARDED_BY(m_stateMutex);
     uint32_t m_lastApplied ABSL_GUARDED_BY(m_stateMutex);
@@ -167,16 +188,19 @@ class consensus_module_t : public RaftService::Service,
     std::unordered_map<id_t, uint32_t> m_nextIndex  ABSL_GUARDED_BY(m_stateMutex);
 
     // Election related fields
-    absl::Mutex           m_timerMutex;
-    std::atomic<bool>     m_leaderHeartbeatReceived{false};
-    std::jthread          m_electionThread;
-    std::atomic<uint32_t> m_voteCount{0};
+    absl::Mutex                   m_timerMutex;
+    std::atomic<bool>             m_leaderHeartbeatReceived{false};
+    std::jthread m_electionThread ABSL_GUARDED_BY(m_stateMutex);
+    std::atomic<uint32_t>         m_voteCount{0};
 
     // Stores clusterSize - 1 thread to send heartbeat to replicas
     std::vector<std::jthread> m_heartbeatThreads ABSL_GUARDED_BY(m_stateMutex);
 
     // Serves incoming RPC's
-    std::jthread m_serverThread;
+    std::jthread m_serverThread ABSL_GUARDED_BY(m_stateMutex);
+
+    bool m_shutdown{false};
+    bool m_shutdownHeartbeatThreads{false};
 
     // Temporary in-memory hashtable to store KVs
     std::unordered_map<std::string, std::string> m_kv;
diff --git a/examples/raft/raft_test.cpp b/examples/raft/raft_test.cpp
index 1971e8d..c511c40 100644
--- a/examples/raft/raft_test.cpp
+++ b/examples/raft/raft_test.cpp
@@ -70,7 +70,6 @@ TEST_CASE("ConsensusModule Initialization", "[ConsensusModule]")
         .WillRepeatedly(testing::DoAll(testing::SetArgPointee<2>(response), testing::Return(grpc::Status::OK)));
 
     AppendEntriesResponse aeResponse;
-    aeResponse.set_responderid(1);
     aeResponse.set_responderid(2);
     aeResponse.set_success(true);
     EXPECT_CALL(*mockStub2, AppendEntries)
@@ -88,6 +87,9 @@ TEST_CASE("ConsensusModule Initialization", "[ConsensusModule]")
 
     raft::consensus_module_t consensusModule{nodeConfig1, std::move(replicas)};
     consensusModule.start();
+
+    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+    consensusModule.stop();
 }
 
 TEST_CASE("ConsensusModule Leader Election", "[ConsensusModule]")