diff --git a/go.mod b/go.mod index 0e1cad6a..708de916 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/segmentio/go-sqlite3 v1.12.0 github.com/segmentio/stats/v4 v4.6.2 github.com/stretchr/testify v1.8.1 + golang.org/x/sync v0.3.0 ) require ( diff --git a/go.sum b/go.sum index 64a25a05..bac4ac19 100644 --- a/go.sum +++ b/go.sum @@ -156,6 +156,8 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/ldb_reader.go b/ldb_reader.go index 1aea253c..ef70e411 100644 --- a/ldb_reader.go +++ b/ldb_reader.go @@ -28,6 +28,7 @@ import ( // across multiple processes. type LDBReader struct { Db *sql.DB + path string pkCache map[string]schema.PrimaryKey // keyed by ldbTableName() getRowByKeyStmtCache map[string]*sql.Stmt // keyed by ldbTableName() getRowsByKeyPrefixStmtCache map[prefixCacheKey]*sql.Stmt @@ -46,13 +47,17 @@ var ( ErrNoLedgerUpdates = errors.New("no ledger updates have been received yet") ) +type RowRetriever interface { + GetRowsByKeyPrefix(ctx context.Context, familyName string, tableName string, key ...interface{}) (*Rows, error) + GetRowByKey(ctx context.Context, out interface{}, familyName string, tableName string, key ...interface{}) (found bool, err error) +} + func newLDBReader(path string) (*LDBReader, error) { db, err := newLDB(path) if err != nil { return nil, err } - - return &LDBReader{Db: db}, nil + return &LDBReader{Db: db, path: path}, nil } func newVersionedLDBReader(dirPath string) (*LDBReader, error) { @@ -344,7 +349,6 @@ func (reader *LDBReader) closeDB() error { if reader.Db != nil { return reader.Db.Close() } - return nil } @@ -369,7 +373,7 @@ func (reader *LDBReader) Ping(ctx context.Context) bool { // to the type of each PK column. func convertKeyBeforeQuery(pk schema.PrimaryKey, key []interface{}) error { for i, k := range key { - // sanity check on th elength of the pk field type slice + // sanity check on the length of the pk field type slice if i >= len(pk.Types) { return errors.New("insufficient key field type data") } diff --git a/ldb_rotating_reader.go b/ldb_rotating_reader.go new file mode 100644 index 00000000..d32e3a45 --- /dev/null +++ b/ldb_rotating_reader.go @@ -0,0 +1,168 @@ +package ctlstore + +import ( + "context" + "errors" + "fmt" + "github.com/segmentio/ctlstore/pkg/errs" + "github.com/segmentio/ctlstore/pkg/globalstats" + "github.com/segmentio/ctlstore/pkg/ldb" + "github.com/segmentio/events/v2" + "github.com/segmentio/stats/v4" + "path" + "strconv" + "sync/atomic" + "time" +) + +// LDBRotatingReader reads data from multiple LDBs on a rotating schedule. +// The main benefit is relieving read pressure on a particular ldb file when it becomes inactive, +// allowing sqlite maintenance +type LDBRotatingReader struct { + active int32 + dbs []*LDBReader + schedule []int8 + now func() time.Time + tickerInterval time.Duration +} + +// RotationPeriod how many minutes each reader is active for before rotating to the next +type RotationPeriod int + +const ( + // Every30 rotate on 30 minute mark in an hour + Every30 RotationPeriod = 30 + // Every20 rotate on 20 minute marks in an hour + Every20 RotationPeriod = 20 + // Every15 rotate on 15 minute marks in an hour + Every15 RotationPeriod = 15 + // Every10 rotate on 10 minute marks in an hour + Every10 RotationPeriod = 10 + // Every6 rotate on 6 minute marks in an hour + Every6 RotationPeriod = 6 + + // for simpler migration, the first ldb retains the original name + defaultPath = DefaultCtlstorePath + ldb.DefaultLDBFilename + ldbFormat = DefaultCtlstorePath + "ldb-%d.db" +) + +func defaultPaths(count int) []string { + paths := []string{defaultPath} + for i := 1; i < count; i++ { + paths = append(paths, fmt.Sprintf(ldbFormat, i+1)) + } + return paths +} + +// RotatingReader creates a new reader that rotates which ldb it reads from on a rotation period with the default location in /var/spool/ctlstore +func RotatingReader(ctx context.Context, minutesPerRotation RotationPeriod, ldbsCount int) (RowRetriever, error) { + return CustomerRotatingReader(ctx, minutesPerRotation, defaultPaths(ldbsCount)...) +} + +// CustomerRotatingReader creates a new reader that rotates which ldb it reads from on a rotation period +func CustomerRotatingReader(ctx context.Context, minutesPerRotation RotationPeriod, ldbPaths ...string) (RowRetriever, error) { + r, err := rotatingReader(minutesPerRotation, ldbPaths...) + if err != nil { + return nil, err + } + r.setActive() + go r.rotate(ctx) + return r, nil +} + +func rotatingReader(minutesPerRotation RotationPeriod, ldbPaths ...string) (*LDBRotatingReader, error) { + if len(ldbPaths) < 2 { + return nil, errors.New("RotatingReader requires more than 1 ldb") + } + if !isValid(minutesPerRotation) { + return nil, errors.New(fmt.Sprintf("invalid rotation period: %v", minutesPerRotation)) + } + if len(ldbPaths) > 60/int(minutesPerRotation) { + return nil, errors.New("cannot have more ldbs than rotations per hour") + } + var r LDBRotatingReader + for _, p := range ldbPaths { + events.Log("Opening ldb %s for reading", p) + reader, err := newLDBReader(p) + if err != nil { + return nil, err + } + r.dbs = append(r.dbs, reader) + } + r.schedule = make([]int8, 60) + idx := 0 + for i := 1; i < 61; i++ { + r.schedule[i-1] = int8(idx % len(ldbPaths)) + if i%int(minutesPerRotation) == 0 { + idx++ + } + } + return &r, nil +} + +func (r *LDBRotatingReader) setActive() { + if r.now == nil { + r.now = time.Now + } + atomic.StoreInt32(&r.active, int32(r.schedule[r.now().Minute()])) +} + +// GetRowsByKeyPrefix delegates to the active LDBReader +func (r *LDBRotatingReader) GetRowsByKeyPrefix(ctx context.Context, familyName string, tableName string, key ...interface{}) (*Rows, error) { + return r.dbs[atomic.LoadInt32(&r.active)].GetRowsByKeyPrefix(ctx, familyName, tableName, key...) +} + +// GetRowByKey delegates to the active LDBReader +func (r *LDBRotatingReader) GetRowByKey(ctx context.Context, out interface{}, familyName string, tableName string, key ...interface{}) (found bool, err error) { + return r.dbs[atomic.LoadInt32(&r.active)].GetRowByKey(ctx, out, familyName, tableName, key...) +} + +// rotate by default checks every 1 minute if the active db has changed according to schedule +func (r *LDBRotatingReader) rotate(ctx context.Context) { + if r.tickerInterval == 0 { + r.tickerInterval = 1 * time.Minute + } + ticker := time.NewTicker(r.tickerInterval) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + next := r.schedule[r.now().Minute()] + last := atomic.LoadInt32(&r.active) + + // move the next to active and close and reopen the last one + if int32(next) != last { + atomic.StoreInt32(&r.active, int32(next)) + stats.Incr("rotating_reader.rotate") + globalstats.Set("rotating_reader.active", next) + err := r.dbs[last].Close() + if err != nil { + events.Log("failed to close LDBReader for %s on rotation: %{error}v", r.dbs[last].path, err) + errs.Incr("rotating_reader.closing_ldbreader", stats.T("id", strconv.Itoa(int(last)))) + return + } + + reader, err := newLDBReader(r.dbs[last].path) + if err != nil { + events.Log("failed to open LDBReader for %s on rotation: %{error}v", r.dbs[last].path, err) + errs.Incr("rotating_reader.opening_ldbreader", + stats.T("id", strconv.Itoa(int(last))), + stats.T("path", path.Base(r.dbs[last].path))) + return + } + r.dbs[last] = reader + + } + } + } +} + +func isValid(rf RotationPeriod) bool { + switch rf { + case Every6, Every10, Every15, Every20, Every30: + return true + } + return false +} diff --git a/ldb_rotating_reader_test.go b/ldb_rotating_reader_test.go new file mode 100644 index 00000000..223c4440 --- /dev/null +++ b/ldb_rotating_reader_test.go @@ -0,0 +1,379 @@ +package ctlstore + +import ( + "context" + "database/sql" + "fmt" + "github.com/segmentio/ctlstore/pkg/ldb" + "github.com/stretchr/testify/require" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" +) + +func getMultiDBs(t *testing.T, count int) (dbs []*sql.DB, paths []string) { + var tds []func() + for i := 0; i < count; i++ { + d, td, p := ldb.LDBForTestWithPath(t) + dbs = append(dbs, d) + tds = append(tds, td) + paths = append(paths, p) + } + t.Cleanup(func() { + for _, fn := range tds { + fn() + } + }) + return dbs, paths +} + +type basic struct { + x int32 `ctlstore:"x"` +} + +func TestBasicRotatingReader(t *testing.T) { + dbs, paths := getMultiDBs(t, 2) + for i, db := range dbs { + _, err := db.Exec("CREATE TABLE family___table (x integer primary key);") + if err != nil { + t.Fatalf("failed to setup table: %v", err) + } + _, err = db.Exec(fmt.Sprintf("INSERT INTO family___table VALUES ('%d')", i+1)) + if err != nil { + t.Fatalf("failed to insert into table: %v", err) + } + } + + rr, err := CustomerRotatingReader(context.Background(), Every30, paths...) + if err != nil { + t.Fatalf("failed to create rotating reader: %v", err) + } + + var out basic + reader := rr.(*LDBRotatingReader) + found, err := rr.GetRowByKey(context.Background(), &out, "family", "table", reader.active+1) + if err != nil || !found { + t.Errorf("failed to find key 1: %v", err) + } + require.Equal(t, reader.active+1, out.x) + + var out2 basic + atomic.StoreInt32(&reader.active, (reader.active+1)%2) + found, err = reader.GetRowByKey(context.Background(), &out2, "family", "table", reader.active+1) + if err != nil || !found { + t.Errorf("failed to find key 2: %v", err) + } + require.Equal(t, reader.active+1, out2.x) + +} + +func TestValidRotatingReader(t *testing.T) { + tests := []struct { + name string + expErr string + paths []string + rp RotationPeriod + }{ + { + "1 ldb", + "more than 1 ldb", + []string{"1path"}, + Every30, + }, + { + "No ldb", + "more than 1 ldb", + []string{}, + Every30, + }, + { + "Nil ldb", + "more than 1 ldb", + nil, + Every30, + }, + { + "bad rotation", + "invalid rotation", + []string{"path1", "path2"}, + RotationPeriod(2), + }, + { + "more ldbs than period, max", + "cannot have more", + []string{"path1", "path2", "path3", "path4", "path5", "path6", "path7", "path8", "path9", "path10", "path11"}, + Every6, + }, + { + "more ldbs than period, min", + "cannot have more", + []string{"path1", "path2", "path3"}, + Every30, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := rotatingReader(tt.rp, tt.paths...) + if err == nil { + t.Fatal("error expected, none found") + } + + if !strings.Contains(err.Error(), tt.expErr) { + t.Errorf("Did not find right error: got %v", err) + } + }) + } +} + +func TestRotation(t *testing.T) { + _, paths := getMultiDBs(t, 6) + + rr, err := rotatingReader(Every6, paths...) + if err != nil { + t.Fatal("unexpected error creating reader") + } + + tests := []struct { + name string + nowFunc func() time.Time + exp int + }{ + { + "0-5", + func() time.Time { + return time.Date(2023, 8, 17, 9, 1, 0, 0, time.UTC) + }, + 0, + }, + { + "6-11", + func() time.Time { + return time.Date(2023, 8, 17, 9, 8, 0, 0, time.UTC) + }, + 1, + }, + { + "12-17", + func() time.Time { + return time.Date(2023, 8, 17, 9, 17, 0, 0, time.UTC) + }, + 2, + }, + { + "18-23", + func() time.Time { + return time.Date(2023, 8, 17, 9, 21, 0, 0, time.UTC) + }, + 3, + }, + { + "24-29", + func() time.Time { + return time.Date(2023, 8, 17, 9, 24, 0, 0, time.UTC) + }, + 4, + }, + { + "30-35", + func() time.Time { + return time.Date(2023, 8, 17, 9, 32, 0, 0, time.UTC) + }, + 5, + }, + { + "36-41", + func() time.Time { + return time.Date(2023, 8, 17, 9, 41, 0, 0, time.UTC) + }, + 0, + }, + { + "42-47", + func() time.Time { + return time.Date(2023, 8, 17, 9, 42, 0, 0, time.UTC) + }, + 1, + }, + { + "48-53", + func() time.Time { + return time.Date(2023, 8, 17, 9, 53, 0, 0, time.UTC) + }, + 2, + }, + { + "54-59", + func() time.Time { + return time.Date(2023, 8, 17, 9, 59, 0, 0, time.UTC) + }, + 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr.now = tt.nowFunc + rr.setActive() + if rr.active != int32(tt.exp) { + t.Errorf("expected %d to be active, got %d instead", tt.exp, rr.active) + } + }) + } +} + +func TestMultipleReaders(t *testing.T) { + ctx := context.Background() + dbs, paths := getMultiDBs(t, 4) + + // create the tables in each db, and add a row unique to that db + for i, db := range dbs { + _, err := db.Exec("CREATE TABLE family___foo (id varchar primary key );") + if err != nil { + t.Fatalf("failure creating table, %v", err) + } + _, err = db.Exec(fmt.Sprintf("INSERT INTO family___foo values ('%d');", i)) + if err != nil { + t.Fatalf("failure inserting into table, %v", err) + } + } + + rr, err := rotatingReader(Every15, paths...) + if err != nil { + t.Fatalf("unexpected error creating reader, %v", err) + } + i := 0 + wait := make(chan interface{}) + + rr.now = func() time.Time { + defer func() { + if i != 0 { + wait <- 1 + } + i = i + 15 + }() + return time.Date(2023, 8, 17, 10, 0+i, 59, 999_999_999, time.UTC) + } + rr.tickerInterval = 1 * time.Millisecond + rr.setActive() + go rr.rotate(ctx) + + for x := range dbs { + // for each db, ensure that we read its unique row + out := make(map[string]interface{}) + val, err := rr.GetRowByKey(ctx, out, "family", "foo", x) + if err != nil || !val { + t.Errorf("unexpected error on GetRowByKey %v", err) + } + + require.EqualValues(t, out, map[string]interface{}{"id": strconv.Itoa(x)}, "did not read correct value from table") + + // also ensure we can't read any other unique rows from other dbs + for y := range dbs { + if y == x { + continue + } + val, err = rr.GetRowByKey(ctx, out, "family", "foo", y) + if val || err != nil { + t.Errorf("row with key %d should not be found", y) + } + } + + // allow the ticker to proceed with its rotation + <-wait + time.Sleep(500 * time.Microsecond) + } + +} + +type kv struct { + id string `ctlstore:"id"` + bar string `ctlstore:"bar"` +} + +// verifies that the rows cursor returned by GetRowsByKeyPrefix is still valid even if a rotation occurs while iterating over the row set +func TestGetRowByPrefixAfterRotation(t *testing.T) { + ctx := context.Background() + dbs, paths := getMultiDBs(t, 4) + + // create the tables and multiple rows + for i, db := range dbs { + _, err := db.Exec("CREATE TABLE family___foo (id varchar, bar varchar, primary key (id, bar));") + if err != nil { + t.Fatalf("failure creating table, %v", err) + } + _, err = db.Exec(fmt.Sprintf("INSERT INTO family___foo values ('%d', '0'), ('%d', '1'), ('%d', '2'), ('%d', '3');", i, i, i, i)) + if err != nil { + t.Fatalf("failure inserting into table, %v", err) + } + } + + rr, err := rotatingReader(Every15, paths...) + if err != nil { + t.Fatalf("unexpected error creating reader, %v", err) + } + + i := 0 + wait := make(chan interface{}) + rr.now = func() time.Time { + defer func() { + if i != 0 { + wait <- 1 + } + i = i + 15 + }() + return time.Date(2023, 8, 17, 10, (0+i)%60, 59, 999_999_999, time.UTC) + } + rr.tickerInterval = 1 * time.Millisecond + rr.setActive() + // get an active rows cursor for the results set from db 0 + rows, err := rr.GetRowsByKeyPrefix(ctx, "family", "foo", "0") + + go rr.rotate(ctx) + + count := 0 + for rows.Next() { + var tar kv + err := rows.Scan(&tar) + if err != nil { + t.Fatalf("scan error: %v", err) + } + require.Equal(t, "0", tar.id) + require.Equal(t, strconv.Itoa(count), tar.bar) + // trigger a rotation + <-wait + time.Sleep(500 * time.Microsecond) + var out kv + count++ + + // should rotate by now, check if different result set is returned + found, err := rr.GetRowByKey(ctx, &out, "family", "foo", "0", "0") + if count == 4 { + // on the 4th rotation, we're back at the beginning + require.EqualValues(t, kv{"0", "0"}, out, "should have rotated all the way back to the first reader") + } else if found || err != nil { + t.Errorf("should not have found the key since it rotated: %v", err) + } + } + + require.Equal(t, 4, count, "should've returned 4 rows") +} + +func TestPath(t *testing.T) { + paths := defaultPaths(5) + if len(paths) != 5 { + t.Fatal("should be 5 paths") + } + + if paths[0] != defaultPath { + t.Fatalf("First path should be the default, %s", paths[0]) + } + + for i := 1; i < 5; i++ { + if !strings.Contains(paths[i], strconv.Itoa(i+1)) { + t.Errorf("path %s should've contained its number, %d", paths[i], i) + } + } +} diff --git a/pkg/cmd/ctlstore/main.go b/pkg/cmd/ctlstore/main.go index 67523cc3..4fbe3972 100644 --- a/pkg/cmd/ctlstore/main.go +++ b/pkg/cmd/ctlstore/main.go @@ -5,10 +5,15 @@ import ( "fmt" "net/http" "os" + "path" + "reflect" "strings" + "sync" "syscall" "time" + "golang.org/x/sync/errgroup" + "github.com/segmentio/conf" "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" @@ -32,6 +37,8 @@ import ( "github.com/segmentio/ctlstore/pkg/utils" ) +var DebugEnabled = false + type dogstatsdConfig struct { Address string `conf:"address" help:"Address of the dogstatsd agent that will receive metrics"` BufferSize int `conf:"buffer-size" help:"Size of the statsd metrics buffer" validate:"min=0"` @@ -67,6 +74,11 @@ type reflectorCliConfig struct { WALCheckpointThresholdSize int `conf:"wal-checkpoint-threshold-size" help:"Performs a checkpoint after the WAL file exceeds this size in bytes"` WALCheckpointType ldbwriter.CheckpointType `conf:"wal-checkpoint-type" help:"what type of checkpoint to manually perform once the wal size is exceeded"` BusyTimeoutMS int `conf:"busy-timeout-ms" help:"Set a busy timeout on the connection string for sqlite in milliseconds"` + MultiReflector multiReflectorConfig `conf:"multi-reflector" help:"Configuration for running multiple reflectors at once"` +} + +type multiReflectorConfig struct { + LDBPaths []string `conf:"ldb-paths" help:"list of ldbs, each ldb is managed by a unique reflector" validate:"nonzero"` } type executiveCliConfig struct { @@ -152,6 +164,7 @@ func main() { Commands: []conf.Command{ {Name: "version", Help: "Get the ctlstore version"}, {Name: "reflector", Help: "Run the ctlstore Reflector"}, + {Name: "multi-reflector", Help: "Run the ctlstore Reflector in multi mode"}, {Name: "sidecar", Help: "Run the ctlstore Sidecar"}, {Name: "executive", Help: "Run the ctlstore Executive service"}, {Name: "supervisor", Help: "Run the ctlstore Supervisor service"}, @@ -171,6 +184,8 @@ func main() { fmt.Println(ctlstore.Version) case "reflector": reflector(ctx, args) + case "multi-reflector": + multiReflector(ctx, args) case "sidecar": sidecar(ctx, args) case "executive": @@ -191,6 +206,7 @@ func main() { func enableDebug() { events.DefaultLogger.EnableDebug = true events.DefaultLogger.EnableSource = true + DebugEnabled = true } func defaultDogstatsdConfig() dogstatsdConfig { @@ -295,7 +311,7 @@ func supervisor(ctx context.Context, args []string) { return errors.Wrap(err, "ensure ldb dir") } - reflector, err := newReflector(cliCfg.ReflectorConfig, true) + reflector, err := newReflector(cliCfg.ReflectorConfig, true, 0) if err != nil { return errors.Wrap(err, "build supervisor reflector") } @@ -462,7 +478,7 @@ func reflector(ctx context.Context, args []string) { prometheusHandler: promHandler, }) defer teardown() - reflector, err := newReflector(cliCfg, false) + reflector, err := newReflector(cliCfg, false, 0) if err != nil { events.Log("Fatal error starting Reflector: %{error}+v", err) errs.IncrDefault(stats.T("op", "startup")) @@ -471,6 +487,86 @@ func reflector(ctx context.Context, args []string) { reflector.Start(ctx) } +func multiReflector(ctx context.Context, args []string) { + cliCfg := defaultReflectorCLIConfig(false) + loadConfig(&cliCfg, "reflector", args) + + if cliCfg.Debug { + enableDebug() + } + + var promHandler *prometheus.Handler + if len(cliCfg.MetricsBind) > 0 { + promHandler = &prometheus.Handler{} + + http.Handle("/metrics", promHandler) + + go func() { + events.Log("Serving Prometheus metrics on %s", cliCfg.MetricsBind) + err := http.ListenAndServe(cliCfg.MetricsBind, nil) + if err != nil { + events.Log("Failed to served Prometheus metrics: %s", err) + } + }() + } + _, teardown := configureDogstatsd(ctx, dogstatsdOpts{ + config: cliCfg.Dogstatsd, + statsPrefix: "reflector", + prometheusHandler: promHandler, + }) + defer teardown() + + reflectors := make([]*reflectorpkg.Reflector, len(cliCfg.MultiReflector.LDBPaths)) + var wg sync.WaitGroup + errChan := make(chan error, len(cliCfg.MultiReflector.LDBPaths)) + wg.Add(len(cliCfg.MultiReflector.LDBPaths)) + for i, ldbPath := range cliCfg.MultiReflector.LDBPaths { + p := ldbPath + x := cliCfg + x.LDBPath = p + if i > 0 { + events.Log("changelog only created for 1st ldb path: %{path}, skipping #%{num}d", cliCfg.MultiReflector.LDBPaths[0], i+1) + x.ChangelogPath = "" + x.ChangelogSize = 0 + + } + go func(x reflectorCliConfig, idx int) { + defer wg.Done() + r, err := newReflector(x, false, idx) + if err != nil { + events.Log("Fatal error starting Reflector: %{error}+v", err) + errs.IncrDefault(stats.T("op", "startup"), stats.T("path", p)) + errChan <- err + return + } + reflectors[idx] = r + }(x, i) + } + + wg.Wait() + + select { + case <-errChan: + return + default: + } + + grp, grpCtx := errgroup.WithContext(ctx) + for _, reflector := range reflectors { + r := reflector + grp.Go(func() error { + return r.Start(grpCtx) + }) + } + + err := grp.Wait() + if err != nil { + events.Log("reflectors ended in error %{error}v", err) + errs.Incr("multi.shutdown", stats.T("err", reflect.ValueOf(err).Type().String())) + return + } +} + func defaultReflectorCLIConfig(isSupervisor bool) reflectorCliConfig { config := reflectorCliConfig{ LDBPath: "", @@ -527,10 +623,13 @@ func newSidecar(config sidecarConfig) (*sidecarpkg.Sidecar, error) { }) } -func newReflector(cliCfg reflectorCliConfig, isSupervisor bool) (*reflectorpkg.Reflector, error) { +func newReflector(cliCfg reflectorCliConfig, isSupervisor bool, i int) (*reflectorpkg.Reflector, error) { if cliCfg.LedgerHealth.Disable { events.Log("DEPRECATION NOTICE: use --disable-ecs-behavior instead of --disable to control this ledger monitor behavior") } + id := fmt.Sprintf("%s-%d", path.Base(cliCfg.LDBPath), i) + l := events.NewLogger(events.DefaultHandler).With(events.Args{{"id", id}}) + l.EnableDebug = cliCfg.Debug return reflectorpkg.ReflectorFromConfig(reflectorpkg.ReflectorConfig{ LDBPath: cliCfg.LDBPath, ChangelogPath: cliCfg.ChangelogPath, @@ -561,5 +660,7 @@ func newReflector(cliCfg reflectorCliConfig, isSupervisor bool) (*reflectorpkg.R WALCheckpointThresholdSize: cliCfg.WALCheckpointThresholdSize, WALCheckpointType: cliCfg.WALCheckpointType, BusyTimeoutMS: cliCfg.BusyTimeoutMS, + ID: id, + Logger: l, }) } diff --git a/pkg/globalstats/stats.go b/pkg/globalstats/stats.go index 0be6f196..7a895572 100644 --- a/pkg/globalstats/stats.go +++ b/pkg/globalstats/stats.go @@ -34,6 +34,13 @@ type ( value interface{} tags []stats.Tag } + + gaugeVal struct { + name string + value interface{} + tags []stats.Tag + } + counterKey struct { name string family string @@ -46,6 +53,7 @@ type ( cfg Config incr counterKey observe observation + set gaugeVal } ) @@ -55,6 +63,7 @@ const ( statEventTypeIncr statEventTypeObserve statEventTypeClose + statEventTypeGauge ) var ( @@ -80,6 +89,21 @@ func Incr(name, family, table string) { } } +func Set(name string, value interface{}, tags ...stats.Tag) { + k := gaugeVal{ + name: name, + value: value, + tags: tags, + } + select { + case eventChan <- statEvent{typ: statEventTypeGauge, set: k}: + default: + // eventChan is full, drop this stat + incrDroppedStats() + } + +} + func Observe(name string, value interface{}, tags ...stats.Tag) { k := observation{name: name, value: value, tags: tags} select { @@ -207,6 +231,12 @@ func loop() { // We're shutting down stats, so stop recording and flushing metrics. case statEventTypeClose: closed = true + + case statEventTypeGauge: + if engine = lazyInitEngine(cfg, engine); engine == nil || closed { + continue + } + engine.Set(event.set.name, event.set.value, event.set.tags...) } } } diff --git a/pkg/ldbwriter/ldb_writer.go b/pkg/ldbwriter/ldb_writer.go index 44824603..a3dfe84f 100644 --- a/pkg/ldbwriter/ldb_writer.go +++ b/pkg/ldbwriter/ldb_writer.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "github.com/pkg/errors" "github.com/segmentio/events/v2" "github.com/segmentio/stats/v4" @@ -38,41 +37,45 @@ type LDBWriteMetadata struct { type SqlLdbWriter struct { Db *sql.DB LedgerTx *sql.Tx + // uniquely identify this SqlWriter + Logger *events.Logger + ID string } // Applies a DML statement to the writer's db, updating the sequence // tracking table in the same transaction -func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schema.DMLStatement) error { +func (w *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schema.DMLStatement) error { var tx *sql.Tx var err error - stats.Incr("sql_ldb_writer.apply") + stats.Incr("sql_ldb_writer.apply", stats.T("id", w.ID)) // Fill in the tx var - if writer.LedgerTx == nil { + if w.LedgerTx == nil { // Not applying a ledger transaction, so need a local transaction - tx, err = writer.Db.Begin() + tx, err = w.Db.Begin() if err != nil { - errs.Incr("sql_ldb_writer.begin_tx.error") + errs.Incr("sql_ldb_writer.begin_tx.error", stats.T("id", w.ID)) return errors.Wrap(err, "open tx error") } } else { // Applying a ledger transaction, so bring it into scope - tx = writer.LedgerTx + tx = w.LedgerTx } + logger := w.logger() // Handle begin ledger transaction control statements if statement.Statement == schema.DMLTxBeginKey { - if writer.LedgerTx != nil { + if w.LedgerTx != nil { // Attempted to open a transaction without committing the last one, // which is a violation of our invariants. Something is very, very // wrong with the ledger processing. tx.Rollback() - errs.Incr("sql_ldb_writer.ledgerTx.begin_invariant_violation") + errs.Incr("sql_ldb_writer.ledgerTx.begin_invariant_violation", stats.T("id", w.ID)) return errors.New("invariant violation") } - writer.LedgerTx = tx - events.Debug("Begin TX at %{sequence}v", statement.Sequence) + w.LedgerTx = tx + logger.Debug("Begin TX at %{sequence}v", statement.Sequence) } // Update the last update table. This will allow the ldb reader @@ -84,7 +87,7 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem _, err = tx.Exec(qs, ldb.LDBLastLedgerUpdateColumn, statement.Timestamp) if err != nil { tx.Rollback() - errs.Incr("sql_ldb_writer.upsert_last_update.error") + errs.Incr("sql_ldb_writer.upsert_last_update.error", stats.T("id", w.ID)) return errors.Wrap(err, "update last_update") } @@ -103,7 +106,7 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem res, err := tx.Exec(qs, statement.Sequence.Int()) if err != nil { tx.Rollback() - errs.Incr("sql_ldb_writer.upsert_seq.error") + errs.Incr("sql_ldb_writer.upsert_seq.error", stats.T("id", w.ID)) return errors.Wrap(err, "update seq tracker error") } @@ -111,12 +114,12 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem rowsAffected, err := res.RowsAffected() if err != nil { tx.Rollback() - errs.Incr("sql_ldb_writer.upsert_seq.rows_affected_error") + errs.Incr("sql_ldb_writer.upsert_seq.rows_affected_error", stats.T("id", w.ID)) return errors.Wrap(err, "update seq tracker rows affected error") } if rowsAffected == 0 { tx.Rollback() - errs.Incr("sql_ldb_writer.upsert_seq.replay_detected") + errs.Incr("sql_ldb_writer.upsert_seq.replay_detected", stats.T("id", w.ID)) return errors.New("update seq tracker replay detected") } @@ -130,27 +133,27 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem // Handle end ledger transaction control statements if statement.Statement == schema.DMLTxEndKey { - if writer.LedgerTx == nil { + if w.LedgerTx == nil { // Attempted to commit a transaction when there is no transaction // open, which is a violation of our invariants. Something is very, // very wrong with the ledger processing! tx.Rollback() - errs.Incr("sql_ldb_writer.ledgerTx.end_invariant_violation") + errs.Incr("sql_ldb_writer.ledgerTx.end_invariant_violation", stats.T("id", w.ID)) return errors.New("invariant violation") } err = tx.Commit() if err != nil { tx.Rollback() - errs.Incr("sql_ldb_writer.ledgerTx.commit.error") - events.Log("Failed to commit Tx at seq %{seq}s: %{error}+v", + errs.Incr("sql_ldb_writer.ledgerTx.commit.error", stats.T("id", w.ID)) + logger.Log("Failed to commit Tx at seq %{seq}s: %{error}+v", statement.Sequence, err) return errors.Wrap(err, "commit multi-statement dml tx error") } - stats.Incr("sql_ldb_writer.ledgerTx.commit.success") - events.Debug("Committed TX at %{sequence}v", statement.Sequence) - writer.LedgerTx = nil + stats.Incr("sql_ldb_writer.ledgerTx.commit.success", stats.T("id", w.ID)) + logger.Debug("Committed TX at %{sequence}v", statement.Sequence) + w.LedgerTx = nil return nil } @@ -158,37 +161,37 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem _, err = tx.Exec(statement.Statement) if err != nil { tx.Rollback() - errs.Incr("sql_ldb_writer.exec.error") + errs.Incr("sql_ldb_writer.exec.error", stats.T("id", w.ID)) return errors.Wrap(err, "exec dml statement error") } - stats.Incr("sql_ldb_writer.exec.success") + stats.Incr("sql_ldb_writer.exec.success", stats.T("id", w.ID)) - events.Debug("Applying DML[%{sequence}d]: '%{statement}s'", + logger.Debug("Applying DML[%{sequence}d]: '%{statement}s'", statement.Sequence, statement.Statement) // Commit if not inside a ledger transaction, since that would be // a single statement transaction. - if writer.LedgerTx == nil { + if w.LedgerTx == nil { err = tx.Commit() if err != nil { tx.Rollback() - errs.Incr("sql_ldb_writer.single.commit.error") - errs.Incr("sql_ldb_writer.commit.error") + errs.Incr("sql_ldb_writer.single.commit.error", stats.T("id", w.ID)) + errs.Incr("sql_ldb_writer.commit.error", stats.T("id", w.ID)) return errors.Wrap(err, "commit one-statement dml tx error") } } - stats.Incr("sql_ldb_writer.commit.success") + stats.Incr("sql_ldb_writer.commit.success", stats.T("id", w.ID)) return nil } -func (writer *SqlLdbWriter) Close() error { - if writer.LedgerTx != nil { - writer.LedgerTx.Rollback() - writer.LedgerTx = nil +func (w *SqlLdbWriter) Close() error { + if w.LedgerTx != nil { + w.LedgerTx.Rollback() + w.LedgerTx = nil } return nil } @@ -222,11 +225,11 @@ var ( // Checkpoint initiates a wal checkpoint, returning stats on the checkpoint's progress // see https://www.sqlite.org/pragma.html#pragma_wal_checkpoint for more details // requires write access -func (writer *SqlLdbWriter) Checkpoint(checkpointingType CheckpointType) (*PragmaWALResult, error) { - res, err := writer.Db.Query(fmt.Sprintf("PRAGMA wal_checkpoint(%s)", string(checkpointingType))) +func (w *SqlLdbWriter) Checkpoint(checkpointingType CheckpointType) (*PragmaWALResult, error) { + res, err := w.Db.Query(fmt.Sprintf("PRAGMA wal_checkpoint(%s)", string(checkpointingType))) if err != nil { - events.Log("error in checkpointing, %{error}", err) - errs.Incr("sql_ldb_writer.wal_checkpoint.query.error") + w.logger().Log("error in checkpointing, %{error}", err) + errs.Incr("sql_ldb_writer.wal_checkpoint.query.error", stats.T("id", w.ID)) return nil, err } @@ -235,11 +238,18 @@ func (writer *SqlLdbWriter) Checkpoint(checkpointingType CheckpointType) (*Pragm if res.Next() { err := res.Scan(&p.Busy, &p.Log, &p.Checkpointed) if err != nil { - events.Log("error in scanning checkpointing, %{error}", err) - errs.Incr("sql_ldb_writer.wal_checkpoint.scan.error") + w.logger().Log("error in scanning checkpointing, %{error}") + errs.Incr("sql_ldb_writer.wal_checkpoint.scan.error", stats.T("id", w.ID)) return nil, err } } p.Type = checkpointingType return &p, nil } + +func (w *SqlLdbWriter) logger() *events.Logger { + if w.Logger == nil { + w.Logger = events.DefaultLogger + } + return w.Logger +} diff --git a/pkg/ldbwriter/ldb_writer_test.go b/pkg/ldbwriter/ldb_writer_test.go index ecd69ff0..a266d93a 100644 --- a/pkg/ldbwriter/ldb_writer_test.go +++ b/pkg/ldbwriter/ldb_writer_test.go @@ -357,6 +357,9 @@ func TestCheckpointQuery(t *testing.T) { for _, tt := range tests { t.Run(string(tt.cpType), func(t *testing.T) { res, err := writer.Checkpoint(tt.cpType) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } err = writer.ApplyDMLStatement(context.Background(), schema.NewTestDMLStatement("INSERT INTO foo VALUES('hello');")) if err != nil { t.Fatalf("expected no error, got %v", err) diff --git a/pkg/reflector/reflector.go b/pkg/reflector/reflector.go index 324c2d25..71fdaaaa 100644 --- a/pkg/reflector/reflector.go +++ b/pkg/reflector/reflector.go @@ -37,6 +37,7 @@ import ( type Reflector struct { shovel func() (*shovel, error) ldb *sql.DB + logger *events.Logger upstreamdb *sql.DB ledgerMonitor *ledger.Monitor walMonitor starter @@ -74,6 +75,8 @@ type ReflectorConfig struct { WALCheckpointType ldbwriter.CheckpointType // optional DoMonitorWAL bool // optional BusyTimeoutMS int // optional + ID string + Logger *events.Logger } type DownloadMetric struct { @@ -192,7 +195,6 @@ func ReflectorFromConfig(config ReflectorConfig) (*Reflector, error) { events.Log("Successfully emitted metric from file") // TODO: check Upstream fields - stop := make(chan struct{}) // This is a function so that initialization can be redone each @@ -204,7 +206,10 @@ func ReflectorFromConfig(config ReflectorConfig) (*Reflector, error) { // as a function allows recovering all the way back to initializing // the and fetching the last known good sequence in the LDB. shovel := func() (*shovel, error) { - sqlDBWriter := &ldbwriter.SqlLdbWriter{Db: ldbDB} + sqlDBWriter := &ldbwriter.SqlLdbWriter{Db: ldbDB, + ID: config.ID, + Logger: config.Logger, + } var writer ldbwriter.LDBWriter = sqlDBWriter var ldbWriteCallbacks []ldbwriter.LDBWriteCallback @@ -240,6 +245,7 @@ func ReflectorFromConfig(config ReflectorConfig) (*Reflector, error) { } lastSeq, err := ldb.FetchSeqFromLdb(context.TODO(), ldbDB) + events.Log("Latest seq from %s: %d", config.ID, lastSeq.Int()) if err != nil { return nil, fmt.Errorf("Error when fetching last sequence from LDB: %v", err) } @@ -261,6 +267,7 @@ func ReflectorFromConfig(config ReflectorConfig) (*Reflector, error) { abortOnSeqSkip: true, maxSeqOnStartup: maxKnownSeq.Int64, stop: stop, + log: config.Logger, }, nil } @@ -289,6 +296,7 @@ func ReflectorFromConfig(config ReflectorConfig) (*Reflector, error) { return &Reflector{ shovel: shovel, ldb: ldbDB, + logger: config.Logger, upstreamdb: upstreamdb, ledgerMonitor: ledgerMon, stop: stop, @@ -336,7 +344,8 @@ func emitMetricFromFile(path string) error { } func (r *Reflector) Start(ctx context.Context) error { - events.Log("Starting Reflector.") + + r.logger.Log("Starting Reflector.") go r.ledgerMonitor.Start(ctx) go r.walMonitor.Start(ctx) for { @@ -346,7 +355,7 @@ func (r *Reflector) Start(ctx context.Context) error { return errors.Wrap(err, "build shovel") } defer shovel.Close() - events.Log("Shoveling...") + r.logger.Log("Shoveling...") stats.Incr("reflector.shovel_start") err = shovel.Start(ctx) return errors.Wrap(err, "shovel") @@ -354,7 +363,7 @@ func (r *Reflector) Start(ctx context.Context) error { switch { case errs.IsCanceled(err): // this is normal case events.IsTermination(errors.Cause(err)): // this is normal - events.Log("Reflector received termination signal") + r.logger.Log("Reflector received termination signal") case err != nil: switch { case errors.Is("SkippedSequence", err): @@ -364,7 +373,7 @@ func (r *Reflector) Start(ctx context.Context) error { default: errs.Incr("reflector.shovel_error") } - events.Log("Error encountered during shoveling: %{error}+v", err) + r.logger.Log("Error encountered during shoveling: %{error}+v", err) } select { case <-r.stop: @@ -384,7 +393,7 @@ func (r *Reflector) Stop() { func (r *Reflector) Close() error { var err error - events.Log("Close() reflector") + r.logger.Log("Close() reflector") err = r.ldb.Close() if err != nil { diff --git a/pkg/reflector/reflector_test.go b/pkg/reflector/reflector_test.go index 10e26119..9d8fda93 100644 --- a/pkg/reflector/reflector_test.go +++ b/pkg/reflector/reflector_test.go @@ -66,6 +66,7 @@ func TestShovelSequenceReset(t *testing.T) { LedgerHealth: ledger.HealthConfig{ DisableECSBehavior: true, }, + Logger: events.DefaultLogger, } reflector, err := ReflectorFromConfig(cfg) require.NoError(t, err) @@ -155,6 +156,7 @@ func TestReflector(t *testing.T) { DisableECSBehavior: true, PollInterval: 10 * time.Second, }, + Logger: events.DefaultLogger, } ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/reflector/shovel.go b/pkg/reflector/shovel.go index a7abae51..a2b8d628 100644 --- a/pkg/reflector/shovel.go +++ b/pkg/reflector/shovel.go @@ -23,6 +23,7 @@ type shovel struct { abortOnSeqSkip bool maxSeqOnStartup int64 stop chan struct{} + log *events.Logger } func (s *shovel) Start(ctx context.Context) error { @@ -44,7 +45,7 @@ func (s *shovel) Start(ctx context.Context) error { // early exit here if the shovel should be stopped select { case <-s.stop: - events.Log("Shovel stopping normally") + s.logger().Log("Shovel stopping normally") return nil default: } @@ -56,7 +57,7 @@ func (s *shovel) Start(ctx context.Context) error { sctx, cancel = context.WithTimeout(ctx, s.pollTimeout) stats.Incr("shovel.loop_enter") - events.Debug("shovel polling...") + s.logger().Debug("shovel polling...") st, err := s.source.Next(sctx) if err != nil { @@ -79,7 +80,7 @@ func (s *shovel) Start(ctx context.Context) error { // pollSleep := jitr.Jitter(s.pollInterval, s.jitterCoefficient) - events.Debug("Poll sleep %{sleepTime}s", pollSleep) + s.logger().Debug("Poll sleep %{sleepTime}s", pollSleep) select { case <-ctx.Done(): @@ -91,12 +92,12 @@ func (s *shovel) Start(ctx context.Context) error { continue } - events.Debug("Shovel applying %{statement}v", st) + s.logger().Debug("Shovel applying %{statement}v", st) if lastSeq != 0 { if st.Sequence > lastSeq+1 && st.Sequence.Int() > s.maxSeqOnStartup { stats.Incr("shovel.skipped_sequence") - events.Log("shovel skip sequence from:%{fromSeq}d to:%{toSeq}d", lastSeq, st.Sequence) + s.logger().Log("shovel skip sequence from:%{fromSeq}d to:%{toSeq}d", lastSeq, st.Sequence) if s.abortOnSeqSkip { // Mitigation for a bug that we haven't found yet @@ -133,8 +134,15 @@ func (s *shovel) Close() error { for _, closer := range s.closers { err := closer.Close() if err != nil { - events.Log("shovel encountered error during close: %{error}s", err) + s.logger().Log("shovel encountered error during close: %{error}s", err) } } return nil } + +func (s *shovel) logger() *events.Logger { + if s.log == nil { + s.log = events.DefaultLogger + } + return s.log +} diff --git a/pkg/reflector/wal_monitor.go b/pkg/reflector/wal_monitor.go index 762db76d..dfa27b62 100644 --- a/pkg/reflector/wal_monitor.go +++ b/pkg/reflector/wal_monitor.go @@ -3,6 +3,7 @@ package reflector import ( "context" "os" + "path" "time" "github.com/segmentio/events/v2" @@ -93,7 +94,9 @@ func (m *WALMonitor) Start(ctx context.Context) { } return } - stats.Set("wal-file-size", size) + + ldbFileName := path.Base(m.walPath) + stats.Set("wal-file-size", size, stats.T("ldb", ldbFileName)) if size <= m.walCheckpointThresholdSize { stats.Incr("wal-no-checkpoint") @@ -116,9 +119,9 @@ func (m *WALMonitor) Start(ctx context.Context) { if res.Busy == 1 { isBusy = "true" } - stats.Set("wal-checkpoint-status", 1, stats.T("busy", isBusy)) - stats.Set("wal-total-pages", res.Log) - stats.Set("wal-checkpointed-pages", res.Checkpointed) + stats.Set("wal-checkpoint-status", 1, stats.T("busy", isBusy), stats.T("ldb", ldbFileName)) + stats.Set("wal-total-pages", res.Log, stats.T("ldb", ldbFileName)) + stats.Set("wal-checkpointed-pages", res.Checkpointed, stats.T("ldb", ldbFileName)) failedInARow = 0 }) diff --git a/scripts/download.sh b/scripts/download.sh index 98559c99..c0ba5c83 100755 --- a/scripts/download.sh +++ b/scripts/download.sh @@ -7,28 +7,33 @@ PREFIX="$(echo $CTLSTORE_BOOTSTRAP_URL | grep :// | sed -e's,^\(.*://\).*,\1,g') URL="$(echo $CTLSTORE_BOOTSTRAP_URL | sed -e s,$PREFIX,,g)" BUCKET="$(echo $URL | grep / | cut -d/ -f1)" KEY="$(echo $URL | grep / | cut -d/ -f2)" +CTLSTORE_DIR="/var/spool/ctlstore" CONCURRENCY=${2:-20} +NUM_LDB=${3:-1} DOWNLOADED="false" COMPRESSED="false" -METRICS="/var/spool/ctlstore/metrics.json" +METRICS="$CTLSTORE_DIR/metrics.json" +mkdir -p $CTLSTORE_DIR +cd $CTLSTORE_DIR + +# busybox does not support sub-second resolution START=$(date +%s) END=$(date +%s) SHA_START=$(date +%s) SHA_END=$(date +%s) get_head_object() { - head_object=$(aws s3api head-object --bucket "${BUCKET}" --key "${KEY}") - echo "$head_object" + head_object=$(aws s3api head-object --bucket "${BUCKET}" --key "${KEY}") + echo "$head_object" } -if [ ! -f /var/spool/ctlstore/ldb.db ]; then - # busybox does not support sub-second resolution - START=$(date +%s) - - mkdir -p /var/spool/ctlstore - cd /var/spool/ctlstore +cleanup() { + echo "Removing snapshot.db" + rm -f $CTLSTORE_DIR/snapshot.* +} +download_snapshot() { echo "Downloading head object from ${CTLSTORE_BOOTSTRAP_URL}" head_object=$(get_head_object) @@ -44,15 +49,17 @@ if [ ! -f /var/spool/ctlstore/ldb.db ]; then DOWNLOADED="true" if [[ ${CTLSTORE_BOOTSTRAP_URL: -2} == gz ]]; then echo "Decompressing" - pigz -d snapshot.db.gz + pigz -df snapshot.db.gz COMPRESSED="true" fi +} +check_sha() { SHA_START=$(date +%s) if [ -z $remote_checksum ]; then echo "Remote checksum sha1 is null, skipping checksum validation" else - local_checksum=$(shasum snapshot.db | cut -f1 -d\ | xxd -r -p | base64) + local_checksum=$(shasum snapshot.db | cut -f1 -d\ | xxd -r -p | base64) echo "Local snapshot checksum in sha1: $local_checksum" if [[ "$local_checksum" == "$remote_checksum" ]]; then @@ -60,11 +67,27 @@ if [ ! -f /var/spool/ctlstore/ldb.db ]; then else echo "Checksum does not match" echo "Failed to download intact snapshot" + cleanup exit 1 fi fi SHA_END=$(date +%s) echo "Local checksum calculation took $(($SHA_END - $SHA_START)) seconds" +} + +if [ ! -f "$CTLSTORE_DIR/ldb.db" ]; then + echo "No ldb found, downloading snapshot" + download_snapshot + check_sha + + i=2 + while [ "$i" -le $NUM_LDB ]; do + if [ ! -f ldb-$i.db ]; then + echo "creating copy ldb-$i.db" + cp snapshot.db ldb-$i.db + fi + i=$((i + 1)) + done mv snapshot.db ldb.db END=$(date +%s) @@ -73,5 +96,27 @@ else echo "Snapshot already present" fi -echo "{\"startTime\": $(($END - $START)), \"downloaded\": \"$DOWNLOADED\", \"compressed\": \"$COMPRESSED\"}" > $METRICS +# on existing nodes, we may already have the ldb file. +# We should download a new snapshot to avoid copying an in-use ldb.db file and risking a malformed db +i=2 +while [ "$i" -le $NUM_LDB ]; do + + # make sure it's not already downloaded + if [ ! -f ldb-$i.db ]; then + echo "Preparing ldb-$i.db" + # download the snapshot if it's not present + if [ ! -f "$CTLSTORE_DIR/snapshot.db" ]; then + download_snapshot + check_sha + fi + + echo "creating copy ldb-$i.db" + cp snapshot.db ldb-$i.db + fi + i=$((i + 1)) +done + +cleanup + +echo "{\"startTime\": $(($END - $START)), \"downloaded\": \"$DOWNLOADED\", \"compressed\": \"$COMPRESSED\"}" >$METRICS cat $METRICS