Skip to content

Commit 8c11bc8

Browse files
PhenXCopilot
andauthored
Apply suggestions from code review
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent 9ff8eb6 commit 8c11bc8

2 files changed

Lines changed: 67 additions & 27 deletions

File tree

src/EntityFrameworkCore.Projectables.CodeFixes/FactoryMethodTransformationHelper.cs

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Microsoft.CodeAnalysis.CSharp.Syntax;
44
using Microsoft.CodeAnalysis.FindSymbols;
55
using Microsoft.CodeAnalysis.Formatting;
6+
using Microsoft.CodeAnalysis.Simplification;
67

78
namespace EntityFrameworkCore.Projectables.CodeFixes;
89

@@ -42,6 +43,13 @@ static internal bool TryGetFactoryMethodPattern(
4243
return false;
4344
}
4445

46+
// Only support static factory methods; instance factories would drop the receiver
47+
// when transformed to a constructor call, which can change semantics.
48+
if (!method.Modifiers.Any(SyntaxKind.StaticKeyword))
49+
{
50+
return false;
51+
}
52+
4553
if (method.ExpressionBody is null)
4654
{
4755
return false;
@@ -127,24 +135,22 @@ async static internal Task<Solution> ConvertToConstructorAndUpdateCallersAsync(
127135
}
128136

129137
var solution = document.Project.Solution;
130-
var returnTypeName = containingType.Identifier.Text;
138+
var returnType = methodSymbol.ReturnType;
139+
var returnTypeSyntax = SyntaxFactory
140+
.ParseTypeName(returnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
141+
.WithAdditionalAnnotations(Simplifier.Annotation);
131142

132143
// Find all callers BEFORE modifying the solution so that spans are still valid.
133144
var references = await SymbolFinder
134145
.FindReferencesAsync(methodSymbol, solution, cancellationToken)
135146
.ConfigureAwait(false);
136147

137-
// Group locations by document, excluding the declaring document (handled separately below).
148+
// Group locations by document (including the declaring document).
138149
var locationsByDoc = new Dictionary<DocumentId, List<ReferenceLocation>>();
139150
foreach (var referencedSymbol in references)
140151
{
141152
foreach (var refLocation in referencedSymbol.Locations)
142153
{
143-
if (refLocation.Document.Id == document.Id)
144-
{
145-
continue;
146-
}
147-
148154
if (!locationsByDoc.TryGetValue(refLocation.Document.Id, out var list))
149155
{
150156
list = [];
@@ -198,12 +204,19 @@ async static internal Task<Solution> ConvertToConstructorAndUpdateCallersAsync(
198204
continue;
199205
}
200206

207+
// Skip conditional-access invocations like x?.FactoryMethod(...)
208+
// to avoid producing invalid syntax such as x?.new ReturnType(...).
209+
if (invocation.Parent is ConditionalAccessExpressionSyntax)
210+
{
211+
continue;
212+
}
213+
201214
// Rewrite: instance.FactoryMethod(args) → new ReturnType(args)
202215
var newCreation = SyntaxFactory
203216
.ObjectCreationExpression(
204217
SyntaxFactory.Token(SyntaxKind.NewKeyword)
205218
.WithTrailingTrivia(SyntaxFactory.Space),
206-
SyntaxFactory.IdentifierName(returnTypeName),
219+
returnTypeSyntax,
207220
invocation.ArgumentList,
208221
initializer: null)
209222
.WithLeadingTrivia(invocation.GetLeadingTrivia())
@@ -231,18 +244,26 @@ private static SyntaxNode BuildRootWithConstructor(
231244
var creation = (ObjectCreationExpressionSyntax)method.ExpressionBody!.Expression;
232245
var initializer = creation.Initializer!;
233246

247+
// Only support simple object-initializer assignments (Prop = value). If there are
248+
// other initializer forms (e.g., collection initializers), bail out to avoid
249+
// producing a constructor that does not preserve behavior.
250+
if (initializer.Expressions.Any(e => e is not AssignmentExpressionSyntax))
251+
{
252+
return root;
253+
}
254+
234255
// Convert each object-initializer assignment (Prop = value) to a statement (Prop = value;).
235256
var statements = initializer.Expressions
236257
.OfType<AssignmentExpressionSyntax>()
237258
.Select(a => (StatementSyntax)SyntaxFactory.ExpressionStatement(a))
238259
.ToArray();
239260

261+
var ctorModifiers = GetConstructorModifiers(method);
262+
240263
var ctor = SyntaxFactory
241264
.ConstructorDeclaration(containingType.Identifier.WithoutTrivia())
242265
.WithAttributeLists(method.AttributeLists)
243-
.WithModifiers(SyntaxFactory.TokenList(
244-
SyntaxFactory.Token(SyntaxKind.PublicKeyword)
245-
.WithTrailingTrivia(SyntaxFactory.Space)))
266+
.WithModifiers(ctorModifiers)
246267
.WithParameterList(method.ParameterList)
247268
.WithBody(SyntaxFactory.Block(statements))
248269
.WithAdditionalAnnotations(Formatter.Annotation)
@@ -262,9 +283,7 @@ private static SyntaxNode BuildRootWithConstructor(
262283
{
263284
var paramlessCtor = SyntaxFactory
264285
.ConstructorDeclaration(containingType.Identifier.WithoutTrivia())
265-
.WithModifiers(SyntaxFactory.TokenList(
266-
SyntaxFactory.Token(SyntaxKind.PublicKeyword)
267-
.WithTrailingTrivia(SyntaxFactory.Space)))
286+
.WithModifiers(ctorModifiers)
268287
.WithParameterList(SyntaxFactory.ParameterList())
269288
.WithBody(SyntaxFactory.Block())
270289
.WithAdditionalAnnotations(Formatter.Annotation)
@@ -275,5 +294,19 @@ private static SyntaxNode BuildRootWithConstructor(
275294

276295
return root.ReplaceNode(containingType, containingType.WithMembers(newMembers));
277296
}
297+
298+
private static SyntaxTokenList GetConstructorModifiers(MethodDeclarationSyntax method)
299+
{
300+
// Derive constructor modifiers from the factory method, dropping modifiers that are
301+
// invalid or meaningless for instance constructors (e.g., static, async, extern, unsafe).
302+
var filteredModifiers = method.Modifiers
303+
.Where(m =>
304+
m.Kind() != SyntaxKind.StaticKeyword &&
305+
m.Kind() != SyntaxKind.AsyncKeyword &&
306+
m.Kind() != SyntaxKind.ExternKeyword &&
307+
m.Kind() != SyntaxKind.UnsafeKeyword);
308+
309+
return SyntaxFactory.TokenList(filteredModifiers);
310+
}
278311
}
279312

src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ private static void ReportFactoryMethodDiagnosticIfApplicable(
299299
MethodDeclarationSyntax method,
300300
SourceProductionContext context)
301301
{
302-
if (method.Parent is not TypeDeclarationSyntax containingType)
302+
if (method.Parent is not TypeDeclarationSyntax)
303303
{
304304
return;
305305
}
@@ -325,10 +325,25 @@ private static void ReportFactoryMethodDiagnosticIfApplicable(
325325
return;
326326
}
327327

328-
// The return type's simple name must equal the containing class name.
329-
var containingTypeName = containingType.Identifier.Text;
330-
if (GetFactorySimpleTypeName(method.ReturnType) != containingTypeName
331-
|| GetFactorySimpleTypeName(creation.Type) != containingTypeName)
328+
if (memberSymbol is not IMethodSymbol methodSymbol)
329+
{
330+
return;
331+
}
332+
333+
var containingTypeSymbol = methodSymbol.ContainingType;
334+
if (containingTypeSymbol is null)
335+
{
336+
return;
337+
}
338+
339+
var createdTypeSymbol = semanticModel.GetTypeInfo(creation).Type;
340+
if (createdTypeSymbol is null)
341+
{
342+
return;
343+
}
344+
345+
if (!SymbolEqualityComparer.Default.Equals(methodSymbol.ReturnType, containingTypeSymbol)
346+
|| !SymbolEqualityComparer.Default.Equals(createdTypeSymbol, containingTypeSymbol))
332347
{
333348
return;
334349
}
@@ -339,14 +354,6 @@ private static void ReportFactoryMethodDiagnosticIfApplicable(
339354
method.Identifier.Text));
340355
}
341356

342-
private static string? GetFactorySimpleTypeName(TypeSyntax type) =>
343-
type switch
344-
{
345-
IdentifierNameSyntax id => id.Identifier.Text,
346-
QualifiedNameSyntax qn => qn.Right.Identifier.Text,
347-
_ => null
348-
};
349-
350357
/// <summary>
351358
/// Extracts a <see cref="ProjectionRegistryEntry"/> from a member declaration.
352359
/// Returns null when the member does not have [Projectable], is an extension member,

0 commit comments

Comments
 (0)