Skip to content

Commit

Permalink
Adding translations for cosmos string methods
Browse files Browse the repository at this point in the history
INDEX_OF, REPLACE, case insensitive STRINGEQUALS
  • Loading branch information
maumar committed Aug 12, 2021
1 parent 913e649 commit 6162369
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Metadata;
Expand Down
60 changes: 53 additions & 7 deletions src/EFCore.Cosmos/Query/Internal/StringMethodTranslator.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Utilities;
Expand All @@ -18,6 +15,15 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal
/// </summary>
public class StringMethodTranslator : IMethodCallTranslator
{
private static readonly MethodInfo _indexOfMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), new[] { typeof(string) });

private static readonly MethodInfo _indexOfMethodInfoWithStartingPosition
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), new[] { typeof(string), typeof(int) });

private static readonly MethodInfo _replaceMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.Replace), new[] { typeof(string), typeof(string) });

private static readonly MethodInfo _containsMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.Contains), new[] { typeof(string) });

Expand Down Expand Up @@ -68,17 +74,23 @@ private static readonly MethodInfo _lastOrDefaultMethodInfoWithoutArgs
&& m.GetParameters().Length == 1).MakeGenericMethod(typeof(char));

private static readonly MethodInfo _stringConcatWithTwoArguments =
typeof(String).GetRequiredRuntimeMethod(nameof(string.Concat),
typeof(string).GetRequiredRuntimeMethod(nameof(string.Concat),
new[] { typeof(string), typeof(string) });

private static readonly MethodInfo _stringConcatWithThreeArguments =
typeof(String).GetRequiredRuntimeMethod(nameof(string.Concat),
typeof(string).GetRequiredRuntimeMethod(nameof(string.Concat),
new[] { typeof(string), typeof(string), typeof(string) });

private static readonly MethodInfo _stringConcatWithFourArguments =
typeof(String).GetRequiredRuntimeMethod(nameof(string.Concat),
typeof(string).GetRequiredRuntimeMethod(nameof(string.Concat),
new[] { typeof(string), typeof(string), typeof(string), typeof(string) });

private static readonly MethodInfo _stringComparisonWithComparisonTypeArgumentInstance
= typeof(string).GetRequiredRuntimeMethod(nameof(string.Equals), typeof(string), typeof(StringComparison));

private static readonly MethodInfo _stringComparisonWithComparisonTypeArgumentStatic
= typeof(string).GetRequiredRuntimeMethod(nameof(string.Equals), typeof(string), typeof(string), typeof(StringComparison));

private readonly ISqlExpressionFactory _sqlExpressionFactory;

/// <summary>
Expand Down Expand Up @@ -110,6 +122,21 @@ public StringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)

if (instance != null)
{
if (_indexOfMethodInfo.Equals(method))
{
return TranslateSystemFunction("INDEX_OF", typeof(int), instance, arguments[0]);
}

if (_indexOfMethodInfoWithStartingPosition.Equals(method))
{
return TranslateSystemFunction("INDEX_OF", typeof(int), instance, arguments[0], arguments[1]);
}

if (_replaceMethodInfo.Equals(method))
{
return TranslateSystemFunction("REPLACE", method.ReturnType, instance, arguments[0], arguments[1]);
}

if (_containsMethodInfo.Equals(method))
{
return TranslateSystemFunction("CONTAINS", typeof(bool), instance, arguments[0]);
Expand Down Expand Up @@ -171,7 +198,11 @@ public StringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)

if (_substringMethodInfoWithTwoArgs.Equals(method))
{
return TranslateSystemFunction("SUBSTRING", method.ReturnType, instance, arguments[0], arguments[1]);
return arguments[0] is SqlConstantExpression constant
&& constant.Value is int intValue
&& intValue == 0
? TranslateSystemFunction("LEFT", method.ReturnType, instance, arguments[1])
: TranslateSystemFunction("SUBSTRING", method.ReturnType, instance, arguments[0], arguments[1]);
}
}

Expand Down Expand Up @@ -212,6 +243,21 @@ public StringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
arguments[3])));
}

