Skip to content
This repository has been archived by the owner on May 18, 2021. It is now read-only.

Commit

Permalink
Add STS Regional Endpoint Support To Other STS Clients (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tensho authored Jan 4, 2021
1 parent 37d8632 commit 5959494
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions lib/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"

"github.com/segmentio/aws-okta/sessioncache"
"github.com/aws/aws-sdk-go/aws/endpoints"
log "github.com/sirupsen/logrus"

"github.com/99designs/keyring"
Expand Down Expand Up @@ -301,11 +302,19 @@ func (p *Provider) GetSAMLLoginURL() (*url.URL, error) {

// assumeRoleFromSession takes a session created with an okta SAML login and uses that to assume a role
func (p *Provider) assumeRoleFromSession(creds sts.Credentials, roleArn string) (sts.Credentials, error) {
client := sts.New(aws_session.New(&aws.Config{Credentials: credentials.NewStaticCredentials(
*creds.AccessKeyId,
*creds.SecretAccessKey,
*creds.SessionToken,
)}))
conf := &aws.Config{
Credentials: credentials.NewStaticCredentials(
*creds.AccessKeyId,
*creds.SecretAccessKey,
*creds.SessionToken,
),
}
if region := p.profiles[sourceProfile(p.profile, p.profiles)]["region"]; region != "" {
conf.WithRegion(region)
conf.WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)
}
sess := aws_session.Must(aws_session.NewSession(conf))
client := sts.New(sess)

input := &sts.AssumeRoleInput{
RoleArn: aws.String(roleArn),
Expand Down Expand Up @@ -341,15 +350,19 @@ func (p *Provider) roleSessionName() string {
// GetRoleARN uses temporary credentials to call AWS's get-caller-identity and
// returns the assumed role's ARN
func (p *Provider) GetRoleARNWithRegion(creds credentials.Value) (string, error) {
config := aws.Config{Credentials: credentials.NewStaticCredentials(
creds.AccessKeyID,
creds.SecretAccessKey,
creds.SessionToken,
)}
conf := &aws.Config{
Credentials: credentials.NewStaticCredentials(
creds.AccessKeyID,
creds.SecretAccessKey,
creds.SessionToken,
),
}
if region := p.profiles[sourceProfile(p.profile, p.profiles)]["region"]; region != "" {
config.WithRegion(region)
conf.WithRegion(region)
conf.WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)
}
client := sts.New(aws_session.New(&config))
sess := aws_session.Must(aws_session.NewSession(conf))
client := sts.New(sess)

indentity, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
Expand Down

0 comments on commit 5959494

Please sign in to comment.