Skip to content

Commit

Permalink
Refactor aws cloud service and introduce a client provider (#3895)
Browse files Browse the repository at this point in the history
* add support for aws clients provider

* refactor aws cloud service

* fix typo
  • Loading branch information
oliviassss authored Oct 18, 2024
1 parent 1ea514f commit b40a257
Show file tree
Hide file tree
Showing 12 changed files with 528 additions and 152 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func main() {
ctrl.SetLogger(appLogger)
klog.SetLoggerWithOptions(appLogger, klog.ContextualLogger(true))

cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log)
cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log, nil)
if err != nil {
setupLog.Error(err, "unable to initialize AWS cloud")
os.Exit(1)
Expand Down
25 changes: 16 additions & 9 deletions pkg/aws/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
amerrors "k8s.io/apimachinery/pkg/util/errors"
epresolver "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
)

Expand Down Expand Up @@ -59,7 +60,7 @@ type Cloud interface {
}

// NewCloud constructs new Cloud implementation.
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger) (Cloud, error) {
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (Cloud, error) {
hasIPv4 := true
addrs, err := net.InterfaceAddrs()
if err == nil {
Expand Down Expand Up @@ -129,7 +130,14 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l
awsConfig.APIOptions = metrics.WithSDKMetricCollector(metricsCollector, awsConfig.APIOptions)
}

ec2Service := services.NewEC2(awsConfig, endpointsResolver)
if awsClientsProvider == nil {
var err error
awsClientsProvider, err = provider.NewDefaultAWSClientsProvider(awsConfig, endpointsResolver)
if err != nil {
return nil, errors.Wrap(err, "failed to create aws clients provider")
}
}
ec2Service := services.NewEC2(awsClientsProvider)

vpcID, err := getVpcID(cfg, ec2Service, ec2Metadata, logger)
if err != nil {
Expand All @@ -139,17 +147,16 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l
return &defaultCloud{
cfg: cfg,
ec2: ec2Service,
elbv2: services.NewELBV2(awsConfig, endpointsResolver),
acm: services.NewACM(awsConfig, endpointsResolver),
wafv2: services.NewWAFv2(awsConfig, endpointsResolver),
wafRegional: services.NewWAFRegional(awsConfig, endpointsResolver, cfg.Region),
shield: services.NewShield(awsConfig, endpointsResolver), //done
rgt: services.NewRGT(awsConfig, endpointsResolver),
elbv2: services.NewELBV2(awsClientsProvider),
acm: services.NewACM(awsClientsProvider),
wafv2: services.NewWAFv2(awsClientsProvider),
wafRegional: services.NewWAFRegional(awsClientsProvider, cfg.Region),
shield: services.NewShield(awsClientsProvider),
rgt: services.NewRGT(awsClientsProvider),
}, nil
}

func getVpcID(cfg CloudConfig, ec2Service services.EC2, ec2Metadata services.EC2Metadata, logger logr.Logger) (string, error) {

if cfg.VpcID != "" {
logger.V(1).Info("vpcid is specified using flag --aws-vpc-id, controller will use the value", "vpc: ", cfg.VpcID)
return cfg.VpcID, nil
Expand Down
109 changes: 109 additions & 0 deletions pkg/aws/provider/default_aws_clients_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package provider

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
"github.com/aws/aws-sdk-go-v2/service/shield"
"github.com/aws/aws-sdk-go-v2/service/wafregional"
"github.com/aws/aws-sdk-go-v2/service/wafv2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
)

type defaultAWSClientsProvider struct {
ec2Client *ec2.Client
elbv2Client *elasticloadbalancingv2.Client
acmClient *acm.Client
wafv2Client *wafv2.Client
wafRegionClient *wafregional.Client
shieldClient *shield.Client
rgtClient *resourcegroupstaggingapi.Client
}

func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) {
ec2CustomEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID)
elbv2CustomEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID)
acmCustomEndpoint := endpointsResolver.EndpointFor(acm.ServiceID)
wafv2CustomEndpoint := endpointsResolver.EndpointFor(wafv2.ServiceID)
wafregionalCustomEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID)
shieldCustomEndpoint := endpointsResolver.EndpointFor(shield.ServiceID)
rgtCustomEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID)

ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) {
if ec2CustomEndpoint != nil {
o.BaseEndpoint = ec2CustomEndpoint
}
})
elbv2Client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) {
if elbv2CustomEndpoint != nil {
o.BaseEndpoint = elbv2CustomEndpoint
}
})
acmClient := acm.NewFromConfig(cfg, func(o *acm.Options) {
if acmCustomEndpoint != nil {
o.BaseEndpoint = acmCustomEndpoint
}
})
wafv2Client := wafv2.NewFromConfig(cfg, func(o *wafv2.Options) {
if wafv2CustomEndpoint != nil {
o.BaseEndpoint = wafv2CustomEndpoint
}
})
wafregionalClient := wafregional.NewFromConfig(cfg, func(o *wafregional.Options) {
o.Region = cfg.Region
o.BaseEndpoint = wafregionalCustomEndpoint
})
sheildClient := shield.NewFromConfig(cfg, func(o *shield.Options) {
o.Region = "us-east-1"
o.BaseEndpoint = shieldCustomEndpoint
})
rgtClient := resourcegroupstaggingapi.NewFromConfig(cfg, func(o *resourcegroupstaggingapi.Options) {
if rgtCustomEndpoint != nil {
o.BaseEndpoint = rgtCustomEndpoint
}
})

