diff --git a/src/Analyzers/CSharp/Analyzers/RemoveRedundantEquality/CSharpRemoveRedundantEqualityDiagnosticAnalyzer.cs b/src/Analyzers/CSharp/Analyzers/RemoveRedundantEquality/CSharpRemoveRedundantEqualityDiagnosticAnalyzer.cs index 6a953341dc596..3317b6041a785 100644 --- a/src/Analyzers/CSharp/Analyzers/RemoveRedundantEquality/CSharpRemoveRedundantEqualityDiagnosticAnalyzer.cs +++ b/src/Analyzers/CSharp/Analyzers/RemoveRedundantEquality/CSharpRemoveRedundantEqualityDiagnosticAnalyzer.cs @@ -9,10 +9,7 @@ namespace Microsoft.CodeAnalysis.CSharp.RemoveRedundantEquality; [DiagnosticAnalyzer(LanguageNames.CSharp)] -internal sealed class CSharpRemoveRedundantEqualityDiagnosticAnalyzer - : AbstractRemoveRedundantEqualityDiagnosticAnalyzer +internal sealed class CSharpRemoveRedundantEqualityDiagnosticAnalyzer() + : AbstractRemoveRedundantEqualityDiagnosticAnalyzer(CSharpSyntaxFacts.Instance) { - public CSharpRemoveRedundantEqualityDiagnosticAnalyzer() : base(CSharpSyntaxFacts.Instance) - { - } } diff --git a/src/Analyzers/CSharp/Tests/RemoveRedundantEquality/RemoveRedundantEqualityTests.cs b/src/Analyzers/CSharp/Tests/RemoveRedundantEquality/RemoveRedundantEqualityTests.cs index 88aeee52301c9..a5b541e8a8bac 100644 --- a/src/Analyzers/CSharp/Tests/RemoveRedundantEquality/RemoveRedundantEqualityTests.cs +++ b/src/Analyzers/CSharp/Tests/RemoveRedundantEquality/RemoveRedundantEqualityTests.cs @@ -15,7 +15,7 @@ namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.RemoveRedundantEquality CSharpRemoveRedundantEqualityDiagnosticAnalyzer, RemoveRedundantEqualityCodeFixProvider>; -public class RemoveRedundantEqualityTests +public sealed class RemoveRedundantEqualityTests { [Fact] public async Task TestSimpleCaseForEqualsTrue() @@ -88,7 +88,7 @@ public bool M1(bool x) } [Fact] - public async Task TestSimpleCaseForNotEqualsTrue_NoDiagnostics() + public async Task TestSimpleCaseForNotEqualsTrue() { await VerifyCS.VerifyCodeFixAsync(""" public class C @@ -252,7 +252,7 @@ public bool M3(bool x) } [Fact, WorkItem("https://github.com/dotnet/roslyn/issues/48236")] - public async Task TestNullableValueTypes_DoesntCrash() + public async Task TestNullableValueTypes_DoesNotCrash() { var code = """ public class C @@ -265,4 +265,83 @@ public bool M1(int? x) """; await VerifyCS.VerifyAnalyzerAsync(code); } + + [Fact] + public async Task TestSimpleCaseForIsFalse() + { + var code = """ + public class C + { + public bool M1(bool x) + { + return x [|is|] false; + } + } + """; + var fixedCode = """ + public class C + { + public bool M1(bool x) + { + return !x; + } + } + """; + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [Fact] + public async Task TestSimpleCaseForIsTrue1() + { + var code = """ + public class C + { + public bool M1(bool x) + { + return x [|is|] true; + } + } + """; + var fixedCode = """ + public class C + { + public bool M1(bool x) + { + return x; + } + } + """; + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [Fact] + public async Task TestSimpleCaseForIsTrue2() + { + var code = """ + public class C + { + public const bool MyTrueConstant = true; + public bool M1(bool x) + { + return x is MyTrueConstant; + } + } + """; + await VerifyCS.VerifyCodeFixAsync(code, code); + } + + [Fact] + public async Task TestNotForNullableBool() + { + var code = """ + public class C + { + public bool M1(bool? x) + { + return x is true; + } + } + """; + await VerifyCS.VerifyCodeFixAsync(code, code); + } } diff --git a/src/Analyzers/Core/Analyzers/RemoveRedundantEquality/AbstractRemoveRedundantEqualityDiagnosticAnalyzer.cs b/src/Analyzers/Core/Analyzers/RemoveRedundantEquality/AbstractRemoveRedundantEqualityDiagnosticAnalyzer.cs index 8be778a4b7e10..7e1944ab19980 100644 --- a/src/Analyzers/Core/Analyzers/RemoveRedundantEquality/AbstractRemoveRedundantEqualityDiagnosticAnalyzer.cs +++ b/src/Analyzers/Core/Analyzers/RemoveRedundantEquality/AbstractRemoveRedundantEqualityDiagnosticAnalyzer.cs @@ -21,13 +21,42 @@ public override DiagnosticAnalyzerCategory GetAnalyzerCategory() => DiagnosticAnalyzerCategory.SemanticSpanAnalysis; protected override void InitializeWorker(AnalysisContext context) - => context.RegisterOperationAction(AnalyzeBinaryOperator, OperationKind.BinaryOperator); + { + context.RegisterOperationAction(AnalyzeBinaryOperator, OperationKind.BinaryOperator); + context.RegisterOperationAction(AnalyzeIsPatternOperator, OperationKind.IsPattern); + } + + private void AnalyzeIsPatternOperator(OperationAnalysisContext context) + { + if (ShouldSkipAnalysis(context, notification: null)) + return; + + var syntax = context.Operation.Syntax; + if (!syntaxFacts.IsIsPatternExpression(syntax)) + return; + + var operation = (IIsPatternOperation)context.Operation; + if (operation.Pattern is not IConstantPatternOperation { Value.ConstantValue.Value: true or false } constantPattern) + return; + + syntaxFacts.GetPartsOfIsPatternExpression(syntax, out _, out var isToken, out _); + AnalyzeOperator( + context, + leftOperand: operation.Value, + rightOperand: constantPattern.Value, + isOperatorEquals: true, + isToken); + } private void AnalyzeBinaryOperator(OperationAnalysisContext context) { if (ShouldSkipAnalysis(context, notification: null)) return; + var syntax = context.Operation.Syntax; + if (!syntaxFacts.IsBinaryExpression(syntax)) + return; + // We shouldn't report diagnostic on overloaded operator as the behavior can change. var operation = (IBinaryOperation)context.Operation; if (operation.OperatorMethod is not null) @@ -36,26 +65,32 @@ private void AnalyzeBinaryOperator(OperationAnalysisContext context) if (operation.OperatorKind is not (BinaryOperatorKind.Equals or BinaryOperatorKind.NotEquals)) return; - if (!syntaxFacts.IsBinaryExpression(operation.Syntax)) - { - return; - } - - var rightOperand = operation.RightOperand; - var leftOperand = operation.LeftOperand; + var isOperatorEquals = operation.OperatorKind == BinaryOperatorKind.Equals; + syntaxFacts.GetPartsOfBinaryExpression(syntax, out _, out var operatorToken, out _); + AnalyzeOperator( + context, + operation.LeftOperand, + operation.RightOperand, + isOperatorEquals, + operatorToken); + } - if (rightOperand.Type is null || leftOperand.Type is null) - return; + private void AnalyzeOperator( + OperationAnalysisContext context, + IOperation leftOperand, + IOperation rightOperand, + bool isOperatorEquals, + SyntaxToken operatorToken) + { + var leftType = leftOperand.Type; + var rightType = rightOperand.Type; - if (rightOperand.Type.SpecialType != SpecialType.System_Boolean || - leftOperand.Type.SpecialType != SpecialType.System_Boolean) + if (leftType?.SpecialType != SpecialType.System_Boolean || + rightType?.SpecialType != SpecialType.System_Boolean) { return; } - var isOperatorEquals = operation.OperatorKind == BinaryOperatorKind.Equals; - syntaxFacts.GetPartsOfBinaryExpression(operation.Syntax, out _, out var operatorToken, out _); - var properties = ImmutableDictionary.CreateBuilder(); if (TryGetLiteralValue(rightOperand) is bool rightBool) { @@ -76,12 +111,12 @@ private void AnalyzeBinaryOperator(OperationAnalysisContext context) context.ReportDiagnostic(Diagnostic.Create(Descriptor, operatorToken.GetLocation(), - additionalLocations: [operation.Syntax.GetLocation()], + additionalLocations: [operatorToken.GetLocation()], properties: properties.ToImmutable())); return; - static bool? TryGetLiteralValue(IOperation operand) + static bool? TryGetLiteralValue(IOperation? operand) { // Make sure we only simplify literals to avoid changing // something like the following example: diff --git a/src/Analyzers/Core/CodeFixes/RemoveRedundantEquality/RemoveRedundantEqualityCodeFixProvider.cs b/src/Analyzers/Core/CodeFixes/RemoveRedundantEquality/RemoveRedundantEqualityCodeFixProvider.cs index 15e00c76d5987..de18bbeca7810 100644 --- a/src/Analyzers/Core/CodeFixes/RemoveRedundantEquality/RemoveRedundantEqualityCodeFixProvider.cs +++ b/src/Analyzers/Core/CodeFixes/RemoveRedundantEquality/RemoveRedundantEqualityCodeFixProvider.cs @@ -49,19 +49,30 @@ protected override async Task FixAsync( SyntaxNode RewriteNode() { - // This should happen only in error cases. - if (!syntaxFacts.IsBinaryExpression(node)) - return node; + if (syntaxFacts.IsBinaryExpression(node)) + { + syntaxFacts.GetPartsOfBinaryExpression(node, out var left, out var right); + var rewritten = + properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Right ? left : + properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Left ? right : node; + + if (properties.ContainsKey(RedundantEqualityConstants.Negate)) + rewritten = generator.Negate(generatorInternal, rewritten, semanticModel, cancellationToken); - syntaxFacts.GetPartsOfBinaryExpression(node, out var left, out var right); - var rewritten = - properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Right ? left : - properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Left ? right : node; + return rewritten; + } + else if (syntaxFacts.IsIsPatternExpression(node)) + { + syntaxFacts.GetPartsOfIsPatternExpression(node, out var left, out _, out var right); + var rewritten = left; + if (properties.ContainsKey(RedundantEqualityConstants.Negate)) + rewritten = generator.Negate(generatorInternal, rewritten, semanticModel, cancellationToken); - if (properties.ContainsKey(RedundantEqualityConstants.Negate)) - rewritten = generator.Negate(generatorInternal, rewritten, semanticModel, cancellationToken); + return rewritten; + } - return rewritten; + // This should happen only in error cases. + return node; } static SyntaxNode WithElasticTrailingTrivia(SyntaxNode node) diff --git a/src/Analyzers/VisualBasic/Tests/RemoveRedundantEquality/RemoveRedundantEqualityTests.vb b/src/Analyzers/VisualBasic/Tests/RemoveRedundantEquality/RemoveRedundantEqualityTests.vb index 994931792e16a..00888b6059179 100644 --- a/src/Analyzers/VisualBasic/Tests/RemoveRedundantEquality/RemoveRedundantEqualityTests.vb +++ b/src/Analyzers/VisualBasic/Tests/RemoveRedundantEquality/RemoveRedundantEqualityTests.vb @@ -28,7 +28,7 @@ End Module End Function - Public Async Function TestSimpleCaseForEqualsFalse_NoDiagnostics() As Task + Public Async Function TestSimpleCaseForEqualsFalse() As Task Await VerifyVB.VerifyCodeFixAsync(" Public Module Module1 Public Function M1(x As Boolean) As Boolean