Skip to content

Commit

Permalink
Switch to helpers for creating aggregate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed May 27, 2022
1 parent 84398b6 commit f5d339b
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public QueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFac
averageSqlExpression.Type,
averageSqlExpression.TypeMapping);

// Count/LongCount are special since if the argument is a star fragment, it needs to be transformed to any non-null constant
// when a predicate is applied.
case nameof(Queryable.Count)
when methodInfo == QueryableMethods.CountWithoutPredicate
|| methodInfo == QueryableMethods.CountWithPredicate:
Expand Down
104 changes: 104 additions & 0 deletions src/EFCore.SqlServer/Query/Internal/SqlServerExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static class SqlServerExpression
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static SqlFunctionExpression AggregateFunction(
ISqlExpressionFactory sqlExpressionFactory,
string name,
IEnumerable<SqlExpression> arguments,
EnumerableExpression enumerableExpression,
int enumerableArgumentIndex,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
Type returnType,
RelationalTypeMapping? typeMapping = null)
=> new(
name,
ProcessAggregateFunctionArguments(sqlExpressionFactory, arguments, enumerableExpression, enumerableArgumentIndex),
nullable,
argumentsPropagateNullability,
returnType,
typeMapping);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static SqlFunctionExpression AggregateFunctionWithOrdering(
ISqlExpressionFactory sqlExpressionFactory,
string name,
IEnumerable<SqlExpression> arguments,
EnumerableExpression enumerableExpression,
int enumerableArgumentIndex,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
Type returnType,
RelationalTypeMapping? typeMapping = null)
=> enumerableExpression.Orderings.Count == 0
? AggregateFunction(sqlExpressionFactory, name, arguments, enumerableExpression, enumerableArgumentIndex, nullable, argumentsPropagateNullability, returnType, typeMapping)
: new SqlServerSqlFunctionExpression(
name,
ProcessAggregateFunctionArguments(sqlExpressionFactory, arguments, enumerableExpression, enumerableArgumentIndex),
enumerableExpression.Orderings,
nullable,
argumentsPropagateNullability,
returnType,
typeMapping);

private static IReadOnlyList<SqlExpression> ProcessAggregateFunctionArguments(
ISqlExpressionFactory sqlExpressionFactory,
IEnumerable<SqlExpression> arguments,
EnumerableExpression enumerableExpression,
int enumerableArgumentIndex)
{
var argIndex = 0;
var typeMappedArguments = new List<SqlExpression>();

foreach (var argument in arguments)
{
var modifiedArgument = sqlExpressionFactory.ApplyDefaultTypeMapping(argument);

if (argIndex == enumerableArgumentIndex)
{
// This is the argument representing the enumerable inputs to be aggregated.
// Wrap it with a CASE/WHEN for the predicate and with DISTINCT, if necessary.
if (enumerableExpression.Predicate != null)
{
modifiedArgument = sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(enumerableExpression.Predicate, modifiedArgument) },
elseResult: null);
}

if (enumerableExpression.IsDistinct)
{
modifiedArgument = new DistinctExpression(modifiedArgument);
}
}

typeMappedArguments.Add(modifiedArgument);

argIndex++;
}

