diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs
index 145c4687168..25957b2b815 100644
--- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs
+++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs
@@ -240,19 +240,73 @@ static Expression RemoveConvert(Expression e)
var property = FindProperty(newLeft) ?? FindProperty(newRight);
var comparer = property?.GetValueComparer();
- if (comparer != null
- && comparer.Type.IsAssignableFrom(newLeft.Type)
- && comparer.Type.IsAssignableFrom(newRight.Type))
+ if (comparer != null)
{
- if (binaryExpression.NodeType == ExpressionType.Equal)
+ MethodInfo? objectEquals = null;
+ MethodInfo? exactMatch = null;
+
+ var converter = property?.GetValueConverter();
+ foreach (var candidate in comparer
+ .GetType()
+ .GetMethods(BindingFlags.Public | BindingFlags.Instance)
+ .Where(
+ m => m.Name == "Equals" && m.GetParameters().Length == 2)
+ .ToList())
{
- return comparer.ExtractEqualsBody(newLeft, newRight);
- }
+ var parameters = candidate.GetParameters();
+ var leftType = parameters[0].ParameterType;
+ var rightType = parameters[1].ParameterType;
- if (binaryExpression.NodeType == ExpressionType.NotEqual)
- {
- return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight));
+ if (leftType == typeof(object)
+ && rightType == typeof(object))
+ {
+ objectEquals = candidate;
+ continue;
+ }
+
+ var matchingLeft = leftType.IsAssignableFrom(newLeft.Type)
+ ? newLeft
+ : converter != null && leftType.IsAssignableFrom(converter.ModelClrType)
+ ? ReplacingExpressionVisitor.Replace(
+ converter.ConvertFromProviderExpression.Parameters.Single(),
+ newLeft,
+ converter.ConvertFromProviderExpression.Body)
+ : null;
+
+ var matchingRight = rightType.IsAssignableFrom(newRight.Type)
+ ? newRight
+ : converter != null && rightType.IsAssignableFrom(converter.ModelClrType)
+ ? ReplacingExpressionVisitor.Replace(
+ converter.ConvertFromProviderExpression.Parameters.Single(),
+ newRight,
+ converter.ConvertFromProviderExpression.Body)
+ : null;
+
+ if (matchingLeft != null && matchingRight != null)
+ {
+ exactMatch = candidate;
+ newLeft = matchingLeft;
+ newRight = matchingRight;
+ break;
+ }
}
+
+ var equalsExpression =
+ exactMatch != null
+ ? Expression.Call(
+ Expression.Constant(comparer, comparer.GetType()),
+ exactMatch,
+ newLeft,
+ newRight)
+ : Expression.Call(
+ Expression.Constant(comparer, comparer.GetType()),
+ objectEquals!,
+ Expression.Convert(newLeft, typeof(object)),
+ Expression.Convert(newRight, typeof(object)));
+
+ return binaryExpression.NodeType == ExpressionType.NotEqual
+ ? Expression.IsFalse(equalsExpression)
+ : equalsExpression;
}
}
diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs
index 00f06a3641f..ebfe4334114 100644
--- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs
+++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs
@@ -594,10 +594,11 @@ private static bool ProcessJoinCondition(
}
if (joinCondition is MethodCallExpression methodCallExpression
- && methodCallExpression.Method.IsStatic
- && methodCallExpression.Method.DeclaringType == typeof(object)
&& methodCallExpression.Method.Name == nameof(object.Equals)
- && methodCallExpression.Arguments.Count == 2)
+ && methodCallExpression.Arguments.Count == 2
+ && ((methodCallExpression.Method.IsStatic
+ && methodCallExpression.Method.DeclaringType == typeof(object))
+ || typeof(ValueComparer).IsAssignableFrom(methodCallExpression.Method.DeclaringType)))
{
leftExpressions.Add(methodCallExpression.Arguments[0]);
rightExpressions.Add(methodCallExpression.Arguments[1]);
diff --git a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs
index fe8b95eb478..173f010df74 100644
--- a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs
+++ b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs
@@ -42,12 +42,12 @@ public static Expression MakeHasDefaultValue(
}
var property = propertyBase as IReadOnlyProperty;
- var clrType = propertyBase?.ClrType ?? currentValueExpression.Type;
var comparer = property?.GetValueComparer()
- ?? ValueComparer.CreateDefault(clrType, favorStructuralComparisons: false);
+ ?? ValueComparer.CreateDefault(
+ propertyBase?.ClrType ?? currentValueExpression.Type, favorStructuralComparisons: false);
return comparer.ExtractEqualsBody(
- comparer.Type != clrType
+ comparer.Type != currentValueExpression.Type
? Expression.Convert(currentValueExpression, comparer.Type)
: currentValueExpression,
Expression.Default(comparer.Type));
diff --git a/src/EFCore/Metadata/Internal/Property.cs b/src/EFCore/Metadata/Internal/Property.cs
index ba6e9478699..28396b2d384 100644
--- a/src/EFCore/Metadata/Internal/Property.cs
+++ b/src/EFCore/Metadata/Internal/Property.cs
@@ -811,8 +811,8 @@ public virtual CoreTypeMapping? TypeMapping
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
public virtual ValueComparer? GetValueComparer()
- => GetValueComparer(null)
- ?? TypeMapping?.Comparer;
+ => ToNullableComparer(GetValueComparer(null)
+ ?? TypeMapping?.Comparer);
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -821,8 +821,62 @@ public virtual CoreTypeMapping? TypeMapping
/// doing so can result in application failures when updating to a new Entity Framework Core release.
///
public virtual ValueComparer? GetKeyValueComparer()
- => GetValueComparer(null)
- ?? TypeMapping?.KeyComparer;
+ => ToNullableComparer(GetValueComparer(null)
+ ?? TypeMapping?.KeyComparer);
+
+ private ValueComparer? ToNullableComparer(ValueComparer? valueComparer)
+ {
+ if (valueComparer == null
+ || !ClrType.IsNullableValueType()
+ || valueComparer.Type.IsNullableValueType())
+ {
+ return valueComparer;
+ }
+
+ var newEqualsParam1 = Expression.Parameter(ClrType, "v1");
+ var newEqualsParam2 = Expression.Parameter(ClrType, "v2");
+ var newHashCodeParam = Expression.Parameter(ClrType, "v");
+ var newSnapshotParam = Expression.Parameter(ClrType, "v");
+ var hasValueMethod = ClrType.GetMethod("get_HasValue")!;
+ var v1HasValue = Expression.Parameter(typeof(bool), "v1HasValue");
+ var v2HasValue = Expression.Parameter(typeof(bool), "v2HasValue");
+
+ return (ValueComparer)Activator.CreateInstance(
+ typeof(ValueComparer<>).MakeGenericType(ClrType),
+ Expression.Lambda(
+ Expression.Block(
+ typeof(bool),
+ new[] { v1HasValue, v2HasValue },
+ Expression.Assign(v1HasValue, Expression.Call(newEqualsParam1, hasValueMethod)),
+ Expression.Assign(v2HasValue, Expression.Call(newEqualsParam2, hasValueMethod)),
+ Expression.OrElse(
+ Expression.AndAlso(
+ v1HasValue,
+ Expression.AndAlso(
+ v2HasValue,
+ valueComparer.ExtractEqualsBody(
+ Expression.Convert(newEqualsParam1, valueComparer.Type),
+ Expression.Convert(newEqualsParam2, valueComparer.Type)))),
+ Expression.AndAlso(
+ Expression.Not(v1HasValue),
+ Expression.Not(v2HasValue)))),
+ newEqualsParam1, newEqualsParam2),
+ Expression.Lambda(
+ Expression.Condition(
+ Expression.Call(newHashCodeParam, hasValueMethod),
+ valueComparer.ExtractHashCodeBody(
+ Expression.Convert(newHashCodeParam, valueComparer.Type)),
+ Expression.Constant(0, typeof(int))),
+ newHashCodeParam),
+ Expression.Lambda(
+ Expression.Condition(
+ Expression.Call(newSnapshotParam, hasValueMethod),
+ Expression.Convert(
+ valueComparer.ExtractSnapshotBody(
+ Expression.Convert(newSnapshotParam, valueComparer.Type)), ClrType),
+ Expression.Default(ClrType)),
+ newSnapshotParam))!;
+ }
private ValueComparer? GetValueComparer(HashSet? checkedProperties)
{
diff --git a/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs b/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs
index 46dd9c8011d..e901b0b2b77 100644
--- a/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs
+++ b/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs
@@ -1686,8 +1686,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values
TestNullableDouble = -1.23456789,
TestNullableDecimal = -1234567890.01M,
TestNullableDateTime = DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(),
- TestNullableDateTimeOffset =
- new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(),
+ TestNullableDateTimeOffset = new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
TestNullableTimeSpan = new TimeSpan(0, 10, 9, 8, 7),
TestNullableSingle = -1.234F,
TestNullableBoolean = false,
@@ -1723,8 +1722,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values
AssertEqualIfMapped(entityType, -1.23456789, () => dt.TestNullableDouble);
AssertEqualIfMapped(entityType, -1234567890.01M, () => dt.TestNullableDecimal);
AssertEqualIfMapped(entityType, DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(), () => dt.TestNullableDateTime);
- AssertEqualIfMapped(
- entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(),
+ AssertEqualIfMapped(entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
() => dt.TestNullableDateTimeOffset);
AssertEqualIfMapped(entityType, new TimeSpan(0, 10, 9, 8, 7), () => dt.TestNullableTimeSpan);
AssertEqualIfMapped(entityType, -1.234F, () => dt.TestNullableSingle);
diff --git a/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs b/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
index 853809a2ba0..870a0086f1e 100644
--- a/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
+++ b/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
@@ -204,6 +204,12 @@ private Email(string value)
_value = value;
}
+ public override bool Equals(object obj)
+ => _value == ((Email)obj)?._value;
+
+ public override int GetHashCode()
+ => _value.GetHashCode();
+
public static Email Create(string value)
=> new(value);
@@ -1069,7 +1075,7 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
b.Property(nameof(BuiltInNullableDataTypes.TestNullableDateTimeOffset)).HasConversion(
new ValueConverter(
v => v.Value.ToUnixTimeMilliseconds(),
- v => DateTimeOffset.FromUnixTimeMilliseconds(v)));
+ v => DateTimeOffset.FromUnixTimeMilliseconds(v).ToOffset(TimeSpan.FromHours(-8.0))));
b.Property(nameof(BuiltInNullableDataTypes.TestNullableDouble)).HasConversion(
new ValueConverter(
diff --git a/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs b/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs
index e3997b74ed8..fa76f62917f 100644
--- a/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs
+++ b/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs
@@ -1569,7 +1569,7 @@ public int NullableAsNonNullable
public int? NonNullableAsNullable
{
get => _nonNullableAsNullable;
- set => _nonNullableAsNullable = (int)value;
+ set => _nonNullableAsNullable = value ?? 0;
}
}
@@ -1930,7 +1930,10 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
{
b.Property(e => e.Id).HasField("_id");
b.Property(e => e.NullableAsNonNullable).HasField("_nullableAsNonNullable").ValueGeneratedOnAddOrUpdate();
- b.Property(e => e.NonNullableAsNullable).HasField("_nonNullableAsNullable").ValueGeneratedOnAddOrUpdate();
+ b.Property(e => e.NonNullableAsNullable)
+ .HasField("_nonNullableAsNullable")
+ .ValueGeneratedOnAddOrUpdate()
+ .UsePropertyAccessMode(PropertyAccessMode.Property);
});
modelBuilder.Entity();
diff --git a/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs b/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs
index d69d92909eb..a7603a2a5ce 100644
--- a/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs
+++ b/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs
@@ -991,7 +991,7 @@ public virtual void Value_converter_configured_on_non_nullable_type_is_applied()
var wierd = entityType.FindProperty("Wierd");
Assert.IsType>(wierd.GetValueConverter());
- Assert.IsType>(wierd.GetValueComparer());
+ Assert.IsType>(wierd.GetValueComparer());
}
[ConditionalFact]