if (_stringComparisonWithComparisonTypeArgumentInstance.Equals(method)
|| _stringComparisonWithComparisonTypeArgumentStatic.Equals(method))
{
var comparisonTypeArgument = arguments[^1];
if (comparisonTypeArgument is SqlConstantExpression constantComparisonTypeArgument
&& constantComparisonTypeArgument.Value is StringComparison comparisonTypeArgumentValue
&& comparisonTypeArgumentValue == StringComparison.OrdinalIgnoreCase)
{

return _stringComparisonWithComparisonTypeArgumentInstance.Equals(method)
? TranslateSystemFunction("STRINGEQUALS", typeof(bool), instance!, arguments[0], _sqlExpressionFactory.Constant(true))
: TranslateSystemFunction("STRINGEQUALS", typeof(bool), arguments[0], arguments[1], _sqlExpressionFactory.Constant(true));
}
}

return null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using Microsoft.EntityFrameworkCore.Diagnostics;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit;
using Xunit.Abstractions;
Expand Down Expand Up @@ -362,7 +363,7 @@ FROM root c
WHERE (((c[""Discriminator""] = ""OrderDetail"") AND (c[""Quantity""] < 5)) AND (FLOOR(c[""UnitPrice""]) > 10.0))");
}

[ConditionalTheory(Skip = "Issue #17246")]
[ConditionalTheory(Skip = "Issue #25120")]
public override async Task Where_math_power(bool async)
{
await base.Where_math_power(async);
Expand All @@ -373,6 +374,17 @@ FROM root c
WHERE (c[""Discriminator""] = ""OrderDetail"")");
}

[ConditionalTheory(Skip = "Issue #25120")]
public override async Task Where_math_square(bool async)
{
await base.Where_math_square(async);

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] = ""OrderDetail"")");
}

public override async Task Where_math_round(bool async)
{
await base.Where_math_round(async);
Expand Down Expand Up @@ -621,15 +633,26 @@ FROM [Order Details] AS [o]
WHERE ([o].[Quantity] < CAST(5 AS smallint)) AND (FLOOR(CAST([o].[UnitPrice] AS real)) > CAST(10 AS real))");
}

[ConditionalTheory(Skip = "Issue #17246")]
[ConditionalTheory]
public override async Task Where_mathf_power(bool async)
{
await base.Where_mathf_power(async);

AssertSql(
@"SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
WHERE POWER([o].[Discount], CAST(2 AS real)) > CAST(0.05 AS real)");
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""OrderDetail"") AND (POWER(c[""Discount""], 3.0) > 0.005))");
}

[ConditionalTheory]
public override async Task Where_mathf_square(bool async)
{
await base.Where_mathf_square(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""OrderDetail"") AND (POWER(c[""Discount""], 2.0) > 0.05))");
}

[ConditionalTheory(Skip = "Issue #17246")]
Expand Down Expand Up @@ -932,7 +955,17 @@ public override async Task Indexof_with_emptystring(bool async)
await base.Indexof_with_emptystring(async);

AssertSql(
@"SELECT c[""ContactName""]
@"SELECT INDEX_OF(c[""ContactName""], """") AS c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = ""ALFKI""))");
}

public override async Task Indexof_with_starting_position(bool async)
{
await base.Indexof_with_starting_position(async);

AssertSql(
@"SELECT INDEX_OF(c[""ContactName""], ""a"", 3) AS c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = ""ALFKI""))");
}
Expand All @@ -942,7 +975,17 @@ public override async Task Replace_with_emptystring(bool async)
await base.Replace_with_emptystring(async);

AssertSql(
@"SELECT c[""ContactName""]
@"SELECT REPLACE(c[""ContactName""], ""ari"", """") AS c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = ""ALFKI""))");
}

