Skip to content

Commit

Permalink
refactor: split startup state into a state for awaiting renewal and a…
Browse files Browse the repository at this point in the history
…nother for beginning the creation of a certificate
  • Loading branch information
natemcmaster committed Jun 2, 2020
1 parent 4c68d81 commit 93926b8
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 149 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Nate McMaster.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using LettuceEncrypt.Internal.IO;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace LettuceEncrypt.Internal.AcmeStates
{
internal class BeginCertificateCreationState : AcmeState
{
private readonly ILogger<ServerStartupState> _logger;
private readonly IOptions<LettuceEncryptOptions> _options;
private readonly AcmeCertificateFactory _acmeCertificateFactory;
private readonly CertificateSelector _selector;
private readonly IEnumerable<ICertificateRepository> _certificateRepositories;
private readonly IClock _clock;

public BeginCertificateCreationState(AcmeStateMachineContext context, ILogger<ServerStartupState> logger,
IOptions<LettuceEncryptOptions> options, AcmeCertificateFactory acmeCertificateFactory,
CertificateSelector selector, IEnumerable<ICertificateRepository> certificateRepositories,
IClock clock) : base(context)
{
_logger = logger;
_options = options;
_acmeCertificateFactory = acmeCertificateFactory;
_selector = selector;
_certificateRepositories = certificateRepositories;
_clock = clock;
}

public override async Task<IAcmeState> MoveNextAsync(CancellationToken cancellationToken)
{
var domainNames = _options.Value.DomainNames;

try
{
var account = await _acmeCertificateFactory.GetOrCreateAccountAsync(cancellationToken);
_logger.LogInformation("Using account {accountId}", account.Id);

_logger.LogInformation("Creating certificate for {hostname}",
string.Join(",", domainNames));

var cert = await _acmeCertificateFactory.CreateCertificateAsync(cancellationToken);

_logger.LogInformation("Created certificate {subjectName} ({thumbprint})",
cert.Subject,
cert.Thumbprint);

await SaveCertificateAsync(cert, cancellationToken);
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Failed to automatically create a certificate for {hostname}", domainNames);
throw;
}

return MoveTo<CheckForRenewalState>();
}

private async Task SaveCertificateAsync(X509Certificate2 cert, CancellationToken cancellationToken)
{
_selector.Add(cert);

var saveTasks = new List<Task>
{
Task.Delay(TimeSpan.FromMinutes(5), cancellationToken)
};

var errors = new List<Exception>();
foreach (var repo in _certificateRepositories)
{
try
{
saveTasks.Add(repo.SaveAsync(cert, cancellationToken));
}
catch (Exception ex)
{
// synchronous saves may fail immediately
errors.Add(ex);
}
}

await Task.WhenAll(saveTasks);

if (errors.Count > 0)
{
throw new AggregateException("Failed to save cert to repositories", errors);
}
}
}
}
68 changes: 68 additions & 0 deletions src/LettuceEncrypt/Internal/AcmeStates/CheckForRenewalState.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Nate McMaster.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Threading;
using System.Threading.Tasks;
using LettuceEncrypt.Internal.IO;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace LettuceEncrypt.Internal.AcmeStates
{
class CheckForRenewalState : AcmeState
{
private readonly ILogger<CheckForRenewalState> _logger;
private readonly IOptions<LettuceEncryptOptions> _options;
private readonly CertificateSelector _selector;
private readonly IClock _clock;

public CheckForRenewalState(
AcmeStateMachineContext context,
ILogger<CheckForRenewalState> logger,
IOptions<LettuceEncryptOptions> options,
CertificateSelector selector,
IClock clock) : base(context)
{
_logger = logger;
_options = options;
_selector = selector;
_clock = clock;
}

public override async Task<IAcmeState> MoveNextAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
var checkPeriod = _options.Value.RenewalCheckPeriod;
var daysInAdvance = _options.Value.RenewDaysInAdvance;
if (!checkPeriod.HasValue || !daysInAdvance.HasValue)
{
_logger.LogInformation("Automatic certificate renewal is not configured. Stopping {service}",
nameof(AcmeCertificateLoader));
return MoveTo<TerminalState>();
}

await Task.Delay(checkPeriod.Value, cancellationToken);

var domainNames = _options.Value.DomainNames;
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Checking certificates' renewals for {hostname}",
string.Join(", ", domainNames));
}

foreach (var domainName in domainNames)
{
if (!_selector.TryGet(domainName, out var cert)
|| cert == null
|| cert.NotAfter <= _clock.Now.DateTime + daysInAdvance.Value)
{
return MoveTo<BeginCertificateCreationState>();
}
}
}

