Skip to content

Commit

Permalink
Cosmos: Fixes around array projection
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jun 22, 2024
1 parent c659d24 commit b99a0b1
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ public static bool TryConvertToArray(
{
subquery.ApplyProjection();

// TODO: Should the type be an array, or enumerable/queryable?
var arrayClrType = projection.Type.MakeArrayType();
var arrayClrType = typeof(IEnumerable<>).MakeGenericType(projection.Type);

switch (projection)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Cosmos.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
Expand All @@ -21,7 +23,9 @@ private static readonly MethodInfo GetParameterValueMethodInfo
= typeof(CosmosProjectionBindingExpressionVisitor)
.GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue))!;

private readonly CosmosQueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor;
private readonly CosmosSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly ITypeMappingSource _typeMappingSource;
private readonly IModel _model;
private SelectExpression _selectExpression;
private bool _clientEval;
Expand All @@ -39,10 +43,14 @@ private static readonly MethodInfo GetParameterValueMethodInfo
/// </summary>
public CosmosProjectionBindingExpressionVisitor(
IModel model,
CosmosSqlTranslatingExpressionVisitor sqlTranslator)
CosmosQueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor,
CosmosSqlTranslatingExpressionVisitor sqlTranslator,
ITypeMappingSource typeMappingSource)
{
_model = model;
_queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor;
_sqlTranslator = sqlTranslator;
_typeMappingSource = typeMappingSource;
_selectExpression = null!;
}

Expand Down Expand Up @@ -570,6 +578,50 @@ UnaryExpression unaryExpression
lambda);
}
}
else if (method is { Name: nameof(Enumerable.ToList), IsGenericMethod: true }
&& method.DeclaringType == typeof(Enumerable)
&& methodCallExpression.Arguments is [var argument]
&& argument.Type.TryGetElementType(typeof(IQueryable<>)) != null)
{
if (_queryableMethodTranslatingExpressionVisitor.TranslateSubquery(argument) is not ShapedQueryExpression subquery
|| !subquery.TryConvertToArray(_typeMappingSource, out var array))
{
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

// If ToList() was composed over a subquery with operators, the result here is an ArrayExpression (ARRAY(SELECT ...)), whose
// CLR Type is IEnumerable<T>. This can be directly used in the resulting ProjectingBindingExpression - the shaper will
// simply read the JSON results out successfully.
// But if ToList() is composed directly over an array property, that property could have type e.g. T[], which will be read
// in the shaper, and then the cast from T[] to List<T> will fail. As a result, wrap the array in an additional
// "reprojection" subquery, effectively to change the CLR type.
if (array is SqlExpression scalarArray
&& !(array.Type.IsGenericType && array.Type.GetGenericTypeDefinition() == typeof(IEnumerable<>)))
{
Check.DebugAssert(
array is not ScalarArrayExpression and not ObjectArrayExpression, "ArrayExpression should be IEnumerable");

if (scalarArray is not { TypeMapping.ElementTypeMapping: CosmosTypeMapping elementTypeMapping })
{
throw new UnreachableException("Scalar array with no element type mapping");
}

// TODO: Proper alias management (#33894).
var arrayReprojectionSubquery = SelectExpression.CreateForCollection(
array, "i", new ScalarReferenceExpression("i", elementTypeMapping.ClrType, elementTypeMapping));
arrayReprojectionSubquery.ApplyProjection();

array = new ScalarArrayExpression(
arrayReprojectionSubquery,
methodCallExpression.Type, // List<>
_typeMappingSource.FindMapping(methodCallExpression.Type, _model, elementTypeMapping));
}

return new ProjectionBindingExpression(
_selectExpression,
_selectExpression.AddToProjection(array),
methodCallExpression.Type);
}
}

var @object = Visit(methodCallExpression.Object);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Cosmos.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;
using Microsoft.EntityFrameworkCore.Internal;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
Expand Down Expand Up @@ -54,7 +55,7 @@ public CosmosQueryableMethodTranslatingExpressionVisitor(
_methodCallTranslatorProvider,
this);
_projectionBindingExpressionVisitor =
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, _sqlTranslator);
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, this, _sqlTranslator, _typeMappingSource);
_subquery = false;
}

Expand All @@ -81,7 +82,7 @@ protected CosmosQueryableMethodTranslatingExpressionVisitor(
_methodCallTranslatorProvider,
parentVisitor);
_projectionBindingExpressionVisitor =
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, _sqlTranslator);
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, this, _sqlTranslator, _typeMappingSource);
_subquery = true;
}

