Skip to content

Commit cc54ab6

Browse files
committed
fix s3 session creation
Signed-off-by: Avi Deitcher <[email protected]>
1 parent 51c9ce6 commit cc54ab6

File tree

5 files changed

+88
-65
lines changed

5 files changed

+88
-65
lines changed

cmd/root.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,13 @@ func rootCmd(execs execs) (*cobra.Command, error) {
127127
// these are not from the config file, as they are generic credentials, used across all targets.
128128
// the config file uses specific ones per target
129129
cmdConfig.creds = credentials.Creds{
130-
AWSEndpoint: v.GetString("aws-endpoint-url"),
131-
SMBCredentials: credentials.SMBCreds{
130+
AWS: credentials.AWSCreds{
131+
Endpoint: v.GetString("aws-endpoint-url"),
132+
AccessKeyID: v.GetString("aws-access-key-id"),
133+
SecretAccessKey: v.GetString("aws-secret-access-key"),
134+
Region: v.GetString("aws-region"),
135+
},
136+
SMB: credentials.SMBCreds{
132137
Username: v.GetString("smb-user"),
133138
Password: v.GetString("smb-pass"),
134139
Domain: v.GetString("smb-domain"),

pkg/storage/credentials/creds.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
package credentials
22

33
type Creds struct {
4-
SMBCredentials SMBCreds
5-
AWSEndpoint string
4+
SMB SMBCreds
5+
AWS AWSCreds
66
}
77

88
type SMBCreds struct {
99
Username string
1010
Password string
1111
Domain string
1212
}
13+
14+
type AWSCreds struct {
15+
AccessKeyID string
16+
SecretAccessKey string
17+
Endpoint string
18+
Region string
19+
}

pkg/storage/parse.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,29 @@ func ParseURL(url string, creds credentials.Creds) (Storage, error) {
2424
store = file.New(*u)
2525
case "smb":
2626
opts := []smb.Option{}
27-
if creds.SMBCredentials.Domain != "" {
28-
opts = append(opts, smb.WithDomain(creds.SMBCredentials.Domain))
27+
if creds.SMB.Domain != "" {
28+
opts = append(opts, smb.WithDomain(creds.SMB.Domain))
2929
}
30-
if creds.SMBCredentials.Username != "" {
31-
opts = append(opts, smb.WithUsername(creds.SMBCredentials.Username))
30+
if creds.SMB.Username != "" {
31+
opts = append(opts, smb.WithUsername(creds.SMB.Username))
3232
}
33-
if creds.SMBCredentials.Password != "" {
34-
opts = append(opts, smb.WithPassword(creds.SMBCredentials.Password))
33+
if creds.SMB.Password != "" {
34+
opts = append(opts, smb.WithPassword(creds.SMB.Password))
3535
}
3636
store = smb.New(*u, opts...)
3737
case "s3":
3838
opts := []s3.Option{}
39-
if creds.AWSEndpoint != "" {
40-
opts = append(opts, s3.WithEndpoint(creds.AWSEndpoint))
39+
if creds.AWS.Endpoint != "" {
40+
opts = append(opts, s3.WithEndpoint(creds.AWS.Endpoint))
41+
}
42+
if creds.AWS.Region != "" {
43+
opts = append(opts, s3.WithRegion(creds.AWS.Region))
44+
}
45+
if creds.AWS.AccessKeyID != "" {
46+
opts = append(opts, s3.WithAccessKeyId(creds.AWS.AccessKeyID))
47+
}
48+
if creds.AWS.SecretAccessKey != "" {
49+
opts = append(opts, s3.WithSecretAccessKey(creds.AWS.SecretAccessKey))
4150
}
4251
store = s3.New(*u, opts...)
4352
default:

pkg/storage/s3/s3.go

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/aws/aws-sdk-go-v2/aws"
1313
"github.com/aws/aws-sdk-go-v2/config"
14+
"github.com/aws/aws-sdk-go-v2/credentials"
1415
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
1516
"github.com/aws/aws-sdk-go-v2/service/s3"
1617
log "github.com/sirupsen/logrus"
@@ -64,18 +65,13 @@ func New(u url.URL, opts ...Option) *S3 {
6465
}
6566

6667
func (s *S3) Pull(source, target string) (int64, error) {
67-
// TODO: need to find way to include cli opts and cli_s3_cp_opts
68-
// old was:
69-
// aws ${AWS_CLI_OPTS} s3 cp ${AWS_CLI_S3_CP_OPTS} "${DB_RESTORE_TARGET}" $TMPRESTORE
70-
71-
bucket, path := s.url.Hostname(), path.Join(s.url.Path, source)
72-
// The session the S3 Downloader will use
73-
cfg, err := getConfig(s.endpoint)
68+
// get the s3 client
69+
client, err := s.getClient()
7470
if err != nil {
75-
return 0, fmt.Errorf("failed to load AWS config: %v", err)
71+
return 0, fmt.Errorf("failed to get AWS client: %v", err)
7672
}
7773

78-
client := s3.NewFromConfig(cfg)
74+
bucket, path := s.url.Hostname(), path.Join(s.url.Path, source)
7975

8076
// Create a downloader with the session and default options
8177
downloader := manager.NewDownloader(client)
@@ -99,18 +95,13 @@ func (s *S3) Pull(source, target string) (int64, error) {
9995
}
10096

10197
func (s *S3) Push(target, source string) (int64, error) {
102-
// TODO: need to find way to include cli opts and cli_s3_cp_opts
103-
// old was:
104-
// aws ${AWS_CLI_OPTS} s3 cp ${AWS_CLI_S3_CP_OPTS} "${DB_RESTORE_TARGET}" $TMPRESTORE
105-
106-
bucket, key := s.url.Hostname(), s.url.Path
107-
// The session the S3 Downloader will use
108-
cfg, err := getConfig(s.endpoint)
98+
// get the s3 client
99+
client, err := s.getClient()
109100
if err != nil {
110-
return 0, fmt.Errorf("failed to load AWS config: %v", err)
101+
return 0, fmt.Errorf("failed to get AWS client: %v", err)
111102
}
103+
bucket, key := s.url.Hostname(), s.url.Path
112104

113-
client := s3.NewFromConfig(cfg)
114105
// Create an uploader with the session and default options
115106
uploader := manager.NewUploader(client)
116107

@@ -142,17 +133,14 @@ func (s *S3) URL() string {
142133
}
143134

144135
func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) {
145-
// Get the AWS config
146-
cfg, err := getConfig(s.endpoint)
136+
// get the s3 client
137+
client, err := s.getClient()
147138
if err != nil {
148-
return nil, fmt.Errorf("failed to load AWS config: %v", err)
139+
return nil, fmt.Errorf("failed to get AWS client: %v", err)
149140
}
150141

151-
// Create a new S3 service client
152-
svc := s3.NewFromConfig(cfg)
153-
154142
// Call ListObjectsV2 with your bucket and prefix
155-
result, err := svc.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{Bucket: aws.String(s.url.Hostname()), Prefix: aws.String(dirname)})
143+
result, err := client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{Bucket: aws.String(s.url.Hostname()), Prefix: aws.String(dirname)})
156144
if err != nil {
157145
return nil, fmt.Errorf("failed to list objects, %v", err)
158146
}
@@ -171,17 +159,14 @@ func (s *S3) ReadDir(dirname string) ([]fs.FileInfo, error) {
171159
}
172160

173161
func (s *S3) Remove(target string) error {
174-
// Get the AWS config
175-
cfg, err := getConfig(s.endpoint)
162+
// Get the AWS client
163+
client, err := s.getClient()
176164
if err != nil {
177-
return fmt.Errorf("failed to load AWS config: %v", err)
165+
return fmt.Errorf("failed to get AWS client: %v", err)
178166
}
179167

180-
// Create a new S3 service client
181-
svc := s3.NewFromConfig(cfg)
182-
183168
// Call DeleteObject with your bucket and the key of the object you want to delete
184-
_, err = svc.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
169+
_, err = client.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
185170
Bucket: aws.String(s.url.Hostname()),
186171
Key: aws.String(target),
187172
})
@@ -192,9 +177,44 @@ func (s *S3) Remove(target string) error {
192177
return nil
193178
}
194179

180+
func (s *S3) getClient() (*s3.Client, error) {
181+
// Get the AWS config
182+
cleanEndpoint := getEndpoint(s.endpoint)
183+
opts := []func(*config.LoadOptions) error{
184+
config.WithEndpointResolverWithOptions(
185+
aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
186+
return aws.Endpoint{URL: cleanEndpoint}, nil
187+
}),
188+
),
189+
}
190+
if log.IsLevelEnabled(log.TraceLevel) {
191+
opts = append(opts, config.WithClientLogMode(aws.LogRequestWithBody|aws.LogResponse))
192+
}
193+
if s.region != "" {
194+
opts = append(opts, config.WithRegion(s.region))
195+
}
196+
if s.accessKeyId != "" {
197+
opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
198+
s.accessKeyId,
199+
s.secretAccessKey,
200+
"",
201+
)))
202+
}
203+
cfg, err := config.LoadDefaultConfig(context.TODO(),
204+
opts...,
205+
)
206+
if err != nil {
207+
return nil, fmt.Errorf("failed to load AWS config: %v", err)
208+
}
209+
210+
// Create a new S3 service client
211+
return s3.NewFromConfig(cfg), nil
212+
}
213+
214+
// getEndpoint returns a clean (for AWS client) endpoint. Normally, this is unchanged,
215+
// but for some reason, the lookup gets flaky when the endpoint is 127.0.0.1,
216+
// so in that case, set it to localhost explicitly.
195217
func getEndpoint(endpoint string) string {
196-
// for some reason, the lookup gets flaky when the endpoint is 127.0.0.1
197-
// so you have to set it to localhost explicitly.
198218
e := endpoint
199219
u, err := url.Parse(endpoint)
200220
if err == nil {
@@ -210,24 +230,6 @@ func getEndpoint(endpoint string) string {
210230
return e
211231
}
212232

213-
func getConfig(endpoint string) (aws.Config, error) {
214-
cleanEndpoint := getEndpoint(endpoint)
215-
opts := []func(*config.LoadOptions) error{
216-
config.WithEndpointResolverWithOptions(
217-
aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
218-
return aws.Endpoint{URL: cleanEndpoint}, nil
219-
}),
220-
),
221-
}
222-
if log.IsLevelEnabled(log.TraceLevel) {
223-
opts = append(opts, config.WithClientLogMode(aws.LogRequestWithBody|aws.LogResponse))
224-
}
225-
return config.LoadDefaultConfig(context.TODO(),
226-
opts...,
227-
)
228-
229-
}
230-
231233
type s3FileInfo struct {
232234
name string
233235
lastModified time.Time

test/backup_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ func runDumpTest(dc *dockerContext, compact bool, base string, targets []backupT
434434
if err := os.MkdirAll(localPath, 0o755); err != nil {
435435
return fmt.Errorf("failed to create local path %s: %w", localPath, err)
436436
}
437-
store, err := storage.ParseURL(t, credentials.Creds{AWSEndpoint: s3})
437+
store, err := storage.ParseURL(t, credentials.Creds{AWS: credentials.AWSCreds{Endpoint: s3}})
438438
if err != nil {
439439
return fmt.Errorf("invalid target url: %v", err)
440440
}

0 commit comments

Comments
 (0)