diff --git a/aws.go b/aws.go index e11f2e1..52a4b19 100644 --- a/aws.go +++ b/aws.go @@ -2,6 +2,7 @@ package main import "strings" import "errors" +import "fmt" import "time" import "os/user" import "github.com/tj/go-debug" @@ -12,6 +13,11 @@ import "github.com/aws/aws-sdk-go/service/sts" var debugAws = debug.Debug("oktad:aws") +type SamlProviderArns struct { + PrincipalArn string + RoleArn string +} + // assumes the first role and returns the credentials you need for // the second assumeRole... // returns those credentials, the expiration time, and error if any @@ -24,39 +30,48 @@ func assumeFirstRole(acfg AwsConfig, saml *OktaSamlResponse) (*credentials.Crede sess, ) - var arns string + var arns *SamlProviderArns + var found bool = false + var err error for _, a := range saml.Attributes { if a.Name == "https://aws.amazon.com/SAML/Attributes/Role" { - arns = a.Value - debugAws("found principal ARN %s", a.Value) - break + var crossAccountArn string + if acfg.CrossAcctArn != "" { + crossAccountArn = acfg.CrossAcctArn + } else { + crossAccountArn, err = selectCrossAccount(a.Value) + + if err != nil { + return nil, emptyExpire, err + } + } + + for _, v := range a.Value { + arns, err = splitSamlProviderArns(v) + + if err != nil { + return nil, emptyExpire, err + } + + debugAws("found principal ARN: %s, role ARN: %s", arns.PrincipalArn, arns.RoleArn) + + if crossAccountArn == arns.RoleArn { + found = true + break + } + } } } - if arns == "" { + if !found { return nil, emptyExpire, errors.New("no arn found from saml data!") } - parts := strings.Split(arns, ",") - - if len(parts) != 2 { - return nil, emptyExpire, errors.New("invalid initial role ARN") - } - - var roleArn, principalArn string - for _, part := range parts { - if strings.Contains(part, "saml-provider") { - principalArn = part - } else { - roleArn = part - } - } - res, err := scl.AssumeRoleWithSAML( &sts.AssumeRoleWithSAMLInput{ - PrincipalArn: &principalArn, - RoleArn: &roleArn, + PrincipalArn: &arns.PrincipalArn, + RoleArn: &arns.RoleArn, SAMLAssertion: &saml.raw, DurationSeconds: aws.Int64(3600), }, @@ -119,3 +134,65 @@ func assumeDestinationRole(acfg AwsConfig, creds *credentials.Credentials) (*cre return mCreds, *res.Credentials.Expiration, nil } + +func splitSamlProviderArns(arns string) (*SamlProviderArns, error) { + var res SamlProviderArns + parts := strings.Split(arns, ",") + + if len(parts) != 2 { + return nil, errors.New("invalid SAML Provider ARN") + } + + for _, part := range parts { + if strings.Contains(part, "saml-provider") { + res.PrincipalArn = part + } else { + res.RoleArn = part + } + } + + return &res, nil +} + +func selectCrossAccount(values []string) (crossAccountArn string, err error) { + choices := len(values) + if choices < 1 { + return "", errors.New("empty array of cross-account ARNs received") + } + + if choices == 1 { + return values[0], nil + } + + var arns []string + fmt.Println("Roles available: ") + for i, a := range values { + debugAws("index: %d, value: %s", i, a) + arn, _ := splitSamlProviderArns(a) + arns = append(arns, arn.RoleArn) + debugAws("arn.RoleArn: %s", arn.RoleArn) + fmt.Println(i, "- ", arns[i]) + } + fmt.Println("Select cross-account ARN number: ") + var roleIndex int + tries := 0 + +TRYROLE: + _, err = fmt.Scanf("%d", &roleIndex) + if err != nil { + return "", err + } + + if roleIndex < choices { + debugAws("selected cross-account Arn %s", arns[roleIndex]) + return arns[roleIndex], nil + } + + if tries < 2 { + tries++ + fmt.Println("Invalid role number, please try again") + goto TRYROLE + } + + return "", errors.New("Invalid role selection. Aborting") +} diff --git a/config-loader.go b/config-loader.go index aee360e..5eba896 100644 --- a/config-loader.go +++ b/config-loader.go @@ -20,8 +20,9 @@ type OktaConfig struct { // in your aws config type AwsConfig struct { // destination ARN - DestArn string - Region string + DestArn string + CrossAcctArn string + Region string } // loads configuration data from the file specified @@ -150,6 +151,32 @@ func readAwsProfile(name string) (AwsConfig, error) { arnKey, _ := s.GetKey("role_arn") cfg.DestArn = arnKey.String() + cfg.CrossAcctArn = "" + if s.HasKey("source_profile") { + debugCfg("aws profile %s has a source_profile, looking for cross-account arn", name) + spKey, _ := s.GetKey("source_profile") + + // profile's key is [profile ] unless it is the default profile + crossAcctProfile := spKey.String() + if spKey.String() != "default" { + crossAcctProfile = fmt.Sprintf("profile %s", spKey.String()) + } + + debugCfg("source_profile is %s", crossAcctProfile) + caas, err := asec.GetSection(crossAcctProfile) + + if err != nil { + debugCfg("aws cross-account profile load err, %s", err) + return cfg, err + } + + if caas.HasKey("role_arn") { + arnKey, _ := caas.GetKey("role_arn") + debugCfg("CrossAcctArn: %s", arnKey.String()) + cfg.CrossAcctArn = arnKey.String() + } + } + // try to figure out a region... // try to look for a region key in current section // if fail: try to look for source_profile diff --git a/okta.go b/okta.go index 7825fd1..5e17e28 100644 --- a/okta.go +++ b/okta.go @@ -55,9 +55,9 @@ type OktaSamlResponse struct { raw string XMLname xml.Name `xml:"Response"` Attributes []struct { - Name string `xml:",attr"` - NameFormat string `xml:",attr"` - Value string `xml:"AttributeValue"` + Name string `xml:",attr"` + NameFormat string `xml:",attr"` + Value []string `xml:"AttributeValue"` } `xml:"Assertion>AttributeStatement>Attribute"` }