Skip to content

Implement a fast lookup for wildcard certificates #2815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<Copyright>© Microsoft Corporation. All rights reserved.</Copyright>
<PackageIcon></PackageIcon>
<PackageIconFullPath></PackageIconFullPath>
<LangVersion>12.0</LangVersion>
<LangVersion>13.0</LangVersion>
<PackageLicenseExpression>MIT</PackageLicenseExpression>
<StrongNameKeyId>Microsoft</StrongNameKeyId>
<EmbedUntrackedSources>true</EmbedUntrackedSources>
Expand Down
121 changes: 121 additions & 0 deletions src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;

#nullable enable
namespace Yarp.Kubernetes.Controller.Certificates;

public abstract class ImmutableCertificateCache<TCert> where TCert : class
{
private readonly List<WildCardDomain> _wildCardDomains = new();
private readonly Dictionary<string, TCert> _certificates = new(StringComparer.OrdinalIgnoreCase);

public ImmutableCertificateCache(IEnumerable<TCert> certificates, Func<TCert, IEnumerable<string>> getDomains)
{
foreach (var certificate in certificates)
{
foreach (var domain in getDomains(certificate))
{
if (domain.StartsWith("*."))
{
_wildCardDomains.Add(new (domain[1..], certificate));
}
else
{
_certificates[domain] = certificate;
}
}
}

_wildCardDomains.Sort(DomainNameComparer.Instance);
}



protected abstract TCert? GetDefaultCertificate();

public TCert? GetCertificate(string domain)
{
if (TryGetCertificateExact(domain, out var certificate))
{
return certificate;
}
if (TryGetWildcardCertificate(domain, out certificate))
{
return certificate;
}

return GetDefaultCertificate();
}

protected IReadOnlyList<WildCardDomain> WildcardCertificates => _wildCardDomains;

protected IReadOnlyDictionary<string, TCert> Certificates => _certificates;

protected record struct WildCardDomain(string Domain, TCert? Certificate);

private bool TryGetCertificateExact(string domain, [NotNullWhen(true)] out TCert? certificate) =>
_certificates.TryGetValue(domain, out certificate);

private bool TryGetWildcardCertificate(string domain, [NotNullWhen(true)] out TCert? certificate)
{
if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index and > -1)
{
certificate = _wildCardDomains[index].Certificate!;
return true;
}

certificate = null;
return false;
}

