Skip to content

Commit

Permalink
Query: Convert SubSelectExpression to search condition as necessary
Browse files Browse the repository at this point in the history
Fix for dotnet#14900
  • Loading branch information
smitpatel committed Jun 24, 2019
1 parent d7e85ed commit e649a72
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline;
Expand Down Expand Up @@ -202,7 +203,6 @@ protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpres
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThan
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThanOrEqual;


return ApplyConversion(sqlBinaryExpression, condition);
}

Expand All @@ -223,10 +223,14 @@ protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpressio
resultCondition = false;
break;

default:
case ExpressionType.Equal:
case ExpressionType.NotEqual:
_isSearchCondition = false;
resultCondition = true;
break;

default:
throw new InvalidOperationException("Unknown operator type encountered in SqlUnaryExpression.");
}

var operand = (SqlExpression)Visit(sqlUnaryExpression.Operand);
Expand Down Expand Up @@ -325,9 +329,11 @@ protected override Expression VisitLeftJoin(LeftJoinExpression leftJoinExpressio

protected override Expression VisitSubSelect(SubSelectExpression subSelectExpression)
{
var parentSearchCondition = _isSearchCondition;
var subquery = (SelectExpression)Visit(subSelectExpression.Subquery);
_isSearchCondition = parentSearchCondition;

return subSelectExpression.Update(subquery);
return ApplyConversion(subSelectExpression.Update(subquery), condition: false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7504,6 +7504,15 @@ public virtual Task Anonymous_projection_take_followed_by_projecting_single_elem
gs => gs.Select(g => new { Gear = g }).Take(25).Select(e => e.Gear.Weapons.OrderBy(w => w.Id).FirstOrDefault()));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Bool_projection_from_subquery_treated_appropriately_in_where(bool isAsync)
{
return AssertQuery<City, Gear>(
isAsync,
(cs, gs) => cs.Where(c => gs.OrderBy(g => g.Nickname).ThenBy(g => g.SquadId).FirstOrDefault().HasSoulPatch));
}

protected GearsOfWarContext CreateContext() => Fixture.CreateContext();

protected virtual void ClearLog()
Expand Down
8 changes: 4 additions & 4 deletions test/EFCore.Specification.Tests/TestUtilities/TestHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ public static int AssertResults<T>(
}
}

elementSorter = elementSorter ?? (e => e);
elementAsserter = elementAsserter ?? Assert.Equal;
elementSorter ??= (e => e);
elementAsserter ??= Assert.Equal;
if (!verifyOrdered)
{
expected = expected.OrderBy(elementSorter).ToList();
Expand Down Expand Up @@ -340,7 +340,7 @@ public static int AssertResults<T>(
}
}

elementAsserter = elementAsserter ?? Assert.Equal;
elementAsserter ??= Assert.Equal;
if (!verifyOrdered)
{
expected = expected.OrderBy(elementSorter).ToList();
Expand Down Expand Up @@ -380,7 +380,7 @@ public static int AssertResultsNullable<T>(
}
}

elementAsserter = elementAsserter ?? Assert.Equal;
elementAsserter ??= Assert.Equal;
if (!verifyOrdered)
{
expected = expected.OrderBy(elementSorter).ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8565,6 +8565,20 @@ public override async Task Anonymous_projection_take_followed_by_projecting_sing
@"");
}

public override async Task Bool_projection_from_subquery_treated_appropriately_in_where(bool isAsync)
{
await base.Bool_projection_from_subquery_treated_appropriately_in_where(isAsync);

AssertSql(
@"SELECT [c].[Name], [c].[Location], [c].[Nation]
FROM [Cities] AS [c]
WHERE (
SELECT TOP(1) [g].[HasSoulPatch]
FROM [Gears] AS [g]
WHERE [g].[Discriminator] IN (N'Gear', N'Officer')
ORDER BY [g].[Nickname], [g].[SquadId]) = CAST(1 AS bit)");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit e649a72

Please sign in to comment.