Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Correctly specify types to match constraints in materializatio… #21881

Merged
merged 1 commit into from
Aug 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,17 @@ protected override Expression VisitExtension(Expression extensionExpression)

if (extensionExpression is CollectionShaperExpression collectionShaperExpression)
{
var navigation = collectionShaperExpression.Navigation;
var collectionAccessor = navigation?.GetCollectionAccessor();
var collectionType = collectionAccessor?.CollectionType ?? collectionShaperExpression.Type;
var elementType = collectionShaperExpression.ElementType;
var collectionType = collectionShaperExpression.Type;

return Expression.Call(
_materializeCollectionMethodInfo.MakeGenericMethod(elementType, collectionType),
QueryCompilationContext.QueryContextParameter,
collectionShaperExpression.Projection,
Expression.Constant(((LambdaExpression)Visit(collectionShaperExpression.InnerShaper)).Compile()),
Expression.Constant(
collectionShaperExpression.Navigation?.GetCollectionAccessor(),
typeof(IClrCollectionAccessor)));
Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor)));
}

if (extensionExpression is SingleResultShaperExpression singleResultShaperExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,11 @@ protected override Expression VisitExtension(Expression extensionExpression)
_readerColumns)
.ProcessShaper(relationalCollectionShaperExpression.InnerShaper, out _, out _);

var collectionType = relationalCollectionShaperExpression.Type;
var elementType = collectionType.TryGetSequenceType();
var relatedElementType = innerShaper.ReturnType;
var navigation = relationalCollectionShaperExpression.Navigation;
var collectionAccessor = navigation?.GetCollectionAccessor();
var collectionType = collectionAccessor?.CollectionType ?? relationalCollectionShaperExpression.Type;
var elementType = relationalCollectionShaperExpression.ElementType;
var relatedElementType = innerShaper.ReturnType;

_inline = true;

Expand Down Expand Up @@ -700,7 +701,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
_resultCoordinatorParameter,
Expression.Constant(parentIdentifierLambda.Compile()),
Expression.Constant(outerIdentifierLambda.Compile()),
Expression.Constant(navigation?.GetCollectionAccessor(), typeof(IClrCollectionAccessor)))));
Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor)))));

_valuesArrayInitializers.Add(collectionParameter);
accessor = Expression.Convert(
Expand Down Expand Up @@ -735,15 +736,16 @@ protected override Expression VisitExtension(Expression extensionExpression)
if (!_variableShaperMapping.TryGetValue(key, out var accessor))
{
var innerProcessor = new ShaperProcessingExpressionVisitor(_parentVisitor, _resultCoordinatorParameter,
_executionStrategyParameter, relationalSplitCollectionShaperExpression.SelectExpression, _tags);
_executionStrategyParameter, relationalSplitCollectionShaperExpression.SelectExpression, _tags);
var innerShaper = innerProcessor.ProcessShaper(relationalSplitCollectionShaperExpression.InnerShaper,
out var relationalCommandCache,
out var relatedDataLoaders);

var collectionType = relationalSplitCollectionShaperExpression.Type;
var elementType = collectionType.TryGetSequenceType();
var relatedElementType = innerShaper.ReturnType;
var navigation = relationalSplitCollectionShaperExpression.Navigation;
var collectionAccessor = navigation?.GetCollectionAccessor();
var collectionType = collectionAccessor?.CollectionType ?? relationalSplitCollectionShaperExpression.Type;
var elementType = relationalSplitCollectionShaperExpression.ElementType;
var relatedElementType = innerShaper.ReturnType;

_inline = true;

Expand Down Expand Up @@ -777,7 +779,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
_dataReaderParameter,
_resultCoordinatorParameter,
Expression.Constant(parentIdentifierLambda.Compile()),
Expression.Constant(navigation?.GetCollectionAccessor(), typeof(IClrCollectionAccessor)))));
Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor)))));

