Skip to content

Commit

Permalink
InMemory: Some refactorings
Browse files Browse the repository at this point in the history
Part of #16963
  • Loading branch information
smitpatel committed Sep 3, 2019
1 parent 1fb1373 commit 5d60090
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
return null;
}

if (TypeNullabilityChanged(newLeft.Type, binaryExpression.Left.Type)
|| TypeNullabilityChanged(newRight.Type, binaryExpression.Right.Type))
if (IsConvertedToNullable(newLeft, binaryExpression.Left)
|| IsConvertedToNullable(newRight, binaryExpression.Right))
{
newLeft = MakeNullable(newLeft);
newRight = MakeNullable(newRight);
newLeft = ConvertToNullable(newLeft);
newRight = ConvertToNullable(newRight);
}

return Expression.MakeBinary(
Expand All @@ -94,11 +94,6 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
binaryExpression.Conversion);
}

private static Expression MakeNullable(Expression expression)
=> !expression.Type.IsNullableType()
? Expression.Convert(expression, expression.Type.MakeNullable())
: expression;

protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
var test = Visit(conditionalExpression.Test);
Expand All @@ -115,11 +110,11 @@ protected override Expression VisitConditional(ConditionalExpression conditional
test = Expression.Equal(test, Expression.Constant(true, typeof(bool?)));
}

if (TypeNullabilityChanged(ifTrue.Type, conditionalExpression.IfTrue.Type)
|| TypeNullabilityChanged(ifFalse.Type, conditionalExpression.IfFalse.Type))
if (IsConvertedToNullable(ifTrue, conditionalExpression.IfTrue)
|| IsConvertedToNullable(ifFalse, conditionalExpression.IfFalse))
{
ifTrue = MakeNullable(ifTrue);
ifFalse = MakeNullable(ifFalse);
ifTrue = ConvertToNullable(ifTrue);
ifFalse = ConvertToNullable(ifFalse);
}

return Expression.Condition(test, ifTrue, ifFalse);
Expand All @@ -142,12 +137,17 @@ protected override Expression VisitMember(MemberExpression memberExpression)
return result;
}

static bool shouldApplyNullProtectionForMemberAccess(Type callerType, string memberName)
=> !(callerType.IsGenericType
&& callerType.GetGenericTypeDefinition() == typeof(Nullable<>)
&& (memberName == nameof(Nullable<int>.Value) || memberName == nameof(Nullable<int>.HasValue)));

var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression);
if (innerExpression != null
&& innerExpression.Type.IsNullableType()
&& ShouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name))
&& shouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name))
{
updatedMemberExpression = MakeNullable(updatedMemberExpression);
updatedMemberExpression = ConvertToNullable(updatedMemberExpression);

return Expression.Condition(
Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)),
Expand All @@ -158,11 +158,6 @@ protected override Expression VisitMember(MemberExpression memberExpression)
return updatedMemberExpression;
}

private bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string memberName)
=> !(callerType.IsGenericType
&& callerType.GetGenericTypeDefinition() == typeof(Nullable<>)
&& (memberName == nameof(Nullable<int>.Value) || memberName == nameof(Nullable<int>.HasValue)));

private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result)
{
result = null;
Expand Down Expand Up @@ -204,7 +199,10 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
result = BindProperty(entityProjection, property);

// if the result type change was just nullability change e.g from int to int? we want to preserve the new type for null propagation
if (result.Type != type && !TypeNullabilityChanged(result.Type, type))
if (result.Type != type
&& !(result.Type.IsNullableType()
&& !type.IsNullableType()
&& result.Type.UnwrapNullableType() == type))
{
result = Expression.Convert(result, type);
}
Expand All @@ -215,13 +213,23 @@ private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Typ
return false;
}

private bool TypeNullabilityChanged(Type maybeNullableType, Type nonNullableType)
=> maybeNullableType.IsNullableType() && !nonNullableType.IsNullableType() && maybeNullableType.UnwrapNullableType() == nonNullableType;
private static bool IsConvertedToNullable(Expression result, Expression original)
=> result.Type.IsNullableType()
&& !original.Type.IsNullableType()
&& result.Type.UnwrapNullableType() == original.Type;

private static Expression ConvertToNullable(Expression expression)
=> !expression.Type.IsNullableType()
? Expression.Convert(expression, expression.Type.MakeNullable())
: expression;

private static Expression ConvertToNonNullable(Expression expression)
=> expression.Type.IsNullableType()
? Expression.Convert(expression, expression.Type.UnwrapNullableType())
: expression;

private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property)
{
return entityProjectionExpression.BindProperty(property);
}
=> entityProjectionExpression.BindProperty(property);

private static Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
{
Expand Down Expand Up @@ -387,7 +395,7 @@ MethodInfo getMethod()
}

