Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more boolean simplification cases #73806

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
namespace Microsoft.CodeAnalysis.CSharp.RemoveRedundantEquality;

[DiagnosticAnalyzer(LanguageNames.CSharp)]
internal sealed class CSharpRemoveRedundantEqualityDiagnosticAnalyzer
: AbstractRemoveRedundantEqualityDiagnosticAnalyzer
internal sealed class CSharpRemoveRedundantEqualityDiagnosticAnalyzer()
: AbstractRemoveRedundantEqualityDiagnosticAnalyzer(CSharpSyntaxFacts.Instance)
{
public CSharpRemoveRedundantEqualityDiagnosticAnalyzer() : base(CSharpSyntaxFacts.Instance)
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Microsoft.CodeAnalysis.Editor.CSharp.UnitTests.RemoveRedundantEquality
CSharpRemoveRedundantEqualityDiagnosticAnalyzer,
RemoveRedundantEqualityCodeFixProvider>;

public class RemoveRedundantEqualityTests
public sealed class RemoveRedundantEqualityTests
{
[Fact]
public async Task TestSimpleCaseForEqualsTrue()
Expand Down Expand Up @@ -88,7 +88,7 @@ public bool M1(bool x)
}

[Fact]
public async Task TestSimpleCaseForNotEqualsTrue_NoDiagnostics()
public async Task TestSimpleCaseForNotEqualsTrue()
{
await VerifyCS.VerifyCodeFixAsync("""
public class C
Expand Down Expand Up @@ -252,7 +252,7 @@ public bool M3(bool x)
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/48236")]
public async Task TestNullableValueTypes_DoesntCrash()
public async Task TestNullableValueTypes_DoesNotCrash()
{
var code = """
public class C
Expand All @@ -265,4 +265,83 @@ public bool M1(int? x)
""";
await VerifyCS.VerifyAnalyzerAsync(code);
}

[Fact]
public async Task TestSimpleCaseForIsFalse()
{
var code = """
public class C
{
public bool M1(bool x)
{
return x [|is|] false;
}
}
""";
var fixedCode = """
public class C
{
public bool M1(bool x)
{
return !x;
}
}
""";
await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}

[Fact]
public async Task TestSimpleCaseForIsTrue1()
{
var code = """
public class C
{
public bool M1(bool x)
{
return x [|is|] true;
}
}
""";
var fixedCode = """
public class C
{
public bool M1(bool x)
{
return x;
}
}
""";
await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}

[Fact]
public async Task TestSimpleCaseForIsTrue2()
{
var code = """
public class C
{
public const bool MyTrueConstant = true;
public bool M1(bool x)
{
return x is MyTrueConstant;
}
}
""";
await VerifyCS.VerifyCodeFixAsync(code, code);
}

[Fact]
public async Task TestNotForNullableBool()
{
var code = """
public class C
{
public bool M1(bool? x)
{
return x is true;
}
}
""";
await VerifyCS.VerifyCodeFixAsync(code, code);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,42 @@ public override DiagnosticAnalyzerCategory GetAnalyzerCategory()
=> DiagnosticAnalyzerCategory.SemanticSpanAnalysis;

protected override void InitializeWorker(AnalysisContext context)
=> context.RegisterOperationAction(AnalyzeBinaryOperator, OperationKind.BinaryOperator);
{
context.RegisterOperationAction(AnalyzeBinaryOperator, OperationKind.BinaryOperator);
context.RegisterOperationAction(AnalyzeIsPatternOperator, OperationKind.IsPattern);
}

private void AnalyzeIsPatternOperator(OperationAnalysisContext context)
{
if (ShouldSkipAnalysis(context, notification: null))
return;

var syntax = context.Operation.Syntax;
if (!syntaxFacts.IsIsPatternExpression(syntax))
return;

var operation = (IIsPatternOperation)context.Operation;
if (operation.Pattern is not IConstantPatternOperation { Value.ConstantValue.Value: true or false } constantPattern)
return;

syntaxFacts.GetPartsOfIsPatternExpression(syntax, out _, out var isToken, out _);
AnalyzeOperator(
context,
leftOperand: operation.Value,
rightOperand: constantPattern.Value,
isOperatorEquals: true,
isToken);
}

private void AnalyzeBinaryOperator(OperationAnalysisContext context)
{
if (ShouldSkipAnalysis(context, notification: null))
return;

var syntax = context.Operation.Syntax;
if (!syntaxFacts.IsBinaryExpression(syntax))
return;

// We shouldn't report diagnostic on overloaded operator as the behavior can change.
var operation = (IBinaryOperation)context.Operation;
if (operation.OperatorMethod is not null)
Expand All @@ -36,26 +65,32 @@ private void AnalyzeBinaryOperator(OperationAnalysisContext context)
if (operation.OperatorKind is not (BinaryOperatorKind.Equals or BinaryOperatorKind.NotEquals))
return;

if (!syntaxFacts.IsBinaryExpression(operation.Syntax))
{
return;
}

var rightOperand = operation.RightOperand;
var leftOperand = operation.LeftOperand;
var isOperatorEquals = operation.OperatorKind == BinaryOperatorKind.Equals;
syntaxFacts.GetPartsOfBinaryExpression(syntax, out _, out var operatorToken, out _);
AnalyzeOperator(
context,
operation.LeftOperand,
operation.RightOperand,
isOperatorEquals,
operatorToken);
}

if (rightOperand.Type is null || leftOperand.Type is null)
return;
private void AnalyzeOperator(
OperationAnalysisContext context,
IOperation leftOperand,
IOperation rightOperand,
bool isOperatorEquals,
SyntaxToken operatorToken)
{
var leftType = leftOperand.Type;
var rightType = rightOperand.Type;

if (rightOperand.Type.SpecialType != SpecialType.System_Boolean ||
leftOperand.Type.SpecialType != SpecialType.System_Boolean)
if (leftType?.SpecialType != SpecialType.System_Boolean ||
rightType?.SpecialType != SpecialType.System_Boolean)
{
return;
}

var isOperatorEquals = operation.OperatorKind == BinaryOperatorKind.Equals;
syntaxFacts.GetPartsOfBinaryExpression(operation.Syntax, out _, out var operatorToken, out _);

var properties = ImmutableDictionary.CreateBuilder<string, string?>();
if (TryGetLiteralValue(rightOperand) is bool rightBool)
{
Expand All @@ -76,12 +111,12 @@ private void AnalyzeBinaryOperator(OperationAnalysisContext context)

context.ReportDiagnostic(Diagnostic.Create(Descriptor,
operatorToken.GetLocation(),
additionalLocations: [operation.Syntax.GetLocation()],
additionalLocations: [operatorToken.GetLocation()],
properties: properties.ToImmutable()));

return;

static bool? TryGetLiteralValue(IOperation operand)
static bool? TryGetLiteralValue(IOperation? operand)
{
// Make sure we only simplify literals to avoid changing
// something like the following example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,30 @@ protected override async Task FixAsync(

SyntaxNode RewriteNode()
{
// This should happen only in error cases.
if (!syntaxFacts.IsBinaryExpression(node))
return node;
if (syntaxFacts.IsBinaryExpression(node))
{
syntaxFacts.GetPartsOfBinaryExpression(node, out var left, out var right);
var rewritten =
properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Right ? left :
properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Left ? right : node;

if (properties.ContainsKey(RedundantEqualityConstants.Negate))
rewritten = generator.Negate(generatorInternal, rewritten, semanticModel, cancellationToken);

syntaxFacts.GetPartsOfBinaryExpression(node, out var left, out var right);
var rewritten =
properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Right ? left :
properties[RedundantEqualityConstants.RedundantSide] == RedundantEqualityConstants.Left ? right : node;
return rewritten;
}
else if (syntaxFacts.IsIsPatternExpression(node))
{
syntaxFacts.GetPartsOfIsPatternExpression(node, out var left, out _, out var right);
var rewritten = left;
if (properties.ContainsKey(RedundantEqualityConstants.Negate))
rewritten = generator.Negate(generatorInternal, rewritten, semanticModel, cancellationToken);

if (properties.ContainsKey(RedundantEqualityConstants.Negate))
rewritten = generator.Negate(generatorInternal, rewritten, semanticModel, cancellationToken);
return rewritten;
}

return rewritten;
// This should happen only in error cases.
return node;
}

static SyntaxNode WithElasticTrailingTrivia(SyntaxNode node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ End Module
End Function

<Fact>
Public Async Function TestSimpleCaseForEqualsFalse_NoDiagnostics() As Task
Public Async Function TestSimpleCaseForEqualsFalse() As Task
Await VerifyVB.VerifyCodeFixAsync("
Public Module Module1
Public Function M1(x As Boolean) As Boolean
Expand Down
Loading