From e7d298f0f00d4f70adac1f5fcabc6352cf54b514 Mon Sep 17 00:00:00 2001 From: Volodymyr Dombrovskyi Date: Sat, 22 Jun 2024 21:49:40 +0200 Subject: [PATCH] extract BaseIdGenerator --- StrongTypeIdGenerator/BaseIdGenerator.cs | 118 +++++++++++++++++++++ StrongTypeIdGenerator/GuidIdGenerator.cs | 93 +--------------- StrongTypeIdGenerator/StringIdGenerator.cs | 90 +--------------- 3 files changed, 128 insertions(+), 173 deletions(-) create mode 100644 StrongTypeIdGenerator/BaseIdGenerator.cs diff --git a/StrongTypeIdGenerator/BaseIdGenerator.cs b/StrongTypeIdGenerator/BaseIdGenerator.cs new file mode 100644 index 0000000..86d3489 --- /dev/null +++ b/StrongTypeIdGenerator/BaseIdGenerator.cs @@ -0,0 +1,118 @@ +namespace StrongTypeIdGenerator.Analyzer +{ + using Microsoft.CodeAnalysis; + using Microsoft.CodeAnalysis.CSharp; + using Microsoft.CodeAnalysis.CSharp.Syntax; + using System; + using System.Collections.Immutable; + + public abstract class BaseIdGenerator : IIncrementalGenerator + { + public void Initialize(IncrementalGeneratorInitializationContext context) + { + IncrementalValuesProvider classDeclarations = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (s, _) => IsSyntaxTargetForGeneration(s), + transform: (ctx, _) => GetSemanticTargetForGeneration(ctx)) + .Where(static m => m is not null)!; + + IncrementalValueProvider<(Compilation, ImmutableArray)> compilationAndClasses = + context.CompilationProvider.Combine(classDeclarations.Collect()); + + context.RegisterSourceOutput(compilationAndClasses, + (spc, source) => Execute(source.Item1, source.Item2, spc)); + } + + protected static bool IsSyntaxTargetForGeneration(SyntaxNode node) => + node is ClassDeclarationSyntax classDeclaration && + classDeclaration.AttributeLists.Count > 0; + + private ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context) + { + var classDeclarationSyntax = (ClassDeclarationSyntax)context.Node; + + foreach (var attributeListSyntax in classDeclarationSyntax.AttributeLists) + { + foreach (var attributeSyntax in attributeListSyntax.Attributes) + { + if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is IMethodSymbol attributeSymbol) + { + if (attributeSymbol.ContainingType.ToDisplayString() == MarkerAttributeFullName) + { + return classDeclarationSyntax; + } + } + } + } + + return null; + } + + protected abstract string MarkerAttributeFullName { get; } + + protected abstract INamedTypeSymbol GetIdTypeSymbol(Compilation compilation); + + protected abstract void Execute(Compilation compilation, ImmutableArray classes, SourceProductionContext context); + + protected static string? GetNamespace(ClassDeclarationSyntax classDeclaration) + { + if (classDeclaration is null) + { + throw new ArgumentNullException(nameof(classDeclaration)); + } + + // Walk upwards until we find the namespace declaration + SyntaxNode? potentialNamespaceParent = classDeclaration.Parent; + while (potentialNamespaceParent != null && + potentialNamespaceParent is not NamespaceDeclarationSyntax && + potentialNamespaceParent is not FileScopedNamespaceDeclarationSyntax) + { + potentialNamespaceParent = potentialNamespaceParent.Parent; + } + + // Return the namespace name if it was found, otherwise null + if (potentialNamespaceParent is BaseNamespaceDeclarationSyntax namespaceDeclaration) + { + return namespaceDeclaration.Name.ToString(); + } + + return null; + } + + protected bool HasCheckValueMethod(Compilation compilation, ClassDeclarationSyntax classDeclaration) + { + if (compilation is null) + { + throw new ArgumentNullException(nameof(compilation)); + } + + if (classDeclaration is null) + { + throw new ArgumentNullException(nameof(classDeclaration)); + } + + var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree); + var classSymbol = semanticModel.GetDeclaredSymbol(classDeclaration) as INamedTypeSymbol; + if (classSymbol is null) + { + return false; + } + + var idTypeSymbol = GetIdTypeSymbol(compilation); + + foreach (var member in classSymbol.GetMembers("CheckValue")) + { + if (member is IMethodSymbol methodSymbol && + methodSymbol.IsStatic && + methodSymbol.DeclaredAccessibility == Accessibility.Private && + methodSymbol.Parameters.Length == 1 && + SymbolEqualityComparer.Default.Equals(methodSymbol.Parameters[0].Type, idTypeSymbol)) + { + return true; + } + } + + return false; + } + } +} \ No newline at end of file diff --git a/StrongTypeIdGenerator/GuidIdGenerator.cs b/StrongTypeIdGenerator/GuidIdGenerator.cs index 57e1fba..760c1d8 100644 --- a/StrongTypeIdGenerator/GuidIdGenerator.cs +++ b/StrongTypeIdGenerator/GuidIdGenerator.cs @@ -4,57 +4,20 @@ namespace StrongTypeIdGenerator.Analyzer using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; - using System; using System.Collections.Immutable; using System.Text; [Generator] - public sealed class GuidIdGenerator : IIncrementalGenerator + public sealed class GuidIdGenerator : BaseIdGenerator { - public void Initialize(IncrementalGeneratorInitializationContext context) - { - IncrementalValuesProvider classDeclarations = context.SyntaxProvider - .CreateSyntaxProvider( - predicate: static (s, _) => IsSyntaxTargetForGeneration(s), - transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx)) - .Where(static m => m is not null)!; - - // Combine the selected class declarations with the compilation - IncrementalValueProvider<(Compilation, ImmutableArray)> compilationAndClasses = - context.CompilationProvider.Combine(classDeclarations.Collect()); - - // Generate the source code - context.RegisterSourceOutput(compilationAndClasses, - static (spc, source) => Execute(source.Item1, source.Item2, spc)); - } + protected override string MarkerAttributeFullName => "StrongTypeIdGenerator.GuidIdAttribute"; - static bool IsSyntaxTargetForGeneration(SyntaxNode node) => - node is ClassDeclarationSyntax classDeclaration && - classDeclaration.AttributeLists.Count > 0; - - static ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context) + protected override INamedTypeSymbol GetIdTypeSymbol(Compilation compilation) { - var classDeclarationSyntax = (ClassDeclarationSyntax)context.Node; - - // Check if the class has the GuidIdAttribute - foreach (var attributeListSyntax in classDeclarationSyntax.AttributeLists) - { - foreach (var attributeSyntax in attributeListSyntax.Attributes) - { - if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is IMethodSymbol attributeSymbol) - { - if (attributeSymbol.ContainingType.ToDisplayString() == "StrongTypeIdGenerator.GuidIdAttribute") - { - return classDeclarationSyntax; - } - } - } - } - - return null; + return compilation.GetTypeByMetadataName("System.Guid")!; } - static void Execute(Compilation compilation, ImmutableArray classes, SourceProductionContext context) + protected override void Execute(Compilation compilation, ImmutableArray classes, SourceProductionContext context) { if (classes.IsDefaultOrEmpty) { @@ -73,52 +36,6 @@ static void Execute(Compilation compilation, ImmutableArray classDeclarations = context.SyntaxProvider - .CreateSyntaxProvider( - predicate: static (s, _) => IsSyntaxTargetForGeneration(s), - transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx)) - .Where(static m => m is not null)!; - - // Combine the selected class declarations with the compilation - IncrementalValueProvider<(Compilation, ImmutableArray)> compilationAndClasses = - context.CompilationProvider.Combine(classDeclarations.Collect()); - - // Generate the source code - context.RegisterSourceOutput(compilationAndClasses, - static (spc, source) => Execute(source.Item1, source.Item2, spc)); - } - - static bool IsSyntaxTargetForGeneration(SyntaxNode node) => - node is ClassDeclarationSyntax classDeclaration && - classDeclaration.AttributeLists.Count > 0; + protected override string MarkerAttributeFullName => "StrongTypeIdGenerator.StringIdAttribute"; - static ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context) + protected override INamedTypeSymbol GetIdTypeSymbol(Compilation compilation) { - var classDeclarationSyntax = (ClassDeclarationSyntax)context.Node; - - // Check if the class has the StrongTypeId attribute - foreach (var attributeListSyntax in classDeclarationSyntax.AttributeLists) - { - foreach (var attributeSyntax in attributeListSyntax.Attributes) - { - if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is IMethodSymbol attributeSymbol) - { - if (attributeSymbol.ContainingType.ToDisplayString() == "StrongTypeIdGenerator.StringIdAttribute") - { - return classDeclarationSyntax; - } - } - } - } - - return null; + return compilation.GetTypeByMetadataName("System.String")!; } - static void Execute(Compilation compilation, ImmutableArray classes, SourceProductionContext context) + protected override void Execute(Compilation compilation, ImmutableArray classes, SourceProductionContext context) { if (classes.IsDefaultOrEmpty) { @@ -75,25 +38,6 @@ static void Execute(Compilation compilation, ImmutableArray