Skip to content

Commit

Permalink
[DMS-408] Add oauth/token endpoint to handle form content (#359)
Browse files Browse the repository at this point in the history
* Add oauth/token endpoint to handle form content

* Fix failing tests
  • Loading branch information
CSR2017 authored Nov 26, 2024
1 parent 111a489 commit 0ce7020
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 19 deletions.
1 change: 1 addition & 0 deletions eng/bulkLoad/modules/BulkLoad.psm1
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ function Write-XmlFiles {
"-f", # Force reload of metadata from metadata url
"-n", # Do not validate the XML document against the XSD before processing
"-x", $Paths.XsdDirectory
"-o", "$BaseUrl/oauth/token"
)

$previousForegroundColor = $host.UI.RawUI.ForegroundColor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ public class When_Getting_An_Access_Token

private readonly TraceId TraceId = new("trace-id");

public async Task Act(string authHeader)
private const string GrantType = "client_credentials";

public async Task Act(string authHeader, string grantType = GrantType)
{
await Act(authHeader, HttpStatusCode.OK, "{}");
await Act(authHeader, HttpStatusCode.OK, "{}", grantType);
}

public async Task Act(string authHeader, HttpStatusCode responseCode, string responseMessage)
public async Task Act(string authHeader, HttpStatusCode responseCode, string responseMessage, string grantType = "")
{
// Arrange
var fakeResponse = A.Fake<HttpResponseMessage>();
Expand All @@ -46,7 +48,7 @@ public async Task Act(string authHeader, HttpStatusCode responseCode, string res
var system = new OAuthManager(_logger);

// Act
_response = await system.GetAccessTokenAsync(_httpClient, authHeader, DestinationUri, TraceId);
_response = await system.GetAccessTokenAsync(_httpClient, grantType, authHeader, DestinationUri, TraceId);
}

public async Task ActWithException(string message)
Expand All @@ -60,6 +62,7 @@ public async Task ActWithException(string message)
// Act
_response = await system.GetAccessTokenAsync(
_httpClient,
GrantType,
"basic 123:abc",
DestinationUri,
TraceId
Expand Down Expand Up @@ -228,7 +231,7 @@ await Act(
"error_description": "Invalid client or Invalid client credentials"
}
"""
);
, GrantType);
}

[Test]
Expand Down Expand Up @@ -279,7 +282,7 @@ public class Given_An_Error_Occurred_Without_Exception : When_Getting_An_Access_
[SetUp]
public async Task SetUp()
{
await Act(AuthHeader, HttpStatusCode.BadRequest, "{}");
await Act(AuthHeader, HttpStatusCode.BadRequest, "{}", GrantType);
}

[Test]
Expand Down Expand Up @@ -355,5 +358,39 @@ public async Task Then_The_Response_Content_Contains_TraceId()
// This is a rare case where it would be nice to check that the log
// has been accessed - but extension methods cannot be verified.
}

[TestFixture]
public class Given_An_Invalid_Grant_Type : When_Getting_An_Access_Token
{
[SetUp]
public async Task SetUp()
{
string authHeader = "basic abc:123";
await Act(authHeader, "invalid_grant_type");
}

[Test]
public void Then_It_Responds_With_BadRequest()
{
_response.StatusCode.Should().Be(HttpStatusCode.BadRequest);
}

[Test]
public void Then_The_Response_ContentType_Is_Problem_JSON()
{
_response
.Content.Headers.ContentType!.MediaType.Should()
.NotBeNull()
.And.Be("application/problem+json");
}

[Test]
public async Task Then_The_Response_Content_Mentions_Malformed_Header()
{
var content = await _response.Content.ReadAsStringAsync();
content.Should().NotBeNull();
content!.Should().Contain("\"detail\": \"Unsupported grant type\"");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public interface IOAuthManager
{
public Task<HttpResponseMessage> GetAccessTokenAsync(
IHttpClientWrapper httpClient,
string grantType,
string authHeaderString,
string upstreamUri,
TraceId traceId
Expand Down
12 changes: 10 additions & 2 deletions src/dms/core/EdFi.DataManagementService.Core/OAuth/OAuthManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class OAuthManager(ILogger<OAuthManager> logger) : IOAuthManager

public async Task<HttpResponseMessage> GetAccessTokenAsync(
IHttpClientWrapper httpClient,
string grantType,
string authHeaderString,
string upstreamUri,
TraceId traceId
Expand All @@ -37,12 +38,19 @@ TraceId traceId
);
}

if (!grantType.Equals("client_credentials", StringComparison.InvariantCultureIgnoreCase))
{
return GenerateProblemDetailResponse(
HttpStatusCode.BadRequest,
FailureResponse.ForBadRequest("Unsupported grant type", traceId, [], [])
);
}

HttpRequestMessage upstreamRequest = new(HttpMethod.Post, upstreamUri);
upstreamRequest.Headers.Add("Authorization", authHeaderString);

// TODO(DMS-408): Replace hard-coded with forwarded request body.
upstreamRequest.Content = new StringContent(
"grant_type=client_credentials",
$"grant_type={grantType}",
Encoding.UTF8,
"application/x-www-form-urlencoded"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,21 @@ namespace EdFi.DataManagementService.Frontend.AspNetCore.Tests.Unit.Modules;
[TestFixture]
public class TokenEndpointModuleTests
{
public static HttpRequestMessage ProxyRequest(string requestContent, string contentType)
{
var proxyRequest = new HttpRequestMessage(HttpMethod.Post, "/oauth/token");
var encodedCredentials = Convert.ToBase64String(Encoding.UTF8.GetBytes("clientId:clientSecret"));
proxyRequest.Headers.Add("Authorization", $"Basic {encodedCredentials}");
proxyRequest!.Content = new StringContent(
requestContent,
Encoding.UTF8,
contentType
);
return proxyRequest;
}

[TestFixture]
public class When_Posting_To_The_Internal_Token_Endpoint
public class When_Posting_To_The_Internal_Token_Endpoint_With_Json_Content : TokenEndpointModuleTests
{
private JsonNode? _jsonContent;
private HttpResponseMessage? _response;
Expand All @@ -46,6 +59,7 @@ public void SetUp()
A<IHttpClientWrapper>.Ignored,
A<string>.Ignored,
A<string>.Ignored,
A<string>.Ignored,
A<TraceId>.Ignored
)
)
Expand All @@ -61,18 +75,83 @@ public void SetUp()
}
);
});
using var client = factory.CreateClient();
var requestContent = """{"grant_type":"client_credentials"}""";
var proxyRequest = ProxyRequest(requestContent, "application/json");

// Act
_response = client.SendAsync(proxyRequest).GetAwaiter().GetResult();
var content = _response.Content.ReadAsStringAsync().GetAwaiter().GetResult();
_jsonContent = JsonNode.Parse(content) ?? throw new Exception("JSON parsing failed");
}

[TearDown]
public void TearDownAttribute()
{
_response?.Dispose();
}

[Test]
public void Then_it_returns_the_upstream_response_code()
{
_response!.StatusCode.Should().Be(HttpStatusCode.OK);
}

[Test]
public void Then_it_returns_the_upstream_response_body()
{
_jsonContent?["access_token"]?.ToString().Should().Be("fake_access_token");
_jsonContent?["expires_in"]?.Should().Be(300);
_jsonContent?["token_type"]?.ToString().Should().Be("bearer");
}
}

[TestFixture]
public class When_Posting_To_The_Internal_Token_Endpoint_With_Form_Content : TokenEndpointModuleTests
{
private JsonNode? _jsonContent;
private HttpResponseMessage? _response;

[SetUp]
public void SetUp()
{
// Arrange
var oAuthManager = A.Fake<IOAuthManager>();
var json =
"""{"status_code":200, "body":{"token":"fake_access_token","token_type":"bearer","expires_in":300}}""";
JsonNode _fake_responseJson = JsonNode.Parse(json)!;
var _fake_response_200 = new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent(_fake_responseJson.ToString(), Encoding.UTF8, "application/json"),
};

A.CallTo(
() =>
oAuthManager.GetAccessTokenAsync(
A<IHttpClientWrapper>.Ignored,
A<string>.Ignored,
A<string>.Ignored,
A<string>.Ignored,
A<TraceId>.Ignored
)
)
.Returns(_fake_response_200);

using var factory = new WebApplicationFactory<Program>().WithWebHostBuilder(builder =>
{
builder.UseEnvironment("Test");
builder.ConfigureServices(
(collection) =>
{
collection.AddTransient((x) => oAuthManager);
}
);
});
using var client = factory.CreateClient();
var proxyRequest = new HttpRequestMessage(HttpMethod.Post, "/oauth/token");
var encodedCredentials = Convert.ToBase64String(Encoding.UTF8.GetBytes("clientId:clientSecret"));
proxyRequest.Headers.Add("Authorization", $"Basic {encodedCredentials}");
var requestContent = "grant_type=client_credentials";
var proxyRequest = ProxyRequest(requestContent, "application/x-www-form-urlencoded");

// Act
proxyRequest!.Content = new StringContent(
"""{"grant_type"="client_credentials"}""",
Encoding.UTF8,
"application/json"
);
_response = client.SendAsync(proxyRequest).GetAwaiter().GetResult();
var content = _response.Content.ReadAsStringAsync().GetAwaiter().GetResult();
_jsonContent = JsonNode.Parse(content) ?? throw new Exception("JSON parsing failed");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using EdFi.DataManagementService.Core.OAuth;
using EdFi.DataManagementService.Frontend.AspNetCore.Configuration;
using EdFi.DataManagementService.Frontend.AspNetCore.Infrastructure.Extensions;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Options;

namespace EdFi.DataManagementService.Frontend.AspNetCore.Modules;
Expand All @@ -15,16 +16,44 @@ public class TokenEndpointModule : IEndpointModule
{
public void MapEndpoints(IEndpointRouteBuilder endpoints)
{
endpoints.MapPost("/oauth/token", GenerateToken);
endpoints.MapPost("/oauth/token", HandleFormData)
.Accepts<TokenRequest>(contentType: "application/x-www-form-urlencoded")
.DisableAntiforgery();
endpoints.MapPost("/oauth/token", HandleJsonData)
.Accepts<TokenRequest>(contentType: "application/json")
.DisableAntiforgery();
}

internal static async Task GenerateToken(
internal static async Task HandleFormData(
HttpContext httpContext,
[FromForm] TokenRequest tokenRequest,
IOptions<AppSettings> appSettings,
IOAuthManager oAuthManager,
ILogger<TokenEndpointModule> logger,
IHttpClientFactory httpClientFactory
)
{
await GenerateToken(httpContext, tokenRequest, appSettings, oAuthManager, logger, httpClientFactory);
}

internal static async Task HandleJsonData(
HttpContext httpContext,
TokenRequest tokenRequest,
IOptions<AppSettings> appSettings,
IOAuthManager oAuthManager,
ILogger<TokenEndpointModule> logger,
IHttpClientFactory httpClientFactory
)
{
await GenerateToken(httpContext, tokenRequest, appSettings, oAuthManager, logger, httpClientFactory);
}

private static async Task GenerateToken(HttpContext httpContext,
TokenRequest tokenRequest,
IOptions<AppSettings> appSettings,
IOAuthManager oAuthManager,
ILogger<TokenEndpointModule> logger,
IHttpClientFactory httpClientFactory)
{
var traceId = AspNetCoreFrontend.ExtractTraceIdFrom(httpContext.Request, appSettings);
logger.LogInformation(
Expand All @@ -38,6 +67,7 @@ IHttpClientFactory httpClientFactory

var response = await oAuthManager.GetAccessTokenAsync(
client,
tokenRequest.grant_type,
authHeader.ToString(),
appSettings.Value.AuthenticationService,
traceId
Expand All @@ -55,3 +85,8 @@ IHttpClientFactory httpClientFactory
await response.Content.CopyToAsync(httpContext.Response.Body);
}
}

public class TokenRequest
{
public string grant_type { get; set; } = "";
}

0 comments on commit 0ce7020

Please sign in to comment.