Skip to content

Commit

Permalink
Get tenant ID from Azure SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
lbussell committed May 20, 2024
1 parent d9e568c commit 354896e
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 23 deletions.
32 changes: 30 additions & 2 deletions src/Microsoft.DotNet.ImageBuilder/src/AuthHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,49 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Identity;
using Azure.ResourceManager;

namespace Microsoft.DotNet.ImageBuilder
{
public static class AuthHelper
{
private const string DefaultScope = "https://management.azure.com/.default";

public static async Task<string> GetDefaultAccessTokenAsync(string resource = DefaultScope)
public static async Task<(string token, Guid tenantId)> GetDefaultAccessTokenAsync(ILoggerService loggerService, string resource = DefaultScope)
{
DefaultAzureCredential credential = new();
AccessToken token = await credential.GetTokenAsync(new TokenRequestContext([ resource ]));
return token.Token;
Guid tenantId = GetTenantId(loggerService, credential);
return (token.Token, tenantId);
}

public static Guid GetTenantId(ILoggerService loggerService, TokenCredential credential)
{
ArmClient armClient = new(credential);
IEnumerable<Guid> tenants = armClient.GetTenants().ToList()
.Select(tenantResource => tenantResource.Data.TenantId)
.Where(guid => guid != null)
.Select(guid => (Guid)guid);

if (!tenants.Any())
{
throw new Exception("Found no tenants for given credential.");
}

if (tenants.Count() > 1)
{
string allTenantIds = string.Join(' ', tenants.Select(guid => guid.ToString()));
loggerService.WriteMessage("Found more than one tenant. Selecting the first one.");
loggerService.WriteMessage($"Tenants: {allTenantIds}");
}

return tenants.First();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ public class RegistryCredentialsOptions : IRegistryCredentialsHost
{
public IDictionary<string, RegistryCredentials> Credentials { get; set; } =
new Dictionary<string, RegistryCredentials>();
public string? Tenant { get; set; }
}

public class RegistryCredentialsOptionsBuilder
Expand All @@ -25,9 +24,7 @@ public IEnumerable<Option> GetCliOptions() =>
{
(string username, string password) = val.ParseKeyValuePair(';');
return new RegistryCredentials(username, password);
}),
CreateOption<string?>("tenant", nameof(RegistryCredentialsOptions.Tenant),
"Tenant containing the ACR to authenticate to"),
})
];

public IEnumerable<Argument> GetCliArguments() => [];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ namespace Microsoft.DotNet.ImageBuilder;
public interface IRegistryCredentialsHost
{
IDictionary<string, RegistryCredentials> Credentials { get; }
string? Tenant { get; }
}
#nullable disable
10 changes: 6 additions & 4 deletions src/Microsoft.DotNet.ImageBuilder/src/McrStatusClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class McrStatusClient : IMcrStatusClient
private readonly HttpClient _httpClient;
private readonly AsyncLockedValue<string> _accessToken = new AsyncLockedValue<string>();
private readonly AsyncPolicy<HttpResponseMessage> _httpPolicy;
private readonly ILoggerService _loggerService;

[ImportingConstructor]
public McrStatusClient(IHttpClientProvider httpClientProvider, ILoggerService loggerService)
Expand All @@ -36,6 +37,7 @@ public McrStatusClient(IHttpClientProvider httpClientProvider, ILoggerService lo
.WithRefreshAccessTokenPolicy(RefreshAccessTokenAsync, loggerService)
.WithNotFoundRetryPolicy(TimeSpan.FromHours(1), TimeSpan.FromSeconds(10), loggerService)
.Build() ?? throw new InvalidOperationException("Policy should not be null");
_loggerService = loggerService;
}

public Task<ImageResult> GetImageResultAsync(string imageDigest)
Expand Down Expand Up @@ -69,11 +71,11 @@ private async Task<T> SendRequestAsync<T>(Func<HttpRequestMessage> message)
}

private Task<string> GetAccessTokenAsync() =>
_accessToken.GetValueAsync(
() => AuthHelper.GetDefaultAccessTokenAsync(McrStatusResource));
_accessToken.GetValueAsync(async () =>
(await AuthHelper.GetDefaultAccessTokenAsync(_loggerService, McrStatusResource)).token);

