Skip to content

Commit

Permalink
feat: added refinement service
Browse files Browse the repository at this point in the history
* Should improve compile time refinement performance
* TODO need to add many more tests around that compile time refinement
  • Loading branch information
bmazzarol committed Jan 4, 2025
1 parent a258eb3 commit 0ac08d4
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 68 deletions.
122 changes: 60 additions & 62 deletions Tuxedo.SourceGenerator/Analysers/InvalidConstAssignmentAnalyser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class InvalidConstAssignmentAnalyser : DiagnosticAnalyzer
/// <inheritdoc />
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics => [Rule];

private static Lazy<Assembly?>? _currentAssembly;
private static Lazy<Type?>? _refinementServiceType;

/// <inheritdoc />
public override void Initialize(AnalysisContext context)
Expand All @@ -50,47 +50,56 @@ public override void Initialize(AnalysisContext context)

// we build an in memory assembly for the compilation so we can run
// the refinement predicates
if (_currentAssembly is null)
if (_refinementServiceType is null)
{
#pragma warning disable S2696
_currentAssembly = new Lazy<Assembly?>(() =>
_refinementServiceType = new Lazy<Type?>(
#pragma warning restore S2696
{
using var ms = new MemoryStream();
var result = compilationContext.Compilation.Emit(
ms,
cancellationToken: compilationContext.CancellationToken
);

if (!result.Success)
{
return null;
}

ms.Seek(0, SeekOrigin.Begin);
#pragma warning disable RS1035 // we need to load it here, as we want to run code against the current state of the compilation.
return Assembly.Load(ms.ToArray());
#pragma warning restore RS1035
});
() =>
BuildIntoRuntimeRefinementServiceType(
compilationContext.Compilation,
compilationContext.CancellationToken
)
);
}

if (_currentAssembly.Value is not { } analysedAssembly)
if (_refinementServiceType.Value is not { } refinementServiceType)
{
return;
}

compilationContext.RegisterOperationAction(
analysisContext => Analyze(analysisContext, analysedAssembly),
analysisContext => Analyze(analysisContext, refinementServiceType),
OperationKind.Invocation
);
});
}

private static readonly ConcurrentDictionary<string, MethodInfo?> TryParseDelegates =
private static Type? BuildIntoRuntimeRefinementServiceType(
Compilation compilation,
CancellationToken token
)
{
using var ms = new MemoryStream();
var result = compilation.Emit(ms, cancellationToken: token);

if (!result.Success)
{
return null;
}

ms.Seek(0, SeekOrigin.Begin);
#pragma warning disable RS1035 // we need to load it here, as we want to run code against the current state of the compilation
var assembly = Assembly.Load(ms.ToArray());
#pragma warning restore RS1035
return assembly.GetType("Tuxedo.RefinementService");
}

private static readonly ConcurrentDictionary<string, Func<object, string?>?> TryParseDelegates =
new(StringComparer.Ordinal);

