Skip to content

Commit

Permalink
Allow set operations over other mapping types
Browse files Browse the repository at this point in the history
Including different mapping types on the two sides, as long as the
result is compatible.

Fixes #16725
  • Loading branch information
roji committed Jul 31, 2019
1 parent 742180b commit b6930c4
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 21 deletions.
34 changes: 14 additions & 20 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -452,17 +452,24 @@ public Expression ApplySetOperation(
continue;
}

if (joinedMapping.Value1 is ColumnExpression && joinedMapping.Value2 is ColumnExpression
|| joinedMapping.Value1 is ScalarSubqueryExpression && joinedMapping.Value2 is ScalarSubqueryExpression)
if (joinedMapping.Value1 is SqlExpression innerColumn1
&& joinedMapping.Value2 is SqlExpression innerColumn2)
{
handleColumnMapping(
joinedMapping.Key,
select1, (SqlExpression)joinedMapping.Value1,
select2, (SqlExpression)joinedMapping.Value2);
// For now, make sure that both sides output the same store type, otherwise the query may fail.
// TODO: with #15586 we'll be able to also allow different store types which are implicitly convertible to one another.
if (innerColumn1.TypeMapping.StoreType != innerColumn2.TypeMapping.StoreType)
{
throw new InvalidOperationException("Set operations over different store types are currently unsupported");
}

var alias = joinedMapping.Key.Last?.Name;
select1.AddToProjection(innerColumn1, alias);
select2.AddToProjection(innerColumn2, alias);
_projectionMapping[joinedMapping.Key] = innerColumn1;
continue;
}

throw new InvalidOperationException("Non-matching or unknown projection mapping type in set operation");
throw new InvalidOperationException($"Non-matching or unknown projection mapping type in set operation ({joinedMapping.Value1.GetType().Name} and {joinedMapping.Value2.GetType().Name})");
}
}

Expand Down Expand Up @@ -520,19 +527,6 @@ ColumnExpression addSetOperationColumnProjections(

return column1;
}

void handleColumnMapping(
ProjectionMember projectionMember,
SelectExpression select1, SqlExpression innerColumn1,
SelectExpression select2, SqlExpression innerColumn2)
{
// The actual columns may actually be different, but we don't care as long as the type and alias
// coming out of the two operands are the same
var alias = projectionMember.Last?.Name;
select1.AddToProjection(innerColumn1, alias);
select2.AddToProjection(innerColumn2, alias);
_projectionMapping[projectionMember] = innerColumn1;
}
}

public IDictionary<SqlExpression, ColumnExpression> PushdownIntoSubquery()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,7 @@ public override void Include_Union_only_on_one_side_throws() {}
public override void Include_Union_different_includes_throws() {}
public override Task SubSelect_Union(bool isAsync) => Task.CompletedTask;
public override Task Client_eval_Union_FirstOrDefault(bool isAsync) => Task.CompletedTask;
public override Task GroupBy_Select_Union(bool isAsync) => Task.CompletedTask;
public override Task Union_over_different_projection_types(bool isAsync, string leftType, string rightType) => Task.CompletedTask;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,16 @@ public override Task Select_Except_reference_projection(bool isAsync)
return Task.CompletedTask; //base.Select_Except_reference_projection(isAsync);
}

public override Task GroupBy_Select_Union(bool isAsync)
{
return Task.CompletedTask; //base.GroupBy_Select_Union(isAsync);
}

public override Task Union_over_different_projection_types(bool isAsync, string leftType, string rightType)
{
return Task.CompletedTask; //base.Union_over_different_projection_types(isAsync);
}

#endregion

[ConditionalFact(Skip = "Issue#16564")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
Expand Down Expand Up @@ -339,5 +340,60 @@ public virtual Task Client_eval_Union_FirstOrDefault(bool isAsync)
.Union(cs));

private static Customer ClientSideMethod(Customer c) => c;

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Select_Union(bool isAsync)
=> AssertQuery<Customer>(isAsync, cs => cs
.Where(c => c.City == "Berlin")
.GroupBy(c => c.CustomerID)
.Select(g => new { CustomerID = g.Key, Count = g.Count() })
.Union(cs
.Where(c => c.City == "London")
.GroupBy(c => c.CustomerID)
.Select(g => new { CustomerID = g.Key, Count = g.Count() })));