/// <summary>
/// Sorts domain names right to left.
/// This allows us to use a Binary Search to achieve a suffix
/// search.
/// </summary>
private class DomainNameComparer : IComparer<WildCardDomain>
{
public static readonly DomainNameComparer Instance = new();

public int Compare(WildCardDomain x, WildCardDomain y)
{
var ret = Compare(x.Domain.AsSpan(), y.Domain.AsSpan());
if (ret != 0)
{
return ret;
}

return (x.Certificate, y.Certificate) switch
{
(null, not null) when x.Domain.Length > y.Domain.Length => 0,
(not null, null) when x.Domain.Length < y.Domain.Length => 0,
_ => x.Domain.Length - y.Domain.Length
};
}

private static int Compare(ReadOnlySpan<char> x, ReadOnlySpan<char> y)
{

var length = Math.Min(x.Length, y.Length);

for (var i = 1; i <= length; i++)
Comment on lines +103 to +105
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the JIT can't elide the bound-check for indexing into the spans.
I think it's better to slice the spans to the common (= min) length and use one span in the loop, in order to elide at least one bound check (the JIT should handle the downward loop (see other comment) in a recent version).

Roughly:

Suggested change
var length = Math.Min(x.Length, y.Length);
for (var i = 1; i <= length; i++)
if (x.Length < y.Length)
{
y = y.Slice(0, x.Length);
}
else if (y.Length < x.Length)
{
x = x.Slice(0, y.Length);
}
for (var i = x.Length; i >= 1; --i)

{
var charA = x[^i] & 0x5F;
var charB = y[^i] & 0x5F;
Comment on lines +107 to +108
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You compare it backwards?
Is it better then, to write the loop backwards and have normal index access?


if (charA == charB)
{
continue;
}

return charB - charA;
}

return 0;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#nullable enable
using System.Collections.Generic;
using System.Formats.Asn1;
using System.Linq;
using System.Security.Cryptography.X509Certificates;

namespace Yarp.Kubernetes.Controller.Certificates;

internal partial class ServerCertificateSelector
{
private class ImmutableX509CertificateCache(IEnumerable<X509Certificate2> certificates)
: ImmutableCertificateCache<X509Certificate2>(certificates, GetDomains)
{
protected override X509Certificate2? GetDefaultCertificate()
{
if (WildcardCertificates.Count != 0)
{
return WildcardCertificates[0].Certificate;
}
return Certificates.Values.FirstOrDefault();
}
}

private static IEnumerable<string> GetDomains(X509Certificate2 certificate)
{
if (certificate.GetNameInfo(X509NameType.DnsName, false) is { } dnsName)
{
yield return dnsName;
}

const string SAN_OID = "2.5.29.17";
var extension = certificate.Extensions[SAN_OID];
if (extension is null)
{
yield break;
}

var dnsNameTag = new Asn1Tag(TagClass.ContextSpecific, tagValue: 2, isConstructed: false);

var asnReader = new AsnReader(extension.RawData, AsnEncodingRules.BER);
var sequenceReader = asnReader.ReadSequence(Asn1Tag.Sequence);
while (sequenceReader.HasData)
{
var tag = sequenceReader.PeekTag();
if (tag != dnsNameTag)
{
sequenceReader.ReadEncodedValue();
continue;
}

var alternativeName = sequenceReader.ReadCharacterString(UniversalTagNumber.IA5String, dnsNameTag);
yield return alternativeName;
}

}



}
Original file line number Diff line number Diff line change
@@ -1,27 +1,57 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using System.Timers;
using Microsoft.AspNetCore.Connections;
using Microsoft.Extensions.Hosting;

namespace Yarp.Kubernetes.Controller.Certificates;

internal class ServerCertificateSelector : IServerCertificateSelector
internal partial class ServerCertificateSelector
: BackgroundService
, IServerCertificateSelector
{
private X509Certificate2 _defaultCertificate;
private readonly ConcurrentDictionary<NamespacedName, X509Certificate2> _certificates = new();
private bool _hasBeenUpdated;

private ImmutableX509CertificateCache _certificateStore = new(Array.Empty<X509Certificate2>());

public void AddCertificate(NamespacedName certificateName, X509Certificate2 certificate)
{
_defaultCertificate = certificate;
_certificates[certificateName] = certificate;
_hasBeenUpdated = true;
}

public X509Certificate2 GetCertificate(ConnectionContext connectionContext, string domainName)
{
return _defaultCertificate;
return _certificateStore.GetCertificate(domainName);
}

public void RemoveCertificate(NamespacedName certificateName)
{
_defaultCertificate = null;
_ = _certificates.TryRemove(certificateName, out _);
_hasBeenUpdated = true;
}

protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
// Poll every 10 seconds for updates to
while (!stoppingToken.IsCancellationRequested)
{
await Task.Delay(TimeSpan.FromSeconds(10), stoppingToken);
if (_hasBeenUpdated)
{
_hasBeenUpdated = false;
_certificateStore = new ImmutableX509CertificateCache(_certificates.Values);
}
}
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ public static IServiceCollection AddKubernetesControllerRuntime(this IServiceCol
services.RegisterResourceInformer<V1Secret, V1SecretResourceInformer>("type=kubernetes.io/tls");

// Add the Ingress/Secret to certificate management
services.AddSingleton<IServerCertificateSelector, ServerCertificateSelector>();
services.AddSingleton<ServerCertificateSelector>();
services.AddHostedService(x => x.GetRequiredService<ServerCertificateSelector>());
services.AddSingleton<IServerCertificateSelector>(x => x.GetRequiredService<ServerCertificateSelector>());
services.AddSingleton<ICertificateHelper, CertificateHelper>();

// ingress status updater
Expand Down
56 changes: 56 additions & 0 deletions test/Kubernetes.Tests/Certificates/CertificateCacheTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using Xunit;
#nullable enable
namespace Yarp.Kubernetes.Controller.Certificates.Tests;

public class CertificateCacheTests
{

private static readonly FakeCertificateCache Cache = new(
new FakeCertificate("Acme", "mail.acme.com", "www.acme.com"),
new FakeCertificate("Initech", "*.initech.com", "initech.com"),
new FakeCertificate("Northwind", "*.northwind.com")
);

[Theory]
[InlineData("www.acme.com", "Acme")]
[InlineData("www.ACME.com", "Acme")]
[InlineData("mail.acme.com", "Acme")]
[InlineData("acme.com", null)]
[InlineData("store.acme.com", null)]
[InlineData("www.northwind.com", "Northwind")]
[InlineData("mail.northwind.com", "Northwind")]
[InlineData("northwind.com", null)]
[InlineData("initech.com", "Initech")]
[InlineData("www.initech.com", "Initech")]
[InlineData("www.IniTech.coM", "Initech")]
public void CertificateConversionFromPem(string requestedDomain, string? expectedCompanyName)
{
var certificate = Cache.GetCertificate(requestedDomain);
if (expectedCompanyName != null)
{
Assert.Equal(expectedCompanyName, certificate?.Name);
}
else
{
Assert.Null(certificate?.Name);
}
}

private record FakeCertificate(string Name, params string[] Domains);

private class FakeCertificateCache(params IEnumerable<FakeCertificate> certificates)
: ImmutableCertificateCache<FakeCertificate>(certificates, static cert => cert.Domains)
{
protected override FakeCertificate? GetDefaultCertificate()
{
return null;
}
}
}



7 changes: 4 additions & 3 deletions test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ public async Task RequestWithCookieHeaders(params string[] cookies)
{
var events = TestEventListener.Collect();


var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "GET";
httpContext.Request.Headers[HeaderNames.Cookie] = cookies;
Expand Down Expand Up @@ -1260,7 +1260,7 @@ public static IEnumerable<object[]> ResponseMultiHeadersData()
{
foreach (var header in ResponseMultiHeaderNames())
{
foreach (var version in new[] { "1.1", "2.0" })
foreach (var version in new[] { "1.1", "2.0" })
{
foreach (var value in MultiValues())
{
Expand Down Expand Up @@ -2567,8 +2567,9 @@ public async Task Response_RemoveProhibitedHeaders(string protocol, string prohi

await sut.SendAsync(httpContext, destinationPrefix, client, new ForwarderRequestConfig { Version = Version.Parse(protocol) });

string[] headers = httpContext.Response.Headers[PreservedHeaderName];
Assert.Equal((int)HttpStatusCode.OK, httpContext.Response.StatusCode);
Assert.Equal(PreservedHeaderValue, string.Join(", ", httpContext.Response.Headers[PreservedHeaderName]));
Assert.Equal(PreservedHeaderValue, string.Join(", ", headers));

foreach (var (name, _) in prohibitedHeaders)
{
Expand Down