diff --git a/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs index 8d9c1f200c8..19579681e28 100644 --- a/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/NullSemanticsRewritingExpressionVisitor.cs @@ -2,25 +2,37 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; using System.Collections.Generic; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.EntityFrameworkCore.Storage; namespace Microsoft.EntityFrameworkCore.Query.Internal { public class NullSemanticsRewritingExpressionVisitor : SqlExpressionVisitor { + private readonly bool _useRelationalNulls; private readonly ISqlExpressionFactory _sqlExpressionFactory; + private readonly List _nonNullableColumns = new List(); + private readonly IReadOnlyDictionary _parameterValues; private bool _isNullable; private bool _canOptimize; - private readonly List _nonNullableColumns = new List(); - public NullSemanticsRewritingExpressionVisitor(ISqlExpressionFactory sqlExpressionFactory) + public virtual bool CanCache { get; set; } + + public NullSemanticsRewritingExpressionVisitor( + bool useRelationalNulls, + ISqlExpressionFactory sqlExpressionFactory, + IReadOnlyDictionary parameterValues) { + _useRelationalNulls = useRelationalNulls; _sqlExpressionFactory = sqlExpressionFactory; + _parameterValues = parameterValues; _canOptimize = true; + CanCache = true; } protected override Expression VisitCase(CaseExpression caseExpression) @@ -111,56 +123,164 @@ protected override Expression VisitIn(InExpression inExpression) _canOptimize = false; _isNullable = false; var item = (SqlExpression)Visit(inExpression.Item); - var isNullable = _isNullable; + var itemNullable = _isNullable; _isNullable = false; var subquery = (SelectExpression)Visit(inExpression.Subquery); - isNullable |= _isNullable; + var subqueryNullable = _isNullable; + + if (inExpression.Values == null) + { + _isNullable |= itemNullable; + _canOptimize = canOptimize; + + return inExpression.Update(item, values: null, subquery); + } + _isNullable = false; - var values = (SqlExpression)Visit(inExpression.Values); - _isNullable |= isNullable; + SqlExpression inValues; + + // for relational null semantics just leave as is + if (_useRelationalNulls) + { + inValues = (SqlExpression)Visit(inExpression.Values); + _isNullable = _isNullable || itemNullable || subqueryNullable; + _canOptimize = canOptimize; + + return inExpression.Update(item, inValues, subquery); + } + + // for c# null semantics we need to remove nulls from Values and add IsNull/IsNotNull when necessary + var hasNullValue = default(bool); + (inValues, hasNullValue) = ProcessInExpressionValues(inExpression.Values); + _canOptimize = canOptimize; - return inExpression.Update(item, values, subquery); + // either values array is empty or only contains null + if (inValues == null) + { + _isNullable = false; + + // a IN () -> false + // non_nullable IN (NULL) -> false + // a NOT IN () -> true + // non_nullable NOT IN (NULL) -> true + // nullable IN (NULL) -> nullable IS NULL + // nullable NOT IN (NULL) -> nullable IS NOT NULL + return !hasNullValue || !itemNullable + ? (SqlExpression)_sqlExpressionFactory.Constant( + inExpression.IsNegated, + inExpression.TypeMapping) + : inExpression.IsNegated + ? _sqlExpressionFactory.IsNotNull(item) + : _sqlExpressionFactory.IsNull(item); + } + + _isNullable = _isNullable | itemNullable || subqueryNullable; + + if (!itemNullable + || (_canOptimize && !inExpression.IsNegated && !hasNullValue)) + { + // non_nullable IN (1, 2) -> non_nullable IN (1, 2) + // non_nullable IN (1, 2, NULL) -> non_nullable IN (1, 2) + // nullable IN (1, 2) -> nullable IN (1, 2) (optimized) + return inExpression.Update(item, inValues, subquery); + } + + // adding null comparison term to remove nulls completely from the resulting expression + _isNullable = false; + + // nullable IN (1, 2) -> nullable IN (1, 2) AND nullable IS NOT NULL (full) + // nullable IN (1, 2, NULL) -> nullable IN (1, 2) OR nullable IS NULL (full) + // nullable NOT IN (1, 2) -> nullable NOT IN (1, 2) OR nullable IS NULL (full) + // nullable NOT IN (1, 2, NULL) -> nullable NOT IN (1, 2) AND nullable IS NOT NULL (full) + return inExpression.IsNegated == hasNullValue + ? _sqlExpressionFactory.AndAlso( + inExpression.Update(item, inValues, subquery), + _sqlExpressionFactory.IsNotNull(item)) + : _sqlExpressionFactory.OrElse( + inExpression.Update(item, inValues, subquery), + _sqlExpressionFactory.IsNull(item)); } - protected override Expression VisitIntersect(IntersectExpression intersectExpression) + private (SqlExpression processedValues, bool hasNullValue) ProcessInExpressionValues(SqlExpression valuesExpression) { - var canOptimize = _canOptimize; - _canOptimize = false; - var source1 = (SelectExpression)Visit(intersectExpression.Source1); - var source2 = (SelectExpression)Visit(intersectExpression.Source2); - _canOptimize = canOptimize; + if (valuesExpression == null) + { + return (processedValues: null, hasNullValue: false); + } - return intersectExpression.Update(source1, source2); + var inValues = new List(); + var hasNullValue = false; + RelationalTypeMapping typeMapping = null; + + switch (valuesExpression) + { + case SqlConstantExpression sqlConstant: + { + CanCache = false; + typeMapping = sqlConstant.TypeMapping; + var values = (IEnumerable)sqlConstant.Value; + foreach (var value in values) + { + if (value == null) + { + hasNullValue = true; + continue; + } + + inValues.Add(value); + } + + break; + } + + case SqlParameterExpression sqlParameter: + { + CanCache = false; + typeMapping = sqlParameter.TypeMapping; + var values = (IEnumerable)_parameterValues[sqlParameter.Name]; + foreach (var value in values) + { + if (value == null) + { + hasNullValue = true; + continue; + } + + inValues.Add(value); + } + + break; + } + } + + var processedValues = inValues.Count > 0 + ? (SqlExpression)Visit(_sqlExpressionFactory.Constant(inValues, typeMapping)) + : null; + + return (processedValues, hasNullValue); } - protected override Expression VisitLike(LikeExpression likeExpression) + protected override Expression VisitInnerJoin(InnerJoinExpression innerJoinExpression) { var canOptimize = _canOptimize; _canOptimize = false; - _isNullable = false; - var newMatch = (SqlExpression)Visit(likeExpression.Match); - var isNullable = _isNullable; - _isNullable = false; - var newPattern = (SqlExpression)Visit(likeExpression.Pattern); - isNullable |= _isNullable; - _isNullable = false; - var newEscapeChar = (SqlExpression)Visit(likeExpression.EscapeChar); - _isNullable |= isNullable; + var newTable = (TableExpressionBase)Visit(innerJoinExpression.Table); + var newJoinPredicate = VisitJoinPredicate((SqlBinaryExpression)innerJoinExpression.JoinPredicate); _canOptimize = canOptimize; - return likeExpression.Update(newMatch, newPattern, newEscapeChar); + return innerJoinExpression.Update(newTable, newJoinPredicate); } - protected override Expression VisitInnerJoin(InnerJoinExpression innerJoinExpression) + protected override Expression VisitIntersect(IntersectExpression intersectExpression) { var canOptimize = _canOptimize; _canOptimize = false; - var newTable = (TableExpressionBase)Visit(innerJoinExpression.Table); - var newJoinPredicate = VisitJoinPredicate((SqlBinaryExpression)innerJoinExpression.JoinPredicate); + var source1 = (SelectExpression)Visit(intersectExpression.Source1); + var source2 = (SelectExpression)Visit(intersectExpression.Source2); _canOptimize = canOptimize; - return innerJoinExpression.Update(newTable, newJoinPredicate); + return intersectExpression.Update(source1, source2); } protected override Expression VisitLeftJoin(LeftJoinExpression leftJoinExpression) @@ -181,11 +301,17 @@ private SqlExpression VisitJoinPredicate(SqlBinaryExpression predicate) if (predicate.OperatorType == ExpressionType.Equal) { - var newLeft = (SqlExpression)Visit(predicate.Left); - var newRight = (SqlExpression)Visit(predicate.Right); + _isNullable = false; + var left = (SqlExpression)Visit(predicate.Left); + var leftNullable = _isNullable; + _isNullable = false; + var right = (SqlExpression)Visit(predicate.Right); + var rightNullable = _isNullable; + + var result = OptimizeComparison(predicate.Update(left, right), left, right, leftNullable, rightNullable, _canOptimize); _canOptimize = canOptimize; - return predicate.Update(newLeft, newRight); + return result; } if (predicate.OperatorType == ExpressionType.AndAlso) @@ -199,6 +325,24 @@ private SqlExpression VisitJoinPredicate(SqlBinaryExpression predicate) throw new InvalidOperationException("Unexpected join predicate shape: " + predicate); } + protected override Expression VisitLike(LikeExpression likeExpression) + { + var canOptimize = _canOptimize; + _canOptimize = false; + _isNullable = false; + var newMatch = (SqlExpression)Visit(likeExpression.Match); + var isNullable = _isNullable; + _isNullable = false; + var newPattern = (SqlExpression)Visit(likeExpression.Pattern); + isNullable |= _isNullable; + _isNullable = false; + var newEscapeChar = (SqlExpression)Visit(likeExpression.EscapeChar); + _isNullable |= isNullable; + _canOptimize = canOptimize; + + return likeExpression.Update(newMatch, newPattern, newEscapeChar); + } + protected override Expression VisitOrdering(OrderingExpression orderingExpression) { var expression = (SqlExpression)Visit(orderingExpression.Expression); @@ -330,9 +474,6 @@ protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpres _isNullable = false; var canOptimize = _canOptimize; - // for SqlServer we could also allow optimize on children of ExpressionType.Equal - // because they get converted to CASE blocks anyway, but for other providers it's incorrect - // once/if null semantics optimizations are provider-specific we can enable it _canOptimize = _canOptimize && (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso || sqlBinaryExpression.OperatorType == ExpressionType.OrElse); @@ -370,169 +511,315 @@ protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpres if (sqlBinaryExpression.OperatorType == ExpressionType.Equal || sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) { - var leftConstantNull = newLeft is SqlConstantExpression leftConstant && leftConstant.Value == null; - var rightConstantNull = newRight is SqlConstantExpression rightConstant && rightConstant.Value == null; - - // a == null -> a IS NULL - // a != null -> a IS NOT NULL - if (rightConstantNull) + var updated = sqlBinaryExpression.Update(newLeft, newRight); + + var optimized = OptimizeComparison( + updated, + newLeft, + newRight, + leftNullable, + rightNullable, + canOptimize); + + // we assume that NullSemantics rewrite is only needed (on the current level) if the optimization didn't make any changes. + // Reason is that optimization can/will change the nullability of the resulting expression and that is not tracked/stored, + // so we can no longer rely on nullabilities that we computer earlier (leftNullable, rightNullable) + // when performing null semantics rewrite. + // It should be fine because current optimizations *radically* change the expression (e.g. binary -> unary, or binary -> constant) + // but we need to pay attention in the future if we introduce more subtle transformations here + if (optimized == updated + && (leftNullable || rightNullable) + && !_useRelationalNulls) { - _isNullable = false; + var result = RewriteNullSemantics( + updated, + updated.Left, + updated.Right, + leftNullable, + rightNullable, + canOptimize); + _canOptimize = canOptimize; - return sqlBinaryExpression.OperatorType == ExpressionType.Equal - ? _sqlExpressionFactory.IsNull(newLeft) - : _sqlExpressionFactory.IsNotNull(newLeft); + return result; } - // null == a -> a IS NULL - // null != a -> a IS NOT NULL - if (leftConstantNull) - { - _isNullable = false; - _canOptimize = canOptimize; + _canOptimize = canOptimize; - return sqlBinaryExpression.OperatorType == ExpressionType.Equal - ? _sqlExpressionFactory.IsNull(newRight) - : _sqlExpressionFactory.IsNotNull(newRight); - } + return optimized; + } + + _isNullable = leftNullable || rightNullable; + _canOptimize = canOptimize; + + return sqlBinaryExpression.Update(newLeft, newRight); + } + + private SqlExpression OptimizeComparison( + SqlBinaryExpression sqlBinaryExpression, + SqlExpression left, + SqlExpression right, + bool leftNullable, + bool rightNullable, + bool canOptimize) + { + var leftNullValue = leftNullable && (left is SqlConstantExpression || left is SqlParameterExpression); + var rightNullValue = rightNullable && (right is SqlConstantExpression || right is SqlParameterExpression); + + // a == null -> a IS NULL + // a != null -> a IS NOT NULL + if (rightNullValue) + { + var result = sqlBinaryExpression.OperatorType == ExpressionType.Equal + ? ProcessNullNotNull(_sqlExpressionFactory.IsNull(left), left, leftNullable) + : ProcessNullNotNull(_sqlExpressionFactory.IsNotNull(left), left, leftNullable); + + _isNullable = false; + _canOptimize = canOptimize; + + return result; + } + + // null == a -> a IS NULL + // null != a -> a IS NOT NULL + if (leftNullValue) + { + var result = sqlBinaryExpression.OperatorType == ExpressionType.Equal + ? ProcessNullNotNull(_sqlExpressionFactory.IsNull(right), right, rightNullable) + : ProcessNullNotNull(_sqlExpressionFactory.IsNotNull(right), right, rightNullable); + + _isNullable = false; + _canOptimize = canOptimize; + + return result; + } - var leftUnary = newLeft as SqlUnaryExpression; - var rightUnary = newRight as SqlUnaryExpression; + if (IsTrueOrFalse(right) is bool rightTrueFalseValue + && !leftNullable) + { + _isNullable = leftNullable; + _canOptimize = canOptimize; + + // only correct in 2-value logic + // a == true -> a + // a == false -> !a + // a != true -> !a + // a != false -> a + return sqlBinaryExpression.OperatorType == ExpressionType.Equal + ? rightTrueFalseValue + ? left + : _sqlExpressionFactory.Not(left) + : rightTrueFalseValue + ? _sqlExpressionFactory.Not(left) + : left; + } + + if (IsTrueOrFalse(left) is bool leftTrueFalseValue + && !rightNullable) + { + _isNullable = rightNullable; + _canOptimize = canOptimize; + + // only correct in 2-value logic + // true == a -> a + // false == a -> !a + // true != a -> !a + // false != a -> a + return sqlBinaryExpression.OperatorType == ExpressionType.Equal + ? leftTrueFalseValue + ? right + : _sqlExpressionFactory.Not(right) + : leftTrueFalseValue + ? _sqlExpressionFactory.Not(right) + : right; + } + + // only correct in 2-value logic + // a == a -> true + // a != a -> false + if (!leftNullable + && left.Equals(right)) + { + _isNullable = false; + _canOptimize = canOptimize; + + return _sqlExpressionFactory.Constant( + sqlBinaryExpression.OperatorType == ExpressionType.Equal, + sqlBinaryExpression.TypeMapping); + } + + if (!leftNullable + && !rightNullable + && (sqlBinaryExpression.OperatorType == ExpressionType.Equal || sqlBinaryExpression.OperatorType == ExpressionType.NotEqual)) + { + var leftUnary = left as SqlUnaryExpression; + var rightUnary = right as SqlUnaryExpression; var leftNegated = leftUnary?.IsLogicalNot() == true; var rightNegated = rightUnary?.IsLogicalNot() == true; if (leftNegated) { - newLeft = leftUnary.Operand; + left = leftUnary.Operand; } if (rightNegated) { - newRight = rightUnary.Operand; + right = rightUnary.Operand; } - var leftIsNull = _sqlExpressionFactory.IsNull(newLeft); - var rightIsNull = _sqlExpressionFactory.IsNull(newRight); + // a == b <=> !a == !b -> a == b + // !a == b <=> a == !b -> a != b + // a != b <=> !a != !b -> a != b + // !a != b <=> a != !b -> a == b + return sqlBinaryExpression.OperatorType == ExpressionType.Equal + ? leftNegated == rightNegated + ? _sqlExpressionFactory.Equal(left, right) + : _sqlExpressionFactory.NotEqual(left, right) + : leftNegated == rightNegated + ? _sqlExpressionFactory.NotEqual(left, right) + : _sqlExpressionFactory.Equal(left, right); + } + + return sqlBinaryExpression.Update(left, right); - // optimized expansion which doesn't distinguish between null and false - if (canOptimize - && sqlBinaryExpression.OperatorType == ExpressionType.Equal - && !leftNegated - && !rightNegated) + bool? IsTrueOrFalse(SqlExpression sqlExpression) + { + if (sqlExpression is SqlConstantExpression sqlConstantExpression && sqlConstantExpression.Value is bool boolConstant) { - // when we use optimized form, the result can still be nullable - if (leftNullable && rightNullable) - { - _isNullable = true; - _canOptimize = canOptimize; + return boolConstant; + } - return _sqlExpressionFactory.OrElse( - _sqlExpressionFactory.Equal(newLeft, newRight), - _sqlExpressionFactory.AndAlso(leftIsNull, rightIsNull)); - } + // TODO: should we do this for parameters also? + //if (sqlExpression is SqlParameterExpression sqlParameterExpression && _parameterValues[sqlParameterExpression.Name] is bool boolParameter) + //{ + // return boolParameter; + //} - if ((leftNullable && !rightNullable) - || (!leftNullable && rightNullable)) - { - _isNullable = true; - _canOptimize = canOptimize; + return null; + } + } - return _sqlExpressionFactory.Equal(newLeft, newRight); - } - } + private SqlExpression RewriteNullSemantics( + SqlBinaryExpression sqlBinaryExpression, + SqlExpression left, + SqlExpression right, + bool leftNullable, + bool rightNullable, + bool canOptimize) + { + var leftUnary = left as SqlUnaryExpression; + var rightUnary = right as SqlUnaryExpression; - // doing a full null semantics rewrite - removing all nulls from truth table - // this will NOT be correct once we introduce simplified null semantics - _isNullable = false; - _canOptimize = canOptimize; + var leftNegated = leftUnary?.IsLogicalNot() == true; + var rightNegated = rightUnary?.IsLogicalNot() == true; - if (sqlBinaryExpression.OperatorType == ExpressionType.Equal) - { - if (!leftNullable - && !rightNullable) - { - // a == b <=> !a == !b -> a == b - // !a == b <=> a == !b -> a != b - return leftNegated == rightNegated - ? _sqlExpressionFactory.Equal(newLeft, newRight) - : _sqlExpressionFactory.NotEqual(newLeft, newRight); - } + if (leftNegated) + { + left = leftUnary.Operand; + } - if (leftNullable && rightNullable) - { - // ?a == ?b <=> !(?a) == !(?b) -> [(a == b) && (a != null && b != null)] || (a == null && b == null)) - // !(?a) == ?b <=> ?a == !(?b) -> [(a != b) && (a != null && b != null)] || (a == null && b == null) - return leftNegated == rightNegated - ? ExpandNullableEqualNullable(newLeft, newRight, leftIsNull, rightIsNull) - : ExpandNegatedNullableEqualNullable(newLeft, newRight, leftIsNull, rightIsNull); - } + if (rightNegated) + { + right = rightUnary.Operand; + } - if (leftNullable && !rightNullable) - { - // ?a == b <=> !(?a) == !b -> (a == b) && (a != null) - // !(?a) == b <=> ?a == !b -> (a != b) && (a != null) - return leftNegated == rightNegated - ? ExpandNullableEqualNonNullable(newLeft, newRight, leftIsNull) - : ExpandNegatedNullableEqualNonNullable(newLeft, newRight, leftIsNull); - } + var leftIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(left), left, leftNullable); + var rightIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(right), right, rightNullable); - if (rightNullable && !leftNullable) - { - // a == ?b <=> !a == !(?b) -> (a == b) && (b != null) - // !a == ?b <=> a == !(?b) -> (a != b) && (b != null) - return leftNegated == rightNegated - ? ExpandNullableEqualNonNullable(newLeft, newRight, rightIsNull) - : ExpandNegatedNullableEqualNonNullable(newLeft, newRight, rightIsNull); - } + // optimized expansion which doesn't distinguish between null and false + if (canOptimize + && sqlBinaryExpression.OperatorType == ExpressionType.Equal + && !leftNegated + && !rightNegated) + { + // when we use optimized form, the result can still be nullable + if (leftNullable && rightNullable) + { + _isNullable = true; + _canOptimize = canOptimize; + + return _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.Equal(left, right), + _sqlExpressionFactory.AndAlso(leftIsNull, rightIsNull)); } - if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) + if ((leftNullable && !rightNullable) + || (!leftNullable && rightNullable)) { - if (!leftNullable - && !rightNullable) - { - // a != b <=> !a != !b -> a != b - // !a != b <=> a != !b -> a == b - return leftNegated == rightNegated - ? _sqlExpressionFactory.NotEqual(newLeft, newRight) - : _sqlExpressionFactory.Equal(newLeft, newRight); - } + _isNullable = true; + _canOptimize = canOptimize; - if (leftNullable && rightNullable) - { - // ?a != ?b <=> !(?a) != !(?b) -> [(a != b) || (a == null || b == null)] && (a != null || b != null) - // !(?a) != ?b <=> ?a != !(?b) -> [(a == b) || (a == null || b == null)] && (a != null || b != null) - return leftNegated == rightNegated - ? ExpandNullableNotEqualNullable(newLeft, newRight, leftIsNull, rightIsNull) - : ExpandNegatedNullableNotEqualNullable(newLeft, newRight, leftIsNull, rightIsNull); - } + return _sqlExpressionFactory.Equal(left, right); + } + } - if (leftNullable) - { - // ?a != b <=> !(?a) != !b -> (a != b) || (a == null) - // !(?a) != b <=> ?a != !b -> (a == b) || (a == null) - return leftNegated == rightNegated - ? ExpandNullableNotEqualNonNullable(newLeft, newRight, leftIsNull) - : ExpandNegatedNullableNotEqualNonNullable(newLeft, newRight, leftIsNull); - } + // doing a full null semantics rewrite - removing all nulls from truth table + // this will NOT be correct once we introduce simplified null semantics + _isNullable = false; + _canOptimize = canOptimize; - if (rightNullable) - { - // a != ?b <=> !a != !(?b) -> (a != b) || (b == null) - // !a != ?b <=> a != !(?b) -> (a == b) || (b == null) - return leftNegated == rightNegated - ? ExpandNullableNotEqualNonNullable(newLeft, newRight, rightIsNull) - : ExpandNegatedNullableNotEqualNonNullable(newLeft, newRight, rightIsNull); - } + if (sqlBinaryExpression.OperatorType == ExpressionType.Equal) + { + if (leftNullable && rightNullable) + { + // ?a == ?b <=> !(?a) == !(?b) -> [(a == b) && (a != null && b != null)] || (a == null && b == null)) + // !(?a) == ?b <=> ?a == !(?b) -> [(a != b) && (a != null && b != null)] || (a == null && b == null) + return leftNegated == rightNegated + ? ExpandNullableEqualNullable(left, right, leftIsNull, rightIsNull) + : ExpandNegatedNullableEqualNullable(left, right, leftIsNull, rightIsNull); + } + + if (leftNullable && !rightNullable) + { + // ?a == b <=> !(?a) == !b -> (a == b) && (a != null) + // !(?a) == b <=> ?a == !b -> (a != b) && (a != null) + return leftNegated == rightNegated + ? ExpandNullableEqualNonNullable(left, right, leftIsNull) + : ExpandNegatedNullableEqualNonNullable(left, right, leftIsNull); + } + + if (rightNullable && !leftNullable) + { + // a == ?b <=> !a == !(?b) -> (a == b) && (b != null) + // !a == ?b <=> a == !(?b) -> (a != b) && (b != null) + return leftNegated == rightNegated + ? ExpandNullableEqualNonNullable(left, right, rightIsNull) + : ExpandNegatedNullableEqualNonNullable(left, right, rightIsNull); } } - _isNullable = leftNullable || rightNullable; - _canOptimize = canOptimize; + if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) + { + if (leftNullable && rightNullable) + { + // ?a != ?b <=> !(?a) != !(?b) -> [(a != b) || (a == null || b == null)] && (a != null || b != null) + // !(?a) != ?b <=> ?a != !(?b) -> [(a == b) || (a == null || b == null)] && (a != null || b != null) + return leftNegated == rightNegated + ? ExpandNullableNotEqualNullable(left, right, leftIsNull, rightIsNull) + : ExpandNegatedNullableNotEqualNullable(left, right, leftIsNull, rightIsNull); + } - return sqlBinaryExpression.Update(newLeft, newRight); + if (leftNullable) + { + // ?a != b <=> !(?a) != !b -> (a != b) || (a == null) + // !(?a) != b <=> ?a != !b -> (a == b) || (a == null) + return leftNegated == rightNegated + ? ExpandNullableNotEqualNonNullable(left, right, leftIsNull) + : ExpandNegatedNullableNotEqualNonNullable(left, right, leftIsNull); + } + + if (rightNullable) + { + // a != ?b <=> !a != !(?b) -> (a != b) || (b == null) + // !a != ?b <=> a != !(?b) -> (a == b) || (b == null) + return leftNegated == rightNegated + ? ExpandNullableNotEqualNonNullable(left, right, rightIsNull) + : ExpandNegatedNullableNotEqualNonNullable(left, right, rightIsNull); + } + } + + return sqlBinaryExpression.Update(left, right); } protected override Expression VisitSqlConstant(SqlConstantExpression sqlConstantExpression) @@ -567,31 +854,206 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction protected override Expression VisitSqlParameter(SqlParameterExpression sqlParameterExpression) { - // at this point we assume every parameter is nullable, we will filter out the non-nullable ones once we know the actual values - _isNullable = true; + _isNullable = _parameterValues[sqlParameterExpression.Name] == null; return sqlParameterExpression; } - protected override Expression VisitSqlUnary(SqlUnaryExpression sqlCastExpression) + protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpression) { _isNullable = false; - var canOptimize = _canOptimize; _canOptimize = false; - var newOperand = (SqlExpression)Visit(sqlCastExpression.Operand); + var operand = (SqlExpression)Visit(sqlUnaryExpression.Operand); + + _canOptimize = canOptimize; - // result of IsNull/IsNotNull can never be null - if (sqlCastExpression.OperatorType == ExpressionType.Equal - || sqlCastExpression.OperatorType == ExpressionType.NotEqual) + if (sqlUnaryExpression.OperatorType == ExpressionType.Equal + || sqlUnaryExpression.OperatorType == ExpressionType.NotEqual) { + // result of IsNull/IsNotNull can never be null + var isNullable = _isNullable; _isNullable = false; + + return ProcessNullNotNull(sqlUnaryExpression, operand, isNullable); } - _canOptimize = canOptimize; + if (operand is SqlBinaryExpression sqlBinaryOperand) + { + // only correct in 2-value logic + // !(a == b) -> a != b + // !(a != b) -> a == b + // !(a > b) -> a <= b + // !(a >= b) -> a < b + // !(a < b) -> a >= b + // !(a <= b) -> a > b + if (!_isNullable + && TryNegate(sqlBinaryOperand.OperatorType, out var negated)) + { + return _sqlExpressionFactory.MakeBinary( + negated, + sqlBinaryOperand.Left, + sqlBinaryOperand.Right, + sqlBinaryOperand.TypeMapping); + } + } + + return sqlUnaryExpression.Update(operand); + + static bool TryNegate(ExpressionType expressionType, out ExpressionType result) + { + var negated = expressionType switch + { + ExpressionType.Equal => ExpressionType.NotEqual, + ExpressionType.NotEqual => ExpressionType.Equal, + ExpressionType.GreaterThan => ExpressionType.LessThanOrEqual, + ExpressionType.GreaterThanOrEqual => ExpressionType.LessThan, + ExpressionType.LessThan => ExpressionType.GreaterThanOrEqual, + ExpressionType.LessThanOrEqual => ExpressionType.GreaterThan, + _ => (ExpressionType?)null + }; + + result = negated ?? default; + + return negated.HasValue; + } + } + + private SqlExpression ProcessNullNotNull( + SqlUnaryExpression sqlUnaryExpression, + SqlExpression operand, + bool? operandNullable) + { + if (operandNullable == false) + { + // when we know that operand is non-nullable: + // not_null_operand is null-> false + // not_null_operand is not null -> true + return _sqlExpressionFactory.Constant( + sqlUnaryExpression.OperatorType == ExpressionType.NotEqual, + sqlUnaryExpression.TypeMapping); + } + + switch (operand) + { + case SqlConstantExpression sqlConstantOperand: + // null_value_constant is null -> true + // null_value_constant is not null -> false + // not_null_value_constant is null -> false + // not_null_value_constant is not null -> true + return sqlConstantOperand.Value == null + ? _sqlExpressionFactory.Constant( + sqlUnaryExpression.OperatorType == ExpressionType.Equal, + sqlUnaryExpression.TypeMapping) + : _sqlExpressionFactory.Constant( + sqlUnaryExpression.OperatorType == ExpressionType.NotEqual, + sqlUnaryExpression.TypeMapping); + + case SqlParameterExpression sqlParameterOperand: + // null_value_parameter is null -> true + // null_value_parameter is not null -> false + // not_null_value_parameter is null -> false + // not_null_value_parameter is not null -> true + return _parameterValues[sqlParameterOperand.Name] == null + ? _sqlExpressionFactory.Constant( + sqlUnaryExpression.OperatorType == ExpressionType.Equal, + sqlUnaryExpression.TypeMapping) + : _sqlExpressionFactory.Constant( + sqlUnaryExpression.OperatorType == ExpressionType.NotEqual, + sqlUnaryExpression.TypeMapping); + + case ColumnExpression columnOperand + when !columnOperand.IsNullable || _nonNullableColumns.Contains(columnOperand): + { + // IsNull(non_nullable_column) -> false + // IsNotNull(non_nullable_column) -> true + return _sqlExpressionFactory.Constant( + sqlUnaryExpression.OperatorType == ExpressionType.NotEqual, + sqlUnaryExpression.TypeMapping); + } + + case SqlUnaryExpression sqlUnaryOperand: + switch (sqlUnaryOperand.OperatorType) + { + case ExpressionType.Convert: + case ExpressionType.Not: + case ExpressionType.Negate: + // op(a) is null -> a is null + // op(a) is not null -> a is not null + return ProcessNullNotNull( + _sqlExpressionFactory.MakeUnary( + sqlUnaryExpression.OperatorType, + sqlUnaryOperand.Operand, + sqlUnaryExpression.Type, + sqlUnaryExpression.TypeMapping), + sqlUnaryOperand.Operand, + operandNullable); + + case ExpressionType.Equal: + case ExpressionType.NotEqual: + // (a is null) is null -> false + // (a is not null) is null -> false + // (a is null) is not null -> true + // (a is not null) is not null -> true + return _sqlExpressionFactory.Constant( + sqlUnaryOperand.OperatorType == ExpressionType.NotEqual, + sqlUnaryOperand.TypeMapping); + } + break; + + case SqlBinaryExpression sqlBinaryOperand + when sqlBinaryOperand.OperatorType != ExpressionType.AndAlso + && sqlBinaryOperand.OperatorType != ExpressionType.OrElse: + { + // in general: + // binaryOp(a, b) == null -> a == null || b == null + // binaryOp(a, b) != null -> a != null && b != null + // for coalesce: + // (a ?? b) == null -> a == null && b == null + // (a ?? b) != null -> a != null || b != null + // for AndAlso, OrElse we can't do this optimization + // we could do something like this, but it seems too complicated: + // (a && b) == null -> a == null && b != 0 || a != 0 && b == null + // NOTE: we don't preserve nullabilities of left/right individually so we are using nullability binary expression as a whole + // this may lead to missing some optimizations, where one of the operands (left or right) is not nullable and the other one is + var left = ProcessNullNotNull( + _sqlExpressionFactory.MakeUnary( + sqlUnaryExpression.OperatorType, + sqlBinaryOperand.Left, + typeof(bool), + sqlUnaryExpression.TypeMapping), + sqlBinaryOperand.Left, + operandNullable: null); + + var right = ProcessNullNotNull( + _sqlExpressionFactory.MakeUnary( + sqlUnaryExpression.OperatorType, + sqlBinaryOperand.Right, + typeof(bool), + sqlUnaryExpression.TypeMapping), + sqlBinaryOperand.Right, + operandNullable: null); + + return sqlBinaryOperand.OperatorType == ExpressionType.Coalesce + ? _sqlExpressionFactory.MakeBinary( + sqlUnaryExpression.OperatorType == ExpressionType.Equal + ? ExpressionType.AndAlso + : ExpressionType.OrElse, + left, + right, + sqlUnaryExpression.TypeMapping) + : _sqlExpressionFactory.MakeBinary( + sqlUnaryExpression.OperatorType == ExpressionType.Equal + ? ExpressionType.OrElse + : ExpressionType.AndAlso, + left, + right, + sqlUnaryExpression.TypeMapping); + } + } - return sqlCastExpression.Update(newOperand); + return sqlUnaryExpression.Update(operand); } protected override Expression VisitTable(TableExpression tableExpression) diff --git a/src/EFCore.Relational/Query/Internal/SqlExpressionOptimizingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/SqlExpressionOptimizingExpressionVisitor.cs index 5e72b69104c..54f8a798b35 100644 --- a/src/EFCore.Relational/Query/Internal/SqlExpressionOptimizingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/SqlExpressionOptimizingExpressionVisitor.cs @@ -148,7 +148,11 @@ private SqlExpression SimplifyUnaryExpression( break; - case SqlBinaryExpression binaryOperand: + // these optimizations are only valid in 2-value logic + // NullSemantics removes all nulls from expressions wrapped around Not + // so the optimizations are safe to do as long as UseRelationalNulls = false + case SqlBinaryExpression binaryOperand + when !_useRelationalNulls: { // De Morgan's if (binaryOperand.OperatorType == ExpressionType.AndAlso @@ -166,11 +170,7 @@ private SqlExpression SimplifyUnaryExpression( binaryOperand.TypeMapping); } - // those optimizations are only valid in 2-value logic - // they are safe to do here because if we apply null semantics - // because null semantics removes possibility of nulls in the tree when the comparison is wrapped around NOT - if (!_useRelationalNulls - && TryNegate(binaryOperand.OperatorType, out var negated)) + if (TryNegate(binaryOperand.OperatorType, out var negated)) { return SimplifyBinaryExpression( negated, @@ -183,98 +183,6 @@ private SqlExpression SimplifyUnaryExpression( } break; } - - case ExpressionType.Equal: - case ExpressionType.NotEqual: - return SimplifyNullNotNullExpression( - operatorType, - operand, - type, - typeMapping); - } - - return SqlExpressionFactory.MakeUnary(operatorType, operand, type, typeMapping); - } - - private SqlExpression SimplifyNullNotNullExpression( - ExpressionType operatorType, - SqlExpression operand, - Type type, - RelationalTypeMapping typeMapping) - { - switch (operatorType) - { - case ExpressionType.Equal: - case ExpressionType.NotEqual: - switch (operand) - { - case SqlConstantExpression constantOperand: - return SqlExpressionFactory.Constant( - operatorType == ExpressionType.Equal - ? constantOperand.Value == null - : constantOperand.Value != null, - typeMapping); - - case ColumnExpression columnOperand - when !columnOperand.IsNullable: - return SqlExpressionFactory.Constant(operatorType == ExpressionType.NotEqual, typeMapping); - - case SqlUnaryExpression sqlUnaryOperand: - if (sqlUnaryOperand.OperatorType == ExpressionType.Convert - || sqlUnaryOperand.OperatorType == ExpressionType.Not - || sqlUnaryOperand.OperatorType == ExpressionType.Negate) - { - // op(a) is null -> a is null - // op(a) is not null -> a is not null - return SimplifyNullNotNullExpression(operatorType, sqlUnaryOperand.Operand, type, typeMapping); - } - - if (sqlUnaryOperand.OperatorType == ExpressionType.Equal - || sqlUnaryOperand.OperatorType == ExpressionType.NotEqual) - { - // (a is null) is null -> false - // (a is not null) is null -> false - // (a is null) is not null -> true - // (a is not null) is not null -> true - return SqlExpressionFactory.Constant(operatorType == ExpressionType.NotEqual, typeMapping); - } - break; - - case SqlBinaryExpression sqlBinaryOperand: - // in general: - // binaryOp(a, b) == null -> a == null || b == null - // binaryOp(a, b) != null -> a != null && b != null - // for coalesce: - // (a ?? b) == null -> a == null && b == null - // (a ?? b) != null -> a != null || b != null - // for AndAlso, OrElse we can't do this optimization - // we could do something like this, but it seems too complicated: - // (a && b) == null -> a == null && b != 0 || a != 0 && b == null - if (sqlBinaryOperand.OperatorType != ExpressionType.AndAlso - && sqlBinaryOperand.OperatorType != ExpressionType.OrElse) - { - var newLeft = SimplifyNullNotNullExpression(operatorType, sqlBinaryOperand.Left, typeof(bool), typeMapping); - var newRight = SimplifyNullNotNullExpression(operatorType, sqlBinaryOperand.Right, typeof(bool), typeMapping); - - return sqlBinaryOperand.OperatorType == ExpressionType.Coalesce - ? SimplifyLogicalSqlBinaryExpression( - operatorType == ExpressionType.Equal - ? ExpressionType.AndAlso - : ExpressionType.OrElse, - newLeft, - newRight, - typeMapping) - : SimplifyLogicalSqlBinaryExpression( - operatorType == ExpressionType.Equal - ? ExpressionType.OrElse - : ExpressionType.AndAlso, - newLeft, - newRight, - typeMapping); - } - break; - } - break; } return SqlExpressionFactory.MakeUnary(operatorType, operand, type, typeMapping); @@ -326,143 +234,11 @@ private SqlExpression SimplifyBinaryExpression( left, right, typeMapping); - - case ExpressionType.Equal: - case ExpressionType.NotEqual: - var leftConstant = left as SqlConstantExpression; - var rightConstant = right as SqlConstantExpression; - var leftNullConstant = leftConstant != null && leftConstant.Value == null; - var rightNullConstant = rightConstant != null && rightConstant.Value == null; - if (leftNullConstant || rightNullConstant) - { - return SimplifyNullComparisonExpression( - operatorType, - left, - right, - leftNullConstant, - rightNullConstant, - typeMapping); - } - - var leftBoolConstant = left.Type == typeof(bool) ? leftConstant : null; - var rightBoolConstant = right.Type == typeof(bool) ? rightConstant : null; - if (leftBoolConstant != null || rightBoolConstant != null) - { - return SimplifyBoolConstantComparisonExpression( - operatorType, - left, - right, - leftBoolConstant, - rightBoolConstant, - typeMapping); - } - - // only works when a is not nullable - // a == a -> true - // a != a -> false - if ((left is LikeExpression - || left is ColumnExpression columnExpression && !columnExpression.IsNullable) - && left.Equals(right)) - { - return SqlExpressionFactory.Constant(operatorType == ExpressionType.Equal, typeMapping); - } - - break; } return SqlExpressionFactory.MakeBinary(operatorType, left, right, typeMapping); } - protected virtual SqlExpression SimplifyNullComparisonExpression( - ExpressionType operatorType, - SqlExpression left, - SqlExpression right, - bool leftNull, - bool rightNull, - RelationalTypeMapping typeMapping) - { - if ((operatorType == ExpressionType.Equal || operatorType == ExpressionType.NotEqual) - && (leftNull || rightNull)) - { - if (leftNull && rightNull) - { - return SqlExpressionFactory.Constant(operatorType == ExpressionType.Equal, typeMapping); - } - - if (leftNull) - { - return SimplifyNullNotNullExpression(operatorType, right, typeof(bool), typeMapping); - } - - if (rightNull) - { - return SimplifyNullNotNullExpression(operatorType, left, typeof(bool), typeMapping); - } - } - - return SqlExpressionFactory.MakeBinary(operatorType, left, right, typeMapping); - } - - private SqlExpression SimplifyBoolConstantComparisonExpression( - ExpressionType operatorType, - SqlExpression left, - SqlExpression right, - SqlConstantExpression leftBoolConstant, - SqlConstantExpression rightBoolConstant, - RelationalTypeMapping typeMapping) - { - if (leftBoolConstant != null && rightBoolConstant != null) - { - return operatorType == ExpressionType.Equal - ? SqlExpressionFactory.Constant((bool)leftBoolConstant.Value == (bool)rightBoolConstant.Value, typeMapping) - : SqlExpressionFactory.Constant((bool)leftBoolConstant.Value != (bool)rightBoolConstant.Value, typeMapping); - } - - if (rightBoolConstant != null - && CanOptimize(left)) - { - // a == true -> a - // a == false -> !a - // a != true -> !a - // a != false -> a - // only correct when f(x) can't be null - return operatorType == ExpressionType.Equal - ? (bool)rightBoolConstant.Value - ? left - : SimplifyUnaryExpression(ExpressionType.Not, left, typeof(bool), typeMapping) - : (bool)rightBoolConstant.Value - ? SimplifyUnaryExpression(ExpressionType.Not, left, typeof(bool), typeMapping) - : left; - } - - if (leftBoolConstant != null - && CanOptimize(right)) - { - // true == a -> a - // false == a -> !a - // true != a -> !a - // false != a -> a - // only correct when a can't be null - return operatorType == ExpressionType.Equal - ? (bool)leftBoolConstant.Value - ? right - : SimplifyUnaryExpression(ExpressionType.Not, right, typeof(bool), typeMapping) - : (bool)leftBoolConstant.Value - ? SimplifyUnaryExpression(ExpressionType.Not, right, typeof(bool), typeMapping) - : right; - } - - return SqlExpressionFactory.MakeBinary(operatorType, left, right, typeMapping); - - static bool CanOptimize(SqlExpression operand) - => operand is LikeExpression - || (operand is SqlUnaryExpression sqlUnary - && (sqlUnary.OperatorType == ExpressionType.Equal - || sqlUnary.OperatorType == ExpressionType.NotEqual - // TODO: #18689 - /*|| sqlUnary.OperatorType == ExpressionType.Not*/)); - } - private SqlExpression SimplifyLogicalSqlBinaryExpression( ExpressionType operatorType, SqlExpression left, diff --git a/src/EFCore.Relational/Query/RelationalParameterBasedQueryTranslationPostprocessor.cs b/src/EFCore.Relational/Query/RelationalParameterBasedQueryTranslationPostprocessor.cs index 95c7b171976..517a51f53f4 100644 --- a/src/EFCore.Relational/Query/RelationalParameterBasedQueryTranslationPostprocessor.cs +++ b/src/EFCore.Relational/Query/RelationalParameterBasedQueryTranslationPostprocessor.cs @@ -38,17 +38,19 @@ public virtual (SelectExpression selectExpression, bool canCache) Optimize( SelectExpression selectExpression, IReadOnlyDictionary parametersValues) { var canCache = true; + var nullSemanticsRewritingExpressionVisitor = new NullSemanticsRewritingExpressionVisitor( + UseRelationalNulls, + Dependencies.SqlExpressionFactory, + parametersValues); - var inExpressionOptimized = new InExpressionValuesExpandingExpressionVisitor( - Dependencies.SqlExpressionFactory, parametersValues).Visit(selectExpression); - - if (!ReferenceEquals(selectExpression, inExpressionOptimized)) + var nullSemanticsOptimized = nullSemanticsRewritingExpressionVisitor.Visit(selectExpression); + if (!nullSemanticsRewritingExpressionVisitor.CanCache) { canCache = false; } var nullParametersOptimized = new ParameterNullabilityBasedSqlExpressionOptimizingExpressionVisitor( - Dependencies.SqlExpressionFactory, UseRelationalNulls, parametersValues).Visit(inExpressionOptimized); + Dependencies.SqlExpressionFactory, UseRelationalNulls, parametersValues).Visit(nullSemanticsOptimized); var fromSqlParameterOptimized = new FromSqlParameterApplyingExpressionVisitor( Dependencies.SqlExpressionFactory, @@ -75,149 +77,6 @@ public ParameterNullabilityBasedSqlExpressionOptimizingExpressionVisitor( { _parametersValues = parametersValues; } - - protected override Expression VisitSqlUnaryExpression(SqlUnaryExpression sqlUnaryExpression) - { - var result = base.VisitSqlUnaryExpression(sqlUnaryExpression); - if (result is SqlUnaryExpression newUnaryExpression - && newUnaryExpression.Operand is SqlParameterExpression parameterOperand) - { - var parameterValue = _parametersValues[parameterOperand.Name]; - if (sqlUnaryExpression.OperatorType == ExpressionType.Equal) - { - return SqlExpressionFactory.Constant(parameterValue == null, sqlUnaryExpression.TypeMapping); - } - - if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual) - { - return SqlExpressionFactory.Constant(parameterValue != null, sqlUnaryExpression.TypeMapping); - } - } - - return result; - } - - protected override Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression) - { - var result = base.VisitSqlBinaryExpression(sqlBinaryExpression); - if (result is SqlBinaryExpression sqlBinaryResult) - { - var leftNullParameter = sqlBinaryResult.Left is SqlParameterExpression leftParameter - && _parametersValues[leftParameter.Name] == null; - - var rightNullParameter = sqlBinaryResult.Right is SqlParameterExpression rightParameter - && _parametersValues[rightParameter.Name] == null; - - if ((sqlBinaryResult.OperatorType == ExpressionType.Equal || sqlBinaryResult.OperatorType == ExpressionType.NotEqual) - && (leftNullParameter || rightNullParameter)) - { - return SimplifyNullComparisonExpression( - sqlBinaryResult.OperatorType, - sqlBinaryResult.Left, - sqlBinaryResult.Right, - leftNullParameter, - rightNullParameter, - sqlBinaryResult.TypeMapping); - } - } - - return result; - } - } - - private sealed class InExpressionValuesExpandingExpressionVisitor : ExpressionVisitor - { - private readonly ISqlExpressionFactory _sqlExpressionFactory; - private readonly IReadOnlyDictionary _parametersValues; - - public InExpressionValuesExpandingExpressionVisitor( - ISqlExpressionFactory sqlExpressionFactory, IReadOnlyDictionary parametersValues) - { - _sqlExpressionFactory = sqlExpressionFactory; - _parametersValues = parametersValues; - } - - public override Expression Visit(Expression expression) - { - if (expression is InExpression inExpression - && inExpression.Values != null) - { - var inValues = new List(); - var hasNullValue = false; - RelationalTypeMapping typeMapping = null; - - switch (inExpression.Values) - { - case SqlConstantExpression sqlConstant: - { - typeMapping = sqlConstant.TypeMapping; - var values = (IEnumerable)sqlConstant.Value; - foreach (var value in values) - { - if (value == null) - { - hasNullValue = true; - continue; - } - - inValues.Add(value); - } - - break; - } - - case SqlParameterExpression sqlParameter: - { - typeMapping = sqlParameter.TypeMapping; - var values = (IEnumerable)_parametersValues[sqlParameter.Name]; - foreach (var value in values) - { - if (value == null) - { - hasNullValue = true; - continue; - } - - inValues.Add(value); - } - - break; - } - } - - var updatedInExpression = inValues.Count > 0 - ? _sqlExpressionFactory.In( - (SqlExpression)Visit(inExpression.Item), - _sqlExpressionFactory.Constant(inValues, typeMapping), - inExpression.IsNegated) - : null; - - var nullCheckExpression = hasNullValue - ? inExpression.IsNegated - ? _sqlExpressionFactory.IsNotNull(inExpression.Item) - : _sqlExpressionFactory.IsNull(inExpression.Item) - : null; - - if (updatedInExpression != null - && nullCheckExpression != null) - { - return inExpression.IsNegated - ? _sqlExpressionFactory.AndAlso(updatedInExpression, nullCheckExpression) - : _sqlExpressionFactory.OrElse(updatedInExpression, nullCheckExpression); - } - - if (updatedInExpression == null - && nullCheckExpression == null) - { - return _sqlExpressionFactory.Equal( - _sqlExpressionFactory.Constant(true), _sqlExpressionFactory.Constant(inExpression.IsNegated)); - } - - return (SqlExpression)updatedInExpression ?? nullCheckExpression; - } - - return base.Visit(expression); - } } private sealed class FromSqlParameterApplyingExpressionVisitor : ExpressionVisitor diff --git a/src/EFCore.Relational/Query/RelationalQueryTranslationPostprocessor.cs b/src/EFCore.Relational/Query/RelationalQueryTranslationPostprocessor.cs index 1c51cf88d51..928b073fe83 100644 --- a/src/EFCore.Relational/Query/RelationalQueryTranslationPostprocessor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryTranslationPostprocessor.cs @@ -9,8 +9,6 @@ namespace Microsoft.EntityFrameworkCore.Query { public class RelationalQueryTranslationPostprocessor : QueryTranslationPostprocessor { - private readonly SqlExpressionOptimizingExpressionVisitor _sqlExpressionOptimizingExpressionVisitor; - public RelationalQueryTranslationPostprocessor( QueryTranslationPostprocessorDependencies dependencies, RelationalQueryTranslationPostprocessorDependencies relationalDependencies, @@ -20,8 +18,6 @@ public RelationalQueryTranslationPostprocessor( RelationalDependencies = relationalDependencies; UseRelationalNulls = RelationalOptionsExtension.Extract(queryCompilationContext.ContextOptions).UseRelationalNulls; SqlExpressionFactory = relationalDependencies.SqlExpressionFactory; - _sqlExpressionOptimizingExpressionVisitor - = new SqlExpressionOptimizingExpressionVisitor(SqlExpressionFactory, UseRelationalNulls); } protected virtual RelationalQueryTranslationPostprocessorDependencies RelationalDependencies { get; } @@ -37,17 +33,12 @@ public override Expression Process(Expression query) query = new CollectionJoinApplyingExpressionVisitor().Visit(query); query = new TableAliasUniquifyingExpressionVisitor().Visit(query); query = new CaseWhenFlatteningExpressionVisitor(SqlExpressionFactory).Visit(query); - - if (!UseRelationalNulls) - { - query = new NullSemanticsRewritingExpressionVisitor(SqlExpressionFactory).Visit(query); - } - query = OptimizeSqlExpression(query); return query; } - protected virtual Expression OptimizeSqlExpression(Expression query) => _sqlExpressionOptimizingExpressionVisitor.Visit(query); + protected virtual Expression OptimizeSqlExpression(Expression query) + => query; } } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerParameterBasedQueryTranslationPostprocessor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerParameterBasedQueryTranslationPostprocessor.cs index 62d277f4e9e..31257c9d2a4 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerParameterBasedQueryTranslationPostprocessor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerParameterBasedQueryTranslationPostprocessor.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal @@ -25,7 +26,10 @@ public override (SelectExpression selectExpression, bool canCache) Optimize( var searchConditionOptimized = (SelectExpression)new SearchConditionConvertingExpressionVisitor(Dependencies.SqlExpressionFactory) .Visit(optimizedSelectExpression); - return (searchConditionOptimized, canCache); + var optimized = (SelectExpression)new SqlExpressionOptimizingExpressionVisitor( + Dependencies.SqlExpressionFactory, UseRelationalNulls).Visit(searchConditionOptimized); + + return (optimized, canCache); } } } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryTranslationPostprocessor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryTranslationPostprocessor.cs index 500cf13d63a..6a90b8e4399 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryTranslationPostprocessor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryTranslationPostprocessor.cs @@ -1,7 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Query; namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal @@ -15,13 +14,5 @@ public SqlServerQueryTranslationPostprocessor( : base(dependencies, relationalDependencies, queryCompilationContext) { } - - public override Expression Process(Expression query) - { - query = base.Process(query); - query = new SearchConditionConvertingExpressionVisitor(SqlExpressionFactory).Visit(query); - - return query; - } } } diff --git a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs index 830dd8f0f51..57615c5b04e 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs @@ -316,7 +316,7 @@ public virtual void Contains_with_local_array_closure_false_with_null() AssertQuery(es => es.Where(e => !ids.Contains(e.NullableStringA))); } - [ConditionalFact(Skip = "issue #14171")] + [ConditionalFact] public virtual void Contains_with_local_nullable_array_closure_negated() { string[] ids = { "Foo" }; @@ -946,40 +946,58 @@ join e2 in _clientData._entities2 } } - [ConditionalFact(Skip = "issue #14171")] + [ConditionalFact] public virtual void Null_semantics_contains() { - using var ctx = CreateContext(); var ids = new List { 1, 2 }; - var query1 = ctx.Entities1.Where(e => ids.Contains(e.NullableIntA)); - var result1 = query1.ToList(); + AssertQuery(es => es.Where(e => ids.Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => !ids.Contains(e.NullableIntA))); - var query2 = ctx.Entities1.Where(e => !ids.Contains(e.NullableIntA)); - var result2 = query2.ToList(); + var ids2 = new List { 1, 2, null }; + AssertQuery(es => es.Where(e => ids2.Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => !ids2.Contains(e.NullableIntA))); - var ids2 = new List - { - 1, - 2, - null - }; - var query3 = ctx.Entities1.Where(e => ids.Contains(e.NullableIntA)); - var result3 = query3.ToList(); + AssertQuery(es => es.Where(e => new List { 1, 2 }.Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => !new List { 1, 2 }.Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => new List { 1, 2, null }.Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => !new List { 1, 2, null }.Contains(e.NullableIntA))); + } - var query4 = ctx.Entities1.Where(e => !ids.Contains(e.NullableIntA)); - var result4 = query4.ToList(); + [ConditionalFact] + public virtual void Null_semantics_contains_array_with_no_values() + { + var ids = new List(); + AssertQuery(es => es.Where(e => ids.Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => !ids.Contains(e.NullableIntA))); - var query5 = ctx.Entities1.Where(e => !new List { 1, 2 }.Contains(e.NullableIntA)); - var result5 = query5.ToList(); + var ids2 = new List { null }; + AssertQuery(es => es.Where(e => ids2.Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => !ids2.Contains(e.NullableIntA))); - var query6 = ctx.Entities1.Where( - e => !new List - { - 1, - 2, - null - }.Contains(e.NullableIntA)); - var result6 = query6.ToList(); + AssertQuery(es => es.Where(e => new List().Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => !new List().Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => new List { null }.Contains(e.NullableIntA))); + AssertQuery(es => es.Where(e => !new List { null }.Contains(e.NullableIntA))); + } + + [ConditionalFact] + public virtual void Null_semantics_contains_non_nullable_argument() + { + var ids = new List { 1, 2, null }; + AssertQuery(es => es.Where(e => ids.Contains(e.IntA))); + AssertQuery(es => es.Where(e => !ids.Contains(e.IntA))); + + var ids2 = new List { 1, 2, }; + AssertQuery(es => es.Where(e => ids2.Contains(e.IntA))); + AssertQuery(es => es.Where(e => !ids2.Contains(e.IntA))); + + var ids3 = new List(); + AssertQuery(es => es.Where(e => ids3.Contains(e.IntA))); + AssertQuery(es => es.Where(e => !ids3.Contains(e.IntA))); + + var ids4 = new List { null }; + AssertQuery(es => es.Where(e => ids4.Contains(e.IntA))); + AssertQuery(es => es.Where(e => !ids4.Contains(e.IntA))); } [ConditionalFact] diff --git a/test/EFCore.Specification.Tests/Query/ComplexNavigationsQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/ComplexNavigationsQueryTestBase.cs index 6a35bf8024b..7eb05a94f6c 100644 --- a/test/EFCore.Specification.Tests/Query/ComplexNavigationsQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/ComplexNavigationsQueryTestBase.cs @@ -4567,7 +4567,7 @@ public virtual Task Nav_rewrite_doesnt_apply_null_protection_for_function_argume .Select(l1 => Math.Max(l1.OneToOne_Optional_PK1.Level1_Required_Id, 7))); } - [ConditionalTheory(Skip = "See issue#11464")] + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Accessing_optional_property_inside_result_operator_subquery(bool async) { diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsQuerySqlServerTest.cs index cef34b937ca..808b6dc05e7 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsQuerySqlServerTest.cs @@ -326,7 +326,7 @@ public override async Task Method_call_on_optional_navigation_translates_to_null @"SELECT [l].[Id], [l].[Date], [l].[Name], [l].[OneToMany_Optional_Self_Inverse1Id], [l].[OneToMany_Required_Self_Inverse1Id], [l].[OneToOne_Optional_Self1Id] FROM [LevelOne] AS [l] LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[Level1_Optional_Id] -WHERE ([l0].[Name] = N'') OR ([l0].[Name] IS NOT NULL AND ([l0].[Name] IS NOT NULL AND (LEFT([l0].[Name], LEN([l0].[Name])) = [l0].[Name])))"); +WHERE ([l0].[Name] = N'') OR ([l0].[Name] IS NOT NULL AND (LEFT([l0].[Name], LEN([l0].[Name])) = [l0].[Name]))"); } public override async Task Optional_navigation_inside_method_call_translated_to_join_keeps_original_nullability(bool async) @@ -3305,9 +3305,10 @@ public override async Task Accessing_optional_property_inside_result_operator_su await base.Accessing_optional_property_inside_result_operator_subquery(async); AssertSql( - @"SELECT [l1].[Id], [l1].[Date], [l1].[Name], [l1].[OneToMany_Optional_Self_Inverse1Id], [l1].[OneToMany_Required_Self_Inverse1Id], [l1].[OneToOne_Optional_Self1Id], [l1.OneToOne_Optional_FK1].[Name] -FROM [LevelOne] AS [l1] -LEFT JOIN [LevelTwo] AS [l1.OneToOne_Optional_FK1] ON [l1].[Id] = [l1.OneToOne_Optional_FK1].[Level1_Optional_Id]"); + @"SELECT [l].[Id], [l].[Date], [l].[Name], [l].[OneToMany_Optional_Self_Inverse1Id], [l].[OneToMany_Required_Self_Inverse1Id], [l].[OneToOne_Optional_Self1Id] +FROM [LevelOne] AS [l] +LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[Level1_Optional_Id] +WHERE [l0].[Name] NOT IN (N'Name1', N'Name2') OR [l0].[Name] IS NULL"); } public override async Task Include_after_SelectMany_and_reference_navigation(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs index 2c99513c003..dcfb0fc8047 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs @@ -869,10 +869,7 @@ public override async Task Null_propagation_optimization1(bool async) AssertSql( @"SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank] FROM [Gears] AS [g] -WHERE [g].[Discriminator] IN (N'Gear', N'Officer') AND (CASE - WHEN ([g].[LeaderNickname] = N'Marcus') AND [g].[LeaderNickname] IS NOT NULL THEN CAST(1 AS bit) - ELSE CAST(0 AS bit) -END = CAST(1 AS bit))"); +WHERE [g].[Discriminator] IN (N'Gear', N'Officer') AND (([g].[LeaderNickname] = N'Marcus') AND [g].[LeaderNickname] IS NOT NULL)"); } public override async Task Null_propagation_optimization2(bool async) @@ -6652,7 +6649,7 @@ LEFT JOIN ( SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank] FROM [Gears] AS [g] WHERE [g].[Discriminator] IN (N'Gear', N'Officer') -) AS [t0] ON ((([t].[GearNickName] = [t0].[Nickname]) AND ([t].[GearSquadId] = [t0].[SquadId])) AND [t].[Note] IS NOT NULL) AND [t].[Note] IS NOT NULL +) AS [t0] ON (([t].[GearNickName] = [t0].[Nickname]) AND ([t].[GearSquadId] = [t0].[SquadId])) AND [t].[Note] IS NOT NULL ORDER BY [t].[Id], [t0].[Nickname], [t0].[SquadId]"); } @@ -7063,10 +7060,7 @@ public override async Task Select_StartsWith_with_null_parameter_as_argument(boo await base.Select_StartsWith_with_null_parameter_as_argument(async); AssertSql( - @"SELECT CASE - WHEN CAST(0 AS bit) = CAST(1 AS bit) THEN CAST(1 AS bit) - ELSE CAST(0 AS bit) -END + @"SELECT CAST(0 AS bit) FROM [Gears] AS [g] WHERE [g].[Discriminator] IN (N'Gear', N'Officer')"); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs index 71fdcd5636e..51cff236f7a 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs @@ -1320,7 +1320,7 @@ public override async Task DefaultIfEmpty_selects_only_required_columns(bool asy FROM ( SELECT NULL AS [empty] ) AS [empty] -LEFT JOIN [Products] AS [p] ON 1 = 1"); +LEFT JOIN [Products] AS [p] ON CAST(1 AS bit) = CAST(1 AS bit)"); } public override async Task Collection_Last_member_access_in_projection_translated(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindFunctionsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindFunctionsQuerySqlServerTest.cs index 46c50276930..61d6ba11d02 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindFunctionsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindFunctionsQuerySqlServerTest.cs @@ -36,7 +36,7 @@ public override async Task String_StartsWith_Identity(bool async) AssertSql( @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -WHERE ([c].[ContactName] = N'') OR ([c].[ContactName] IS NOT NULL AND ([c].[ContactName] IS NOT NULL AND (LEFT([c].[ContactName], LEN([c].[ContactName])) = [c].[ContactName])))"); +WHERE ([c].[ContactName] = N'') OR ([c].[ContactName] IS NOT NULL AND (LEFT([c].[ContactName], LEN([c].[ContactName])) = [c].[ContactName]))"); } public override async Task String_StartsWith_Column(bool async) @@ -46,7 +46,7 @@ public override async Task String_StartsWith_Column(bool async) AssertSql( @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -WHERE ([c].[ContactName] = N'') OR ([c].[ContactName] IS NOT NULL AND ([c].[ContactName] IS NOT NULL AND (LEFT([c].[ContactName], LEN([c].[ContactName])) = [c].[ContactName])))"); +WHERE ([c].[ContactName] = N'') OR ([c].[ContactName] IS NOT NULL AND (LEFT([c].[ContactName], LEN([c].[ContactName])) = [c].[ContactName]))"); } public override async Task String_StartsWith_MethodCall(bool async) @@ -76,7 +76,7 @@ public override async Task String_EndsWith_Identity(bool async) AssertSql( @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -WHERE ([c].[ContactName] = N'') OR ([c].[ContactName] IS NOT NULL AND ([c].[ContactName] IS NOT NULL AND (RIGHT([c].[ContactName], LEN([c].[ContactName])) = [c].[ContactName])))"); +WHERE ([c].[ContactName] = N'') OR ([c].[ContactName] IS NOT NULL AND (RIGHT([c].[ContactName], LEN([c].[ContactName])) = [c].[ContactName]))"); } public override async Task String_EndsWith_Column(bool async) @@ -86,7 +86,7 @@ public override async Task String_EndsWith_Column(bool async) AssertSql( @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -WHERE ([c].[ContactName] = N'') OR ([c].[ContactName] IS NOT NULL AND ([c].[ContactName] IS NOT NULL AND (RIGHT([c].[ContactName], LEN([c].[ContactName])) = [c].[ContactName])))"); +WHERE ([c].[ContactName] = N'') OR ([c].[ContactName] IS NOT NULL AND (RIGHT([c].[ContactName], LEN([c].[ContactName])) = [c].[ContactName]))"); } public override async Task String_EndsWith_MethodCall(bool async) @@ -1179,7 +1179,7 @@ public override async Task Indexof_with_emptystring(bool async) AssertSql( @"SELECT CASE - WHEN N'' = N'' THEN 0 + WHEN CAST(1 AS bit) = CAST(1 AS bit) THEN 0 ELSE CHARINDEX(N'', [c].[ContactName]) - 1 END FROM [Customers] AS [c] diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs index e7047d92c54..8f88af7d304 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs @@ -343,7 +343,7 @@ LEFT JOIN ( SELECT [e].[EmployeeID], [e].[City], [e].[Country], [e].[FirstName], [e].[ReportsTo], [e].[Title] FROM [Employees] AS [e] WHERE [e].[EmployeeID] = -1 -) AS [t] ON 1 = 1"); +) AS [t] ON CAST(1 AS bit) = CAST(1 AS bit)"); } public override async Task Join_with_default_if_empty_on_both_sources(bool async) @@ -359,7 +359,7 @@ LEFT JOIN ( SELECT [e].[EmployeeID], [e].[City], [e].[Country], [e].[FirstName], [e].[ReportsTo], [e].[Title] FROM [Employees] AS [e] WHERE [e].[EmployeeID] = -1 -) AS [t] ON 1 = 1 +) AS [t] ON CAST(1 AS bit) = CAST(1 AS bit) INNER JOIN ( SELECT [t0].[EmployeeID], [t0].[City], [t0].[Country], [t0].[FirstName], [t0].[ReportsTo], [t0].[Title] FROM ( @@ -369,7 +369,7 @@ LEFT JOIN ( SELECT [e0].[EmployeeID], [e0].[City], [e0].[Country], [e0].[FirstName], [e0].[ReportsTo], [e0].[Title] FROM [Employees] AS [e0] WHERE [e0].[EmployeeID] = -1 - ) AS [t0] ON 1 = 1 + ) AS [t0] ON CAST(1 AS bit) = CAST(1 AS bit) ) AS [t1] ON [t].[EmployeeID] = [t1].[EmployeeID]"); } @@ -386,7 +386,7 @@ LEFT JOIN ( SELECT [e].[EmployeeID], [e].[City], [e].[Country], [e].[FirstName], [e].[ReportsTo], [e].[Title] FROM [Employees] AS [e] WHERE [e].[EmployeeID] = -1 -) AS [t] ON 1 = 1"); +) AS [t] ON CAST(1 AS bit) = CAST(1 AS bit)"); } public override async Task Default_if_empty_top_level_positive(bool async) @@ -402,7 +402,7 @@ LEFT JOIN ( SELECT [e].[EmployeeID], [e].[City], [e].[Country], [e].[FirstName], [e].[ReportsTo], [e].[Title] FROM [Employees] AS [e] WHERE [e].[EmployeeID] > 0 -) AS [t] ON 1 = 1"); +) AS [t] ON CAST(1 AS bit) = CAST(1 AS bit)"); } public override async Task Default_if_empty_top_level_projection(bool async) @@ -418,7 +418,7 @@ LEFT JOIN ( SELECT [e].[EmployeeID], [e].[City], [e].[Country], [e].[FirstName], [e].[ReportsTo], [e].[Title] FROM [Employees] AS [e] WHERE [e].[EmployeeID] = -1 -) AS [t] ON 1 = 1"); +) AS [t] ON CAST(1 AS bit) = CAST(1 AS bit)"); } public override async Task Where_query_composition(bool async) @@ -1460,7 +1460,7 @@ public override async Task All_top_level_column(bool async) WHEN NOT EXISTS ( SELECT 1 FROM [Customers] AS [c] - WHERE (([c].[ContactName] <> N'') OR [c].[ContactName] IS NULL) AND ([c].[ContactName] IS NULL OR ([c].[ContactName] IS NULL OR ((LEFT([c].[ContactName], LEN([c].[ContactName])) <> [c].[ContactName]) OR LEFT([c].[ContactName], LEN([c].[ContactName])) IS NULL)))) THEN CAST(1 AS bit) + WHERE (([c].[ContactName] <> N'') OR [c].[ContactName] IS NULL) AND ([c].[ContactName] IS NULL OR ((LEFT([c].[ContactName], LEN([c].[ContactName])) <> [c].[ContactName]) OR LEFT([c].[ContactName], LEN([c].[ContactName])) IS NULL))) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END"); } @@ -3242,7 +3242,7 @@ LEFT JOIN ( SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] WHERE [c].[City] = N'London' -) AS [t] ON 1 = 1 +) AS [t] ON CAST(1 AS bit) = CAST(1 AS bit) WHERE [t].[CustomerID] IS NOT NULL"); } @@ -3273,7 +3273,7 @@ LEFT JOIN ( SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] FROM [Orders] AS [o] WHERE [o].[OrderID] > 15000 - ) AS [t] ON 1 = 1 + ) AS [t] ON CAST(1 AS bit) = CAST(1 AS bit) ) AS [t0]"); } @@ -3293,7 +3293,7 @@ LEFT JOIN ( SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] FROM [Orders] AS [o] WHERE [o].[OrderID] > 15000 - ) AS [t] ON 1 = 1 + ) AS [t] ON CAST(1 AS bit) = CAST(1 AS bit) ) AS [t0] LEFT JOIN [Orders] AS [o0] ON [c].[CustomerID] = [o0].[CustomerID] WHERE ([c].[City] = N'Seattle') AND ([t0].[OrderID] IS NOT NULL AND [o0].[OrderID] IS NOT NULL) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindSelectQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindSelectQuerySqlServerTest.cs index 4ca761331dd..ecef80e3b3b 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindSelectQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindSelectQuerySqlServerTest.cs @@ -1167,7 +1167,7 @@ public override async Task Select_with_complex_expression_that_can_be_funcletize AssertSql( @"SELECT CASE - WHEN N'' = N'' THEN 0 + WHEN CAST(1 AS bit) = CAST(1 AS bit) THEN 0 ELSE CHARINDEX(N'', [c].[ContactName]) - 1 END FROM [Customers] AS [c] diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs index a9c72b33b48..bc1eda003f6 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs @@ -1111,10 +1111,10 @@ public override async Task Where_negated_boolean_expression_compared_to_another_ @"SELECT [p].[ProductID], [p].[Discontinued], [p].[ProductName], [p].[SupplierID], [p].[UnitPrice], [p].[UnitsInStock] FROM [Products] AS [p] WHERE CASE - WHEN [p].[ProductID] > 50 THEN CAST(1 AS bit) + WHEN [p].[ProductID] <= 50 THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END = CASE - WHEN [p].[ProductID] > 20 THEN CAST(1 AS bit) + WHEN [p].[ProductID] <= 20 THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END"); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs index 4180492bea1..ac1593660d8 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs @@ -730,7 +730,9 @@ public override void Contains_with_local_nullable_array_closure_negated() base.Contains_with_local_nullable_array_closure_negated(); AssertSql( - @""); + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableStringA] NOT IN (N'Foo') OR [e].[NullableStringA] IS NULL"); } public override void Contains_with_local_array_closure_with_multiple_nulls() @@ -1447,7 +1449,109 @@ public override void Null_semantics_contains() base.Null_semantics_contains(); AssertSql( - @""); + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IN (1, 2)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] NOT IN (1, 2) OR [e].[NullableIntA] IS NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IN (1, 2) OR [e].[NullableIntA] IS NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] NOT IN (1, 2) AND [e].[NullableIntA] IS NOT NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IN (1, 2)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] NOT IN (1, 2) OR [e].[NullableIntA] IS NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IN (1, 2) OR [e].[NullableIntA] IS NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] NOT IN (1, 2) AND [e].[NullableIntA] IS NOT NULL"); + } + + public override void Null_semantics_contains_array_with_no_values() + { + base.Null_semantics_contains_array_with_no_values(); + + AssertSql( + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE CAST(0 AS bit) = CAST(1 AS bit)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e]", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IS NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IS NOT NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE CAST(0 AS bit) = CAST(1 AS bit)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e]", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IS NULL", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IS NOT NULL"); + } + + public override void Null_semantics_contains_non_nullable_argument() + { + base.Null_semantics_contains_non_nullable_argument(); + + AssertSql( + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] IN (1, 2)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] NOT IN (1, 2)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] IN (1, 2)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] NOT IN (1, 2)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE CAST(0 AS bit) = CAST(1 AS bit)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e]", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE CAST(0 AS bit) = CAST(1 AS bit)", + // + @"SELECT [e].[Id] +FROM [Entities1] AS [e]"); } public override void Null_semantics_with_null_check_simple() diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index 2494b5fa968..4a9dd1e9c40 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -6108,7 +6108,7 @@ WHERE EXISTS ( SELECT 1 FROM [OrganisationUser7973] AS [o0] WHERE [o].[Id] = [o0].[OrganisationId]) - ) AS [t] ON 1 = 1 + ) AS [t] ON CAST(1 AS bit) = CAST(1 AS bit) ) AS [t0]"); } } diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindFunctionsQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindFunctionsQuerySqliteTest.cs index 4b91cd4f9fa..440140585e2 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindFunctionsQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindFunctionsQuerySqliteTest.cs @@ -134,7 +134,7 @@ public override async Task String_StartsWith_Identity(bool async) AssertSql( @"SELECT ""c"".""CustomerID"", ""c"".""Address"", ""c"".""City"", ""c"".""CompanyName"", ""c"".""ContactName"", ""c"".""ContactTitle"", ""c"".""Country"", ""c"".""Fax"", ""c"".""Phone"", ""c"".""PostalCode"", ""c"".""Region"" FROM ""Customers"" AS ""c"" -WHERE (""c"".""ContactName"" = '') OR (""c"".""ContactName"" IS NOT NULL AND (""c"".""ContactName"" IS NOT NULL AND (((""c"".""ContactName"" LIKE ""c"".""ContactName"" || '%') AND (substr(""c"".""ContactName"", 1, length(""c"".""ContactName"")) = ""c"".""ContactName"")) OR (""c"".""ContactName"" = ''))))"); +WHERE (""c"".""ContactName"" = '') OR (""c"".""ContactName"" IS NOT NULL AND (((""c"".""ContactName"" LIKE ""c"".""ContactName"" || '%') AND (substr(""c"".""ContactName"", 1, length(""c"".""ContactName"")) = ""c"".""ContactName"")) OR (""c"".""ContactName"" = '')))"); } public override async Task String_StartsWith_Column(bool async) @@ -144,7 +144,7 @@ public override async Task String_StartsWith_Column(bool async) AssertSql( @"SELECT ""c"".""CustomerID"", ""c"".""Address"", ""c"".""City"", ""c"".""CompanyName"", ""c"".""ContactName"", ""c"".""ContactTitle"", ""c"".""Country"", ""c"".""Fax"", ""c"".""Phone"", ""c"".""PostalCode"", ""c"".""Region"" FROM ""Customers"" AS ""c"" -WHERE (""c"".""ContactName"" = '') OR (""c"".""ContactName"" IS NOT NULL AND (""c"".""ContactName"" IS NOT NULL AND (((""c"".""ContactName"" LIKE ""c"".""ContactName"" || '%') AND (substr(""c"".""ContactName"", 1, length(""c"".""ContactName"")) = ""c"".""ContactName"")) OR (""c"".""ContactName"" = ''))))"); +WHERE (""c"".""ContactName"" = '') OR (""c"".""ContactName"" IS NOT NULL AND (((""c"".""ContactName"" LIKE ""c"".""ContactName"" || '%') AND (substr(""c"".""ContactName"", 1, length(""c"".""ContactName"")) = ""c"".""ContactName"")) OR (""c"".""ContactName"" = '')))"); } public override async Task String_StartsWith_MethodCall(bool async) @@ -174,7 +174,7 @@ public override async Task String_EndsWith_Identity(bool async) AssertSql( @"SELECT ""c"".""CustomerID"", ""c"".""Address"", ""c"".""City"", ""c"".""CompanyName"", ""c"".""ContactName"", ""c"".""ContactTitle"", ""c"".""Country"", ""c"".""Fax"", ""c"".""Phone"", ""c"".""PostalCode"", ""c"".""Region"" FROM ""Customers"" AS ""c"" -WHERE (""c"".""ContactName"" = '') OR (""c"".""ContactName"" IS NOT NULL AND (""c"".""ContactName"" IS NOT NULL AND ((substr(""c"".""ContactName"", -length(""c"".""ContactName"")) = ""c"".""ContactName"") OR (""c"".""ContactName"" = ''))))"); +WHERE (""c"".""ContactName"" = '') OR (""c"".""ContactName"" IS NOT NULL AND ((substr(""c"".""ContactName"", -length(""c"".""ContactName"")) = ""c"".""ContactName"") OR (""c"".""ContactName"" = '')))"); } public override async Task String_EndsWith_Column(bool async) @@ -184,7 +184,7 @@ public override async Task String_EndsWith_Column(bool async) AssertSql( @"SELECT ""c"".""CustomerID"", ""c"".""Address"", ""c"".""City"", ""c"".""CompanyName"", ""c"".""ContactName"", ""c"".""ContactTitle"", ""c"".""Country"", ""c"".""Fax"", ""c"".""Phone"", ""c"".""PostalCode"", ""c"".""Region"" FROM ""Customers"" AS ""c"" -WHERE (""c"".""ContactName"" = '') OR (""c"".""ContactName"" IS NOT NULL AND (""c"".""ContactName"" IS NOT NULL AND ((substr(""c"".""ContactName"", -length(""c"".""ContactName"")) = ""c"".""ContactName"") OR (""c"".""ContactName"" = ''))))"); +WHERE (""c"".""ContactName"" = '') OR (""c"".""ContactName"" IS NOT NULL AND ((substr(""c"".""ContactName"", -length(""c"".""ContactName"")) = ""c"".""ContactName"") OR (""c"".""ContactName"" = '')))"); } public override async Task String_EndsWith_MethodCall(bool async)