Skip to content

Commit

Permalink
Ensure FK properties have nullable-appropriate value comparers
Browse files Browse the repository at this point in the history
Part of #11597

This change takes the ValueComparer defined for the principal key and uses it for the foreign key, but also accommodating for nulls appropriately. As part of this, we started getting some more complex expressions in value comparers used in the in-memory database. These expressions became part of the query, which then meant they needed to be translated. Therefore, this logic has been changed to call the value comparer as a method when using the in-memory database, and this method is then detected. This incidentally fixes #27495, which was also a case of a value comparer expression that could not be translated, and any other case where a value comparer could not be translated in in-memory queries.
  • Loading branch information
ajcvickers committed Mar 16, 2022
1 parent bbc4c7b commit 8e23bfb
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,73 @@ static Expression RemoveConvert(Expression e)
var property = FindProperty(newLeft) ?? FindProperty(newRight);
var comparer = property?.GetValueComparer();

if (comparer != null
&& comparer.Type.IsAssignableFrom(newLeft.Type)
&& comparer.Type.IsAssignableFrom(newRight.Type))
if (comparer != null)
{
if (binaryExpression.NodeType == ExpressionType.Equal)
MethodInfo? objectEquals = null;
MethodInfo? exactMatch = null;

var converter = property?.GetValueConverter();
foreach (var candidate in comparer
.GetType()
.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.Where(
m => m.Name == "Equals" && m.GetParameters().Length == 2)
.ToList())
{
return comparer.ExtractEqualsBody(newLeft, newRight);
}
var parameters = candidate.GetParameters();
var leftType = parameters[0].ParameterType;
var rightType = parameters[1].ParameterType;

if (binaryExpression.NodeType == ExpressionType.NotEqual)
{
return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight));
if (leftType == typeof(object)
&& rightType == typeof(object))
{
objectEquals = candidate;
continue;
}

var matchingLeft = leftType.IsAssignableFrom(newLeft.Type)
? newLeft
: converter != null && leftType.IsAssignableFrom(converter.ModelClrType)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newLeft,
converter.ConvertFromProviderExpression.Body)
: null;

var matchingRight = rightType.IsAssignableFrom(newRight.Type)
? newRight
: converter != null && rightType.IsAssignableFrom(converter.ModelClrType)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newRight,
converter.ConvertFromProviderExpression.Body)
: null;

if (matchingLeft != null && matchingRight != null)
{
exactMatch = candidate;
newLeft = matchingLeft;
newRight = matchingRight;
break;
}
}

var equalsExpression =
exactMatch != null
? Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
exactMatch,
newLeft,
newRight)
: Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
objectEquals!,
Expression.Convert(newLeft, typeof(object)),
Expression.Convert(newRight, typeof(object)));

return binaryExpression.NodeType == ExpressionType.NotEqual
? Expression.IsFalse(equalsExpression)
: equalsExpression;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,11 @@ private static bool ProcessJoinCondition(
}

