From 15a96e434fef02aa76d8d378803ed2ce50d3b3ac Mon Sep 17 00:00:00 2001 From: Avi Deitcher Date: Sun, 21 Apr 2024 18:10:44 +0300 Subject: [PATCH] implement remote logging Signed-off-by: Avi Deitcher --- cmd/common_test.go | 57 ++++++++++++++--- cmd/dump.go | 55 ++++++++-------- cmd/dump_test.go | 11 ++-- cmd/prune.go | 24 +++---- cmd/prune_test.go | 10 ++- cmd/restore.go | 28 ++++++--- cmd/restore_test.go | 38 ++++++----- cmd/root.go | 31 ++++++--- docs/configuration.md | 3 +- docs/logging.md | 21 +++++++ go.mod | 1 + go.sum | 2 + pkg/config/local.go | 3 + pkg/core/dump.go | 19 +++--- pkg/core/dumpoptions.go | 2 + pkg/core/executor.go | 17 +++++ pkg/core/prune.go | 29 +++++---- pkg/core/prune_test.go | 9 ++- pkg/core/pruneoptions.go | 2 + pkg/core/restore.go | 26 ++++---- pkg/core/restoreoptions.go | 17 +++++ pkg/core/timer.go | 7 +-- pkg/internal/remote/certs.go | 61 ++++++++++++++++++ pkg/internal/test/README.md | 4 ++ pkg/internal/test/remote.go | 77 +++++++++++++++++++++++ pkg/log/telemetry.go | 119 +++++++++++++++++++++++++++++++++++ pkg/log/telemetry_test.go | 119 +++++++++++++++++++++++++++++++++++ pkg/log/type.go | 9 +++ pkg/remote/const.go | 8 --- pkg/remote/get.go | 84 +++++++++---------------- pkg/remote/get_test.go | 62 ++++-------------- pkg/storage/file/file.go | 10 +-- pkg/storage/s3/s3.go | 20 +++--- pkg/storage/smb/smb.go | 9 +-- pkg/storage/storage.go | 14 +++-- test/backup_test.go | 7 ++- 36 files changed, 742 insertions(+), 273 deletions(-) create mode 100644 docs/logging.md create mode 100644 pkg/core/executor.go create mode 100644 pkg/core/restoreoptions.go create mode 100644 pkg/internal/remote/certs.go create mode 100644 pkg/internal/test/README.md create mode 100644 pkg/internal/test/remote.go create mode 100644 pkg/log/telemetry.go create mode 100644 pkg/log/telemetry_test.go create mode 100644 pkg/log/type.go diff --git a/cmd/common_test.go b/cmd/common_test.go index ec39025f..adaeedfe 100644 --- a/cmd/common_test.go +++ b/cmd/common_test.go @@ -1,15 +1,17 @@ package cmd import ( - "github.com/databacker/mysql-backup/pkg/compression" + "reflect" + "github.com/databacker/mysql-backup/pkg/core" - "github.com/databacker/mysql-backup/pkg/database" - "github.com/databacker/mysql-backup/pkg/storage" + "github.com/go-test/deep" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/mock" ) type mockExecs struct { mock.Mock + logger *log.Logger } func newMockExecs() *mockExecs { @@ -17,21 +19,21 @@ func newMockExecs() *mockExecs { return m } -func (m *mockExecs) dump(opts core.DumpOptions) error { +func (m *mockExecs) Dump(opts core.DumpOptions) error { args := m.Called(opts) return args.Error(0) } -func (m *mockExecs) restore(target storage.Storage, targetFile string, dbconn database.Connection, databasesMap map[string]string, compressor compression.Compressor) error { - args := m.Called(target, targetFile, dbconn, databasesMap, compressor) +func (m *mockExecs) Restore(opts core.RestoreOptions) error { + args := m.Called(opts) return args.Error(0) } -func (m *mockExecs) prune(opts core.PruneOptions) error { +func (m *mockExecs) Prune(opts core.PruneOptions) error { args := m.Called(opts) return args.Error(0) } -func (m *mockExecs) timer(timerOpts core.TimerOptions, cmd func() error) error { +func (m *mockExecs) Timer(timerOpts core.TimerOptions, cmd func() error) error { args := m.Called(timerOpts) err := args.Error(0) if err != nil { @@ -39,3 +41,42 @@ func (m *mockExecs) timer(timerOpts core.TimerOptions, cmd func() error) error { } return cmd() } + +func (m *mockExecs) SetLogger(logger *log.Logger) { + m.logger = logger +} + +func (m *mockExecs) GetLogger() *log.Logger { + return m.logger +} + +func equalIgnoreFields(a, b interface{}, fields []string) bool { + va := reflect.ValueOf(a) + vb := reflect.ValueOf(b) + + // Check if both a and b are struct types + if va.Kind() != reflect.Struct || vb.Kind() != reflect.Struct { + return false + } + + // Make a map of fields to ignore for quick lookup + ignoreMap := make(map[string]bool) + for _, f := range fields { + ignoreMap[f] = true + } + + // Compare fields that are not in the ignore list + for i := 0; i < va.NumField(); i++ { + field := va.Type().Field(i).Name + if !ignoreMap[field] { + vaField := va.Field(i).Interface() + vbField := vb.Field(i).Interface() + diff := deep.Equal(vaField, vbField) + if diff != nil { + return false + } + } + } + + return true +} diff --git a/cmd/dump.go b/cmd/dump.go index 5d3d2f92..102c6574 100644 --- a/cmd/dump.go +++ b/cmd/dump.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - log "github.com/sirupsen/logrus" + "github.com/google/uuid" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -20,7 +20,7 @@ const ( defaultMaxAllowedPacket = 4194304 ) -func dumpCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { +func dumpCmd(passedExecs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { if cmdConfig == nil { return nil, fmt.Errorf("cmdConfig is nil") } @@ -37,7 +37,7 @@ func dumpCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { bindFlags(cmd, v) }, RunE: func(cmd *cobra.Command, args []string) error { - log.Debug("starting dump") + cmdConfig.logger.Debug("starting dump") // check targets targetURLs := v.GetStringSlice("target") var ( @@ -130,19 +130,6 @@ func dumpCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { return fmt.Errorf("failure to get compression '%s': %v", compressionAlgo, err) } } - dumpOpts := core.DumpOptions{ - Targets: targets, - Safechars: safechars, - DBNames: include, - DBConn: cmdConfig.dbconn, - Compressor: compressor, - Exclude: exclude, - PreBackupScripts: preBackupScripts, - PostBackupScripts: preBackupScripts, - SuppressUseDatabase: noDatabaseName, - Compact: compact, - MaxAllowedPacket: maxAllowedPacket, - } // retention, if enabled retention := v.GetString("retention") @@ -173,23 +160,37 @@ func dumpCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { Begin: begin, Frequency: frequency, } - dump := core.Dump - prune := core.Prune - timer := core.TimerCommand - if execs != nil { - dump = execs.dump - prune = execs.prune - timer = execs.timer + var executor execs + executor = &core.Executor{} + if passedExecs != nil { + executor = passedExecs } + executor.SetLogger(cmdConfig.logger) + // at this point, any errors should not have usage cmd.SilenceUsage = true - if err := timer(timerOpts, func() error { - err := dump(dumpOpts) + if err := executor.Timer(timerOpts, func() error { + uid := uuid.New() + dumpOpts := core.DumpOptions{ + Targets: targets, + Safechars: safechars, + DBNames: include, + DBConn: cmdConfig.dbconn, + Compressor: compressor, + Exclude: exclude, + PreBackupScripts: preBackupScripts, + PostBackupScripts: preBackupScripts, + SuppressUseDatabase: noDatabaseName, + Compact: compact, + MaxAllowedPacket: maxAllowedPacket, + Run: uid, + } + err := executor.Dump(dumpOpts) if err != nil { return fmt.Errorf("error running dump: %w", err) } if retention != "" { - if err := prune(core.PruneOptions{Targets: targets, Retention: retention}); err != nil { + if err := executor.Prune(core.PruneOptions{Targets: targets, Retention: retention}); err != nil { return fmt.Errorf("error running prune: %w", err) } } @@ -197,7 +198,7 @@ func dumpCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { }); err != nil { return fmt.Errorf("error running command: %w", err) } - log.Info("Backup complete") + executor.GetLogger().Info("Backup complete") return nil }, } diff --git a/cmd/dump_test.go b/cmd/dump_test.go index 380c395a..d9e16505 100644 --- a/cmd/dump_test.go +++ b/cmd/dump_test.go @@ -110,15 +110,14 @@ func TestDumpCmd(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := newMockExecs() - m.On("dump", mock.MatchedBy(func(dumpOpts core.DumpOptions) bool { - diff := deep.Equal(dumpOpts, tt.expectedDumpOptions) - if diff == nil { + m.On("Dump", mock.MatchedBy(func(dumpOpts core.DumpOptions) bool { + if equalIgnoreFields(dumpOpts, tt.expectedDumpOptions, []string{"Run"}) { return true } - t.Errorf("dumpOpts compare failed: %v", diff) + t.Errorf("dumpOpts compare failed: %#v %#v", dumpOpts, tt.expectedDumpOptions) return false })).Return(nil) - m.On("timer", mock.MatchedBy(func(timerOpts core.TimerOptions) bool { + m.On("Timer", mock.MatchedBy(func(timerOpts core.TimerOptions) bool { diff := deep.Equal(timerOpts, tt.expectedTimerOptions) if diff == nil { return true @@ -127,7 +126,7 @@ func TestDumpCmd(t *testing.T) { return false })).Return(nil) if tt.expectedPruneOptions != nil { - m.On("prune", mock.MatchedBy(func(pruneOpts core.PruneOptions) bool { + m.On("Prune", mock.MatchedBy(func(pruneOpts core.PruneOptions) bool { diff := deep.Equal(pruneOpts, *tt.expectedPruneOptions) if diff == nil { return true diff --git a/cmd/prune.go b/cmd/prune.go index 9e77939b..1eabe5a8 100644 --- a/cmd/prune.go +++ b/cmd/prune.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - log "github.com/sirupsen/logrus" + "github.com/google/uuid" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -12,7 +12,7 @@ import ( "github.com/databacker/mysql-backup/pkg/storage" ) -func pruneCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { +func pruneCmd(passedExecs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { if cmdConfig == nil { return nil, fmt.Errorf("cmdConfig is nil") } @@ -31,7 +31,7 @@ func pruneCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) bindFlags(cmd, v) }, RunE: func(cmd *cobra.Command, args []string) error { - log.Debug("starting prune") + cmdConfig.logger.Debug("starting prune") retention := v.GetString("retention") targetURLs := v.GetStringSlice("target") var ( @@ -96,18 +96,20 @@ func pruneCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) Frequency: frequency, } - prune := core.Prune - timer := core.TimerCommand - if execs != nil { - prune = execs.prune - timer = execs.timer + var executor execs + executor = &core.Executor{} + if passedExecs != nil { + executor = passedExecs } - if err := timer(timerOpts, func() error { - return prune(core.PruneOptions{Targets: targets, Retention: retention}) + executor.SetLogger(cmdConfig.logger) + + if err := executor.Timer(timerOpts, func() error { + uid := uuid.New() + return executor.Prune(core.PruneOptions{Targets: targets, Retention: retention, Run: uid}) }); err != nil { return fmt.Errorf("error running prune: %w", err) } - log.Info("Pruning complete") + executor.GetLogger().Info("Pruning complete") return nil }, } diff --git a/cmd/prune_test.go b/cmd/prune_test.go index 99622b31..ef502969 100644 --- a/cmd/prune_test.go +++ b/cmd/prune_test.go @@ -7,7 +7,6 @@ import ( "github.com/databacker/mysql-backup/pkg/core" "github.com/databacker/mysql-backup/pkg/storage" "github.com/databacker/mysql-backup/pkg/storage/file" - "github.com/go-test/deep" "github.com/stretchr/testify/mock" ) @@ -32,15 +31,14 @@ func TestPruneCmd(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := newMockExecs() - m.On("prune", mock.MatchedBy(func(pruneOpts core.PruneOptions) bool { - diff := deep.Equal(pruneOpts, tt.expectedPruneOptions) - if diff == nil { + m.On("Prune", mock.MatchedBy(func(pruneOpts core.PruneOptions) bool { + if equalIgnoreFields(pruneOpts, tt.expectedPruneOptions, []string{"Run"}) { return true } - t.Errorf("pruneOpts compare failed: %v", diff) + t.Errorf("pruneOpts compare failed: %#v %#v", pruneOpts, tt.expectedPruneOptions) return false })).Return(nil) - m.On("timer", tt.expectedTimerOptions).Return(nil) + m.On("Timer", tt.expectedTimerOptions).Return(nil) cmd, err := rootCmd(m) if err != nil { t.Fatal(err) diff --git a/cmd/restore.go b/cmd/restore.go index 1334b481..cdaf9b0e 100644 --- a/cmd/restore.go +++ b/cmd/restore.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - log "github.com/sirupsen/logrus" + "github.com/google/uuid" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -14,7 +14,7 @@ import ( "github.com/databacker/mysql-backup/pkg/util" ) -func restoreCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { +func restoreCmd(passedExecs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error) { if cmdConfig == nil { return nil, fmt.Errorf("cmdConfig is nil") } @@ -28,7 +28,7 @@ func restoreCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error }, Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - log.Debug("starting restore") + cmdConfig.logger.Debug("starting restore") targetFile := args[0] target := v.GetString("target") // get databases namesand mappings @@ -94,16 +94,28 @@ func restoreCmd(execs execs, cmdConfig *cmdConfiguration) (*cobra.Command, error return fmt.Errorf("invalid target url: %v", err) } } - restore := core.Restore - if execs != nil { - restore = execs.restore + var executor execs + executor = &core.Executor{} + if passedExecs != nil { + executor = passedExecs } + executor.SetLogger(cmdConfig.logger) + // at this point, any errors should not have usage cmd.SilenceUsage = true - if err := restore(store, targetFile, cmdConfig.dbconn, databasesMap, compressor); err != nil { + uid := uuid.New() + restoreOpts := core.RestoreOptions{ + Target: store, + TargetFile: targetFile, + Compressor: compressor, + DatabasesMap: databasesMap, + DBConn: cmdConfig.dbconn, + Run: uid, + } + if err := executor.Restore(restoreOpts); err != nil { return fmt.Errorf("error restoring: %v", err) } - log.Info("Restore complete") + passedExecs.GetLogger().Info("Restore complete") return nil }, } diff --git a/cmd/restore_test.go b/cmd/restore_test.go index 97d221df..cd280aa0 100644 --- a/cmd/restore_test.go +++ b/cmd/restore_test.go @@ -5,9 +5,10 @@ import ( "testing" "github.com/databacker/mysql-backup/pkg/compression" + "github.com/databacker/mysql-backup/pkg/core" "github.com/databacker/mysql-backup/pkg/database" - "github.com/databacker/mysql-backup/pkg/storage" "github.com/databacker/mysql-backup/pkg/storage/file" + "github.com/stretchr/testify/mock" ) func TestRestoreCmd(t *testing.T) { @@ -17,26 +18,33 @@ func TestRestoreCmd(t *testing.T) { fileTargetURL, _ := url.Parse(fileTarget) tests := []struct { - name string - args []string // "restore" will be prepended automatically - config string - wantErr bool - expectedTarget storage.Storage - expectedTargetFile string - expectedDbconn database.Connection - expectedDatabasesMap map[string]string - expectedCompressor compression.Compressor + name string + args []string // "restore" will be prepended automatically + config string + wantErr bool + expectedRestoreOptions core.RestoreOptions + //expectedTarget storage.Storage + //expectedTargetFile string + //expectedDbconn database.Connection + //expectedDatabasesMap map[string]string + //expectedCompressor compression.Compressor }{ - {"missing server and target options", []string{""}, "", true, nil, "", database.Connection{}, nil, &compression.GzipCompressor{}}, - {"invalid target URL", []string{"--server", "abc", "--target", "def"}, "", true, nil, "", database.Connection{Host: "abc"}, nil, &compression.GzipCompressor{}}, - {"valid URL missing dump filename", []string{"--server", "abc", "--target", "file:///foo/bar"}, "", true, nil, "", database.Connection{Host: "abc"}, nil, &compression.GzipCompressor{}}, - {"valid file URL", []string{"--server", "abc", "--target", fileTarget, "filename.tgz", "--verbose", "2"}, "", false, file.New(*fileTargetURL), "filename.tgz", database.Connection{Host: "abc", Port: defaultPort}, map[string]string{}, &compression.GzipCompressor{}}, + {"missing server and target options", []string{""}, "", true, core.RestoreOptions{}}, + {"invalid target URL", []string{"--server", "abc", "--target", "def"}, "", true, core.RestoreOptions{}}, + {"valid URL missing dump filename", []string{"--server", "abc", "--target", "file:///foo/bar"}, "", true, core.RestoreOptions{}}, + {"valid file URL", []string{"--server", "abc", "--target", fileTarget, "filename.tgz", "--verbose", "2"}, "", false, core.RestoreOptions{Target: file.New(*fileTargetURL), TargetFile: "filename.tgz", DBConn: database.Connection{Host: "abc", Port: defaultPort}, DatabasesMap: map[string]string{}, Compressor: &compression.GzipCompressor{}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := newMockExecs() - m.On("restore", tt.expectedTarget, tt.expectedTargetFile, tt.expectedDbconn, tt.expectedDatabasesMap, tt.expectedCompressor).Return(nil) + m.On("Restore", mock.MatchedBy(func(restoreOpts core.RestoreOptions) bool { + if equalIgnoreFields(restoreOpts, tt.expectedRestoreOptions, []string{"Run"}) { + return true + } + t.Errorf("restoreOpts compare failed: %#v %#v", restoreOpts, tt.expectedRestoreOptions) + return false + })).Return(nil) cmd, err := rootCmd(m) if err != nil { t.Fatal(err) diff --git a/cmd/root.go b/cmd/root.go index 9d2a6bc8..bae60697 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,11 +5,10 @@ import ( "os" "strings" - "github.com/databacker/mysql-backup/pkg/compression" "github.com/databacker/mysql-backup/pkg/config" "github.com/databacker/mysql-backup/pkg/core" "github.com/databacker/mysql-backup/pkg/database" - "github.com/databacker/mysql-backup/pkg/storage" + databacklog "github.com/databacker/mysql-backup/pkg/log" "github.com/databacker/mysql-backup/pkg/storage/credentials" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -18,10 +17,12 @@ import ( ) type execs interface { - dump(opts core.DumpOptions) error - restore(target storage.Storage, targetFile string, dbconn database.Connection, databasesMap map[string]string, compressor compression.Compressor) error - prune(opts core.PruneOptions) error - timer(timerOpts core.TimerOptions, cmd func() error) error + SetLogger(logger *log.Logger) + GetLogger() *log.Logger + Dump(opts core.DumpOptions) error + Restore(opts core.RestoreOptions) error + Prune(opts core.PruneOptions) error + Timer(timerOpts core.TimerOptions, cmd func() error) error } type subCommand func(execs, *cmdConfiguration) (*cobra.Command, error) @@ -32,6 +33,7 @@ type cmdConfiguration struct { dbconn database.Connection creds credentials.Creds configuration *config.ConfigSpec + logger *log.Logger } const ( @@ -57,14 +59,15 @@ func rootCmd(execs execs) (*cobra.Command, error) { `, PersistentPreRunE: func(c *cobra.Command, args []string) error { bindFlags(cmd, v) + var logger = log.New() logLevel := v.GetInt("verbose") switch logLevel { case 0: - log.SetLevel(log.InfoLevel) + logger.SetLevel(log.InfoLevel) case 1: - log.SetLevel(log.DebugLevel) + logger.SetLevel(log.DebugLevel) case 2: - log.SetLevel(log.TraceLevel) + logger.SetLevel(log.TraceLevel) } // read the config file, if needed; the structure of the config differs quite some @@ -105,6 +108,15 @@ func rootCmd(execs execs) (*cobra.Command, error) { cmdConfig.dbconn.Pass = actualConfig.Database.Credentials.Password } cmdConfig.configuration = actualConfig + + if actualConfig.Telemetry.URL != "" { + // set up telemetry + loggerHook, err := databacklog.NewTelemetry(actualConfig.Telemetry, nil) + if err != nil { + return fmt.Errorf("unable to set up telemetry: %w", err) + } + logger.AddHook(loggerHook) + } } // override config with env var or CLI flag, if set @@ -140,6 +152,7 @@ func rootCmd(execs execs) (*cobra.Command, error) { Domain: v.GetString("smb-domain"), }, } + cmdConfig.logger = logger return nil }, } diff --git a/docs/configuration.md b/docs/configuration.md index f287857b..d3bb3803 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -60,8 +60,9 @@ Various sample configuration files are available in the [sample-configs](../samp The following are the environment variables, CLI flags and configuration file options for: backup(B), restore (R), prune (P). -| Purpose | Backup / Restore | CLI Flag | Env Var | Config Key | Default | +| Purpose | Backup / Restore / Prune | CLI Flag | Env Var | Config Key | Default | | --- | --- | --- | --- | --- | --- | +| config file path | BRP | `config` | `DB_DUMP_CONFIG` | | | | hostname or unix domain socket path (starting with a slash) to connect to database. Required. | BR | `server` | `DB_SERVER` | `database.server` | | | port to use to connect to database. Optional. | BR | `port` | `DB_PORT` | `database.port` | 3306 | | username for the database | BR | `user` | `DB_USER` | `database.credentials.username` | | diff --git a/docs/logging.md b/docs/logging.md new file mode 100644 index 00000000..fb0ca56c --- /dev/null +++ b/docs/logging.md @@ -0,0 +1,21 @@ +# Logging + +Logging is provided on standard out (stdout) and standard error (stderr). The log level can be set +using `--v` to the following levels: + +- `--v=0` (default): log level set to `INFO` +- `--v=1`: log level set to `DEBUG` +- `--v=2`: log level set to `TRACE` + +Log output and basic metrics can be sent to a remote service, using the +[configuration options](./configuration.md). + +The remote log service includes the following information: + +- backup start timestamp +- backup config, including command-line options (scrubbed of sensitive data) +- backup logs +- backup success or failure timestamp and duration + +Log levels up to debug are sent to the remote service. Trace logs are not sent to the remote service, and are +used for local debugging only. diff --git a/go.mod b/go.mod index 4b809e11..2d6e6eff 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,7 @@ require ( github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/geoffgarside/ber v1.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/uuid v1.6.0 github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect diff --git a/go.sum b/go.sum index 96a49a0f..c8e64054 100644 --- a/go.sum +++ b/go.sum @@ -113,6 +113,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= diff --git a/pkg/config/local.go b/pkg/config/local.go index ba532f07..1b703ee9 100644 --- a/pkg/config/local.go +++ b/pkg/config/local.go @@ -74,6 +74,9 @@ type DBCredentials struct { type Telemetry struct { remote.Connection + // BufferSize is the size of the buffer for telemetry messages. It keeps BufferSize messages + // in memory before sending them remotely. The default of 0 is the same as 1, i.e. send every message. + BufferSize int `yaml:"buffer-size"` } var _ yaml.Unmarshaler = &Target{} diff --git a/pkg/core/dump.go b/pkg/core/dump.go index 322d61ce..551fe3ef 100644 --- a/pkg/core/dump.go +++ b/pkg/core/dump.go @@ -20,7 +20,7 @@ const ( ) // Dump run a single dump, based on the provided opts -func Dump(opts DumpOptions) error { +func (e *Executor) Dump(opts DumpOptions) error { targets := opts.Targets safechars := opts.Safechars dbnames := opts.DBNames @@ -29,10 +29,11 @@ func Dump(opts DumpOptions) error { compact := opts.Compact suppressUseDatabase := opts.SuppressUseDatabase maxAllowedPacket := opts.MaxAllowedPacket + logger := e.Logger.WithField("run", opts.Run.String()) now := time.Now() timepart := now.Format(time.RFC3339) - log.Infof("beginning dump %s", timepart) + logger.Infof("beginning dump %s", timepart) if safechars { timepart = strings.ReplaceAll(timepart, ":", "-") } @@ -49,7 +50,7 @@ func Dump(opts DumpOptions) error { } defer os.RemoveAll(tmpdir) // execute pre-backup scripts if any - if err := preBackup(timepart, path.Join(tmpdir, targetFilename), tmpdir, opts.PreBackupScripts, log.GetLevel() == log.DebugLevel); err != nil { + if err := preBackup(timepart, path.Join(tmpdir, targetFilename), tmpdir, opts.PreBackupScripts, logger.Level == log.DebugLevel); err != nil { return fmt.Errorf("error running pre-restore: %v", err) } @@ -106,12 +107,12 @@ func Dump(opts DumpOptions) error { f.Close() // execute post-backup scripts if any - if err := postBackup(timepart, path.Join(tmpdir, targetFilename), tmpdir, opts.PostBackupScripts, log.GetLevel() == log.DebugLevel); err != nil { + if err := postBackup(timepart, path.Join(tmpdir, targetFilename), tmpdir, opts.PostBackupScripts, logger.Level == log.DebugLevel); err != nil { return fmt.Errorf("error running pre-restore: %v", err) } // perform any renaming - newName, err := renameSource(timepart, path.Join(tmpdir, targetFilename), tmpdir, log.GetLevel() == log.DebugLevel) + newName, err := renameSource(timepart, path.Join(tmpdir, targetFilename), tmpdir, logger.Level == log.DebugLevel) if err != nil { return fmt.Errorf("failed rename source: %v", err) } @@ -120,7 +121,7 @@ func Dump(opts DumpOptions) error { } // perform any renaming - newName, err = renameTarget(timepart, path.Join(tmpdir, targetFilename), tmpdir, log.GetLevel() == log.DebugLevel) + newName, err = renameTarget(timepart, path.Join(tmpdir, targetFilename), tmpdir, logger.Level == log.DebugLevel) if err != nil { return fmt.Errorf("failed rename target: %v", err) } @@ -130,12 +131,12 @@ func Dump(opts DumpOptions) error { // upload to each destination for _, t := range targets { - log.Debugf("uploading via protocol %s from %s", t.Protocol(), targetFilename) - copied, err := t.Push(targetFilename, filepath.Join(tmpdir, sourceFilename)) + logger.Debugf("uploading via protocol %s from %s", t.Protocol(), targetFilename) + copied, err := t.Push(targetFilename, filepath.Join(tmpdir, sourceFilename), logger) if err != nil { return fmt.Errorf("failed to push file: %v", err) } - log.Debugf("completed copying %d bytes", copied) + logger.Debugf("completed copying %d bytes", copied) } return nil diff --git a/pkg/core/dumpoptions.go b/pkg/core/dumpoptions.go index 4a14f347..4a41cba9 100644 --- a/pkg/core/dumpoptions.go +++ b/pkg/core/dumpoptions.go @@ -4,6 +4,7 @@ import ( "github.com/databacker/mysql-backup/pkg/compression" "github.com/databacker/mysql-backup/pkg/database" "github.com/databacker/mysql-backup/pkg/storage" + "github.com/google/uuid" ) type DumpOptions struct { @@ -18,4 +19,5 @@ type DumpOptions struct { Compact bool SuppressUseDatabase bool MaxAllowedPacket int + Run uuid.UUID } diff --git a/pkg/core/executor.go b/pkg/core/executor.go new file mode 100644 index 00000000..5887105b --- /dev/null +++ b/pkg/core/executor.go @@ -0,0 +1,17 @@ +package core + +import ( + log "github.com/sirupsen/logrus" +) + +type Executor struct { + Logger *log.Logger +} + +func (e *Executor) SetLogger(logger *log.Logger) { + e.Logger = logger +} + +func (e *Executor) GetLogger() *log.Logger { + return e.Logger +} diff --git a/pkg/core/prune.go b/pkg/core/prune.go index 82475c2a..3b4d10f5 100644 --- a/pkg/core/prune.go +++ b/pkg/core/prune.go @@ -7,16 +7,15 @@ import ( "slices" "strconv" "time" - - log "github.com/sirupsen/logrus" ) // filenameRE is a regular expression to match a backup filename var filenameRE = regexp.MustCompile(`^db_backup_(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})Z\.\w+$`) // Prune prune older backups -func Prune(opts PruneOptions) error { - log.Info("beginning prune") +func (e *Executor) Prune(opts PruneOptions) error { + logger := e.Logger.WithField("run", opts.Run.String()) + logger.Info("beginning prune") var ( candidates []string now = opts.Now @@ -36,8 +35,8 @@ func Prune(opts PruneOptions) error { for _, target := range opts.Targets { var pruned int - log.Debugf("pruning target %s", target) - files, err := target.ReadDir(".") + logger.Debugf("pruning target %s", target) + files, err := target.ReadDir(".", logger) if err != nil { return fmt.Errorf("failed to read directory: %v", err) } @@ -49,17 +48,17 @@ func Prune(opts PruneOptions) error { filename := fileInfo.Name() matches := filenameRE.FindStringSubmatch(filename) if matches == nil { - log.Debugf("ignoring filename that is not standard backup pattern: %s", filename) + logger.Debugf("ignoring filename that is not standard backup pattern: %s", filename) continue } - log.Debugf("checking filename that is standard backup pattern: %s", filename) + logger.Debugf("checking filename that is standard backup pattern: %s", filename) // Parse the date from the filename year, month, day, hour, minute, second := matches[1], matches[2], matches[3], matches[4], matches[5], matches[6] dateTimeStr := fmt.Sprintf("%s-%s-%sT%s:%s:%sZ", year, month, day, hour, minute, second) filetime, err := time.Parse(time.RFC3339, dateTimeStr) if err != nil { - log.Debugf("Error parsing date from filename %s: %v; ignoring", filename, err) + logger.Debugf("Error parsing date from filename %s: %v; ignoring", filename, err) continue } filesWithTimes = append(filesWithTimes, fileWithTime{ @@ -75,11 +74,11 @@ func Prune(opts PruneOptions) error { // Check if the file is within 'retain' hours from 'now' age := now.Sub(f.filetime).Hours() if age < float64(retainHours) { - log.Debugf("file %s is %f hours old", f.filename, age) - log.Debugf("keeping file %s", f.filename) + logger.Debugf("file %s is %f hours old", f.filename, age) + logger.Debugf("keeping file %s", f.filename) continue } - log.Debugf("Adding candidate file: %s", f.filename) + logger.Debugf("Adding candidate file: %s", f.filename) candidates = append(candidates, f.filename) } case retainCount > 0: @@ -96,7 +95,7 @@ func Prune(opts PruneOptions) error { slices.Reverse(filesWithTimes) if retainCount >= len(filesWithTimes) { for i := 0 + retainCount; i < len(filesWithTimes); i++ { - log.Debugf("Adding candidate file %s:", filesWithTimes[i].filename) + logger.Debugf("Adding candidate file %s:", filesWithTimes[i].filename) candidates = append(candidates, filesWithTimes[i].filename) } } @@ -106,12 +105,12 @@ func Prune(opts PruneOptions) error { // we have the list, remove them all for _, filename := range candidates { - if err := target.Remove(filename); err != nil { + if err := target.Remove(filename, logger); err != nil { return fmt.Errorf("failed to remove file %s: %v", filename, err) } pruned++ } - log.Debugf("pruning %d files from target %s", pruned, target) + logger.Debugf("pruning %d files from target %s", pruned, target) } return nil diff --git a/pkg/core/prune_test.go b/pkg/core/prune_test.go index 9197c821..cd3a4339 100644 --- a/pkg/core/prune_test.go +++ b/pkg/core/prune_test.go @@ -2,6 +2,7 @@ package core import ( "fmt" + "io" "os" "slices" "testing" @@ -9,6 +10,7 @@ import ( "github.com/databacker/mysql-backup/pkg/storage" "github.com/databacker/mysql-backup/pkg/storage/credentials" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) @@ -107,7 +109,12 @@ func TestPrune(t *testing.T) { } // run Prune - err := Prune(tt.opts) + logger := log.New() + logger.Out = io.Discard + executor := Executor{ + Logger: logger, + } + err := executor.Prune(tt.opts) switch { case (err == nil && tt.err != nil) || (err != nil && tt.err == nil): t.Errorf("expected error %v, got %v", tt.err, err) diff --git a/pkg/core/pruneoptions.go b/pkg/core/pruneoptions.go index 07efd365..993bdd39 100644 --- a/pkg/core/pruneoptions.go +++ b/pkg/core/pruneoptions.go @@ -4,10 +4,12 @@ import ( "time" "github.com/databacker/mysql-backup/pkg/storage" + "github.com/google/uuid" ) type PruneOptions struct { Targets []storage.Storage Retention string Now time.Time + Run uuid.UUID } diff --git a/pkg/core/restore.go b/pkg/core/restore.go index 6f36fa0e..135f7c0a 100644 --- a/pkg/core/restore.go +++ b/pkg/core/restore.go @@ -6,12 +6,8 @@ import ( "os" "path" - log "github.com/sirupsen/logrus" - "github.com/databacker/mysql-backup/pkg/archive" - "github.com/databacker/mysql-backup/pkg/compression" "github.com/databacker/mysql-backup/pkg/database" - "github.com/databacker/mysql-backup/pkg/storage" ) const ( @@ -21,20 +17,22 @@ const ( ) // Restore restore a specific backup into the database -func Restore(target storage.Storage, targetFile string, dbconn database.Connection, databasesMap map[string]string, compressor compression.Compressor) error { - log.Info("beginning restore") +func (e *Executor) Restore(opts RestoreOptions) error { + logger := e.Logger.WithField("run", opts.Run.String()) + + logger.Info("beginning restore") // execute pre-restore scripts if any - if err := preRestore(target.URL()); err != nil { + if err := preRestore(opts.Target.URL()); err != nil { return fmt.Errorf("error running pre-restore: %v", err) } - log.Debugf("restoring via %s protocol, temporary file location %s", target.Protocol(), tmpRestoreFile) + logger.Debugf("restoring via %s protocol, temporary file location %s", opts.Target.Protocol(), tmpRestoreFile) - copied, err := target.Pull(targetFile, tmpRestoreFile) + copied, err := opts.Target.Pull(opts.TargetFile, tmpRestoreFile, logger) if err != nil { - return fmt.Errorf("failed to pull target %s: %v", target, err) + return fmt.Errorf("failed to pull target %s: %v", opts.Target, err) } - log.Debugf("completed copying %d bytes", copied) + logger.Debugf("completed copying %d bytes", copied) // successfully download file, now restore it tmpdir, err := os.MkdirTemp("", "restore") @@ -50,7 +48,7 @@ func Restore(target storage.Storage, targetFile string, dbconn database.Connecti os.Remove(tmpRestoreFile) // create my tar reader to put the files in the directory - cr, err := compressor.Uncompress(f) + cr, err := opts.Compressor.Uncompress(f) if err != nil { return fmt.Errorf("unable to create an uncompressor: %v", err) } @@ -76,12 +74,12 @@ func Restore(target storage.Storage, targetFile string, dbconn database.Connecti defer file.Close() readers = append(readers, file) } - if err := database.Restore(dbconn, databasesMap, readers); err != nil { + if err := database.Restore(opts.DBConn, opts.DatabasesMap, readers); err != nil { return fmt.Errorf("failed to restore database: %v", err) } // execute post-restore scripts if any - if err := postRestore(target.URL()); err != nil { + if err := postRestore(opts.Target.URL()); err != nil { return fmt.Errorf("error running post-restove: %v", err) } return nil diff --git a/pkg/core/restoreoptions.go b/pkg/core/restoreoptions.go new file mode 100644 index 00000000..dfda97b0 --- /dev/null +++ b/pkg/core/restoreoptions.go @@ -0,0 +1,17 @@ +package core + +import ( + "github.com/databacker/mysql-backup/pkg/compression" + "github.com/databacker/mysql-backup/pkg/database" + "github.com/databacker/mysql-backup/pkg/storage" + "github.com/google/uuid" +) + +type RestoreOptions struct { + Target storage.Storage + TargetFile string + DBConn database.Connection + DatabasesMap map[string]string + Compressor compression.Compressor + Run uuid.UUID +} diff --git a/pkg/core/timer.go b/pkg/core/timer.go index 2a9c8f7c..1d3723d4 100644 --- a/pkg/core/timer.go +++ b/pkg/core/timer.go @@ -8,7 +8,6 @@ import ( "time" "github.com/robfig/cron/v3" - log "github.com/sirupsen/logrus" ) type TimerOptions struct { @@ -164,11 +163,11 @@ func waitForCron(cronExpr string, from time.Time) (time.Duration, error) { return next.Sub(from), nil } -// TimerCommand runs a command on a timer -func TimerCommand(timerOpts TimerOptions, cmd func() error) error { +// Timer runs a command on a timer +func (e *Executor) Timer(timerOpts TimerOptions, cmd func() error) error { c, err := Timer(timerOpts) if err != nil { - log.Errorf("error creating timer: %v", err) + e.Logger.Errorf("error creating timer: %v", err) os.Exit(1) } // block and wait for it diff --git a/pkg/internal/remote/certs.go b/pkg/internal/remote/certs.go new file mode 100644 index 00000000..f45274f5 --- /dev/null +++ b/pkg/internal/remote/certs.go @@ -0,0 +1,61 @@ +package remote + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "time" +) + +const ( + clientOrg = "client.databack.io" + certValidity = 5 * time.Minute + DigestSha256 = "sha256" +) + +// SelfSignedCertFromPrivateKey creates a self-signed certificate from an ed25519 private key +func SelfSignedCertFromPrivateKey(privateKey ed25519.PrivateKey, hostname string) (*tls.Certificate, error) { + if privateKey == nil || len(privateKey) != ed25519.PrivateKeySize { + return nil, fmt.Errorf("invalid private key") + } + publicKey := privateKey.Public() + + // Create a template for the certificate + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{clientOrg}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(certValidity), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + if hostname != "" { + template.DNSNames = append(template.DNSNames, hostname) + } + + // Self-sign the certificate + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey, privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %w", err) + } + + // Encode and print the certificate + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) + + // Create the TLS certificate to use in tls.Config + marshaledPrivateKey, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal private key: %w", err) + } + cert, err := tls.X509KeyPair(certPEM, pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: marshaledPrivateKey})) + return &cert, err +} diff --git a/pkg/internal/test/README.md b/pkg/internal/test/README.md new file mode 100644 index 00000000..43ccb996 --- /dev/null +++ b/pkg/internal/test/README.md @@ -0,0 +1,4 @@ +# test package + +Contains common utilities used in tests across other packages. Part of `pkg/internal/` to ensure it is not +exported elsewhere. diff --git a/pkg/internal/test/remote.go b/pkg/internal/test/remote.go new file mode 100644 index 00000000..fa8bf5c2 --- /dev/null +++ b/pkg/internal/test/remote.go @@ -0,0 +1,77 @@ +package test + +import ( + "crypto" + "crypto/ed25519" + cryptorand "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "net/http" + "net/http/httptest" + + utilremote "github.com/databacker/mysql-backup/pkg/internal/remote" +) + +func StartServer(clientKeyCount int, handler http.HandlerFunc) (server *httptest.Server, serverFingerprint string, clientKeys [][]byte, err error) { + // Generate new private keys for each of the clients + var clientPublicKeys []crypto.PublicKey + for i := 0; i < clientKeyCount; i++ { + clientSeed := make([]byte, ed25519.SeedSize) + if _, err := io.ReadFull(cryptorand.Reader, clientSeed); err != nil { + return nil, "", nil, fmt.Errorf("failed to generate client random seed: %w", err) + } + clientKeys = append(clientKeys, clientSeed) + clientKey := ed25519.NewKeyFromSeed(clientSeed) + clientPublicKeys = append(clientPublicKeys, clientKey.Public()) + } + + serverSeed := make([]byte, ed25519.SeedSize) + if _, err := io.ReadFull(cryptorand.Reader, serverSeed); err != nil { + return nil, "", nil, fmt.Errorf("failed to generate server random seed: %w", err) + } + serverKey := ed25519.NewKeyFromSeed(serverSeed) + + // Create a self-signed certificate from the private key + serverCert, err := utilremote.SelfSignedCertFromPrivateKey(serverKey, "127.0.0.1") + if err != nil { + return nil, "", nil, fmt.Errorf("failed to create self-signed certificate: %v", err) + } + serverFingerprint = fmt.Sprintf("%s:%s", utilremote.DigestSha256, fmt.Sprintf("%x", sha256.Sum256(serverCert.Certificate[0]))) + + // Start a local HTTPS server + server = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if the client's public key is in the known list + peerCerts := r.TLS.PeerCertificates + if len(peerCerts) == 0 { + w.WriteHeader(http.StatusForbidden) + return + } + peerPublicKey := peerCerts[0].PublicKey.(ed25519.PublicKey) + // make sure the client's public key is in the list of known keys + var matched bool + for _, publicKey := range clientPublicKeys { + if peerPublicKey.Equal(publicKey.(ed25519.PublicKey)) { + matched = true + break + } + } + if !matched { + w.WriteHeader(http.StatusForbidden) + return + } + // was any custom handler passed? + if handler != nil { + handler(w, r) + } + })) + server.TLS = &tls.Config{ + ClientAuth: tls.RequestClientCert, + ClientCAs: x509.NewCertPool(), + Certificates: []tls.Certificate{*serverCert}, + } + server.StartTLS() + return server, serverFingerprint, clientKeys, nil +} diff --git a/pkg/log/telemetry.go b/pkg/log/telemetry.go new file mode 100644 index 00000000..9adb3c71 --- /dev/null +++ b/pkg/log/telemetry.go @@ -0,0 +1,119 @@ +package log + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + + "github.com/databacker/mysql-backup/pkg/config" + "github.com/databacker/mysql-backup/pkg/remote" + log "github.com/sirupsen/logrus" +) + +const ( + sourceField = "source" + sourceTelemetry = "telemetry" +) + +// NewTelemetry creates a new telemetry writer, which writes to the configured telemetry endpoint. +// NewTelemetry creates an initial connection, which it keeps open and then can reopen as needed for each write. +func NewTelemetry(conf config.Telemetry, ch chan<- int) (log.Hook, error) { + client, err := remote.GetTLSClient(conf.Certificates, conf.Credentials) + if err != nil { + return nil, err + } + req, err := http.NewRequest(http.MethodGet, conf.URL, nil) + if err != nil { + return nil, fmt.Errorf("error creating HTTP request: %w", err) + } + + // GET the telemetry endpoint; this is just done to check that it is valid. + // Other requests will be POSTs. + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error requesting telemetry endpoint: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("error requesting telemetry endpoint: %s", resp.Status) + } + return &telemetry{conf: conf, client: client, ch: ch}, nil +} + +type telemetry struct { + conf config.Telemetry + client *http.Client + buffer []*log.Entry + // ch channel to indicate when done sending a message, in case needed for synchronization, e.g. testing. + // sends a count down the channel when done sending a message to the remote. The count is the number. + // of messages sent. + ch chan<- int +} + +// Levels the levels for which the hook should fire +func (t *telemetry) Levels() []log.Level { + return []log.Level{log.PanicLevel, log.FatalLevel, log.ErrorLevel, log.WarnLevel, log.InfoLevel, log.DebugLevel} +} + +// Fire send off a log entry. +func (t *telemetry) Fire(entry *log.Entry) error { + // send the log entry to the telemetry endpoint + // this is blocking, and we do not want to do so, so do it in a go routine + // and do not wait for the response. + + // if this message is from ourself, do not try to send it again + if entry.Data[sourceField] == sourceTelemetry { + return nil + } + t.buffer = append(t.buffer, entry) + if t.conf.BufferSize <= 1 || len(t.buffer) >= t.conf.BufferSize { + entries := t.buffer + t.buffer = nil + go func(entries []*log.Entry, ch chan<- int) { + if ch != nil { + defer func() { ch <- len(entries) }() + } + l := entry.Logger.WithField(sourceField, sourceTelemetry) + remoteEntries := make([]LogEntry, len(entries)) + for i, entry := range entries { + // send the structured data to the telemetry endpoint + var runID string + if v, ok := entry.Data["run"]; ok { + runID = v.(string) + } + remoteEntries[i] = LogEntry{ + Run: runID, + Timestamp: entry.Time.Format("2006-01-02T15:04:05.000Z"), + Level: entry.Level.String(), + Fields: entry.Data, + Message: entry.Message, + } + } + // marshal to json + b, err := json.Marshal(remoteEntries) + if err != nil { + l.Errorf("error marshalling log entry: %v", err) + return + } + req, err := http.NewRequest(http.MethodPost, t.conf.URL, bytes.NewReader(b)) + if err != nil { + l.Errorf("error creating telemetry HTTP request: %v", err) + return + } + req.Header.Set("Content-Type", "application/json") + + // POST to the telemetry endpoint + resp, err := t.client.Do(req) + if err != nil { + l.Errorf("error connecting to telemetry endpoint: %v", err) + return + } + + if resp.StatusCode != http.StatusCreated { + l.Errorf("failed sending data telemetry endpoint: %s", resp.Status) + return + } + }(entries, t.ch) + } + return nil +} diff --git a/pkg/log/telemetry_test.go b/pkg/log/telemetry_test.go new file mode 100644 index 00000000..a87ad1e0 --- /dev/null +++ b/pkg/log/telemetry_test.go @@ -0,0 +1,119 @@ +package log + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "math/rand" + "net/http" + "testing" + "time" + + "github.com/databacker/mysql-backup/pkg/config" + utiltest "github.com/databacker/mysql-backup/pkg/internal/test" + "github.com/databacker/mysql-backup/pkg/remote" + + log "github.com/sirupsen/logrus" +) + +// TestSendLog tests sending logs. There is no `SendLog` function in the codebase, +// as it is all just a hook for logrus. This test is a test of the actual functionality. +func TestSendLog(t *testing.T) { + tests := []struct { + name string + level log.Level + fields map[string]interface{} + bufSize int + expected bool + }{ + {"normal", log.InfoLevel, nil, 1, true}, + {"fatal", log.FatalLevel, nil, 1, true}, + {"error", log.ErrorLevel, nil, 1, true}, + {"warn", log.WarnLevel, nil, 1, true}, + {"debug", log.DebugLevel, nil, 1, true}, + {"debug", log.DebugLevel, nil, 3, true}, + {"trace", log.TraceLevel, nil, 1, false}, + {"self-log", log.InfoLevel, map[string]interface{}{ + sourceField: sourceTelemetry, + }, 1, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + server, fingerprint, clientKeys, err := utiltest.StartServer(1, func(w http.ResponseWriter, r *http.Request) { + _, err := buf.ReadFrom(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + }) + if err != nil { + t.Fatalf("failed to start server: %v", err) + } + defer server.Close() + + ch := make(chan int, 1) + logger := log.New() + hook, err := NewTelemetry(config.Telemetry{ + Connection: remote.Connection{ + URL: server.URL, + Certificates: []string{fingerprint}, + Credentials: base64.StdEncoding.EncodeToString(clientKeys[0]), + }, + BufferSize: tt.bufSize, + }, ch) + if err != nil { + t.Fatalf("failed to create telemetry hook: %v", err) + } + // add the hook and set the writer + logger.SetLevel(log.TraceLevel) + logger.AddHook(hook) + var localBuf bytes.Buffer + logger.SetOutput(&localBuf) + + buf.Reset() + var msgs []string + for i := 0; i < tt.bufSize; i++ { + msg := fmt.Sprintf("test message %d random %d", i, rand.Intn(1000)) + msgs = append(msgs, msg) + logger.WithFields(tt.fields).Log(tt.level, msg) + } + // wait for the message to get across, but only one second maximum, as it should be quick + // this allows us to handle those that should not have a message and never send anything + var msgCount int + select { + case msgCount = <-ch: + case <-time.After(1 * time.Second): + } + if tt.expected { + if buf.Len() == 0 { + t.Fatalf("expected log message, got none") + } + // message is sent as json, so convert to our structure and compare + var entries []LogEntry + if err := json.Unmarshal(buf.Bytes(), &entries); err != nil { + t.Fatalf("failed to unmarshal log entries: %v", err) + } + if len(entries) != msgCount { + t.Fatalf("channel sent %d log entries, actual got %d", msgCount, len(entries)) + } + if len(entries) != tt.bufSize { + t.Fatalf("expected %d log entries, got %d", tt.bufSize, len(entries)) + } + for i, le := range entries { + if le.Message != msgs[i] { + t.Errorf("message %d: expected message %q, got %q", i, msgs[i], le.Message) + } + if le.Level != tt.level.String() { + t.Errorf("expected level %q, got %q", tt.level.String(), le.Level) + } + } + } else { + if buf.Len() != 0 { + t.Fatalf("expected no log message, got one") + } + } + }) + } +} diff --git a/pkg/log/type.go b/pkg/log/type.go new file mode 100644 index 00000000..ffb57f1b --- /dev/null +++ b/pkg/log/type.go @@ -0,0 +1,9 @@ +package log + +type LogEntry struct { + Run string `json:"run"` + Timestamp string `json:"timestamp"` + Level string `json:"level"` + Fields map[string]interface{} `json:"fields"` + Message string `json:"message"` +} diff --git a/pkg/remote/const.go b/pkg/remote/const.go index 576667f0..fbe5b64e 100644 --- a/pkg/remote/const.go +++ b/pkg/remote/const.go @@ -1,9 +1 @@ package remote - -import "time" - -const ( - clientOrg = "client.databack.io" - certValidity = 5 * time.Minute - digestSha256 = "sha256" -) diff --git a/pkg/remote/get.go b/pkg/remote/get.go index 53ca39e5..2a5e9631 100644 --- a/pkg/remote/get.go +++ b/pkg/remote/get.go @@ -3,23 +3,21 @@ package remote import ( "context" "crypto/ed25519" - "crypto/rand" "crypto/sha256" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" "encoding/base64" - "encoding/pem" "fmt" - "math/big" "net" "net/http" "strings" "time" + + utilremote "github.com/databacker/mysql-backup/pkg/internal/remote" ) var ( - validAlgos = []string{digestSha256} + validAlgos = []string{utilremote.DigestSha256} validAlgosHash = map[string]bool{} ) @@ -37,6 +35,27 @@ func OpenConnection(u string, certs []string, credentials string) (resp *http.Re // open a connection to the URL. // Uses mTLS, but rather than verifying the CA that signed the client cert, // server should accept a self-signed cert. It then should check if the client's public key is in a known good list. + client, err := GetTLSClient(certs, credentials) + if err != nil { + return nil, fmt.Errorf("error creating TLS client: %w", err) + } + + req, err := http.NewRequest(http.MethodGet, u, nil) + if err != nil { + return nil, fmt.Errorf("error creating HTTP request: %w", err) + } + + return client.Do(req) +} + +// GetTLSClient gets a TLS client for a connection to a TLS server, given the URL, digests of acceptable certs, and curve25519 key for authentication. +// The credentials should be base64-encoded curve25519 private key. This is curve25519 and *not* ed25519; ed25519 calls this +// the "seed key". It must be 32 bytes long. +// The certs should be a list of fingerprints in the format "algo:hex-fingerprint". +func GetTLSClient(certs []string, credentials string) (client *http.Client, err error) { + // open a connection to the URL. + // Uses mTLS, but rather than verifying the CA that signed the client cert, + // server should accept a self-signed cert. It then should check if the client's public key is in a known good list. var trustedCertsByAlgo = map[string]map[string]bool{} for _, fingerprint := range certs { @@ -63,12 +82,12 @@ func OpenConnection(u string, certs []string, credentials string) (resp *http.Re } key := ed25519.NewKeyFromSeed(keyBytes) - clientCert, err := selfSignedCertFromPrivateKey(key, "") + clientCert, err := utilremote.SelfSignedCertFromPrivateKey(key, "") if err != nil { return nil, fmt.Errorf("error creating client certificate: %w", err) } - client := &http.Client{ + client = &http.Client{ Transport: &http.Transport{ // Configure TLS via DialTLS DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -117,7 +136,7 @@ func OpenConnection(u string, certs []string, credentials string) (resp *http.Re // the cert presented by the server was not signed by a known CA, so fall back to our own list for _, rawCert := range rawCerts { fingerprint := fmt.Sprintf("%x", sha256.Sum256(rawCert)) - if trustedFingerprints, ok := trustedCertsByAlgo[digestSha256]; ok { + if trustedFingerprints, ok := trustedCertsByAlgo[utilremote.DigestSha256]; ok { if _, ok := trustedFingerprints[fingerprint]; ok { if validateCert(certs[0], host) { return nil @@ -135,54 +154,7 @@ func OpenConnection(u string, certs []string, credentials string) (resp *http.Re }, }, } - req, err := http.NewRequest(http.MethodGet, u, nil) - if err != nil { - return nil, fmt.Errorf("error creating HTTP request: %w", err) - } - - return client.Do(req) -} - -// selfSignedCertFromPrivateKey creates a self-signed certificate from an ed25519 private key -func selfSignedCertFromPrivateKey(privateKey ed25519.PrivateKey, hostname string) (*tls.Certificate, error) { - if privateKey == nil || len(privateKey) != ed25519.PrivateKeySize { - return nil, fmt.Errorf("invalid private key") - } - publicKey := privateKey.Public() - - // Create a template for the certificate - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{clientOrg}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(certValidity), - - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, - BasicConstraintsValid: true, - } - if hostname != "" { - template.DNSNames = append(template.DNSNames, hostname) - } - - // Self-sign the certificate - certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey, privateKey) - if err != nil { - return nil, fmt.Errorf("failed to create certificate: %w", err) - } - - // Encode and print the certificate - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) - - // Create the TLS certificate to use in tls.Config - marshaledPrivateKey, err := x509.MarshalPKCS8PrivateKey(privateKey) - if err != nil { - return nil, fmt.Errorf("failed to marshal private key: %w", err) - } - cert, err := tls.X509KeyPair(certPEM, pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: marshaledPrivateKey})) - return &cert, err + return client, nil } // validateCert given a cert that we decided to trust its cert or signature, make sure its properties are correct: diff --git a/pkg/remote/get_test.go b/pkg/remote/get_test.go index d8a768b8..bde45672 100644 --- a/pkg/remote/get_test.go +++ b/pkg/remote/get_test.go @@ -3,15 +3,14 @@ package remote import ( "crypto/ed25519" cryptorand "crypto/rand" - "crypto/sha256" - "crypto/tls" "crypto/x509" "encoding/base64" - "fmt" "io" "net/http" - "net/http/httptest" "testing" + + utilremote "github.com/databacker/mysql-backup/pkg/internal/remote" + utiltest "github.com/databacker/mysql-backup/pkg/internal/test" ) func TestSelfSignedCertFromPrivateKey(t *testing.T) { @@ -41,7 +40,7 @@ func TestSelfSignedCertFromPrivateKey(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Call the function with the private key - cert, err := selfSignedCertFromPrivateKey(test.privateKey, "") + cert, err := utilremote.SelfSignedCertFromPrivateKey(test.privateKey, "") if (err != nil) != test.expectError { t.Fatalf("selfSignedCertFromPrivateKey returned an error: %v", err) } @@ -68,52 +67,15 @@ func TestSelfSignedCertFromPrivateKey(t *testing.T) { } func TestOpenConnection(t *testing.T) { - // Generate a new private key - clientSeed1 := make([]byte, ed25519.SeedSize) - if _, err := io.ReadFull(cryptorand.Reader, clientSeed1); err != nil { - t.Fatalf("failed to generate random seed: %v", err) - } - clientKey1 := ed25519.NewKeyFromSeed(clientSeed1) - - clientSeed2 := make([]byte, ed25519.SeedSize) - if _, err := io.ReadFull(cryptorand.Reader, clientSeed2); err != nil { + // Generate a private key that is not in the list of known keys + clientSeedUnknown := make([]byte, ed25519.SeedSize) + if _, err := io.ReadFull(cryptorand.Reader, clientSeedUnknown); err != nil { t.Fatalf("failed to generate random seed: %v", err) } - - serverSeed := make([]byte, ed25519.SeedSize) - if _, err := io.ReadFull(cryptorand.Reader, serverSeed); err != nil { - t.Fatalf("failed to generate random seed: %v", err) - } - serverKey := ed25519.NewKeyFromSeed(serverSeed) - - // Create a self-signed certificate from the private key - serverCert, err := selfSignedCertFromPrivateKey(serverKey, "127.0.0.1") + server, fingerprint, clientKeys, err := utiltest.StartServer(1, nil) if err != nil { - t.Fatalf("failed to create self-signed certificate: %v", err) - } - fingerprint := fmt.Sprintf("%s:%s", digestSha256, fmt.Sprintf("%x", sha256.Sum256(serverCert.Certificate[0]))) - - // Start a local HTTPS server - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check if the client's public key is in the known list - peerCerts := r.TLS.PeerCertificates - if len(peerCerts) == 0 { - w.WriteHeader(http.StatusForbidden) - return - } - peerPublicKey := peerCerts[0].PublicKey.(ed25519.PublicKey) - expectedPublicKey := clientKey1.Public().(ed25519.PublicKey) - if !peerPublicKey.Equal(expectedPublicKey) { - w.WriteHeader(http.StatusForbidden) - return - } - })) - server.TLS = &tls.Config{ - ClientAuth: tls.RequestClientCert, - ClientCAs: x509.NewCertPool(), - Certificates: []tls.Certificate{*serverCert}, + t.Fatalf("failed to start server: %v", err) } - server.StartTLS() defer server.Close() tests := []struct { @@ -125,21 +87,21 @@ func TestOpenConnection(t *testing.T) { }{ { name: "client key in list", - clientPrivateKey: clientSeed1, + clientPrivateKey: clientKeys[0], certs: []string{fingerprint}, expectError: false, expectedStatus: http.StatusOK, }, { name: "client key not in list", - clientPrivateKey: clientSeed2, + clientPrivateKey: clientSeedUnknown, certs: []string{fingerprint}, expectError: false, expectedStatus: http.StatusForbidden, }, { name: "no certs", - clientPrivateKey: clientSeed1, + clientPrivateKey: clientKeys[0], certs: []string{}, expectError: true, expectedStatus: http.StatusForbidden, diff --git a/pkg/storage/file/file.go b/pkg/storage/file/file.go index cc5653c8..195495b2 100644 --- a/pkg/storage/file/file.go +++ b/pkg/storage/file/file.go @@ -7,6 +7,8 @@ import ( "os" "path" "path/filepath" + + log "github.com/sirupsen/logrus" ) type File struct { @@ -18,11 +20,11 @@ func New(u url.URL) *File { return &File{u, u.Path} } -func (f *File) Pull(source, target string) (int64, error) { +func (f *File) Pull(source, target string, logger *log.Entry) (int64, error) { return copyFile(path.Join(f.path, source), target) } -func (f *File) Push(target, source string) (int64, error) { +func (f *File) Push(target, source string, logger *log.Entry) (int64, error) { return copyFile(source, filepath.Join(f.path, target)) } @@ -34,7 +36,7 @@ func (f *File) URL() string { return f.url.String() } -func (f *File) ReadDir(dirname string) ([]fs.FileInfo, error) { +func (f *File) ReadDir(dirname string, logger *log.Entry) ([]fs.FileInfo, error) { entries, err := os.ReadDir(filepath.Join(f.path, dirname)) if err != nil { @@ -51,7 +53,7 @@ func (f *File) ReadDir(dirname string) ([]fs.FileInfo, error) { return files, nil } -func (f *File) Remove(target string) error { +func (f *File) Remove(target string, logger *log.Entry) error { return os.Remove(filepath.Join(f.path, target)) } diff --git a/pkg/storage/s3/s3.go b/pkg/storage/s3/s3.go index fc76a80f..855256d2 100644 --- a/pkg/storage/s3/s3.go +++ b/pkg/storage/s3/s3.go @@ -64,9 +64,9 @@ func New(u url.URL, opts ...Option) *S3 { return s } -func (s *S3) Pull(source, target string) (int64, error) { +func (s *S3) Pull(source, target string, logger *log.Entry) (int64, error) { // get the s3 client - client, err := s.getClient() + client, err := s.getClient(logger) if err != nil { return 0, fmt.Errorf("failed to get AWS client: %v", err) } @@ -94,9 +94,9 @@ func (s *S3) Pull(source, target string) (int64, error) { return n, nil } -func (s *S3) Push(target, source string) (int64, error) { +func (s *S3) Push(target, source string, logger *log.Entry) (int64, error) { // get the s3 client - client, err := s.getClient() + client, err := s.getClient(logger) if err != nil { return 0, fmt.Errorf("failed to get AWS client: %v", err) } @@ -132,9 +132,9 @@ func (s *S3) URL() string { return s.url.String() } -func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) { +func (s *S3) ReadDir(dirname string, logger *log.Entry) ([]fs.FileInfo, error) { // get the s3 client - client, err := s.getClient() + client, err := s.getClient(logger) if err != nil { return nil, fmt.Errorf("failed to get AWS client: %v", err) } @@ -158,9 +158,9 @@ func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) { return files, nil } -func (s *S3) Remove(target string) error { +func (s *S3) Remove(target string, logger *log.Entry) error { // Get the AWS client - client, err := s.getClient() + client, err := s.getClient(logger) if err != nil { return fmt.Errorf("failed to get AWS client: %v", err) } @@ -177,7 +177,7 @@ func (s *S3) Remove(target string) error { return nil } -func (s *S3) getClient() (*s3.Client, error) { +func (s *S3) getClient(logger *log.Entry) (*s3.Client, error) { // Get the AWS config var opts []func(*config.LoadOptions) error if s.endpoint != "" { @@ -190,7 +190,7 @@ func (s *S3) getClient() (*s3.Client, error) { ), ) } - if log.IsLevelEnabled(log.TraceLevel) { + if logger.Level == log.TraceLevel { opts = append(opts, config.WithClientLogMode(aws.LogRequestWithBody|aws.LogResponse)) } if s.region != "" { diff --git a/pkg/storage/smb/smb.go b/pkg/storage/smb/smb.go index eab2ca65..38a043a9 100644 --- a/pkg/storage/smb/smb.go +++ b/pkg/storage/smb/smb.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/cloudsoda/go-smb2" + log "github.com/sirupsen/logrus" ) const ( @@ -49,7 +50,7 @@ func New(u url.URL, opts ...Option) *SMB { return s } -func (s *SMB) Pull(source, target string) (int64, error) { +func (s *SMB) Pull(source, target string, logger *log.Entry) (int64, error) { var ( copied int64 err error @@ -73,7 +74,7 @@ func (s *SMB) Pull(source, target string) (int64, error) { return copied, err } -func (s *SMB) Push(target, source string) (int64, error) { +func (s *SMB) Push(target, source string, logger *log.Entry) (int64, error) { var ( copied int64 err error @@ -104,7 +105,7 @@ func (s *SMB) URL() string { return s.url.String() } -func (s *SMB) ReadDir(dirname string) ([]os.FileInfo, error) { +func (s *SMB) ReadDir(dirname string, logger *log.Entry) ([]os.FileInfo, error) { var ( err error infos []os.FileInfo @@ -116,7 +117,7 @@ func (s *SMB) ReadDir(dirname string) ([]os.FileInfo, error) { return infos, err } -func (s *SMB) Remove(target string) error { +func (s *SMB) Remove(target string, logger *log.Entry) error { return s.exec(s.url, func(fs *smb2.Share, sharepath string) error { smbFilename := fmt.Sprintf("%s%c%s", sharepath, smb2.PathSeparator, filepath.Base(strings.ReplaceAll(target, ":", "-"))) return fs.Remove(smbFilename) diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 15b244bf..9c31ab44 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -1,13 +1,17 @@ package storage -import "io/fs" +import ( + "io/fs" + + log "github.com/sirupsen/logrus" +) type Storage interface { - Push(target, source string) (int64, error) - Pull(source, target string) (int64, error) Protocol() string URL() string - ReadDir(dirname string) ([]fs.FileInfo, error) + Push(target, source string, logger *log.Entry) (int64, error) + Pull(source, target string, logger *log.Entry) (int64, error) + ReadDir(dirname string, logger *log.Entry) ([]fs.FileInfo, error) // Remove remove a particular file - Remove(string) error + Remove(target string, logger *log.Entry) error } diff --git a/test/backup_test.go b/test/backup_test.go index 956f7482..4b40be35 100644 --- a/test/backup_test.go +++ b/test/backup_test.go @@ -453,8 +453,11 @@ func runDumpTest(dc *dockerContext, compact bool, base string, targets []backupT timerOpts := core.TimerOptions{ Once: true, } - return core.TimerCommand(timerOpts, func() error { - return core.Dump(dumpOpts) + executor := &core.Executor{} + executor.SetLogger(log.New()) + + return executor.Timer(timerOpts, func() error { + return executor.Dump(dumpOpts) }) }