diff --git a/src/Analyzers/MSTest.Analyzers.CodeFixes/UseNewerAssertThrowsFixer.cs b/src/Analyzers/MSTest.Analyzers.CodeFixes/UseNewerAssertThrowsFixer.cs index 1077813834..5281f497ac 100644 --- a/src/Analyzers/MSTest.Analyzers.CodeFixes/UseNewerAssertThrowsFixer.cs +++ b/src/Analyzers/MSTest.Analyzers.CodeFixes/UseNewerAssertThrowsFixer.cs @@ -100,7 +100,7 @@ private static SyntaxNode UpdateMethodName(SyntaxEditor editor, InvocationExpres { editor.ReplaceNode( lambdaSyntax.ExpressionBody, - AssignToDiscard(lambdaSyntax.ExpressionBody)); + AssignToDiscardIfNeeded(lambdaSyntax.ExpressionBody)); } else if (lambdaSyntax.Block is not null) { @@ -126,7 +126,7 @@ private static SyntaxNode UpdateMethodName(SyntaxEditor editor, InvocationExpres continue; } - ExpressionStatementSyntax returnReplacement = SyntaxFactory.ExpressionStatement(AssignToDiscard(returnStatement.Expression)); + ExpressionStatementSyntax returnReplacement = SyntaxFactory.ExpressionStatement(AssignToDiscardIfNeeded(returnStatement.Expression)); if (returnStatement.Parent is BlockSyntax blockSyntax) { @@ -151,12 +151,24 @@ private static SyntaxNode UpdateMethodName(SyntaxEditor editor, InvocationExpres SyntaxFactory.ParenthesizedLambdaExpression( SyntaxFactory.ParameterList(), block: null, - expressionBody: AssignToDiscard(SyntaxFactory.InvocationExpression(SyntaxFactory.ParenthesizedExpression(expressionSyntax).WithAdditionalAnnotations(Simplifier.Annotation))))); + expressionBody: SyntaxFactory.InvocationExpression(SyntaxFactory.ParenthesizedExpression(expressionSyntax).WithAdditionalAnnotations(Simplifier.Annotation)))); } return editor.GetChangedRoot(); } - private static AssignmentExpressionSyntax AssignToDiscard(ExpressionSyntax expression) - => SyntaxFactory.AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, SyntaxFactory.IdentifierName("_"), expression); + private static ExpressionSyntax AssignToDiscardIfNeeded(ExpressionSyntax expression) + => NeedsDiscard(expression) + ? SyntaxFactory.AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, SyntaxFactory.IdentifierName("_"), expression) + : expression; + + private static bool NeedsDiscard(ExpressionSyntax expression) + => expression is not InvocationExpressionSyntax && + expression is not AssignmentExpressionSyntax && + !expression.IsKind(SyntaxKind.PostIncrementExpression) && + !expression.IsKind(SyntaxKind.PostDecrementExpression) && + !expression.IsKind(SyntaxKind.PreIncrementExpression) && + !expression.IsKind(SyntaxKind.PreDecrementExpression) && + expression is not AwaitExpressionSyntax && + expression is not ObjectCreationExpressionSyntax; } diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseNewerAssertThrowsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseNewerAssertThrowsAnalyzerTests.cs index 1bc0223a6b..38a1dc0f47 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseNewerAssertThrowsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseNewerAssertThrowsAnalyzerTests.cs @@ -250,7 +250,7 @@ public class MyTestClass public void MyTestMethod() { Func action = () => _ = 5; - Assert.ThrowsExactly(() => _ = action()); + Assert.ThrowsExactly(() => action()); } } """; @@ -288,11 +288,85 @@ public class MyTestClass public void MyTestMethod() { Func action = () => _ = 5; - Assert.ThrowsExactly(() => _ = (action + action)()); + Assert.ThrowsExactly(() => (action + action)()); } } """; await VerifyCS.VerifyCodeFixAsync(code, fixedCode); } + + [TestMethod] + public async Task VariousTestCasesForDiscard() + { + string code = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public sealed class Test1 + { + [TestMethod] + public void TestMethod1() + { + int[] numbers = [1]; + int x = 0; + string s = ""; + + [|Assert.ThrowsException(() => VoidMethod(1))|]; + [|Assert.ThrowsException(() => NonVoidMethod(1))|]; + [|Assert.ThrowsException(() => _ = NonVoidMethod(1))|]; + [|Assert.ThrowsException(() => new Test1())|]; + [|Assert.ThrowsException(() => _ = new Test1())|]; + [|Assert.ThrowsException(() => numbers[0] = 4)|]; + [|Assert.ThrowsException(() => x++)|]; + [|Assert.ThrowsException(() => x--)|]; + [|Assert.ThrowsException(() => ++x)|]; + [|Assert.ThrowsException(() => --x)|]; + [|Assert.ThrowsException(() => s!)|]; + [|Assert.ThrowsException(() => !true)|]; + } + + private void VoidMethod(object o) => _ = o; + + private int NonVoidMethod(int i) => i; + } + """; + + string fixedCode = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public sealed class Test1 + { + [TestMethod] + public void TestMethod1() + { + int[] numbers = [1]; + int x = 0; + string s = ""; + + Assert.ThrowsExactly(() => VoidMethod(1)); + Assert.ThrowsExactly(() => NonVoidMethod(1)); + Assert.ThrowsExactly(() => _ = NonVoidMethod(1)); + Assert.ThrowsExactly(() => new Test1()); + Assert.ThrowsExactly(() => _ = new Test1()); + Assert.ThrowsExactly(() => numbers[0] = 4); + Assert.ThrowsExactly(() => x++); + Assert.ThrowsExactly(() => x--); + Assert.ThrowsExactly(() => ++x); + Assert.ThrowsExactly(() => --x); + Assert.ThrowsExactly(() => _ = s!); + Assert.ThrowsExactly(() => _ = !true); + } + + private void VoidMethod(object o) => _ = o; + + private int NonVoidMethod(int i) => i; + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } }