Skip to content

Commit

Permalink
support raft snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
LLiuJJ committed Sep 5, 2024
1 parent 92216c5 commit 972019d
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 14 deletions.
2 changes: 2 additions & 0 deletions raftcore/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ var RAFT_STATE_KEY = []byte{0x19, 0x49}
const INIT_LOG_INDEX = 0

var SNAPSHOT_STATE_KEY = []byte{0x19, 0x97}

var BootstrapStateKey = []byte{0x19, 0x95}
27 changes: 23 additions & 4 deletions raftcore/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,16 +641,35 @@ func (rf *Raft) Applier() {
}
}

func (rf *Raft) StartSnapshot(snap_idx uint64) error {
func (rf *Raft) Snapshot(snap_idx uint64, snapshotContext []byte) error {
rf.mu.Lock()
defer rf.mu.Unlock()
rf.isSnapshoting = true
if snap_idx <= rf.logs.GetFirstLogId() {
rf.isSnapshoting = false
return errors.New("ety index is larger than the first log index")
}
rf.logs.EraseBefore(int64(snap_idx), true)
rf.logs.ResetFirstLogEntry(rf.curTerm, int64(snap_idx))
_, err := rf.logs.EraseBefore(int64(snap_idx), true)
if err != nil {
rf.isSnapshoting = false
return err
}
if err := rf.logs.ResetFirstLogEntry(rf.curTerm, int64(snap_idx)); err != nil {
rf.isSnapshoting = false
return err
}

// create checkpoint for db
rf.isSnapshoting = false
return nil
return rf.logs.PersistSnapshot(snapshotContext)
}

func (rf *Raft) ReadSnapshot() ([]byte, error) {
rf.mu.RLock()
defer rf.mu.RUnlock()
return rf.logs.ReadSnapshot()
}

func (rf *Raft) LogCount() int {
return rf.logs.LogItemCount()
}
24 changes: 22 additions & 2 deletions raftcore/raft_persistent_log.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ type RaftPersistenState struct {
// newdbEng: a LevelDBKvStore storage engine
func MakePersistRaftLog(newdbEng storage_eng.KvStore) *RaftLog {
_, err := newdbEng.GetBytesValue(EncodeRaftLogKey(INIT_LOG_INDEX))
if err != nil {
_, readBootstrapStateErr := newdbEng.GetBytesValue(BootstrapStateKey)
if err != nil && readBootstrapStateErr != nil {
logger.ELogger().Sugar().Debugf("init raft log state")
emp_ent := &pb.Entry{}
emp_ent_encode := EncodeEntry(emp_ent)
newdbEng.PutBytesKv(EncodeRaftLogKey(INIT_LOG_INDEX), emp_ent_encode)
if err := newdbEng.PutBytesKv(EncodeRaftLogKey(INIT_LOG_INDEX), emp_ent_encode); err != nil {
panic(err.Error())
}
if err := newdbEng.PutBytesKv(BootstrapStateKey, []byte{}); err != nil {
panic(err.Error())
}
return &RaftLog{dbEng: newdbEng}
}
lidkBytes, _, err := newdbEng.SeekPrefixLast(RAFTLOG_PREFIX)
Expand Down Expand Up @@ -96,6 +102,20 @@ func (rfLog *RaftLog) ReadRaftState() (curTerm int64, votedFor int64) {
return rf_state.CurTerm, rf_state.VotedFor
}

// PersistSnapshot ...
func (rfLog *RaftLog) PersistSnapshot(context []byte) error {
return rfLog.dbEng.PutBytesKv(SNAPSHOT_STATE_KEY, context)
}

// ReadSnapshot ...
func (rfLog *RaftLog) ReadSnapshot() ([]byte, error) {
snapContext, err := rfLog.dbEng.GetBytesValue(SNAPSHOT_STATE_KEY)
if err != nil {
return nil, err
}
return snapContext, nil
}

// GetFirstLogId
// get the first log id from storage engine
func (rfLog *RaftLog) GetFirstLogId() uint64 {
Expand Down
4 changes: 2 additions & 2 deletions shardkvserver/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ func (bu *Bucket) Append(key, value string) error {
}

// copy all of the data in a bucket
func (bu *Bucket) deepCopy() (map[string]string, error) {
func (bu *Bucket) deepCopy(trimPrefix bool) (map[string]string, error) {
encode_key_prefix := strconv.Itoa(bu.ID) + SPLIT
kvs, err := bu.KvDB.DumpPrefixKey(encode_key_prefix)
kvs, err := bu.KvDB.DumpPrefixKey(encode_key_prefix, trimPrefix)
if err != nil {
return nil, err
}
Expand Down
62 changes: 59 additions & 3 deletions shardkvserver/shard_kvserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
package shardkvserver

import (
"bytes"
"context"
"encoding/gob"
"encoding/json"
"errors"
"maps"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -66,6 +69,10 @@ type ShardKV struct {
pb.UnimplementedRaftServiceServer
}

type MemSnapshotDB struct {
KV map[string]string
}

// MakeShardKVServer make a new shard kv server
// peerMaps: init peer map in the raft group
// nodeId: the peer's nodeId in the raft group
Expand Down Expand Up @@ -256,8 +263,13 @@ func (s *ShardKV) ApplingToStm(done <-chan interface{}) {
logger.ELogger().Sugar().Debugf("appling msg %s", appliedMsg.String())

if appliedMsg.SnapshotValid {
// TODO: install snapshot data to leveldb
s.rf.CondInstallSnapshot(int(appliedMsg.SnapshotTerm), int(appliedMsg.SnapshotIndex))
s.mu.Lock()
if s.rf.CondInstallSnapshot(int(appliedMsg.SnapshotTerm), int(appliedMsg.SnapshotIndex)) {
s.restoreSnapshot(appliedMsg.Snapshot)
s.lastApplied = int(appliedMsg.SnapshotIndex)
}
s.mu.Unlock()
return
}

req := &pb.CommandRequest{}
Expand Down Expand Up @@ -340,6 +352,12 @@ func (s *ShardKV) ApplingToStm(done <-chan interface{}) {

ch := s.getNotifyChan(int(appliedMsg.CommandIndex))
ch <- cmd_resp

if _, isLeader := s.rf.GetState(); isLeader && s.GetRf().LogCount() > 500 {
s.mu.Lock()
s.takeSnapshot(uint64(appliedMsg.CommandIndex))
s.mu.Unlock()
}
}
}
}
Expand All @@ -353,6 +371,44 @@ func (s *ShardKV) initStm(eng storage_eng.KvStore) {
}
}

// takeSnapshot
func (s *ShardKV) takeSnapshot(index uint64) {
var bytesState bytes.Buffer
enc := gob.NewEncoder(&bytesState)
memSnapshotDB := MemSnapshotDB{}
memSnapshotDB.KV = map[string]string{}
for i := 0; i < common.NBuckets; i++ {
if s.CanServe(i) {
kvs, err := s.stm[i].deepCopy(true)
if err != nil {
logger.ELogger().Sugar().Errorf(err.Error())
}
maps.Copy(memSnapshotDB.KV, kvs)
}
}
enc.Encode(memSnapshotDB)
s.GetRf().Snapshot(index, bytesState.Bytes())
}

// restoreSnapshot
func (s *ShardKV) restoreSnapshot(snapData []byte) {
if snapData == nil {
return
}
buf := bytes.NewBuffer(snapData)
data := gob.NewDecoder(buf)
var memSnapshotDB MemSnapshotDB
if data.Decode(&memSnapshotDB) != nil {
logger.ELogger().Sugar().Error("decode memsnapshot error")
}
for k, v := range memSnapshotDB.KV {
bucketID := common.Key2BucketID(k)
if s.CanServe(bucketID) {
s.stm[bucketID].Put(k, v)
}
}
}

// rpc interface
func (s *ShardKV) RequestVote(ctx context.Context, req *pb.RequestVoteRequest) (*pb.RequestVoteResponse, error) {
resp := &pb.RequestVoteResponse{}
Expand Down Expand Up @@ -403,7 +459,7 @@ func (s *ShardKV) DoBucketsOperation(ctx context.Context, req *pb.BucketOperatio
bucket_datas := &BucketDatasVo{}
bucket_datas.Datas = map[int]map[string]string{}
for _, bucketID := range req.BucketIds {
sDatas, err := s.stm[int(bucketID)].deepCopy()
sDatas, err := s.stm[int(bucketID)].deepCopy(false)
if err != nil {
s.mu.RUnlock()
return op_resp, err
Expand Down
2 changes: 1 addition & 1 deletion storage/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type KvStore interface {
Put(string, string) error
Get(string) (string, error)
Delete(string) error
DumpPrefixKey(string) (map[string]string, error)
DumpPrefixKey(string, bool) (map[string]string, error)
PutBytesKv([]byte, []byte) error
DeleteBytesK([]byte) error
GetBytesValue([]byte) ([]byte, error)
Expand Down
6 changes: 5 additions & 1 deletion storage/kv_leveldb.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package storage

import (
"errors"
"strings"

"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/opt"
Expand Down Expand Up @@ -74,11 +75,14 @@ func (levelDB *LevelDBKvStore) Delete(k string) error {
return levelDB.db.Delete([]byte(k), nil)
}

func (levelDB *LevelDBKvStore) DumpPrefixKey(prefix string) (map[string]string, error) {
func (levelDB *LevelDBKvStore) DumpPrefixKey(prefix string, trimPrefix bool) (map[string]string, error) {
kvs := make(map[string]string)
iter := levelDB.db.NewIterator(util.BytesPrefix([]byte(prefix)), nil)
for iter.Next() {
k := string(iter.Key())
if trimPrefix {
k = strings.TrimPrefix(k, prefix)
}
v := string(iter.Value())
kvs[k] = v
}
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func TestClusterSingleShardRwBench(t *testing.T) {
// R-W test
shardkvcli := shardkvserver.MakeKvClient("127.0.0.1:8088,127.0.0.1:8089,127.0.0.1:8090")

N := 1000
N := 100
KEY_SIZE := 64
VAL_SIZE := 64
bench_kvs := map[string]string{}
Expand Down

0 comments on commit 972019d

Please sign in to comment.