Skip to content

Commit

Permalink
Query: add translation for string.IndexOf(string, int) dotnet#25396
Browse files Browse the repository at this point in the history
  • Loading branch information
yosoyhabacuc committed Sep 6, 2021
1 parent e8a56f5 commit 268bd95
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 41 deletions.
102 changes: 63 additions & 39 deletions src/EFCore.SqlServer/Query/Internal/SqlServerStringMethodTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ public class SqlServerStringMethodTranslator : IMethodCallTranslator
private static readonly MethodInfo _indexOfMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), new[] { typeof(string) });

private static readonly MethodInfo _indexOfMethodInfoWithStartingPosition
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), new[] { typeof(string), typeof(int) });

private static readonly MethodInfo _replaceMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.Replace), new[] { typeof(string), typeof(string) });

Expand Down Expand Up @@ -120,46 +123,12 @@ public SqlServerStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactor
{
if (_indexOfMethodInfo.Equals(method))
{
var argument = arguments[0];
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, argument)!;
argument = _sqlExpressionFactory.ApplyTypeMapping(argument, stringTypeMapping);

SqlExpression charIndexExpression;
var storeType = stringTypeMapping.StoreType;
if (string.Equals(storeType, "nvarchar(max)", StringComparison.OrdinalIgnoreCase)
|| string.Equals(storeType, "varchar(max)", StringComparison.OrdinalIgnoreCase))
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
new[] { argument, _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping) },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
typeof(long));

charIndexExpression = _sqlExpressionFactory.Convert(charIndexExpression, typeof(int));
}
else
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
new[] { argument, _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping) },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
method.ReturnType);
}

charIndexExpression = _sqlExpressionFactory.Subtract(charIndexExpression, _sqlExpressionFactory.Constant(1));
return TranslateIndexOf(instance, method, arguments[0], null);
}

return _sqlExpressionFactory.Case(
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(
argument,
_sqlExpressionFactory.Constant(string.Empty, stringTypeMapping)),
_sqlExpressionFactory.Constant(0))
},
charIndexExpression);
if (_indexOfMethodInfoWithStartingPosition.Equals(method))
{
return TranslateIndexOf(instance, method, arguments[0], arguments[1]);
}

if (_replaceMethodInfo.Equals(method))
Expand Down Expand Up @@ -470,6 +439,61 @@ private SqlExpression TranslateStartsEndsWith(SqlExpression instance, SqlExpress
pattern);
}

private SqlExpression TranslateIndexOf(SqlExpression instance, MethodInfo method, SqlExpression searchExpression, SqlExpression? startIndex)
{
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, searchExpression)!;
searchExpression = _sqlExpressionFactory.ApplyTypeMapping(searchExpression, stringTypeMapping);

SqlExpression[] charIndexArguments;
if (startIndex == null)
{
charIndexArguments = new[] { searchExpression, _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping) };
}
else
{
var startIndexSql = _sqlExpressionFactory.Add(startIndex, _sqlExpressionFactory.Constant(1));
charIndexArguments = new[] { searchExpression, _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping), startIndexSql };
}

SqlExpression charIndexExpression;
var storeType = stringTypeMapping.StoreType;
if (string.Equals(storeType, "nvarchar(max)", StringComparison.OrdinalIgnoreCase)
|| string.Equals(storeType, "varchar(max)", StringComparison.OrdinalIgnoreCase))
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
charIndexArguments,
nullable: true,
argumentsPropagateNullability: new[] { true, true },
typeof(long));

charIndexExpression = _sqlExpressionFactory.Convert(charIndexExpression, typeof(int));
}
else
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
charIndexArguments,
nullable: true,
argumentsPropagateNullability: new[] { true, true },
method.ReturnType);
}

charIndexExpression = _sqlExpressionFactory.Subtract(charIndexExpression, _sqlExpressionFactory.Constant(1));

return _sqlExpressionFactory.Case(
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(
searchExpression,
_sqlExpressionFactory.Constant(string.Empty, stringTypeMapping)),
_sqlExpressionFactory.Constant(0))
},
charIndexExpression);
}


// See https://docs.microsoft.com/en-us/sql/t-sql/language-elements/like-transact-sql
private bool IsLikeWildChar(char c)
=> c == '%' || c == '_' || c == '[';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,15 @@ public virtual Task Indexof_with_emptystring(bool async)
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.IndexOf(string.Empty)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Indexof_with_one_arg(bool async)
{
return AssertQueryScalar(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.IndexOf("a")));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Indexof_with_starting_position(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1498,13 +1498,28 @@ FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'");
}

[ConditionalTheory(Skip = "issue #25396")]
public override async Task Indexof_with_one_arg(bool async)
{
await base.Indexof_with_one_arg(async);

AssertSql(
@"SELECT CASE
WHEN N'a' = N'' THEN 0
ELSE CAST(CHARINDEX(N'a', [c].[ContactName]) AS int) - 1
END
FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'");
}

public override async Task Indexof_with_starting_position(bool async)
{
await base.Indexof_with_starting_position(async);

AssertSql(
@"SELECT [c].[ContactName]
@"SELECT CASE
WHEN N'a' = N'' THEN 0
ELSE CAST(CHARINDEX(N'a', [c].[ContactName], 3 + 1) AS int) - 1
END
FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'");
}
Expand Down

0 comments on commit 268bd95

Please sign in to comment.