private Task RefreshAccessTokenAsync() =>
_accessToken.ResetValueAsync(
() => AuthHelper.GetDefaultAccessTokenAsync(McrStatusResource));
_accessToken.ResetValueAsync(async () =>
(await AuthHelper.GetDefaultAccessTokenAsync(_loggerService, McrStatusResource)).token);
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.DotNet.ImageBuilder/src/OAuthHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public static class OAuthHelper
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
};

public static async Task<string> GetRefreshTokenAsync(HttpClient httpClient, string acrName, string tenant, string eidToken)
public static async Task<string> GetRefreshTokenAsync(HttpClient httpClient, string acrName, Guid tenant, string eidToken)
{
StringContent requestContent = new(
$"grant_type=access_token&service={acrName}&tenant={tenant}&access_token={eidToken}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,22 @@ namespace Microsoft.DotNet.ImageBuilder;
#nullable enable
[Export(typeof(IRegistryCredentialsProvider))]
[method: ImportingConstructor]
public class RegistryCredentialsProvider(IHttpClientProvider httpClientProvider) : IRegistryCredentialsProvider
public class RegistryCredentialsProvider(ILoggerService loggerService, IHttpClientProvider httpClientProvider) : IRegistryCredentialsProvider
{
private readonly ILoggerService _loggerService = loggerService;
private readonly IHttpClientProvider _httpClientProvider = httpClientProvider;

/// <summary>
/// Dynamically gets the RegistryCredentials for the specified registry in the following order of preference:
/// 1. If we own the registry, use OAuth to get the credentials.
/// 2. Read the credentials passed in from the command line.
/// 1. If we own the ACR, use the Azure SDK for authentication via the DefaultAzureCredential (no explicit credentials needed).
/// 2. If we don't own the ACR, try to read the username/password passed in from the command line.
/// 3. Return null if there are no credentials to be found.
/// </summary>
/// <param name="registry">The container registry to get credentials for.</param>
/// <returns>Registry credentials</returns>
public async ValueTask<RegistryCredentials?> GetCredentialsAsync(
string registry, string? ownedAcr, IRegistryCredentialsHost? credsHost)
{
string? tenant = credsHost?.Tenant;

// Docker Hub's registry has a separate host name for its API
string apiRegistry = registry == DockerHelper.DockerHubRegistry ?
DockerHelper.DockerHubApiRegistry :
Expand All @@ -37,18 +36,18 @@ public class RegistryCredentialsProvider(IHttpClientProvider httpClientProvider)
ownedAcr = DockerHelper.FormatAcrName(ownedAcr);
}

if (apiRegistry == ownedAcr && !string.IsNullOrEmpty(tenant))
if (apiRegistry == ownedAcr)
{
return await GetAcrCredentialsWithOAuthAsync(apiRegistry, tenant);
return await GetAcrCredentialsWithOAuthAsync(_loggerService, apiRegistry);
}

return credsHost?.TryGetCredentials(apiRegistry) ?? null;
}

private async ValueTask<RegistryCredentials> GetAcrCredentialsWithOAuthAsync(string apiRegistry, string tenant)
private async ValueTask<RegistryCredentials> GetAcrCredentialsWithOAuthAsync(ILoggerService logger, string apiRegistry)
{
string eidToken = await AuthHelper.GetDefaultAccessTokenAsync();
string refreshToken = await OAuthHelper.GetRefreshTokenAsync(_httpClientProvider.GetClient(), apiRegistry, tenant, eidToken);
(string token, Guid tenantId) = await AuthHelper.GetDefaultAccessTokenAsync(logger);
string refreshToken = await OAuthHelper.GetRefreshTokenAsync(_httpClientProvider.GetClient(), apiRegistry, tenantId, token);
return new RegistryCredentials(Guid.Empty.ToString(), refreshToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public void CreateAcrClient(string ownedAcr)
{
const string AcrName = "my-acr.azurecr.io";
const string RepoName = "repo-name";
IRegistryCredentialsHost credsHost = Mock.Of<IRegistryCredentialsHost>(o => o.Tenant == s_tenant);
IRegistryCredentialsHost credsHost = Mock.Of<IRegistryCredentialsHost>();

IContainerRegistryContentClient contentClient = Mock.Of<IContainerRegistryContentClient>();

Expand Down

0 comments on commit 354896e

Please sign in to comment.