Skip to content

Commit

Permalink
Use SharedCredentialsProvider instead of StaticProvider for shared cr…
Browse files Browse the repository at this point in the history
…edentials

Previously, we were using the StaticProvider for all credentials resolved
while reading shared config files. There is logic in the AfterRetryHandler that
marks a credential object as expired whenever an expired token exception
occurs. When that happens, the ideal workflow is for the provider to trigger
a credentials refresh by re-reading the profile file, but this is not possible
with the StaticProvider.

This commit addresses that problem by using the SharedCredentialsProvider
instead of the StaticProvider when resolving such credentials. With this
change, the SDK should automatically refreshes the credentials whenever we get
a expired token exception from AWS. This should also implicitly address the
feature requested in aws#1993, but instead of refreshing explicitly, the trigger
here would be an exception from AWS.
  • Loading branch information
jaylim-crl committed Nov 16, 2023
1 parent e635384 commit 7384cb9
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
24 changes: 20 additions & 4 deletions aws/session/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package session
import (
"fmt"
"os"
"regexp"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -18,6 +19,12 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
)

// sharedCfgCredProviderNameRE is a regex that describes the ProviderName of
// credentials retrieved via sharedConfig. This will be used to extract the
// resolved filename and profile fields that indicate where the credentials
// came from.
var sharedCfgCredProviderNameRE = regexp.MustCompile("SharedConfigCredentials: filename=(.*?),profile=(.*?)$")

