Skip to content

Commit

Permalink
Merge pull request #73806 from CyrusNajmabadi/boolSimplification
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi committed May 31, 2024
2 parents f1642d7 + f4d1b94 commit 5d852d4
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
namespace Microsoft.CodeAnalysis.CSharp.RemoveRedundantEquality;

[DiagnosticAnalyzer(LanguageNames.CSharp)]
internal sealed class CSharpRemoveRedundantEqualityDiagnosticAnalyzer
: AbstractRemoveRedundantEqualityDiagnosticAnalyzer
internal sealed class CSharpRemoveRedundantEqualityDiagnosticAnalyzer()
: AbstractRemoveRedundantEqualityDiagnosticAnalyzer(CSharpSyntaxFacts.Instance)
{
public CSharpRemoveRedundantEqualityDiagnosticAnalyzer() : base(CSharpSyntaxFacts.Instance)
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.RemoveRedundantEquality
CSharpRemoveRedundantEqualityDiagnosticAnalyzer,
RemoveRedundantEqualityCodeFixProvider>;

public class RemoveRedundantEqualityTests
public sealed class RemoveRedundantEqualityTests
{
[Fact]
public async Task TestSimpleCaseForEqualsTrue()
Expand Down Expand Up @@ -88,7 +88,7 @@ public bool M1(bool x)
}

[Fact]
public async Task TestSimpleCaseForNotEqualsTrue_NoDiagnostics()
public async Task TestSimpleCaseForNotEqualsTrue()
{
await VerifyCS.VerifyCodeFixAsync("""
public class C
Expand Down Expand Up @@ -252,7 +252,7 @@ public bool M3(bool x)
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/48236")]
public async Task TestNullableValueTypes_DoesntCrash()
public async Task TestNullableValueTypes_DoesNotCrash()
{
var code = """
public class C
Expand All @@ -265,4 +265,83 @@ public bool M1(int? x)
""";
await VerifyCS.VerifyAnalyzerAsync(code);
}

[Fact]
public async Task TestSimpleCaseForIsFalse()
{
var code = """
public class C
{
public bool M1(bool x)
{
return x [|is|] false;
}
}
""";
var fixedCode = """
public class C
{
public bool M1(bool x)
{
return !x;
}
}
""";
await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}

[Fact]
public async Task TestSimpleCaseForIsTrue1()
{
var code = """
public class C
{
public bool M1(bool x)
{
return x [|is|] true;
}
}
""";
var fixedCode = """
public class C
{
public bool M1(bool x)
{
return x;
}
}
""";
await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}

[Fact]
public async Task TestSimpleCaseForIsTrue2()
{
var code = """
public class C
{
public const bool MyTrueConstant = true;
public bool M1(bool x)
{
return x is MyTrueConstant;
}
}
""";
await VerifyCS.VerifyCodeFixAsync(code, code);
}

[Fact]
public async Task TestNotForNullableBool()
{
var code = """
public class C
{
public bool M1(bool? x)
{
return x is true;
}
}
""";
await VerifyCS.VerifyCodeFixAsync(code, code);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,42 @@ public override DiagnosticAnalyzerCategory GetAnalyzerCategory()
=> DiagnosticAnalyzerCategory.SemanticSpanAnalysis;

protected override void InitializeWorker(AnalysisContext context)
=> context.RegisterOperationAction(AnalyzeBinaryOperator, OperationKind.BinaryOperator);
{
context.RegisterOperationAction(AnalyzeBinaryOperator, OperationKind.BinaryOperator);
context.RegisterOperationAction(AnalyzeIsPatternOperator, OperationKind.IsPattern);
}

private void AnalyzeIsPatternOperator(OperationAnalysisContext context)
{
if (ShouldSkipAnalysis(context, notification: null))
return;

var syntax = context.Operation.Syntax;
if (!syntaxFacts.IsIsPatternExpression(syntax))
return;

var operation = (IIsPatternOperation)context.Operation;
if (operation.Pattern is not IConstantPatternOperation { Value.ConstantValue.Value: true or false } constantPattern)
return;

syntaxFacts.GetPartsOfIsPatternExpression(syntax, out _, out var isToken, out _);
AnalyzeOperator(
context,
leftOperand: operation.Value,
rightOperand: constantPattern.Value,
isOperatorEquals: true,
isToken);
}

private void AnalyzeBinaryOperator(OperationAnalysisContext context)
{
if (ShouldSkipAnalysis(context, notification: null))
return;

var syntax = context.Operation.Syntax;
if (!syntaxFacts.IsBinaryExpression(syntax))
return;

// We shouldn't report diagnostic on overloaded operator as the behavior can change.
var operation = (IBinaryOperation)context.Operation;
if (operation.OperatorMethod is not null)
Expand All @@ -36,26 +65,32 @@ private void AnalyzeBinaryOperator(OperationAnalysisContext context)
if (operation.OperatorKind is not (BinaryOperatorKind.Equals or BinaryOperatorKind.NotEquals))
return;

if (!syntaxFacts.IsBinaryExpression(operation.Syntax))
{
return;
}

var rightOperand = operation.RightOperand;
var leftOperand = operation.LeftOperand;
var isOperatorEquals = operation.OperatorKind == BinaryOperatorKind.Equals;
syntaxFacts.GetPartsOfBinaryExpression(syntax, out _, out var operatorToken, out _);
AnalyzeOperator(
context,
operation.LeftOperand,
operation.RightOperand,
isOperatorEquals,
operatorToken);
}

if (rightOperand.Type is null || leftOperand.Type is null)
return;
private void AnalyzeOperator(
OperationAnalysisContext context,
IOperation leftOperand,
IOperation rightOperand,
bool isOperatorEquals,
SyntaxToken operatorToken)
{
var leftType = leftOperand.Type;
var rightType = rightOperand.Type;

if (rightOperand.Type.SpecialType != SpecialType.System_Boolean ||
leftOperand.Type.SpecialType != SpecialType.System_Boolean)
if (leftType?.SpecialType != SpecialType.System_Boolean ||
rightType?.SpecialType != SpecialType.System_Boolean)
{
return;
}

var isOperatorEquals = operation.OperatorKind == BinaryOperatorKind.Equals;
syntaxFacts.GetPartsOfBinaryExpression(operation.Syntax, out _, out var operatorToken, out _);

var properties = ImmutableDictionary.CreateBuilder<string, string?>();
if (TryGetLiteralValue(rightOperand) is bool rightBool)
{
Expand All @@ -76,12 +111,12 @@ private void AnalyzeBinaryOperator(OperationAnalysisContext context)

context.ReportDiagnostic(Diagnostic.Create(Descriptor,
operatorToken.GetLocation(),
additionalLocations: [operation.Syntax.GetLocation()],
additionalLocations: [operatorToken.GetLocation()],
properties: properties.ToImmutable()));

return;

static bool? TryGetLiteralValue(IOperation operand)
static bool? TryGetLiteralValue(IOperation? operand)
{
// Make sure we only simplify literals to avoid changing
// something like the following example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,30 @@ protected override async Task FixAsync(

SyntaxNode RewriteNode()
{
// This should happen only in error cases.
if (!syntaxFacts.IsBinaryExpression(node))
return node;
if (syntaxFacts.IsBinaryExpression(node))
{
syntaxFacts.GetPartsOfBinaryExpression(node, out var left, out var right);
var rewritten =
properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Right ? left :
properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Left ? right : node;

if (properties.ContainsKey(RedundantEqualityConstants.Negate))
rewritten = generator.Negate(generatorInternal, rewritten, semanticModel, cancellationToken);

syntaxFacts.GetPartsOfBinaryExpression(node, out var left, out var right);
var rewritten =
properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Right ? left :
properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Left ? right : node;
return rewritten;
}
else if (syntaxFacts.IsIsPatternExpression(node))
{
syntaxFacts.GetPartsOfIsPatternExpression(node, out var left, out _, out var right);
var rewritten = left;
if (properties.ContainsKey(RedundantEqualityConstants.Negate))
rewritten = generator.Negate(generatorInternal, rewritten, semanticModel, cancellationToken);

if (properties.ContainsKey(RedundantEqualityConstants.Negate))
rewritten = generator.Negate(generatorInternal, rewritten, semanticModel, cancellationToken);
return rewritten;
}

return rewritten;
// This should happen only in error cases.
return node;
}

static SyntaxNode WithElasticTrailingTrivia(SyntaxNode node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ End Module
End Function

<Fact>
Public Async Function TestSimpleCaseForEqualsFalse_NoDiagnostics() As Task
Public Async Function TestSimpleCaseForEqualsFalse() As Task
Await VerifyVB.VerifyCodeFixAsync("
Public Module Module1
Public Function M1(x As Boolean) As Boolean
Expand Down

0 comments on commit 5d852d4

Please sign in to comment.