return &defaultAWSClientsProvider{
ec2Client: ec2Client,
elbv2Client: elbv2Client,
acmClient: acmClient,
wafv2Client: wafv2Client,
wafRegionClient: wafregionalClient,
shieldClient: sheildClient,
rgtClient: rgtClient,
}, nil
}

// DO NOT REMOVE operationName as parameter, this is on purpose
// to retain the default behavior for OSS controller to use the default client for each aws service
// for our internal controller, we will choose different client based on operationName
func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) {
return p.ec2Client, nil
}

func (p *defaultAWSClientsProvider) GetELBv2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error) {
return p.elbv2Client, nil
}

func (p *defaultAWSClientsProvider) GetACMClient(ctx context.Context, operationName string) (*acm.Client, error) {
return p.acmClient, nil
}

func (p *defaultAWSClientsProvider) GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error) {
return p.wafv2Client, nil
}

func (p *defaultAWSClientsProvider) GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error) {
return p.wafRegionClient, nil
}

func (p *defaultAWSClientsProvider) GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error) {
return p.shieldClient, nil
}

func (p *defaultAWSClientsProvider) GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error) {
return p.rgtClient, nil
}
22 changes: 22 additions & 0 deletions pkg/aws/provider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package provider

import (
"context"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
"github.com/aws/aws-sdk-go-v2/service/shield"
"github.com/aws/aws-sdk-go-v2/service/wafregional"
"github.com/aws/aws-sdk-go-v2/service/wafv2"
)

type AWSClientsProvider interface {
GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error)
GetELBv2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error)
GetACMClient(ctx context.Context, operationName string) (*acm.Client, error)
GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error)
GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error)
GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error)
GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error)
}
26 changes: 14 additions & 12 deletions pkg/aws/services/acm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package services

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/aws/aws-sdk-go-v2/service/acm/types"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
)

type ACM interface {
Expand All @@ -15,24 +14,23 @@ type ACM interface {
}

// NewACM constructs new ACM implementation.
func NewACM(cfg aws.Config, endpointsResolver *endpoints.Resolver) ACM {
customEndpoint := endpointsResolver.EndpointFor(acm.ServiceID)
func NewACM(awsClientsProvider provider.AWSClientsProvider) ACM {
return &acmClient{
acmClient: acm.NewFromConfig(cfg, func(o *acm.Options) {
if customEndpoint != nil {
o.BaseEndpoint = customEndpoint
}
}),
awsClientsProvider: awsClientsProvider,
}
}

type acmClient struct {
acmClient *acm.Client
awsClientsProvider provider.AWSClientsProvider
}

func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListCertificatesInput) ([]types.CertificateSummary, error) {
var result []types.CertificateSummary
paginator := acm.NewListCertificatesPaginator(c.acmClient, input)
client, err := c.awsClientsProvider.GetACMClient(ctx, "ListCertificates")
if err != nil {
return nil, err
}
paginator := acm.NewListCertificatesPaginator(client, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
Expand All @@ -44,5 +42,9 @@ func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListC
}

func (c *acmClient) DescribeCertificateWithContext(ctx context.Context, input *acm.DescribeCertificateInput) (*acm.DescribeCertificateOutput, error) {
return c.acmClient.DescribeCertificate(ctx, input)
client, err := c.awsClientsProvider.GetACMClient(ctx, "DescribeCertificate")
if err != nil {
return nil, err
}
return client.DescribeCertificate(ctx, input)
}
Loading

0 comments on commit b40a257

Please sign in to comment.