Skip to content

Commit

Permalink
fix s3 session creation
Browse files Browse the repository at this point in the history
Signed-off-by: Avi Deitcher <[email protected]>
  • Loading branch information
deitch committed Apr 7, 2024
1 parent 51c9ce6 commit cc54ab6
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 65 deletions.
9 changes: 7 additions & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,13 @@ func rootCmd(execs execs) (*cobra.Command, error) {
// these are not from the config file, as they are generic credentials, used across all targets.
// the config file uses specific ones per target
cmdConfig.creds = credentials.Creds{
AWSEndpoint: v.GetString("aws-endpoint-url"),
SMBCredentials: credentials.SMBCreds{
AWS: credentials.AWSCreds{
Endpoint: v.GetString("aws-endpoint-url"),
AccessKeyID: v.GetString("aws-access-key-id"),
SecretAccessKey: v.GetString("aws-secret-access-key"),
Region: v.GetString("aws-region"),
},
SMB: credentials.SMBCreds{
Username: v.GetString("smb-user"),
Password: v.GetString("smb-pass"),
Domain: v.GetString("smb-domain"),
Expand Down
11 changes: 9 additions & 2 deletions pkg/storage/credentials/creds.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package credentials

type Creds struct {
SMBCredentials SMBCreds
AWSEndpoint string
SMB SMBCreds
AWS AWSCreds
}

type SMBCreds struct {
Username string
Password string
Domain string
}

type AWSCreds struct {
AccessKeyID string
SecretAccessKey string
Endpoint string
Region string
}
25 changes: 17 additions & 8 deletions pkg/storage/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,29 @@ func ParseURL(url string, creds credentials.Creds) (Storage, error) {
store = file.New(*u)
case "smb":
opts := []smb.Option{}
if creds.SMBCredentials.Domain != "" {
opts = append(opts, smb.WithDomain(creds.SMBCredentials.Domain))
if creds.SMB.Domain != "" {
opts = append(opts, smb.WithDomain(creds.SMB.Domain))
}
if creds.SMBCredentials.Username != "" {
opts = append(opts, smb.WithUsername(creds.SMBCredentials.Username))
if creds.SMB.Username != "" {
opts = append(opts, smb.WithUsername(creds.SMB.Username))
}
if creds.SMBCredentials.Password != "" {
opts = append(opts, smb.WithPassword(creds.SMBCredentials.Password))
if creds.SMB.Password != "" {
opts = append(opts, smb.WithPassword(creds.SMB.Password))
}
store = smb.New(*u, opts...)
case "s3":
opts := []s3.Option{}
if creds.AWSEndpoint != "" {
opts = append(opts, s3.WithEndpoint(creds.AWSEndpoint))
if creds.AWS.Endpoint != "" {
opts = append(opts, s3.WithEndpoint(creds.AWS.Endpoint))
}
if creds.AWS.Region != "" {
opts = append(opts, s3.WithRegion(creds.AWS.Region))
}
if creds.AWS.AccessKeyID != "" {
opts = append(opts, s3.WithAccessKeyId(creds.AWS.AccessKeyID))
}
if creds.AWS.SecretAccessKey != "" {
opts = append(opts, s3.WithSecretAccessKey(creds.AWS.SecretAccessKey))
}
store = s3.New(*u, opts...)
default:
Expand Down
106 changes: 54 additions & 52 deletions pkg/storage/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -64,18 +65,13 @@ func New(u url.URL, opts ...Option) *S3 {
}

func (s *S3) Pull(source, target string) (int64, error) {
// TODO: need to find way to include cli opts and cli_s3_cp_opts
// old was:
// aws ${AWS_CLI_OPTS} s3 cp ${AWS_CLI_S3_CP_OPTS} "${DB_RESTORE_TARGET}" $TMPRESTORE

bucket, path := s.url.Hostname(), path.Join(s.url.Path, source)
// The session the S3 Downloader will use
cfg, err := getConfig(s.endpoint)
// get the s3 client
client, err := s.getClient()
if err != nil {
return 0, fmt.Errorf("failed to load AWS config: %v", err)
return 0, fmt.Errorf("failed to get AWS client: %v", err)
}

client := s3.NewFromConfig(cfg)
bucket, path := s.url.Hostname(), path.Join(s.url.Path, source)

// Create a downloader with the session and default options
downloader := manager.NewDownloader(client)
Expand All @@ -99,18 +95,13 @@ func (s *S3) Pull(source, target string) (int64, error) {
}

func (s *S3) Push(target, source string) (int64, error) {
// TODO: need to find way to include cli opts and cli_s3_cp_opts
// old was:
// aws ${AWS_CLI_OPTS} s3 cp ${AWS_CLI_S3_CP_OPTS} "${DB_RESTORE_TARGET}" $TMPRESTORE

bucket, key := s.url.Hostname(), s.url.Path
// The session the S3 Downloader will use
cfg, err := getConfig(s.endpoint)
// get the s3 client
client, err := s.getClient()
if err != nil {
return 0, fmt.Errorf("failed to load AWS config: %v", err)
return 0, fmt.Errorf("failed to get AWS client: %v", err)
}
bucket, key := s.url.Hostname(), s.url.Path

client := s3.NewFromConfig(cfg)
// Create an uploader with the session and default options
uploader := manager.NewUploader(client)

Expand Down Expand Up @@ -142,17 +133,14 @@ func (s *S3) URL() string {
}

func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) {
// Get the AWS config
cfg, err := getConfig(s.endpoint)
// get the s3 client
client, err := s.getClient()
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %v", err)
return nil, fmt.Errorf("failed to get AWS client: %v", err)
}

// Create a new S3 service client
svc := s3.NewFromConfig(cfg)

// Call ListObjectsV2 with your bucket and prefix
result, err := svc.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{Bucket: aws.String(s.url.Hostname()), Prefix: aws.String(dirname)})
result, err := client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{Bucket: aws.String(s.url.Hostname()), Prefix: aws.String(dirname)})
if err != nil {
return nil, fmt.Errorf("failed to list objects, %v", err)
}
Expand All @@ -171,17 +159,14 @@ func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) {
}

func (s *S3) Remove(target string) error {
// Get the AWS config
cfg, err := getConfig(s.endpoint)
// Get the AWS client
client, err := s.getClient()
if err != nil {
return fmt.Errorf("failed to load AWS config: %v", err)
return fmt.Errorf("failed to get AWS client: %v", err)
}

// Create a new S3 service client
svc := s3.NewFromConfig(cfg)

// Call DeleteObject with your bucket and the key of the object you want to delete
_, err = svc.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
_, err = client.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
Bucket: aws.String(s.url.Hostname()),
Key: aws.String(target),
})
Expand All @@ -192,9 +177,44 @@ func (s *S3) Remove(target string) error {
return nil
}

func (s *S3) getClient() (*s3.Client, error) {
// Get the AWS config
cleanEndpoint := getEndpoint(s.endpoint)
opts := []func(*config.LoadOptions) error{
config.WithEndpointResolverWithOptions(
aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{URL: cleanEndpoint}, nil
}),
),
}
if log.IsLevelEnabled(log.TraceLevel) {
opts = append(opts, config.WithClientLogMode(aws.LogRequestWithBody|aws.LogResponse))
}
if s.region != "" {
opts = append(opts, config.WithRegion(s.region))
}
if s.accessKeyId != "" {
opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
s.accessKeyId,
s.secretAccessKey,
"",
)))
}
cfg, err := config.LoadDefaultConfig(context.TODO(),
opts...,
)
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %v", err)
}

// Create a new S3 service client
return s3.NewFromConfig(cfg), nil
}