var arguments = new Expression[methodCallExpression.Arguments.Count];
var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray();
var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray();
for (var i = 0; i < arguments.Length; i++)
{
var argument = Visit(methodCallExpression.Arguments[i]);
Expand All @@ -398,10 +406,10 @@ MethodInfo getMethod()

// if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type,
// so we are forced to cast back to the original type
if (argument.Type != methodCallExpression.Arguments[i].Type
if (IsConvertedToNullable(argument, methodCallExpression.Arguments[i])
&& !parameterTypes[i].IsAssignableFrom(argument.Type))
{
argument = Expression.Convert(argument, methodCallExpression.Arguments[i].Type);
argument = ConvertToNonNullable(argument);
}

arguments[i] = argument;
Expand All @@ -413,11 +421,11 @@ MethodInfo getMethod()
&& @object.Type.IsNullableType()
&& !(methodCallExpression.Method.Name == nameof(Nullable<int>.GetValueOrDefault)))
{
var result = (Expression)methodCallExpression.Update(
var result = (Expression)methodCallExpression.Update(
Expression.Convert(@object, methodCallExpression.Object.Type),
arguments);

result = MakeNullable(result);
result = ConvertToNullable(result);
result = Expression.Condition(
Expression.Equal(@object, Expression.Constant(null, @object.Type)),
Expression.Constant(null, result.Type),
Expand Down Expand Up @@ -475,9 +483,9 @@ protected override Expression VisitNew(NewExpression newExpression)
foreach (var argument in newExpression.Arguments)
{
var newArgument = Visit(argument);
if (newArgument.Type != argument.Type)
if (IsConvertedToNullable(newArgument, argument))
{
newArgument = Expression.Convert(newArgument, argument.Type);
newArgument = ConvertToNonNullable(newArgument);
}

newArguments.Add(newArgument);
Expand All @@ -492,9 +500,9 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio
foreach (var expression in newArrayExpression.Expressions)
{
var newExpression = Visit(expression);
if (newExpression.Type != expression.Type)
if (IsConvertedToNullable(newExpression, expression))
{
newExpression = Expression.Convert(newExpression, expression.Type);
newExpression = ConvertToNonNullable(newExpression);
}

newExpressions.Add(newExpression);
Expand All @@ -503,32 +511,15 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio
return newArrayExpression.Update(newExpressions);
}

protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression)
protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment)
{
var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression);
var bindings = new List<MemberBinding>();
foreach (var binding in memberInitExpression.Bindings)
var expression = Visit(memberAssignment.Expression);
if (IsConvertedToNullable(expression, memberAssignment.Expression))
{
switch (binding)
{
case MemberAssignment memberAssignment:
var expression = Visit(memberAssignment.Expression);
if (expression.Type != memberAssignment.Expression.Type)
{
expression = Expression.Convert(expression, memberAssignment.Expression.Type);
}

bindings.Add(Expression.Bind(memberAssignment.Member, expression));
break;

default:
// TODO: MemberMemberBinding and MemberListBinding
bindings.Add(binding);
break;
}
expression = ConvertToNonNullable(expression);
}

return memberInitExpression.Update(newExpression, bindings);
return memberAssignment.Update(expression);
}

protected override Expression VisitExtension(Expression extensionExpression)
Expand Down
15 changes: 13 additions & 2 deletions src/EFCore.InMemory/Query/Internal/InMemoryLinqOperatorProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,20 @@ public static MethodInfo GetMinWithSelector(Type type)
private static Dictionary<Type, MethodInfo> SumWithoutSelectorMethods { get; }
private static Dictionary<Type, MethodInfo> SumWithSelectorMethods { get; }

private static bool IsFunc(Type type, int funcGenericArgs = 2)
private static Type GetFuncType(int funcGenericArguments)
{
return funcGenericArguments switch
{
1 => typeof(Func<>),
2 => typeof(Func<,>),
3 => typeof(Func<,,>),
4 => typeof(Func<,,,>),
_ => throw new InvalidOperationException("Invalid number of arguments for Func"),
};
}
private static bool IsFunc(Type type, int funcGenericArguments = 2)
=> type.IsGenericType
&& type.GetGenericArguments().Length == funcGenericArgs;
&& type.GetGenericTypeDefinition() == GetFuncType(funcGenericArguments);

static InMemoryLinqOperatorProvider()
{
Expand Down
6 changes: 3 additions & 3 deletions src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public InMemoryQueryExpression(IEntityType entityType)
Constant(property.GetIndex()),
MakeMemberAccess(_valueBufferParameter,
_valueBufferCountMemberInfo)),
CreateReadValueExpression(typeof(object), property.GetIndex(), property),
Default(typeof(object)));
CreateReadValueExpression(property.ClrType, property.GetIndex(), property),
Default(property.ClrType));
}

var entityProjection = new EntityProjectionExpression(entityType, readExpressionMap);
Expand Down Expand Up @@ -256,7 +256,7 @@ public virtual void ApplyDefaultIfEmpty()
{
if (_valueBufferSlots.Count != 0)
{
throw new InvalidOperationException("Cannot Apply DefaultIfEmpty after ClientProjection.");
throw new InvalidOperationException("Cannot apply DefaultIfEmpty after a client-evaluated projection.");
}

var result = new Dictionary<ProjectionMember, Expression>();
Expand Down
Loading

0 comments on commit 5d60090

Please sign in to comment.