Skip to content

Commit

Permalink
Query: Introduce EnumerableExpression which is not SQL token
Browse files Browse the repository at this point in the history
This iterates over design in #27931

- Bring back DistinctExpression (yes the token in invalid in some places and incorrect tree will throw invalid SQL error in database
- Introduce EnumerableExpression which is a holder/parameter object which contains facets of grouping element chain including Distinct/Predicate/Selector/Ordering
- Translators are responsible for putting together pieces from EnumerableExpression to generate a SqlExpression as a result
- Remove GroupByAggregateChainProcessor which failed to avoid double visitation. We need to refactor this code in future to avoid it when we implement public API for aggregate functions

Resolves #27948
Resolves #27935
  • Loading branch information
smitpatel committed May 10, 2022
1 parent dd7c329 commit c9c18bb
Show file tree
Hide file tree
Showing 13 changed files with 593 additions and 550 deletions.
158 changes: 158 additions & 0 deletions src/EFCore.Relational/Query/EnumerableExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.CompilerServices;

namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions;

/// <summary>
/// <para>
/// An expression that represents an enumerable or group translated from chain over a grouping element.
/// </para>
/// <para>
/// This type is typically used by database providers (and other extensions). It is generally
/// not used in application code.
/// </para>
/// </summary>
public class EnumerableExpression : Expression, IPrintableExpression
{
private readonly List<OrderingExpression> _orderings = new();

/// <summary>
/// Creates a new instance of the <see cref="EnumerableExpression" /> class.
/// </summary>
/// <param name="selector">The underlying sql expression being enumerated.</param>
public EnumerableExpression(Expression selector)
{
Selector = selector;
}

/// <summary>
/// The underlying expression being enumerated.
/// </summary>
public virtual Expression Selector { get; private set; }

/// <summary>
/// The value indicating if distinct operator is applied on the enumerable or not.
/// </summary>
public virtual bool IsDistinct { get; private set; }

/// <summary>
/// The value indicating any predicate applied on the enumerable.
/// </summary>
public virtual SqlExpression? Predicate { get; private set; }

/// <summary>
/// The list of orderings to be applied to the enumerable.
/// </summary>
public virtual IReadOnlyList<OrderingExpression> Orderings => _orderings;


/// <summary>
/// Applies new selector to the <see cref="EnumerableExpression" />.
/// </summary>
public virtual void ApplySelector(Expression expression)
{
Selector = expression;
}

/// <summary>
/// Applies DISTINCT operator to the selector of the <see cref="EnumerableExpression" />.
/// </summary>
public virtual void ApplyDistinct()
{
IsDistinct = true;
}

/// <summary>
/// Applies filter predicate to the <see cref="EnumerableExpression" />.
/// </summary>
/// <param name="sqlExpression">An expression to use for filtering.</param>
public virtual void ApplyPredicate(SqlExpression sqlExpression)
{
if (sqlExpression is SqlConstantExpression sqlConstant
&& sqlConstant.Value is bool boolValue
&& boolValue)
{
return;
}

Predicate = Predicate == null
? sqlExpression
: new SqlBinaryExpression(
ExpressionType.AndAlso,
Predicate,
sqlExpression,
typeof(bool),
sqlExpression.TypeMapping);
}

/// <summary>
/// Applies ordering to the <see cref="EnumerableExpression" />. This overwrites any previous ordering specified.
/// </summary>
/// <param name="orderingExpression">An ordering expression to use for ordering.</param>
public virtual void ApplyOrdering(OrderingExpression orderingExpression)
{
_orderings.Clear();
AppendOrdering(orderingExpression);
}

/// <summary>
/// Appends ordering to the existing orderings of the <see cref="EnumerableExpression" />.
/// </summary>
/// <param name="orderingExpression">An ordering expression to use for ordering.</param>
public virtual void AppendOrdering(OrderingExpression orderingExpression)
{
if (!_orderings.Any(o => o.Expression.Equals(orderingExpression.Expression)))
{
_orderings.Add(orderingExpression.Update(orderingExpression.Expression));
}
}

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> throw new InvalidOperationException(
CoreStrings.VisitIsNotAllowed($"{nameof(EnumerableExpression)}.{nameof(VisitChildren)}"));

/// <inheritdoc />
public override ExpressionType NodeType => ExpressionType.Extension;

/// <inheritdoc />
public override Type Type => typeof(IEnumerable<>).MakeGenericType(Selector.Type);

/// <inheritdoc />
public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.AppendLine(nameof(EnumerableExpression) + ":");
using (expressionPrinter.Indent())
{
expressionPrinter.Append("Selector: ");
expressionPrinter.Visit(Selector);
expressionPrinter.AppendLine();
if (IsDistinct)
{
expressionPrinter.AppendLine($"IsDistinct: {IsDistinct}");
}

if (Predicate != null)
{
expressionPrinter.Append("Predicate: ");
expressionPrinter.Visit(Predicate);
expressionPrinter.AppendLine();
}

if (Orderings.Count > 0)
{
expressionPrinter.Append("Orderings: ");
expressionPrinter.VisitCollection(Orderings);
expressionPrinter.AppendLine();
}
}
}

/// <inheritdoc />
public override bool Equals(object? obj) => ReferenceEquals(this, obj);

/// <inheritdoc />
public override int GetHashCode() => RuntimeHelpers.GetHashCode(this);
}
35 changes: 10 additions & 25 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -506,31 +506,6 @@ protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpres
return sqlBinaryExpression;
}

