Skip to content

Commit

Permalink
Query: Add regression tests for correlated collection and GroupBy/Dis…
Browse files Browse the repository at this point in the history
…tinct (#24477)

Some cleanup for #17337
  • Loading branch information
smitpatel authored Mar 25, 2021
1 parent b63d0a6 commit 088c849
Show file tree
Hide file tree
Showing 14 changed files with 361 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,14 @@ protected override Expression VisitExtension(Expression extensionExpression)
case EntityShaperExpression entityShaperExpression:
return new EntityReferenceExpression(entityShaperExpression);

case ProjectionBindingExpression projectionBindingExpression:
return projectionBindingExpression.ProjectionMember != null
? ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression)
.GetMappedProjection(projectionBindingExpression.ProjectionMember)
: QueryCompilationContext.NotTranslatedExpression;
case ProjectionBindingExpression projectionBindingExpression
when projectionBindingExpression.ProjectionMember != null:
return ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression)
.GetMappedProjection(projectionBindingExpression.ProjectionMember);

//case ProjectionBindingExpression projectionBindingExpression
// when projectionBindingExpression.Index is int index:
// return ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression).Projection[index];

case InMemoryGroupByShaperExpression inMemoryGroupByShaperExpression:
return new GroupingElementExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1441,8 +1441,13 @@ outerKey is NewArrayExpression newArrayExpression
|| (entityType.FindDiscriminatorProperty() == null
&& navigation.DeclaringEntityType.IsStrictlyDerivedFrom(entityShaperExpression.EntityType));

innerShaper = _selectExpression.GenerateWeakEntityShaper(
var entityProjection = _selectExpression.GenerateWeakEntityProjectionExpression(
targetEntityType, table, identifyingColumn.Name, identifyingColumn.Table, principalNullable);

if (entityProjection != null)
{
innerShaper = new RelationalEntityShaperExpression(targetEntityType, entityProjection, principalNullable);
}
}

if (innerShaper == null)
Expand Down Expand Up @@ -1475,8 +1480,11 @@ outerKey is NewArrayExpression newArrayExpression
_selectExpression.AddLeftJoin(innerSelectExpression, joinPredicate);
var leftJoinTable = ((LeftJoinExpression)_selectExpression.Tables.Last()).Table;

innerShaper = _selectExpression.GenerateWeakEntityShaper(
targetEntityType, table, null, leftJoinTable, makeNullable: true)!;
innerShaper = new RelationalEntityShaperExpression(
targetEntityType,
_selectExpression.GenerateWeakEntityProjectionExpression(
targetEntityType, table, null, leftJoinTable, nullable: true)!,
nullable: true);
}

entityProjectionExpression.AddNavigationBinding(navigation, innerShaper);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,14 @@ protected override Expression VisitExtension(Expression extensionExpression)
case EntityShaperExpression entityShaperExpression:
return new EntityReferenceExpression(entityShaperExpression);

case ProjectionBindingExpression projectionBindingExpression:
return projectionBindingExpression.ProjectionMember != null
? ((SelectExpression)projectionBindingExpression.QueryExpression)
.GetMappedProjection(projectionBindingExpression.ProjectionMember)
: QueryCompilationContext.NotTranslatedExpression;
case ProjectionBindingExpression projectionBindingExpression
when projectionBindingExpression.ProjectionMember != null:
return ((SelectExpression)projectionBindingExpression.QueryExpression)
.GetMappedProjection(projectionBindingExpression.ProjectionMember);

//case ProjectionBindingExpression projectionBindingExpression
// when projectionBindingExpression.Index is int index:
// return ((SelectExpression)projectionBindingExpression.QueryExpression).Projection[index].Expression;

case GroupByShaperExpression groupByShaperExpression:
return new GroupingElementExpression(groupByShaperExpression.ElementSelector);
Expand Down
101 changes: 13 additions & 88 deletions src/EFCore.Relational/Query/SqlExpressions/ColumnExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@

using System;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

