Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix s3 session creation #295

Merged
merged 1 commit into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,13 @@ func rootCmd(execs execs) (*cobra.Command, error) {
// these are not from the config file, as they are generic credentials, used across all targets.
// the config file uses specific ones per target
cmdConfig.creds = credentials.Creds{
AWSEndpoint: v.GetString("aws-endpoint-url"),
SMBCredentials: credentials.SMBCreds{
AWS: credentials.AWSCreds{
Endpoint: v.GetString("aws-endpoint-url"),
AccessKeyID: v.GetString("aws-access-key-id"),
SecretAccessKey: v.GetString("aws-secret-access-key"),
Region: v.GetString("aws-region"),
},
SMB: credentials.SMBCreds{
Username: v.GetString("smb-user"),
Password: v.GetString("smb-pass"),
Domain: v.GetString("smb-domain"),
Expand Down
11 changes: 9 additions & 2 deletions pkg/storage/credentials/creds.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package credentials

type Creds struct {
SMBCredentials SMBCreds
AWSEndpoint string
SMB SMBCreds
AWS AWSCreds
}

type SMBCreds struct {
Username string
Password string
Domain string
}

type AWSCreds struct {
AccessKeyID string
SecretAccessKey string
Endpoint string
Region string
}
25 changes: 17 additions & 8 deletions pkg/storage/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,29 @@ func ParseURL(url string, creds credentials.Creds) (Storage, error) {
store = file.New(*u)
case "smb":
opts := []smb.Option{}
if creds.SMBCredentials.Domain != "" {
opts = append(opts, smb.WithDomain(creds.SMBCredentials.Domain))
if creds.SMB.Domain != "" {
opts = append(opts, smb.WithDomain(creds.SMB.Domain))
}
if creds.SMBCredentials.Username != "" {
opts = append(opts, smb.WithUsername(creds.SMBCredentials.Username))
if creds.SMB.Username != "" {
opts = append(opts, smb.WithUsername(creds.SMB.Username))
}
if creds.SMBCredentials.Password != "" {
opts = append(opts, smb.WithPassword(creds.SMBCredentials.Password))
if creds.SMB.Password != "" {
opts = append(opts, smb.WithPassword(creds.SMB.Password))
}
store = smb.New(*u, opts...)
case "s3":
opts := []s3.Option{}
if creds.AWSEndpoint != "" {
opts = append(opts, s3.WithEndpoint(creds.AWSEndpoint))
if creds.AWS.Endpoint != "" {
opts = append(opts, s3.WithEndpoint(creds.AWS.Endpoint))
}
if creds.AWS.Region != "" {
opts = append(opts, s3.WithRegion(creds.AWS.Region))
}
if creds.AWS.AccessKeyID != "" {
opts = append(opts, s3.WithAccessKeyId(creds.AWS.AccessKeyID))
}
if creds.AWS.SecretAccessKey != "" {
opts = append(opts, s3.WithSecretAccessKey(creds.AWS.SecretAccessKey))
}
store = s3.New(*u, opts...)
default:
Expand Down
106 changes: 54 additions & 52 deletions pkg/storage/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -64,18 +65,13 @@ func New(u url.URL, opts ...Option) *S3 {
}

func (s *S3) Pull(source, target string) (int64, error) {
// TODO: need to find way to include cli opts and cli_s3_cp_opts
// old was:
// aws ${AWS_CLI_OPTS} s3 cp ${AWS_CLI_S3_CP_OPTS} "${DB_RESTORE_TARGET}" $TMPRESTORE

bucket, path := s.url.Hostname(), path.Join(s.url.Path, source)
// The session the S3 Downloader will use
cfg, err := getConfig(s.endpoint)
// get the s3 client
client, err := s.getClient()
if err != nil {
return 0, fmt.Errorf("failed to load AWS config: %v", err)
return 0, fmt.Errorf("failed to get AWS client: %v", err)
}

client := s3.NewFromConfig(cfg)
bucket, path := s.url.Hostname(), path.Join(s.url.Path, source)

// Create a downloader with the session and default options
downloader := manager.NewDownloader(client)
Expand All @@ -99,18 +95,13 @@ func (s *S3) Pull(source, target string) (int64, error) {
}

func (s *S3) Push(target, source string) (int64, error) {
// TODO: need to find way to include cli opts and cli_s3_cp_opts
// old was:
// aws ${AWS_CLI_OPTS} s3 cp ${AWS_CLI_S3_CP_OPTS} "${DB_RESTORE_TARGET}" $TMPRESTORE

bucket, key := s.url.Hostname(), s.url.Path
// The session the S3 Downloader will use
cfg, err := getConfig(s.endpoint)
// get the s3 client
client, err := s.getClient()
if err != nil {
return 0, fmt.Errorf("failed to load AWS config: %v", err)
return 0, fmt.Errorf("failed to get AWS client: %v", err)
}
bucket, key := s.url.Hostname(), s.url.Path

client := s3.NewFromConfig(cfg)
// Create an uploader with the session and default options
uploader := manager.NewUploader(client)

Expand Down Expand Up @@ -142,17 +133,14 @@ func (s *S3) URL() string {
}

func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) {
// Get the AWS config
cfg, err := getConfig(s.endpoint)
// get the s3 client
client, err := s.getClient()
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %v", err)
return nil, fmt.Errorf("failed to get AWS client: %v", err)
}

// Create a new S3 service client
svc := s3.NewFromConfig(cfg)

// Call ListObjectsV2 with your bucket and prefix
result, err := svc.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{Bucket: aws.String(s.url.Hostname()), Prefix: aws.String(dirname)})
result, err := client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{Bucket: aws.String(s.url.Hostname()), Prefix: aws.String(dirname)})
if err != nil {
return nil, fmt.Errorf("failed to list objects, %v", err)
}
Expand All @@ -171,17 +159,14 @@ func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) {
}

func (s *S3) Remove(target string) error {
// Get the AWS config
cfg, err := getConfig(s.endpoint)
// Get the AWS client
client, err := s.getClient()
if err != nil {
return fmt.Errorf("failed to load AWS config: %v", err)
return fmt.Errorf("failed to get AWS client: %v", err)
}

// Create a new S3 service client
svc := s3.NewFromConfig(cfg)

// Call DeleteObject with your bucket and the key of the object you want to delete
_, err = svc.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
_, err = client.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
Bucket: aws.String(s.url.Hostname()),
Key: aws.String(target),
})
Expand All @@ -192,9 +177,44 @@ func (s *S3) Remove(target string) error {
return nil
}

func (s *S3) getClient() (*s3.Client, error) {
// Get the AWS config
cleanEndpoint := getEndpoint(s.endpoint)
opts := []func(*config.LoadOptions) error{
config.WithEndpointResolverWithOptions(
aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{URL: cleanEndpoint}, nil
}),
),
}
if log.IsLevelEnabled(log.TraceLevel) {
opts = append(opts, config.WithClientLogMode(aws.LogRequestWithBody|aws.LogResponse))
}
if s.region != "" {
opts = append(opts, config.WithRegion(s.region))
}
if s.accessKeyId != "" {
opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
s.accessKeyId,
s.secretAccessKey,
"",
)))
}
cfg, err := config.LoadDefaultConfig(context.TODO(),
opts...,
)
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %v", err)
}

// Create a new S3 service client
return s3.NewFromConfig(cfg), nil
}