[SuppressMessage("Design", "MA0051:Method is too long")]
private static void Analyze(OperationAnalysisContext ctx, Assembly assembly)
private static void Analyze(OperationAnalysisContext ctx, Type refinementServiceType)
{
var operation = (IInvocationOperation)ctx.Operation;
if (
Expand All @@ -111,56 +120,25 @@ private static void Analyze(OperationAnalysisContext ctx, Assembly assembly)
return;
}

var fullName = operation.TargetMethod.ContainingType.ToDisplayString();

try
{
var tryParseDelegate = TryParseDelegates.GetOrAdd(
fullName,
fn =>
{
// we try and get the refined type out of the assembly
var type = assembly.GetType(fn);

if (type == null)
{
return null;
}

// now we get the static method and call it to find out if the
// constant would pass the refinement predicate
var methodInfo = type.GetMethod(
"TryParse",
BindingFlags.Public | BindingFlags.Static
);

if (methodInfo == null)
{
return null;
}

#pragma warning disable MA0026, S1135
return methodInfo; // TODO: use expressions to make this faster
#pragma warning restore S1135, MA0026
}
// build a refinement delegate and run the constant value against it
var refinementDelegate = TryParseDelegates.GetOrAdd(
operation.TargetMethod.ContainingType.ToDisplayString(),
s => BuildRefinementDelegate(s, refinementServiceType)
);

if (tryParseDelegate is null)
{
return;
}
var failureMessage = refinementDelegate?.Invoke(constantValue);

// out parameters
object?[] parameters = [constantValue, null, null];
if (tryParseDelegate.Invoke(null, parameters) is true)
if (failureMessage is null)
{
return;
}

var diagnostic = Diagnostic.Create(
descriptor: Rule,
location: ctx.Operation.Syntax.GetLocation(),
messageArgs: parameters[2]
messageArgs: failureMessage
);

ctx.ReportDiagnostic(diagnostic);
Expand All @@ -170,4 +148,24 @@ private static void Analyze(OperationAnalysisContext ctx, Assembly assembly)
// we just give up, we tried
}
}

private static Func<object, string?>? BuildRefinementDelegate(
string fn,
Type refinementServiceType
)
{
var methodInfo = refinementServiceType.GetMethod(
$"TestAgainst{fn.RemoveNamespace()}",
#pragma warning disable S3011
BindingFlags.NonPublic | BindingFlags.Static
#pragma warning restore S3011
);

if (methodInfo == null)
{
return null;
}

return (Func<object, string?>?)methodInfo.CreateDelegate(typeof(Func<object, string?>));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,9 @@ private readonly record struct RefinedTypeDetails(
)
{
public string? RefinedTypeXmlSafeName => (RefinedType + Generics).EscapeXml();

public bool IsTuple =>
RawType?.StartsWith("(", StringComparison.Ordinal) == true
&& RawType.EndsWith(")", StringComparison.Ordinal);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System.Collections.Immutable;
using Tuxedo.SourceGenerator.Extensions;

namespace Tuxedo.SourceGenerator;

public sealed partial class RefinementSourceGenerator
{
private static string RenderRefinementService(
ImmutableArray<RefinedTypeDetails> refinedTypeDetails
)
{
return $$"""
// <auto-generated/>
#nullable enable
namespace Tuxedo;
/// <summary>
/// Provides compile time support for running refinement predicates against compile time known values
/// </summary>
/// <remarks>
/// This is for compile time use and should not be used in application code
/// </remarks>
internal static class RefinementService
{
{{string.Join("\n\n", refinedTypeDetails.Select(RenderTestMethod))}}
}
""";
}

private static string RenderTestMethod(RefinedTypeDetails model)
{
return $$"""
private static string? TestAgainst{{model.RefinedType}}{{model.Generics}}(object value){{model.GenericConstraints.PrependIfNotNull(
"\n\t\t"
)}}
{
return {{(model.IsTuple
? $"!{model.Namespace}.{model.RefinedType}{model.Generics}.TryParse(({model.RawType})value, out _, out var errorMessage) ? errorMessage : null"
: $"value is {model.RawType} rt && !{model.Namespace}.{model.RefinedType}{model.Generics}.TryParse(rt, out _, out var errorMessage) ? errorMessage : null")}};
}
""";
}
}
20 changes: 18 additions & 2 deletions Tuxedo.SourceGenerator/Generators/RefinementSourceGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma warning disable MA0051

using System.Collections.Immutable;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -31,6 +32,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
);

context.RegisterSourceOutput(refinedTypeDetailsProvider, GenerateRefinedTypes);
context.RegisterSourceOutput(
refinedTypeDetailsProvider.Collect(),
GenerateRefinementService
);
}

private static bool IsRefinementMethod(SyntaxNode s, CancellationToken cancellationToken)
Expand Down Expand Up @@ -138,8 +143,8 @@ out string? constraints
}

generics = $"<{string.Join(", ", genericTypeArguments.Select(t => t.ToDisplayString()))}>";
var parts = methodDeclaration.ConstraintClauses.Select(x => x.ToString());
constraints = string.Join("\n", parts);
var parts = methodDeclaration.ConstraintClauses.Select(x => x.ToString()).ToArray();
constraints = parts.Length > 0 ? string.Join("\n", parts) : null;
}

