|
8 | 8 | "errors"
|
9 | 9 |
|
10 | 10 | "github.com/segmentio/aws-okta/sessioncache"
|
| 11 | + "github.com/aws/aws-sdk-go/aws/endpoints" |
11 | 12 | log "github.com/sirupsen/logrus"
|
12 | 13 |
|
13 | 14 | "github.com/99designs/keyring"
|
@@ -301,11 +302,19 @@ func (p *Provider) GetSAMLLoginURL() (*url.URL, error) {
|
301 | 302 |
|
302 | 303 | // assumeRoleFromSession takes a session created with an okta SAML login and uses that to assume a role
|
303 | 304 | func (p *Provider) assumeRoleFromSession(creds sts.Credentials, roleArn string) (sts.Credentials, error) {
|
304 |
| - client := sts.New(aws_session.New(&aws.Config{Credentials: credentials.NewStaticCredentials( |
305 |
| - *creds.AccessKeyId, |
306 |
| - *creds.SecretAccessKey, |
307 |
| - *creds.SessionToken, |
308 |
| - )})) |
| 305 | + conf := &aws.Config{ |
| 306 | + Credentials: credentials.NewStaticCredentials( |
| 307 | + *creds.AccessKeyId, |
| 308 | + *creds.SecretAccessKey, |
| 309 | + *creds.SessionToken, |
| 310 | + ), |
| 311 | + } |
| 312 | + if region := p.profiles[sourceProfile(p.profile, p.profiles)]["region"]; region != "" { |
| 313 | + conf.WithRegion(region) |
| 314 | + conf.WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint) |
| 315 | + } |
| 316 | + sess := aws_session.Must(aws_session.NewSession(conf)) |
| 317 | + client := sts.New(sess) |
309 | 318 |
|
310 | 319 | input := &sts.AssumeRoleInput{
|
311 | 320 | RoleArn: aws.String(roleArn),
|
@@ -341,15 +350,19 @@ func (p *Provider) roleSessionName() string {
|
341 | 350 | // GetRoleARN uses temporary credentials to call AWS's get-caller-identity and
|
342 | 351 | // returns the assumed role's ARN
|
343 | 352 | func (p *Provider) GetRoleARNWithRegion(creds credentials.Value) (string, error) {
|
344 |
| - config := aws.Config{Credentials: credentials.NewStaticCredentials( |
345 |
| - creds.AccessKeyID, |
346 |
| - creds.SecretAccessKey, |
347 |
| - creds.SessionToken, |
348 |
| - )} |
| 353 | + conf := &aws.Config{ |
| 354 | + Credentials: credentials.NewStaticCredentials( |
| 355 | + creds.AccessKeyID, |
| 356 | + creds.SecretAccessKey, |
| 357 | + creds.SessionToken, |
| 358 | + ), |
| 359 | + } |
349 | 360 | if region := p.profiles[sourceProfile(p.profile, p.profiles)]["region"]; region != "" {
|
350 |
| - config.WithRegion(region) |
| 361 | + conf.WithRegion(region) |
| 362 | + conf.WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint) |
351 | 363 | }
|
352 |
| - client := sts.New(aws_session.New(&config)) |
| 364 | + sess := aws_session.Must(aws_session.NewSession(conf)) |
| 365 | + client := sts.New(sess) |
353 | 366 |
|
354 | 367 | indentity, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
|
355 | 368 | if err != nil {
|
|
0 commit comments