diff --git a/src/EFCore.InMemory/Query/Pipeline/Translator.cs b/src/EFCore.InMemory/Query/Pipeline/Translator.cs index 0418699c6f5..0121977d1a7 100644 --- a/src/EFCore.InMemory/Query/Pipeline/Translator.cs +++ b/src/EFCore.InMemory/Query/Pipeline/Translator.cs @@ -189,7 +189,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return methodCallExpression.Update(@object, arguments); } - private static MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0]; + private static readonly MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0]; protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) { diff --git a/src/EFCore/Query/Pipeline/NullCheckRemovingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/NullCheckRemovingExpressionVisitor.cs index 03c93e03ad0..6a98170ee9a 100644 --- a/src/EFCore/Query/Pipeline/NullCheckRemovingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/NullCheckRemovingExpressionVisitor.cs @@ -12,6 +12,8 @@ public class NullCheckRemovingExpressionVisitor : ExpressionVisitor { private readonly NullSafeAccessVerifyingExpressionVisitor _nullSafeAccessVerifyingExpressionVisitor = new NullSafeAccessVerifyingExpressionVisitor(); + private readonly NullConditionalRemovingExpressionVisitor _nullConditionalRemovingExpressionVisitor + = new NullConditionalRemovingExpressionVisitor(); public NullCheckRemovingExpressionVisitor() { @@ -19,7 +21,7 @@ public NullCheckRemovingExpressionVisitor() protected override Expression VisitConditional(ConditionalExpression conditionalExpression) { - var test = conditionalExpression.Test; + var test = Visit(conditionalExpression.Test); if (test is BinaryExpression binaryTest && (binaryTest.NodeType == ExpressionType.Equal @@ -42,6 +44,15 @@ protected override Expression VisitConditional(ConditionalExpression conditional ? conditionalExpression.IfFalse : conditionalExpression.IfTrue; + // Unwrap nested nullConditional + if (caller is NullConditionalExpression nullConditionalCaller) + { + accessOperation = ReplacingExpressionVisitor.Replace( + _nullConditionalRemovingExpressionVisitor.Visit(nullConditionalCaller.AccessOperation), + nullConditionalCaller, + accessOperation); + } + if (_nullSafeAccessVerifyingExpressionVisitor.Verify(caller, accessOperation)) { return new NullConditionalExpression(caller, accessOperation); @@ -51,6 +62,19 @@ protected override Expression VisitConditional(ConditionalExpression conditional return base.VisitConditional(conditionalExpression); } + private class NullConditionalRemovingExpressionVisitor : ExpressionVisitor + { + public override Expression Visit(Expression expression) + { + if (expression is NullConditionalExpression nullConditionalExpression) + { + return Visit(nullConditionalExpression.AccessOperation); + } + + return base.Visit(expression); + } + } + private class NullSafeAccessVerifyingExpressionVisitor : ExpressionVisitor { private readonly ISet _nullSafeAccesses = new HashSet(ExpressionEqualityComparer.Instance); diff --git a/src/EFCore/Query/Pipeline/ReplacingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/ReplacingExpressionVisitor.cs index 784d2a200e9..3f0afb7d870 100644 --- a/src/EFCore/Query/Pipeline/ReplacingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/ReplacingExpressionVisitor.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Query.Internal; namespace Microsoft.EntityFrameworkCore.Query.Pipeline { @@ -15,14 +16,18 @@ public class ReplacingExpressionVisitor : ExpressionVisitor public static Expression Replace(Expression original, Expression replacement, Expression tree) { return new ReplacingExpressionVisitor( - new Dictionary { { original, replacement } }).Visit(tree); + new Dictionary(ExpressionEqualityComparer.Instance) + { + { original, replacement } + }).Visit(tree); } public static Expression Replace( Expression original1, Expression replacement1, Expression original2, Expression replacement2, Expression tree) { return new ReplacingExpressionVisitor( - new Dictionary { + new Dictionary(ExpressionEqualityComparer.Instance) + { { original1, replacement1 }, { original2, replacement2 } }).Visit(tree); diff --git a/test/EFCore.Specification.Tests/Query/ComplexNavigationsQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/ComplexNavigationsQueryTestBase.cs index d0141b7ab81..db18104f9c4 100644 --- a/test/EFCore.Specification.Tests/Query/ComplexNavigationsQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/ComplexNavigationsQueryTestBase.cs @@ -6278,5 +6278,41 @@ public virtual Task Include_multiple_collections_on_same_level(bool isAsync) }, assertOrder: true); } + + [ConditionalTheory(Skip = "Issue#16088")] + [MemberData(nameof(IsAsyncData))] + public virtual Task Null_check_removal_applied_recursively(bool isAsync) + { + return AssertQuery( + isAsync, + l1s => l1s.Where(l1 => + ((((l1.OneToOne_Optional_FK1 == null + ? null + : l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2) == null + ? null + : l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3) == null + ? null + : l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3) == null + ? null + : l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3.Name) == "L4 01")); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Null_check_different_structure_does_not_remove_null_checks(bool isAsync) + { + return AssertQuery( + isAsync, + l1s => l1s.Where(l1 => + (l1.OneToOne_Optional_FK1 == null + ? null + : l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2 == null + ? null + : l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3 == null + ? null + : l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3 == null + ? null + : l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3.Name) == "L4 01")); + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsQuerySqlServerTest.cs index 7cfbecb8389..b33f0a5cc08 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsQuerySqlServerTest.cs @@ -4698,6 +4698,44 @@ public override async Task Member_pushdown_with_multiple_collections(bool isAsyn @""); } + public override async Task Null_check_removal_applied_recursively(bool isAsync) + { + await base.Null_check_removal_applied_recursively(isAsync); + + AssertSql(" "); + } + + public override async Task Null_check_different_structure_does_not_remove_null_checks(bool isAsync) + { + await base.Null_check_different_structure_does_not_remove_null_checks(isAsync); + + AssertSql( + @"SELECT [l].[Id], [l].[Date], [l].[Name], [l].[OneToMany_Optional_Self_Inverse1Id], [l].[OneToMany_Required_Self_Inverse1Id], [l].[OneToOne_Optional_Self1Id] +FROM [LevelOne] AS [l] +LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[Level1_Optional_Id] +LEFT JOIN [LevelThree] AS [l1] ON [l0].[Id] = [l1].[Level2_Optional_Id] +LEFT JOIN [LevelFour] AS [l2] ON [l1].[Id] = [l2].[Level3_Optional_Id] +WHERE (CASE + WHEN [l0].[Id] IS NULL THEN NULL + ELSE CASE + WHEN [l1].[Id] IS NULL THEN NULL + ELSE CASE + WHEN [l2].[Id] IS NULL THEN NULL + ELSE [l2].[Name] + END + END +END = N'L4 01') AND CASE + WHEN [l0].[Id] IS NULL THEN NULL + ELSE CASE + WHEN [l1].[Id] IS NULL THEN NULL + ELSE CASE + WHEN [l2].[Id] IS NULL THEN NULL + ELSE [l2].[Name] + END + END +END IS NOT NULL"); + } + private void AssertSql(params string[] expected) { //Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index b5cb4080367..038de871108 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -6,7 +6,6 @@ using System.Collections.ObjectModel; using System.ComponentModel.DataAnnotations; using System.Data; -using System.Data.Common; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -1251,7 +1250,7 @@ public virtual void Repro3101_simple_coalesce1() { var query = from eVersion in ctx.Entities.Include(e => e.Children) join eRoot in ctx.Entities - on eVersion.RootEntityId equals (int?)eRoot.Id + on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() select eRootJoined ?? eVersion; @@ -1271,7 +1270,7 @@ public virtual void Repro3101_simple_coalesce2() { var query = from eVersion in ctx.Entities join eRoot in ctx.Entities.Include(e => e.Children) - on eVersion.RootEntityId equals (int?)eRoot.Id + on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() select eRootJoined ?? eVersion; @@ -1291,7 +1290,7 @@ public virtual void Repro3101_simple_coalesce3() { var query = from eVersion in ctx.Entities.Include(e => e.Children) join eRoot in ctx.Entities.Include(e => e.Children) - on eVersion.RootEntityId equals (int?)eRoot.Id + on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() select eRootJoined ?? eVersion; @@ -1311,7 +1310,7 @@ public virtual void Repro3101_complex_coalesce1() { var query = from eVersion in ctx.Entities.Include(e => e.Children) join eRoot in ctx.Entities - on eVersion.RootEntityId equals (int?)eRoot.Id + on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() select new @@ -1335,7 +1334,7 @@ public virtual void Repro3101_complex_coalesce2() { var query = from eVersion in ctx.Entities join eRoot in ctx.Entities.Include(e => e.Children) - on eVersion.RootEntityId equals (int?)eRoot.Id + on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() select new @@ -1359,7 +1358,7 @@ public virtual void Repro3101_nested_coalesce1() { var query = from eVersion in ctx.Entities join eRoot in ctx.Entities.Include(e => e.Children) - on eVersion.RootEntityId equals (int?)eRoot.Id + on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() // ReSharper disable once ConstantNullCoalescingCondition @@ -1384,7 +1383,7 @@ public virtual void Repro3101_nested_coalesce2() { var query = from eVersion in ctx.Entities.Include(e => e.Children) join eRoot in ctx.Entities - on eVersion.RootEntityId equals (int?)eRoot.Id + on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() // ReSharper disable once ConstantNullCoalescingCondition @@ -1410,7 +1409,7 @@ public virtual void Repro3101_conditional() { var query = from eVersion in ctx.Entities.Include(e => e.Children) join eRoot in ctx.Entities - on eVersion.RootEntityId equals (int?)eRoot.Id + on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() // ReSharper disable once MergeConditionalExpression @@ -1433,7 +1432,7 @@ public virtual void Repro3101_coalesce_tracking() { var query = from eVersion in ctx.Entities join eRoot in ctx.Entities - on eVersion.RootEntityId equals (int?)eRoot.Id + on eVersion.RootEntityId equals eRoot.Id into RootEntities from eRootJoined in RootEntities.DefaultIfEmpty() select new @@ -3813,8 +3812,8 @@ from m in grouping2.DefaultIfEmpty() }) .ToList(); - AssertSql( - @"SELECT [t].[Name] AS [MyKey], COUNT(*) + 5 AS [cnt] + AssertSql( + @"SELECT [t].[Name] AS [MyKey], COUNT(*) + 5 AS [cnt] FROM [Table] AS [t0] LEFT JOIN [Table] AS [t] ON [t0].[Id] = [t].[Id] LEFT JOIN [Table] AS [t1] ON [t0].[Id] = [t1].[Id] @@ -5661,7 +5660,7 @@ public class MyContext15684 : DbContext public DbSet Categories { get; set; } public DbSet Products { get; set; } - public MyContext15684(DbContextOptions options) : base(options) {} + public MyContext15684(DbContextOptions options) : base(options) { } protected override void OnModelCreating(ModelBuilder modelBuilder) { @@ -5710,7 +5709,149 @@ public enum CategoryStatus15684 Removed = 1, } - #endregion Bug15684 + #endregion + + #region Bug15204 + + private MemberInfo GetMemberInfo(Type type, string name) + { + return type.GetTypeInfo().GetProperty(name); + } + + [ConditionalFact] + public virtual void Null_check_removal_applied_recursively() + { + using (CreateDatabase15204()) + { + var userParam = Expression.Parameter(typeof(TBuilding15204), "s"); + var builderProperty = Expression.MakeMemberAccess(userParam, GetMemberInfo(typeof(TBuilding15204), "Builder")); + var cityProperty = Expression.MakeMemberAccess(builderProperty, GetMemberInfo(typeof(TBuilder15204), "City")); + var nameProperty = Expression.MakeMemberAccess(cityProperty, GetMemberInfo(typeof(TCity15204), "Name")); + + //{s => (IIF((IIF((s.Builder == null), null, s.Builder.City) == null), null, s.Builder.City.Name) == "Leeds")} + var selection = Expression.Lambda>( + Expression.Equal( + Expression.Condition( + Expression.Equal( + Expression.Condition( + Expression.Equal( + builderProperty, + Expression.Constant(null, typeof(TBuilder15204))), + Expression.Constant(null, typeof(TCity15204)), + cityProperty), + Expression.Constant(null, typeof(TCity15204))), + Expression.Constant(null, typeof(string)), + nameProperty), + Expression.Constant("Leeds", typeof(string))), + userParam); + + + using (var context = new MyContext15204(_options)) + { + var query = context.BuildingSet + .Where(selection) + .Include(a => a.Builder).ThenInclude(a => a.City) + .Include(a => a.Mandator).ToList(); + + Assert.True(query.Count == 1); + Assert.True(query.First().Builder.City.Name == "Leeds"); + Assert.True(query.First().LongName == "Two L2"); + + AssertSql( + @"SELECT [b].[Id], [b].[BuilderId], [b].[Identity], [b].[LongName], [b].[MandatorId], [b0].[Id], [b0].[CityId], [b0].[Name], [c].[Id], [c].[Name], [m].[Id], [m].[Identity], [m].[Name] +FROM [BuildingSet] AS [b] +INNER JOIN [Builder] AS [b0] ON [b].[BuilderId] = [b0].[Id] +INNER JOIN [City] AS [c] ON [b0].[CityId] = [c].[Id] +INNER JOIN [MandatorSet] AS [m] ON [b].[MandatorId] = [m].[Id] +WHERE ([c].[Name] = N'Leeds') AND [c].[Name] IS NOT NULL"); + } + } + } + + private SqlServerTestStore CreateDatabase15204() + => CreateTestStore( + () => new MyContext15204(_options), + context => + { + var london = new TCity15204 { Name = "London" }; + var sam = new TBuilder15204 { Name = "Sam", City = london }; + + context.MandatorSet.Add(new TMandator15204 + { + Identity = Guid.NewGuid(), + Name = "One", + Buildings = new List + { + new TBuilding15204 { Identity = Guid.NewGuid(), LongName = "One L1", Builder = sam }, + new TBuilding15204 { Identity = Guid.NewGuid(), LongName = "One L2", Builder = sam } + } + }); + context.MandatorSet.Add(new TMandator15204 + { + Identity = Guid.NewGuid(), + Name = "Two", + Buildings = new List + { + new TBuilding15204 { Identity = Guid.NewGuid(), LongName = "Two L1", + Builder = new TBuilder15204 { Name = "John", City = london }}, + new TBuilding15204 { Identity = Guid.NewGuid(), LongName = "Two L2", + Builder = new TBuilder15204 { Name = "Mark", City = new TCity15204 { Name = "Leeds" }}} + } + }); + + context.SaveChanges(); + + ClearLog(); + }); + + public class MyContext15204 : DbContext + { + public DbSet MandatorSet { get; set; } + public DbSet BuildingSet { get; set; } + public DbSet Builder { get; set; } + public DbSet City { get; set; } + + public MyContext15204(DbContextOptions options) : base(options) + { + ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking; + ChangeTracker.AutoDetectChangesEnabled = false; + } + } + + public class TBuilding15204 + { + public int Id { get; set; } + public Guid Identity { get; set; } + public string LongName { get; set; } + public int BuilderId { get; set; } + public TBuilder15204 Builder { get; set; } + public TMandator15204 Mandator { get; set; } + public int MandatorId { get; set; } + } + + public class TBuilder15204 + { + public int Id { get; set; } + public string Name { get; set; } + public int CityId { get; set; } + public TCity15204 City { get; set; } + } + + public class TCity15204 + { + public int Id { get; set; } + public string Name { get; set; } + } + + public class TMandator15204 + { + public int Id { get; set; } + public Guid Identity { get; set; } + public string Name { get; set; } + public virtual ICollection Buildings { get; set; } + } + + #endregion private DbContextOptions _options;