Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Provide a way to translate aggregate methods on top level #28102

Merged
1 commit merged into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,53 +79,50 @@ public QueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFac
|| methodInfo == QueryableMethods.CountWithPredicate:
var countSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*");
countSqlExpression = CombineTerms(source, countSqlExpression);
return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
new[] { countSqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(int)));
return _sqlExpressionFactory.Function(
"COUNT",
new[] { countSqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(int));

case nameof(Queryable.LongCount)
when methodInfo == QueryableMethods.LongCountWithoutPredicate
|| methodInfo == QueryableMethods.LongCountWithPredicate:
var longCountSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*");
longCountSqlExpression = CombineTerms(source, longCountSqlExpression);

return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
new[] { longCountSqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(long)));
return _sqlExpressionFactory.Function(
"COUNT",
new[] { longCountSqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(long));

case nameof(Queryable.Max)
when (methodInfo == QueryableMethods.MaxWithoutSelector
|| methodInfo == QueryableMethods.MaxWithSelector)
&& source.Selector is SqlExpression maxSqlExpression:
maxSqlExpression = CombineTerms(source, maxSqlExpression);
return _sqlExpressionFactory.Function(
"MAX",
new[] { maxSqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
maxSqlExpression.Type,
maxSqlExpression.TypeMapping);
"MAX",
new[] { maxSqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
maxSqlExpression.Type,
maxSqlExpression.TypeMapping);

case nameof(Queryable.Min)
when (methodInfo == QueryableMethods.MinWithoutSelector
|| methodInfo == QueryableMethods.MinWithSelector)
&& source.Selector is SqlExpression minSqlExpression:
minSqlExpression = CombineTerms(source, minSqlExpression);
return _sqlExpressionFactory.Function(
"MIN",
new[] { minSqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
minSqlExpression.Type,
minSqlExpression.TypeMapping);
"MIN",
new[] { minSqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
minSqlExpression.Type,
minSqlExpression.TypeMapping);

case nameof(Queryable.Sum)
when (QueryableMethods.IsSumWithoutSelector(methodInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
ShapedQueryExpression source,
LambdaExpression? selector,
Type resultType)
=> TranslateAggregateWithSelector(
source, selector, e => TranslateAverage(e), throwWhenEmpty: true, resultType);
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, throwWhenEmpty: true, resultType);

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateCast(ShapedQueryExpression source, Type resultType)
Expand Down Expand Up @@ -301,7 +300,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate)
=> TranslateAggregateWithPredicate(source, predicate, e => TranslateCount(e), typeof(int));
=> TranslateAggregateWithPredicate(source, predicate, QueryableMethods.CountWithoutPredicate);

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateDefaultIfEmpty(ShapedQueryExpression source, Expression? defaultValue)
Expand Down Expand Up @@ -386,16 +385,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
LambdaExpression? resultSelector)
{
var selectExpression = (SelectExpression)source.QueryExpression;
// This has it's own set of condition since it is different scenario from below.
// Aggregate operators need pushdown for skip/limit/offset covered by selectExpression.PrepareForAggregate.
// Aggregate operators need special processing beyond pushdown when applying over group by for client eval.
if (selectExpression.Limit != null
|| selectExpression.Offset != null
|| selectExpression.IsDistinct
|| selectExpression.GroupBy.Count > 0)
{
selectExpression.PushdownIntoSubquery();
}
selectExpression.PrepareForAggregate();

var remappedKeySelector = RemapLambdaBody(source, keySelector);
var translatedKey = TranslateGroupingKey(remappedKeySelector);
Expand Down Expand Up @@ -623,15 +613,17 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate)
=> TranslateAggregateWithPredicate(source, predicate, e => TranslateLongCount(e), typeof(long));
=> TranslateAggregateWithPredicate(source, predicate, QueryableMethods.LongCountWithoutPredicate);

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
=> TranslateAggregateWithSelector(source, selector, e => TranslateMax(e), throwWhenEmpty: true, resultType);
=> TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
=> TranslateAggregateWithSelector(source, selector, e => TranslateMin(e), throwWhenEmpty: true, resultType);
=> TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateOfType(ShapedQueryExpression source, Type resultType)
Expand Down Expand Up @@ -888,7 +880,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
=> TranslateAggregateWithSelector(source, selector, e => TranslateSum(e), throwWhenEmpty: false, resultType);
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, throwWhenEmpty: false, resultType);

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count)
Expand Down Expand Up @@ -954,7 +946,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return source;
}

