Skip to content

Commit

Permalink
extract BaseIdGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
dombrovsky committed Jun 22, 2024
1 parent 2746c54 commit e7d298f
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 173 deletions.
118 changes: 118 additions & 0 deletions StrongTypeIdGenerator/BaseIdGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
namespace StrongTypeIdGenerator.Analyzer
{
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Immutable;

public abstract class BaseIdGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
IncrementalValuesProvider<ClassDeclarationSyntax> classDeclarations = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
transform: (ctx, _) => GetSemanticTargetForGeneration(ctx))
.Where(static m => m is not null)!;

IncrementalValueProvider<(Compilation, ImmutableArray<ClassDeclarationSyntax>)> compilationAndClasses =
context.CompilationProvider.Combine(classDeclarations.Collect());

context.RegisterSourceOutput(compilationAndClasses,
(spc, source) => Execute(source.Item1, source.Item2, spc));
}

protected static bool IsSyntaxTargetForGeneration(SyntaxNode node) =>
node is ClassDeclarationSyntax classDeclaration &&
classDeclaration.AttributeLists.Count > 0;

private ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
{
var classDeclarationSyntax = (ClassDeclarationSyntax)context.Node;

foreach (var attributeListSyntax in classDeclarationSyntax.AttributeLists)
{
foreach (var attributeSyntax in attributeListSyntax.Attributes)
{
if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is IMethodSymbol attributeSymbol)
{
if (attributeSymbol.ContainingType.ToDisplayString() == MarkerAttributeFullName)
{
return classDeclarationSyntax;
}
}
}
}

return null;
}

protected abstract string MarkerAttributeFullName { get; }

protected abstract INamedTypeSymbol GetIdTypeSymbol(Compilation compilation);

protected abstract void Execute(Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes, SourceProductionContext context);

protected static string? GetNamespace(ClassDeclarationSyntax classDeclaration)
{
if (classDeclaration is null)
{
throw new ArgumentNullException(nameof(classDeclaration));
}

// Walk upwards until we find the namespace declaration
SyntaxNode? potentialNamespaceParent = classDeclaration.Parent;
while (potentialNamespaceParent != null &&
potentialNamespaceParent is not NamespaceDeclarationSyntax &&
potentialNamespaceParent is not FileScopedNamespaceDeclarationSyntax)
{
potentialNamespaceParent = potentialNamespaceParent.Parent;
}

// Return the namespace name if it was found, otherwise null
if (potentialNamespaceParent is BaseNamespaceDeclarationSyntax namespaceDeclaration)
{
return namespaceDeclaration.Name.ToString();
}

return null;
}

protected bool HasCheckValueMethod(Compilation compilation, ClassDeclarationSyntax classDeclaration)
{
if (compilation is null)
{
throw new ArgumentNullException(nameof(compilation));
}

if (classDeclaration is null)
{
throw new ArgumentNullException(nameof(classDeclaration));
}

var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
var classSymbol = semanticModel.GetDeclaredSymbol(classDeclaration) as INamedTypeSymbol;
if (classSymbol is null)
{
return false;
}

var idTypeSymbol = GetIdTypeSymbol(compilation);

foreach (var member in classSymbol.GetMembers("CheckValue"))
{
if (member is IMethodSymbol methodSymbol &&
methodSymbol.IsStatic &&
methodSymbol.DeclaredAccessibility == Accessibility.Private &&
methodSymbol.Parameters.Length == 1 &&
SymbolEqualityComparer.Default.Equals(methodSymbol.Parameters[0].Type, idTypeSymbol))
{
return true;
}
}

return false;
}
}
}
93 changes: 5 additions & 88 deletions StrongTypeIdGenerator/GuidIdGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,20 @@ namespace StrongTypeIdGenerator.Analyzer
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using System;
using System.Collections.Immutable;
using System.Text;

