From 08aebff1ff48d051941559be7a922b95c1079105 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Wed, 27 Jul 2022 10:54:30 +0200 Subject: [PATCH] Implement joins via USING --- .../NpgsqlServiceCollectionExtensions.cs | 1 + .../Query/Internal/NpgsqlQuerySqlGenerator.cs | 92 ++++++++++++++++--- ...yableMethodTranslatingExpressionVisitor.cs | 56 +++++++++++ ...thodTranslatingExpressionVisitorFactory.cs | 19 ++++ ...FiltersInheritanceBulkUpdatesNpgsqlTest.cs | 19 ++++ .../InheritanceBulkUpdatesNpgsqlTest.cs | 19 ++++ .../NorthwindBulkUpdatesNpgsqlTest.cs | 16 +--- 7 files changed, 198 insertions(+), 24 deletions(-) create mode 100644 src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs create mode 100644 src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitorFactory.cs create mode 100644 test/EFCore.PG.FunctionalTests/BulkUpdates/FiltersInheritanceBulkUpdatesNpgsqlTest.cs create mode 100644 test/EFCore.PG.FunctionalTests/BulkUpdates/InheritanceBulkUpdatesNpgsqlTest.cs diff --git a/src/EFCore.PG/Extensions/NpgsqlServiceCollectionExtensions.cs b/src/EFCore.PG/Extensions/NpgsqlServiceCollectionExtensions.cs index 4483d6cbbb..0c0e42f6f7 100644 --- a/src/EFCore.PG/Extensions/NpgsqlServiceCollectionExtensions.cs +++ b/src/EFCore.PG/Extensions/NpgsqlServiceCollectionExtensions.cs @@ -105,6 +105,7 @@ public static IServiceCollection AddEntityFrameworkNpgsql(this IServiceCollectio .TryAdd() .TryAdd() .TryAdd() + .TryAdd() .TryAdd() .TryAdd() .TryAdd() diff --git a/src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs b/src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs index 1581ea14ce..794304a4b5 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs @@ -206,6 +206,85 @@ protected override Expression VisitSqlBinary(SqlBinaryExpression binary) } } + protected override Expression VisitDelete(DeleteExpression deleteExpression) + { + var selectExpression = deleteExpression.SelectExpression; + + if (selectExpression.Offset == null + && selectExpression.Limit == null + && selectExpression.Having == null + && selectExpression.Orderings.Count == 0 + && selectExpression.GroupBy.Count == 0 + && selectExpression.Projection.Count == 0) + { + Sql.Append("DELETE FROM "); + Visit(deleteExpression.Table); + + var predicate = selectExpression.Predicate; + + // The SelectExpression also contains the target table being modified (same as deleteExpression.Table). + // If it has additional inner joins, use the PostgreSQL-specific USING syntax to express the join. + if (selectExpression.Tables.Count > 1) + { + Sql.AppendLine().Append("USING "); + + var first = true; + + for (var i = 0; i < selectExpression.Tables.Count; i++) + { + switch (selectExpression.Tables[i]) + { + case InnerJoinExpression { Table: TableExpression tableExpression } innerJoinExpression: + // Add the table name and alias to the USING list, and add the join condition into the predicate + AppendToUsingList(tableExpression); + + predicate = predicate is null + ? innerJoinExpression.JoinPredicate + : new SqlBinaryExpression( + ExpressionType.AndAlso, innerJoinExpression.JoinPredicate, predicate, typeof(bool), null); + break; + + case TableExpression tableExpression: + AppendToUsingList(tableExpression); + break; + + default: + throw new InvalidOperationException(RelationalStrings.BulkOperationWithUnsupportedOperatorInSqlGeneration); + + void AppendToUsingList(TableExpression tableExpression) + { + if (tableExpression == deleteExpression.Table) + { + return; + } + + if (first) + { + first = false; + } + else + { + Sql.Append(", "); + } + Visit(tableExpression); + } + } + } + } + + if (predicate is not null) + { + Sql.AppendLine().Append("WHERE "); + + Visit(predicate); + } + + return deleteExpression; + } + + throw new InvalidOperationException(RelationalStrings.BulkOperationWithUnsupportedOperatorInSqlGeneration); + } + protected virtual Expression VisitPostgresNewArray(PostgresNewArrayExpression postgresNewArrayExpression) { Debug.Assert(postgresNewArrayExpression.TypeMapping is not null); @@ -426,19 +505,6 @@ protected override void GenerateSetOperationOperand(SetOperationBase setOperatio base.GenerateSetOperationOperand(setOperation, operand); } - protected override Expression VisitCollate(CollateExpression collateExpresion) - { - Check.NotNull(collateExpresion, nameof(collateExpresion)); - - Visit(collateExpresion.Operand); - - Sql - .Append(" COLLATE ") - .Append(_sqlGenerationHelper.DelimitIdentifier(collateExpresion.Collation)); - - return collateExpresion; - } - public virtual Expression VisitArrayAll(PostgresAllExpression expression) { Visit(expression.Item); diff --git a/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs new file mode 100644 index 0000000000..a4ff3aa33b --- /dev/null +++ b/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs @@ -0,0 +1,56 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Internal; + +public class NpgsqlQueryableMethodTranslatingExpressionVisitor : RelationalQueryableMethodTranslatingExpressionVisitor +{ + public NpgsqlQueryableMethodTranslatingExpressionVisitor( + QueryableMethodTranslatingExpressionVisitorDependencies dependencies, + RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies, + QueryCompilationContext queryCompilationContext) + : base(dependencies, relationalDependencies, queryCompilationContext) + { + } + + protected override bool IsValidSelectExpressionForBulkDelete( + SelectExpression selectExpression, + EntityShaperExpression entityShaperExpression, + [NotNullWhen(true)] out TableExpression? tableExpression) + { + // The default relational behavior is to allow only single-table expressions, and the only permitted feature is a predicate. + // Here we extend this to also inner joins to tables, which we generate via the PostgreSQL-specific USING construct. + if (selectExpression.Offset == null + && selectExpression.Limit == null + && (!selectExpression.IsDistinct || entityShaperExpression.EntityType.FindPrimaryKey() != null) + && selectExpression.GroupBy.Count == 0 + && selectExpression.Having == null + && selectExpression.Orderings.Count == 0) + { + TableExpressionBase? table = null; + if (selectExpression.Tables.Count == 1) + { + table = selectExpression.Tables[0]; + } + else if (selectExpression.Tables.All(t => t is TableExpression or InnerJoinExpression { Table: TableExpression })) + { + var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression; + var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetProjection(projectionBindingExpression); + var column = entityProjectionExpression.BindProperty(entityShaperExpression.EntityType.GetProperties().First()); + table = column.Table; + if (table is JoinExpressionBase joinExpressionBase) + { + table = joinExpressionBase.Table; + } + } + + if (table is TableExpression te) + { + tableExpression = te; + return true; + } + } + + tableExpression = null; + return false; + } +} diff --git a/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitorFactory.cs b/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitorFactory.cs new file mode 100644 index 0000000000..64ebaa2f0b --- /dev/null +++ b/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitorFactory.cs @@ -0,0 +1,19 @@ +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Internal; + +public class NpgsqlQueryableMethodTranslatingExpressionVisitorFactory : IQueryableMethodTranslatingExpressionVisitorFactory +{ + public NpgsqlQueryableMethodTranslatingExpressionVisitorFactory( + QueryableMethodTranslatingExpressionVisitorDependencies dependencies, + RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies) + { + Dependencies = dependencies; + RelationalDependencies = relationalDependencies; + } + + protected virtual QueryableMethodTranslatingExpressionVisitorDependencies Dependencies { get; } + + protected virtual RelationalQueryableMethodTranslatingExpressionVisitorDependencies RelationalDependencies { get; } + + public virtual QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext) + => new NpgsqlQueryableMethodTranslatingExpressionVisitor(Dependencies, RelationalDependencies, queryCompilationContext); +} diff --git a/test/EFCore.PG.FunctionalTests/BulkUpdates/FiltersInheritanceBulkUpdatesNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/BulkUpdates/FiltersInheritanceBulkUpdatesNpgsqlTest.cs new file mode 100644 index 0000000000..7549586d2f --- /dev/null +++ b/test/EFCore.PG.FunctionalTests/BulkUpdates/FiltersInheritanceBulkUpdatesNpgsqlTest.cs @@ -0,0 +1,19 @@ +using Microsoft.EntityFrameworkCore.BulkUpdates; +using Npgsql.EntityFrameworkCore.PostgreSQL.Query; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.BulkUpdates; + +public class FiltersInheritanceBulkUpdatesNpgsqlTest : FiltersInheritanceBulkUpdatesTestBase +{ + public FiltersInheritanceBulkUpdatesNpgsqlTest(FiltersInheritanceQueryNpgsqlFixture fixture) + : base(fixture) + { + } + + [ConditionalFact] + public virtual void Check_all_tests_overridden() + => TestHelpers.AssertAllMethodsOverridden(GetType()); + + private void AssertSql(params string[] expected) + => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); +} diff --git a/test/EFCore.PG.FunctionalTests/BulkUpdates/InheritanceBulkUpdatesNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/BulkUpdates/InheritanceBulkUpdatesNpgsqlTest.cs new file mode 100644 index 0000000000..31b807f5fe --- /dev/null +++ b/test/EFCore.PG.FunctionalTests/BulkUpdates/InheritanceBulkUpdatesNpgsqlTest.cs @@ -0,0 +1,19 @@ +using Microsoft.EntityFrameworkCore.BulkUpdates; +using Npgsql.EntityFrameworkCore.PostgreSQL.Query; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.BulkUpdates; + +public class InheritanceBulkUpdatesNpgsqlTest : InheritanceBulkUpdatesTestBase +{ + public InheritanceBulkUpdatesNpgsqlTest(InheritanceQueryNpgsqlFixture fixture) + : base(fixture) + { + } + + [ConditionalFact] + public virtual void Check_all_tests_overridden() + => TestHelpers.AssertAllMethodsOverridden(GetType()); + + private void AssertSql(params string[] expected) + => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); +} diff --git a/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs index 40d155fa73..f18d1ecee3 100644 --- a/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs @@ -239,12 +239,9 @@ public override async Task Delete_SelectMany(bool async) await base.Delete_SelectMany(async); AssertSql( - @"DELETE FROM ""Order Details"" AS o -WHERE EXISTS ( - SELECT 1 - FROM ""Orders"" AS o0 - INNER JOIN ""Order Details"" AS o1 ON o0.""OrderID"" = o1.""OrderID"" - WHERE o0.""OrderID"" < 10250 AND o1.""OrderID"" = o.""OrderID"" AND o1.""ProductID"" = o.""ProductID"")"); + @"DELETE FROM ""Order Details"" AS o0 +USING ""Orders"" AS o +WHERE o.""OrderID"" = o0.""OrderID"" AND o.""OrderID"" < 10250"); } public override async Task Delete_SelectMany_subquery(bool async) @@ -270,11 +267,8 @@ public override async Task Delete_Where_using_navigation(bool async) AssertSql( @"DELETE FROM ""Order Details"" AS o -WHERE EXISTS ( - SELECT 1 - FROM ""Order Details"" AS o0 - INNER JOIN ""Orders"" AS o1 ON o0.""OrderID"" = o1.""OrderID"" - WHERE date_part('year', o1.""OrderDate"")::int = 2000 AND o0.""OrderID"" = o.""OrderID"" AND o0.""ProductID"" = o.""ProductID"")"); +USING ""Orders"" AS o0 +WHERE o.""OrderID"" = o0.""OrderID"" AND date_part('year', o0.""OrderDate"")::int = 2000"); } public override async Task Delete_Where_using_navigation_2(bool async)