// CredentialsProviderOptions specifies additional options for configuring
// credentials providers.
type CredentialsProviderOptions struct {
Expand Down Expand Up @@ -113,10 +120,19 @@ func resolveCredsFromProfile(cfg *aws.Config,
)

case sharedCfg.Creds.HasKeys():
// Static Credentials from Shared Config/Credentials file.
creds = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
// Credentials from Shared Config/Credentials file. We will use a
// SharedCredentialsProvider that has the ability to re-read the file
// if the credential was marked as expired.
matches := sharedCfgCredProviderNameRE.FindStringSubmatch(sharedCfg.Creds.ProviderName)
if len(matches) == 3 {
filename, profile := matches[1], matches[2]
creds = credentials.NewSharedCredentials(filename, profile)
} else {
// sharedCfg.Creds must be populated via SharedConfigCredentials,
// so this case is not possible. Use a static credential regardless
// to maintain the existing behavior.
creds = credentials.NewStaticCredentialsFromCreds(sharedCfg.Creds)
}

case len(sharedCfg.CredentialSource) != 0:
creds, err = resolveCredsFromSource(cfg, envCfg,
Expand Down
2 changes: 1 addition & 1 deletion aws/session/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ func TestSessionAssumeRole_DisableSharedConfig(t *testing.T) {
if e, a := "assume_role_w_creds_secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SharedConfigCredentials", creds.ProviderName; !strings.Contains(a, e) {
if e, a := "SharedCredentialsProvider", creds.ProviderName; !strings.Contains(a, e) {
t.Errorf("expect %v, to be in %v", e, a)
}
}
Expand Down
14 changes: 7 additions & 7 deletions aws/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func TestNewSessionWithOptions_OverrideProfile(t *testing.T) {
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if e, a := "SharedConfigCredentials", creds.ProviderName; !strings.Contains(a, e) {
if e, a := "SharedCredentialsProvider", creds.ProviderName; !strings.Contains(a, e) {
t.Errorf("expect %v, to be in %v", e, a)
}
}
Expand Down Expand Up @@ -275,7 +275,7 @@ func TestNewSessionWithOptions_OverrideSharedConfigEnable(t *testing.T) {
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if e, a := "SharedConfigCredentials", creds.ProviderName; !strings.Contains(a, e) {
if e, a := "SharedCredentialsProvider", creds.ProviderName; !strings.Contains(a, e) {
t.Errorf("expect %v, to be in %v", e, a)
}
}
Expand Down Expand Up @@ -312,7 +312,7 @@ func TestNewSessionWithOptions_OverrideSharedConfigDisable(t *testing.T) {
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if e, a := "SharedConfigCredentials", creds.ProviderName; !strings.Contains(a, e) {
if e, a := "SharedCredentialsProvider", creds.ProviderName; !strings.Contains(a, e) {
t.Errorf("expect %v, to be in %v", e, a)
}
}
Expand Down Expand Up @@ -349,7 +349,7 @@ func TestNewSessionWithOptions_OverrideSharedConfigFiles(t *testing.T) {
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if e, a := "SharedConfigCredentials", creds.ProviderName; !strings.Contains(a, e) {
if e, a := "SharedCredentialsProvider", creds.ProviderName; !strings.Contains(a, e) {
t.Errorf("expect %v, to be in %v", e, a)
}
}
Expand All @@ -372,7 +372,7 @@ func TestNewSessionWithOptions_Overrides(t *testing.T) {
OutCreds: credentials.Value{
AccessKeyID: "full_profile_akid",
SecretAccessKey: "full_profile_secret",
ProviderName: "SharedConfigCredentials",
ProviderName: "SharedCredentialsProvider",
},
},
"env creds with env profile": {
Expand Down Expand Up @@ -405,7 +405,7 @@ func TestNewSessionWithOptions_Overrides(t *testing.T) {
OutCreds: credentials.Value{
AccessKeyID: "full_profile_akid",
SecretAccessKey: "full_profile_secret",
ProviderName: "SharedConfigCredentials",
ProviderName: "SharedCredentialsProvider",
},
},
"cfg and cred file with opt profile": {
Expand All @@ -420,7 +420,7 @@ func TestNewSessionWithOptions_Overrides(t *testing.T) {
OutCreds: credentials.Value{
AccessKeyID: "shared_config_akid",
SecretAccessKey: "shared_config_secret",
ProviderName: "SharedConfigCredentials",
ProviderName: "SharedCredentialsProvider",
},
},
}
Expand Down
2 changes: 1 addition & 1 deletion aws/session/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, e
AccessKeyID: section.String(accessKeyIDKey),
SecretAccessKey: section.String(secretAccessKey),
SessionToken: section.String(sessionTokenKey),
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=%s", file.Filename, profile),
}
if creds.HasKeys() {
cfg.Creds = creds
Expand Down
18 changes: 9 additions & 9 deletions aws/session/shared_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestLoadSharedConfig(t *testing.T) {
Creds: credentials.Value{
AccessKeyID: "shared_config_akid",
SecretAccessKey: "shared_config_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=config_file_load_order", testConfigFilename),
},
},
},
Expand All @@ -65,7 +65,7 @@ func TestLoadSharedConfig(t *testing.T) {
Creds: credentials.Value{
AccessKeyID: "shared_config_other_akid",
SecretAccessKey: "shared_config_other_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigOtherFilename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=config_file_load_order", testConfigOtherFilename),
},
},
},
Expand All @@ -81,7 +81,7 @@ func TestLoadSharedConfig(t *testing.T) {
Creds: credentials.Value{
AccessKeyID: "complete_creds_akid",
SecretAccessKey: "complete_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=complete_creds", testConfigFilename),
},
},
},
Expand Down Expand Up @@ -113,7 +113,7 @@ func TestLoadSharedConfig(t *testing.T) {
Creds: credentials.Value{
AccessKeyID: "assume_role_w_creds_akid",
SecretAccessKey: "assume_role_w_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=assume_role_w_creds", testConfigFilename),
},
},
},
Expand Down Expand Up @@ -161,7 +161,7 @@ func TestLoadSharedConfig(t *testing.T) {
Creds: credentials.Value{
AccessKeyID: "complete_creds_akid",
SecretAccessKey: "complete_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=complete_creds", testConfigFilename),
},
},
},
Expand Down Expand Up @@ -252,7 +252,7 @@ func TestLoadSharedConfig(t *testing.T) {
AccessKeyID: "sso_and_static_akid",
SecretAccessKey: "sso_and_static_secret",
SessionToken: "sso_and_static_token",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=sso_and_static", testConfigFilename),
},
SSOAccountID: "012345678901",
SSORegion: "us-west-2",
Expand Down Expand Up @@ -496,7 +496,7 @@ func TestLoadSharedConfigFromFile(t *testing.T) {
Creds: credentials.Value{
AccessKeyID: "complete_creds_akid",
SecretAccessKey: "complete_creds_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=complete_creds", testConfigFilename),
},
},
},
Expand All @@ -507,7 +507,7 @@ func TestLoadSharedConfigFromFile(t *testing.T) {
AccessKeyID: "complete_creds_with_token_akid",
SecretAccessKey: "complete_creds_with_token_secret",
SessionToken: "complete_creds_with_token_token",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=complete_creds_with_token", testConfigFilename),
},
},
},
Expand All @@ -517,7 +517,7 @@ func TestLoadSharedConfigFromFile(t *testing.T) {
Creds: credentials.Value{
AccessKeyID: "full_profile_akid",
SecretAccessKey: "full_profile_secret",
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename),
ProviderName: fmt.Sprintf("SharedConfigCredentials: filename=%s,profile=full_profile", testConfigFilename),
},
Region: "full_profile_region",
},
Expand Down

0 comments on commit 7384cb9

Please sign in to comment.