From 68fa1a9f4a41eeb46ab2e2fc1d6220d7966b3f72 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 15 Mar 2021 11:54:40 -0700 Subject: [PATCH 01/21] [SPARK-34981][SQL] Implement V2 function resolution and evaluation Co-Authored-By: Chao Sun Co-Authored-By: Ryan Blue --- .../catalog/functions/ScalarFunction.java | 32 ++ .../sql/catalyst/analysis/Analyzer.scala | 94 ++++- .../analysis/higherOrderFunctions.scala | 5 +- .../sql/catalyst/analysis/unresolved.scala | 20 +- .../expressions/ApplyFunctionExpression.scala | 43 +++ .../expressions/aggregate/V2Aggregator.scala | 65 ++++ .../sql/catalyst/parser/AstBuilder.scala | 35 +- .../catalog/CatalogV2Implicits.scala | 32 ++ ...eateTablePartitioningValidationSuite.scala | 4 +- .../analysis/TableLookupCacheSuite.scala | 4 +- .../catalog/CatalogManagerSuite.scala | 4 +- ...eCatalogSuite.scala => CatalogSuite.scala} | 52 ++- .../catalog/StagingInMemoryTableCatalog.scala | 4 +- ...pportsAtomicPartitionManagementSuite.scala | 4 +- .../SupportsPartitionManagementSuite.scala | 6 +- ...eCatalog.scala => V2InMemoryCatalog.scala} | 124 +------ ...scala => V2InMemoryPartitionCatalog.scala} | 4 +- .../spark/sql/JavaDataFrameWriterV2Suite.java | 4 +- .../catalog/functions/JavaAverage.java | 93 +++++ .../catalog/functions/JavaStrLen.java | 119 ++++++ .../spark/sql/CharVarcharTestSuite.scala | 4 +- .../spark/sql/DataFrameWriterV2Suite.scala | 4 +- .../apache/spark/sql/SQLInsertTestSuite.scala | 4 +- .../DataSourceV2DataFrameSuite.scala | 4 +- .../connector/DataSourceV2FunctionSuite.scala | 339 ++++++++++++++++++ .../sql/connector/DataSourceV2SQLSuite.scala | 20 +- .../sql/connector/DatasourceV2SQLBase.scala | 8 +- .../SupportsCatalogOptionsSuite.scala | 2 +- .../sql/connector/V1ReadFallbackSuite.scala | 2 +- .../WriteDistributionAndOrderingSuite.scala | 6 +- .../spark/sql/execution/HiveResultSuite.scala | 6 +- .../command/CharVarcharDDLTestBase.scala | 4 +- .../command/v2/CommandSuiteBase.scala | 6 +- .../command/v2/ShowNamespacesSuite.scala | 4 +- .../test/DataStreamTableAPISuite.scala | 10 +- 35 files changed, 972 insertions(+), 199 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/{TableCatalogSuite.scala => CatalogSuite.scala} (92%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/{InMemoryTableCatalog.scala => V2InMemoryCatalog.scala} (53%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/{InMemoryPartitionTableCatalog.scala => V2InMemoryPartitionCatalog.scala} (91%) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index c2106a21c4a8f..49999fb1fa24d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -30,10 +30,42 @@ *

* The JVM type of result values produced by this function must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. + *

+ * IMPORTANT: the default implementation of {@link #produceResult} throws + * {@link UnsupportedOperationException}. Users can choose to override this method, or implement + * a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters + * instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java + * reflection and also will provide better performance in general, due to optimizations such as + * codegen, Java boxing and so on. + * + * For example, a scalar UDF for adding two integers can be defined as follow with the magic + * method approach: + * + *

+ *   {@code
+ *     public class IntegerAdd implements ScalarFunction {
+ *       public int invoke(int left, int right) {
+ *         return left + right;
+ *       }
+ *
+ *       @Overrides
+ *       public produceResult(InternalRow input) {
+ *         int left = input.getInt(0);
+ *         int right = input.getInt(1);
+ *         return left + right;
+ *       }
+ *     }
+ *   }
+ * 
+ * In this case, both {@link #MAGIC_METHOD_NAME} and {@link #produceResult} are defined, and Spark will + * first lookup the {@code invoke} method during query analysis. It checks whether the method + * parameters have the valid types that are supported by Spark. If the check fails it falls back + * to use {@link #produceResult}. * * @param the JVM type of result values */ public interface ScalarFunction extends BoundFunction { + String MAGIC_METHOD_NAME = "invoke"; /** * Applies the function to an input row to produce a value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 87b8d52ac277f..61f9cb498126f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Method import java.util import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean @@ -44,6 +45,7 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} +import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -1958,6 +1960,9 @@ class Analyzer(override val catalogManager: CatalogManager) override def apply(plan: LogicalPlan): LogicalPlan = { val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() plan.resolveExpressions { + case f @ UnresolvedFunction(NonSessionCatalogAndIdentifier(_, _), _, _, _, _) => + // no-op if this is from a v2 catalog + f case f: UnresolvedFunction if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f @@ -2016,9 +2021,71 @@ class Analyzer(override val catalogManager: CatalogManager) name, other.getClass.getCanonicalName) } } - case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter, ignoreNulls) => + case UnresolvedFunction(NonSessionCatalogAndIdentifier(v2Catalog, ident), arguments, + isDistinct, filter, ignoreNulls) if v2Catalog.isFunctionCatalog => + val unbound = v2Catalog.asFunctionCatalog.loadFunction(ident) + + val inputType = StructType(arguments.zipWithIndex.map { + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) + }) + + val bound = try { + unbound.bind(inputType) + } catch { + case unsupported: UnsupportedOperationException => + failAnalysis(s"Function ${unbound.name} cannot process input: " + + s"(${arguments.map(_.dataType.simpleString).mkString(", ")}): " + + unsupported.getMessage) + } + + bound match { + case scalarFunc: ScalarFunction[_] => + if (isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "DISTINCT") + } else if (filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "FILTER clause") + } else { + val argClasses = inputType.fields.map(_.dataType) + findMethod(scalarFunc, ScalarFunction.MAGIC_METHOD_NAME, Some(argClasses)) match { + case Some(_) => + val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) + Invoke(caller, ScalarFunction.MAGIC_METHOD_NAME, scalarFunc.resultType(), + arguments, returnNullable = scalarFunc.isResultNullable) + case _ => + // TODO: handle functions defined in Scala too - in Scala, even if a + // subclass do not override the default method in parent interface defined + // in Java, the method can still be found from `getDeclaredMethod`. + findMethod(scalarFunc, "produceResult", Some(Seq(inputType))) match { + case Some(_) => + ApplyFunctionExpression(scalarFunc, arguments) + case None => + failAnalysis(s"ScalarFunction '${bound.name()}' neither implement " + + s"magic method nor override 'produceResult'") + } + } + } + case aggFunc: V2AggregateFunction[_, _] => + // due to type erasure we can't match by parameter types here, so this check will + // succeed even if the class doesn't override `update` but implements another + // method with the same name. + findMethod(aggFunc, "update") match { + case Some(_) => + val aggregator = V2Aggregator(aggFunc, arguments) + AggregateExpression(aggregator, Complete, isDistinct, filter) + case None => + failAnalysis(s"AggregateFunction '${bound.name()}' neither implement magic " + + s"method nor override 'update'") + } + case _ => + failAnalysis(s"Function ${bound.name()} does not implement ScalarFunction or " + + s"AggregateFunction") + } + + case u @ UnresolvedFunction(parts, arguments, isDistinct, filter, ignoreNulls) => withPosition(u) { - v1SessionCatalog.lookupFunction(funcId, arguments) match { + v1SessionCatalog.lookupFunction(parts.asFunctionIdentifier, arguments) match { // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. @@ -2098,6 +2165,29 @@ class Analyzer(override val catalogManager: CatalogManager) } } } + + /** + * Check if the input `fn` implements the given `methodName`. If `inputType` is set, it also + * tries to match it against the declared parameter types. + */ + private def findMethod( + fn: BoundFunction, + methodName: String, + inputTypeOpt: Option[Seq[DataType]] = None): Option[Method] = { + val cls = fn.getClass + inputTypeOpt match { + case Some(inputType) => + try { + val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass) + Some(cls.getDeclaredMethod(methodName, argClasses: _*)) + } catch { + case _: NoSuchMethodException => + None + } + case None => + cls.getDeclaredMethods.find(_.getName == methodName) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 7d74c0d1cd14f..3d6b4e97c8acd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.DataType @@ -33,10 +34,10 @@ import org.apache.spark.sql.types.DataType case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case u @ UnresolvedFunction(fn, children, false, filter, ignoreNulls) + case u @ UnresolvedFunction(parts, children, false, filter, ignoreNulls) if hasLambdaAndResolvedArguments(children) => withPosition(u) { - catalog.lookupFunction(fn, children) match { + catalog.lookupFunction(parts.asFunctionIdentifier, children) match { case func: HigherOrderFunction => filter.foreach(_.failAnalysis("FILTER predicate specified, " + s"but ${func.prettyName} is not an aggregate function")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 5001e2ea88ac7..f8a10c4afc9c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -269,12 +269,13 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio } case class UnresolvedFunction( - name: FunctionIdentifier, + multipartIdentifier: Seq[String], arguments: Seq[Expression], isDistinct: Boolean, filter: Option[Expression] = None, ignoreNulls: Boolean = false) extends Expression with Unevaluable { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def children: Seq[Expression] = arguments ++ filter.toSeq @@ -282,10 +283,10 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException("nullable") override lazy val resolved = false - override def prettyName: String = name.unquotedString + override def prettyName: String = multipartIdentifier.quoted override def toString: String = { val distinct = if (isDistinct) "distinct " else "" - s"'$name($distinct${children.mkString(", ")})" + s"'${multipartIdentifier.quoted}($distinct${children.mkString(", ")})" } override protected def withNewChildrenInternal( @@ -296,11 +297,22 @@ case class UnresolvedFunction( copy(arguments = newChildren) } } + + def name: FunctionIdentifier = multipartIdentifier.asFunctionIdentifier } object UnresolvedFunction { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + def apply( + name: FunctionIdentifier, + arguments: Seq[Expression], + isDistinct: Boolean): UnresolvedFunction = { + UnresolvedFunction(name.asMultipart, arguments, isDistinct) + } + def apply(name: String, arguments: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = { - UnresolvedFunction(FunctionIdentifier(name, None), arguments, isDistinct) + UnresolvedFunction(Seq(name), arguments, isDistinct) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala new file mode 100644 index 0000000000000..25fb4a0731b35 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction +import org.apache.spark.sql.types.DataType + +case class ApplyFunctionExpression( + function: ScalarFunction[_], + children: Seq[Expression]) extends Expression with UserDefinedExpression with CodegenFallback { + override def nullable: Boolean = function.isResultNullable + override def name: String = function.name() + override def dataType: DataType = function.resultType() + + private lazy val reusedRow = new GenericInternalRow(children.size) + + /** Returns the result of evaluating this expression on a given input Row */ + override def eval(input: InternalRow): Any = { + children.zipWithIndex.foreach { + case (expr, pos) => + reusedRow.update(pos, expr.eval(input)) + } + + function.produceResult(reusedRow) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala new file mode 100644 index 0000000000000..27568e6f1b119 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction} +import org.apache.spark.sql.types.DataType + +case class V2Aggregator[BUF <: java.io.Serializable, OUT]( + aggrFunc: V2AggregateFunction[BUF, OUT], + children: Seq[Expression], + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[BUF] { + override def createAggregationBuffer(): BUF = aggrFunc.newAggregationState() + + override def update(buffer: BUF, input: InternalRow): BUF = aggrFunc.update(buffer, input) + + override def merge(buffer: BUF, input: BUF): BUF = aggrFunc.merge(buffer, input) + + override def eval(buffer: BUF): Any = { + aggrFunc.produceResult(buffer) + } + + override def serialize(buffer: BUF): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val out = new ObjectOutputStream(bos) + out.writeObject(buffer) + out.close() + bos.toByteArray + } + + override def deserialize(bytes: Array[Byte]): BUF = { + val in = new ObjectInputStream(new ByteArrayInputStream(bytes)) + in.readObject().asInstanceOf[BUF] + } + + def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): V2Aggregator[BUF, OUT] = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): V2Aggregator[BUF, OUT] = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = aggrFunc.isResultNullable + + override def dataType: DataType = aggrFunc.resultType() +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd21beafbd85f..8f9faf67f2cf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1818,7 +1818,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val ignoreNulls = Option(ctx.nullsOption).map(_.getType == SqlBaseParser.IGNORE).getOrElse(false) val function = UnresolvedFunction( - getFunctionIdentifier(ctx.functionName), arguments, isDistinct, filter, ignoreNulls) + getFunctionMultiparts(ctx.functionName), arguments, isDistinct, filter, ignoreNulls) // Check if the function is evaluated in a windowed context. ctx.windowSpec match { @@ -1830,7 +1830,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } } - /** * Create a function database (optional) and name pair, for multipartIdentifier. * This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS. @@ -1846,18 +1845,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq) } - /** - * Create a function database (optional) and name pair. - */ - private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = { - texts match { - case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) - case Seq(fn) => FunctionIdentifier(fn, None) - case other => - throw QueryParsingErrors.functionNameUnsupportedError(texts.mkString("."), ctx) - } - } - /** * Get a function identifier consist by database (optional) and name. */ @@ -1869,7 +1856,27 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } } + protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = { + if (ctx.qualifiedName != null) { + ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq + } else { + Seq(ctx.getText) + } + } + /** + * Create a function database (optional) and name pair. + */ + private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = { + texts match { + case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) + case Seq(fn) => FunctionIdentifier(fn, None) + case other => + throw QueryParsingErrors.functionNameUnsupportedError(texts.mkString("."), ctx) + } + } + + /** * Create an [[LambdaFunction]]. */ override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index be9b94c606196..31edb1838ef19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -83,12 +83,36 @@ private[sql] object CatalogV2Implicits { throw new AnalysisException( s"Cannot use catalog ${plugin.name}: does not support namespaces") } + + def isFunctionCatalog: Boolean = plugin match { + case _: FunctionCatalog => true + case _ => false + } + + def asFunctionCatalog: FunctionCatalog = plugin match { + case functionCatalog: FunctionCatalog => + functionCatalog + case _ => + throw new UnsupportedOperationException( + s"Cannot use catalog ${plugin.name}: not a FunctionCatalog") + } } implicit class NamespaceHelper(namespace: Array[String]) { def quoted: String = namespace.map(quoteIfNeeded).mkString(".") } + implicit class FunctionIdentifierHelper(ident: FunctionIdentifier) { + def asMultipart: Seq[String] = { + ident.database match { + case Some(db) => + Seq(db, ident.funcName) + case _ => + Seq(ident.funcName) + } + } + } + implicit class IdentifierHelper(ident: Identifier) { def quoted: String = { if (ident.namespace.nonEmpty) { @@ -132,6 +156,14 @@ private[sql] object CatalogV2Implicits { s"$quoted is not a valid TableIdentifier as it has more than 2 name parts.") } + def asFunctionIdentifier: FunctionIdentifier = parts match { + case Seq(funcName) => FunctionIdentifier(funcName) + case Seq(dbName, funcName) => FunctionIdentifier(funcName, Some(dbName)) + case _ => + throw new AnalysisException( + s"$quoted is not a valid FunctionIdentifier as it has more than 2 name parts.") + } + def quoted: String = parts.map(quoteIfNeeded).mkString(".") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index f7e57e3b27b21..fde969108a49c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog, V2InMemoryCatalog} import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -134,7 +134,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { private[sql] object CreateTablePartitioningValidationSuite { val catalog: TableCatalog = { - val cat = new InMemoryTableCatalog() + val cat = new V2InMemoryCatalog() cat.initialize("test", CaseInsensitiveStringMap.empty()) cat } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala index 7d6ad3bc60902..ddd0ab983b684 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, ExternalCatalog, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, InMemoryTable, InMemoryTableCatalog, Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, InMemoryTable, Table, V2InMemoryCatalog} import org.apache.spark.sql.types._ class TableLookupCacheSuite extends AnalysisTest with Matchers { @@ -45,7 +45,7 @@ class TableLookupCacheSuite extends AnalysisTest with Matchers { CatalogStorageFormat.empty, StructType(Seq(StructField("a", IntegerType)))), ignoreIfExists = false) - val v2Catalog = new InMemoryTableCatalog { + val v2Catalog = new V2InMemoryCatalog { override def loadTable(ident: Identifier): Table = { val catalogTable = externalCatalog.getTable("default", ident.name) new InMemoryTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala index eb35dd47a508f..a8fbbb9c06d98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala @@ -113,9 +113,9 @@ class CatalogManagerSuite extends SparkFunSuite with SQLHelper { assert(v1SessionCatalog.getCurrentDatabase == "default") // Check namespace existence if currentCatalog implements SupportsNamespaces. - withSQLConf("spark.sql.catalog.testCatalog" -> classOf[InMemoryTableCatalog].getName) { + withSQLConf("spark.sql.catalog.testCatalog" -> classOf[V2InMemoryCatalog].getName) { catalogManager.setCurrentCatalog("testCatalog") - catalogManager.currentCatalog.asInstanceOf[InMemoryTableCatalog] + catalogManager.currentCatalog.asInstanceOf[V2InMemoryCatalog] .createNamespace(Array("test3"), Map.empty[String, String].asJava) assert(v1SessionCatalog.getCurrentDatabase == "default") catalogManager.setCurrentNamespace(Array("test3")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala similarity index 92% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index 5560bda928232..a93c9c1bb4687 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -24,14 +24,15 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} import org.apache.spark.sql.connector.expressions.LogicalExpressions import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap -class TableCatalogSuite extends SparkFunSuite { +class CatalogSuite extends SparkFunSuite { import CatalogV2Implicits._ private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] @@ -39,8 +40,8 @@ class TableCatalogSuite extends SparkFunSuite { .add("id", IntegerType) .add("data", StringType) - private def newCatalog(): TableCatalog with SupportsNamespaces = { - val newCatalog = new InMemoryTableCatalog + private def newCatalog(): V2InMemoryCatalog = { + val newCatalog = new V2InMemoryCatalog newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) newCatalog } @@ -902,7 +903,7 @@ class TableCatalogSuite extends SparkFunSuite { } test("truncate partitioned table") { - val partCatalog = new InMemoryPartitionTableCatalog + val partCatalog = new V2InMemoryPartitionCatalog partCatalog.initialize("test", CaseInsensitiveStringMap.empty()) val table = partCatalog.createTable( @@ -927,4 +928,43 @@ class TableCatalogSuite extends SparkFunSuite { assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 2) assert(partTable.rows.isEmpty) } + + val function: UnboundFunction = new UnboundFunction { + override def bind(inputType: StructType): BoundFunction = new ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType) + override def resultType(): DataType = IntegerType + override def name(): String = "my_bound_function" + } + override def description(): String = "my_function" + override def name(): String = "my_function" + } + + test("list functions") { + val catalog = newCatalog() + val ident1 = Identifier.of(Array("ns1", "ns2"), "func1") + val ident2 = Identifier.of(Array("ns1", "ns2"), "func2") + val ident3 = Identifier.of(Array("ns1", "ns3"), "func3") + + catalog.createNamespace(Array("ns1", "ns2"), emptyProps) + catalog.createNamespace(Array("ns1", "ns3"), emptyProps) + catalog.asInstanceOf[V2InMemoryCatalog].createFunction(ident1, function) + catalog.asInstanceOf[V2InMemoryCatalog].createFunction(ident2, function) + catalog.asInstanceOf[V2InMemoryCatalog].createFunction(ident3, function) + + assert(catalog.listFunctions(Array("ns1", "ns2")).toSet === Set(ident1, ident2)) + assert(catalog.listFunctions(Array("ns1", "ns3")).toSet === Set(ident3)) + assert(catalog.listFunctions(Array("ns1")).toSet == Set()) + intercept[NoSuchNamespaceException](catalog.listFunctions(Array("ns2"))) + } + + test("lookup function") { + val catalog = newCatalog() + val ident = Identifier.of(Array("ns"), "func") + catalog.createNamespace(Array("ns"), emptyProps) + catalog.asInstanceOf[V2InMemoryCatalog].createFunction(ident, function) + + assert(catalog.loadFunction(ident) == function) + intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns"), "func1"))) + intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns1"), "func"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala index 954650ae0eebd..54724d6129fa3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTableCatalog { - import InMemoryTableCatalog._ +class StagingInMemoryCatalog extends V2InMemoryCatalog with StagingTableCatalog { + import V2InMemoryCatalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def stageCreate( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala index df2fbd6d179bb..00e9e2ff26cb5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala @@ -32,8 +32,8 @@ class SupportsAtomicPartitionManagementSuite extends SparkFunSuite { def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) - private val catalog: InMemoryTableCatalog = { - val newCatalog = new InMemoryTableCatalog + private val catalog: V2InMemoryCatalog = { + val newCatalog = new V2InMemoryCatalog newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) newCatalog.createTable( ident, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala index e5aeb90b841a6..332b0975c8a34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala @@ -34,8 +34,8 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) - private val catalog: InMemoryTableCatalog = { - val newCatalog = new InMemoryTableCatalog + private val catalog: V2InMemoryCatalog = { + val newCatalog = new V2InMemoryCatalog newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) newCatalog.createTable( ident, @@ -156,7 +156,7 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { } private def createMultiPartTable(): InMemoryPartitionTable = { - val partCatalog = new InMemoryPartitionTableCatalog + val partCatalog = new V2InMemoryPartitionCatalog partCatalog.initialize("test", CaseInsensitiveStringMap.empty()) val table = partCatalog.createTable( ident, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala similarity index 53% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala index 38113f9ea1902..b0bc6d054721b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala @@ -22,122 +22,14 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -class BasicInMemoryTableCatalog extends TableCatalog { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - - protected val namespaces: util.Map[List[String], Map[String, String]] = - new ConcurrentHashMap[List[String], Map[String, String]]() - - protected val tables: util.Map[Identifier, Table] = - new ConcurrentHashMap[Identifier, Table]() - - private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet() - - private var _name: Option[String] = None - - override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { - _name = Some(name) - } - - override def name: String = _name.get - - override def listTables(namespace: Array[String]): Array[Identifier] = { - tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray - } - - override def loadTable(ident: Identifier): Table = { - Option(tables.get(ident)) match { - case Some(table) => - table - case _ => - throw new NoSuchTableException(ident) - } - } - - override def invalidateTable(ident: Identifier): Unit = { - invalidatedTables.add(ident) - } - - override def createTable( - ident: Identifier, - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { - createTable(ident, schema, partitions, properties, Distributions.unspecified(), - Array.empty, None) - } - - def createTable( - ident: Identifier, - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String], - distribution: Distribution, - ordering: Array[SortOrder], - requiredNumPartitions: Option[Int]): Table = { - if (tables.containsKey(ident)) { - throw new TableAlreadyExistsException(ident) - } - - InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) - - val tableName = s"$name.${ident.quoted}" - val table = new InMemoryTable(tableName, schema, partitions, properties, distribution, - ordering, requiredNumPartitions) - tables.put(ident, table) - namespaces.putIfAbsent(ident.namespace.toList, Map()) - table - } - - override def alterTable(ident: Identifier, changes: TableChange*): Table = { - val table = loadTable(ident).asInstanceOf[InMemoryTable] - val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) - val schema = CatalogV2Util.applySchemaChanges(table.schema, changes) - - // fail if the last column in the schema was dropped - if (schema.fields.isEmpty) { - throw new IllegalArgumentException(s"Cannot drop all fields") - } - - val newTable = new InMemoryTable(table.name, schema, table.partitioning, properties) - .withData(table.data) - - tables.put(ident, newTable) - - newTable - } - - override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined - - override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { - if (tables.containsKey(newIdent)) { - throw new TableAlreadyExistsException(newIdent) - } - - Option(tables.remove(oldIdent)) match { - case Some(table) => - tables.put(newIdent, table) - case _ => - throw new NoSuchTableException(oldIdent) - } - } - - def isTableInvalidated(ident: Identifier): Boolean = { - invalidatedTables.contains(ident) - } - - def clearTables(): Unit = { - tables.clear() - } -} - -class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamespaces { +class V2InMemoryCatalog extends BasicInMemoryCatalog with SupportsNamespaces { private def allNamespaces: Seq[Seq[String]] = { (tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct } @@ -210,9 +102,17 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp throw new NoSuchNamespaceException(namespace) } } + + override def listFunctions(namespace: Array[String]): Array[Identifier] = { + if (namespace.isEmpty || namespaceExists(namespace)) { + super.listFunctions(namespace) + } else { + throw new NoSuchNamespaceException(namespace) + } + } } -object InMemoryTableCatalog { +object V2InMemoryCatalog { val SIMULATE_FAILED_CREATE_PROPERTY = "spark.sql.test.simulateFailedCreate" val SIMULATE_DROP_BEFORE_REPLACE_PROPERTY = "spark.sql.test.simulateDropBeforeReplace" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryPartitionCatalog.scala similarity index 91% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryPartitionCatalog.scala index a24f5c9a0c463..76c1010524683 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryPartitionCatalog.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType -class InMemoryPartitionTableCatalog extends InMemoryTableCatalog { +class V2InMemoryPartitionCatalog extends V2InMemoryCatalog { import CatalogV2Implicits._ override def createTable( @@ -35,7 +35,7 @@ class InMemoryPartitionTableCatalog extends InMemoryTableCatalog { throw new TableAlreadyExistsException(ident) } - InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) + V2InMemoryCatalog.maybeSimulateFailedTableCreation(properties) val table = new InMemoryAtomicPartitionTable( s"$name.${ident.quoted}", schema, partitions, properties) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java index 59c5263563b27..b1ddeb1b88864 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; -import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog; +import org.apache.spark.sql.connector.catalog.V2InMemoryCatalog; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.StructType; import org.junit.After; @@ -43,7 +43,7 @@ public Dataset df() { @Before public void createTestTable() { this.spark = new TestSparkSession(); - spark.conf().set("spark.sql.catalog.testcat", InMemoryTableCatalog.class.getName()); + spark.conf().set("spark.sql.catalog.testcat", V2InMemoryCatalog.class.getName()); spark.sql("CREATE TABLE testcat.t (s string) USING foo"); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java new file mode 100644 index 0000000000000..3c98b7c520fae --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.connector.catalog.functions.AggregateFunction; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.StructType; + +import java.io.Serializable; + +public class JavaAverage implements UnboundFunction { + @Override + public String name() { + return "iavg"; + } + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + if (inputType.fields()[0].dataType() instanceof IntegerType) { + return new JavaAverageNoImpl(); + } + throw new UnsupportedOperationException("Unsupported non-integral type: " + + inputType.fields()[0].dataType()); + } + + @Override + public String description() { + return null; + } + + public static class JavaAverageNoImpl implements AggregateFunction { + @Override + public State newAggregationState() { + return new State(0, 0); + } + + @Override + public Integer produceResult(State state) { + return state.sum / state.count; + } + + @Override + public State merge(State leftState, State rightState) { + return new State(leftState.sum + rightState.sum, leftState.count + rightState.count); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.LongType }; + } + + @Override + public DataType resultType() { + return DataTypes.LongType; + } + + @Override + public String name() { + return "iavg"; + } + } + + public static class State implements Serializable { + int sum, count; + + State(int left, int count) { + this.sum = left; + this.count = count; + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java new file mode 100644 index 0000000000000..93f969358c366 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.UTF8String; + +public class JavaStrLen implements UnboundFunction { + private final BoundFunction fn; + + public JavaStrLen(BoundFunction fn) { + this.fn = fn; + } + + @Override + public String name() { + return "strlen"; + } + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + + if (inputType.fields()[0].dataType() instanceof StringType) { + return fn; + } + + throw new UnsupportedOperationException("Except StringType"); + } + + @Override + public String description() { + return "strlen: returns the length of the input string\n" + + " strlen(string) -> int"; + } + + public static class JavaStrLenDefault implements ScalarFunction { + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.StringType }; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String name() { + return "strlen"; + } + + @Override + public Integer produceResult(InternalRow input) { + String str = input.getString(0); + return str.length(); + } + } + + public static class JavaStrLenMagic implements ScalarFunction { + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.StringType }; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String name() { + return "strlen"; + } + + public int invoke(UTF8String str) { + return str.toString().length(); + } + } + + public static class JavaStrLenNoImpl implements ScalarFunction { + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.StringType }; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String name() { + return "strlen"; + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index c06544ee00621..15be034bd57ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.SchemaRequiredDataSource -import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog +import org.apache.spark.sql.connector.catalog.V2InMemoryPartitionCatalog import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf @@ -870,7 +870,7 @@ class DSV2CharVarcharTestSuite extends CharVarcharTestSuite override def format: String = "foo" protected override def sparkConf = { super.sparkConf - .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) + .set("spark.sql.catalog.testcat", classOf[V2InMemoryPartitionCatalog].getName) .set(SQLConf.DEFAULT_CATALOG.key, "testcat") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index 8aef27a1b6692..20c7ae5947d35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -25,7 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, InMemoryTableCatalog, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, TableCatalog, V2InMemoryCatalog} import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -48,7 +48,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo private val defaultOwnership = Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName()) before { - spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index 2f56fbaf7f821..d11b922ea67bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog +import org.apache.spark.sql.connector.catalog.V2InMemoryPartitionCatalog import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.unsafe.types.UTF8String @@ -320,7 +320,7 @@ class DSV2SQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession protected override def sparkConf: SparkConf = { super.sparkConf - .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) + .set("spark.sql.catalog.testcat", classOf[V2InMemoryPartitionCatalog].getName) .set(SQLConf.DEFAULT_CATALOG.key, "testcat") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index d83d1a2755928..8ca900c0abff4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -34,8 +34,8 @@ class DataSourceV2DataFrameSuite import testImplicits._ before { - spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) - spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat2", classOf[V2InMemoryCatalog].getName) } after { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala new file mode 100644 index 0000000000000..76c396f4d784a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util +import java.util.Collections + +import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaStrLen} +import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsNamespaces, V2InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { + private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + + private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = { + catalog("testcat").asInstanceOf[V2InMemoryCatalog].createFunction(ident, fn) + } + + test("scalar function: with default produceResult method") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: lookup magic method") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: bad magic method") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagic)) + assert(intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect()) + .getMessage.contains("Cannot find a compatible")) + } + + test("scalar function: bad magic method with default impl") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagicWithDefault)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: no implementation found") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenNoImpl)) + intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect()) + } + + test("scalar function: invalid parameter type or length") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault)) + + assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen(42)")) + .getMessage.contains("cannot process input")) + assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('a', 'b')")) + .getMessage.contains("cannot process input")) + } + + test("scalar function: default produceResult in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenDefault)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: magic method in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: no implementation found in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenNoImpl)) + assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')").collect()) + .getMessage.contains("neither implement magic method nor override 'produceResult'")) + } + + test("bad bound function (neither scalar nor aggregate)") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction)) + + assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')")) + .getMessage.contains("does not implement ScalarFunction or AggregateFunction")) + } + + test("aggregate function: lookup int average") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + + (1 to 100).toDF("i").write.saveAsTable(t) + checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil) + } + } + + test("aggregate function: lookup long average") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + + (1L to 100L).toDF("i").write.saveAsTable(t) + checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil) + } + } + + test("aggregate function: unsupported input type") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + + Seq(1.toShort, 2.toShort).toDF("i").write.saveAsTable(t) + assert(intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t")) + .getMessage.contains("Unsupported non-integral type: ShortType")) + } + } + + test("aggregate function: doesn't implement update should throw runtime error") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + + Seq(1.toByte, 2.toByte).toDF("i").write.saveAsTable(t) + assert(intercept[SparkException](sql(s"SELECT testcat.ns.avg(i) from $t").collect()) + .getMessage.contains("Cannot find a compatible AggregateFunction")) + } + } + + test("aggregate function: doesn't implement update in Java should throw analysis error") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), new JavaAverage) + + (1 to 100).toDF("i").write.saveAsTable(t) + assert(intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t").collect()) + .getMessage.contains("neither implement magic method nor override 'update'")) + } + } + + private case class StrLen(impl: BoundFunction) extends UnboundFunction { + override def description(): String = + """strlen: returns the length of the input string + | strlen(string) -> int""".stripMargin + override def name(): String = "strlen" + + override def bind(inputType: StructType): BoundFunction = { + if (inputType.fields.length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + inputType.fields(0).dataType match { + case StringType => impl + case _ => + throw new UnsupportedOperationException("Expect StringType") + } + } + } + + private case object StrLenDefault extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_default" + + override def produceResult(input: InternalRow): Int = { + val s = input.getString(0) + s.length + } + } + + private case object StrLenMagic extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_magic" + + def invoke(input: UTF8String): Int = { + input.toString.length + } + } + + private case object StrLenBadMagic extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_bad_magic" + + def invoke(input: String): Int = { + input.length + } + } + + private case object StrLenBadMagicWithDefault extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_bad_magic" + + def invoke(input: String): Int = { + input.length + } + + override def produceResult(input: InternalRow): Int = { + val s = input.getString(0) + s.length + } + } + + private case object StrLenNoImpl extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_noimpl" + } + + private case object BadBoundFunction extends BoundFunction { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "bad_bound_func" + } + + object IntegralAverage extends UnboundFunction { + override def name(): String = "iavg" + + override def bind(inputType: StructType): BoundFunction = { + if (inputType.fields.length > 1) { + throw new UnsupportedOperationException("Too many arguments") + } + + inputType.fields(0).dataType match { + case _: ByteType => ByteAverage + case _: IntegerType => IntAverage + case _: LongType => LongAverage + case dataType => + throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType") + } + } + + override def description(): String = + """iavg: produces an average using integer division, ignoring nulls + | iavg(int) -> int + | iavg(bigint) -> bigint""".stripMargin + } + + object IntAverage extends AggregateFunction[(Int, Int), Int] { + override def name(): String = "iavg" + override def inputTypes(): Array[DataType] = Array(IntegerType) + override def resultType(): DataType = IntegerType + + override def newAggregationState(): (Int, Int) = (0, 0) + + override def update(state: (Int, Int), input: InternalRow): (Int, Int) = { + if (input.isNullAt(0)) { + state + } else { + val i = input.getInt(0) + state match { + case (_, 0) => + (i, 1) + case (total, count) => + (total + i, count + 1) + } + } + } + + override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = { + (leftState._1 + rightState._1, leftState._2 + rightState._2) + } + + override def produceResult(state: (Int, Int)): Int = state._1 / state._2 + } + + object LongAverage extends AggregateFunction[(Long, Long), Long] { + override def name(): String = "iavg" + override def inputTypes(): Array[DataType] = Array(LongType) + override def resultType(): DataType = LongType + + override def newAggregationState(): (Long, Long) = (0L, 0L) + + override def update(state: (Long, Long), input: InternalRow): (Long, Long) = { + if (input.isNullAt(0)) { + state + } else { + val l = input.getLong(0) + state match { + case (_, 0L) => + (l, 1) + case (total, count) => + (total + l, count + 1L) + } + } + } + + override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = { + (leftState._1 + rightState._1, leftState._2 + rightState._2) + } + + override def produceResult(state: (Long, Long)): Long = state._1 / state._2 + } + + /** Bad implementation which doesn't override `produceResult` */ + object ByteAverage extends AggregateFunction[(Long, Long), Long] { + override def name(): String = "iavg" + override def inputTypes(): Array[DataType] = Array(LongType) + override def resultType(): DataType = LongType + + override def newAggregationState(): (Long, Long) = (0L, 0L) + + override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = { + (leftState._1 + rightState._1, leftState._2 + rightState._2) + } + + override def produceResult(state: (Long, Long)): Long = state._1 / state._2 + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 13facc36876b2..9bc4a96afaccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -494,7 +494,7 @@ class DataSourceV2SQLSuite intercept[Exception] { spark.sql("REPLACE TABLE testcat.table_name" + s" USING foo" + - s" TBLPROPERTIES (`${InMemoryTableCatalog.SIMULATE_FAILED_CREATE_PROPERTY}`=true)" + + s" TBLPROPERTIES (`${V2InMemoryCatalog.SIMULATE_FAILED_CREATE_PROPERTY}`=true)" + s" AS SELECT id FROM source") } @@ -519,7 +519,7 @@ class DataSourceV2SQLSuite intercept[Exception] { spark.sql("REPLACE TABLE testcat_atomic.table_name" + s" USING foo" + - s" TBLPROPERTIES (`${InMemoryTableCatalog.SIMULATE_FAILED_CREATE_PROPERTY}`=true)" + + s" TBLPROPERTIES (`${V2InMemoryCatalog.SIMULATE_FAILED_CREATE_PROPERTY}`=true)" + s" AS SELECT id FROM source") } @@ -578,7 +578,7 @@ class DataSourceV2SQLSuite } test("ReplaceTableAsSelect: REPLACE TABLE throws exception if table is dropped before commit.") { - import InMemoryTableCatalog._ + import V2InMemoryCatalog._ spark.sql(s"CREATE TABLE testcat_atomic.created USING $v2Source AS SELECT id, data FROM source") intercept[CannotReplaceMissingTableException] { spark.sql(s"REPLACE TABLE testcat_atomic.replaced" + @@ -1390,7 +1390,7 @@ class DataSourceV2SQLSuite "and namespace does not exist") { // Namespaces are not required to exist for v2 catalogs // that does not implement SupportsNamespaces. - withSQLConf("spark.sql.catalog.dummy" -> classOf[BasicInMemoryTableCatalog].getName) { + withSQLConf("spark.sql.catalog.dummy" -> classOf[BasicInMemoryCatalog].getName) { val catalogManager = spark.sessionState.catalogManager sql("USE dummy.ns1") @@ -1547,7 +1547,7 @@ class DataSourceV2SQLSuite |CLUSTERED BY (`a.b`) INTO 4 BUCKETS """.stripMargin) - val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] + val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[V2InMemoryCatalog] val table = testCatalog.loadTable(Identifier.of(Array.empty, "t")) val partitioning = table.partitioning() assert(partitioning.length == 1 && partitioning.head.name() == "bucket") @@ -1614,7 +1614,7 @@ class DataSourceV2SQLSuite withTable(t) { sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] + val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[V2InMemoryCatalog] val identifier = Identifier.of(Array("ns1", "ns2"), "tbl") assert(!testCatalog.isTableInvalidated(identifier)) @@ -1630,7 +1630,7 @@ class DataSourceV2SQLSuite sql("CREATE TEMPORARY VIEW t AS SELECT 2") sql("USE testcat.ns") - val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] + val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[V2InMemoryCatalog] val identifier = Identifier.of(Array("ns"), "t") assert(!testCatalog.isTableInvalidated(identifier)) @@ -2142,7 +2142,7 @@ class DataSourceV2SQLSuite test("global temp view should not be masked by v2 catalog") { val globalTempDB = spark.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) - spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) + spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[V2InMemoryCatalog].getName) try { sql("create global temp view v as select 1") @@ -2167,7 +2167,7 @@ class DataSourceV2SQLSuite test("SPARK-30104: v2 catalog named global_temp will be masked") { val globalTempDB = spark.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) - spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) + spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[V2InMemoryCatalog].getName) val e = intercept[AnalysisException] { // Since the following multi-part name starts with `globalTempDB`, it is resolved to @@ -2366,7 +2366,7 @@ class DataSourceV2SQLSuite intercept[AnalysisException](sql("COMMENT ON TABLE testcat.abc IS NULL")) val globalTempDB = spark.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) - spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) + spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[V2InMemoryCatalog].getName) withTempView("v") { sql("create global temp view v as select 1") val e = intercept[AnalysisException](sql("COMMENT ON TABLE global_temp.v IS NULL")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala index 3ef242f90f7e7..723d9148eb60d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala @@ -32,11 +32,11 @@ trait DatasourceV2SQLBase } before { - spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) - spark.conf.set("spark.sql.catalog.testpart", classOf[InMemoryPartitionTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) + spark.conf.set("spark.sql.catalog.testpart", classOf[V2InMemoryPartitionCatalog].getName) spark.conf.set( - "spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryTableCatalog].getName) - spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName) + "spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat2", classOf[V2InMemoryCatalog].getName) spark.conf.set( V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 076dad7530807..a2566beffc7ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -55,7 +55,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with spark.conf.set( V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName) spark.conf.set( - s"spark.sql.catalog.$catalogName", classOf[InMemoryTableCatalog].getName) + s"spark.sql.catalog.$catalogName", classOf[V2InMemoryCatalog].getName) } override def afterEach(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index 847953e09cef7..1328f6f61d764 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -101,7 +101,7 @@ class V1ReadFallbackWithCatalogSuite extends V1ReadFallbackSuite { } } -class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { +class V1ReadFallbackCatalog extends BasicInMemoryCatalog { override def createTable( ident: Identifier, schema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index db4a9c153c0ff..6c17a7dcd347d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -44,7 +44,7 @@ class WriteDistributionAndOrderingSuite import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ before { - spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) } after { @@ -756,9 +756,9 @@ class WriteDistributionAndOrderingSuite UnresolvedAttribute(name) } - private def catalog: InMemoryTableCatalog = { + private def catalog: V2InMemoryCatalog = { val catalog = spark.sessionState.catalogManager.catalog("testcat") - catalog.asTableCatalog.asInstanceOf[InMemoryTableCatalog] + catalog.asTableCatalog.asInstanceOf[V2InMemoryCatalog] } // executes a write operation and keeps the executed physical plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala index f8366b3f7c5fa..b60b390a75ff9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.time.{Duration, Period} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils -import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog +import org.apache.spark.sql.connector.catalog.V2InMemoryCatalog import org.apache.spark.sql.execution.HiveResult._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} @@ -80,7 +80,7 @@ class HiveResultSuite extends SharedSparkSession { } test("SHOW TABLES in hive result") { - withSQLConf("spark.sql.catalog.testcat" -> classOf[InMemoryTableCatalog].getName) { + withSQLConf("spark.sql.catalog.testcat" -> classOf[V2InMemoryCatalog].getName) { Seq(("testcat.ns", "tbl", "foo"), ("spark_catalog.default", "tbl", "csv")).foreach { case (ns, tbl, source) => withTable(s"$ns.$tbl") { @@ -94,7 +94,7 @@ class HiveResultSuite extends SharedSparkSession { } test("DESCRIBE TABLE in hive result") { - withSQLConf("spark.sql.catalog.testcat" -> classOf[InMemoryTableCatalog].getName) { + withSQLConf("spark.sql.catalog.testcat" -> classOf[V2InMemoryCatalog].getName) { Seq(("testcat.ns", "tbl", "foo"), ("spark_catalog.default", "tbl", "csv")).foreach { case (ns, tbl, source) => withTable(s"$ns.$tbl") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala index ba683c049a631..5e2e6cc592727 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog +import org.apache.spark.sql.connector.catalog.V2InMemoryPartitionCatalog import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types._ @@ -151,7 +151,7 @@ class DSV2CharVarcharDDLTestSuite extends CharVarcharDDLTestBase override def format: String = "foo" protected override def sparkConf = { super.sparkConf - .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) + .set("spark.sql.catalog.testcat", classOf[V2InMemoryPartitionCatalog].getName) .set(SQLConf.DEFAULT_CATALOG.key, "testcat") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala index bed04f4f2659b..7d2df64583e09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.ResolvePartitionSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryPartitionTable, V2InMemoryCatalog, V2InMemoryPartitionCatalog} import org.apache.spark.sql.test.SharedSparkSession /** @@ -36,8 +36,8 @@ trait CommandSuiteBase extends SharedSparkSession { // V2 catalogs created and used especially for testing override def sparkConf: SparkConf = super.sparkConf - .set(s"spark.sql.catalog.$catalog", classOf[InMemoryPartitionTableCatalog].getName) - .set(s"spark.sql.catalog.non_part_$catalog", classOf[InMemoryTableCatalog].getName) + .set(s"spark.sql.catalog.$catalog", classOf[V2InMemoryPartitionCatalog].getName) + .set(s"spark.sql.catalog.non_part_$catalog", classOf[V2InMemoryCatalog].getName) def checkLocation( t: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala index bafb6608c8e6c..fb7ffc9967b2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.BasicInMemoryTableCatalog +import org.apache.spark.sql.connector.catalog.BasicInMemoryCatalog import org.apache.spark.sql.execution.command import org.apache.spark.sql.internal.SQLConf @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf */ class ShowNamespacesSuite extends command.ShowNamespacesSuiteBase with CommandSuiteBase { override def sparkConf: SparkConf = super.sparkConf - .set("spark.sql.catalog.testcat_no_namespace", classOf[BasicInMemoryTableCatalog].getName) + .set("spark.sql.catalog.testcat_no_namespace", classOf[BasicInMemoryCatalog].getName) test("IN namespace doesn't exist") { withSQLConf(SQLConf.DEFAULT_CATALOG.key -> catalog) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index 49e5218ea3352..de67c54332c24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.{FakeV2Provider, InMemoryTableSessionCatalog} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, SupportsRead, Table, TableCapability, V2TableWithV1Fallback} +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table, TableCapability, V2InMemoryCatalog, V2TableWithV1Fallback} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.streaming.{MemoryStream, MemoryStreamScanBuilder} @@ -46,8 +46,8 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ before { - spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) - spark.conf.set("spark.sql.catalog.teststream", classOf[InMemoryStreamTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) + spark.conf.set("spark.sql.catalog.teststream", classOf[V2InMemoryStreamCatalog].getName) } after { @@ -157,7 +157,7 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { test("read: fallback to V1 relation") { val tblName = DataStreamTableAPISuite.V1FallbackTestTableName spark.conf.set(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION.key, - classOf[InMemoryStreamTableCatalog].getName) + classOf[V2InMemoryStreamCatalog].getName) val v2Source = classOf[FakeV2Provider].getName withTempDir { tempDir => withTable(tblName) { @@ -439,7 +439,7 @@ class NonStreamV2Table(override val name: String) } -class InMemoryStreamTableCatalog extends InMemoryTableCatalog { +class V2InMemoryStreamCatalog extends V2InMemoryCatalog { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def createTable( From 27d5a20d30ba7a056732d08626cabb4c64b85828 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 7 Apr 2021 17:14:49 -0700 Subject: [PATCH 02/21] fix Java code style --- .../spark/sql/connector/catalog/functions/ScalarFunction.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index 49999fb1fa24d..eac060396b32e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -57,8 +57,8 @@ * } * } * - * In this case, both {@link #MAGIC_METHOD_NAME} and {@link #produceResult} are defined, and Spark will - * first lookup the {@code invoke} method during query analysis. It checks whether the method + * In this case, both {@link #MAGIC_METHOD_NAME} and {@link #produceResult} are defined, and Spark + * will first lookup the {@code invoke} method during query analysis. It checks whether the method * parameters have the valid types that are supported by Spark. If the check fails it falls back * to use {@link #produceResult}. * From 8fa2b7b9879b2651589a16f4476987d9cb23b322 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 7 Apr 2021 17:54:59 -0700 Subject: [PATCH 03/21] more docs for ScalarFunction --- .../catalog/functions/ScalarFunction.java | 37 +++++++++++++++++-- .../sql/catalyst/analysis/Analyzer.scala | 9 +++-- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index eac060396b32e..88781ee4f7349 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -58,9 +58,40 @@ * } * * In this case, both {@link #MAGIC_METHOD_NAME} and {@link #produceResult} are defined, and Spark - * will first lookup the {@code invoke} method during query analysis. It checks whether the method - * parameters have the valid types that are supported by Spark. If the check fails it falls back - * to use {@link #produceResult}. + * will first lookup the {@link #MAGIC_METHOD_NAME} method during query analysis. This is done by + * first converting the actual input SQL data types to their corresponding Java types following the + * mapping defined below, and then checking if there is a matching method from all the declared + * methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and the Java + * types. If no magic method is found, Spark will falls back to use {@link #produceResult}. + *

+ * The following are the mapping from {@link DataType SQL data type} to Java type through + * the magic method approach: + *

    + *
  • {@link org.apache.spark.sql.types.BooleanType}: {@code boolean}
  • + *
  • {@link org.apache.spark.sql.types.ByteType}: {@code byte}
  • + *
  • {@link org.apache.spark.sql.types.ShortType}: {@code short}
  • + *
  • {@link org.apache.spark.sql.types.IntegerType}: {@code int}
  • + *
  • {@link org.apache.spark.sql.types.LongType}: {@code long}
  • + *
  • {@link org.apache.spark.sql.types.FloatType}: {@code float}
  • + *
  • {@link org.apache.spark.sql.types.DoubleType}: {@code double}
  • + *
  • {@link org.apache.spark.sql.types.StringType}: + * {@link org.apache.spark.unsafe.types.UTF8String}
  • + *
  • {@link org.apache.spark.sql.types.DateType}: {@code int}
  • + *
  • {@link org.apache.spark.sql.types.TimestampType}: {@code long}
  • + *
  • {@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}
  • + *
  • {@link org.apache.spark.sql.types.CalendarIntervalType}: + * {@link org.apache.spark.unsafe.types.CalendarInterval}
  • + *
  • {@link org.apache.spark.sql.types.DayTimeIntervalType}: {@code long}
  • + *
  • {@link org.apache.spark.sql.types.YearMonthIntervalType}: {@code int}
  • + *
  • {@link org.apache.spark.sql.types.DecimalType}: + * {@link org.apache.spark.sql.types.Decimal}
  • + *
  • {@link org.apache.spark.sql.types.StructType}: {@link InternalRow}
  • + *
  • {@link org.apache.spark.sql.types.ArrayType}: + * {@link org.apache.spark.sql.catalyst.util.ArrayData}
  • + *
  • {@link org.apache.spark.sql.types.MapType}: + * {@link org.apache.spark.sql.catalyst.util.MapData}
  • + *
  • any other type: {@code Object}
  • + *
* * @param the JVM type of result values */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 61f9cb498126f..b4c5908009ebe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2022,7 +2022,7 @@ class Analyzer(override val catalogManager: CatalogManager) } } case UnresolvedFunction(NonSessionCatalogAndIdentifier(v2Catalog, ident), arguments, - isDistinct, filter, ignoreNulls) if v2Catalog.isFunctionCatalog => + isDistinct, filter, _) if v2Catalog.isFunctionCatalog => val unbound = v2Catalog.asFunctionCatalog.loadFunction(ident) val inputType = StructType(arguments.zipWithIndex.map { @@ -2033,7 +2033,7 @@ class Analyzer(override val catalogManager: CatalogManager) unbound.bind(inputType) } catch { case unsupported: UnsupportedOperationException => - failAnalysis(s"Function ${unbound.name} cannot process input: " + + failAnalysis(s"Function '${unbound.name}' cannot process input: " + s"(${arguments.map(_.dataType.simpleString).mkString(", ")}): " + unsupported.getMessage) } @@ -2047,6 +2047,9 @@ class Analyzer(override val catalogManager: CatalogManager) throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( scalarFunc.name(), "FILTER clause") } else { + // TODO: implement type coercion by looking at input type from the UDF. We may + // also want to check if the parameter types from the magic method match the + // input type through `BoundFunction.inputTypes`. val argClasses = inputType.fields.map(_.dataType) findMethod(scalarFunc, ScalarFunction.MAGIC_METHOD_NAME, Some(argClasses)) match { case Some(_) => @@ -2079,7 +2082,7 @@ class Analyzer(override val catalogManager: CatalogManager) s"method nor override 'update'") } case _ => - failAnalysis(s"Function ${bound.name()} does not implement ScalarFunction or " + + failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction or " + s"AggregateFunction") } From 1edca4a384d491aef8bd6ef087db61da79e7ef0d Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 7 Apr 2021 23:43:58 -0700 Subject: [PATCH 04/21] fix Java doc --- .../catalog/functions/ScalarFunction.java | 20 +++++++++---------- .../catalog/functions/JavaAverage.java | 4 ++-- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index 88781ee4f7349..dad45dc345283 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -42,18 +42,16 @@ * method approach: * *
- *   {@code
- *     public class IntegerAdd implements ScalarFunction {
- *       public int invoke(int left, int right) {
- *         return left + right;
- *       }
+ *   public class IntegerAdd implements{@code ScalarFunction} {
+ *     public int invoke(int left, int right) {
+ *       return left + right;
+ *     }
  *
- *       @Overrides
- *       public produceResult(InternalRow input) {
- *         int left = input.getInt(0);
- *         int right = input.getInt(1);
- *         return left + right;
- *       }
+ *    {@literal @}Override
+ *     public produceResult(InternalRow input) {
+ *       int left = input.getInt(0);
+ *       int right = input.getInt(1);
+ *       return left + right;
  *     }
  *   }
  * 
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java index 3c98b7c520fae..5cc3cf58750f4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java @@ -85,8 +85,8 @@ public String name() { public static class State implements Serializable { int sum, count; - State(int left, int count) { - this.sum = left; + State(int sum, int count) { + this.sum = sum; this.count = count; } } From 1607d0e1c3d18d0f599b9dd2dee0fa411b4fa22a Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 12 Apr 2021 12:32:34 -0700 Subject: [PATCH 05/21] implement withNewChildrenInternal --- .../sql/catalyst/expressions/ApplyFunctionExpression.scala | 3 +++ .../sql/catalyst/expressions/aggregate/V2Aggregator.scala | 3 +++ 2 files changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala index 25fb4a0731b35..ce9c933902942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala @@ -40,4 +40,7 @@ case class ApplyFunctionExpression( function.produceResult(reusedRow) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala index 27568e6f1b119..7598e9dd6e7ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala @@ -61,5 +61,8 @@ case class V2Aggregator[BUF <: java.io.Serializable, OUT]( override def nullable: Boolean = aggrFunc.isResultNullable override def dataType: DataType = aggrFunc.resultType() + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } From 412f19112a25981ef480e27bfbe8bfda0fcb5b20 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 15 Apr 2021 22:58:50 -0700 Subject: [PATCH 06/21] address comments --- .../catalog/functions/ScalarFunction.java | 26 +- .../sql/catalyst/analysis/Analyzer.scala | 330 ++++++++++-------- .../sql/catalyst/catalog/SessionCatalog.scala | 55 ++- .../expressions/aggregate/V2Aggregator.scala | 14 +- .../sql/connector/catalog/LookupCatalog.scala | 20 ++ .../sql/errors/QueryCompilationErrors.scala | 6 - .../connector/catalog/V2InMemoryCatalog.scala | 131 +++++++ .../connector/DataSourceV2FunctionSuite.scala | 90 ++++- 8 files changed, 487 insertions(+), 185 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index dad45dc345283..86137b8e3dfcc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -23,9 +23,8 @@ /** * Interface for a function that produces a result value for each input row. *

- * For each input row, Spark will call a produceResult method that corresponds to the - * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by - * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call + * To evaluate each input row, Spark will try to first lookup and use a "magic method" (described + * below) through Java reflection. If the method is not found, Spark will call * {@link #produceResult(InternalRow)}. *

* The JVM type of result values produced by this function must be the type used by Spark's @@ -46,21 +45,14 @@ * public int invoke(int left, int right) { * return left + right; * } - * - * {@literal @}Override - * public produceResult(InternalRow input) { - * int left = input.getInt(0); - * int right = input.getInt(1); - * return left + right; - * } * } * - * In this case, both {@link #MAGIC_METHOD_NAME} and {@link #produceResult} are defined, and Spark - * will first lookup the {@link #MAGIC_METHOD_NAME} method during query analysis. This is done by - * first converting the actual input SQL data types to their corresponding Java types following the - * mapping defined below, and then checking if there is a matching method from all the declared - * methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and the Java - * types. If no magic method is found, Spark will falls back to use {@link #produceResult}. + * In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will first lookup it during + * query analysis. This is done by first converting the actual input SQL data types to their + * corresponding Java types following the mapping defined below, and then checking if there is a + * matching method from all the declared methods in the UDF class, using method name (i.e., + * {@link #MAGIC_METHOD_NAME}) and the Java types. If no magic method is found, Spark will falls + * back to use {@link #produceResult}. *

* The following are the mapping from {@link DataType SQL data type} to Java type through * the magic method approach: @@ -77,8 +69,6 @@ *

  • {@link org.apache.spark.sql.types.DateType}: {@code int}
  • *
  • {@link org.apache.spark.sql.types.TimestampType}: {@code long}
  • *
  • {@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}
  • - *
  • {@link org.apache.spark.sql.types.CalendarIntervalType}: - * {@link org.apache.spark.unsafe.types.CalendarInterval}
  • *
  • {@link org.apache.spark.sql.types.DayTimeIntervalType}: {@code long}
  • *
  • {@link org.apache.spark.sql.types.YearMonthIntervalType}: {@code int}
  • *
  • {@link org.apache.spark.sql.types.DecimalType}: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b4c5908009ebe..28c1d98ccf849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes -import org.apache.spark.sql.catalyst.expressions.{FrameLessOffsetWindowFunction, _} +import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -46,6 +46,7 @@ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction} +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -1960,8 +1961,13 @@ class Analyzer(override val catalogManager: CatalogManager) override def apply(plan: LogicalPlan): LogicalPlan = { val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() plan.resolveExpressions { - case f @ UnresolvedFunction(NonSessionCatalogAndIdentifier(_, _), _, _, _, _) => - // no-op if this is from a v2 catalog + case f @ UnresolvedFunction(NonSessionCatalogAndIdentifier(catalog, name), _, _, _, _) => + if (!catalog.isFunctionCatalog) { + withPosition(f) { + throw new AnalysisException(s"Trying to lookup function '$name' in catalog" + + s" '${catalog.name()}', but '${catalog.name()}' is not a FunctionCatalog.") + } + } f case f: UnresolvedFunction if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f @@ -2021,150 +2027,32 @@ class Analyzer(override val catalogManager: CatalogManager) name, other.getClass.getCanonicalName) } } - case UnresolvedFunction(NonSessionCatalogAndIdentifier(v2Catalog, ident), arguments, - isDistinct, filter, _) if v2Catalog.isFunctionCatalog => - val unbound = v2Catalog.asFunctionCatalog.loadFunction(ident) - - val inputType = StructType(arguments.zipWithIndex.map { - case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) - }) - - val bound = try { - unbound.bind(inputType) - } catch { - case unsupported: UnsupportedOperationException => - failAnalysis(s"Function '${unbound.name}' cannot process input: " + - s"(${arguments.map(_.dataType.simpleString).mkString(", ")}): " + - unsupported.getMessage) - } - bound match { - case scalarFunc: ScalarFunction[_] => - if (isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "DISTINCT") - } else if (filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "FILTER clause") - } else { - // TODO: implement type coercion by looking at input type from the UDF. We may - // also want to check if the parameter types from the magic method match the - // input type through `BoundFunction.inputTypes`. - val argClasses = inputType.fields.map(_.dataType) - findMethod(scalarFunc, ScalarFunction.MAGIC_METHOD_NAME, Some(argClasses)) match { - case Some(_) => - val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) - Invoke(caller, ScalarFunction.MAGIC_METHOD_NAME, scalarFunc.resultType(), - arguments, returnNullable = scalarFunc.isResultNullable) - case _ => - // TODO: handle functions defined in Scala too - in Scala, even if a - // subclass do not override the default method in parent interface defined - // in Java, the method can still be found from `getDeclaredMethod`. - findMethod(scalarFunc, "produceResult", Some(Seq(inputType))) match { - case Some(_) => - ApplyFunctionExpression(scalarFunc, arguments) - case None => - failAnalysis(s"ScalarFunction '${bound.name()}' neither implement " + - s"magic method nor override 'produceResult'") - } - } - } - case aggFunc: V2AggregateFunction[_, _] => - // due to type erasure we can't match by parameter types here, so this check will - // succeed even if the class doesn't override `update` but implements another - // method with the same name. - findMethod(aggFunc, "update") match { - case Some(_) => - val aggregator = V2Aggregator(aggFunc, arguments) - AggregateExpression(aggregator, Complete, isDistinct, filter) - case None => - failAnalysis(s"AggregateFunction '${bound.name()}' neither implement magic " + - s"method nor override 'update'") - } - case _ => - failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction or " + - s"AggregateFunction") + case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, + isDistinct, filter, ignoreNulls) => + withPosition(u) { + processFunctionExpr(v1SessionCatalog.lookupFunction(ident, arguments), + arguments, isDistinct, filter, ignoreNulls) } case u @ UnresolvedFunction(parts, arguments, isDistinct, filter, ignoreNulls) => withPosition(u) { - v1SessionCatalog.lookupFunction(parts.asFunctionIdentifier, arguments) match { - // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within - // the context of a Window clause. They do not need to be wrapped in an - // AggregateExpression. - case wf: AggregateWindowFunction => - if (isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - wf.prettyName, "DISTINCT") - } else if (filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - wf.prettyName, "FILTER clause") - } else if (ignoreNulls) { - wf match { - case nthValue: NthValue => - nthValue.copy(ignoreNulls = ignoreNulls) - case _ => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - wf.prettyName, "IGNORE NULLS") - } - } else { - wf - } - case owf: FrameLessOffsetWindowFunction => - if (isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - owf.prettyName, "DISTINCT") - } else if (filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - owf.prettyName, "FILTER clause") - } else if (ignoreNulls) { - owf match { - case lead: Lead => - lead.copy(ignoreNulls = ignoreNulls) - case lag: Lag => - lag.copy(ignoreNulls = ignoreNulls) - } - } else { - owf - } - // We get an aggregate function, we need to wrap it in an AggregateExpression. - case agg: AggregateFunction => - if (filter.isDefined && !filter.get.deterministic) { - throw QueryCompilationErrors.nonDeterministicFilterInAggregateError - } - if (ignoreNulls) { - val aggFunc = agg match { - case first: First => first.copy(ignoreNulls = ignoreNulls) - case last: Last => last.copy(ignoreNulls = ignoreNulls) - case _ => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - agg.prettyName, "IGNORE NULLS") - } - AggregateExpression(aggFunc, Complete, isDistinct, filter) - } else { - AggregateExpression(agg, Complete, isDistinct, filter) - } - // This function is not an aggregate function, just return the resolved one. - case other if isDistinct => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - other.prettyName, "DISTINCT") - case other if filter.isDefined => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - other.prettyName, "FILTER clause") - case other if ignoreNulls => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - other.prettyName, "IGNORE NULLS") - case e: String2TrimExpression if arguments.size == 2 => - if (trimWarningEnabled.get) { - log.warn("Two-parameter TRIM/LTRIM/RTRIM function signatures are deprecated." + - " Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM str)`" + - " instead.") - trimWarningEnabled.set(false) - } - e - case other => - other + // resolve built-in or temporary functions with v2 catalog + val resultExpression = if (parts.length == 1) { + v1SessionCatalog.lookupBuiltinOrTempFunction(parts.head, arguments).map( + processFunctionExpr(_, arguments, isDistinct, filter, ignoreNulls) + ) + } else { + None } + + resultExpression.getOrElse( + expandRelationName(parts) match { + case NonSessionCatalogAndIdentifier(catalog: FunctionCatalog, ident) => + lookupV2Function(catalog, ident, arguments, isDistinct, filter, ignoreNulls) + case _ => u + } + ) } } } @@ -2191,6 +2079,166 @@ class Analyzer(override val catalogManager: CatalogManager) cls.getDeclaredMethods.find(_.getName == methodName) } } + + private def processFunctionExpr( + expr: Expression, + arguments: Seq[Expression], + isDistinct: Boolean, + filter: Option[Expression], + ignoreNulls: Boolean): Expression = expr match { + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => + if (isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + wf.prettyName, "DISTINCT") + } else if (filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + wf.prettyName, "FILTER clause") + } else if (ignoreNulls) { + wf match { + case nthValue: NthValue => + nthValue.copy(ignoreNulls = ignoreNulls) + case _ => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + wf.prettyName, "IGNORE NULLS") + } + } else { + wf + } + case owf: FrameLessOffsetWindowFunction => + if (isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + owf.prettyName, "DISTINCT") + } else if (filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + owf.prettyName, "FILTER clause") + } else if (ignoreNulls) { + owf match { + case lead: Lead => + lead.copy(ignoreNulls = ignoreNulls) + case lag: Lag => + lag.copy(ignoreNulls = ignoreNulls) + } + } else { + owf + } + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg: AggregateFunction => + if (filter.isDefined && !filter.get.deterministic) { + throw QueryCompilationErrors.nonDeterministicFilterInAggregateError + } + if (ignoreNulls) { + val aggFunc = agg match { + case first: First => first.copy(ignoreNulls = ignoreNulls) + case last: Last => last.copy(ignoreNulls = ignoreNulls) + case _ => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + agg.prettyName, "IGNORE NULLS") + } + AggregateExpression(aggFunc, Complete, isDistinct, filter) + } else { + AggregateExpression(agg, Complete, isDistinct, filter) + } + // This function is not an aggregate function, just return the resolved one. + case other if isDistinct => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + other.prettyName, "DISTINCT") + case other if filter.isDefined => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + other.prettyName, "FILTER clause") + case other if ignoreNulls => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + other.prettyName, "IGNORE NULLS") + case e: String2TrimExpression if arguments.size == 2 => + if (trimWarningEnabled.get) { + log.warn("Two-parameter TRIM/LTRIM/RTRIM function signatures are deprecated." + + " Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM str)`" + + " instead.") + trimWarningEnabled.set(false) + } + e + case other => + other + } + + private def lookupV2Function( + catalog: FunctionCatalog, + ident: Identifier, + arguments: Seq[Expression], + isDistinct: Boolean, + filter: Option[Expression], + ignoreNulls: Boolean): Expression = { + val unbound = catalog.loadFunction(ident) + val inputType = StructType(arguments.zipWithIndex.map { + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) + }) + val bound = try { + unbound.bind(inputType) + } catch { + case unsupported: UnsupportedOperationException => + failAnalysis(s"Function '${unbound.name}' cannot process input: " + + s"(${arguments.map(_.dataType.simpleString).mkString(", ")}): " + + unsupported.getMessage) + } + + bound match { + case scalarFunc: ScalarFunction[_] => + if (isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "DISTINCT") + } else if (filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "FILTER clause") + } else if (ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "IGNORE NULLS") + } else { + // TODO: implement type coercion by looking at input type from the UDF. We may + // also want to check if the parameter types from the magic method match the + // input type through `BoundFunction.inputTypes`. + val argClasses = inputType.fields.map(_.dataType) + findMethod(scalarFunc, MAGIC_METHOD_NAME, Some(argClasses)) match { + case Some(_) => + val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) + Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), + arguments, returnNullable = scalarFunc.isResultNullable) + case _ => + // TODO: handle functions defined in Scala too - in Scala, even if a + // subclass do not override the default method in parent interface + // defined in Java, the method can still be found from + // `getDeclaredMethod`. + findMethod(scalarFunc, "produceResult", Some(Seq(inputType))) match { + case Some(_) => + ApplyFunctionExpression(scalarFunc, arguments) + case None => + failAnalysis(s"ScalarFunction '${bound.name()}' neither implement " + + s"magic method nor override 'produceResult'") + } + } + } + case aggFunc: V2AggregateFunction[_, _] => + if (ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + aggFunc.name(), "IGNORE NULLS") + } + // due to type erasure we can't match by parameter types here, so this check + // will succeed even if the class doesn't override `update` but implements + // another method with the same name. + findMethod(aggFunc, "update") match { + case Some(_) => + val aggregator = V2Aggregator(aggFunc, arguments) + AggregateExpression(aggregator, Complete, isDistinct, filter) + case None => + failAnalysis(s"AggregateFunction '${bound.name()}' neither implement " + + s"magic method nor override 'update'") + } + case _ => + failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction " + + s"or AggregateFunction") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index d259d6a706d72..edd42486dff5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1559,6 +1559,34 @@ class SessionCatalog( } } + /** + * Lookup `registry` and check if a built-in or temporary function is defined for the input + * `name`. None if no such function exists. + * + * This is currently used by both V1 function lookup (in `lookupFunction`), and V2 + * function lookup (in `Analyzer`). + */ + private def lookupBuiltinOrTempFunctionInfo[T]( + name: String, + children: Seq[Expression], + registry: FunctionRegistryBase[T]): Option[T] = synchronized { + val ident = FunctionIdentifier(name) + if (registry.functionExists(ident)) { + val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames + val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty + // Lookup the function as a temporary or a built-in function (i.e. without database) and + // 1. if we are not resolving view, we don't care about the function type and just return it. + // 2. if we are resolving view, only return a temp function if it's referred by this view. + if (!isResolvingView || + !isTemporaryFunction(ident) || + referredTempFunctionNames.contains(ident.funcName)) { + // This function has been already loaded into the function registry. + return Some(registry.lookupFunction(ident, children)) + } + } + None + } + /** * Look up a specific function, assuming it exists. * @@ -1576,20 +1604,15 @@ class SessionCatalog( name: FunctionIdentifier, children: Seq[Expression], registry: FunctionRegistryBase[T]): T = synchronized { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + // Note: the implementation of this function is a little bit convoluted. // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions // (built-in, temp, and external). - if (name.database.isEmpty && registry.functionExists(name)) { - val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames - val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty - // Lookup the function as a temporary or a built-in function (i.e. without database) and - // 1. if we are not resolving view, we don't care about the function type and just return it. - // 2. if we are resolving view, only return a temp function if it's referred by this view. - if (!isResolvingView || - !isTemporaryFunction(name) || - referredTempFunctionNames.contains(name.funcName)) { - // This function has been already loaded into the function registry. - return registry.lookupFunction(name, children) + if (name.database.isEmpty) { + val funcInfo = lookupBuiltinOrTempFunctionInfo(name.funcName, children, registry) + if (funcInfo.isDefined) { + return funcInfo.get } } @@ -1598,7 +1621,8 @@ class SessionCatalog( case Seq() => getCurrentDatabase case Seq(_, db) => db case Seq(catalog, namespace @ _*) => - throw QueryCompilationErrors.v2CatalogNotSupportFunctionError(catalog, namespace) + throw new IllegalStateException(s"[BUG] unexpected v2 catalog: $catalog, and " + + s"namespace: ${namespace.quoted} in v1 function lookup") } // If the name itself is not qualified, add the current database to it. @@ -1645,6 +1669,13 @@ class SessionCatalog( lookupFunction[LogicalPlan](name, children, tableFunctionRegistry) } + /** + * Return a optional [[Expression]] for the input built-in or temporary function with name + * `name`. None if the function doesn't exist. + */ + def lookupBuiltinOrTempFunction(name: String, children: Seq[Expression]): Option[Expression] = { + lookupBuiltinOrTempFunctionInfo[Expression](name, children, functionRegistry) + } /** * List all functions in the specified database, including temporary functions. This * returns the function identifier and the scope in which it was defined (system or user diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala index 7598e9dd6e7ad..55e3f504ae2e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction} import org.apache.spark.sql.types.DataType @@ -29,9 +29,15 @@ case class V2Aggregator[BUF <: java.io.Serializable, OUT]( children: Seq[Expression], mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[BUF] { + private[this] lazy val inputProjection = UnsafeProjection.create(children) + + override def nullable: Boolean = aggrFunc.isResultNullable + override def dataType: DataType = aggrFunc.resultType() override def createAggregationBuffer(): BUF = aggrFunc.newAggregationState() - override def update(buffer: BUF, input: InternalRow): BUF = aggrFunc.update(buffer, input) + override def update(buffer: BUF, input: InternalRow): BUF = { + aggrFunc.update(buffer, inputProjection(input)) + } override def merge(buffer: BUF, input: BUF): BUF = aggrFunc.merge(buffer, input) @@ -58,10 +64,6 @@ case class V2Aggregator[BUF <: java.io.Serializable, OUT]( def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): V2Aggregator[BUF, OUT] = copy(inputAggBufferOffset = newInputAggBufferOffset) - override def nullable: Boolean = aggrFunc.isResultNullable - - override def dataType: DataType = aggrFunc.resultType() - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index af951a0e7aa66..d0d226e27f925 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -154,6 +154,26 @@ private[sql] trait LookupCatalog extends Logging { } } + object AsFunctionIdentifier { + def unapply(parts: Seq[String]): Option[FunctionIdentifier] = { + def namesToFunctionIdentifier(names: Seq[String]): Option[FunctionIdentifier] = names match { + case Seq(name) => Some(FunctionIdentifier(name)) + case Seq(database, name) => Some(FunctionIdentifier(name, Some(database))) + case _ => None + } + parts match { + case CatalogAndMultipartIdentifier(None, names) + if CatalogV2Util.isSessionCatalog(currentCatalog) => + namesToFunctionIdentifier(names) + case CatalogAndMultipartIdentifier(Some(catalog), names) + if CatalogV2Util.isSessionCatalog(catalog) && + CatalogV2Util.isSessionCatalog(currentCatalog) => + namesToFunctionIdentifier(names) + case _ => None + } + } + } + def parseSessionCatalogFunctionIdentifier(nameParts: Seq[String]): FunctionIdentifier = { if (nameParts.length == 1 && catalogManager.v1SessionCatalog.isTempFunction(nameParts.head)) { return FunctionIdentifier(nameParts.head) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index a3fbe4c742160..cf065125c4f0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -617,12 +617,6 @@ private[spark] object QueryCompilationErrors { s"the function '$func', please make sure it is on the classpath") } - def v2CatalogNotSupportFunctionError( - catalog: String, namespace: Seq[String]): Throwable = { - new AnalysisException("V2 catalog does not support functions yet. " + - s"catalog: $catalog, namespace: '${namespace.quoted}'") - } - def resourceTypeNotSupportedError(resourceType: String): Throwable = { new AnalysisException(s"Resource Type '$resourceType' is not supported.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala index b0bc6d054721b..68b210ef36cc7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala @@ -29,6 +29,137 @@ import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +class BasicInMemoryTableCatalog extends TableCatalog { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + protected val namespaces: util.Map[List[String], Map[String, String]] = + new ConcurrentHashMap[List[String], Map[String, String]]() + + protected val tables: util.Map[Identifier, Table] = + new ConcurrentHashMap[Identifier, Table]() + + protected val functions: util.Map[Identifier, UnboundFunction] = + new ConcurrentHashMap[Identifier, UnboundFunction]() + + private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet() + + private var _name: Option[String] = None + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + _name = Some(name) + } + + override def name: String = _name.get + + override def listTables(namespace: Array[String]): Array[Identifier] = { + tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray + } + + override def loadTable(ident: Identifier): Table = { + Option(tables.get(ident)) match { + case Some(table) => + table + case _ => + throw new NoSuchTableException(ident) + } + } + + override def invalidateTable(ident: Identifier): Unit = { + invalidatedTables.add(ident) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + createTable(ident, schema, partitions, properties, Distributions.unspecified(), + Array.empty, None) + } + + def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + distribution: Distribution, + ordering: Array[SortOrder], + requiredNumPartitions: Option[Int]): Table = { + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + V2InMemoryCatalog.maybeSimulateFailedTableCreation(properties) + + val tableName = s"$name.${ident.quoted}" + val table = new InMemoryTable(tableName, schema, partitions, properties, distribution, + ordering, requiredNumPartitions) + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + val table = loadTable(ident).asInstanceOf[InMemoryTable] + val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) + val schema = CatalogV2Util.applySchemaChanges(table.schema, changes) + + // fail if the last column in the schema was dropped + if (schema.fields.isEmpty) { + throw new IllegalArgumentException(s"Cannot drop all fields") + } + + val newTable = new InMemoryTable(table.name, schema, table.partitioning, properties) + .withData(table.data) + + tables.put(ident, newTable) + + newTable + } + + override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined + + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + if (tables.containsKey(newIdent)) { + throw new TableAlreadyExistsException(newIdent) + } + + Option(tables.remove(oldIdent)) match { + case Some(table) => + tables.put(newIdent, table) + case _ => + throw new NoSuchTableException(oldIdent) + } + } + + def isTableInvalidated(ident: Identifier): Boolean = { + invalidatedTables.contains(ident) + } + + def clearTables(): Unit = { + tables.clear() + } +} + +class BasicInMemoryCatalog extends BasicInMemoryTableCatalog with FunctionCatalog { + override def listFunctions(namespace: Array[String]): Array[Identifier] = { + functions.keySet().asScala.filter(_.namespace().sameElements(namespace)).toArray + } + + override def loadFunction(ident: Identifier): UnboundFunction = { + Option(functions.get(ident)) match { + case Some(func) => + func + case _ => + throw new NoSuchFunctionException(ident) + } + } + + def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction = { + functions.put(ident, fn) + } +} + class V2InMemoryCatalog extends BasicInMemoryCatalog with SupportsNamespaces { private def allNamespaces: Seq[Seq[String]] = { (tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index 76c396f4d784a..e36c0c59f02f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{Identifier, SupportsNamespaces, V2InMemoryCatalog} import org.apache.spark.sql.connector.catalog.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -38,18 +39,83 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { catalog("testcat").asInstanceOf[V2InMemoryCatalog].createFunction(ident, fn) } + test("undefined function") { + assert(intercept[AnalysisException]( + sql("SELECT testcat.non_exist('abc')").collect() + ).getMessage.contains("Undefined function")) + } + + test("non-function catalog") { + spark.conf.set("spark.sql.catalog.testcat", classOf[BasicInMemoryTableCatalog].getName) + assert(intercept[AnalysisException]( + sql("SELECT testcat.strlen('abc')").collect() + ).getMessage.contains("is not a FunctionCatalog")) + } + + test("built-in with default v2 function catalog") { + spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat") + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + + test("built-in override with default v2 function catalog") { + // a built-in function with the same name should take higher priority + spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat") + addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + + test("temp function override with default v2 function catalog") { + val className = "test.org.apache.spark.sql.JavaStringLength" + sql(s"CREATE FUNCTION length AS '$className'") + + spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat") + addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + + test("view should use captured catalog and namespace for function lookup") { + val viewName = "my_view" + withView(viewName) { + spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat") + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "my_avg"), IntegralAverage) + sql("USE ns") + sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT my_avg(col1) FROM values (1), (2), (3)") + + // change default catalog and namespace and add a function with the same name but with no + // implementation + spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat2") + catalog("testcat2").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns2"), emptyProps) + addFunction(Identifier.of(Array("ns2"), "my_avg"), NoImplAverage) + sql("USE ns2") + checkAnswer(sql(s"SELECT * FROM $viewName"), Row(2.0) :: Nil) + } + } + test("scalar function: with default produceResult method") { catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault)) checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) } + test("scalar function: with default produceResult method w/ expression") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault)) + checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil) + } + test("scalar function: lookup magic method") { catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic)) checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) } + test("scalar function: lookup magic method w/ expression") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil) + } + test("scalar function: bad magic method") { catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagic)) @@ -74,9 +140,9 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault)) assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen(42)")) - .getMessage.contains("cannot process input")) + .getMessage.contains("Expect StringType")) assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('a', 'b')")) - .getMessage.contains("cannot process input")) + .getMessage.contains("Expect exactly one argument")) } test("scalar function: default produceResult in Java") { @@ -131,6 +197,17 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { } } + test("aggregate function: lookup int average w/ expression") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + + (1 to 100).toDF("i").write.saveAsTable(t) + checkAnswer(sql(s"SELECT testcat.ns.avg(i * 10) from $t"), Row(505) :: Nil) + } + } + test("aggregate function: unsupported input type") { import testImplicits._ val t = "testcat.ns.t" @@ -336,4 +413,13 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { override def produceResult(state: (Long, Long)): Long = state._1 / state._2 } + + object NoImplAverage extends UnboundFunction { + override def name(): String = "no_impl_avg" + override def description(): String = name() + + override def bind(inputType: StructType): BoundFunction = { + throw new UnsupportedOperationException(s"Not implemented") + } + } } From f4a3f327179c865387497c7b433d95eb5eba21b1 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 20 Apr 2021 10:02:36 -0700 Subject: [PATCH 07/21] minor changes --- .../catalog/functions/ScalarFunction.java | 20 +++++++++---------- .../sql/catalyst/analysis/Analyzer.scala | 14 ++++++------- .../sql/catalyst/catalog/SessionCatalog.scala | 4 ++-- .../catalog/CatalogV2Implicits.scala | 2 +- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index 86137b8e3dfcc..a1526a46899d2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -23,9 +23,9 @@ /** * Interface for a function that produces a result value for each input row. *

    - * To evaluate each input row, Spark will try to first lookup and use a "magic method" (described + * To evaluate each input row, Spark will first try to lookup and use a "magic method" (described * below) through Java reflection. If the method is not found, Spark will call - * {@link #produceResult(InternalRow)}. + * {@link #produceResult(InternalRow)} as a fallback approach. *

    * The JVM type of result values produced by this function must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. @@ -34,8 +34,8 @@ * {@link UnsupportedOperationException}. Users can choose to override this method, or implement * a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters * instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java - * reflection and also will provide better performance in general, due to optimizations such as - * codegen, Java boxing and so on. + * reflection and will also provide better performance in general, due to optimizations such as + * codegen, removal of Java boxing, etc. * * For example, a scalar UDF for adding two integers can be defined as follow with the magic * method approach: @@ -47,12 +47,12 @@ * } * } * - * In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will first lookup it during - * query analysis. This is done by first converting the actual input SQL data types to their - * corresponding Java types following the mapping defined below, and then checking if there is a - * matching method from all the declared methods in the UDF class, using method name (i.e., - * {@link #MAGIC_METHOD_NAME}) and the Java types. If no magic method is found, Spark will falls - * back to use {@link #produceResult}. + * In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over + * {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by + * first converting the actual input SQL data types to their corresponding Java types following + * the mapping defined below, and then checking if there is a matching method from all the + * declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and + * the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}. *

    * The following are the mapping from {@link DataType SQL data type} to Java type through * the magic method approach: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 28c1d98ccf849..b3c3fa82ea889 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -898,9 +898,9 @@ class Analyzer(override val catalogManager: CatalogManager) } } - // If we are resolving relations insides views, we need to expand single-part relation names with - // the current catalog and namespace of when the view was created. - private def expandRelationName(nameParts: Seq[String]): Seq[String] = { + // If we are resolving relations insides views, we may need to expand single or multi-part + // identifiers with the current catalog and namespace of when the view was created. + private def expandIdentifier(nameParts: Seq[String]): Seq[String] = { if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts if (nameParts.length == 1) { @@ -1043,7 +1043,7 @@ class Analyzer(override val catalogManager: CatalogManager) identifier: Seq[String], options: CaseInsensitiveStringMap, isStreaming: Boolean): Option[LogicalPlan] = - expandRelationName(identifier) match { + expandIdentifier(identifier) match { case NonSessionCatalogAndIdentifier(catalog, ident) => CatalogV2Util.loadTable(catalog, ident) match { case Some(table) => @@ -1156,7 +1156,7 @@ class Analyzer(override val catalogManager: CatalogManager) } private def lookupTableOrView(identifier: Seq[String]): Option[LogicalPlan] = { - expandRelationName(identifier) match { + expandIdentifier(identifier) match { case SessionCatalogAndIdentifier(catalog, ident) => CatalogV2Util.loadTable(catalog, ident).map { case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW => @@ -1176,7 +1176,7 @@ class Analyzer(override val catalogManager: CatalogManager) identifier: Seq[String], options: CaseInsensitiveStringMap, isStreaming: Boolean): Option[LogicalPlan] = { - expandRelationName(identifier) match { + expandIdentifier(identifier) match { case SessionCatalogAndIdentifier(catalog, ident) => lazy val loaded = CatalogV2Util.loadTable(catalog, ident).map { case v1Table: V1Table => @@ -2047,7 +2047,7 @@ class Analyzer(override val catalogManager: CatalogManager) } resultExpression.getOrElse( - expandRelationName(parts) match { + expandIdentifier(parts) match { case NonSessionCatalogAndIdentifier(catalog: FunctionCatalog, ident) => lookupV2Function(catalog, ident, arguments, isDistinct, filter, ignoreNulls) case _ => u diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index edd42486dff5d..94bd540613ce8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1670,8 +1670,8 @@ class SessionCatalog( } /** - * Return a optional [[Expression]] for the input built-in or temporary function with name - * `name`. None if the function doesn't exist. + * Looks up a built-in or temporary function with the given `name`. Returns `None` if the + * function doesn't exist. */ def lookupBuiltinOrTempFunction(name: String, children: Seq[Expression]): Option[Expression] = { lookupBuiltinOrTempFunctionInfo[Expression](name, children, functionRegistry) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 31edb1838ef19..56a11cd60a0ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -94,7 +94,7 @@ private[sql] object CatalogV2Implicits { functionCatalog case _ => throw new UnsupportedOperationException( - s"Cannot use catalog ${plugin.name}: not a FunctionCatalog") + s"Cannot use catalog '${plugin.name}': not a FunctionCatalog") } } From 1b94e655783b99ac795863cdab5552f5b7886f83 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 00:06:54 -0700 Subject: [PATCH 08/21] address more comments --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 +++++++++------- .../catalyst/analysis/higherOrderFunctions.scala | 10 +++++----- .../spark/sql/catalyst/parser/AstBuilder.scala | 14 +++++++------- .../connector/catalog/functions/JavaAverage.java | 4 ++-- 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b3c3fa82ea889..1bf2d2fbccb85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -284,7 +284,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables :: - ResolveHigherOrderFunctions(v1SessionCatalog) :: + ResolveHigherOrderFunctions(catalogManager) :: ResolveLambdaVariables :: ResolveTimeZone :: ResolveRandomSeed :: @@ -898,8 +898,9 @@ class Analyzer(override val catalogManager: CatalogManager) } } - // If we are resolving relations insides views, we may need to expand single or multi-part - // identifiers with the current catalog and namespace of when the view was created. + // If we are resolving database objects (relations, functions, etc.) insides views, we may need to + // expand single or multi-part identifiers with the current catalog and namespace of when the + // view was created. private def expandIdentifier(nameParts: Seq[String]): Seq[String] = { if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts @@ -2029,8 +2030,7 @@ class Analyzer(override val catalogManager: CatalogManager) } case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, - isDistinct, filter, ignoreNulls) => - withPosition(u) { + isDistinct, filter, ignoreNulls) => withPosition(u) { processFunctionExpr(v1SessionCatalog.lookupFunction(ident, arguments), arguments, isDistinct, filter, ignoreNulls) } @@ -2178,9 +2178,9 @@ class Analyzer(override val catalogManager: CatalogManager) unbound.bind(inputType) } catch { case unsupported: UnsupportedOperationException => - failAnalysis(s"Function '${unbound.name}' cannot process input: " + + throw new AnalysisException(s"Function '${unbound.name}' cannot process input: " + s"(${arguments.map(_.dataType.simpleString).mkString(", ")}): " + - unsupported.getMessage) + unsupported.getMessage, cause = Some(unsupported)) } bound match { @@ -2209,6 +2209,8 @@ class Analyzer(override val catalogManager: CatalogManager) // subclass do not override the default method in parent interface // defined in Java, the method can still be found from // `getDeclaredMethod`. + // since `inputType` is a `StructType`, it is mapped to a `InternalRow` which we + // can use to lookup the `produceResult` method. findMethod(scalarFunc, "produceResult", Some(Seq(inputType))) match { case Some(_) => ApplyFunctionExpression(scalarFunc, arguments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 3d6b4e97c8acd..b86d5d75a4550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.DataType @@ -31,13 +30,14 @@ import org.apache.spark.sql.types.DataType * so we need to resolve higher order function when all children are either resolved or a lambda * function. */ -case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { +case class ResolveHigherOrderFunctions(catalogManager: CatalogManager) + extends Rule[LogicalPlan] with LookupCatalog { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case u @ UnresolvedFunction(parts, children, false, filter, ignoreNulls) + case u @ UnresolvedFunction(AsFunctionIdentifier(ident), children, false, filter, ignoreNulls) if hasLambdaAndResolvedArguments(children) => withPosition(u) { - catalog.lookupFunction(parts.asFunctionIdentifier, children) match { + catalogManager.v1SessionCatalog.lookupFunction(ident, children) match { case func: HigherOrderFunction => filter.foreach(_.failAnalysis("FILTER predicate specified, " + s"but ${func.prettyName} is not an aggregate function")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8f9faf67f2cf4..76ef7aeefed0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1831,18 +1831,18 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create a function database (optional) and name pair, for multipartIdentifier. - * This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS. + * Create a function database (optional) and name pair. */ - protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = { - visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq) + protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { + visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq) } /** - * Create a function database (optional) and name pair. + * Create a function database (optional) and name pair, for multipartIdentifier. + * This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS. */ - protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { - visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq) + protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = { + visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java index 5cc3cf58750f4..041a6b61d694c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java @@ -39,7 +39,7 @@ public BoundFunction bind(StructType inputType) { throw new UnsupportedOperationException("Expect exactly one argument"); } if (inputType.fields()[0].dataType() instanceof IntegerType) { - return new JavaAverageNoImpl(); + return new BoundJavaAverageNoImpl(); } throw new UnsupportedOperationException("Unsupported non-integral type: " + inputType.fields()[0].dataType()); @@ -50,7 +50,7 @@ public String description() { return null; } - public static class JavaAverageNoImpl implements AggregateFunction { + public static class BoundJavaAverageNoImpl implements AggregateFunction { @Override public State newAggregationState() { return new State(0, 0); From 310cdca637a53abfd0a0eea8299b8731eb307d92 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 00:30:57 -0700 Subject: [PATCH 09/21] remove magic method from AggregateFunction --- .../catalog/functions/AggregateFunction.java | 13 ++--- .../sql/catalyst/analysis/Analyzer.scala | 40 +++++---------- .../sql/connector/catalog/LookupCatalog.scala | 3 +- .../catalog/functions/JavaAverage.java | 41 +++++++++------ .../connector/DataSourceV2FunctionSuite.scala | 51 ++++--------------- 5 files changed, 54 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java index 6982ebb329ff3..d4af2f547d300 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java @@ -25,12 +25,9 @@ /** * Interface for a function that produces a result value by aggregating over multiple input rows. *

    - * For each input row, Spark will call an update method that corresponds to the - * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by - * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call - * update with {@link InternalRow}. - *

    - * The JVM type of result values produced by this function must be the type used by Spark's + * For each input row, Spark will call the {@link #update} method which should evaluate the row + * and update the aggregation state. The JVM type of result values produced by + * {@link #produceResult} must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. *

    * All implementations must support partial aggregation by implementing merge so that Spark can @@ -68,9 +65,7 @@ public interface AggregateFunction extends BoundFunct * @param input an input row * @return updated aggregation state */ - default S update(S state, InternalRow input) { - throw new UnsupportedOperationException("Cannot find a compatible AggregateFunction#update"); - } + S update(S state, InternalRow input); /** * Merge two partial aggregation states. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1bf2d2fbccb85..42e3679a0d4be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2058,25 +2058,20 @@ class Analyzer(override val catalogManager: CatalogManager) } /** - * Check if the input `fn` implements the given `methodName`. If `inputType` is set, it also - * tries to match it against the declared parameter types. + * Check if the input `fn` implements the given `methodName` with parameter types specified + * via `inputType`. */ private def findMethod( fn: BoundFunction, methodName: String, - inputTypeOpt: Option[Seq[DataType]] = None): Option[Method] = { + inputType: Seq[DataType]): Option[Method] = { val cls = fn.getClass - inputTypeOpt match { - case Some(inputType) => - try { - val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass) - Some(cls.getDeclaredMethod(methodName, argClasses: _*)) - } catch { - case _: NoSuchMethodException => - None - } - case None => - cls.getDeclaredMethods.find(_.getName == methodName) + try { + val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass) + Some(cls.getDeclaredMethod(methodName, argClasses: _*)) + } catch { + case _: NoSuchMethodException => + None } } @@ -2199,7 +2194,7 @@ class Analyzer(override val catalogManager: CatalogManager) // also want to check if the parameter types from the magic method match the // input type through `BoundFunction.inputTypes`. val argClasses = inputType.fields.map(_.dataType) - findMethod(scalarFunc, MAGIC_METHOD_NAME, Some(argClasses)) match { + findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { case Some(_) => val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), @@ -2211,7 +2206,7 @@ class Analyzer(override val catalogManager: CatalogManager) // `getDeclaredMethod`. // since `inputType` is a `StructType`, it is mapped to a `InternalRow` which we // can use to lookup the `produceResult` method. - findMethod(scalarFunc, "produceResult", Some(Seq(inputType))) match { + findMethod(scalarFunc, "produceResult", Seq(inputType)) match { case Some(_) => ApplyFunctionExpression(scalarFunc, arguments) case None => @@ -2225,17 +2220,8 @@ class Analyzer(override val catalogManager: CatalogManager) throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( aggFunc.name(), "IGNORE NULLS") } - // due to type erasure we can't match by parameter types here, so this check - // will succeed even if the class doesn't override `update` but implements - // another method with the same name. - findMethod(aggFunc, "update") match { - case Some(_) => - val aggregator = V2Aggregator(aggFunc, arguments) - AggregateExpression(aggregator, Complete, isDistinct, filter) - case None => - failAnalysis(s"AggregateFunction '${bound.name()}' neither implement " + - s"magic method nor override 'update'") - } + val aggregator = V2Aggregator(aggFunc, arguments) + AggregateExpression(aggregator, Complete, isDistinct, filter) case _ => failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction " + s"or AggregateFunction") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index d0d226e27f925..461d5864cf1d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -166,8 +166,7 @@ private[sql] trait LookupCatalog extends Logging { if CatalogV2Util.isSessionCatalog(currentCatalog) => namesToFunctionIdentifier(names) case CatalogAndMultipartIdentifier(Some(catalog), names) - if CatalogV2Util.isSessionCatalog(catalog) && - CatalogV2Util.isSessionCatalog(currentCatalog) => + if CatalogV2Util.isSessionCatalog(catalog) => namesToFunctionIdentifier(names) case _ => None } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java index 041a6b61d694c..4e783fdd439b6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java @@ -17,12 +17,13 @@ package test.org.apache.spark.sql.connector.catalog.functions; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.functions.AggregateFunction; import org.apache.spark.sql.connector.catalog.functions.BoundFunction; import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.DoubleType; import org.apache.spark.sql.types.StructType; import java.io.Serializable; @@ -30,7 +31,7 @@ public class JavaAverage implements UnboundFunction { @Override public String name() { - return "iavg"; + return "avg"; } @Override @@ -38,8 +39,8 @@ public BoundFunction bind(StructType inputType) { if (inputType.fields().length != 1) { throw new UnsupportedOperationException("Expect exactly one argument"); } - if (inputType.fields()[0].dataType() instanceof IntegerType) { - return new BoundJavaAverageNoImpl(); + if (inputType.fields()[0].dataType() instanceof DoubleType) { + return new JavaDoubleAverage(); } throw new UnsupportedOperationException("Unsupported non-integral type: " + inputType.fields()[0].dataType()); @@ -50,42 +51,50 @@ public String description() { return null; } - public static class BoundJavaAverageNoImpl implements AggregateFunction { + public static class JavaDoubleAverage implements AggregateFunction, Double> { @Override - public State newAggregationState() { - return new State(0, 0); + public State newAggregationState() { + return new State<>(0.0, 0.0); } @Override - public Integer produceResult(State state) { + public State update(State state, InternalRow input) { + if (input.isNullAt(0)) { + return state; + } + return new State<>(state.sum + input.getDouble(0), state.count + 1); + } + + @Override + public Double produceResult(State state) { return state.sum / state.count; } @Override - public State merge(State leftState, State rightState) { - return new State(leftState.sum + rightState.sum, leftState.count + rightState.count); + public State merge(State leftState, State rightState) { + return new State<>(leftState.sum + rightState.sum, leftState.count + rightState.count); } @Override public DataType[] inputTypes() { - return new DataType[] { DataTypes.LongType }; + return new DataType[] { DataTypes.DoubleType }; } @Override public DataType resultType() { - return DataTypes.LongType; + return DataTypes.DoubleType; } @Override public String name() { - return "iavg"; + return "davg"; } } - public static class State implements Serializable { - int sum, count; + public static class State implements Serializable { + T sum, count; - State(int sum, int count) { + State(T sum, T count) { this.sum = sum; this.count = count; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index e36c0c59f02f9..4a60516628da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -197,50 +197,37 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { } } - test("aggregate function: lookup int average w/ expression") { + test("aggregate function: lookup double average in Java") { import testImplicits._ val t = "testcat.ns.t" withTable(t) { - addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + addFunction(Identifier.of(Array("ns"), "avg"), new JavaAverage) - (1 to 100).toDF("i").write.saveAsTable(t) - checkAnswer(sql(s"SELECT testcat.ns.avg(i * 10) from $t"), Row(505) :: Nil) + Seq(1.toDouble, 2.toDouble, 3.toDouble).toDF("i").write.saveAsTable(t) + checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(2.0) :: Nil) } } - test("aggregate function: unsupported input type") { + test("aggregate function: lookup int average w/ expression") { import testImplicits._ val t = "testcat.ns.t" withTable(t) { addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) - Seq(1.toShort, 2.toShort).toDF("i").write.saveAsTable(t) - assert(intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t")) - .getMessage.contains("Unsupported non-integral type: ShortType")) + (1 to 100).toDF("i").write.saveAsTable(t) + checkAnswer(sql(s"SELECT testcat.ns.avg(i * 10) from $t"), Row(505) :: Nil) } } - test("aggregate function: doesn't implement update should throw runtime error") { + test("aggregate function: unsupported input type") { import testImplicits._ val t = "testcat.ns.t" withTable(t) { addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) - Seq(1.toByte, 2.toByte).toDF("i").write.saveAsTable(t) - assert(intercept[SparkException](sql(s"SELECT testcat.ns.avg(i) from $t").collect()) - .getMessage.contains("Cannot find a compatible AggregateFunction")) - } - } - - test("aggregate function: doesn't implement update in Java should throw analysis error") { - import testImplicits._ - val t = "testcat.ns.t" - withTable(t) { - addFunction(Identifier.of(Array("ns"), "avg"), new JavaAverage) - - (1 to 100).toDF("i").write.saveAsTable(t) - assert(intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t").collect()) - .getMessage.contains("neither implement magic method nor override 'update'")) + Seq(1.toShort, 2.toShort).toDF("i").write.saveAsTable(t) + assert(intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t")) + .getMessage.contains("Unsupported non-integral type: ShortType")) } } @@ -329,7 +316,6 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { } inputType.fields(0).dataType match { - case _: ByteType => ByteAverage case _: IntegerType => IntAverage case _: LongType => LongAverage case dataType => @@ -399,21 +385,6 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { override def produceResult(state: (Long, Long)): Long = state._1 / state._2 } - /** Bad implementation which doesn't override `produceResult` */ - object ByteAverage extends AggregateFunction[(Long, Long), Long] { - override def name(): String = "iavg" - override def inputTypes(): Array[DataType] = Array(LongType) - override def resultType(): DataType = LongType - - override def newAggregationState(): (Long, Long) = (0L, 0L) - - override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = { - (leftState._1 + rightState._1, leftState._2 + rightState._2) - } - - override def produceResult(state: (Long, Long)): Long = state._1 / state._2 - } - object NoImplAverage extends UnboundFunction { override def name(): String = "no_impl_avg" override def description(): String = name() From 57b3c253da516c639b5a1a5a701c6ea37064aa00 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 01:27:03 -0700 Subject: [PATCH 10/21] use withConf --- .../connector/DataSourceV2FunctionSuite.scala | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index 4a60516628da6..0af2dc8e76d4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -46,49 +46,56 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { } test("non-function catalog") { - spark.conf.set("spark.sql.catalog.testcat", classOf[BasicInMemoryTableCatalog].getName) - assert(intercept[AnalysisException]( - sql("SELECT testcat.strlen('abc')").collect() - ).getMessage.contains("is not a FunctionCatalog")) + withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { + assert(intercept[AnalysisException]( + sql("SELECT testcat.strlen('abc')").collect() + ).getMessage.contains("is not a FunctionCatalog")) + } } test("built-in with default v2 function catalog") { - spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat") - checkAnswer(sql("SELECT length('abc')"), Row(3)) + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } } test("built-in override with default v2 function catalog") { // a built-in function with the same name should take higher priority - spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat") - addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) - checkAnswer(sql("SELECT length('abc')"), Row(3)) + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { + addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } } test("temp function override with default v2 function catalog") { val className = "test.org.apache.spark.sql.JavaStringLength" sql(s"CREATE FUNCTION length AS '$className'") - spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat") - addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) - checkAnswer(sql("SELECT length('abc')"), Row(3)) + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { + addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } } test("view should use captured catalog and namespace for function lookup") { val viewName = "my_view" withView(viewName) { - spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat") - catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) - addFunction(Identifier.of(Array("ns"), "my_avg"), IntegralAverage) - sql("USE ns") - sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT my_avg(col1) FROM values (1), (2), (3)") + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "my_avg"), IntegralAverage) + sql("USE ns") + sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT my_avg(col1) FROM values (1), (2), (3)") + } // change default catalog and namespace and add a function with the same name but with no // implementation - spark.conf.set(SQLConf.DEFAULT_CATALOG.key, "testcat2") - catalog("testcat2").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns2"), emptyProps) - addFunction(Identifier.of(Array("ns2"), "my_avg"), NoImplAverage) - sql("USE ns2") - checkAnswer(sql(s"SELECT * FROM $viewName"), Row(2.0) :: Nil) + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat2") { + catalog("testcat2").asInstanceOf[SupportsNamespaces] + .createNamespace(Array("ns2"), emptyProps) + addFunction(Identifier.of(Array("ns2"), "my_avg"), NoImplAverage) + sql("USE ns2") + checkAnswer(sql(s"SELECT * FROM $viewName"), Row(2.0) :: Nil) + } } } From de726d001c6122a18b4e6adced514f7cc98ee564 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 11:13:05 -0700 Subject: [PATCH 11/21] reduce changes on InMemoryTableCatalog --- ...eateTablePartitioningValidationSuite.scala | 4 +- .../analysis/TableLookupCacheSuite.scala | 4 +- .../catalog/CatalogManagerSuite.scala | 8 +-- .../sql/connector/catalog/CatalogSuite.scala | 14 ++--- .../connector/catalog/InMemoryCatalog.scala | 58 +++++++++++++++++++ ...la => InMemoryPartitionTableCatalog.scala} | 4 +- ...talog.scala => InMemoryTableCatalog.scala} | 41 ++----------- .../catalog/StagingInMemoryTableCatalog.scala | 4 +- ...pportsAtomicPartitionManagementSuite.scala | 4 +- .../SupportsPartitionManagementSuite.scala | 6 +- .../spark/sql/JavaDataFrameWriterV2Suite.java | 4 +- .../spark/sql/CharVarcharTestSuite.scala | 4 +- .../spark/sql/DataFrameWriterV2Suite.scala | 4 +- .../apache/spark/sql/SQLInsertTestSuite.scala | 4 +- .../DataSourceV2DataFrameSuite.scala | 4 +- .../connector/DataSourceV2FunctionSuite.scala | 4 +- .../sql/connector/DataSourceV2SQLSuite.scala | 20 +++---- .../sql/connector/DatasourceV2SQLBase.scala | 10 ++-- .../SupportsCatalogOptionsSuite.scala | 2 +- .../sql/connector/V1ReadFallbackSuite.scala | 2 +- .../WriteDistributionAndOrderingSuite.scala | 6 +- .../spark/sql/execution/HiveResultSuite.scala | 6 +- .../command/CharVarcharDDLTestBase.scala | 4 +- .../command/v2/CommandSuiteBase.scala | 6 +- .../command/v2/ShowNamespacesSuite.scala | 4 +- .../test/DataStreamTableAPISuite.scala | 10 ++-- 26 files changed, 134 insertions(+), 107 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/{V2InMemoryPartitionCatalog.scala => InMemoryPartitionTableCatalog.scala} (91%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/{V2InMemoryCatalog.scala => InMemoryTableCatalog.scala} (84%) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index fde969108a49c..f7e57e3b27b21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode} -import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog, V2InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog} import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -134,7 +134,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { private[sql] object CreateTablePartitioningValidationSuite { val catalog: TableCatalog = { - val cat = new V2InMemoryCatalog() + val cat = new InMemoryTableCatalog() cat.initialize("test", CaseInsensitiveStringMap.empty()) cat } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala index ddd0ab983b684..7d6ad3bc60902 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, ExternalCatalog, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, InMemoryTable, Table, V2InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, InMemoryTable, InMemoryTableCatalog, Table} import org.apache.spark.sql.types._ class TableLookupCacheSuite extends AnalysisTest with Matchers { @@ -45,7 +45,7 @@ class TableLookupCacheSuite extends AnalysisTest with Matchers { CatalogStorageFormat.empty, StructType(Seq(StructField("a", IntegerType)))), ignoreIfExists = false) - val v2Catalog = new V2InMemoryCatalog { + val v2Catalog = new InMemoryTableCatalog { override def loadTable(ident: Identifier): Table = { val catalogTable = externalCatalog.getTable("default", ident.name) new InMemoryTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala index a8fbbb9c06d98..bfff3ee855e6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, FakeV2SessionCatalog, NoSuchNamespaceException} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog => V1InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class CatalogManagerSuite extends SparkFunSuite with SQLHelper { private def createSessionCatalog(): SessionCatalog = { - val catalog = new InMemoryCatalog() + val catalog = new V1InMemoryCatalog() catalog.createDatabase( CatalogDatabase(SessionCatalog.DEFAULT_DATABASE, "", new URI("fake"), Map.empty), ignoreIfExists = true) @@ -113,9 +113,9 @@ class CatalogManagerSuite extends SparkFunSuite with SQLHelper { assert(v1SessionCatalog.getCurrentDatabase == "default") // Check namespace existence if currentCatalog implements SupportsNamespaces. - withSQLConf("spark.sql.catalog.testCatalog" -> classOf[V2InMemoryCatalog].getName) { + withSQLConf("spark.sql.catalog.testCatalog" -> classOf[InMemoryTableCatalog].getName) { catalogManager.setCurrentCatalog("testCatalog") - catalogManager.currentCatalog.asInstanceOf[V2InMemoryCatalog] + catalogManager.currentCatalog.asInstanceOf[InMemoryTableCatalog] .createNamespace(Array("test3"), Map.empty[String, String].asJava) assert(v1SessionCatalog.getCurrentDatabase == "default") catalogManager.setCurrentNamespace(Array("test3")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index a93c9c1bb4687..0cca1cc9bebf2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -40,8 +40,8 @@ class CatalogSuite extends SparkFunSuite { .add("id", IntegerType) .add("data", StringType) - private def newCatalog(): V2InMemoryCatalog = { - val newCatalog = new V2InMemoryCatalog + private def newCatalog(): InMemoryCatalog = { + val newCatalog = new InMemoryCatalog newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) newCatalog } @@ -903,7 +903,7 @@ class CatalogSuite extends SparkFunSuite { } test("truncate partitioned table") { - val partCatalog = new V2InMemoryPartitionCatalog + val partCatalog = new InMemoryPartitionTableCatalog partCatalog.initialize("test", CaseInsensitiveStringMap.empty()) val table = partCatalog.createTable( @@ -947,9 +947,9 @@ class CatalogSuite extends SparkFunSuite { catalog.createNamespace(Array("ns1", "ns2"), emptyProps) catalog.createNamespace(Array("ns1", "ns3"), emptyProps) - catalog.asInstanceOf[V2InMemoryCatalog].createFunction(ident1, function) - catalog.asInstanceOf[V2InMemoryCatalog].createFunction(ident2, function) - catalog.asInstanceOf[V2InMemoryCatalog].createFunction(ident3, function) + catalog.createFunction(ident1, function) + catalog.createFunction(ident2, function) + catalog.createFunction(ident3, function) assert(catalog.listFunctions(Array("ns1", "ns2")).toSet === Set(ident1, ident2)) assert(catalog.listFunctions(Array("ns1", "ns3")).toSet === Set(ident3)) @@ -961,7 +961,7 @@ class CatalogSuite extends SparkFunSuite { val catalog = newCatalog() val ident = Identifier.of(Array("ns"), "func") catalog.createNamespace(Array("ns"), emptyProps) - catalog.asInstanceOf[V2InMemoryCatalog].createFunction(ident, function) + catalog.createFunction(ident, function) assert(catalog.loadFunction(ident) == function) intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns"), "func1"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala new file mode 100644 index 0000000000000..202b03f28f082 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, NoSuchNamespaceException} +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction + +class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog { + protected val functions: util.Map[Identifier, UnboundFunction] = + new ConcurrentHashMap[Identifier, UnboundFunction]() + + override protected def allNamespaces: Seq[Seq[String]] = { + (tables.keySet.asScala.map(_.namespace.toSeq) ++ + functions.keySet.asScala.map(_.namespace.toSeq) ++ + namespaces.keySet.asScala).toSeq.distinct + } + + override def listFunctions(namespace: Array[String]): Array[Identifier] = { + if (namespace.isEmpty || namespaceExists(namespace)) { + functions.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray + } else { + throw new NoSuchNamespaceException(namespace) + } + } + + override def loadFunction(ident: Identifier): UnboundFunction = { + Option(functions.get(ident)) match { + case Some(func) => + func + case _ => + throw new NoSuchFunctionException(ident) + } + } + + def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction = { + functions.put(ident, fn) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryPartitionCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala similarity index 91% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryPartitionCatalog.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala index 76c1010524683..a24f5c9a0c463 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryPartitionCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType -class V2InMemoryPartitionCatalog extends V2InMemoryCatalog { +class InMemoryPartitionTableCatalog extends InMemoryTableCatalog { import CatalogV2Implicits._ override def createTable( @@ -35,7 +35,7 @@ class V2InMemoryPartitionCatalog extends V2InMemoryCatalog { throw new TableAlreadyExistsException(ident) } - V2InMemoryCatalog.maybeSimulateFailedTableCreation(properties) + InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) val table = new InMemoryAtomicPartitionTable( s"$name.${ident.quoted}", schema, partitions, properties) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala similarity index 84% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 68b210ef36cc7..0c403baca2113 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2InMemoryCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -22,8 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} -import org.apache.spark.sql.connector.catalog.functions.UnboundFunction +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.types.StructType @@ -38,9 +37,6 @@ class BasicInMemoryTableCatalog extends TableCatalog { protected val tables: util.Map[Identifier, Table] = new ConcurrentHashMap[Identifier, Table]() - protected val functions: util.Map[Identifier, UnboundFunction] = - new ConcurrentHashMap[Identifier, UnboundFunction]() - private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet() private var _name: Option[String] = None @@ -89,7 +85,7 @@ class BasicInMemoryTableCatalog extends TableCatalog { throw new TableAlreadyExistsException(ident) } - V2InMemoryCatalog.maybeSimulateFailedTableCreation(properties) + InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) val tableName = s"$name.${ident.quoted}" val table = new InMemoryTable(tableName, schema, partitions, properties, distribution, @@ -141,27 +137,8 @@ class BasicInMemoryTableCatalog extends TableCatalog { } } -class BasicInMemoryCatalog extends BasicInMemoryTableCatalog with FunctionCatalog { - override def listFunctions(namespace: Array[String]): Array[Identifier] = { - functions.keySet().asScala.filter(_.namespace().sameElements(namespace)).toArray - } - - override def loadFunction(ident: Identifier): UnboundFunction = { - Option(functions.get(ident)) match { - case Some(func) => - func - case _ => - throw new NoSuchFunctionException(ident) - } - } - - def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction = { - functions.put(ident, fn) - } -} - -class V2InMemoryCatalog extends BasicInMemoryCatalog with SupportsNamespaces { - private def allNamespaces: Seq[Seq[String]] = { +class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamespaces { + protected def allNamespaces: Seq[Seq[String]] = { (tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct } @@ -233,17 +210,9 @@ class V2InMemoryCatalog extends BasicInMemoryCatalog with SupportsNamespaces { throw new NoSuchNamespaceException(namespace) } } - - override def listFunctions(namespace: Array[String]): Array[Identifier] = { - if (namespace.isEmpty || namespaceExists(namespace)) { - super.listFunctions(namespace) - } else { - throw new NoSuchNamespaceException(namespace) - } - } } -object V2InMemoryCatalog { +object InMemoryTableCatalog { val SIMULATE_FAILED_CREATE_PROPERTY = "spark.sql.test.simulateFailedCreate" val SIMULATE_DROP_BEFORE_REPLACE_PROPERTY = "spark.sql.test.simulateDropBeforeReplace" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala index 54724d6129fa3..954650ae0eebd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -class StagingInMemoryCatalog extends V2InMemoryCatalog with StagingTableCatalog { - import V2InMemoryCatalog._ +class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTableCatalog { + import InMemoryTableCatalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def stageCreate( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala index 00e9e2ff26cb5..df2fbd6d179bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala @@ -32,8 +32,8 @@ class SupportsAtomicPartitionManagementSuite extends SparkFunSuite { def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) - private val catalog: V2InMemoryCatalog = { - val newCatalog = new V2InMemoryCatalog + private val catalog: InMemoryTableCatalog = { + val newCatalog = new InMemoryTableCatalog newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) newCatalog.createTable( ident, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala index 332b0975c8a34..e5aeb90b841a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala @@ -34,8 +34,8 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) - private val catalog: V2InMemoryCatalog = { - val newCatalog = new V2InMemoryCatalog + private val catalog: InMemoryTableCatalog = { + val newCatalog = new InMemoryTableCatalog newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) newCatalog.createTable( ident, @@ -156,7 +156,7 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { } private def createMultiPartTable(): InMemoryPartitionTable = { - val partCatalog = new V2InMemoryPartitionCatalog + val partCatalog = new InMemoryPartitionTableCatalog partCatalog.initialize("test", CaseInsensitiveStringMap.empty()) val table = partCatalog.createTable( ident, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java index b1ddeb1b88864..59c5263563b27 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; -import org.apache.spark.sql.connector.catalog.V2InMemoryCatalog; +import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.StructType; import org.junit.After; @@ -43,7 +43,7 @@ public Dataset df() { @Before public void createTestTable() { this.spark = new TestSparkSession(); - spark.conf().set("spark.sql.catalog.testcat", V2InMemoryCatalog.class.getName()); + spark.conf().set("spark.sql.catalog.testcat", InMemoryTableCatalog.class.getName()); spark.sql("CREATE TABLE testcat.t (s string) USING foo"); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 15be034bd57ba..c06544ee00621 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.SchemaRequiredDataSource -import org.apache.spark.sql.connector.catalog.V2InMemoryPartitionCatalog +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf @@ -870,7 +870,7 @@ class DSV2CharVarcharTestSuite extends CharVarcharTestSuite override def format: String = "foo" protected override def sparkConf = { super.sparkConf - .set("spark.sql.catalog.testcat", classOf[V2InMemoryPartitionCatalog].getName) + .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) .set(SQLConf.DEFAULT_CATALOG.key, "testcat") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index 20c7ae5947d35..8aef27a1b6692 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -25,7 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, TableCatalog, V2InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, InMemoryTableCatalog, TableCatalog} import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -48,7 +48,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo private val defaultOwnership = Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName()) before { - spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index d11b922ea67bb..2f56fbaf7f821 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.connector.catalog.V2InMemoryPartitionCatalog +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.unsafe.types.UTF8String @@ -320,7 +320,7 @@ class DSV2SQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession protected override def sparkConf: SparkConf = { super.sparkConf - .set("spark.sql.catalog.testcat", classOf[V2InMemoryPartitionCatalog].getName) + .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) .set(SQLConf.DEFAULT_CATALOG.key, "testcat") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 8ca900c0abff4..d83d1a2755928 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -34,8 +34,8 @@ class DataSourceV2DataFrameSuite import testImplicits._ before { - spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) - spark.conf.set("spark.sql.catalog.testcat2", classOf[V2InMemoryCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName) } after { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index 0af2dc8e76d4c..882890cc51516 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -26,7 +26,7 @@ import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.catalog.{Identifier, SupportsNamespaces, V2InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -36,7 +36,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = { - catalog("testcat").asInstanceOf[V2InMemoryCatalog].createFunction(ident, fn) + catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn) } test("undefined function") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 9bc4a96afaccc..13facc36876b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -494,7 +494,7 @@ class DataSourceV2SQLSuite intercept[Exception] { spark.sql("REPLACE TABLE testcat.table_name" + s" USING foo" + - s" TBLPROPERTIES (`${V2InMemoryCatalog.SIMULATE_FAILED_CREATE_PROPERTY}`=true)" + + s" TBLPROPERTIES (`${InMemoryTableCatalog.SIMULATE_FAILED_CREATE_PROPERTY}`=true)" + s" AS SELECT id FROM source") } @@ -519,7 +519,7 @@ class DataSourceV2SQLSuite intercept[Exception] { spark.sql("REPLACE TABLE testcat_atomic.table_name" + s" USING foo" + - s" TBLPROPERTIES (`${V2InMemoryCatalog.SIMULATE_FAILED_CREATE_PROPERTY}`=true)" + + s" TBLPROPERTIES (`${InMemoryTableCatalog.SIMULATE_FAILED_CREATE_PROPERTY}`=true)" + s" AS SELECT id FROM source") } @@ -578,7 +578,7 @@ class DataSourceV2SQLSuite } test("ReplaceTableAsSelect: REPLACE TABLE throws exception if table is dropped before commit.") { - import V2InMemoryCatalog._ + import InMemoryTableCatalog._ spark.sql(s"CREATE TABLE testcat_atomic.created USING $v2Source AS SELECT id, data FROM source") intercept[CannotReplaceMissingTableException] { spark.sql(s"REPLACE TABLE testcat_atomic.replaced" + @@ -1390,7 +1390,7 @@ class DataSourceV2SQLSuite "and namespace does not exist") { // Namespaces are not required to exist for v2 catalogs // that does not implement SupportsNamespaces. - withSQLConf("spark.sql.catalog.dummy" -> classOf[BasicInMemoryCatalog].getName) { + withSQLConf("spark.sql.catalog.dummy" -> classOf[BasicInMemoryTableCatalog].getName) { val catalogManager = spark.sessionState.catalogManager sql("USE dummy.ns1") @@ -1547,7 +1547,7 @@ class DataSourceV2SQLSuite |CLUSTERED BY (`a.b`) INTO 4 BUCKETS """.stripMargin) - val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[V2InMemoryCatalog] + val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] val table = testCatalog.loadTable(Identifier.of(Array.empty, "t")) val partitioning = table.partitioning() assert(partitioning.length == 1 && partitioning.head.name() == "bucket") @@ -1614,7 +1614,7 @@ class DataSourceV2SQLSuite withTable(t) { sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[V2InMemoryCatalog] + val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] val identifier = Identifier.of(Array("ns1", "ns2"), "tbl") assert(!testCatalog.isTableInvalidated(identifier)) @@ -1630,7 +1630,7 @@ class DataSourceV2SQLSuite sql("CREATE TEMPORARY VIEW t AS SELECT 2") sql("USE testcat.ns") - val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[V2InMemoryCatalog] + val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] val identifier = Identifier.of(Array("ns"), "t") assert(!testCatalog.isTableInvalidated(identifier)) @@ -2142,7 +2142,7 @@ class DataSourceV2SQLSuite test("global temp view should not be masked by v2 catalog") { val globalTempDB = spark.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) - spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[V2InMemoryCatalog].getName) + spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) try { sql("create global temp view v as select 1") @@ -2167,7 +2167,7 @@ class DataSourceV2SQLSuite test("SPARK-30104: v2 catalog named global_temp will be masked") { val globalTempDB = spark.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) - spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[V2InMemoryCatalog].getName) + spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) val e = intercept[AnalysisException] { // Since the following multi-part name starts with `globalTempDB`, it is resolved to @@ -2366,7 +2366,7 @@ class DataSourceV2SQLSuite intercept[AnalysisException](sql("COMMENT ON TABLE testcat.abc IS NULL")) val globalTempDB = spark.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) - spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[V2InMemoryCatalog].getName) + spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) withTempView("v") { sql("create global temp view v as select 1") val e = intercept[AnalysisException](sql("COMMENT ON TABLE global_temp.v IS NULL")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala index 723d9148eb60d..77a515b55ce76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import org.scalatest.BeforeAndAfter import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryPartitionTableCatalog, InMemoryTableCatalog, StagingInMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryCatalog, InMemoryPartitionTableCatalog, StagingInMemoryTableCatalog} import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSparkSession @@ -32,11 +32,11 @@ trait DatasourceV2SQLBase } before { - spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) - spark.conf.set("spark.sql.catalog.testpart", classOf[V2InMemoryPartitionCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) + spark.conf.set("spark.sql.catalog.testpart", classOf[InMemoryPartitionTableCatalog].getName) spark.conf.set( - "spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryCatalog].getName) - spark.conf.set("spark.sql.catalog.testcat2", classOf[V2InMemoryCatalog].getName) + "spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryCatalog].getName) spark.conf.set( V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index a2566beffc7ba..076dad7530807 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -55,7 +55,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with spark.conf.set( V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName) spark.conf.set( - s"spark.sql.catalog.$catalogName", classOf[V2InMemoryCatalog].getName) + s"spark.sql.catalog.$catalogName", classOf[InMemoryTableCatalog].getName) } override def afterEach(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index 1328f6f61d764..847953e09cef7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -101,7 +101,7 @@ class V1ReadFallbackWithCatalogSuite extends V1ReadFallbackSuite { } } -class V1ReadFallbackCatalog extends BasicInMemoryCatalog { +class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { override def createTable( ident: Identifier, schema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 6c17a7dcd347d..db4a9c153c0ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -44,7 +44,7 @@ class WriteDistributionAndOrderingSuite import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ before { - spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) } after { @@ -756,9 +756,9 @@ class WriteDistributionAndOrderingSuite UnresolvedAttribute(name) } - private def catalog: V2InMemoryCatalog = { + private def catalog: InMemoryTableCatalog = { val catalog = spark.sessionState.catalogManager.catalog("testcat") - catalog.asTableCatalog.asInstanceOf[V2InMemoryCatalog] + catalog.asTableCatalog.asInstanceOf[InMemoryTableCatalog] } // executes a write operation and keeps the executed physical plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala index b60b390a75ff9..f8366b3f7c5fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.time.{Duration, Period} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils -import org.apache.spark.sql.connector.catalog.V2InMemoryCatalog +import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog import org.apache.spark.sql.execution.HiveResult._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} @@ -80,7 +80,7 @@ class HiveResultSuite extends SharedSparkSession { } test("SHOW TABLES in hive result") { - withSQLConf("spark.sql.catalog.testcat" -> classOf[V2InMemoryCatalog].getName) { + withSQLConf("spark.sql.catalog.testcat" -> classOf[InMemoryTableCatalog].getName) { Seq(("testcat.ns", "tbl", "foo"), ("spark_catalog.default", "tbl", "csv")).foreach { case (ns, tbl, source) => withTable(s"$ns.$tbl") { @@ -94,7 +94,7 @@ class HiveResultSuite extends SharedSparkSession { } test("DESCRIBE TABLE in hive result") { - withSQLConf("spark.sql.catalog.testcat" -> classOf[V2InMemoryCatalog].getName) { + withSQLConf("spark.sql.catalog.testcat" -> classOf[InMemoryTableCatalog].getName) { Seq(("testcat.ns", "tbl", "foo"), ("spark_catalog.default", "tbl", "csv")).foreach { case (ns, tbl, source) => withTable(s"$ns.$tbl") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala index 5e2e6cc592727..ba683c049a631 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.catalog.V2InMemoryPartitionCatalog +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types._ @@ -151,7 +151,7 @@ class DSV2CharVarcharDDLTestSuite extends CharVarcharDDLTestBase override def format: String = "foo" protected override def sparkConf = { super.sparkConf - .set("spark.sql.catalog.testcat", classOf[V2InMemoryPartitionCatalog].getName) + .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) .set(SQLConf.DEFAULT_CATALOG.key, "testcat") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala index 7d2df64583e09..bed04f4f2659b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.ResolvePartitionSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryPartitionTable, V2InMemoryCatalog, V2InMemoryPartitionCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} import org.apache.spark.sql.test.SharedSparkSession /** @@ -36,8 +36,8 @@ trait CommandSuiteBase extends SharedSparkSession { // V2 catalogs created and used especially for testing override def sparkConf: SparkConf = super.sparkConf - .set(s"spark.sql.catalog.$catalog", classOf[V2InMemoryPartitionCatalog].getName) - .set(s"spark.sql.catalog.non_part_$catalog", classOf[V2InMemoryCatalog].getName) + .set(s"spark.sql.catalog.$catalog", classOf[InMemoryPartitionTableCatalog].getName) + .set(s"spark.sql.catalog.non_part_$catalog", classOf[InMemoryTableCatalog].getName) def checkLocation( t: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala index fb7ffc9967b2b..bafb6608c8e6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.BasicInMemoryCatalog +import org.apache.spark.sql.connector.catalog.BasicInMemoryTableCatalog import org.apache.spark.sql.execution.command import org.apache.spark.sql.internal.SQLConf @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf */ class ShowNamespacesSuite extends command.ShowNamespacesSuiteBase with CommandSuiteBase { override def sparkConf: SparkConf = super.sparkConf - .set("spark.sql.catalog.testcat_no_namespace", classOf[BasicInMemoryCatalog].getName) + .set("spark.sql.catalog.testcat_no_namespace", classOf[BasicInMemoryTableCatalog].getName) test("IN namespace doesn't exist") { withSQLConf(SQLConf.DEFAULT_CATALOG.key -> catalog) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index de67c54332c24..49e5218ea3352 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.{FakeV2Provider, InMemoryTableSessionCatalog} -import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table, TableCapability, V2InMemoryCatalog, V2TableWithV1Fallback} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, SupportsRead, Table, TableCapability, V2TableWithV1Fallback} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.streaming.{MemoryStream, MemoryStreamScanBuilder} @@ -46,8 +46,8 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ before { - spark.conf.set("spark.sql.catalog.testcat", classOf[V2InMemoryCatalog].getName) - spark.conf.set("spark.sql.catalog.teststream", classOf[V2InMemoryStreamCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + spark.conf.set("spark.sql.catalog.teststream", classOf[InMemoryStreamTableCatalog].getName) } after { @@ -157,7 +157,7 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { test("read: fallback to V1 relation") { val tblName = DataStreamTableAPISuite.V1FallbackTestTableName spark.conf.set(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION.key, - classOf[V2InMemoryStreamCatalog].getName) + classOf[InMemoryStreamTableCatalog].getName) val v2Source = classOf[FakeV2Provider].getName withTempDir { tempDir => withTable(tblName) { @@ -439,7 +439,7 @@ class NonStreamV2Table(override val name: String) } -class V2InMemoryStreamCatalog extends V2InMemoryCatalog { +class InMemoryStreamTableCatalog extends InMemoryTableCatalog { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def createTable( From ee56ea76f53b4a8c71f1873fb935a8575a488065 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 11:30:51 -0700 Subject: [PATCH 12/21] remove UnresolvedFunction.name --- .../sql/catalyst/analysis/Analyzer.scala | 21 ++++++++++--------- .../sql/catalyst/analysis/unresolved.scala | 2 -- .../analysis/LookupFunctionsSuite.scala | 5 +++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 42e3679a0d4be..6ea32f9bc9acd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1573,8 +1573,8 @@ class Analyzer(override val catalogManager: CatalogManager) // results and confuse users if there is any null values. For count(t1.*, t2.*), it is // still allowed, since it's well-defined in spark. if (!conf.allowStarWithSingleTableIdentifierInCount && - f1.name.database.isEmpty && - f1.name.funcName == "count" && + f1.multipartIdentifier.length == 1 && + f1.multipartIdentifier.head == "count" && f1.arguments.length == 1) { f1.arguments.foreach { case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) => @@ -1959,6 +1959,7 @@ class Analyzer(override val catalogManager: CatalogManager) * @see https://issues.apache.org/jira/browse/SPARK-19737 */ object LookupFunctions extends Rule[LogicalPlan] { + import CatalogV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = { val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() plan.resolveExpressions { @@ -1970,17 +1971,17 @@ class Analyzer(override val catalogManager: CatalogManager) } } f - case f: UnresolvedFunction - if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f - case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f - case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) => - externalFunctionNameSet.add(normalizeFuncName(f.name)) + case f @ UnresolvedFunction(AsFunctionIdentifier(name), _, _, _, _) + if externalFunctionNameSet.contains(normalizeFuncName(name)) => f + case f @ UnresolvedFunction(AsFunctionIdentifier(name), _, _, _, _) + if v1SessionCatalog.isRegisteredFunction(name) => f + case f @ UnresolvedFunction(AsFunctionIdentifier(name), _, _, _, _) + if v1SessionCatalog.isPersistentFunction(name) => + externalFunctionNameSet.add(normalizeFuncName(name)) f case f: UnresolvedFunction => withPosition(f) { - throw new NoSuchFunctionException( - f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase), - f.name.funcName) + throw new NoSuchFunctionException(f.multipartIdentifier.asIdentifier) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index f8a10c4afc9c2..7c6f550597bff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -297,8 +297,6 @@ case class UnresolvedFunction( copy(arguments = newChildren) } } - - def name: FunctionIdentifier = multipartIdentifier.asFunctionIdentifier } object UnresolvedFunction { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala index e0f3c9a835b6e..92d3b517d0282 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ class LookupFunctionsSuite extends PlanTest { @@ -49,7 +50,7 @@ class LookupFunctionsSuite extends PlanTest { assert(externalCatalog.getFunctionExistsCalledTimes == 1) assert(analyzer.LookupFunctions.normalizeFuncName - (unresolvedPersistentFunc.name).database == Some("default")) + (unresolvedPersistentFunc.nameParts.asFunctionIdentifier).database == Some("default")) } test("SPARK-23486: the functionExists for the Registered function check") { @@ -72,7 +73,7 @@ class LookupFunctionsSuite extends PlanTest { assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) assert(analyzer.LookupFunctions.normalizeFuncName - (unresolvedRegisteredFunc.name).database == Some("default")) + (unresolvedRegisteredFunc.nameParts.asFunctionIdentifier).database == Some("default")) } } From 91038d170a1c884c035edf567d659f59d219e895 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 11:31:49 -0700 Subject: [PATCH 13/21] rename multipartIdentifier to nameParts --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6ea32f9bc9acd..add0a20f7e33b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1573,8 +1573,8 @@ class Analyzer(override val catalogManager: CatalogManager) // results and confuse users if there is any null values. For count(t1.*, t2.*), it is // still allowed, since it's well-defined in spark. if (!conf.allowStarWithSingleTableIdentifierInCount && - f1.multipartIdentifier.length == 1 && - f1.multipartIdentifier.head == "count" && + f1.nameParts.length == 1 && + f1.nameParts.head == "count" && f1.arguments.length == 1) { f1.arguments.foreach { case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) => @@ -1981,7 +1981,7 @@ class Analyzer(override val catalogManager: CatalogManager) f case f: UnresolvedFunction => withPosition(f) { - throw new NoSuchFunctionException(f.multipartIdentifier.asIdentifier) + throw new NoSuchFunctionException(f.nameParts.asIdentifier) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 7c6f550597bff..6fcde63ca225c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -269,7 +269,7 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio } case class UnresolvedFunction( - multipartIdentifier: Seq[String], + nameParts: Seq[String], arguments: Seq[Expression], isDistinct: Boolean, filter: Option[Expression] = None, @@ -283,10 +283,10 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException("nullable") override lazy val resolved = false - override def prettyName: String = multipartIdentifier.quoted + override def prettyName: String = nameParts.quoted override def toString: String = { val distinct = if (isDistinct) "distinct " else "" - s"'${multipartIdentifier.quoted}($distinct${children.mkString(", ")})" + s"'${nameParts.quoted}($distinct${children.mkString(", ")})" } override protected def withNewChildrenInternal( From eeccf6b1faaacb7755137b02306df3ad86ba819e Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 11:42:32 -0700 Subject: [PATCH 14/21] remove unnecessary changes --- .../sql/catalyst/parser/AstBuilder.scala | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 76ef7aeefed0b..da7c5c0221bd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1838,11 +1838,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create a function database (optional) and name pair, for multipartIdentifier. - * This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS. + * Create a function database (optional) and name pair. */ - protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = { - visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq) + private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = { + texts match { + case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) + case Seq(fn) => FunctionIdentifier(fn, None) + case other => + throw QueryParsingErrors.functionNameUnsupportedError(texts.mkString("."), ctx) + } } /** @@ -1865,18 +1869,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create a function database (optional) and name pair. - */ - private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = { - texts match { - case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) - case Seq(fn) => FunctionIdentifier(fn, None) - case other => - throw QueryParsingErrors.functionNameUnsupportedError(texts.mkString("."), ctx) - } - } - - /** * Create an [[LambdaFunction]]. */ override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { From 0ca1ca58483ee6f3ed25bff3e23cef3c6a5d15fc Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 12:14:34 -0700 Subject: [PATCH 15/21] remove star import --- .../spark/sql/connector/catalog/functions/JavaStrLen.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java index 93f969358c366..8b2d883a3703f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java @@ -21,7 +21,10 @@ import org.apache.spark.sql.connector.catalog.functions.BoundFunction; import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; public class JavaStrLen implements UnboundFunction { From 465737c1979b6bf03a2a93f361071e6bcfd116dd Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 16:31:43 -0700 Subject: [PATCH 16/21] keep old error message for V1 unresolved functions --- .../sql/catalyst/analysis/Analyzer.scala | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index add0a20f7e33b..20b6fc9ebcdd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1971,14 +1971,20 @@ class Analyzer(override val catalogManager: CatalogManager) } } f - case f @ UnresolvedFunction(AsFunctionIdentifier(name), _, _, _, _) - if externalFunctionNameSet.contains(normalizeFuncName(name)) => f - case f @ UnresolvedFunction(AsFunctionIdentifier(name), _, _, _, _) - if v1SessionCatalog.isRegisteredFunction(name) => f - case f @ UnresolvedFunction(AsFunctionIdentifier(name), _, _, _, _) - if v1SessionCatalog.isPersistentFunction(name) => - externalFunctionNameSet.add(normalizeFuncName(name)) + case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) + if externalFunctionNameSet.contains(normalizeFuncName(ident)) => f + case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) + if v1SessionCatalog.isRegisteredFunction(ident) => f + case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) + if v1SessionCatalog.isPersistentFunction(ident) => + externalFunctionNameSet.add(normalizeFuncName(ident)) f + case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) => + withPosition(f) { + throw new NoSuchFunctionException( + ident.database.getOrElse(v1SessionCatalog.getCurrentDatabase), + ident.funcName) + } case f: UnresolvedFunction => withPosition(f) { throw new NoSuchFunctionException(f.nameParts.asIdentifier) From 68e10014d2793d0bbe69f93b263e9817b0fa2913 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 22 Apr 2021 22:58:56 -0700 Subject: [PATCH 17/21] check FunctionCatalog in ResolveFunctions instead --- .../spark/sql/catalyst/analysis/Analyzer.scala | 17 ++++++++--------- .../connector/catalog/CatalogV2Implicits.scala | 2 +- .../connector/DataSourceV2FunctionSuite.scala | 7 +++++++ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 20b6fc9ebcdd8..89d39e11c0082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1963,13 +1963,7 @@ class Analyzer(override val catalogManager: CatalogManager) override def apply(plan: LogicalPlan): LogicalPlan = { val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() plan.resolveExpressions { - case f @ UnresolvedFunction(NonSessionCatalogAndIdentifier(catalog, name), _, _, _, _) => - if (!catalog.isFunctionCatalog) { - withPosition(f) { - throw new AnalysisException(s"Trying to lookup function '$name' in catalog" + - s" '${catalog.name()}', but '${catalog.name()}' is not a FunctionCatalog.") - } - } + case f @ UnresolvedFunction(NonSessionCatalogAndIdentifier(_, _), _, _, _, _) => f case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) if externalFunctionNameSet.contains(normalizeFuncName(ident)) => f @@ -2055,8 +2049,13 @@ class Analyzer(override val catalogManager: CatalogManager) resultExpression.getOrElse( expandIdentifier(parts) match { - case NonSessionCatalogAndIdentifier(catalog: FunctionCatalog, ident) => - lookupV2Function(catalog, ident, arguments, isDistinct, filter, ignoreNulls) + case NonSessionCatalogAndIdentifier(catalog, ident) => + if (!catalog.isFunctionCatalog) { + throw new AnalysisException(s"Trying to lookup function '$ident' in catalog" + + s" '${catalog.name()}', but it is not a FunctionCatalog.") + } + lookupV2Function(catalog.asFunctionCatalog, ident, arguments, isDistinct, + filter, ignoreNulls) case _ => u } ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 56a11cd60a0ab..cc41d8ca9007f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -93,7 +93,7 @@ private[sql] object CatalogV2Implicits { case functionCatalog: FunctionCatalog => functionCatalog case _ => - throw new UnsupportedOperationException( + throw new AnalysisException( s"Cannot use catalog '${plugin.name}': not a FunctionCatalog") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index 882890cc51516..ceac38330e18f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -53,6 +53,13 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { } } + test("built-in with non-function catalog should still work") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat", + "spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + } + test("built-in with default v2 function catalog") { withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { checkAnswer(sql("SELECT length('abc')"), Row(3)) From f25b5e683f796c353296adb519b4882e37fbc213 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 26 Apr 2021 16:55:50 -0700 Subject: [PATCH 18/21] refactoring & fix higher-order functions --- .../catalog/functions/ScalarFunction.java | 1 - .../sql/catalyst/analysis/Analyzer.scala | 76 ++++++++----------- .../sql/connector/catalog/LookupCatalog.scala | 3 + .../connector/DataSourceV2FunctionSuite.scala | 10 +++ 4 files changed, 43 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index a1526a46899d2..ef755aae3fb07 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -78,7 +78,6 @@ * {@link org.apache.spark.sql.catalyst.util.ArrayData}

  • *
  • {@link org.apache.spark.sql.types.MapType}: * {@link org.apache.spark.sql.catalyst.util.MapData}
  • - *
  • any other type: {@code Object}
  • * * * @param the JVM type of result values diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 89d39e11c0082..2dc0209514341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1573,8 +1573,7 @@ class Analyzer(override val catalogManager: CatalogManager) // results and confuse users if there is any null values. For count(t1.*, t2.*), it is // still allowed, since it's well-defined in spark. if (!conf.allowStarWithSingleTableIdentifierInCount && - f1.nameParts.length == 1 && - f1.nameParts.head == "count" && + f1.nameParts == Seq("count") && f1.arguments.length == 1) { f1.arguments.foreach { case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) => @@ -1959,30 +1958,26 @@ class Analyzer(override val catalogManager: CatalogManager) * @see https://issues.apache.org/jira/browse/SPARK-19737 */ object LookupFunctions extends Rule[LogicalPlan] { - import CatalogV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = { val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() plan.resolveExpressions { - case f @ UnresolvedFunction(NonSessionCatalogAndIdentifier(_, _), _, _, _, _) => - f - case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) - if externalFunctionNameSet.contains(normalizeFuncName(ident)) => f - case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) - if v1SessionCatalog.isRegisteredFunction(ident) => f - case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) - if v1SessionCatalog.isPersistentFunction(ident) => - externalFunctionNameSet.add(normalizeFuncName(ident)) - f case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) => - withPosition(f) { - throw new NoSuchFunctionException( - ident.database.getOrElse(v1SessionCatalog.getCurrentDatabase), - ident.funcName) + if (externalFunctionNameSet.contains(normalizeFuncName(ident)) || + v1SessionCatalog.isRegisteredFunction(ident)) { + f + } else if (v1SessionCatalog.isPersistentFunction(ident)) { + externalFunctionNameSet.add(normalizeFuncName(ident)) + f + } else { + withPosition(f) { + throw new NoSuchFunctionException( + ident.database.getOrElse(v1SessionCatalog.getCurrentDatabase), + ident.funcName) + } } case f: UnresolvedFunction => - withPosition(f) { - throw new NoSuchFunctionException(f.nameParts.asIdentifier) - } + // v2 functions - do nothing for now + f } } @@ -2030,35 +2025,24 @@ class Analyzer(override val catalogManager: CatalogManager) } } - case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, - isDistinct, filter, ignoreNulls) => withPosition(u) { - processFunctionExpr(v1SessionCatalog.lookupFunction(ident, arguments), - arguments, isDistinct, filter, ignoreNulls) - } + case u @ UnresolvedFunction(AsFunctionIdentifier(ident), + arguments, isDistinct, filter, ignoreNulls) => withPosition(u) { + processFunctionExpr(v1SessionCatalog.lookupFunction(ident, arguments), + arguments, isDistinct, filter, ignoreNulls) + } - case u @ UnresolvedFunction(parts, arguments, isDistinct, filter, ignoreNulls) => + case u @ UnresolvedFunction(nameParts, arguments, isDistinct, filter, ignoreNulls) => withPosition(u) { - // resolve built-in or temporary functions with v2 catalog - val resultExpression = if (parts.length == 1) { - v1SessionCatalog.lookupBuiltinOrTempFunction(parts.head, arguments).map( - processFunctionExpr(_, arguments, isDistinct, filter, ignoreNulls) - ) - } else { - None + expandIdentifier(nameParts) match { + case NonSessionCatalogAndIdentifier(catalog, ident) => + if (!catalog.isFunctionCatalog) { + throw new AnalysisException(s"Trying to lookup function '$ident' in " + + s"catalog '${catalog.name()}', but it is not a FunctionCatalog.") + } + lookupV2Function(catalog.asFunctionCatalog, ident, arguments, isDistinct, + filter, ignoreNulls) + case _ => u } - - resultExpression.getOrElse( - expandIdentifier(parts) match { - case NonSessionCatalogAndIdentifier(catalog, ident) => - if (!catalog.isFunctionCatalog) { - throw new AnalysisException(s"Trying to lookup function '$ident' in catalog" + - s" '${catalog.name()}', but it is not a FunctionCatalog.") - } - lookupV2Function(catalog.asFunctionCatalog, ident, arguments, isDistinct, - filter, ignoreNulls) - case _ => u - } - ) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index 461d5864cf1d0..dcd352267a178 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -162,6 +162,9 @@ private[sql] trait LookupCatalog extends Logging { case _ => None } parts match { + case Seq(name) + if catalogManager.v1SessionCatalog.isRegisteredFunction(FunctionIdentifier(name)) => + Some(FunctionIdentifier(name)) case CatalogAndMultipartIdentifier(None, names) if CatalogV2Util.isSessionCatalog(currentCatalog) => namesToFunctionIdentifier(names) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index ceac38330e18f..fe856ffecb84a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -66,6 +66,11 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { } } + test("looking up higher-order function with non-session catalog") { + checkAnswer(sql("SELECT transform(array(1, 2, 3), x -> x + 1)"), + Row(Array(2, 3, 4)) :: Nil) + } + test("built-in override with default v2 function catalog") { // a built-in function with the same name should take higher priority withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { @@ -74,6 +79,11 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { } } + test("built-in override with non-session catalog") { + addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + test("temp function override with default v2 function catalog") { val className = "test.org.apache.spark.sql.JavaStringLength" sql(s"CREATE FUNCTION length AS '$className'") From c453b64ae669c2be28a9467be8b5111a7563f2ab Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 26 Apr 2021 23:38:53 -0700 Subject: [PATCH 19/21] remove unnecessary changes --- .../sql/catalyst/analysis/Analyzer.scala | 300 ++++++++---------- .../sql/catalyst/catalog/SessionCatalog.scala | 50 +-- 2 files changed, 152 insertions(+), 198 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2dc0209514341..8c4a43c73b5d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2025,10 +2025,85 @@ class Analyzer(override val catalogManager: CatalogManager) } } - case u @ UnresolvedFunction(AsFunctionIdentifier(ident), - arguments, isDistinct, filter, ignoreNulls) => withPosition(u) { - processFunctionExpr(v1SessionCatalog.lookupFunction(ident, arguments), - arguments, isDistinct, filter, ignoreNulls) + case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, isDistinct, filter, + ignoreNulls) => withPosition(u) { + v1SessionCatalog.lookupFunction(ident, arguments) match { + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => + if (isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + wf.prettyName, "DISTINCT") + } else if (filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + wf.prettyName, "FILTER clause") + } else if (ignoreNulls) { + wf match { + case nthValue: NthValue => + nthValue.copy(ignoreNulls = ignoreNulls) + case _ => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + wf.prettyName, "IGNORE NULLS") + } + } else { + wf + } + case owf: FrameLessOffsetWindowFunction => + if (isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + owf.prettyName, "DISTINCT") + } else if (filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + owf.prettyName, "FILTER clause") + } else if (ignoreNulls) { + owf match { + case lead: Lead => + lead.copy(ignoreNulls = ignoreNulls) + case lag: Lag => + lag.copy(ignoreNulls = ignoreNulls) + } + } else { + owf + } + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg: AggregateFunction => + if (filter.isDefined && !filter.get.deterministic) { + throw QueryCompilationErrors.nonDeterministicFilterInAggregateError + } + if (ignoreNulls) { + val aggFunc = agg match { + case first: First => first.copy(ignoreNulls = ignoreNulls) + case last: Last => last.copy(ignoreNulls = ignoreNulls) + case _ => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + agg.prettyName, "IGNORE NULLS") + } + AggregateExpression(aggFunc, Complete, isDistinct, filter) + } else { + AggregateExpression(agg, Complete, isDistinct, filter) + } + // This function is not an aggregate function, just return the resolved one. + case other if isDistinct => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + other.prettyName, "DISTINCT") + case other if filter.isDefined => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + other.prettyName, "FILTER clause") + case other if ignoreNulls => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + other.prettyName, "IGNORE NULLS") + case e: String2TrimExpression if arguments.size == 2 => + if (trimWarningEnabled.get) { + log.warn("Two-parameter TRIM/LTRIM/RTRIM function signatures are deprecated." + + " Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM str)`" + + " instead.") + trimWarningEnabled.set(false) + } + e + case other => + other + } } case u @ UnresolvedFunction(nameParts, arguments, isDistinct, filter, ignoreNulls) => @@ -2039,8 +2114,68 @@ class Analyzer(override val catalogManager: CatalogManager) throw new AnalysisException(s"Trying to lookup function '$ident' in " + s"catalog '${catalog.name()}', but it is not a FunctionCatalog.") } - lookupV2Function(catalog.asFunctionCatalog, ident, arguments, isDistinct, - filter, ignoreNulls) + + val unbound = catalog.asFunctionCatalog.loadFunction(ident) + val inputType = StructType(arguments.zipWithIndex.map { + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) + }) + val bound = try { + unbound.bind(inputType) + } catch { + case unsupported: UnsupportedOperationException => + throw new AnalysisException(s"Function '${unbound.name}' cannot process " + + s"input: (${arguments.map(_.dataType.simpleString).mkString(", ")}): " + + unsupported.getMessage, cause = Some(unsupported)) + } + + bound match { + case scalarFunc: ScalarFunction[_] => + if (isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "DISTINCT") + } else if (filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "FILTER clause") + } else if (ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "IGNORE NULLS") + } else { + // TODO: implement type coercion by looking at input type from the UDF. We + // may also want to check if the parameter types from the magic method + // match the input type through `BoundFunction.inputTypes`. + val argClasses = inputType.fields.map(_.dataType) + findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { + case Some(_) => + val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) + Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), + arguments, returnNullable = scalarFunc.isResultNullable) + case _ => + // TODO: handle functions defined in Scala too - in Scala, even if a + // subclass do not override the default method in parent interface + // defined in Java, the method can still be found from + // `getDeclaredMethod`. + // since `inputType` is a `StructType`, it is mapped to a `InternalRow` + // which we can use to lookup the `produceResult` method. + findMethod(scalarFunc, "produceResult", Seq(inputType)) match { + case Some(_) => + ApplyFunctionExpression(scalarFunc, arguments) + case None => + failAnalysis(s"ScalarFunction '${bound.name()}' neither implement" + + s" magic method nor override 'produceResult'") + } + } + } + case aggFunc: V2AggregateFunction[_, _] => + if (ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + aggFunc.name(), "IGNORE NULLS") + } + val aggregator = V2Aggregator(aggFunc, arguments) + AggregateExpression(aggregator, Complete, isDistinct, filter) + case _ => + failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction" + + s" or AggregateFunction") + } case _ => u } } @@ -2064,159 +2199,6 @@ class Analyzer(override val catalogManager: CatalogManager) None } } - - private def processFunctionExpr( - expr: Expression, - arguments: Seq[Expression], - isDistinct: Boolean, - filter: Option[Expression], - ignoreNulls: Boolean): Expression = expr match { - // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within - // the context of a Window clause. They do not need to be wrapped in an - // AggregateExpression. - case wf: AggregateWindowFunction => - if (isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - wf.prettyName, "DISTINCT") - } else if (filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - wf.prettyName, "FILTER clause") - } else if (ignoreNulls) { - wf match { - case nthValue: NthValue => - nthValue.copy(ignoreNulls = ignoreNulls) - case _ => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - wf.prettyName, "IGNORE NULLS") - } - } else { - wf - } - case owf: FrameLessOffsetWindowFunction => - if (isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - owf.prettyName, "DISTINCT") - } else if (filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - owf.prettyName, "FILTER clause") - } else if (ignoreNulls) { - owf match { - case lead: Lead => - lead.copy(ignoreNulls = ignoreNulls) - case lag: Lag => - lag.copy(ignoreNulls = ignoreNulls) - } - } else { - owf - } - // We get an aggregate function, we need to wrap it in an AggregateExpression. - case agg: AggregateFunction => - if (filter.isDefined && !filter.get.deterministic) { - throw QueryCompilationErrors.nonDeterministicFilterInAggregateError - } - if (ignoreNulls) { - val aggFunc = agg match { - case first: First => first.copy(ignoreNulls = ignoreNulls) - case last: Last => last.copy(ignoreNulls = ignoreNulls) - case _ => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - agg.prettyName, "IGNORE NULLS") - } - AggregateExpression(aggFunc, Complete, isDistinct, filter) - } else { - AggregateExpression(agg, Complete, isDistinct, filter) - } - // This function is not an aggregate function, just return the resolved one. - case other if isDistinct => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - other.prettyName, "DISTINCT") - case other if filter.isDefined => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - other.prettyName, "FILTER clause") - case other if ignoreNulls => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - other.prettyName, "IGNORE NULLS") - case e: String2TrimExpression if arguments.size == 2 => - if (trimWarningEnabled.get) { - log.warn("Two-parameter TRIM/LTRIM/RTRIM function signatures are deprecated." + - " Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM str)`" + - " instead.") - trimWarningEnabled.set(false) - } - e - case other => - other - } - - private def lookupV2Function( - catalog: FunctionCatalog, - ident: Identifier, - arguments: Seq[Expression], - isDistinct: Boolean, - filter: Option[Expression], - ignoreNulls: Boolean): Expression = { - val unbound = catalog.loadFunction(ident) - val inputType = StructType(arguments.zipWithIndex.map { - case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) - }) - val bound = try { - unbound.bind(inputType) - } catch { - case unsupported: UnsupportedOperationException => - throw new AnalysisException(s"Function '${unbound.name}' cannot process input: " + - s"(${arguments.map(_.dataType.simpleString).mkString(", ")}): " + - unsupported.getMessage, cause = Some(unsupported)) - } - - bound match { - case scalarFunc: ScalarFunction[_] => - if (isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "DISTINCT") - } else if (filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "FILTER clause") - } else if (ignoreNulls) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "IGNORE NULLS") - } else { - // TODO: implement type coercion by looking at input type from the UDF. We may - // also want to check if the parameter types from the magic method match the - // input type through `BoundFunction.inputTypes`. - val argClasses = inputType.fields.map(_.dataType) - findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { - case Some(_) => - val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) - Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), - arguments, returnNullable = scalarFunc.isResultNullable) - case _ => - // TODO: handle functions defined in Scala too - in Scala, even if a - // subclass do not override the default method in parent interface - // defined in Java, the method can still be found from - // `getDeclaredMethod`. - // since `inputType` is a `StructType`, it is mapped to a `InternalRow` which we - // can use to lookup the `produceResult` method. - findMethod(scalarFunc, "produceResult", Seq(inputType)) match { - case Some(_) => - ApplyFunctionExpression(scalarFunc, arguments) - case None => - failAnalysis(s"ScalarFunction '${bound.name()}' neither implement " + - s"magic method nor override 'produceResult'") - } - } - } - case aggFunc: V2AggregateFunction[_, _] => - if (ignoreNulls) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - aggFunc.name(), "IGNORE NULLS") - } - val aggregator = V2Aggregator(aggFunc, arguments) - AggregateExpression(aggregator, Complete, isDistinct, filter) - case _ => - failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction " + - s"or AggregateFunction") - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 94bd540613ce8..0813d41af1617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1559,34 +1559,6 @@ class SessionCatalog( } } - /** - * Lookup `registry` and check if a built-in or temporary function is defined for the input - * `name`. None if no such function exists. - * - * This is currently used by both V1 function lookup (in `lookupFunction`), and V2 - * function lookup (in `Analyzer`). - */ - private def lookupBuiltinOrTempFunctionInfo[T]( - name: String, - children: Seq[Expression], - registry: FunctionRegistryBase[T]): Option[T] = synchronized { - val ident = FunctionIdentifier(name) - if (registry.functionExists(ident)) { - val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames - val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty - // Lookup the function as a temporary or a built-in function (i.e. without database) and - // 1. if we are not resolving view, we don't care about the function type and just return it. - // 2. if we are resolving view, only return a temp function if it's referred by this view. - if (!isResolvingView || - !isTemporaryFunction(ident) || - referredTempFunctionNames.contains(ident.funcName)) { - // This function has been already loaded into the function registry. - return Some(registry.lookupFunction(ident, children)) - } - } - None - } - /** * Look up a specific function, assuming it exists. * @@ -1609,10 +1581,17 @@ class SessionCatalog( // Note: the implementation of this function is a little bit convoluted. // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions // (built-in, temp, and external). - if (name.database.isEmpty) { - val funcInfo = lookupBuiltinOrTempFunctionInfo(name.funcName, children, registry) - if (funcInfo.isDefined) { - return funcInfo.get + if (name.database.isEmpty && registry.functionExists(name)) { + val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames + val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty + // Lookup the function as a temporary or a built-in function (i.e. without database) and + // 1. if we are not resolving view, we don't care about the function type and just return it. + // 2. if we are resolving view, only return a temp function if it's referred by this view. + if (!isResolvingView || + !isTemporaryFunction(name) || + referredTempFunctionNames.contains(name.funcName)) { + // This function has been already loaded into the function registry. + return registry.lookupFunction(name, children) } } @@ -1669,13 +1648,6 @@ class SessionCatalog( lookupFunction[LogicalPlan](name, children, tableFunctionRegistry) } - /** - * Looks up a built-in or temporary function with the given `name`. Returns `None` if the - * function doesn't exist. - */ - def lookupBuiltinOrTempFunction(name: String, children: Seq[Expression]): Option[Expression] = { - lookupBuiltinOrTempFunctionInfo[Expression](name, children, functionRegistry) - } /** * List all functions in the specified database, including temporary functions. This * returns the function identifier and the scope in which it was defined (system or user From 790d27f900c0ed456c198a5175410961223990f8 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 27 Apr 2021 00:28:52 -0700 Subject: [PATCH 20/21] fix test failure --- .../spark/sql/catalyst/analysis/LookupFunctionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala index 92d3b517d0282..85e0b1062c81f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -71,7 +71,7 @@ class LookupFunctionsSuite extends PlanTest { table("TaBlE")) analyzer.LookupFunctions.apply(plan) - assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) + assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 4) assert(analyzer.LookupFunctions.normalizeFuncName (unresolvedRegisteredFunc.nameParts.asFunctionIdentifier).database == Some("default")) } From c18715fcc304f4a805d420e193d4a3884b3f7e4f Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 27 Apr 2021 12:23:19 -0700 Subject: [PATCH 21/21] address more comments --- .../catalog/functions/AggregateFunction.java | 2 + .../sql/catalyst/analysis/Analyzer.scala | 107 +++++++++++------- 2 files changed, 65 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java index d4af2f547d300..4181feafed101 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java @@ -29,6 +29,8 @@ * and update the aggregation state. The JVM type of result values produced by * {@link #produceResult} must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. + * Please refer to class documentation of {@link ScalarFunction} for the mapping between + * {@link DataType} and the JVM type. *

    * All implementations must support partial aggregation by implementing merge so that Spark can * partially aggregate and shuffle intermediate results, instead of shuffling all rows for an diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8c4a43c73b5d0..1bca8f24538b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1975,9 +1975,6 @@ class Analyzer(override val catalogManager: CatalogManager) ident.funcName) } } - case f: UnresolvedFunction => - // v2 functions - do nothing for now - f } } @@ -2130,58 +2127,80 @@ class Analyzer(override val catalogManager: CatalogManager) bound match { case scalarFunc: ScalarFunction[_] => - if (isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "DISTINCT") - } else if (filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "FILTER clause") - } else if (ignoreNulls) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "IGNORE NULLS") - } else { - // TODO: implement type coercion by looking at input type from the UDF. We - // may also want to check if the parameter types from the magic method - // match the input type through `BoundFunction.inputTypes`. - val argClasses = inputType.fields.map(_.dataType) - findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { - case Some(_) => - val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) - Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), - arguments, returnNullable = scalarFunc.isResultNullable) - case _ => - // TODO: handle functions defined in Scala too - in Scala, even if a - // subclass do not override the default method in parent interface - // defined in Java, the method can still be found from - // `getDeclaredMethod`. - // since `inputType` is a `StructType`, it is mapped to a `InternalRow` - // which we can use to lookup the `produceResult` method. - findMethod(scalarFunc, "produceResult", Seq(inputType)) match { - case Some(_) => - ApplyFunctionExpression(scalarFunc, arguments) - case None => - failAnalysis(s"ScalarFunction '${bound.name()}' neither implement" + - s" magic method nor override 'produceResult'") - } - } - } + processV2ScalarFunction(scalarFunc, inputType, arguments, isDistinct, + filter, ignoreNulls) case aggFunc: V2AggregateFunction[_, _] => - if (ignoreNulls) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - aggFunc.name(), "IGNORE NULLS") - } - val aggregator = V2Aggregator(aggFunc, arguments) - AggregateExpression(aggregator, Complete, isDistinct, filter) + processV2AggregateFunction(aggFunc, arguments, isDistinct, filter, + ignoreNulls) case _ => failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction" + s" or AggregateFunction") } + case _ => u } } } } + private def processV2ScalarFunction( + scalarFunc: ScalarFunction[_], + inputType: StructType, + arguments: Seq[Expression], + isDistinct: Boolean, + filter: Option[Expression], + ignoreNulls: Boolean): Expression = { + if (isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "DISTINCT") + } else if (filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "FILTER clause") + } else if (ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "IGNORE NULLS") + } else { + // TODO: implement type coercion by looking at input type from the UDF. We + // may also want to check if the parameter types from the magic method + // match the input type through `BoundFunction.inputTypes`. + val argClasses = inputType.fields.map(_.dataType) + findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { + case Some(_) => + val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) + Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), + arguments, returnNullable = scalarFunc.isResultNullable) + case _ => + // TODO: handle functions defined in Scala too - in Scala, even if a + // subclass do not override the default method in parent interface + // defined in Java, the method can still be found from + // `getDeclaredMethod`. + // since `inputType` is a `StructType`, it is mapped to a `InternalRow` + // which we can use to lookup the `produceResult` method. + findMethod(scalarFunc, "produceResult", Seq(inputType)) match { + case Some(_) => + ApplyFunctionExpression(scalarFunc, arguments) + case None => + failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" + + s" magic method nor override 'produceResult'") + } + } + } + } + + private def processV2AggregateFunction( + aggFunc: V2AggregateFunction[_, _], + arguments: Seq[Expression], + isDistinct: Boolean, + filter: Option[Expression], + ignoreNulls: Boolean): Expression = { + if (ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + aggFunc.name(), "IGNORE NULLS") + } + val aggregator = V2Aggregator(aggFunc, arguments) + AggregateExpression(aggregator, Complete, isDistinct, filter) + } + /** * Check if the input `fn` implements the given `methodName` with parameter types specified * via `inputType`.