return typeMappedArguments;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,6 @@ public SqlServerSqlExpressionFactory(SqlExpressionFactoryDependencies dependenci
: base(dependencies)
=> _typeMappingSource = dependencies.TypeMappingSource;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public override SqlFunctionExpression AggregateFunction(
string name,
IEnumerable<SqlExpression> arguments,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
EnumerableExpression enumerableExpression,
int enumerableArgumentIndex,
Type returnType,
RelationalTypeMapping? typeMapping = null)
{
// SQL Server supports ordering on some functions
var baseFunction = base.AggregateFunction(
name, arguments, nullable, argumentsPropagateNullability, enumerableExpression, enumerableArgumentIndex, returnType,
typeMapping);

return enumerableExpression.Orderings.Count == 0
? baseFunction
: new SqlServerSqlFunctionExpression(
baseFunction.Name,
baseFunction.Arguments!,
baseFunction.IsNullable,
baseFunction.ArgumentsPropagateNullability!,
enumerableExpression.Orderings,
baseFunction.Type,
baseFunction.TypeMapping);
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ public class SqlServerSqlFunctionExpression : SqlFunctionExpression, IEquatable<
public SqlServerSqlFunctionExpression(
string functionName,
IEnumerable<SqlExpression> arguments,
IReadOnlyList<OrderingExpression> aggregateOrderings,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
IReadOnlyList<OrderingExpression> aggregateOrderings,
Type type,
RelationalTypeMapping? typeMapping)
: base(functionName, arguments, nullable, argumentsPropagateNullability, type, typeMapping)
Expand Down Expand Up @@ -73,9 +73,9 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
? new SqlServerSqlFunctionExpression(
Name,
visitedBase.Arguments!,
visitedAggregateOrderings ?? AggregateOrderings,
IsNullable,
ArgumentsPropagateNullability!,
visitedAggregateOrderings ?? AggregateOrderings,
Type,
TypeMapping)
: this;
Expand All @@ -91,9 +91,9 @@ public override SqlServerSqlFunctionExpression ApplyTypeMapping(RelationalTypeMa
=> new(
Name,
Arguments!,
AggregateOrderings,
IsNullable,
ArgumentsPropagateNullability!,
AggregateOrderings,
Type,
typeMapping ?? TypeMapping);

Expand All @@ -111,7 +111,7 @@ public override SqlFunctionExpression Update(SqlExpression? instance, IReadOnlyL
return arguments.SequenceEqual(Arguments!)
? this
: new SqlServerSqlFunctionExpression(
Name, arguments, IsNullable, ArgumentsPropagateNullability!, AggregateOrderings, Type, TypeMapping);
Name, arguments, AggregateOrderings, IsNullable, ArgumentsPropagateNullability!, Type, TypeMapping);
}

/// <summary>
Expand All @@ -124,7 +124,7 @@ public virtual SqlFunctionExpression UpdateAggregateOrderings(IReadOnlyList<Orde
=> aggregateOrderings.SequenceEqual(AggregateOrderings)
? this
: new SqlServerSqlFunctionExpression(
Name, Arguments!, IsNullable, ArgumentsPropagateNullability!, aggregateOrderings, Type, TypeMapping);
Name, Arguments!, aggregateOrderings, IsNullable, ArgumentsPropagateNullability!, Type, TypeMapping);

/// <inheritdoc />
protected override void Print(ExpressionPrinter expressionPrinter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,12 @@ public SqlServerStatisticsAggregateMethodTranslator(
return null;
}

if (source.Predicate != null)
{
if (sqlExpression is SqlFragmentExpression)
{
sqlExpression = _sqlExpressionFactory.Constant(1);
}

sqlExpression = _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(source.Predicate, sqlExpression) },
elseResult: null);
}

if (source.IsDistinct)
{
sqlExpression = new DistinctExpression(sqlExpression);
}

return _sqlExpressionFactory.Function(
return SqlServerExpression.AggregateFunction(
_sqlExpressionFactory,
functionName,
new[] { sqlExpression },
source,
enumerableArgumentIndex: 0,
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,11 @@ public SqlServerStringAggregateMethodTranslator(
_sqlExpressionFactory.Constant(string.Empty, typeof(string)));
}

if (source.Predicate != null)
{
if (sqlExpression is SqlFragmentExpression)
{
sqlExpression = _sqlExpressionFactory.Constant(1);
}

sqlExpression = _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(source.Predicate, sqlExpression) },
elseResult: null);
}

if (source.IsDistinct)
{
sqlExpression = new DistinctExpression(sqlExpression);
}

// STRING_AGG returns null when there are no rows (or non-null values), but string.Join returns an empty string.
return
_sqlExpressionFactory.Coalesce(
_sqlExpressionFactory.Function(
SqlServerExpression.AggregateFunctionWithOrdering(
_sqlExpressionFactory,
"STRING_AGG",
new[]
{
Expand All @@ -115,6 +99,8 @@ public SqlServerStringAggregateMethodTranslator(
method == StringJoinMethod ? arguments[0] : _sqlExpressionFactory.Constant(string.Empty, typeof(string)),
sqlExpression.TypeMapping)
},
source,
enumerableArgumentIndex: 0,
nullable: true,
argumentsPropagateNullability: new[] { false, true },
typeof(string)),
Expand Down

0 comments on commit f5d339b

Please sign in to comment.