/// <inheritdoc />
protected override Expression VisitSqlEnumerable(SqlEnumerableExpression sqlEnumerableExpression)
{
if (sqlEnumerableExpression.Orderings.Count != 0)
{
// TODO: Throw error here because we don't know how to print orderings.
// Though providers can override this method and generate orderings if they have a way to print it.
throw new InvalidOperationException();
}

if (sqlEnumerableExpression.IsDistinct)
{
_relationalCommandBuilder.Append("DISTINCT (");
}

Visit(sqlEnumerableExpression.SqlExpression);

if (sqlEnumerableExpression.IsDistinct)
{
_relationalCommandBuilder.Append(")");
}

return sqlEnumerableExpression;
}

/// <inheritdoc />
protected override Expression VisitSqlConstant(SqlConstantExpression sqlConstantExpression)
{
Expand Down Expand Up @@ -634,6 +609,16 @@ protected override Expression VisitCollate(CollateExpression collateExpression)
return collateExpression;
}

/// <inheritdoc />
protected override Expression VisitDistinct(DistinctExpression distinctExpression)
{
_relationalCommandBuilder.Append("DISTINCT (");
Visit(distinctExpression.Operand);
_relationalCommandBuilder.Append(")");

return distinctExpression;
}

/// <inheritdoc />
protected override Expression VisitCase(CaseExpression caseExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
private ShapedQueryExpression? TranslateAggregateWithPredicate(
ShapedQueryExpression source,
LambdaExpression? predicate,
Func<SqlEnumerableExpression, SqlExpression?> aggregateTranslator,
Func<SqlExpression, SqlExpression?> aggregateTranslator,
Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
Expand All @@ -1480,7 +1480,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape

HandleGroupByForAggregate(selectExpression, eraseProjection: true);

var translation = aggregateTranslator(new SqlEnumerableExpression(_sqlExpressionFactory.Fragment("*"), distinct: false, null));
var translation = aggregateTranslator(_sqlExpressionFactory.Fragment("*"));
if (translation == null)
{
return null;
Expand All @@ -1500,7 +1500,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
private ShapedQueryExpression? TranslateAggregateWithSelector(
ShapedQueryExpression source,
LambdaExpression? selector,
Func<SqlEnumerableExpression, SqlExpression?> aggregateTranslator,
Func<SqlExpression, SqlExpression?> aggregateTranslator,
bool throwWhenEmpty,
Type resultType)
{
Expand Down Expand Up @@ -1541,7 +1541,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
}
}

var projection = aggregateTranslator(new SqlEnumerableExpression(translatedSelector, distinct: false, null));
var projection = aggregateTranslator(translatedSelector);
if (projection == null)
{
return null;
Expand Down
Loading

0 comments on commit c9c18bb

Please sign in to comment.