Expand Down Expand Up @@ -1125,8 +1126,10 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
// ElementAtOrDefault over an array of scalars
case SqlExpression scalarArray when projection is SqlExpression element:
{
var slice = _sqlExpressionFactory.Function(
"ARRAY_SLICE", [scalarArray, translatedCount], scalarArray.Type, scalarArray.TypeMapping);
var arrayType = typeof(IEnumerable<>).MakeGenericType(projection.Type);
var arrayTypeMapping = _typeMappingSource.FindMapping(arrayType, _queryCompilationContext.Model, element.TypeMapping);

var slice = _sqlExpressionFactory.Function("ARRAY_SLICE", [scalarArray, translatedCount], arrayType, arrayTypeMapping);

// TODO: Proper alias management (#33894). Ideally reach into the source of the original SelectExpression and use that alias.
var translatedSelect = SelectExpression.CreateForCollection(
Expand All @@ -1139,8 +1142,10 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
// ElementAtOrDefault over an array os structural types
case not null when projectedStructuralTypeShaper is not null:
{
var arrayType = typeof(IEnumerable<>).MakeGenericType(projectedStructuralTypeShaper.Type);

// TODO: Proper alias management (#33894).
var slice = new ObjectFunctionExpression("ARRAY_SLICE", [array, translatedCount], projectedStructuralTypeShaper.Type);
var slice = new ObjectFunctionExpression("ARRAY_SLICE", [array, translatedCount], arrayType);
var translatedSelect = SelectExpression.CreateForCollection(
slice,
"i",
Expand Down Expand Up @@ -1585,7 +1590,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
// value conversion). #34026.
var elementClrType = inlineQueryRootExpression.ElementType;
var elementTypeMapping = _typeMappingSource.FindMapping(elementClrType)!;
var arrayTypeMapping = _typeMappingSource.FindMapping(elementClrType.MakeArrayType()); // TODO: IEnumerable?
var arrayTypeMapping = _typeMappingSource.FindMapping(typeof(IEnumerable<>).MakeGenericType(elementClrType));
var inlineArray = new ArrayConstantExpression(elementClrType, translatedItems, arrayTypeMapping);

// TODO: Do proper alias management: #33894
Expand Down Expand Up @@ -1614,7 +1619,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
// TODO: Temporary hack - need to perform proper derivation of the array type mapping from the element (e.g. for
// value conversion). #34026.
var elementClrType = parameterQueryRootExpression.ElementType;
var arrayTypeMapping = _typeMappingSource.FindMapping(elementClrType.MakeArrayType()); // TODO: IEnumerable?
var arrayTypeMapping = _typeMappingSource.FindMapping(typeof(IEnumerable<>).MakeGenericType(elementClrType));
var elementTypeMapping = _typeMappingSource.FindMapping(elementClrType)!;
var sqlParameterExpression = new SqlParameterExpression(parameterQueryRootExpression.ParameterExpression, arrayTypeMapping);

Expand Down Expand Up @@ -1683,13 +1688,17 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
&& source2.TryConvertToArray(_typeMappingSource, out var array2, out var projection2, ignoreOrderings)
&& projection1.Type == projection2.Type)
{
var arrayType = typeof(IEnumerable<>).MakeGenericType(projection1.Type);

// Set operation over arrays of scalars
if (projection1 is SqlExpression sqlProjection1
&& projection2 is SqlExpression sqlProjection2
&& (sqlProjection1.TypeMapping ?? sqlProjection2.TypeMapping) is CoreTypeMapping typeMapping)
&& (sqlProjection1.TypeMapping ?? sqlProjection2.TypeMapping) is CosmosTypeMapping typeMapping)
{
var arrayTypeMapping = _typeMappingSource.FindMapping(arrayType, _queryCompilationContext.Model, typeMapping);

// TODO: Proper alias management (#33894).
var translation = _sqlExpressionFactory.Function(functionName, [array1, array2], projection1.Type, typeMapping);
var translation = _sqlExpressionFactory.Function(functionName, [array1, array2], arrayType, arrayTypeMapping);
var select = SelectExpression.CreateForCollection(
translation, "i", new ScalarReferenceExpression("i", projection1.Type, typeMapping));
return source1.UpdateQueryExpression(select);
Expand All @@ -1701,7 +1710,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
&& structuralType1 == structuralType2)
{
// TODO: Proper alias management (#33894).
var translation = new ObjectFunctionExpression(functionName, [array1, array2], projection1.Type);
var translation = new ObjectFunctionExpression(functionName, [array1, array2], arrayType);
var select = SelectExpression.CreateForCollection(
translation, "i", new ObjectReferenceExpression((IEntityType)structuralType1, "i"));
return CreateShapedQueryExpression(select, structuralType1.ClrType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </remarks>
[DebuggerDisplay("{Microsoft.EntityFrameworkCore.Query.ExpressionPrinter.Print(this), nq}")]
public class ScalarAccessExpression(Expression @object, string propertyName, Type clrType, CoreTypeMapping? typeMapping)
: SqlExpression(clrType, typeMapping), IAccessExpression
{
Expand Down
4 changes: 2 additions & 2 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
// TODO: This infers based on the CLR type; need to properly infer based on the element type mapping
// TODO: being applied here (e.g. WHERE @p[1] = c.PropertyWithValueConverter). #34026
var arrayTypeMapping = left.TypeMapping
?? (typeMapping is null ? null : typeMappingSource.FindMapping(typeMapping.ClrType.MakeArrayType()));
?? (typeMapping is null ? null : typeMappingSource.FindMapping(typeof(IEnumerable<>).MakeGenericType(typeMapping.ClrType)));
return new SqlBinaryExpression(
ExpressionType.ArrayIndex,
ApplyTypeMapping(left, arrayTypeMapping),
Expand Down Expand Up @@ -290,7 +290,7 @@ private InExpression ApplyTypeMappingOnIn(InExpression inExpression)
var arrayClrType = arrayExpression.Type switch
{
var t when t.TryGetSequenceType() != typeof(object) => t,
{ IsArray: true } => itemExpression.Type.MakeArrayType(),
{ IsArray: true } => typeof(IEnumerable<>).MakeGenericType(itemExpression.Type),
{ IsConstructedGenericType: true, GenericTypeArguments.Length: 1 } t
=> t.GetGenericTypeDefinition().MakeGenericType(itemExpression.Type),
_ => throw new InvalidOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3248,57 +3248,69 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
var entityProjection = selectExpression.GetProjection(entityProjectionBindingExpression).GetConstantValue<object>();

if (entityProjection is QueryableJsonProjectionInfo || (_insideCollection && entityProjection is JsonProjectionInfo))
switch (entityProjection)
{
throw new InvalidOperationException(
RelationalStrings.JsonProjectingQueryableOperationNoTrackingWithIdentityResolution(nameof(QueryTrackingBehavior.NoTrackingWithIdentityResolution)));
}
case QueryableJsonProjectionInfo:
case JsonProjectionInfo when _insideCollection:
throw new InvalidOperationException(
RelationalStrings.JsonProjectingQueryableOperationNoTrackingWithIdentityResolution(
nameof(QueryTrackingBehavior.NoTrackingWithIdentityResolution)));

if (entityProjection is JsonProjectionInfo jsonEntityProjectionInfo)
{
var jsonEntityType = (IEntityType)entityShaperExpression.StructuralType;
if (_insideInclude)
case JsonProjectionInfo jsonEntityProjectionInfo:
{
if (!_includedJsonEntityTypes.Contains(jsonEntityType))
var jsonEntityType = (IEntityType)entityShaperExpression.StructuralType;
if (_insideInclude)
{
_includedJsonEntityTypes.Add(jsonEntityType);
if (!_includedJsonEntityTypes.Contains(jsonEntityType))
{
_includedJsonEntityTypes.Add(jsonEntityType);
}
}
else
{
_projectedKeyAccessInfos.Add((jsonEntityType, jsonEntityProjectionInfo.KeyAccessInfo));
}

break;
}
else
{
_projectedKeyAccessInfos.Add((jsonEntityType, jsonEntityProjectionInfo.KeyAccessInfo));
}
}

return extensionExpression;
default:
return extensionExpression;
}
}

if (extensionExpression is CollectionResultExpression { ProjectionBindingExpression: ProjectionBindingExpression collectionProjectionBindingExpression } collectionResultExpression)
{
var collectionProjection = selectExpression.GetProjection(collectionProjectionBindingExpression).GetConstantValue<object>();
if (collectionProjection is QueryableJsonProjectionInfo || (_insideCollection && collectionProjection is JsonProjectionInfo))
{
throw new InvalidOperationException(
RelationalStrings.JsonProjectingQueryableOperationNoTrackingWithIdentityResolution(nameof(QueryTrackingBehavior.NoTrackingWithIdentityResolution)));
}

if (collectionProjection is JsonProjectionInfo jsonCollectionProjectionInfo)
switch (collectionProjection)
{
var jsonEntityType = collectionResultExpression.Navigation!.TargetEntityType;
if (_insideInclude)
case QueryableJsonProjectionInfo:
case JsonProjectionInfo when _insideCollection:
throw new InvalidOperationException(
RelationalStrings.JsonProjectingQueryableOperationNoTrackingWithIdentityResolution(nameof(QueryTrackingBehavior.NoTrackingWithIdentityResolution)));

case JsonProjectionInfo jsonCollectionProjectionInfo:
{
if (!_includedJsonEntityTypes.Contains(jsonEntityType))
var jsonEntityType = collectionResultExpression.Navigation!.TargetEntityType;
if (_insideInclude)
{
_includedJsonEntityTypes.Add(jsonEntityType);
if (!_includedJsonEntityTypes.Contains(jsonEntityType))
{
_includedJsonEntityTypes.Add(jsonEntityType);
}
}
else
{
_projectedKeyAccessInfos.Add((jsonEntityType, jsonCollectionProjectionInfo.KeyAccessInfo));
}

break;
}
else
{
_projectedKeyAccessInfos.Add((jsonEntityType, jsonCollectionProjectionInfo.KeyAccessInfo));
}
}

return extensionExpression;
default:
return extensionExpression;
}
}

return base.VisitExtension(extensionExpression);
Expand Down
Loading

0 comments on commit b99a0b1

Please sign in to comment.