Skip to content

Commit

Permalink
Query: Make null check removal recursive
Browse files Browse the repository at this point in the history
Resolves #15204
  • Loading branch information
smitpatel committed Jul 12, 2019
1 parent aec9bcf commit e3bb027
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/EFCore.InMemory/Query/Pipeline/Translator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return methodCallExpression.Update(@object, arguments);
}

private static MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0];
private static readonly MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0];

protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression)
{
Expand Down
26 changes: 25 additions & 1 deletion src/EFCore/Query/Pipeline/NullCheckRemovingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ public class NullCheckRemovingExpressionVisitor : ExpressionVisitor
{
private readonly NullSafeAccessVerifyingExpressionVisitor _nullSafeAccessVerifyingExpressionVisitor
= new NullSafeAccessVerifyingExpressionVisitor();
private readonly NullConditionalRemovingExpressionVisitor _nullConditionalRemovingExpressionVisitor
= new NullConditionalRemovingExpressionVisitor();

public NullCheckRemovingExpressionVisitor()
{
}

protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
var test = conditionalExpression.Test;
var test = Visit(conditionalExpression.Test);

if (test is BinaryExpression binaryTest
&& (binaryTest.NodeType == ExpressionType.Equal
Expand All @@ -42,6 +44,15 @@ protected override Expression VisitConditional(ConditionalExpression conditional
? conditionalExpression.IfFalse
: conditionalExpression.IfTrue;

// Unwrap nested nullConditional
if (caller is NullConditionalExpression nullConditionalCaller)
{
accessOperation = ReplacingExpressionVisitor.Replace(
_nullConditionalRemovingExpressionVisitor.Visit(nullConditionalCaller.AccessOperation),
nullConditionalCaller,
accessOperation);
}

if (_nullSafeAccessVerifyingExpressionVisitor.Verify(caller, accessOperation))
{
return new NullConditionalExpression(caller, accessOperation);
Expand All @@ -51,6 +62,19 @@ protected override Expression VisitConditional(ConditionalExpression conditional
return base.VisitConditional(conditionalExpression);
}

private class NullConditionalRemovingExpressionVisitor : ExpressionVisitor
{
public override Expression Visit(Expression expression)
{
if (expression is NullConditionalExpression nullConditionalExpression)
{
return Visit(nullConditionalExpression.AccessOperation);
}

return base.Visit(expression);
}
}

private class NullSafeAccessVerifyingExpressionVisitor : ExpressionVisitor
{
private readonly ISet<Expression> _nullSafeAccesses = new HashSet<Expression>(ExpressionEqualityComparer.Instance);
Expand Down
9 changes: 7 additions & 2 deletions src/EFCore/Query/Pipeline/ReplacingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;

namespace Microsoft.EntityFrameworkCore.Query.Pipeline
{
Expand All @@ -15,14 +16,18 @@ public class ReplacingExpressionVisitor : ExpressionVisitor
public static Expression Replace(Expression original, Expression replacement, Expression tree)
{
return new ReplacingExpressionVisitor(
new Dictionary<Expression, Expression> { { original, replacement } }).Visit(tree);
new Dictionary<Expression, Expression>(ExpressionEqualityComparer.Instance)
{
{ original, replacement }
}).Visit(tree);
}

public static Expression Replace(
Expression original1, Expression replacement1, Expression original2, Expression replacement2, Expression tree)
{
return new ReplacingExpressionVisitor(
new Dictionary<Expression, Expression> {
new Dictionary<Expression, Expression>(ExpressionEqualityComparer.Instance)
{
{ original1, replacement1 },
{ original2, replacement2 }
}).Visit(tree);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6278,5 +6278,41 @@ public virtual Task Include_multiple_collections_on_same_level(bool isAsync)
},
assertOrder: true);
}

[ConditionalTheory(Skip = "Issue#16088")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Null_check_removal_applied_recursively(bool isAsync)
{
return AssertQuery<Level1>(
isAsync,
l1s => l1s.Where(l1 =>
((((l1.OneToOne_Optional_FK1 == null
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2) == null
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3) == null
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3) == null
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3.Name) == "L4 01"));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Null_check_different_structure_does_not_remove_null_checks(bool isAsync)
{
return AssertQuery<Level1>(
isAsync,
l1s => l1s.Where(l1 =>
(l1.OneToOne_Optional_FK1 == null
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2 == null
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3 == null
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3 == null
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3.Name) == "L4 01"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4698,6 +4698,44 @@ public override async Task Member_pushdown_with_multiple_collections(bool isAsyn
@"");
}

public override async Task Null_check_removal_applied_recursively(bool isAsync)
{
await base.Null_check_removal_applied_recursively(isAsync);

AssertSql(" ");
}

public override async Task Null_check_different_structure_does_not_remove_null_checks(bool isAsync)
{
await base.Null_check_different_structure_does_not_remove_null_checks(isAsync);

AssertSql(
@"SELECT [l].[Id], [l].[Date], [l].[Name], [l].[OneToMany_Optional_Self_Inverse1Id], [l].[OneToMany_Required_Self_Inverse1Id], [l].[OneToOne_Optional_Self1Id]
FROM [LevelOne] AS [l]
LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[Level1_Optional_Id]
LEFT JOIN [LevelThree] AS [l1] ON [l0].[Id] = [l1].[Level2_Optional_Id]
LEFT JOIN [LevelFour] AS [l2] ON [l1].[Id] = [l2].[Level3_Optional_Id]
WHERE (CASE
WHEN [l0].[Id] IS NULL THEN NULL
ELSE CASE
WHEN [l1].[Id] IS NULL THEN NULL
ELSE CASE
WHEN [l2].[Id] IS NULL THEN NULL
ELSE [l2].[Name]
END
END
END = N'L4 01') AND CASE
WHEN [l0].[Id] IS NULL THEN NULL
ELSE CASE
WHEN [l1].[Id] IS NULL THEN NULL
ELSE CASE
WHEN [l2].[Id] IS NULL THEN NULL
ELSE [l2].[Name]
END
END
END IS NOT NULL");
}

private void AssertSql(params string[] expected)
{
//Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
Expand Down
Loading

0 comments on commit e3bb027

Please sign in to comment.