Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add msi source detect logic #4761

Merged
merged 9 commits into from
May 21, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ internal class AppServiceManagedIdentitySource : AbstractManagedIdentity
private readonly Uri _endpoint;
private readonly string _secret;

public static AbstractManagedIdentity TryCreate(RequestContext requestContext)
public static AbstractManagedIdentity Create(RequestContext requestContext)
{
var msiSecret = EnvironmentVariables.IdentityHeader;
requestContext.Logger.Info(() => "[Managed Identity] App service managed identity is available.");

return TryValidateEnvVars(EnvironmentVariables.IdentityEndpoint, msiSecret, requestContext.Logger, out Uri endpointUri)
? new AppServiceManagedIdentitySource(requestContext, endpointUri, msiSecret)
return TryValidateEnvVars(EnvironmentVariables.IdentityEndpoint, requestContext.Logger, out Uri endpointUri)
? new AppServiceManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.IdentityHeader)
: null;
}

Expand All @@ -38,17 +38,10 @@ private AppServiceManagedIdentitySource(RequestContext requestContext, Uri endpo
_secret = secret;
}

private static bool TryValidateEnvVars(string msiEndpoint, string secret, ILoggerAdapter logger, out Uri endpointUri)
private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger, out Uri endpointUri)
{
endpointUri = null;

// if BOTH the env vars endpoint and secret values are null, this MSI provider is unavailable.
if (string.IsNullOrEmpty(msiEndpoint) || string.IsNullOrEmpty(secret))
{
logger.Verbose(()=>"[Managed Identity] App service managed identity is unavailable.");
return false;
}

try
{
endpointUri = new Uri(msiEndpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,11 @@ internal class AzureArcManagedIdentitySource : AbstractManagedIdentity

private readonly Uri _endpoint;

public static AbstractManagedIdentity TryCreate(RequestContext requestContext)
public static AbstractManagedIdentity Create(RequestContext requestContext)
{
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
string imdsEndpoint = EnvironmentVariables.ImdsEndpoint;

// if BOTH the env vars IDENTITY_ENDPOINT and IMDS_ENDPOINT are set the MsiType is Azure Arc
if (string.IsNullOrEmpty(identityEndpoint) || string.IsNullOrEmpty(imdsEndpoint))
{
requestContext.Logger.Verbose(()=>"[Managed Identity] Azure Arc managed identity is unavailable.");
return null;
}
requestContext.Logger.Info(() => "[Managed Identity] Azure Arc managed identity is available.");

if (!Uri.TryCreate(identityEndpoint, UriKind.Absolute, out Uri endpointUri))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@ internal class CloudShellManagedIdentitySource : AbstractManagedIdentity
private readonly Uri _endpoint;
private const string CloudShell = "Cloud Shell";

public static AbstractManagedIdentity TryCreate(RequestContext requestContext)
public static AbstractManagedIdentity Create(RequestContext requestContext)
{
string msiEndpoint = EnvironmentVariables.MsiEndpoint;

// if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell
if (string.IsNullOrEmpty(msiEndpoint))
{
requestContext.Logger.Verbose(()=>"[Managed Identity] Cloud shell managed identity is unavailable.");
return null;
}

Uri endpointUri;

requestContext.Logger.Info(() => "[Managed Identity] Cloud shell managed identity is available.");

try
{
endpointUri = new Uri(msiEndpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ internal class ImdsManagedIdentitySource : AbstractManagedIdentity
internal ImdsManagedIdentitySource(RequestContext requestContext) :
base(requestContext, ManagedIdentitySource.Imds)
{
requestContext.Logger.Info(() => "[Managed Identity] Defaulting to IMDS endpoint for managed identity.");

if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint))
{
requestContext.Logger.Verbose(() => "[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: " + EnvironmentVariables.PodIdentityEndpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,9 @@
// Licensed under the MIT License.

using System;
using Microsoft.Identity.Client.Extensibility;
using System.Threading.Tasks;
using System.Threading;
using System.Linq;
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Internal;
using Microsoft.IdentityModel.Abstractions;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Utils;
using Microsoft.Identity.Client.OAuth2;
using Microsoft.Identity.Client.ApiConfig.Parameters;

namespace Microsoft.Identity.Client.ManagedIdentity
Expand All @@ -23,6 +16,13 @@ namespace Microsoft.Identity.Client.ManagedIdentity
internal class ManagedIdentityClient
{
private readonly AbstractManagedIdentity _identitySource;
internal static Lazy<ManagedIdentitySource> s_managedIdentitySourceDetected = new Lazy<ManagedIdentitySource>(() => GetManagedIdentitySource());

// To reset the cached source for testing purposes.
internal static void resetCachedSource()
{
s_managedIdentitySourceDetected = new Lazy<ManagedIdentitySource>(() => GetManagedIdentitySource());
}

public ManagedIdentityClient(RequestContext requestContext)
{
Expand All @@ -40,12 +40,51 @@ internal Task<ManagedIdentityResponse> SendTokenRequestForManagedIdentityAsync(A
// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
private static AbstractManagedIdentity SelectManagedIdentitySource(RequestContext requestContext)
{
return
ServiceFabricManagedIdentitySource.TryCreate(requestContext) ??
AppServiceManagedIdentitySource.TryCreate(requestContext) ??
CloudShellManagedIdentitySource.TryCreate(requestContext) ??
AzureArcManagedIdentitySource.TryCreate(requestContext) ??
new ImdsManagedIdentitySource(requestContext);
return s_managedIdentitySourceDetected.Value switch
{
ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext),
_ => new ImdsManagedIdentitySource(requestContext)
};
}

// Detect managed identity source based on the availability of environment variables.
private static ManagedIdentitySource GetManagedIdentitySource()
{
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
neha-bhargava marked this conversation as resolved.
Show resolved Hide resolved
string identityHeader = EnvironmentVariables.IdentityHeader;
string identityServerThumbprint = EnvironmentVariables.IdentityServerThumbprint;
string msiSecret = EnvironmentVariables.IdentityHeader;
string msiEndpoint = EnvironmentVariables.MsiEndpoint;
string imdsEndpoint = EnvironmentVariables.ImdsEndpoint;
string podIdentityEndpoint = EnvironmentVariables.PodIdentityEndpoint;


if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(identityHeader))
{
if (!string.IsNullOrEmpty(identityServerThumbprint))
{
return ManagedIdentitySource.ServiceFabric;
}
else
{
return ManagedIdentitySource.AppService;
}
}
else if (!string.IsNullOrEmpty(msiEndpoint))
{
return ManagedIdentitySource.CloudShell;
}
else if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(imdsEndpoint))
{
return ManagedIdentitySource.AzureArc;
}
else
{
return ManagedIdentitySource.DefaultToImds;
neha-bhargava marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ public enum ManagedIdentitySource
/// <summary>
/// The source to acquire token for managed identity is Service Fabric.
/// </summary>
ServiceFabric
ServiceFabric,

/// <summary>
/// Indicates that the source is defaulted to IMDS since no environment variables are set.
/// This is used to detect the managed identity source.
/// </summary>
DefaultToImds
neha-bhargava marked this conversation as resolved.
Show resolved Hide resolved
neha-bhargava marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,11 @@ internal class ServiceFabricManagedIdentitySource : AbstractManagedIdentity
private readonly Uri _endpoint;
private readonly string _identityHeaderValue;

public static AbstractManagedIdentity TryCreate(RequestContext requestContext)
public static AbstractManagedIdentity Create(RequestContext requestContext)
{
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
string identityHeader = EnvironmentVariables.IdentityHeader;
string identityServerThumbprint = EnvironmentVariables.IdentityServerThumbprint;

if (string.IsNullOrEmpty(identityEndpoint) || string.IsNullOrEmpty(identityHeader) || string.IsNullOrEmpty(identityServerThumbprint))
{
requestContext.Logger.Verbose(() => "[Managed Identity] Service Fabric managed identity unavailable.");
return null;
}
requestContext.Logger.Info(() => "[Managed Identity] Service fabric managed identity is available.");

if (!Uri.TryCreate(identityEndpoint, UriKind.Absolute, out Uri endpointUri))
{
Expand All @@ -48,7 +42,7 @@ public static AbstractManagedIdentity TryCreate(RequestContext requestContext)
}

requestContext.Logger.Verbose(() => "[Managed Identity] Creating Service Fabric managed identity. Endpoint URI: " + identityEndpoint);
return new ServiceFabricManagedIdentitySource(requestContext, endpointUri, identityHeader);
return new ServiceFabricManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.IdentityHeader);
}

private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri endpoint, string identityHeaderValue) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,14 @@ public AcquireTokenForManagedIdentityParameterBuilder AcquireTokenForManagedIden
ClientExecutorFactory.CreateManagedIdentityExecutor(this),
resource);
}

/// <summary>
/// Detects and returns the managed identity source available on the environment.
/// </summary>
/// <returns>Managed identity source detected on the environment if any.</returns>
public static ManagedIdentitySource GetManagedIdentitySource()
neha-bhargava marked this conversation as resolved.
Show resolved Hide resolved
{
return ManagedIdentityClient.s_managedIdentitySourceDetected.Value;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity
Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", "thumbprint");
break;
}

ManagedIdentityClient.resetCachedSource();
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ public async Task AcquireMSITokenAsync(MsiAzureResource azureResource, string us
//Set the Environment Variables
SetEnvironmentVariables(envVariables);

//Reset cached source with update in environment variables
ManagedIdentityClient.resetCachedSource();

//form the http proxy URI
string uri = s_baseURL + $"MSIToken?" +
$"azureresource={azureResource}&uri=";
Expand Down Expand Up @@ -137,7 +140,10 @@ public async Task AcquireMsiToken_ForTokenExchangeResource_Successfully()

//Set the Environment Variables
SetEnvironmentVariables(envVariables);


//Reset cached source with update in environment variables
ManagedIdentityClient.resetCachedSource();

//form the http proxy URI
string uri = s_baseURL + $"MSIToken?" +
$"azureresource={MsiAzureResource.WebApp}&uri=";
Expand Down Expand Up @@ -202,6 +208,9 @@ public async Task ManagedIdentityRequestFailureCheckAsync(MsiAzureResource azure
//Set the Environment Variables
SetEnvironmentVariables(envVariables);

//Reset cached source with update in environment variables
ManagedIdentityClient.resetCachedSource();

//form the http proxy URI
string uri = s_baseURL + $"MSIToken?" +
$"azureresource={azureResource}&uri=";
Expand Down Expand Up @@ -244,6 +253,9 @@ public async Task MSIWrongScopesAsync(MsiAzureResource azureResource, string use
//Set the Environment Variables
SetEnvironmentVariables(envVariables);

//Reset cached source with update in environment variables
ManagedIdentityClient.resetCachedSource();

//form the http proxy URI
string uri = s_baseURL + $"MSIToken?" +
$"azureresource={azureResource}&uri=";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ public class ManagedIdentityTests : TestBase
internal const string ExpectedErrorCode = "ErrorCode";
internal const string ExpectedCorrelationId = "Some GUID";

[DataTestMethod]
[DataRow("http://127.0.0.1:41564/msi/token/", ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)]
[DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)]
[DataRow(ImdsEndpoint, ManagedIdentitySource.Imds, ManagedIdentitySource.DefaultToImds)]
[DataRow(null, ManagedIdentitySource.Imds, ManagedIdentitySource.DefaultToImds)]
[DataRow(AzureArcEndpoint, ManagedIdentitySource.AzureArc, ManagedIdentitySource.AzureArc)]
[DataRow(CloudShellEndpoint, ManagedIdentitySource.CloudShell, ManagedIdentitySource.CloudShell)]
[DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, ManagedIdentitySource.ServiceFabric)]
public void GetManagedIdentityTests(
neha-bhargava marked this conversation as resolved.
Show resolved Hide resolved
string endpoint,
ManagedIdentitySource managedIdentitySource,
ManagedIdentitySource expectedManagedIdentitySource)
{
using (new EnvVariableContext())
{
SetEnvironmentVariables(managedIdentitySource, endpoint);

Assert.AreEqual(expectedManagedIdentitySource, ManagedIdentityApplication.GetManagedIdentitySource());
}
}

[DataTestMethod]
[DataRow("http://127.0.0.1:41564/msi/token/", Resource, ManagedIdentitySource.AppService)]
[DataRow(AppServiceEndpoint, Resource, ManagedIdentitySource.AppService)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ public async Task AcquireTokenWithMSITelemetryTestAsync()
string endpoint = "http://localhost:40342/metadata/identity/oauth2/token";
string resource = "https://management.azure.com";

Environment.SetEnvironmentVariable("MSI_ENDPOINT", endpoint);
SetEnvironmentVariables(ManagedIdentitySource.CloudShell, endpoint);

var mia = ManagedIdentityApplicationBuilder
.Create(ManagedIdentityId.SystemAssigned)
Expand Down