From 972019d0a0ae50cd8cc44dc59748be7f6e152420 Mon Sep 17 00:00:00 2001 From: LLiuJJ Date: Thu, 5 Sep 2024 23:38:21 +0800 Subject: [PATCH] support raft snapshot --- raftcore/consts.go | 2 ++ raftcore/raft.go | 27 +++++++++++--- raftcore/raft_persistent_log.go | 24 +++++++++++-- shardkvserver/bucket.go | 4 +-- shardkvserver/shard_kvserver.go | 62 +++++++++++++++++++++++++++++++-- storage/kv.go | 2 +- storage/kv_leveldb.go | 6 +++- tests/integration_test.go | 2 +- 8 files changed, 115 insertions(+), 14 deletions(-) diff --git a/raftcore/consts.go b/raftcore/consts.go index c8877c80..dab7dfd5 100644 --- a/raftcore/consts.go +++ b/raftcore/consts.go @@ -39,3 +39,5 @@ var RAFT_STATE_KEY = []byte{0x19, 0x49} const INIT_LOG_INDEX = 0 var SNAPSHOT_STATE_KEY = []byte{0x19, 0x97} + +var BootstrapStateKey = []byte{0x19, 0x95} diff --git a/raftcore/raft.go b/raftcore/raft.go index 46712c68..0b5199ad 100644 --- a/raftcore/raft.go +++ b/raftcore/raft.go @@ -641,16 +641,35 @@ func (rf *Raft) Applier() { } } -func (rf *Raft) StartSnapshot(snap_idx uint64) error { +func (rf *Raft) Snapshot(snap_idx uint64, snapshotContext []byte) error { + rf.mu.Lock() + defer rf.mu.Unlock() rf.isSnapshoting = true if snap_idx <= rf.logs.GetFirstLogId() { rf.isSnapshoting = false return errors.New("ety index is larger than the first log index") } - rf.logs.EraseBefore(int64(snap_idx), true) - rf.logs.ResetFirstLogEntry(rf.curTerm, int64(snap_idx)) + _, err := rf.logs.EraseBefore(int64(snap_idx), true) + if err != nil { + rf.isSnapshoting = false + return err + } + if err := rf.logs.ResetFirstLogEntry(rf.curTerm, int64(snap_idx)); err != nil { + rf.isSnapshoting = false + return err + } // create checkpoint for db rf.isSnapshoting = false - return nil + return rf.logs.PersistSnapshot(snapshotContext) +} + +func (rf *Raft) ReadSnapshot() ([]byte, error) { + rf.mu.RLock() + defer rf.mu.RUnlock() + return rf.logs.ReadSnapshot() +} + +func (rf *Raft) LogCount() int { + return rf.logs.LogItemCount() } diff --git a/raftcore/raft_persistent_log.go b/raftcore/raft_persistent_log.go index 18224321..390addb4 100644 --- a/raftcore/raft_persistent_log.go +++ b/raftcore/raft_persistent_log.go @@ -52,11 +52,17 @@ type RaftPersistenState struct { // newdbEng: a LevelDBKvStore storage engine func MakePersistRaftLog(newdbEng storage_eng.KvStore) *RaftLog { _, err := newdbEng.GetBytesValue(EncodeRaftLogKey(INIT_LOG_INDEX)) - if err != nil { + _, readBootstrapStateErr := newdbEng.GetBytesValue(BootstrapStateKey) + if err != nil && readBootstrapStateErr != nil { logger.ELogger().Sugar().Debugf("init raft log state") emp_ent := &pb.Entry{} emp_ent_encode := EncodeEntry(emp_ent) - newdbEng.PutBytesKv(EncodeRaftLogKey(INIT_LOG_INDEX), emp_ent_encode) + if err := newdbEng.PutBytesKv(EncodeRaftLogKey(INIT_LOG_INDEX), emp_ent_encode); err != nil { + panic(err.Error()) + } + if err := newdbEng.PutBytesKv(BootstrapStateKey, []byte{}); err != nil { + panic(err.Error()) + } return &RaftLog{dbEng: newdbEng} } lidkBytes, _, err := newdbEng.SeekPrefixLast(RAFTLOG_PREFIX) @@ -96,6 +102,20 @@ func (rfLog *RaftLog) ReadRaftState() (curTerm int64, votedFor int64) { return rf_state.CurTerm, rf_state.VotedFor } +// PersistSnapshot ... +func (rfLog *RaftLog) PersistSnapshot(context []byte) error { + return rfLog.dbEng.PutBytesKv(SNAPSHOT_STATE_KEY, context) +} + +// ReadSnapshot ... +func (rfLog *RaftLog) ReadSnapshot() ([]byte, error) { + snapContext, err := rfLog.dbEng.GetBytesValue(SNAPSHOT_STATE_KEY) + if err != nil { + return nil, err + } + return snapContext, nil +} + // GetFirstLogId // get the first log id from storage engine func (rfLog *RaftLog) GetFirstLogId() uint64 { diff --git a/shardkvserver/bucket.go b/shardkvserver/bucket.go index dc471577..956dd204 100644 --- a/shardkvserver/bucket.go +++ b/shardkvserver/bucket.go @@ -85,9 +85,9 @@ func (bu *Bucket) Append(key, value string) error { } // copy all of the data in a bucket -func (bu *Bucket) deepCopy() (map[string]string, error) { +func (bu *Bucket) deepCopy(trimPrefix bool) (map[string]string, error) { encode_key_prefix := strconv.Itoa(bu.ID) + SPLIT - kvs, err := bu.KvDB.DumpPrefixKey(encode_key_prefix) + kvs, err := bu.KvDB.DumpPrefixKey(encode_key_prefix, trimPrefix) if err != nil { return nil, err } diff --git a/shardkvserver/shard_kvserver.go b/shardkvserver/shard_kvserver.go index 6302b7c8..00f0d37b 100644 --- a/shardkvserver/shard_kvserver.go +++ b/shardkvserver/shard_kvserver.go @@ -25,9 +25,12 @@ package shardkvserver import ( + "bytes" "context" + "encoding/gob" "encoding/json" "errors" + "maps" "strconv" "strings" "sync" @@ -66,6 +69,10 @@ type ShardKV struct { pb.UnimplementedRaftServiceServer } +type MemSnapshotDB struct { + KV map[string]string +} + // MakeShardKVServer make a new shard kv server // peerMaps: init peer map in the raft group // nodeId: the peer's nodeId in the raft group @@ -256,8 +263,13 @@ func (s *ShardKV) ApplingToStm(done <-chan interface{}) { logger.ELogger().Sugar().Debugf("appling msg %s", appliedMsg.String()) if appliedMsg.SnapshotValid { - // TODO: install snapshot data to leveldb - s.rf.CondInstallSnapshot(int(appliedMsg.SnapshotTerm), int(appliedMsg.SnapshotIndex)) + s.mu.Lock() + if s.rf.CondInstallSnapshot(int(appliedMsg.SnapshotTerm), int(appliedMsg.SnapshotIndex)) { + s.restoreSnapshot(appliedMsg.Snapshot) + s.lastApplied = int(appliedMsg.SnapshotIndex) + } + s.mu.Unlock() + return } req := &pb.CommandRequest{} @@ -340,6 +352,12 @@ func (s *ShardKV) ApplingToStm(done <-chan interface{}) { ch := s.getNotifyChan(int(appliedMsg.CommandIndex)) ch <- cmd_resp + + if _, isLeader := s.rf.GetState(); isLeader && s.GetRf().LogCount() > 500 { + s.mu.Lock() + s.takeSnapshot(uint64(appliedMsg.CommandIndex)) + s.mu.Unlock() + } } } } @@ -353,6 +371,44 @@ func (s *ShardKV) initStm(eng storage_eng.KvStore) { } } +// takeSnapshot +func (s *ShardKV) takeSnapshot(index uint64) { + var bytesState bytes.Buffer + enc := gob.NewEncoder(&bytesState) + memSnapshotDB := MemSnapshotDB{} + memSnapshotDB.KV = map[string]string{} + for i := 0; i < common.NBuckets; i++ { + if s.CanServe(i) { + kvs, err := s.stm[i].deepCopy(true) + if err != nil { + logger.ELogger().Sugar().Errorf(err.Error()) + } + maps.Copy(memSnapshotDB.KV, kvs) + } + } + enc.Encode(memSnapshotDB) + s.GetRf().Snapshot(index, bytesState.Bytes()) +} + +// restoreSnapshot +func (s *ShardKV) restoreSnapshot(snapData []byte) { + if snapData == nil { + return + } + buf := bytes.NewBuffer(snapData) + data := gob.NewDecoder(buf) + var memSnapshotDB MemSnapshotDB + if data.Decode(&memSnapshotDB) != nil { + logger.ELogger().Sugar().Error("decode memsnapshot error") + } + for k, v := range memSnapshotDB.KV { + bucketID := common.Key2BucketID(k) + if s.CanServe(bucketID) { + s.stm[bucketID].Put(k, v) + } + } +} + // rpc interface func (s *ShardKV) RequestVote(ctx context.Context, req *pb.RequestVoteRequest) (*pb.RequestVoteResponse, error) { resp := &pb.RequestVoteResponse{} @@ -403,7 +459,7 @@ func (s *ShardKV) DoBucketsOperation(ctx context.Context, req *pb.BucketOperatio bucket_datas := &BucketDatasVo{} bucket_datas.Datas = map[int]map[string]string{} for _, bucketID := range req.BucketIds { - sDatas, err := s.stm[int(bucketID)].deepCopy() + sDatas, err := s.stm[int(bucketID)].deepCopy(false) if err != nil { s.mu.RUnlock() return op_resp, err diff --git a/storage/kv.go b/storage/kv.go index 1b3aeb40..2b2ea141 100644 --- a/storage/kv.go +++ b/storage/kv.go @@ -29,7 +29,7 @@ type KvStore interface { Put(string, string) error Get(string) (string, error) Delete(string) error - DumpPrefixKey(string) (map[string]string, error) + DumpPrefixKey(string, bool) (map[string]string, error) PutBytesKv([]byte, []byte) error DeleteBytesK([]byte) error GetBytesValue([]byte) ([]byte, error) diff --git a/storage/kv_leveldb.go b/storage/kv_leveldb.go index 06777831..f55ca46f 100644 --- a/storage/kv_leveldb.go +++ b/storage/kv_leveldb.go @@ -24,6 +24,7 @@ package storage import ( "errors" + "strings" "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/opt" @@ -74,11 +75,14 @@ func (levelDB *LevelDBKvStore) Delete(k string) error { return levelDB.db.Delete([]byte(k), nil) } -func (levelDB *LevelDBKvStore) DumpPrefixKey(prefix string) (map[string]string, error) { +func (levelDB *LevelDBKvStore) DumpPrefixKey(prefix string, trimPrefix bool) (map[string]string, error) { kvs := make(map[string]string) iter := levelDB.db.NewIterator(util.BytesPrefix([]byte(prefix)), nil) for iter.Next() { k := string(iter.Key()) + if trimPrefix { + k = strings.TrimPrefix(k, prefix) + } v := string(iter.Value()) kvs[k] = v } diff --git a/tests/integration_test.go b/tests/integration_test.go index 2faa332e..5e91aa57 100644 --- a/tests/integration_test.go +++ b/tests/integration_test.go @@ -160,7 +160,7 @@ func TestClusterSingleShardRwBench(t *testing.T) { // R-W test shardkvcli := shardkvserver.MakeKvClient("127.0.0.1:8088,127.0.0.1:8089,127.0.0.1:8090") - N := 1000 + N := 100 KEY_SIZE := 64 VAL_SIZE := 64 bench_kvs := map[string]string{}