// getEndpoint returns a clean (for AWS client) endpoint. Normally, this is unchanged,
// but for some reason, the lookup gets flaky when the endpoint is 127.0.0.1,
// so in that case, set it to localhost explicitly.
func getEndpoint(endpoint string) string {
// for some reason, the lookup gets flaky when the endpoint is 127.0.0.1
// so you have to set it to localhost explicitly.
e := endpoint
u, err := url.Parse(endpoint)
if err == nil {
Expand All @@ -210,24 +230,6 @@ func getEndpoint(endpoint string) string {
return e
}

func getConfig(endpoint string) (aws.Config, error) {
cleanEndpoint := getEndpoint(endpoint)
opts := []func(*config.LoadOptions) error{
config.WithEndpointResolverWithOptions(
aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{URL: cleanEndpoint}, nil
}),
),
}
if log.IsLevelEnabled(log.TraceLevel) {
opts = append(opts, config.WithClientLogMode(aws.LogRequestWithBody|aws.LogResponse))
}
return config.LoadDefaultConfig(context.TODO(),
opts...,
)

}

type s3FileInfo struct {
name string
lastModified time.Time
Expand Down
2 changes: 1 addition & 1 deletion test/backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ func runDumpTest(dc *dockerContext, compact bool, base string, targets []backupT
if err := os.MkdirAll(localPath, 0o755); err != nil {
return fmt.Errorf("failed to create local path %s: %w", localPath, err)
}
store, err := storage.ParseURL(t, credentials.Creds{AWSEndpoint: s3})
store, err := storage.ParseURL(t, credentials.Creds{AWS: credentials.AWSCreds{Endpoint: s3}})
if err != nil {
return fmt.Errorf("invalid target url: %v", err)
}
Expand Down
Loading