if (joinCondition is MethodCallExpression methodCallExpression
&& methodCallExpression.Method.IsStatic
&& methodCallExpression.Method.DeclaringType == typeof(object)
&& methodCallExpression.Method.Name == nameof(object.Equals)
&& methodCallExpression.Arguments.Count == 2)
&& methodCallExpression.Arguments.Count == 2
&& ((methodCallExpression.Method.IsStatic
&& methodCallExpression.Method.DeclaringType == typeof(object))
|| typeof(ValueComparer).IsAssignableFrom(methodCallExpression.Method.DeclaringType)))
{
leftExpressions.Add(methodCallExpression.Arguments[0]);
rightExpressions.Add(methodCallExpression.Arguments[1]);
Expand Down
6 changes: 3 additions & 3 deletions src/EFCore/Extensions/Internal/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ public static Expression MakeHasDefaultValue(
}

var property = propertyBase as IReadOnlyProperty;
var clrType = propertyBase?.ClrType ?? currentValueExpression.Type;
var comparer = property?.GetValueComparer()
?? ValueComparer.CreateDefault(clrType, favorStructuralComparisons: false);
?? ValueComparer.CreateDefault(
propertyBase?.ClrType ?? currentValueExpression.Type, favorStructuralComparisons: false);

return comparer.ExtractEqualsBody(
comparer.Type != clrType
comparer.Type != currentValueExpression.Type
? Expression.Convert(currentValueExpression, comparer.Type)
: currentValueExpression,
Expression.Default(comparer.Type));
Expand Down
62 changes: 58 additions & 4 deletions src/EFCore/Metadata/Internal/Property.cs
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ public virtual CoreTypeMapping? TypeMapping
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ValueComparer? GetValueComparer()
=> GetValueComparer(null)
?? TypeMapping?.Comparer;
=> ToNullableComparer(GetValueComparer(null)
?? TypeMapping?.Comparer);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -821,8 +821,62 @@ public virtual CoreTypeMapping? TypeMapping
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ValueComparer? GetKeyValueComparer()
=> GetValueComparer(null)
?? TypeMapping?.KeyComparer;
=> ToNullableComparer(GetValueComparer(null)
?? TypeMapping?.KeyComparer);

private ValueComparer? ToNullableComparer(ValueComparer? valueComparer)
{
if (valueComparer == null
|| !ClrType.IsNullableValueType()
|| valueComparer.Type.IsNullableValueType())
{
return valueComparer;
}

var newEqualsParam1 = Expression.Parameter(ClrType, "v1");
var newEqualsParam2 = Expression.Parameter(ClrType, "v2");
var newHashCodeParam = Expression.Parameter(ClrType, "v");
var newSnapshotParam = Expression.Parameter(ClrType, "v");
var hasValueMethod = ClrType.GetMethod("get_HasValue")!;
var v1HasValue = Expression.Parameter(typeof(bool), "v1HasValue");
var v2HasValue = Expression.Parameter(typeof(bool), "v2HasValue");

return (ValueComparer)Activator.CreateInstance(
typeof(ValueComparer<>).MakeGenericType(ClrType),
Expression.Lambda(
Expression.Block(
typeof(bool),
new[] { v1HasValue, v2HasValue },
Expression.Assign(v1HasValue, Expression.Call(newEqualsParam1, hasValueMethod)),
Expression.Assign(v2HasValue, Expression.Call(newEqualsParam2, hasValueMethod)),
Expression.OrElse(
Expression.AndAlso(
v1HasValue,
Expression.AndAlso(
v2HasValue,
valueComparer.ExtractEqualsBody(
Expression.Convert(newEqualsParam1, valueComparer.Type),
Expression.Convert(newEqualsParam2, valueComparer.Type)))),
Expression.AndAlso(
Expression.Not(v1HasValue),
Expression.Not(v2HasValue)))),
newEqualsParam1, newEqualsParam2),
Expression.Lambda(
Expression.Condition(
Expression.Call(newHashCodeParam, hasValueMethod),
valueComparer.ExtractHashCodeBody(
Expression.Convert(newHashCodeParam, valueComparer.Type)),
Expression.Constant(0, typeof(int))),
newHashCodeParam),
Expression.Lambda(
Expression.Condition(
Expression.Call(newSnapshotParam, hasValueMethod),
Expression.Convert(
valueComparer.ExtractSnapshotBody(
Expression.Convert(newSnapshotParam, valueComparer.Type)), ClrType),
Expression.Default(ClrType)),
newSnapshotParam))!;
}

private ValueComparer? GetValueComparer(HashSet<IProperty>? checkedProperties)
{
Expand Down
6 changes: 2 additions & 4 deletions test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1686,8 +1686,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values
TestNullableDouble = -1.23456789,
TestNullableDecimal = -1234567890.01M,
TestNullableDateTime = DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(),
TestNullableDateTimeOffset =
new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(),
TestNullableDateTimeOffset = new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
TestNullableTimeSpan = new TimeSpan(0, 10, 9, 8, 7),
TestNullableSingle = -1.234F,
TestNullableBoolean = false,
Expand Down Expand Up @@ -1723,8 +1722,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values
AssertEqualIfMapped(entityType, -1.23456789, () => dt.TestNullableDouble);
AssertEqualIfMapped(entityType, -1234567890.01M, () => dt.TestNullableDecimal);
AssertEqualIfMapped(entityType, DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(), () => dt.TestNullableDateTime);
AssertEqualIfMapped(
entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(),
AssertEqualIfMapped(entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
() => dt.TestNullableDateTimeOffset);
AssertEqualIfMapped(entityType, new TimeSpan(0, 10, 9, 8, 7), () => dt.TestNullableTimeSpan);
AssertEqualIfMapped(entityType, -1.234F, () => dt.TestNullableSingle);
Expand Down
8 changes: 7 additions & 1 deletion test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ private Email(string value)
_value = value;
}

public override bool Equals(object obj)
=> _value == ((Email)obj)?._value;

public override int GetHashCode()
=> _value.GetHashCode();

public static Email Create(string value)
=> new(value);

Expand Down Expand Up @@ -1069,7 +1075,7 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
b.Property(nameof(BuiltInNullableDataTypes.TestNullableDateTimeOffset)).HasConversion(
new ValueConverter<DateTimeOffset?, long>(
v => v.Value.ToUnixTimeMilliseconds(),
v => DateTimeOffset.FromUnixTimeMilliseconds(v)));
v => DateTimeOffset.FromUnixTimeMilliseconds(v).ToOffset(TimeSpan.FromHours(-8.0))));
b.Property(nameof(BuiltInNullableDataTypes.TestNullableDouble)).HasConversion(
new ValueConverter<double?, decimal?>(
Expand Down
7 changes: 5 additions & 2 deletions test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ public int NullableAsNonNullable
public int? NonNullableAsNullable
{
get => _nonNullableAsNullable;
set => _nonNullableAsNullable = (int)value;
set => _nonNullableAsNullable = value ?? 0;
}
}

Expand Down Expand Up @@ -1930,7 +1930,10 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
{
b.Property(e => e.Id).HasField("_id");
b.Property(e => e.NullableAsNonNullable).HasField("_nullableAsNonNullable").ValueGeneratedOnAddOrUpdate();
b.Property(e => e.NonNullableAsNullable).HasField("_nonNullableAsNullable").ValueGeneratedOnAddOrUpdate();
b.Property(e => e.NonNullableAsNullable)
.HasField("_nonNullableAsNullable")
.ValueGeneratedOnAddOrUpdate()
.UsePropertyAccessMode(PropertyAccessMode.Property);
});

modelBuilder.Entity<OptionalProduct>();
Expand Down
2 changes: 1 addition & 1 deletion test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ public virtual void Value_converter_configured_on_non_nullable_type_is_applied()

var wierd = entityType.FindProperty("Wierd");
Assert.IsType<NumberToStringConverter<int>>(wierd.GetValueConverter());
Assert.IsType<CustomValueComparer<int>>(wierd.GetValueComparer());
Assert.IsType<ValueComparer<int?>>(wierd.GetValueComparer());
}

[ConditionalFact]
Expand Down

0 comments on commit 8e23bfb

Please sign in to comment.