public override async Task Replace_using_property_arguments(bool async)
{
await base.Replace_using_property_arguments(async);

AssertSql(
@"SELECT REPLACE(c[""ContactName""], c[""ContactName""], c[""CustomerID""]) AS c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = ""ALFKI""))");
}
Expand Down Expand Up @@ -984,7 +1027,7 @@ public override async Task Substring_with_two_args_with_zero_startindex(bool asy
await base.Substring_with_two_args_with_zero_startindex(async);

AssertSql(
@"SELECT SUBSTRING(c[""ContactName""], 0, 3) AS c
@"SELECT LEFT(c[""ContactName""], 3) AS c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = ""ALFKI""))");
}
Expand Down Expand Up @@ -1026,7 +1069,7 @@ public override async Task Substring_with_two_args_with_Index_of(bool async)
await base.Substring_with_two_args_with_Index_of(async);

AssertSql(
@"SELECT c[""ContactName""]
@"SELECT SUBSTRING(c[""ContactName""], INDEX_OF(c[""ContactName""], ""a""), 3) AS c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = ""ALFKI""))");
}
Expand Down Expand Up @@ -1290,6 +1333,36 @@ public override Task Regex_IsMatch_MethodCall_constant_input(bool async)
return AssertTranslationFailed(() => base.Regex_IsMatch_MethodCall_constant_input(async));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Case_insensitive_string_comparison_instance(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID.Equals("alFkI", StringComparison.OrdinalIgnoreCase)),
entryCount: 1);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND STRINGEQUALS(c[""CustomerID""], ""alFkI"", true))");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Case_insensitive_string_comparison_static(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => string.Equals(c.CustomerID, "alFkI", StringComparison.OrdinalIgnoreCase)),
entryCount: 1);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND STRINGEQUALS(c[""CustomerID""], ""alFkI"", true))");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2173,9 +2173,14 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND c[""OrderID""] IN (10248, 10249))");
}

public override Task Where_equals_method_string_with_ignore_case(bool async)
public override async Task Where_equals_method_string_with_ignore_case(bool async)
{
return AssertTranslationFailed(() => base.Where_equals_method_string_with_ignore_case(async));
await base.Where_equals_method_string_with_ignore_case(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND STRINGEQUALS(c[""City""], ""London"", true))");
}

public override async Task Filter_with_EF_Property_using_closure_for_property_name(bool async)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit;
Expand Down Expand Up @@ -829,6 +825,16 @@ public virtual Task Where_math_floor(bool async)
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_math_power(bool async)
{
return AssertQuery(
async,
ss => ss.Set<OrderDetail>().Where(od => Math.Pow(od.Discount, 3) > 0.005f),
entryCount: 315);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_math_square(bool async)
{
return AssertQuery(
async,
Expand Down Expand Up @@ -1083,13 +1089,23 @@ public virtual Task Where_mathf_floor(bool async)
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_mathf_power(bool async)
{
return AssertQuery(
async,
ss => ss.Set<OrderDetail>().Where(od => MathF.Pow(od.Discount, 3) > 0.005f),
entryCount: 315);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_mathf_square(bool async)
{
return AssertQuery(
async,
ss => ss.Set<OrderDetail>().Where(od => MathF.Pow(od.Discount, 2) > 0.05f),
entryCount: 154);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_mathf_round2(bool async)
Expand Down Expand Up @@ -1508,6 +1524,15 @@ public virtual Task Indexof_with_emptystring(bool async)
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.IndexOf(string.Empty)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Indexof_with_starting_position(bool async)
{
return AssertQueryScalar(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.IndexOf("a", 3)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Replace_with_emptystring(bool async)
Expand All @@ -1517,6 +1542,15 @@ public virtual Task Replace_with_emptystring(bool async)
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.Replace("ari", string.Empty)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Replace_using_property_arguments(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.Replace(c.ContactName, c.CustomerID)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Substring_with_one_arg_with_zero_startindex(bool async)
Expand Down
Loading

0 comments on commit 6162369

Please sign in to comment.