// getEndpoint returns a clean (for AWS client) endpoint. Normally, this is unchanged,
// but for some reason, the lookup gets flaky when the endpoint is 127.0.0.1,
// so in that case, set it to localhost explicitly.
func getEndpoint(endpoint string) string {
// for some reason, the lookup gets flaky when the endpoint is 127.0.0.1
// so you have to set it to localhost explicitly.
e := endpoint
u, err := url.Parse(endpoint)
if err == nil {
Expand All @@ -210,24 +230,6 @@ func getEndpoint(endpoint string) string {
return e
}

func getConfig(endpoint string) (aws.Config, error) {
cleanEndpoint := getEndpoint(endpoint)
opts := []func(*config.LoadOptions) error{
config.WithEndpointResolverWithOptions(
aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{URL: cleanEndpoint}, nil
}),
),
}
if log.IsLevelEnabled(log.TraceLevel) {
opts = append(opts, config.WithClientLogMode(aws.LogRequestWithBody|aws.LogResponse))
}
return config.LoadDefaultConfig(context.TODO(),
opts...,
)

}

type s3FileInfo struct {
name string
lastModified time.Time
Expand Down
2 changes: 1 addition & 1 deletion test/backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ func runDumpTest(dc *dockerContext, compact bool, base string, targets []backupT
if err := os.MkdirAll(localPath, 0o755); err != nil {
return fmt.Errorf("failed to create local path %s: %w", localPath, err)
}
store, err := storage.ParseURL(t, credentials.Creds{AWSEndpoint: s3})
store, err := storage.ParseURL(t, credentials.Creds{AWS: credentials.AWSCreds{Endpoint: s3}})
if err != nil {
return fmt.Errorf("invalid target url: %v", err)
}
Expand Down

0 comments on commit cc54ab6

Please sign in to comment.