[Generator]
public sealed class GuidIdGenerator : IIncrementalGenerator
public sealed class GuidIdGenerator : BaseIdGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
IncrementalValuesProvider<ClassDeclarationSyntax> classDeclarations = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx))
.Where(static m => m is not null)!;

// Combine the selected class declarations with the compilation
IncrementalValueProvider<(Compilation, ImmutableArray<ClassDeclarationSyntax>)> compilationAndClasses =
context.CompilationProvider.Combine(classDeclarations.Collect());

// Generate the source code
context.RegisterSourceOutput(compilationAndClasses,
static (spc, source) => Execute(source.Item1, source.Item2, spc));
}
protected override string MarkerAttributeFullName => "StrongTypeIdGenerator.GuidIdAttribute";

static bool IsSyntaxTargetForGeneration(SyntaxNode node) =>
node is ClassDeclarationSyntax classDeclaration &&
classDeclaration.AttributeLists.Count > 0;

static ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
protected override INamedTypeSymbol GetIdTypeSymbol(Compilation compilation)
{
var classDeclarationSyntax = (ClassDeclarationSyntax)context.Node;

// Check if the class has the GuidIdAttribute
foreach (var attributeListSyntax in classDeclarationSyntax.AttributeLists)
{
foreach (var attributeSyntax in attributeListSyntax.Attributes)
{
if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is IMethodSymbol attributeSymbol)
{
if (attributeSymbol.ContainingType.ToDisplayString() == "StrongTypeIdGenerator.GuidIdAttribute")
{
return classDeclarationSyntax;
}
}
}
}

return null;
return compilation.GetTypeByMetadataName("System.Guid")!;
}

static void Execute(Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes, SourceProductionContext context)
protected override void Execute(Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes, SourceProductionContext context)
{
if (classes.IsDefaultOrEmpty)
{
Expand All @@ -73,52 +36,6 @@ static void Execute(Compilation compilation, ImmutableArray<ClassDeclarationSynt
}
}

static string? GetNamespace(ClassDeclarationSyntax classDeclaration)
{
// Walk upwards until we find the namespace declaration
SyntaxNode? potentialNamespaceParent = classDeclaration.Parent;
while (potentialNamespaceParent != null &&
potentialNamespaceParent is not NamespaceDeclarationSyntax &&
potentialNamespaceParent is not FileScopedNamespaceDeclarationSyntax)
{
potentialNamespaceParent = potentialNamespaceParent.Parent;
}

// Return the namespace name if it was found, otherwise null
if (potentialNamespaceParent is BaseNamespaceDeclarationSyntax namespaceDeclaration)
{
return namespaceDeclaration.Name.ToString();
}

return null;
}

static bool HasCheckValueMethod(Compilation compilation, ClassDeclarationSyntax classDeclaration)
{
var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
var classSymbol = semanticModel.GetDeclaredSymbol(classDeclaration) as INamedTypeSymbol;
if (classSymbol is null)
{
return false;
}

var guidTypeSymbol = compilation.GetTypeByMetadataName("System.Guid");

foreach (var member in classSymbol.GetMembers("CheckValue"))
{
if (member is IMethodSymbol methodSymbol &&
methodSymbol.IsStatic &&
methodSymbol.DeclaredAccessibility == Accessibility.Private &&
methodSymbol.Parameters.Length == 1 &&
SymbolEqualityComparer.Default.Equals(methodSymbol.Parameters[0].Type, guidTypeSymbol))
{
return true;
}
}

return false;
}

static string GenerateStrongTypeIdClass(string? namespaceName, string className, bool hasCheckValueMethod)
{
const string TIdentifier = "Guid";
Expand Down
90 changes: 5 additions & 85 deletions StrongTypeIdGenerator/StringIdGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,53 +8,16 @@ namespace StrongTypeIdGenerator.Analyzer
using System.Text;

[Generator]
public sealed class StringIdGenerator : IIncrementalGenerator
public sealed class StringIdGenerator : BaseIdGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Find all classes with the StringIdentityAttribute
IncrementalValuesProvider<ClassDeclarationSyntax> classDeclarations = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx))
.Where(static m => m is not null)!;