return MoveTo<TerminalState>();
}
}
}
156 changes: 8 additions & 148 deletions src/LettuceEncrypt/Internal/AcmeStates/ServerStartupState.cs
Original file line number Diff line number Diff line change
@@ -1,181 +1,41 @@
// Copyright (c) Nate McMaster.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using LettuceEncrypt.Internal.IO;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace LettuceEncrypt.Internal.AcmeStates
{
internal class ServerStartupState : AcmeState
internal class ServerStartupState : SyncAcmeState
{
private readonly ILogger<ServerStartupState> _logger;
private readonly IOptions<LettuceEncryptOptions> _options;
private readonly AcmeCertificateFactory _acmeCertificateFactory;
private readonly CertificateSelector _selector;
private readonly IEnumerable<ICertificateRepository> _certificateRepositories;
private readonly IClock _clock;
private readonly ILogger<ServerStartupState> _logger;

public ServerStartupState(
AcmeStateMachineContext context,
ILogger<ServerStartupState> logger,
IOptions<LettuceEncryptOptions> options,
AcmeCertificateFactory acmeCertificateFactory,
CertificateSelector selector,
IEnumerable<ICertificateRepository> certificateRepositories,
IClock clock)
: base(context)
ILogger<ServerStartupState> logger) :
base(context)
{
_logger = logger;
_options = options;
_acmeCertificateFactory = acmeCertificateFactory;
_selector = selector;
_certificateRepositories = certificateRepositories;
_clock = clock;
}

private const string ErrorMessage = "Failed to create certificate";

public override async Task<IAcmeState> MoveNextAsync(CancellationToken cancellationToken)
{
try
{
await LoadCerts(cancellationToken);
}
catch (AggregateException ex) when (ex.InnerException != null)
{
_logger.LogError(0, ex.InnerException, ErrorMessage);
}
catch (Exception ex)
{
_logger.LogError(0, ex, ErrorMessage);
}

await MonitorRenewal(cancellationToken);
return this;
_logger = logger;
}

private async Task LoadCerts(CancellationToken cancellationToken)
public override IAcmeState MoveNext()
{
cancellationToken.ThrowIfCancellationRequested();

var domainNames = _options.Value.DomainNames;
var hasCertForAllDomains = domainNames.All(_selector.HasCertForDomain);
if (hasCertForAllDomains)
{
_logger.LogDebug("Certificate for {domainNames} already found.", domainNames);
return;
return MoveTo<CheckForRenewalState>();
}

await CreateCertificateAsync(domainNames, cancellationToken);
}

private async Task CreateCertificateAsync(string[] domainNames, CancellationToken cancellationToken)
{
var account = await _acmeCertificateFactory.GetOrCreateAccountAsync(cancellationToken);
_logger.LogInformation("Using account {accountId}", account.Id);

try
{
_logger.LogInformation("Creating certificate for {hostname}",
string.Join(",", domainNames));

var cert = await _acmeCertificateFactory.CreateCertificateAsync(cancellationToken);

_logger.LogInformation("Created certificate {subjectName} ({thumbprint})",
cert.Subject,
cert.Thumbprint);

await SaveCertificateAsync(cert, cancellationToken);
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Failed to automatically create a certificate for {hostname}", domainNames);
throw;
}
}

private async Task SaveCertificateAsync(X509Certificate2 cert, CancellationToken cancellationToken)
{
_selector.Add(cert);

var saveTasks = new List<Task>
{
Task.Delay(TimeSpan.FromMinutes(5), cancellationToken)
};

var errors = new List<Exception>();
foreach (var repo in _certificateRepositories)
{
try
{
saveTasks.Add(repo.SaveAsync(cert, cancellationToken));
}
catch (Exception ex)
{
// synchronous saves may fail immediately
errors.Add(ex);
}
}

await Task.WhenAll(saveTasks);

if (errors.Count > 0)
{
throw new AggregateException("Failed to save cert to repositories", errors);
}
}

private async Task MonitorRenewal(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
var checkPeriod = _options.Value.RenewalCheckPeriod;
var daysInAdvance = _options.Value.RenewDaysInAdvance;
if (!checkPeriod.HasValue || !daysInAdvance.HasValue)
{
_logger.LogInformation("Automatic certificate renewal is not configured. Stopping {service}",
nameof(AcmeCertificateLoader));
return;
}

await Task.Delay(checkPeriod.Value, cancellationToken);

try
{
var domainNames = _options.Value.DomainNames;
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Checking certificates' renewals for {hostname}",
string.Join(", ", domainNames));
}

foreach (var domainName in domainNames)
{
if (!_selector.TryGet(domainName, out var cert)
|| cert == null
|| cert.NotAfter <= _clock.Now.DateTime + daysInAdvance.Value)
{
await CreateCertificateAsync(domainNames, cancellationToken);
break;
}
}
}
catch (AggregateException ex) when (ex.InnerException != null)
{
_logger.LogError(0, ex.InnerException, ErrorMessage);
}
catch (Exception ex)
{
_logger.LogError(0, ex, ErrorMessage);
}
}
return MoveTo<BeginCertificateCreationState>();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ public static ILettuceEncryptServiceBuilder AddLettuceEncrypt(this IServiceColle
services.AddSingleton(TerminalState.Singleton);

// States should always be transient
services.AddTransient<ServerStartupState>();
services
.AddTransient<ServerStartupState>()
.AddTransient<CheckForRenewalState>()
.AddTransient<BeginCertificateCreationState>();

return new LettuceEncryptServiceBuilder(services);
}
Expand Down

0 comments on commit 93926b8

Please sign in to comment.