From a2376a711c449ebee75bfbdd9efc8e723e6250e3 Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Thu, 11 Apr 2024 14:04:13 -0700 Subject: [PATCH] Spark 3.4: Fix system function pushdown in CoW row-level commands (#10119) --- .../optimizer/ReplaceStaticInvoke.scala | 49 ++- .../spark/extensions/SparkPlanUtil.java | 48 +++ ...mFunctionPushDownInRowLevelOperations.java | 354 ++++++++++++++++++ .../spark/functions/BaseScalarFunction.java | 40 ++ .../spark/functions/BucketFunction.java | 3 +- .../iceberg/spark/functions/DaysFunction.java | 3 +- .../spark/functions/HoursFunction.java | 5 +- .../functions/IcebergVersionFunction.java | 3 +- .../spark/functions/MonthsFunction.java | 3 +- .../spark/functions/TruncateFunction.java | 3 +- .../spark/functions/YearsFunction.java | 3 +- 11 files changed, 486 insertions(+), 28 deletions(-) create mode 100644 spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java create mode 100644 spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala index 1f0e164d8467..655a93a7db8b 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala @@ -22,12 +22,20 @@ import org.apache.iceberg.spark.functions.SparkFunctions import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression import org.apache.spark.sql.catalyst.expressions.BinaryComparison import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.In +import org.apache.spark.sql.catalyst.expressions.InSet import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.ReplaceData import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER +import org.apache.spark.sql.catalyst.trees.TreePattern.IN +import org.apache.spark.sql.catalyst.trees.TreePattern.INSET +import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType @@ -40,21 +48,36 @@ import org.apache.spark.sql.types.StructType object ReplaceStaticInvoke extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) { - case filter @ Filter(condition, _) => - val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) { - case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable => - c.withNewChildren(Seq(replaceStaticInvoke(left), right)) + plan.transformWithPruning (_.containsAnyPattern(COMMAND, FILTER, JOIN)) { + case join @ Join(_, _, _, Some(cond), _) => + replaceStaticInvoke(join, cond, newCond => join.copy(condition = Some(newCond))) - case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable => - c.withNewChildren(Seq(left, replaceStaticInvoke(right))) - } + case filter @ Filter(cond, _) => + replaceStaticInvoke(filter, cond, newCond => filter.copy(condition = newCond)) + } + + private def replaceStaticInvoke[T <: LogicalPlan]( + node: T, + condition: Expression, + copy: Expression => T): T = { + val newCondition = replaceStaticInvoke(condition) + if (newCondition fastEquals condition) node else copy(newCondition) + } + + private def replaceStaticInvoke(condition: Expression): Expression = { + condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) { + case in @ In(value: StaticInvoke, _) if canReplace(value) => + in.copy(value = replaceStaticInvoke(value)) - if (newCondition fastEquals condition) { - filter - } else { - filter.copy(condition = newCondition) - } + case in @ InSet(value: StaticInvoke, _) if canReplace(value) => + in.copy(child = replaceStaticInvoke(value)) + + case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable => + c.withNewChildren(Seq(replaceStaticInvoke(left), right)) + + case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable => + c.withNewChildren(Seq(left, replaceStaticInvoke(right))) + } } private def replaceStaticInvoke(invoke: StaticInvoke): Expression = { diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java index 4f7c3ebadbc5..830d07d86eab 100644 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java @@ -20,12 +20,17 @@ import static scala.collection.JavaConverters.seqAsJavaListConverter; +import java.util.Collection; import java.util.List; +import java.util.function.Predicate; import java.util.stream.Collectors; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.execution.CommandResultExec; import org.apache.spark.sql.execution.SparkPlan; import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper; import org.apache.spark.sql.execution.datasources.v2.BatchScanExec; +import scala.PartialFunction; import scala.collection.Seq; public class SparkPlanUtil { @@ -53,6 +58,49 @@ private static SparkPlan actualPlan(SparkPlan plan) { } } + public static List collectExprs( + SparkPlan sparkPlan, Predicate predicate) { + Seq> seq = + SPARK_HELPER.collect( + sparkPlan, + new PartialFunction>() { + @Override + public List apply(SparkPlan plan) { + List exprs = Lists.newArrayList(); + + for (Expression expr : toJavaList(plan.expressions())) { + exprs.addAll(collectExprs(expr, predicate)); + } + + return exprs; + } + + @Override + public boolean isDefinedAt(SparkPlan plan) { + return true; + } + }); + return toJavaList(seq).stream().flatMap(Collection::stream).collect(Collectors.toList()); + } + + private static List collectExprs( + Expression expression, Predicate predicate) { + Seq seq = + expression.collect( + new PartialFunction() { + @Override + public Expression apply(Expression expr) { + return expr; + } + + @Override + public boolean isDefinedAt(Expression expr) { + return predicate.test(expr); + } + }); + return toJavaList(seq); + } + private static List toJavaList(Seq seq) { return seqAsJavaListConverter(seq).asJava(); } diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java new file mode 100644 index 000000000000..db4d10645b99 --- /dev/null +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; +import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.DistributionMode; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; +import org.apache.spark.sql.execution.CommandResultExec; +import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Parameterized.Parameters; + +public class TestSystemFunctionPushDownInRowLevelOperations extends SparkExtensionsTestBase { + + private static final String CHANGES_TABLE_NAME = "changes"; + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + public static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + SparkCatalogConfig.HIVE.properties() + } + }; + } + + public TestSystemFunctionPushDownInRowLevelOperations( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void beforeEach() { + sql("USE %s", catalogName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s PURGE", tableName); + sql("DROP TABLE IF EXISTS %s PURGE", tableName(CHANGES_TABLE_NAME)); + } + + @Test + public void testCopyOnWriteDeleteBucketTransformInPredicate() { + initTable("bucket(4, dep)"); + checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)"); + } + + @Test + public void testMergeOnReadDeleteBucketTransformInPredicate() { + initTable("bucket(4, dep)"); + checkDelete(MERGE_ON_READ, "system.bucket(4, dep) IN (2, 3)"); + } + + @Test + public void testCopyOnWriteDeleteBucketTransformEqPredicate() { + initTable("bucket(4, dep)"); + checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) = 2"); + } + + @Test + public void testMergeOnReadDeleteBucketTransformEqPredicate() { + initTable("bucket(4, dep)"); + checkDelete(MERGE_ON_READ, "system.bucket(4, dep) = 2"); + } + + @Test + public void testCopyOnWriteDeleteYearsTransform() { + initTable("years(ts)"); + checkDelete(COPY_ON_WRITE, "system.years(ts) > 30"); + } + + @Test + public void testMergeOnReadDeleteYearsTransform() { + initTable("years(ts)"); + checkDelete(MERGE_ON_READ, "system.years(ts) <= 30"); + } + + @Test + public void testCopyOnWriteDeleteMonthsTransform() { + initTable("months(ts)"); + checkDelete(COPY_ON_WRITE, "system.months(ts) <= 250"); + } + + @Test + public void testMergeOnReadDeleteMonthsTransform() { + initTable("months(ts)"); + checkDelete(MERGE_ON_READ, "system.months(ts) > 250"); + } + + @Test + public void testCopyOnWriteDeleteDaysTransform() { + initTable("days(ts)"); + checkDelete(COPY_ON_WRITE, "system.days(ts) <= date('2000-01-03 00:00:00')"); + } + + @Test + public void testMergeOnReadDeleteDaysTransform() { + initTable("days(ts)"); + checkDelete(MERGE_ON_READ, "system.days(ts) > date('2000-01-03 00:00:00')"); + } + + @Test + public void testCopyOnWriteDeleteHoursTransform() { + initTable("hours(ts)"); + checkDelete(COPY_ON_WRITE, "system.hours(ts) <= 100000"); + } + + @Test + public void testMergeOnReadDeleteHoursTransform() { + initTable("hours(ts)"); + checkDelete(MERGE_ON_READ, "system.hours(ts) > 100000"); + } + + @Test + public void testCopyOnWriteDeleteTruncateTransform() { + initTable("truncate(1, dep)"); + checkDelete(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'"); + } + + @Test + public void testMergeOnReadDeleteTruncateTransform() { + initTable("truncate(1, dep)"); + checkDelete(MERGE_ON_READ, "system.truncate(1, dep) = 'i'"); + } + + @Test + public void testCopyOnWriteUpdateBucketTransform() { + initTable("bucket(4, dep)"); + checkUpdate(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)"); + } + + @Test + public void testMergeOnReadUpdateBucketTransform() { + initTable("bucket(4, dep)"); + checkUpdate(MERGE_ON_READ, "system.bucket(4, dep) = 2"); + } + + @Test + public void testCopyOnWriteUpdateYearsTransform() { + initTable("years(ts)"); + checkUpdate(COPY_ON_WRITE, "system.years(ts) > 30"); + } + + @Test + public void testMergeOnReadUpdateYearsTransform() { + initTable("years(ts)"); + checkUpdate(MERGE_ON_READ, "system.years(ts) <= 30"); + } + + @Test + public void testCopyOnWriteMergeBucketTransform() { + initTable("bucket(4, dep)"); + checkMerge(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)"); + } + + @Test + public void testMergeOnReadMergeBucketTransform() { + initTable("bucket(4, dep)"); + checkMerge(MERGE_ON_READ, "system.bucket(4, dep) = 2"); + } + + @Test + public void testCopyOnWriteMergeYearsTransform() { + initTable("years(ts)"); + checkMerge(COPY_ON_WRITE, "system.years(ts) > 30"); + } + + @Test + public void testMergeOnReadMergeYearsTransform() { + initTable("years(ts)"); + checkMerge(MERGE_ON_READ, "system.years(ts) <= 30"); + } + + @Test + public void testCopyOnWriteMergeTruncateTransform() { + initTable("truncate(1, dep)"); + checkMerge(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'"); + } + + @Test + public void testMergeOnReadMergeTruncateTransform() { + initTable("truncate(1, dep)"); + checkMerge(MERGE_ON_READ, "system.truncate(1, dep) = 'i'"); + } + + private void checkDelete(RowLevelOperationMode mode, String cond) { + withUnavailableLocations( + findIrrelevantFileLocations(cond), + () -> { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", + tableName, + TableProperties.DELETE_MODE, + mode.modeName(), + TableProperties.DELETE_DISTRIBUTION_MODE, + DistributionMode.NONE.modeName()); + + Dataset changeDF = spark.table(tableName).where(cond).limit(2).select("id"); + changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create(); + + List calls = + executeAndCollectFunctionCalls( + "DELETE FROM %s t WHERE %s AND t.id IN (SELECT id FROM %s)", + tableName, cond, tableName(CHANGES_TABLE_NAME)); + // CoW planning currently does not optimize post-scan filters in DELETE + int expectedCallCount = mode == COPY_ON_WRITE ? 1 : 0; + assertThat(calls).hasSize(expectedCallCount); + + assertEquals( + "Should have no matching rows", + ImmutableList.of(), + sql( + "SELECT * FROM %s WHERE %s AND id IN (SELECT * FROM %s)", + tableName, cond, tableName(CHANGES_TABLE_NAME))); + }); + } + + private void checkUpdate(RowLevelOperationMode mode, String cond) { + withUnavailableLocations( + findIrrelevantFileLocations(cond), + () -> { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", + tableName, + TableProperties.UPDATE_MODE, + mode.modeName(), + TableProperties.UPDATE_DISTRIBUTION_MODE, + DistributionMode.NONE.modeName()); + + Dataset changeDF = spark.table(tableName).where(cond).limit(2).select("id"); + changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create(); + + List calls = + executeAndCollectFunctionCalls( + "UPDATE %s t SET t.salary = -1 WHERE %s AND t.id IN (SELECT id FROM %s)", + tableName, cond, tableName(CHANGES_TABLE_NAME)); + // CoW planning currently does not optimize post-scan filters in UPDATE + int expectedCallCount = mode == COPY_ON_WRITE ? 2 : 0; + assertThat(calls).hasSize(expectedCallCount); + + assertEquals( + "Should have correct updates", + sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)), + sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond)); + }); + } + + private void checkMerge(RowLevelOperationMode mode, String cond) { + withUnavailableLocations( + findIrrelevantFileLocations(cond), + () -> { + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", + tableName, + TableProperties.MERGE_MODE, + mode.modeName(), + TableProperties.MERGE_DISTRIBUTION_MODE, + DistributionMode.NONE.modeName()); + + Dataset changeDF = + spark.table(tableName).where(cond).limit(2).selectExpr("id + 1 as id"); + changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create(); + + List calls = + executeAndCollectFunctionCalls( + "MERGE INTO %s t USING %s s " + + "ON t.id == s.id AND %s " + + "WHEN MATCHED THEN " + + " UPDATE SET salary = -1 " + + "WHEN NOT MATCHED AND s.id = 2 THEN " + + " INSERT (id, salary, dep, ts) VALUES (100, -1, 'hr', null)", + tableName, tableName(CHANGES_TABLE_NAME), cond); + assertThat(calls).isEmpty(); + + assertEquals( + "Should have correct updates", + sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)), + sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond)); + }); + } + + private List executeAndCollectFunctionCalls(String query, Object... args) { + CommandResultExec command = (CommandResultExec) executeAndKeepPlan(query, args); + V2TableWriteExec write = (V2TableWriteExec) command.commandPhysicalPlan(); + System.out.println("!!! WRITE PLAN !!!"); + System.out.println(write.toString()); + return SparkPlanUtil.collectExprs( + write.query(), + expr -> expr instanceof StaticInvoke || expr instanceof ApplyFunctionExpression); + } + + private List findIrrelevantFileLocations(String cond) { + return spark + .table(tableName) + .where("NOT " + cond) + .select(MetadataColumns.FILE_PATH.name()) + .distinct() + .as(Encoders.STRING()) + .collectAsList(); + } + + private void initTable(String transform) { + sql( + "CREATE TABLE %s (id BIGINT, salary INT, dep STRING, ts TIMESTAMP)" + + "USING iceberg " + + "PARTITIONED BY (%s) " + + "TBLPROPERTIES ('%s' 'true')", + tableName, transform, TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + append( + tableName, + "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", + "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", + "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", + "{ \"id\": 4, \"salary\": 400, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }", + "{ \"id\": 5, \"salary\": 500, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }", + "{ \"id\": 6, \"salary\": 600, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }"); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java new file mode 100644 index 000000000000..5ec44f314180 --- /dev/null +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; + +abstract class BaseScalarFunction implements ScalarFunction { + @Override + public int hashCode() { + return canonicalName().hashCode(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (!(other instanceof ScalarFunction)) { + return false; + } + + ScalarFunction that = (ScalarFunction) other; + return canonicalName().equals(that.canonicalName()); + } +} diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java index af3c67a4bb63..c3de3d48dbcc 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java @@ -25,7 +25,6 @@ import org.apache.iceberg.util.BucketUtil; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.functions.BoundFunction; -import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; import org.apache.spark.sql.types.BinaryType; import org.apache.spark.sql.types.ByteType; @@ -115,7 +114,7 @@ public String name() { return "bucket"; } - public abstract static class BucketBase implements ScalarFunction { + public abstract static class BucketBase extends BaseScalarFunction { public static int apply(int numBuckets, int hashedValue) { return (hashedValue & Integer.MAX_VALUE) % numBuckets; } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java index b8d28b73f42f..f52edd9b208f 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java @@ -21,7 +21,6 @@ import org.apache.iceberg.util.DateTimeUtil; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.functions.BoundFunction; -import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DateType; @@ -61,7 +60,7 @@ public String name() { return "days"; } - private abstract static class BaseToDaysFunction implements ScalarFunction { + private abstract static class BaseToDaysFunction extends BaseScalarFunction { @Override public String name() { return "days"; diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java index 18697e1c16fb..660a182f0e78 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java @@ -21,7 +21,6 @@ import org.apache.iceberg.util.DateTimeUtil; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.functions.BoundFunction; -import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.TimestampNTZType; @@ -58,7 +57,7 @@ public String name() { return "hours"; } - public static class TimestampToHoursFunction implements ScalarFunction { + public static class TimestampToHoursFunction extends BaseScalarFunction { // magic method used in codegen public static int invoke(long micros) { return DateTimeUtil.microsToHours(micros); @@ -91,7 +90,7 @@ public Integer produceResult(InternalRow input) { } } - public static class TimestampNtzToHoursFunction implements ScalarFunction { + public static class TimestampNtzToHoursFunction extends BaseScalarFunction { // magic method used in codegen public static int invoke(long micros) { return DateTimeUtil.microsToHours(micros); diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java index 9cd059377ce3..689a0f4cb4df 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java @@ -21,7 +21,6 @@ import org.apache.iceberg.IcebergBuild; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.functions.BoundFunction; -import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; @@ -55,7 +54,7 @@ public String name() { // Implementing class cannot be private, otherwise Spark is unable to access the static invoke // function during code-gen and calling the function fails - static class IcebergVersionFunctionImpl implements ScalarFunction { + static class IcebergVersionFunctionImpl extends BaseScalarFunction { private static final UTF8String VERSION = UTF8String.fromString(IcebergBuild.version()); // magic function used in code-gen. must be named `invoke`. diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java index 1d38014461c1..353d850f86e2 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java @@ -21,7 +21,6 @@ import org.apache.iceberg.util.DateTimeUtil; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.functions.BoundFunction; -import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DateType; @@ -61,7 +60,7 @@ public String name() { return "months"; } - private abstract static class BaseToMonthsFunction implements ScalarFunction { + private abstract static class BaseToMonthsFunction extends BaseScalarFunction { @Override public String name() { return "months"; diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java index 8cfb529e1028..fac90c9efee6 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java @@ -27,7 +27,6 @@ import org.apache.iceberg.util.TruncateUtil; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.functions.BoundFunction; -import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; import org.apache.spark.sql.types.BinaryType; import org.apache.spark.sql.types.ByteType; @@ -108,7 +107,7 @@ public String name() { return "truncate"; } - public abstract static class TruncateBase implements ScalarFunction { + public abstract static class TruncateBase extends BaseScalarFunction { @Override public String name() { return "truncate"; diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java index 02642e657d76..cfd1b0e8d002 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java @@ -21,7 +21,6 @@ import org.apache.iceberg.util.DateTimeUtil; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.functions.BoundFunction; -import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DateType; @@ -61,7 +60,7 @@ public String name() { return "years"; } - private abstract static class BaseToYearsFunction implements ScalarFunction { + private abstract static class BaseToYearsFunction extends BaseScalarFunction { @Override public String name() { return "years";