Expand All @@ -20,101 +16,47 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
/// This type is typically used by database providers (and other extensions). It is generally
/// not used in application code.
/// </para>
/// <para>
/// This class is not publicly constructable. If this is a problem for your application or provider, then please file
/// an issue at https://github.com/dotnet/efcore.
/// </para>
/// </summary>
[DebuggerDisplay("{DebuggerDisplay(),nq}")]
// Class is sealed because there are no public/protected constructors. Can be unsealed if this is changed.
public sealed class ColumnExpression : SqlExpression
public abstract class ColumnExpression : SqlExpression
{
private readonly TableReferenceExpression _table;

internal ColumnExpression(IProperty property, IColumnBase column, TableReferenceExpression table, bool nullable)
: this(
column.Name,
table,
property.ClrType.UnwrapNullableType(),
column.PropertyMappings.First(m => m.Property == property).TypeMapping,
nullable || column.IsNullable)
{
}

internal ColumnExpression(ProjectionExpression subqueryProjection, TableReferenceExpression table)
: this(
subqueryProjection.Alias, table,
subqueryProjection.Type, subqueryProjection.Expression.TypeMapping!,
IsNullableProjection(subqueryProjection))
{
}

private static bool IsNullableProjection(ProjectionExpression projectionExpression)
=> projectionExpression.Expression switch
{
ColumnExpression columnExpression => columnExpression.IsNullable,
SqlConstantExpression sqlConstantExpression => sqlConstantExpression.Value == null,
_ => true,
};

private ColumnExpression(string name, TableReferenceExpression table, Type type, RelationalTypeMapping typeMapping, bool nullable)
/// <summary>
/// Creates a new instance of the <see cref="ColumnExpression" /> class.
/// </summary>
/// <param name="type"> The <see cref="System.Type" /> of the expression. </param>
/// <param name="typeMapping"> The <see cref="RelationalTypeMapping" /> associated with the expression. </param>
protected ColumnExpression(Type type, RelationalTypeMapping? typeMapping)
: base(type, typeMapping)
{
Check.NotEmpty(name, nameof(name));
Check.NotNull(table, nameof(table));
Check.NotEmpty(table.Alias, $"{nameof(table)}.{nameof(table.Alias)}");

Name = name;
_table = table;
IsNullable = nullable;
}

/// <summary>
/// The name of the column.
/// </summary>
public string Name { get; }
public abstract string Name { get; }

/// <summary>
/// The table from which column is being referenced.
/// </summary>
public TableExpressionBase Table => _table.Table;
public abstract TableExpressionBase Table { get; }

/// <summary>
/// The alias of the table from which column is being referenced.
/// </summary>
public string TableAlias => _table.Alias;
public abstract string TableAlias { get; }

/// <summary>
/// The bool value indicating if this column can have null values.
/// </summary>
public bool IsNullable { get; }

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
{
Check.NotNull(visitor, nameof(visitor));

return this;
}
public abstract bool IsNullable { get; }

/// <summary>
/// Makes this column nullable.
/// </summary>
/// <returns> A new expression which has <see cref="IsNullable" /> property set to true. </returns>
public ColumnExpression MakeNullable()
=> new(Name, _table, Type, TypeMapping!, true);
public abstract ColumnExpression MakeNullable();

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
[EntityFrameworkInternal]
public void UpdateTableReference(SelectExpression oldSelect, SelectExpression newSelect)
=> _table.UpdateTableReference(oldSelect, newSelect);

/// <inheritdoc />
/// <inheritdoc />
protected override void Print(ExpressionPrinter expressionPrinter)
{
Check.NotNull(expressionPrinter, nameof(expressionPrinter));
Expand All @@ -123,23 +65,6 @@ protected override void Print(ExpressionPrinter expressionPrinter)
expressionPrinter.Append(Name);
}

/// <inheritdoc />
public override bool Equals(object? obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is ColumnExpression columnExpression
&& Equals(columnExpression));

private bool Equals(ColumnExpression columnExpression)
=> base.Equals(columnExpression)
&& Name == columnExpression.Name
&& _table.Equals(columnExpression._table)
&& IsNullable == columnExpression.IsNullable;

/// <inheritdoc />
public override int GetHashCode()
=> HashCode.Combine(base.GetHashCode(), Name, _table, IsNullable);

private string DebuggerDisplay()
=> $"{TableAlias}.{Name}";
}
Expand Down
135 changes: 133 additions & 2 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
Expand Down Expand Up @@ -194,7 +196,7 @@ when _mappings.TryGetValue(sqlExpression, out var outer):
when _subquery.ContainsTableReference(columnExpression):
var index = _subquery.AddToProjection(columnExpression);
var projectionExpression = _subquery._projection[index];
return new ColumnExpression(projectionExpression, _tableReferenceExpression);
return new ConcreteColumnExpression(projectionExpression, _tableReferenceExpression);