private static string BuildSafeStructName(string name, string? parameterType)
Expand All @@ -161,4 +166,15 @@ RefinedTypeDetails refinedTypeDetails
SourceText.From(source, Encoding.UTF8)
);
}

private static void GenerateRefinementService(
SourceProductionContext context,
ImmutableArray<RefinedTypeDetails> refinedTypeDetails
)
{
context.AddSource(
"RefinementService.g.cs",
SourceText.From(RenderRefinementService(refinedTypeDetails), Encoding.UTF8)
);
}
}
8 changes: 8 additions & 0 deletions Tuxedo.Tests/BoolRefinementsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ namespace Tuxedo;
[AttributeUsage(AttributeTargets.Struct)]
internal sealed class RefinedTypeAttribute : Attribute {}
internal static class RefinementService
{
private static string? TestAgainstTrueBool(object value)
{
return value is bool rt && !Tuxedo.TrueBool.TryParse(rt, out _, out var errorMessage) ? errorMessage : null;
}
}
[RefinedType]
internal readonly partial struct TrueBool
{
Expand Down
5 changes: 4 additions & 1 deletion Tuxedo.Tests/Extensions/GeneratorDriverExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ internal static class Test
""".BuildDriver();
return Verify(driver, sourceFile: sourceFile)
.IgnoreGeneratedResult(x =>
x.HintName is "RefinementAttribute.g.cs" or "RefinedTypeAttribute.g.cs"
x.HintName
is "RefinementAttribute.g.cs"
or "RefinedTypeAttribute.g.cs"
or "RefinementService.g.cs"
);
}
}
35 changes: 35 additions & 0 deletions Tuxedo.Tests/SharedTests.Case1#RefinementService.g.verified.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//HintName: RefinementService.g.cs
// <auto-generated/>
#nullable enable

namespace Tuxedo;

/// <summary>
/// Provides compile time support for running refinement predicates against compile time known values
/// </summary>
/// <remarks>
/// This is for compile time use and should not be used in application code
/// </remarks>
internal static class RefinementService
{
private static string? TestAgainstTest1(object value)
{
return value is bool rt && !<global namespace>.Test1.TryParse(rt, out _, out var errorMessage) ? errorMessage : null;
}

private static string? TestAgainstTest2(object value)
{
return !<global namespace>.Test2.TryParse(((int a, int b))value, out _, out var errorMessage) ? errorMessage : null;
}

private static string? TestAgainstTest3<T>(object value)
{
return value is List<T> rt && !<global namespace>.Test3<T>.TryParse(rt, out _, out var errorMessage) ? errorMessage : null;
}

private static string? TestAgainstTest4<T>(object value)
where T: struct
{
return value is List<T> rt && !<global namespace>.Test4<T>.TryParse(rt, out _, out var errorMessage) ? errorMessage : null;
}
}
23 changes: 20 additions & 3 deletions Tuxedo.Tests/SharedTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,27 @@ namespace Tuxedo.Tests;

public sealed class SharedTests
{
[Fact(DisplayName = "A")]
[Fact(DisplayName = "All shared code renders correctly")]
public Task Case1()
{
var driver = GeneratorDriverExtensions.BuildDriver(source: null);
return Verify(driver);
var driver = """
using Tuxedo;
internal static class Test
{
[Refinement("test1", Name = "Test1")]
internal static bool Pred1(bool value) => !value;
[Refinement("test2", Name = "Test2")]
internal static bool Pred2((int a, int b) value) => true;
[Refinement("test2", Name = "Test3")]
internal static bool Pred3<T>(List<T> value) => true;
[Refinement("test2", Name = "Test4")]
internal static bool Pred4<T>(List<T> value) where T: struct => true;
}
""".BuildDriver();
return Verify(driver).IgnoreGeneratedResult(result => result.HintName.StartsWith("Test"));
}
}

0 comments on commit 0ac08d4

Please sign in to comment.