_valuesArrayInitializers.Add(collectionParameter);
accessor = Expression.Convert(
Expand Down Expand Up @@ -1053,7 +1055,9 @@ private static void IncludeReference<TEntity, TIncludingEntity, TIncludedEntity>
INavigationBase inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : TEntity
where TEntity : class
where TIncludingEntity : class, TEntity
where TIncludedEntity : class
{
if (entity is TIncludingEntity includingEntity)
{
Expand Down Expand Up @@ -1093,7 +1097,8 @@ private static void InitializeIncludeCollection<TParent, TNavigationEntity>(
INavigationBase navigation,
IClrCollectionAccessor clrCollectionAccessor,
bool trackingQuery)
where TNavigationEntity : TParent
where TParent : class
where TNavigationEntity : class, TParent
{
object collection = null;
if (entity is TNavigationEntity)
Expand Down Expand Up @@ -1133,6 +1138,8 @@ private static void PopulateIncludeCollection<TIncludingEntity, TIncludedEntity>
INavigationBase inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : class
where TIncludedEntity : class
{
var collectionMaterializationContext = resultCoordinator.Collections[collectionId];
if (collectionMaterializationContext.Parent is TIncludingEntity entity)
Expand Down Expand Up @@ -1242,7 +1249,8 @@ private static void InitializeSplitIncludeCollection<TParent, TNavigationEntity>
INavigationBase navigation,
IClrCollectionAccessor clrCollectionAccessor,
bool trackingQuery)
where TNavigationEntity : TParent
where TParent : class
where TNavigationEntity : class, TParent
{
object collection = null;
if (entity is TNavigationEntity)
Expand Down Expand Up @@ -1279,6 +1287,8 @@ private static void PopulateSplitIncludeCollection<TIncludingEntity, TIncludedEn
INavigationBase inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : class
where TIncludedEntity : class
{
if (resultCoordinator.DataReaders.Count <= collectionId
|| resultCoordinator.DataReaders[collectionId] == null)
Expand Down Expand Up @@ -1357,6 +1367,8 @@ private static async Task PopulateSplitIncludeCollectionAsync<TIncludingEntity,
INavigationBase inverseNavigation,
Action<TIncludingEntity, TIncludedEntity> fixup,
bool trackingQuery)
where TIncludingEntity : class
where TIncludedEntity : class
{
if (resultCoordinator.DataReaders.Count <= collectionId
|| resultCoordinator.DataReaders[collectionId] == null)
Expand Down Expand Up @@ -1435,7 +1447,7 @@ private static TCollection InitializeCollection<TElement, TCollection>(
Func<QueryContext, DbDataReader, object[]> parentIdentifier,
Func<QueryContext, DbDataReader, object[]> outerIdentifier,
IClrCollectionAccessor clrCollectionAccessor)
where TCollection : class, IEnumerable<TElement>
where TCollection : class, ICollection<TElement>
{
var collection = clrCollectionAccessor?.Create() ?? new List<TElement>();

Expand Down Expand Up @@ -1566,7 +1578,7 @@ private static TCollection InitializeSplitCollection<TElement, TCollection>(
SplitQueryResultCoordinator resultCoordinator,
Func<QueryContext, DbDataReader, object[]> parentIdentifier,
IClrCollectionAccessor clrCollectionAccessor)
where TCollection : class, IEnumerable<TElement>
where TCollection : class, ICollection<TElement>
{
var collection = clrCollectionAccessor?.Create() ?? new List<TElement>();
var parentKey = parentIdentifier(queryContext, parentDataReader);
Expand All @@ -1587,8 +1599,8 @@ private static void PopulateSplitCollection<TCollection, TElement, TRelatedEntit
IReadOnlyList<ValueComparer> identifierValueComparers,
Func<QueryContext, DbDataReader, ResultContext, SplitQueryResultCoordinator, TRelatedEntity> innerShaper,
Action<QueryContext, IExecutionStrategy, SplitQueryResultCoordinator> relatedDataLoaders)
where TRelatedEntity : TElement
where TCollection : class, ICollection<TElement>
where TRelatedEntity : TElement
where TCollection : class, ICollection<TElement>
{
if (resultCoordinator.DataReaders.Count <= collectionId
|| resultCoordinator.DataReaders[collectionId] == null)
Expand Down Expand Up @@ -1659,8 +1671,8 @@ private static async Task PopulateSplitCollectionAsync<TCollection, TElement, TR
IReadOnlyList<ValueComparer> identifierValueComparers,
Func<QueryContext, DbDataReader, ResultContext, SplitQueryResultCoordinator, TRelatedEntity> innerShaper,
Func<QueryContext, IExecutionStrategy, SplitQueryResultCoordinator, Task> relatedDataLoaders)
where TRelatedEntity : TElement
where TCollection : class, ICollection<TElement>
where TRelatedEntity : TElement
where TCollection : class, ICollection<TElement>
{
if (resultCoordinator.DataReaders.Count <= collectionId
|| resultCoordinator.DataReaders[collectionId] == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,61 @@ protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)

#endregion

#region Issue21803

[ConditionalFact]
public virtual void Select_enumerable_navigation_backed_by_collection()
{
using (CreateScratch<MyContext21803>(Seed21803, "21803"))
{
using var context = new MyContext21803();

var query = context.Set<AppEntity21803>().Select(appEntity => appEntity.OtherEntities);

query.ToList();
}
}

private static void Seed21803(MyContext21803 context)
{
var appEntity = new AppEntity21803();
context.AddRange(
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity });

context.SaveChanges();
}

public class AppEntity21803
{
private readonly List<OtherEntity21803> _otherEntities = new List<OtherEntity21803>();

public int Id { get; private set; }
public IEnumerable<OtherEntity21803> OtherEntities => _otherEntities;
}

public class OtherEntity21803
{
public int Id { get; private set; }
public AppEntity21803 AppEntity { get; set; }
}

private class MyContext21803 : DbContext
{
public DbSet<AppEntity21803> Entities { get; set; }

protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
optionsBuilder
.UseInternalServiceProvider(InMemoryFixture.DefaultServiceProvider)
.UseInMemoryDatabase("21803");
}
}

#endregion

#region SharedHelper

private static InMemoryTestStore CreateScratch<TContext>(Action<TContext> seed, string databaseName)
Expand Down
98 changes: 96 additions & 2 deletions test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8103,8 +8103,7 @@ private SqlServerTestStore CreateDatabase19206()

#endregion


#region Issue 18510
#region Issue18510

[ConditionalFact]
public virtual void Invoke_inside_query_filter_gets_correctly_evaluated_during_translation()
Expand Down Expand Up @@ -8198,6 +8197,101 @@ private SqlServerTestStore CreateDatabase18510()

#endregion

#region Issue21803

[ConditionalTheory]
[InlineData(true, true)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(false, false)]
public virtual async Task Select_enumerable_navigation_backed_by_collection(bool async, bool split)
{
using (CreateDatabase21803())
{
using var context = new MyContext21803(_options);

var query = context.Set<AppEntity21803>().Select(appEntity => appEntity.OtherEntities);

if (split)
{
query = query.AsSplitQuery();
}

if (async)
{
await query.ToListAsync();
}
else
{
query.ToList();
}

if (split)
{
AssertSql(
@"SELECT [e].[Id]
FROM [Entities] AS [e]
ORDER BY [e].[Id]",
//
@"SELECT [o].[Id], [o].[AppEntityId], [e].[Id]
FROM [Entities] AS [e]
INNER JOIN [OtherEntity21803] AS [o] ON [e].[Id] = [o].[AppEntityId]
ORDER BY [e].[Id]");
}
else
{
AssertSql(
@"SELECT [e].[Id], [o].[Id], [o].[AppEntityId]
FROM [Entities] AS [e]
LEFT JOIN [OtherEntity21803] AS [o] ON [e].[Id] = [o].[AppEntityId]
ORDER BY [e].[Id], [o].[Id]");
}
}
}

public class AppEntity21803
{
private readonly List<OtherEntity21803> _otherEntities = new List<OtherEntity21803>();

public int Id { get; private set; }
public IEnumerable<OtherEntity21803> OtherEntities => _otherEntities;
}

public class OtherEntity21803
{
public int Id { get; private set; }
public AppEntity21803 AppEntity { get; set; }
}

private class MyContext21803 : DbContext
{
public DbSet<AppEntity21803> Entities { get; set; }

public MyContext21803(DbContextOptions options)
: base(options)
{
}
}

private SqlServerTestStore CreateDatabase21803()
=> CreateTestStore(
() => new MyContext21803(_options),
context =>
{
var appEntity = new AppEntity21803();
context.AddRange(
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity },
new OtherEntity21803 { AppEntity = appEntity });

context.SaveChanges();

ClearLog();
});

#endregion

private DbContextOptions _options;

private SqlServerTestStore CreateTestStore<TContext>(
Expand Down