default:
return base.Visit(expression);
Expand Down Expand Up @@ -296,7 +298,7 @@ public TableReferenceUpdatingExpressionVisitor(SelectExpression oldSelect, Selec
[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
if (expression is ColumnExpression columnExpression)
if (expression is ConcreteColumnExpression columnExpression)
{
columnExpression.UpdateTableReference(_oldSelect, _newSelect);
}
Expand Down Expand Up @@ -338,5 +340,134 @@ public AliasUniquefier(HashSet<string> usedAliases)
return base.Visit(expression);
}
}

private sealed class TableReferenceExpression : Expression
{
private SelectExpression _selectExpression;

public TableReferenceExpression(SelectExpression selectExpression, string alias)
{
_selectExpression = selectExpression;
Alias = alias;
}

public TableExpressionBase Table
=> _selectExpression.Tables.Single(
e => string.Equals((e as JoinExpressionBase)?.Table.Alias ?? e.Alias, Alias, StringComparison.OrdinalIgnoreCase));

public string Alias { get; internal set; }

public override Type Type => typeof(object);

public override ExpressionType NodeType => ExpressionType.Extension;
public void UpdateTableReference(SelectExpression oldSelect, SelectExpression newSelect)
{
if (ReferenceEquals(oldSelect, _selectExpression))
{
_selectExpression = newSelect;
}
}

/// <inheritdoc />
public override bool Equals(object? obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is TableReferenceExpression tableReferenceExpression
&& Equals(tableReferenceExpression));

// Since table reference is owned by SelectExpression, the select expression should be the same reference if they are matching.
// That means we also don't need to compute the hashcode for it.
// This allows us to break the cycle in computation when traversing this graph.
private bool Equals(TableReferenceExpression tableReferenceExpression)
=> string.Equals(Alias, tableReferenceExpression.Alias, StringComparison.OrdinalIgnoreCase)
&& ReferenceEquals(_selectExpression, tableReferenceExpression._selectExpression);

/// <inheritdoc />
public override int GetHashCode()
=> Alias.GetHashCode();
}

private sealed class ConcreteColumnExpression : ColumnExpression
{
private readonly TableReferenceExpression _table;

public ConcreteColumnExpression(IProperty property, IColumnBase column, TableReferenceExpression table, bool nullable)
: this(
column.Name,
table,
property.ClrType.UnwrapNullableType(),
column.PropertyMappings.First(m => m.Property == property).TypeMapping,
nullable || column.IsNullable)
{
}

public ConcreteColumnExpression(ProjectionExpression subqueryProjection, TableReferenceExpression table)
: this(
subqueryProjection.Alias, table,
subqueryProjection.Type, subqueryProjection.Expression.TypeMapping!,
IsNullableProjection(subqueryProjection))
{
}

private static bool IsNullableProjection(ProjectionExpression projectionExpression)
=> projectionExpression.Expression switch
{
ColumnExpression columnExpression => columnExpression.IsNullable,
SqlConstantExpression sqlConstantExpression => sqlConstantExpression.Value == null,
_ => true,
};

private ConcreteColumnExpression(
string name, TableReferenceExpression table, Type type, RelationalTypeMapping typeMapping, bool nullable)
: base(type, typeMapping)
{
Check.NotEmpty(name, nameof(name));
Check.NotNull(table, nameof(table));
Check.NotEmpty(table.Alias, $"{nameof(table)}.{nameof(table.Alias)}");

Name = name;
_table = table;
IsNullable = nullable;
}

public override string Name { get; }

public override TableExpressionBase Table => _table.Table;

public override string TableAlias => _table.Alias;

public override bool IsNullable { get; }

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
{
Check.NotNull(visitor, nameof(visitor));

return this;
}

public override ConcreteColumnExpression MakeNullable()
=> new(Name, _table, Type, TypeMapping!, true);

public void UpdateTableReference(SelectExpression oldSelect, SelectExpression newSelect)
=> _table.UpdateTableReference(oldSelect, newSelect);

/// <inheritdoc />
public override bool Equals(object? obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is ConcreteColumnExpression concreteColumnExpression
&& Equals(concreteColumnExpression));

private bool Equals(ConcreteColumnExpression concreteColumnExpression)
=> base.Equals(concreteColumnExpression)
&& Name == concreteColumnExpression.Name
&& _table.Equals(concreteColumnExpression._table)
&& IsNullable == concreteColumnExpression.IsNullable;

/// <inheritdoc />
public override int GetHashCode()
=> HashCode.Combine(base.GetHashCode(), Name, _table, IsNullable);
}
}
}
Loading

0 comments on commit 088c849

Please sign in to comment.