Skip to content
This repository was archived by the owner on May 18, 2021. It is now read-only.

Commit 5959494

Browse files
authored
Add STS Regional Endpoint Support To Other STS Clients (#308)
1 parent 37d8632 commit 5959494

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

lib/provider.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"errors"
99

1010
"github.com/segmentio/aws-okta/sessioncache"
11+
"github.com/aws/aws-sdk-go/aws/endpoints"
1112
log "github.com/sirupsen/logrus"
1213

1314
"github.com/99designs/keyring"
@@ -301,11 +302,19 @@ func (p *Provider) GetSAMLLoginURL() (*url.URL, error) {
301302

302303
// assumeRoleFromSession takes a session created with an okta SAML login and uses that to assume a role
303304
func (p *Provider) assumeRoleFromSession(creds sts.Credentials, roleArn string) (sts.Credentials, error) {
304-
client := sts.New(aws_session.New(&aws.Config{Credentials: credentials.NewStaticCredentials(
305-
*creds.AccessKeyId,
306-
*creds.SecretAccessKey,
307-
*creds.SessionToken,
308-
)}))
305+
conf := &aws.Config{
306+
Credentials: credentials.NewStaticCredentials(
307+
*creds.AccessKeyId,
308+
*creds.SecretAccessKey,
309+
*creds.SessionToken,
310+
),
311+
}
312+
if region := p.profiles[sourceProfile(p.profile, p.profiles)]["region"]; region != "" {
313+
conf.WithRegion(region)
314+
conf.WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)
315+
}
316+
sess := aws_session.Must(aws_session.NewSession(conf))
317+
client := sts.New(sess)
309318

310319
input := &sts.AssumeRoleInput{
311320
RoleArn: aws.String(roleArn),
@@ -341,15 +350,19 @@ func (p *Provider) roleSessionName() string {
341350
// GetRoleARN uses temporary credentials to call AWS's get-caller-identity and
342351
// returns the assumed role's ARN
343352
func (p *Provider) GetRoleARNWithRegion(creds credentials.Value) (string, error) {
344-
config := aws.Config{Credentials: credentials.NewStaticCredentials(
345-
creds.AccessKeyID,
346-
creds.SecretAccessKey,
347-
creds.SessionToken,
348-
)}
353+
conf := &aws.Config{
354+
Credentials: credentials.NewStaticCredentials(
355+
creds.AccessKeyID,
356+
creds.SecretAccessKey,
357+
creds.SessionToken,
358+
),
359+
}
349360
if region := p.profiles[sourceProfile(p.profile, p.profiles)]["region"]; region != "" {
350-
config.WithRegion(region)
361+
conf.WithRegion(region)
362+
conf.WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)
351363
}
352-
client := sts.New(aws_session.New(&config))
364+
sess := aws_session.Must(aws_session.NewSession(conf))
365+
client := sts.New(sess)
353366

354367
indentity, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
355368
if err != nil {

0 commit comments

Comments
 (0)