[ConditionalTheory]
[MemberData(nameof(GetSetOperandTestCases))]
public virtual Task Union_over_different_projection_types(bool isAsync, string leftType, string rightType)
{
var (left, right) = (ExpressionGenerator(leftType), ExpressionGenerator(rightType));
return AssertQuery<Order>(isAsync, os => left(os).Union(right(os)));

static Func<IQueryable<Order>, IQueryable<object>> ExpressionGenerator(string expressionType)
{
switch (expressionType)
{
case "Column":
return os => os.Select(o => (object)o.OrderID);
case "Function":
return os => os
.GroupBy(o => o.OrderID)
.Select(g => (object)g.Count());
case "Constant":
return os => os.Select(o => (object)8);
case "Unary":
return os => os.Select(o => (object)-o.OrderID);
case "Binary":
return os => os.Select(o => (object)(o.OrderID + 1));
case "ScalarSubquery":
return os => os.Select(o => (object)o.OrderDetails.Count());
default:
throw new NotSupportedException();
}
}
}

private static IEnumerable<object[]> GetSetOperandTestCases()
=> from async in new[] { true, false }
from leftType in SupportedOperandExpressionType
from rightType in SupportedOperandExpressionType
select new object[] { async, leftType, rightType };

// ReSharper disable once StaticMemberInGenericType
private static readonly string[] SupportedOperandExpressionType =
{
"Column", "Function", "Constant", "Unary", "Binary", "ScalarSubquery"
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -334,5 +334,67 @@ FROM [Orders] AS [o0]
WHERE ([c0].[CustomerID] = [o0].[CustomerID]) AND [o0].[CustomerID] IS NOT NULL) AS [Orders]
FROM [Customers] AS [c0]");
}

public override async Task GroupBy_Select_Union(bool isAsync)
{
await base.GroupBy_Select_Union(isAsync);

AssertSql(
@"SELECT [c].[CustomerID], COUNT(*) AS [Count]
FROM [Customers] AS [c]
WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
GROUP BY [c].[CustomerID]
UNION
SELECT [c0].[CustomerID], COUNT(*) AS [Count]
FROM [Customers] AS [c0]
WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL
GROUP BY [c0].[CustomerID]");
}

public override async Task Union_over_different_projection_types(bool isAsync, string leftType, string rightType)
{
await base.Union_over_different_projection_types(isAsync, leftType, rightType);

var leftSql = GenerateSql(leftType);
var rightSql = GenerateSql(rightType);

// Fix up right-side SQL as table aliases shift
rightSql = leftType == "ScalarSubquery"
? rightSql.Replace("[o]", "[o1]").Replace("[o0]", "[o2]")
: rightSql.Replace("[o0]", "[o1]").Replace("[o]", "[o0]");

AssertSql(leftSql + Environment.NewLine + "UNION" + Environment.NewLine + rightSql);

static string GenerateSql(string expressionType)
{
switch (expressionType)
{
case "Column":
return @"SELECT [o].[OrderID]
FROM [Orders] AS [o]";
case "Function":
return @"SELECT COUNT(*)
FROM [Orders] AS [o]
GROUP BY [o].[OrderID]";
case "Constant":
return @"SELECT 8
FROM [Orders] AS [o]";
case "Unary":
return @"SELECT -[o].[OrderID]
FROM [Orders] AS [o]";
case "Binary":
return @"SELECT [o].[OrderID] + 1
FROM [Orders] AS [o]";
case "ScalarSubquery":
return @"SELECT (
SELECT COUNT(*)
FROM [Order Details] AS [o]
WHERE [o0].[OrderID] = [o].[OrderID])
FROM [Orders] AS [o0]";
default:
throw new NotSupportedException();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public SimpleQuerySqlServerTest(NorthwindQuerySqlServerFixture<NoopModelCustomiz
: base(fixture)
{
ClearLog();
//Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

public override void Shaper_command_caching_when_parameter_names_different()
Expand Down

0 comments on commit b6930c4

Please sign in to comment.