// Combine the selected class declarations with the compilation
IncrementalValueProvider<(Compilation, ImmutableArray<ClassDeclarationSyntax>)> compilationAndClasses =
context.CompilationProvider.Combine(classDeclarations.Collect());

// Generate the source code
context.RegisterSourceOutput(compilationAndClasses,
static (spc, source) => Execute(source.Item1, source.Item2, spc));
}

static bool IsSyntaxTargetForGeneration(SyntaxNode node) =>
node is ClassDeclarationSyntax classDeclaration &&
classDeclaration.AttributeLists.Count > 0;
protected override string MarkerAttributeFullName => "StrongTypeIdGenerator.StringIdAttribute";

static ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
protected override INamedTypeSymbol GetIdTypeSymbol(Compilation compilation)
{
var classDeclarationSyntax = (ClassDeclarationSyntax)context.Node;

// Check if the class has the StrongTypeId attribute
foreach (var attributeListSyntax in classDeclarationSyntax.AttributeLists)
{
foreach (var attributeSyntax in attributeListSyntax.Attributes)
{
if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is IMethodSymbol attributeSymbol)
{
if (attributeSymbol.ContainingType.ToDisplayString() == "StrongTypeIdGenerator.StringIdAttribute")
{
return classDeclarationSyntax;
}
}
}
}

return null;
return compilation.GetTypeByMetadataName("System.String")!;
}

static void Execute(Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes, SourceProductionContext context)
protected override void Execute(Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes, SourceProductionContext context)
{
if (classes.IsDefaultOrEmpty)
{
Expand All @@ -75,25 +38,6 @@ static void Execute(Compilation compilation, ImmutableArray<ClassDeclarationSynt
}
}

static string? GetNamespace(ClassDeclarationSyntax classDeclaration)
{
// Walk upwards until we find the namespace declaration
SyntaxNode? potentialNamespaceParent = classDeclaration.Parent;
while (potentialNamespaceParent != null &&
potentialNamespaceParent is not NamespaceDeclarationSyntax &&
potentialNamespaceParent is not FileScopedNamespaceDeclarationSyntax)
{
potentialNamespaceParent = potentialNamespaceParent.Parent;
}

// Return the namespace name if it was found, otherwise null
if (potentialNamespaceParent is BaseNamespaceDeclarationSyntax namespaceDeclaration)
{
return namespaceDeclaration.Name.ToString();
}

return null;
}
static AttributeSyntax GetStringIdAttributeSyntax(ClassDeclarationSyntax classDeclaration)
{
foreach (var attributeListSyntax in classDeclaration.AttributeLists)
Expand Down Expand Up @@ -133,30 +77,6 @@ static bool GetGenerateConstructorPrivate(Compilation compilation, AttributeSynt
return false;
}

static bool HasCheckValueMethod(Compilation compilation, ClassDeclarationSyntax classDeclaration)
{
var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
var classSymbol = semanticModel.GetDeclaredSymbol(classDeclaration) as INamedTypeSymbol;
if (classSymbol == null)
{
return false;
}

foreach (var member in classSymbol.GetMembers("CheckValue"))
{
if (member is IMethodSymbol methodSymbol &&
methodSymbol.IsStatic &&
methodSymbol.DeclaredAccessibility == Accessibility.Private &&
methodSymbol.Parameters.Length == 1 &&
methodSymbol.Parameters[0].Type.SpecialType == SpecialType.System_String)
{
return true;
}
}

return false;
}

static string GenerateStrongTypeIdClass(string? namespaceName, string className, bool generateConstructorPrivate, bool hasCheckValueMethod)
{
const string TIdentifier = "string";
Expand Down

0 comments on commit e7d298f

Please sign in to comment.