From cc54ab6db1778ea0ff6caa0f10ef89c5cbf3bcad Mon Sep 17 00:00:00 2001 From: Avi Deitcher Date: Sun, 7 Apr 2024 12:49:39 +0300 Subject: [PATCH] fix s3 session creation Signed-off-by: Avi Deitcher --- cmd/root.go | 9 ++- pkg/storage/credentials/creds.go | 11 +++- pkg/storage/parse.go | 25 +++++--- pkg/storage/s3/s3.go | 106 ++++++++++++++++--------------- test/backup_test.go | 2 +- 5 files changed, 88 insertions(+), 65 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 11bc3523..77adca83 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -127,8 +127,13 @@ func rootCmd(execs execs) (*cobra.Command, error) { // these are not from the config file, as they are generic credentials, used across all targets. // the config file uses specific ones per target cmdConfig.creds = credentials.Creds{ - AWSEndpoint: v.GetString("aws-endpoint-url"), - SMBCredentials: credentials.SMBCreds{ + AWS: credentials.AWSCreds{ + Endpoint: v.GetString("aws-endpoint-url"), + AccessKeyID: v.GetString("aws-access-key-id"), + SecretAccessKey: v.GetString("aws-secret-access-key"), + Region: v.GetString("aws-region"), + }, + SMB: credentials.SMBCreds{ Username: v.GetString("smb-user"), Password: v.GetString("smb-pass"), Domain: v.GetString("smb-domain"), diff --git a/pkg/storage/credentials/creds.go b/pkg/storage/credentials/creds.go index 3062df92..68745a5f 100644 --- a/pkg/storage/credentials/creds.go +++ b/pkg/storage/credentials/creds.go @@ -1,8 +1,8 @@ package credentials type Creds struct { - SMBCredentials SMBCreds - AWSEndpoint string + SMB SMBCreds + AWS AWSCreds } type SMBCreds struct { @@ -10,3 +10,10 @@ type SMBCreds struct { Password string Domain string } + +type AWSCreds struct { + AccessKeyID string + SecretAccessKey string + Endpoint string + Region string +} diff --git a/pkg/storage/parse.go b/pkg/storage/parse.go index 35f794ef..13975b51 100644 --- a/pkg/storage/parse.go +++ b/pkg/storage/parse.go @@ -24,20 +24,29 @@ func ParseURL(url string, creds credentials.Creds) (Storage, error) { store = file.New(*u) case "smb": opts := []smb.Option{} - if creds.SMBCredentials.Domain != "" { - opts = append(opts, smb.WithDomain(creds.SMBCredentials.Domain)) + if creds.SMB.Domain != "" { + opts = append(opts, smb.WithDomain(creds.SMB.Domain)) } - if creds.SMBCredentials.Username != "" { - opts = append(opts, smb.WithUsername(creds.SMBCredentials.Username)) + if creds.SMB.Username != "" { + opts = append(opts, smb.WithUsername(creds.SMB.Username)) } - if creds.SMBCredentials.Password != "" { - opts = append(opts, smb.WithPassword(creds.SMBCredentials.Password)) + if creds.SMB.Password != "" { + opts = append(opts, smb.WithPassword(creds.SMB.Password)) } store = smb.New(*u, opts...) case "s3": opts := []s3.Option{} - if creds.AWSEndpoint != "" { - opts = append(opts, s3.WithEndpoint(creds.AWSEndpoint)) + if creds.AWS.Endpoint != "" { + opts = append(opts, s3.WithEndpoint(creds.AWS.Endpoint)) + } + if creds.AWS.Region != "" { + opts = append(opts, s3.WithRegion(creds.AWS.Region)) + } + if creds.AWS.AccessKeyID != "" { + opts = append(opts, s3.WithAccessKeyId(creds.AWS.AccessKeyID)) + } + if creds.AWS.SecretAccessKey != "" { + opts = append(opts, s3.WithSecretAccessKey(creds.AWS.SecretAccessKey)) } store = s3.New(*u, opts...) default: diff --git a/pkg/storage/s3/s3.go b/pkg/storage/s3/s3.go index 2482dfdc..704c6517 100644 --- a/pkg/storage/s3/s3.go +++ b/pkg/storage/s3/s3.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" log "github.com/sirupsen/logrus" @@ -64,18 +65,13 @@ func New(u url.URL, opts ...Option) *S3 { } func (s *S3) Pull(source, target string) (int64, error) { - // TODO: need to find way to include cli opts and cli_s3_cp_opts - // old was: - // aws ${AWS_CLI_OPTS} s3 cp ${AWS_CLI_S3_CP_OPTS} "${DB_RESTORE_TARGET}" $TMPRESTORE - - bucket, path := s.url.Hostname(), path.Join(s.url.Path, source) - // The session the S3 Downloader will use - cfg, err := getConfig(s.endpoint) + // get the s3 client + client, err := s.getClient() if err != nil { - return 0, fmt.Errorf("failed to load AWS config: %v", err) + return 0, fmt.Errorf("failed to get AWS client: %v", err) } - client := s3.NewFromConfig(cfg) + bucket, path := s.url.Hostname(), path.Join(s.url.Path, source) // Create a downloader with the session and default options downloader := manager.NewDownloader(client) @@ -99,18 +95,13 @@ func (s *S3) Pull(source, target string) (int64, error) { } func (s *S3) Push(target, source string) (int64, error) { - // TODO: need to find way to include cli opts and cli_s3_cp_opts - // old was: - // aws ${AWS_CLI_OPTS} s3 cp ${AWS_CLI_S3_CP_OPTS} "${DB_RESTORE_TARGET}" $TMPRESTORE - - bucket, key := s.url.Hostname(), s.url.Path - // The session the S3 Downloader will use - cfg, err := getConfig(s.endpoint) + // get the s3 client + client, err := s.getClient() if err != nil { - return 0, fmt.Errorf("failed to load AWS config: %v", err) + return 0, fmt.Errorf("failed to get AWS client: %v", err) } + bucket, key := s.url.Hostname(), s.url.Path - client := s3.NewFromConfig(cfg) // Create an uploader with the session and default options uploader := manager.NewUploader(client) @@ -142,17 +133,14 @@ func (s *S3) URL() string { } func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) { - // Get the AWS config - cfg, err := getConfig(s.endpoint) + // get the s3 client + client, err := s.getClient() if err != nil { - return nil, fmt.Errorf("failed to load AWS config: %v", err) + return nil, fmt.Errorf("failed to get AWS client: %v", err) } - // Create a new S3 service client - svc := s3.NewFromConfig(cfg) - // Call ListObjectsV2 with your bucket and prefix - result, err := svc.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{Bucket: aws.String(s.url.Hostname()), Prefix: aws.String(dirname)}) + result, err := client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{Bucket: aws.String(s.url.Hostname()), Prefix: aws.String(dirname)}) if err != nil { return nil, fmt.Errorf("failed to list objects, %v", err) } @@ -171,17 +159,14 @@ func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) { } func (s *S3) Remove(target string) error { - // Get the AWS config - cfg, err := getConfig(s.endpoint) + // Get the AWS client + client, err := s.getClient() if err != nil { - return fmt.Errorf("failed to load AWS config: %v", err) + return fmt.Errorf("failed to get AWS client: %v", err) } - // Create a new S3 service client - svc := s3.NewFromConfig(cfg) - // Call DeleteObject with your bucket and the key of the object you want to delete - _, err = svc.DeleteObject(context.TODO(), &s3.DeleteObjectInput{ + _, err = client.DeleteObject(context.TODO(), &s3.DeleteObjectInput{ Bucket: aws.String(s.url.Hostname()), Key: aws.String(target), }) @@ -192,9 +177,44 @@ func (s *S3) Remove(target string) error { return nil } +func (s *S3) getClient() (*s3.Client, error) { + // Get the AWS config + cleanEndpoint := getEndpoint(s.endpoint) + opts := []func(*config.LoadOptions) error{ + config.WithEndpointResolverWithOptions( + aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { + return aws.Endpoint{URL: cleanEndpoint}, nil + }), + ), + } + if log.IsLevelEnabled(log.TraceLevel) { + opts = append(opts, config.WithClientLogMode(aws.LogRequestWithBody|aws.LogResponse)) + } + if s.region != "" { + opts = append(opts, config.WithRegion(s.region)) + } + if s.accessKeyId != "" { + opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( + s.accessKeyId, + s.secretAccessKey, + "", + ))) + } + cfg, err := config.LoadDefaultConfig(context.TODO(), + opts..., + ) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %v", err) + } + + // Create a new S3 service client + return s3.NewFromConfig(cfg), nil +} + +// getEndpoint returns a clean (for AWS client) endpoint. Normally, this is unchanged, +// but for some reason, the lookup gets flaky when the endpoint is 127.0.0.1, +// so in that case, set it to localhost explicitly. func getEndpoint(endpoint string) string { - // for some reason, the lookup gets flaky when the endpoint is 127.0.0.1 - // so you have to set it to localhost explicitly. e := endpoint u, err := url.Parse(endpoint) if err == nil { @@ -210,24 +230,6 @@ func getEndpoint(endpoint string) string { return e } -func getConfig(endpoint string) (aws.Config, error) { - cleanEndpoint := getEndpoint(endpoint) - opts := []func(*config.LoadOptions) error{ - config.WithEndpointResolverWithOptions( - aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { - return aws.Endpoint{URL: cleanEndpoint}, nil - }), - ), - } - if log.IsLevelEnabled(log.TraceLevel) { - opts = append(opts, config.WithClientLogMode(aws.LogRequestWithBody|aws.LogResponse)) - } - return config.LoadDefaultConfig(context.TODO(), - opts..., - ) - -} - type s3FileInfo struct { name string lastModified time.Time diff --git a/test/backup_test.go b/test/backup_test.go index 1e62bebe..6b771f6c 100644 --- a/test/backup_test.go +++ b/test/backup_test.go @@ -434,7 +434,7 @@ func runDumpTest(dc *dockerContext, compact bool, base string, targets []backupT if err := os.MkdirAll(localPath, 0o755); err != nil { return fmt.Errorf("failed to create local path %s: %w", localPath, err) } - store, err := storage.ParseURL(t, credentials.Creds{AWSEndpoint: s3}) + store, err := storage.ParseURL(t, credentials.Creds{AWS: credentials.AWSCreds{Endpoint: s3}}) if err != nil { return fmt.Errorf("invalid target url: %v", err) }