Skip to content

Commit

Permalink
Query: SqlServer: Add translation for string.IndexOf(string, int)
Browse files Browse the repository at this point in the history
Resolves #25396
  • Loading branch information
yosoyhabacuc authored and smitpatel committed Nov 11, 2021
1 parent 3f9a137 commit 813dafd
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public CosmosMethodCallTranslatorProvider(
new IMethodCallTranslator[]
{
new EqualsTranslator(sqlExpressionFactory),
new StringMethodTranslator(sqlExpressionFactory),
new CosmosStringMethodTranslator(sqlExpressionFactory),
new ContainsTranslator(sqlExpressionFactory),
new RandomTranslator(sqlExpressionFactory),
new MathTranslator(sqlExpressionFactory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class StringMethodTranslator : IMethodCallTranslator
public class CosmosStringMethodTranslator : IMethodCallTranslator
{
private static readonly MethodInfo _indexOfMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), typeof(string));
Expand Down Expand Up @@ -98,7 +98,7 @@ private static readonly MethodInfo _stringComparisonWithComparisonTypeArgumentSt
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public StringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
public CosmosStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
{
_sqlExpressionFactory = sqlExpressionFactory;
}
Expand Down
104 changes: 65 additions & 39 deletions src/EFCore.SqlServer/Query/Internal/SqlServerStringMethodTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ public class SqlServerStringMethodTranslator : IMethodCallTranslator
private static readonly MethodInfo _indexOfMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), typeof(string));

private static readonly MethodInfo _indexOfMethodInfoWithStartingPosition
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), new[] { typeof(string), typeof(int) });

private static readonly MethodInfo _indexOfMethodInfoWithStartingPosition
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), new[] { typeof(string), typeof(int) });

private static readonly MethodInfo _replaceMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.Replace), typeof(string), typeof(string));

Expand Down Expand Up @@ -115,46 +121,12 @@ public SqlServerStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactor
{
if (_indexOfMethodInfo.Equals(method))
{
var argument = arguments[0];
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, argument)!;
argument = _sqlExpressionFactory.ApplyTypeMapping(argument, stringTypeMapping);

SqlExpression charIndexExpression;
var storeType = stringTypeMapping.StoreType;
if (string.Equals(storeType, "nvarchar(max)", StringComparison.OrdinalIgnoreCase)
|| string.Equals(storeType, "varchar(max)", StringComparison.OrdinalIgnoreCase))
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
new[] { argument, _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping) },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
typeof(long));

charIndexExpression = _sqlExpressionFactory.Convert(charIndexExpression, typeof(int));
}
else
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
new[] { argument, _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping) },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
method.ReturnType);
}

charIndexExpression = _sqlExpressionFactory.Subtract(charIndexExpression, _sqlExpressionFactory.Constant(1));
return TranslateIndexOf(instance, method, arguments[0], null);
}

return _sqlExpressionFactory.Case(
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(
argument,
_sqlExpressionFactory.Constant(string.Empty, stringTypeMapping)),
_sqlExpressionFactory.Constant(0))
},
charIndexExpression);
if (_indexOfMethodInfoWithStartingPosition.Equals(method))
{
return TranslateIndexOf(instance, method, arguments[0], arguments[1]);
}

if (_replaceMethodInfo.Equals(method))
Expand Down Expand Up @@ -465,6 +437,60 @@ private SqlExpression TranslateStartsEndsWith(SqlExpression instance, SqlExpress
pattern);
}

private SqlExpression TranslateIndexOf(SqlExpression instance, MethodInfo method, SqlExpression searchExpression, SqlExpression? startIndex)
{
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, searchExpression)!;
searchExpression = _sqlExpressionFactory.ApplyTypeMapping(searchExpression, stringTypeMapping);
instance = _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping);

var charIndexArguments = new List<SqlExpression> { searchExpression, instance };

if (startIndex is not null)
{
charIndexArguments.Add(_sqlExpressionFactory.Add(startIndex, _sqlExpressionFactory.Constant(1)));
}

var argumentsPropagateNullability = Enumerable.Repeat(true, charIndexArguments.Count);

SqlExpression charIndexExpression;
var storeType = stringTypeMapping.StoreType;
if (string.Equals(storeType, "nvarchar(max)", StringComparison.OrdinalIgnoreCase)
|| string.Equals(storeType, "varchar(max)", StringComparison.OrdinalIgnoreCase))
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
charIndexArguments,
nullable: true,
argumentsPropagateNullability,
typeof(long));

charIndexExpression = _sqlExpressionFactory.Convert(charIndexExpression, typeof(int));
}
else
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
charIndexArguments,
nullable: true,
argumentsPropagateNullability,
method.ReturnType);
}

charIndexExpression = _sqlExpressionFactory.Subtract(charIndexExpression, _sqlExpressionFactory.Constant(1));

return _sqlExpressionFactory.Case(
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(
searchExpression,
_sqlExpressionFactory.Constant(string.Empty, stringTypeMapping)),
_sqlExpressionFactory.Constant(0))
},
charIndexExpression);
}


// See https://docs.microsoft.com/en-us/sql/t-sql/language-elements/like-transact-sql
private bool IsLikeWildChar(char c)
=> c == '%' || c == '_' || c == '[';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,15 @@ public virtual Task Indexof_with_emptystring(bool async)
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.IndexOf(string.Empty)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Indexof_with_one_arg(bool async)
{
return AssertQueryScalar(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.IndexOf("a")));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Indexof_with_starting_position(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,13 +1528,28 @@ FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'");
}

[ConditionalTheory(Skip = "issue #25396")]
public override async Task Indexof_with_one_arg(bool async)
{
await base.Indexof_with_one_arg(async);

AssertSql(
@"SELECT CASE
WHEN N'a' = N'' THEN 0
ELSE CAST(CHARINDEX(N'a', [c].[ContactName]) AS int) - 1
END
FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'");
}

public override async Task Indexof_with_starting_position(bool async)
{
await base.Indexof_with_starting_position(async);

AssertSql(
@"SELECT [c].[ContactName]
@"SELECT CASE
WHEN N'a' = N'' THEN 0
ELSE CAST(CHARINDEX(N'a', [c].[ContactName], 3 + 1) AS int) - 1
END
FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'");
}
Expand Down

0 comments on commit 813dafd

Please sign in to comment.