Skip to content

Commit

Permalink
ExecuteDelete: Always generate Any form
Browse files Browse the repository at this point in the history
Visit retrived ProjectionBinding from select expression as projection binding can bind to non-SqlExpression too.

Resolves #28524
Resolves #28745
Resolves #28752
  • Loading branch information
smitpatel committed Aug 19, 2022
1 parent eb6087b commit cc97ab2
Show file tree
Hide file tree
Showing 25 changed files with 660 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1079,21 +1079,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var clrType = entityType.ClrType;
var entityParameter = Expression.Parameter(clrType);
Expression predicateBody;
//if (pk.Properties.Count == 1)
//{
// predicateBody = Expression.Call(
// EnumerableMethods.Contains.MakeGenericMethod(clrType), source, entityParameter);
//}
//else
//{
var innerParameter = Expression.Parameter(clrType);
predicateBody = Expression.Call(
QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType),
source,
Expression.Quote(Expression.Lambda(
Infrastructure.ExpressionExtensions.CreateEqualsExpression(innerParameter, entityParameter),
innerParameter)));
//}

var newSource = Expression.Call(
QueryableMethods.Where.MakeGenericMethod(clrType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Collections;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

Expand Down Expand Up @@ -606,8 +605,82 @@ protected override Expression VisitExtension(Expression extensionExpression)
return new EntityReferenceExpression(entityShaperExpression);

case ProjectionBindingExpression projectionBindingExpression:
return ((SelectExpression)projectionBindingExpression.QueryExpression)
.GetProjection(projectionBindingExpression);
return Visit(((SelectExpression)projectionBindingExpression.QueryExpression)
.GetProjection(projectionBindingExpression));

case ShapedQueryExpression shapedQueryExpression:
if (shapedQueryExpression.ResultCardinality == ResultCardinality.Enumerable)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var shaperExpression = shapedQueryExpression.ShaperExpression;
ProjectionBindingExpression? mappedProjectionBindingExpression = null;

var innerExpression = shaperExpression;
Type? convertedType = null;
if (shaperExpression is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
convertedType = unaryExpression.Type;
innerExpression = unaryExpression.Operand;
}

if (innerExpression is EntityShaperExpression ese
&& (convertedType == null
|| convertedType.IsAssignableFrom(ese.Type)))
{
return new EntityReferenceExpression(shapedQueryExpression.UpdateShaperExpression(innerExpression));
}

if (innerExpression is ProjectionBindingExpression pbe
&& (convertedType == null
|| convertedType.MakeNullable() == innerExpression.Type))
{
mappedProjectionBindingExpression = pbe;
}

if (mappedProjectionBindingExpression == null
&& shaperExpression is BlockExpression blockExpression
&& blockExpression.Expressions.Count == 2
&& blockExpression.Expressions[0] is BinaryExpression binaryExpression
&& binaryExpression.NodeType == ExpressionType.Assign
&& binaryExpression.Right is ProjectionBindingExpression pbe2)
{
mappedProjectionBindingExpression = pbe2;
}

if (mappedProjectionBindingExpression == null)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var subquery = (SelectExpression)shapedQueryExpression.QueryExpression;
var projection = subquery.GetProjection(mappedProjectionBindingExpression);
if (projection is not SqlExpression sqlExpression)
{
return QueryCompilationContext.NotTranslatedExpression;
}

if (subquery.Tables.Count == 0)
{
return sqlExpression;
}

subquery.ReplaceProjection(new List<Expression> { sqlExpression });
subquery.ApplyProjection();

SqlExpression scalarSubqueryExpression = new ScalarSubqueryExpression(subquery);

if (shapedQueryExpression.ResultCardinality == ResultCardinality.SingleOrDefault
&& !shaperExpression.Type.IsNullableType())
{
scalarSubqueryExpression = _sqlExpressionFactory.Coalesce(
scalarSubqueryExpression,
(SqlExpression)Visit(shaperExpression.Type.GetDefaultValueConstant()));
}

return scalarSubqueryExpression;

default:
return QueryCompilationContext.NotTranslatedExpression;
Expand All @@ -632,7 +705,7 @@ protected override Expression VisitMember(MemberExpression memberExpression)
var innerExpression = Visit(memberExpression.Expression);

return TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member))
?? (TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
?? (TranslationFailed(memberExpression.Expression, innerExpression, out var sqlInnerExpression)
? QueryCompilationContext.NotTranslatedExpression
: Dependencies.MemberTranslatorProvider.Translate(
sqlInnerExpression, memberExpression.Member, memberExpression.Type, _queryCompilationContext.Logger))
Expand Down Expand Up @@ -792,9 +865,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
: method;

var enumerableSource = Visit(arguments[0]);
if (enumerableSource is EnumerableExpression)
if (enumerableSource is EnumerableExpression ee)
{
enumerableExpression = (EnumerableExpression)enumerableSource;
enumerableExpression = ee;
switch (method.Name)
{
case nameof(Queryable.AsQueryable)
Expand Down Expand Up @@ -928,10 +1001,10 @@ when QueryableMethods.IsSumWithSelector(genericMethod):
&& !skipVisitChildren)
{
var @object = Visit(methodCallExpression.Object);
if (@object is EnumerableExpression)
if (@object is EnumerableExpression eeo)
{
// This is safe since if enumerableExpression is non-null then it was static method
enumerableExpression = (EnumerableExpression)@object;
enumerableExpression = eeo;
}
else if (TranslationFailed(methodCallExpression.Object, @object, out sqlObject))
{
Expand All @@ -944,15 +1017,15 @@ when QueryableMethods.IsSumWithSelector(genericMethod):
{
var argument = arguments[i];
var visitedArgument = Visit(argument);
if (visitedArgument is EnumerableExpression)
if (visitedArgument is EnumerableExpression eea)
{
if (enumerableExpression != null)
{
abortTranslation = true;
break;
}

enumerableExpression = (EnumerableExpression)visitedArgument;
enumerableExpression = eea;
continue;
}

Expand Down Expand Up @@ -1009,83 +1082,10 @@ when QueryableMethods.IsSumWithSelector(genericMethod):

// Subquery case
var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression);
if (subqueryTranslation != null)
{
if (subqueryTranslation.ResultCardinality == ResultCardinality.Enumerable)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var shaperExpression = subqueryTranslation.ShaperExpression;
ProjectionBindingExpression? mappedProjectionBindingExpression = null;

var innerExpression = shaperExpression;
Type? convertedType = null;
if (shaperExpression is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
convertedType = unaryExpression.Type;
innerExpression = unaryExpression.Operand;
}

if (innerExpression is EntityShaperExpression ese
&& (convertedType == null
|| convertedType.IsAssignableFrom(ese.Type)))
{
return new EntityReferenceExpression(subqueryTranslation.UpdateShaperExpression(innerExpression));
}

if (innerExpression is ProjectionBindingExpression pbe
&& (convertedType == null
|| convertedType.MakeNullable() == innerExpression.Type))
{
mappedProjectionBindingExpression = pbe;
}

if (mappedProjectionBindingExpression == null
&& shaperExpression is BlockExpression blockExpression
&& blockExpression.Expressions.Count == 2
&& blockExpression.Expressions[0] is BinaryExpression binaryExpression
&& binaryExpression.NodeType == ExpressionType.Assign
&& binaryExpression.Right is ProjectionBindingExpression pbe2)
{
mappedProjectionBindingExpression = pbe2;
}

if (mappedProjectionBindingExpression == null)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var subquery = (SelectExpression)subqueryTranslation.QueryExpression;
var projection = subquery.GetProjection(mappedProjectionBindingExpression);
if (projection is not SqlExpression sqlExpression)
{
return QueryCompilationContext.NotTranslatedExpression;
}

if (subquery.Tables.Count == 0)
{
return sqlExpression;
}

subquery.ReplaceProjection(new List<Expression> { sqlExpression });
subquery.ApplyProjection();

SqlExpression scalarSubqueryExpression = new ScalarSubqueryExpression(subquery);

if (subqueryTranslation.ResultCardinality == ResultCardinality.SingleOrDefault
&& !shaperExpression.Type.IsNullableType())
{
scalarSubqueryExpression = _sqlExpressionFactory.Coalesce(
scalarSubqueryExpression,
(SqlExpression)Visit(shaperExpression.Type.GetDefaultValueConstant()));
}

return scalarSubqueryExpression;
}

return QueryCompilationContext.NotTranslatedExpression;
return subqueryTranslation == null
? QueryCompilationContext.NotTranslatedExpression
: Visit(subqueryTranslation);
}

/// <inheritdoc />
Expand Down Expand Up @@ -1394,12 +1394,7 @@ private static EnumerableExpression ProcessSelector(EnumerableExpression enumera
{
var lambdaBody = RemapLambda(enumerableExpression, lambdaExpression);
var predicate = TranslateInternal(lambdaBody);
if (predicate == null)
{
return null;
}

return enumerableExpression.ApplyPredicate(predicate);
return predicate == null ? null : enumerableExpression.ApplyPredicate(predicate);
}

private static Expression TryRemoveImplicitConvert(Expression expression)
Expand Down
10 changes: 2 additions & 8 deletions src/EFCore/Query/QueryableMethodTranslatingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,7 @@ when QueryableMethods.IsAverageWithSelector(method):
var source2 = Visit(methodCallExpression.Arguments[1]);
if (source2 is ShapedQueryExpression innerShapedQueryExpression)
{
return CheckTranslated(
TranslateConcat(
shapedQueryExpression,
innerShapedQueryExpression));
return CheckTranslated(TranslateConcat(shapedQueryExpression, innerShapedQueryExpression));
}

break;
Expand Down Expand Up @@ -207,10 +204,7 @@ when QueryableMethods.IsAverageWithSelector(method):
var source2 = Visit(methodCallExpression.Arguments[1]);
if (source2 is ShapedQueryExpression innerShapedQueryExpression)
{
return CheckTranslated(
TranslateExcept(
shapedQueryExpression,
innerShapedQueryExpression));
return CheckTranslated(TranslateExcept(shapedQueryExpression, innerShapedQueryExpression));
}

break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public virtual Task Delete_where_hierarchy(bool async)
ss => ss.Set<Animal>().Where(e => e.Name == "Great spotted kiwi"),
rowsAffectedCount: 1);

[ConditionalTheory(Skip = "Issue#28524")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_where_hierarchy_subquery(bool async)
=> AssertDelete(
Expand Down Expand Up @@ -53,6 +53,35 @@ public virtual Task Delete_where_using_hierarchy_derived(bool async)
ss => ss.Set<Country>().Where(e => e.Animals.OfType<Kiwi>().Where(a => a.CountryId > 0).Count() > 0),
rowsAffectedCount: 1);

[ConditionalTheory(Skip = "Issue#28525")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_GroupBy_Where_Select_First(bool async)
=> AssertDelete(
async,
ss => ss.Set<Animal>()
.GroupBy(e => e.CountryId)
.Where(g => g.Count() < 3)
.Select(g => g.First()),
rowsAffectedCount: 1);

[ConditionalTheory(Skip = "Issue#26753")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_GroupBy_Where_Select_First_2(bool async)
=> AssertDelete(
async,
ss => ss.Set<Animal>().Where(e => e == ss.Set<Animal>().GroupBy(e => e.CountryId)
.Where(g => g.Count() < 3).Select(g => g.First()).FirstOrDefault()),
rowsAffectedCount: 1);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_GroupBy_Where_Select_First_3(bool async)
=> AssertDelete(
async,
ss => ss.Set<Animal>().Where(e => ss.Set<Animal>().GroupBy(e => e.CountryId)
.Where(g => g.Count() < 3).Select(g => g.First()).Any(i => i == e)),
rowsAffectedCount: 1);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_where_keyless_entity_mapped_to_sql_query(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ public abstract class InheritanceBulkUpdatesFixtureBase : InheritanceQueryFixtur
{
protected override string StoreName => "InheritanceBulkUpdatesTest";

public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder)
=> base.AddOptions(builder).ConfigureWarnings(w => w.Log(CoreEventId.FirstWithoutOrderByAndFilterWarning));

public void UseTransaction(DatabaseFacade facade, IDbContextTransaction transaction)
=> facade.UseTransaction(transaction.GetDbTransaction());
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public virtual Task Delete_where_hierarchy(bool async)
ss => ss.Set<Animal>().Where(e => e.Name == "Great spotted kiwi"),
rowsAffectedCount: 1);

[ConditionalTheory(Skip = "Issue#28524")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_where_hierarchy_subquery(bool async)
=> AssertDelete(
Expand Down Expand Up @@ -53,6 +53,35 @@ public virtual Task Delete_where_using_hierarchy_derived(bool async)
ss => ss.Set<Country>().Where(e => e.Animals.OfType<Kiwi>().Where(a => a.CountryId > 0).Count() > 0),
rowsAffectedCount: 1);

[ConditionalTheory(Skip = "Issue#28525")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_GroupBy_Where_Select_First(bool async)
=> AssertDelete(
async,
ss => ss.Set<Animal>()
.GroupBy(e => e.CountryId)
.Where( g=> g.Count() < 3)
.Select(g => g.First()),
rowsAffectedCount: 2);

[ConditionalTheory(Skip = "Issue#26753")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_GroupBy_Where_Select_First_2(bool async)
=> AssertDelete(
async,
ss => ss.Set<Animal>().Where(e => e == ss.Set<Animal>().GroupBy(e => e.CountryId)
.Where(g => g.Count() < 3).Select(g => g.First()).FirstOrDefault()),
rowsAffectedCount: 2);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_GroupBy_Where_Select_First_3(bool async)
=> AssertDelete(
async,
ss => ss.Set<Animal>().Where(e => ss.Set<Animal>().GroupBy(e => e.CountryId)
.Where(g => g.Count() < 3).Select(g => g.First()).Any(i => i == e)),
rowsAffectedCount: 2);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_where_keyless_entity_mapped_to_sql_query(bool async)
Expand Down
Loading

0 comments on commit cc97ab2

Please sign in to comment.