private SqlExpression? TranslateExpression(Expression expression)
/// <summary>
/// Translates the given expression into equivalent SQL representation.
/// </summary>
/// <param name="expression">An expression to translate.</param>
/// <returns>A <see cref="SqlExpression"/> which is translation of given expression or <see langword="null"/>.</returns>
protected virtual SqlExpression? TranslateExpression(Expression expression)
{
var translation = _sqlTranslator.Translate(expression);
if (translation == null && _sqlTranslator.TranslationErrorDetails != null)
Expand All @@ -965,7 +962,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return translation;
}

private SqlExpression? TranslateLambdaExpression(
/// <summary>
/// Translates the given lambda expression for the <see cref="ShapedQueryExpression"/> source into equivalent SQL representation.
/// </summary>
/// <param name="shapedQueryExpression">A <see cref="ShapedQueryExpression"/> on which the lambda expression is being applied.</param>
/// <param name="lambdaExpression">A <see cref="LambdaExpression"/> to translate into SQL.</param>
/// <returns>A <see cref="SqlExpression"/> which is translation of given lambda expression or <see langword="null"/>.</returns>
protected virtual SqlExpression? TranslateLambdaExpression(
ShapedQueryExpression shapedQueryExpression,
LambdaExpression lambdaExpression)
=> TranslateExpression(RemapLambdaBody(shapedQueryExpression, lambdaExpression));
Expand Down Expand Up @@ -1465,45 +1468,11 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
}
}

private SqlExpression? TranslateAverage(SqlExpression sqlExpression)
=> RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate(
RelationalDependencies.Model, QueryableMethods.GetAverageWithoutSelector(sqlExpression.Type),
new EnumerableExpression(sqlExpression), Array.Empty<SqlExpression>(), _queryCompilationContext.Logger);

private SqlExpression? TranslateCount(SqlExpression sqlExpression)
=> RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate(
RelationalDependencies.Model, QueryableMethods.CountWithoutPredicate,
new EnumerableExpression(sqlExpression), Array.Empty<SqlExpression>(), _queryCompilationContext.Logger);

private SqlExpression? TranslateLongCount(SqlExpression sqlExpression)
=> RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate(
RelationalDependencies.Model, QueryableMethods.LongCountWithoutPredicate,
new EnumerableExpression(sqlExpression), Array.Empty<SqlExpression>(), _queryCompilationContext.Logger);

private SqlExpression? TranslateMax(SqlExpression sqlExpression)
=> RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate(
RelationalDependencies.Model, QueryableMethods.MaxWithoutSelector,
new EnumerableExpression(sqlExpression), Array.Empty<SqlExpression>(), _queryCompilationContext.Logger);

private SqlExpression? TranslateMin(SqlExpression sqlExpression)
=> RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate(
RelationalDependencies.Model, QueryableMethods.MinWithoutSelector,
new EnumerableExpression(sqlExpression), Array.Empty<SqlExpression>(), _queryCompilationContext.Logger);

private SqlExpression? TranslateSum(SqlExpression sqlExpression)
=> RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate(
RelationalDependencies.Model, QueryableMethods.GetSumWithoutSelector(sqlExpression.Type),
new EnumerableExpression(sqlExpression), Array.Empty<SqlExpression>(), _queryCompilationContext.Logger);

