Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreasBurger committed Jun 13, 2024
1 parent 1e5c473 commit 1e0b5c4
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 149 deletions.
14 changes: 4 additions & 10 deletions pkg/azure/api/providerspec.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,20 +314,14 @@ type AzureDiagnosticsProfile struct {

// The (currently) supported values for the names of clouds to use in the CloudConfiguration.
const (
AzureChinaCloudName string = "AzureChina"
AzureGovCloudName string = "AzureGovernment"
AzurePublicCloudName string = "AzurePublic"
)

// The known prefixes in of region names for the various instances.
var (
AzureGovRegionPrefixes = []string{"usgov", "usdod", "ussec"}
AzureChinaRegionPrefixes = []string{"china"}
CloudNameChina string = "AzureChina"
CloudNameGov string = "AzureGovernment"
CloudNamePublic string = "AzurePublic"
)

// CloudConfiguration contains detailed config for the cloud to connect to. Currently we only support selection of well-
// known Azure-instances by name, but this could be extended in future to support private clouds.
type CloudConfiguration struct {
// Name is the name of the cloud to connect to, e.g. "AzurePublic" or "AzureChina".
Name string `json:"name,omitempty"`
Name string `json:"name"`
}
14 changes: 8 additions & 6 deletions pkg/azure/api/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ func ValidateProviderSpec(spec api.AzureProviderSpec) field.ErrorList {
allErrs = append(allErrs, validateSubnetInfo(spec.SubnetInfo, specPath.Child("subnetInfo"))...)
allErrs = append(allErrs, validateProperties(spec.Properties, specPath.Child("properties"))...)
allErrs = append(allErrs, validateTags(spec.Tags, specPath.Child("tags"))...)

if spec.CloudConfiguration != nil {
allErrs = append(allErrs, validateCloudConfiguration(*spec.CloudConfiguration, specPath.Child("cloudConfiguration"))...)
}
allErrs = append(allErrs, validateCloudConfiguration(spec.CloudConfiguration, specPath.Child("cloudConfiguration"))...)

return allErrs
}
Expand Down Expand Up @@ -126,9 +123,14 @@ func validateProperties(properties api.AzureVirtualMachineProperties, fldPath *f
return allErrs
}

func validateCloudConfiguration(cloudConfiguration api.CloudConfiguration, fldPath *field.Path) field.ErrorList {
func validateCloudConfiguration(cloudConfiguration *api.CloudConfiguration, fldPath *field.Path) field.ErrorList {
var allErrs field.ErrorList
knownCloudInstances := []string{api.AzurePublicCloudName, api.AzureChinaCloudName, api.AzureGovCloudName}

if cloudConfiguration == nil {
return allErrs
}

knownCloudInstances := []string{api.CloudNamePublic, api.CloudNameChina, api.CloudNameGov}

if cloudName := cloudConfiguration.Name; !slices.Contains(knownCloudInstances, cloudName) {
allErrs = append(allErrs, field.NotSupported(fldPath.Child("name"), cloudName, knownCloudInstances))
Expand Down
10 changes: 1 addition & 9 deletions pkg/azure/provider/helpers/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ func ExtractProviderSpecAndConnectConfig(mcc *v1alpha1.MachineClass, secret *cor
err error
providerSpec api.AzureProviderSpec
connectConfig access.ConnectConfig
cloudConfiguration *api.CloudConfiguration
azCloudConfiguration cloud.Configuration
region *string
)
// validate provider Spec provider. Exit early if it is not azure.
if err = validation.ValidateMachineClassProvider(mcc); err != nil {
Expand All @@ -60,13 +58,7 @@ func ExtractProviderSpecAndConnectConfig(mcc *v1alpha1.MachineClass, secret *cor
return api.AzureProviderSpec{}, access.ConnectConfig{}, err
}

if providerSpec.CloudConfiguration != nil {
cloudConfiguration = providerSpec.CloudConfiguration
}
if mcc != nil && mcc.NodeTemplate != nil {
region = &mcc.NodeTemplate.Region
}
if azCloudConfiguration, err = DetermineCloudConfiguration(cloudConfiguration, region); err != nil {
if azCloudConfiguration, err = DetermineCloudConfiguration(providerSpec.CloudConfiguration); err != nil {
return api.AzureProviderSpec{}, access.ConnectConfig{}, err
}

Expand Down
38 changes: 6 additions & 32 deletions pkg/azure/provider/helpers/providerspec.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,23 @@ import (
"github.com/gardener/machine-controller-manager-provider-azure/pkg/azure/api"
)

// DetermineCloudConfiguration returns the Azure cloud.Configuration corresponding to the instance given by the provided input. If both cloudConfiguration and
// region are provided, cloudConfiguration takes precedence.
func DetermineCloudConfiguration(cloudConfiguration *api.CloudConfiguration, region *string) (cloud.Configuration, error) {

// DetermineCloudConfiguration returns the Azure cloud.Configuration corresponding to the instance given by the provided api.Configuration.
func DetermineCloudConfiguration(cloudConfiguration *api.CloudConfiguration) (cloud.Configuration, error) {
if cloudConfiguration != nil {
cloudConfigurationName := cloudConfiguration.Name
switch {
case strings.EqualFold(cloudConfigurationName, api.AzurePublicCloudName):
case strings.EqualFold(cloudConfigurationName, api.CloudNamePublic):
return cloud.AzurePublic, nil
case strings.EqualFold(cloudConfigurationName, api.AzureGovCloudName):
case strings.EqualFold(cloudConfigurationName, api.CloudNameGov):
return cloud.AzureGovernment, nil
case strings.EqualFold(cloudConfigurationName, api.AzureChinaCloudName):
case strings.EqualFold(cloudConfigurationName, api.CloudNameChina):
return cloud.AzureChina, nil

default:
return cloud.Configuration{}, fmt.Errorf("unknown cloud configuration name '%s'", cloudConfigurationName)
}
} else if region != nil {
return cloudConfigurationFromRegion(*region), nil
} else {
// Fallback, this case should only occur during testing as we expect the region to always be given in an actual live scenario.
// Fallback
return cloud.AzurePublic, nil
}
}

// cloudConfigurationFromRegion returns a matching cloudConfiguration corresponding to a well known cloud instance for the given region
func cloudConfigurationFromRegion(region string) cloud.Configuration {
switch {
case hasAnyPrefix(region, api.AzureGovRegionPrefixes...):
return cloud.AzureGovernment
case hasAnyPrefix(region, api.AzureChinaRegionPrefixes...):
return cloud.AzureChina
default:
return cloud.AzurePublic
}
}

func hasAnyPrefix(s string, prefixes ...string) bool {
lString := strings.ToLower(s)
for _, p := range prefixes {
if strings.HasPrefix(lString, strings.ToLower(p)) {
return true
}
}
return false
}
117 changes: 25 additions & 92 deletions pkg/azure/provider/helpers/providerspec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,101 +10,34 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/gardener/machine-controller-manager-provider-azure/pkg/azure/api"
. "github.com/onsi/gomega"
"k8s.io/utils/ptr"
)

func TestNilConfig(t *testing.T) {
func TestCloudConfigurationDetermination(t *testing.T) {
g := NewWithT(t)

var (
testConfig *api.CloudConfiguration = nil
testRegion = ptr.To("Foo")
)

configuration, err := DetermineCloudConfiguration(testConfig, testRegion)

g.Expect(err).ToNot(HaveOccurred())
g.Expect(configuration).To(Equal(cloud.AzurePublic))
}

func TestNilRegion(t *testing.T) {
g := NewWithT(t)

var (
testConfig = &api.CloudConfiguration{Name: api.AzurePublicCloudName}
testRegion *string = nil
)

configuration, err := DetermineCloudConfiguration(testConfig, testRegion)

g.Expect(err).ToNot(HaveOccurred())
g.Expect(configuration).To(Equal(cloud.AzurePublic))
}

func TestNilConfigAndRegion(t *testing.T) {
g := NewWithT(t)

var (
testConfig *api.CloudConfiguration = nil
testRegion *string = nil
)

configuration, err := DetermineCloudConfiguration(testConfig, testRegion)

g.Expect(err).ToNot(HaveOccurred())
g.Expect(configuration).To(Equal(cloud.AzurePublic))
}

func TestInvalidConfigName(t *testing.T) {
g := NewWithT(t)

var (
testConfig = &api.CloudConfiguration{Name: "Foo"}
testRegion *string = nil
)

_, err := DetermineCloudConfiguration(testConfig, testRegion)

g.Expect(err).To(HaveOccurred())
}

func TestPredefinedClouds(t *testing.T) {
g := NewWithT(t)

var (
testPublicConfiguration = &api.CloudConfiguration{Name: api.AzurePublicCloudName}
testGovConfiguration = &api.CloudConfiguration{Name: api.AzureGovCloudName}
testChinaConfigration = &api.CloudConfiguration{Name: api.AzureChinaCloudName}
testRegion *string = nil
)

configuration, err := DetermineCloudConfiguration(testPublicConfiguration, testRegion)

g.Expect(err).ToNot(HaveOccurred())
g.Expect(configuration).To(Equal(cloud.AzurePublic))

configuration, err = DetermineCloudConfiguration(testGovConfiguration, testRegion)

g.Expect(err).ToNot(HaveOccurred())
g.Expect(configuration).To(Equal(cloud.AzureGovernment))

configuration, err = DetermineCloudConfiguration(testChinaConfigration, testRegion)

g.Expect(err).ToNot(HaveOccurred())
g.Expect(configuration).To(Equal(cloud.AzureChina))
}

func TestRegionMatching(t *testing.T) {
g := NewWithT(t)

var (
testConfig *api.CloudConfiguration = nil
testRegion *string = ptr.To("ussecFoo")
)

configuration, err := DetermineCloudConfiguration(testConfig, testRegion)

g.Expect(err).ToNot(HaveOccurred())
g.Expect(configuration).To(Equal(cloud.AzureGovernment))
type testData struct {
testConfiguration *api.CloudConfiguration
expectedOutput *cloud.Configuration
shouldFail bool
}

tests := []testData{
{testConfiguration: &api.CloudConfiguration{Name: api.CloudNamePublic}, expectedOutput: &cloud.AzurePublic},
{testConfiguration: &api.CloudConfiguration{Name: api.CloudNameChina}, expectedOutput: &cloud.AzureChina},
{testConfiguration: &api.CloudConfiguration{Name: api.CloudNameGov}, expectedOutput: &cloud.AzureGovernment},
{testConfiguration: &api.CloudConfiguration{Name: "FooCloud"}, expectedOutput: nil, shouldFail: true},
{testConfiguration: nil, expectedOutput: &cloud.AzurePublic},
}

for _, t := range tests {
cloudConfiguration, err := DetermineCloudConfiguration(t.testConfiguration)
if t.shouldFail {
g.Expect(err).To(HaveOccurred())
}
if t.expectedOutput != nil {
g.Expect(cloudConfiguration).To(Equal(*t.expectedOutput))
}

}

}

0 comments on commit 1e0b5c4

Please sign in to comment.