private ShapedQueryExpression? TranslateAggregateWithPredicate(
ShapedQueryExpression source,
LambdaExpression? predicate,
Func<SqlExpression, SqlExpression?> aggregateTranslator,
Type resultType)
MethodInfo predicateLessMethodInfo)
{
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

if (predicate != null)
{
var translatedSource = TranslateWhere(source, predicate);
Expand All @@ -1515,9 +1484,19 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
source = translatedSource;
}

HandleGroupByForAggregate(selectExpression, eraseProjection: true);
var selectExpression = (SelectExpression)source.QueryExpression;
if (!selectExpression.IsDistinct)
{
selectExpression.ReplaceProjection(new List<Expression>());
}

var translation = aggregateTranslator(_sqlExpressionFactory.Fragment("*"));
selectExpression.PrepareForAggregate();
var selector = _sqlExpressionFactory.Fragment("*");
var methodCall = Expression.Call(
predicateLessMethodInfo.MakeGenericMethod(selector.Type),
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(selector.Type), new EnumerableExpression(selector)));
var translation = TranslateExpression(methodCall);
if (translation == null)
{
return null;
Expand All @@ -1527,6 +1506,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape

selectExpression.ClearOrdering();
selectExpression.ReplaceProjection(projectionMapping);
var resultType = predicateLessMethodInfo.ReturnType;

return source.UpdateShaperExpression(
Expression.Convert(
Expand All @@ -1536,18 +1516,17 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape

private ShapedQueryExpression? TranslateAggregateWithSelector(
ShapedQueryExpression source,
LambdaExpression? selector,
Func<SqlExpression, SqlExpression?> aggregateTranslator,
LambdaExpression? selectorLambda,
Func<Type, MethodInfo> methodGenerator,
smitpatel marked this conversation as resolved.
Show resolved Hide resolved
bool throwWhenEmpty,
Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();
HandleGroupByForAggregate(selectExpression);

SqlExpression translatedSelector;
if (selector == null
|| selector.Body == selector.Parameters[0])
Expression? selector = null;
if (selectorLambda == null
|| selectorLambda.Body == selectorLambda.Parameters[0])
{
var shaperExpression = source.ShaperExpression;
if (shaperExpression is UnaryExpression unaryExpression
Expand All @@ -1558,34 +1537,32 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape

if (shaperExpression is ProjectionBindingExpression projectionBindingExpression)
{
translatedSelector = (SqlExpression)selectExpression.GetProjection(projectionBindingExpression);
}
else
{
return null;
selector = selectExpression.GetProjection(projectionBindingExpression);
}
}
else
{
var newSelector = RemapLambdaBody(source, selector);
if (TranslateExpression(newSelector) is SqlExpression sqlExpression)
{
translatedSelector = sqlExpression;
}
else
{
return null;
}
selector = RemapLambdaBody(source, selectorLambda);
}

if (selector == null
|| TranslateExpression(selector) is not SqlExpression translatedSelector)
{
return null;
}

var projection = aggregateTranslator(translatedSelector);
if (projection == null)
var methodCall = Expression.Call(
methodGenerator(translatedSelector.Type),
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(translatedSelector.Type), new EnumerableExpression(translatedSelector)));
var translation = _sqlTranslator.Translate(methodCall);
if (translation == null)
{
return null;
}

selectExpression.ReplaceProjection(
new Dictionary<ProjectionMember, Expression> { { new ProjectionMember(), projection } });
new Dictionary<ProjectionMember, Expression> { { new ProjectionMember(), translation } });

selectExpression.ClearOrdering();
Expression shaper;
Expand All @@ -1602,7 +1579,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
var resultVariable = Expression.Variable(nullableResultType, "result");
var returnValueForNull = resultType.IsNullableType()
? (Expression)Expression.Constant(null, resultType)
: projection.Type.IsNullableType()
: translation.Type.IsNullableType()
? Expression.Default(resultType)
: Expression.Throw(
Expression.New(
Expand All @@ -1624,7 +1601,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
else
{
// Sum case. Projection is always non-null. We read nullable value.
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type.MakeNullable());
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable());

if (resultType != shaper.Type)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,10 @@ public sealed record RelationalQueryableMethodTranslatingExpressionVisitorDepend
[EntityFrameworkInternal]
public RelationalQueryableMethodTranslatingExpressionVisitorDependencies(
IRelationalSqlTranslatingExpressionVisitorFactory relationalSqlTranslatingExpressionVisitorFactory,
ISqlExpressionFactory sqlExpressionFactory,
IModel model,
IAggregateMethodCallTranslatorProvider aggregateMethodCallTranslatorProvider)
ISqlExpressionFactory sqlExpressionFactory)
{
RelationalSqlTranslatingExpressionVisitorFactory = relationalSqlTranslatingExpressionVisitorFactory;
SqlExpressionFactory = sqlExpressionFactory;
Model = model;
AggregateMethodCallTranslatorProvider = aggregateMethodCallTranslatorProvider;
}

/// <summary>
Expand All @@ -66,14 +62,4 @@ public RelationalQueryableMethodTranslatingExpressionVisitorDependencies(
/// The SQL expression factory.
/// </summary>
public ISqlExpressionFactory SqlExpressionFactory { get; init; }

/// <summary>
/// The model.
/// </summary>
public IModel Model { get; init; }

/// <summary>
/// The aggregate method-call translation provider.
/// </summary>
public IAggregateMethodCallTranslatorProvider AggregateMethodCallTranslatorProvider { get; }
}
Loading