diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java index 6982ebb329ff3..4181feafed101 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java @@ -25,13 +25,12 @@ /** * Interface for a function that produces a result value by aggregating over multiple input rows. *

- * For each input row, Spark will call an update method that corresponds to the - * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by - * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call - * update with {@link InternalRow}. - *

- * The JVM type of result values produced by this function must be the type used by Spark's + * For each input row, Spark will call the {@link #update} method which should evaluate the row + * and update the aggregation state. The JVM type of result values produced by + * {@link #produceResult} must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. + * Please refer to class documentation of {@link ScalarFunction} for the mapping between + * {@link DataType} and the JVM type. *

* All implementations must support partial aggregation by implementing merge so that Spark can * partially aggregate and shuffle intermediate results, instead of shuffling all rows for an @@ -68,9 +67,7 @@ public interface AggregateFunction extends BoundFunct * @param input an input row * @return updated aggregation state */ - default S update(S state, InternalRow input) { - throw new UnsupportedOperationException("Cannot find a compatible AggregateFunction#update"); - } + S update(S state, InternalRow input); /** * Merge two partial aggregation states. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index c2106a21c4a8f..ef755aae3fb07 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -23,17 +23,67 @@ /** * Interface for a function that produces a result value for each input row. *

- * For each input row, Spark will call a produceResult method that corresponds to the - * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by - * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call - * {@link #produceResult(InternalRow)}. + * To evaluate each input row, Spark will first try to lookup and use a "magic method" (described + * below) through Java reflection. If the method is not found, Spark will call + * {@link #produceResult(InternalRow)} as a fallback approach. *

* The JVM type of result values produced by this function must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. + *

+ * IMPORTANT: the default implementation of {@link #produceResult} throws + * {@link UnsupportedOperationException}. Users can choose to override this method, or implement + * a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters + * instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java + * reflection and will also provide better performance in general, due to optimizations such as + * codegen, removal of Java boxing, etc. + * + * For example, a scalar UDF for adding two integers can be defined as follow with the magic + * method approach: + * + *

+ *   public class IntegerAdd implements{@code ScalarFunction} {
+ *     public int invoke(int left, int right) {
+ *       return left + right;
+ *     }
+ *   }
+ * 
+ * In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over + * {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by + * first converting the actual input SQL data types to their corresponding Java types following + * the mapping defined below, and then checking if there is a matching method from all the + * declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and + * the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}. + *

+ * The following are the mapping from {@link DataType SQL data type} to Java type through + * the magic method approach: + *

* * @param the JVM type of result values */ public interface ScalarFunction extends BoundFunction { + String MAGIC_METHOD_NAME = "invoke"; /** * Applies the function to an input row to produce a value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 87b8d52ac277f..1bca8f24538b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Method import java.util import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean @@ -29,7 +30,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes -import org.apache.spark.sql.catalyst.expressions.{FrameLessOffsetWindowFunction, _} +import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -44,6 +45,8 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} +import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction} +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -281,7 +284,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables :: - ResolveHigherOrderFunctions(v1SessionCatalog) :: + ResolveHigherOrderFunctions(catalogManager) :: ResolveLambdaVariables :: ResolveTimeZone :: ResolveRandomSeed :: @@ -895,9 +898,10 @@ class Analyzer(override val catalogManager: CatalogManager) } } - // If we are resolving relations insides views, we need to expand single-part relation names with - // the current catalog and namespace of when the view was created. - private def expandRelationName(nameParts: Seq[String]): Seq[String] = { + // If we are resolving database objects (relations, functions, etc.) insides views, we may need to + // expand single or multi-part identifiers with the current catalog and namespace of when the + // view was created. + private def expandIdentifier(nameParts: Seq[String]): Seq[String] = { if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts if (nameParts.length == 1) { @@ -1040,7 +1044,7 @@ class Analyzer(override val catalogManager: CatalogManager) identifier: Seq[String], options: CaseInsensitiveStringMap, isStreaming: Boolean): Option[LogicalPlan] = - expandRelationName(identifier) match { + expandIdentifier(identifier) match { case NonSessionCatalogAndIdentifier(catalog, ident) => CatalogV2Util.loadTable(catalog, ident) match { case Some(table) => @@ -1153,7 +1157,7 @@ class Analyzer(override val catalogManager: CatalogManager) } private def lookupTableOrView(identifier: Seq[String]): Option[LogicalPlan] = { - expandRelationName(identifier) match { + expandIdentifier(identifier) match { case SessionCatalogAndIdentifier(catalog, ident) => CatalogV2Util.loadTable(catalog, ident).map { case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW => @@ -1173,7 +1177,7 @@ class Analyzer(override val catalogManager: CatalogManager) identifier: Seq[String], options: CaseInsensitiveStringMap, isStreaming: Boolean): Option[LogicalPlan] = { - expandRelationName(identifier) match { + expandIdentifier(identifier) match { case SessionCatalogAndIdentifier(catalog, ident) => lazy val loaded = CatalogV2Util.loadTable(catalog, ident).map { case v1Table: V1Table => @@ -1569,8 +1573,7 @@ class Analyzer(override val catalogManager: CatalogManager) // results and confuse users if there is any null values. For count(t1.*, t2.*), it is // still allowed, since it's well-defined in spark. if (!conf.allowStarWithSingleTableIdentifierInCount && - f1.name.database.isEmpty && - f1.name.funcName == "count" && + f1.nameParts == Seq("count") && f1.arguments.length == 1) { f1.arguments.foreach { case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) => @@ -1958,17 +1961,19 @@ class Analyzer(override val catalogManager: CatalogManager) override def apply(plan: LogicalPlan): LogicalPlan = { val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() plan.resolveExpressions { - case f: UnresolvedFunction - if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f - case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f - case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) => - externalFunctionNameSet.add(normalizeFuncName(f.name)) - f - case f: UnresolvedFunction => - withPosition(f) { - throw new NoSuchFunctionException( - f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase), - f.name.funcName) + case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) => + if (externalFunctionNameSet.contains(normalizeFuncName(ident)) || + v1SessionCatalog.isRegisteredFunction(ident)) { + f + } else if (v1SessionCatalog.isPersistentFunction(ident)) { + externalFunctionNameSet.add(normalizeFuncName(ident)) + f + } else { + withPosition(f) { + throw new NoSuchFunctionException( + ident.database.getOrElse(v1SessionCatalog.getCurrentDatabase), + ident.funcName) + } } } } @@ -2016,9 +2021,10 @@ class Analyzer(override val catalogManager: CatalogManager) name, other.getClass.getCanonicalName) } } - case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter, ignoreNulls) => - withPosition(u) { - v1SessionCatalog.lookupFunction(funcId, arguments) match { + + case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, isDistinct, filter, + ignoreNulls) => withPosition(u) { + v1SessionCatalog.lookupFunction(ident, arguments) match { // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. @@ -2095,9 +2101,123 @@ class Analyzer(override val catalogManager: CatalogManager) case other => other } + } + + case u @ UnresolvedFunction(nameParts, arguments, isDistinct, filter, ignoreNulls) => + withPosition(u) { + expandIdentifier(nameParts) match { + case NonSessionCatalogAndIdentifier(catalog, ident) => + if (!catalog.isFunctionCatalog) { + throw new AnalysisException(s"Trying to lookup function '$ident' in " + + s"catalog '${catalog.name()}', but it is not a FunctionCatalog.") + } + + val unbound = catalog.asFunctionCatalog.loadFunction(ident) + val inputType = StructType(arguments.zipWithIndex.map { + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) + }) + val bound = try { + unbound.bind(inputType) + } catch { + case unsupported: UnsupportedOperationException => + throw new AnalysisException(s"Function '${unbound.name}' cannot process " + + s"input: (${arguments.map(_.dataType.simpleString).mkString(", ")}): " + + unsupported.getMessage, cause = Some(unsupported)) + } + + bound match { + case scalarFunc: ScalarFunction[_] => + processV2ScalarFunction(scalarFunc, inputType, arguments, isDistinct, + filter, ignoreNulls) + case aggFunc: V2AggregateFunction[_, _] => + processV2AggregateFunction(aggFunc, arguments, isDistinct, filter, + ignoreNulls) + case _ => + failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction" + + s" or AggregateFunction") + } + + case _ => u + } } } } + + private def processV2ScalarFunction( + scalarFunc: ScalarFunction[_], + inputType: StructType, + arguments: Seq[Expression], + isDistinct: Boolean, + filter: Option[Expression], + ignoreNulls: Boolean): Expression = { + if (isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "DISTINCT") + } else if (filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "FILTER clause") + } else if (ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), "IGNORE NULLS") + } else { + // TODO: implement type coercion by looking at input type from the UDF. We + // may also want to check if the parameter types from the magic method + // match the input type through `BoundFunction.inputTypes`. + val argClasses = inputType.fields.map(_.dataType) + findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { + case Some(_) => + val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) + Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), + arguments, returnNullable = scalarFunc.isResultNullable) + case _ => + // TODO: handle functions defined in Scala too - in Scala, even if a + // subclass do not override the default method in parent interface + // defined in Java, the method can still be found from + // `getDeclaredMethod`. + // since `inputType` is a `StructType`, it is mapped to a `InternalRow` + // which we can use to lookup the `produceResult` method. + findMethod(scalarFunc, "produceResult", Seq(inputType)) match { + case Some(_) => + ApplyFunctionExpression(scalarFunc, arguments) + case None => + failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" + + s" magic method nor override 'produceResult'") + } + } + } + } + + private def processV2AggregateFunction( + aggFunc: V2AggregateFunction[_, _], + arguments: Seq[Expression], + isDistinct: Boolean, + filter: Option[Expression], + ignoreNulls: Boolean): Expression = { + if (ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + aggFunc.name(), "IGNORE NULLS") + } + val aggregator = V2Aggregator(aggFunc, arguments) + AggregateExpression(aggregator, Complete, isDistinct, filter) + } + + /** + * Check if the input `fn` implements the given `methodName` with parameter types specified + * via `inputType`. + */ + private def findMethod( + fn: BoundFunction, + methodName: String, + inputType: Seq[DataType]): Option[Method] = { + val cls = fn.getClass + try { + val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass) + Some(cls.getDeclaredMethod(methodName, argClasses: _*)) + } catch { + case _: NoSuchMethodException => + None + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 7d74c0d1cd14f..b86d5d75a4550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.DataType @@ -30,13 +30,14 @@ import org.apache.spark.sql.types.DataType * so we need to resolve higher order function when all children are either resolved or a lambda * function. */ -case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { +case class ResolveHigherOrderFunctions(catalogManager: CatalogManager) + extends Rule[LogicalPlan] with LookupCatalog { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case u @ UnresolvedFunction(fn, children, false, filter, ignoreNulls) + case u @ UnresolvedFunction(AsFunctionIdentifier(ident), children, false, filter, ignoreNulls) if hasLambdaAndResolvedArguments(children) => withPosition(u) { - catalog.lookupFunction(fn, children) match { + catalogManager.v1SessionCatalog.lookupFunction(ident, children) match { case func: HigherOrderFunction => filter.foreach(_.failAnalysis("FILTER predicate specified, " + s"but ${func.prettyName} is not an aggregate function")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 5001e2ea88ac7..6fcde63ca225c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -269,12 +269,13 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio } case class UnresolvedFunction( - name: FunctionIdentifier, + nameParts: Seq[String], arguments: Seq[Expression], isDistinct: Boolean, filter: Option[Expression] = None, ignoreNulls: Boolean = false) extends Expression with Unevaluable { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def children: Seq[Expression] = arguments ++ filter.toSeq @@ -282,10 +283,10 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException("nullable") override lazy val resolved = false - override def prettyName: String = name.unquotedString + override def prettyName: String = nameParts.quoted override def toString: String = { val distinct = if (isDistinct) "distinct " else "" - s"'$name($distinct${children.mkString(", ")})" + s"'${nameParts.quoted}($distinct${children.mkString(", ")})" } override protected def withNewChildrenInternal( @@ -299,8 +300,17 @@ case class UnresolvedFunction( } object UnresolvedFunction { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + def apply( + name: FunctionIdentifier, + arguments: Seq[Expression], + isDistinct: Boolean): UnresolvedFunction = { + UnresolvedFunction(name.asMultipart, arguments, isDistinct) + } + def apply(name: String, arguments: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = { - UnresolvedFunction(FunctionIdentifier(name, None), arguments, isDistinct) + UnresolvedFunction(Seq(name), arguments, isDistinct) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index d259d6a706d72..0813d41af1617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1576,6 +1576,8 @@ class SessionCatalog( name: FunctionIdentifier, children: Seq[Expression], registry: FunctionRegistryBase[T]): T = synchronized { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + // Note: the implementation of this function is a little bit convoluted. // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions // (built-in, temp, and external). @@ -1598,7 +1600,8 @@ class SessionCatalog( case Seq() => getCurrentDatabase case Seq(_, db) => db case Seq(catalog, namespace @ _*) => - throw QueryCompilationErrors.v2CatalogNotSupportFunctionError(catalog, namespace) + throw new IllegalStateException(s"[BUG] unexpected v2 catalog: $catalog, and " + + s"namespace: ${namespace.quoted} in v1 function lookup") } // If the name itself is not qualified, add the current database to it. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala new file mode 100644 index 0000000000000..ce9c933902942 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction +import org.apache.spark.sql.types.DataType + +case class ApplyFunctionExpression( + function: ScalarFunction[_], + children: Seq[Expression]) extends Expression with UserDefinedExpression with CodegenFallback { + override def nullable: Boolean = function.isResultNullable + override def name: String = function.name() + override def dataType: DataType = function.resultType() + + private lazy val reusedRow = new GenericInternalRow(children.size) + + /** Returns the result of evaluating this expression on a given input Row */ + override def eval(input: InternalRow): Any = { + children.zipWithIndex.foreach { + case (expr, pos) => + reusedRow.update(pos, expr.eval(input)) + } + + function.produceResult(reusedRow) + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala new file mode 100644 index 0000000000000..55e3f504ae2e6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection} +import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction} +import org.apache.spark.sql.types.DataType + +case class V2Aggregator[BUF <: java.io.Serializable, OUT]( + aggrFunc: V2AggregateFunction[BUF, OUT], + children: Seq[Expression], + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[BUF] { + private[this] lazy val inputProjection = UnsafeProjection.create(children) + + override def nullable: Boolean = aggrFunc.isResultNullable + override def dataType: DataType = aggrFunc.resultType() + override def createAggregationBuffer(): BUF = aggrFunc.newAggregationState() + + override def update(buffer: BUF, input: InternalRow): BUF = { + aggrFunc.update(buffer, inputProjection(input)) + } + + override def merge(buffer: BUF, input: BUF): BUF = aggrFunc.merge(buffer, input) + + override def eval(buffer: BUF): Any = { + aggrFunc.produceResult(buffer) + } + + override def serialize(buffer: BUF): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val out = new ObjectOutputStream(bos) + out.writeObject(buffer) + out.close() + bos.toByteArray + } + + override def deserialize(bytes: Array[Byte]): BUF = { + val in = new ObjectInputStream(new ByteArrayInputStream(bytes)) + in.readObject().asInstanceOf[BUF] + } + + def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): V2Aggregator[BUF, OUT] = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): V2Aggregator[BUF, OUT] = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd21beafbd85f..da7c5c0221bd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1818,7 +1818,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val ignoreNulls = Option(ctx.nullsOption).map(_.getType == SqlBaseParser.IGNORE).getOrElse(false) val function = UnresolvedFunction( - getFunctionIdentifier(ctx.functionName), arguments, isDistinct, filter, ignoreNulls) + getFunctionMultiparts(ctx.functionName), arguments, isDistinct, filter, ignoreNulls) // Check if the function is evaluated in a windowed context. ctx.windowSpec match { @@ -1830,15 +1830,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } } - - /** - * Create a function database (optional) and name pair, for multipartIdentifier. - * This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS. - */ - protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = { - visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq) - } - /** * Create a function database (optional) and name pair. */ @@ -1869,6 +1860,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } } + protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = { + if (ctx.qualifiedName != null) { + ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq + } else { + Seq(ctx.getText) + } + } + /** * Create an [[LambdaFunction]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index be9b94c606196..cc41d8ca9007f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -83,12 +83,36 @@ private[sql] object CatalogV2Implicits { throw new AnalysisException( s"Cannot use catalog ${plugin.name}: does not support namespaces") } + + def isFunctionCatalog: Boolean = plugin match { + case _: FunctionCatalog => true + case _ => false + } + + def asFunctionCatalog: FunctionCatalog = plugin match { + case functionCatalog: FunctionCatalog => + functionCatalog + case _ => + throw new AnalysisException( + s"Cannot use catalog '${plugin.name}': not a FunctionCatalog") + } } implicit class NamespaceHelper(namespace: Array[String]) { def quoted: String = namespace.map(quoteIfNeeded).mkString(".") } + implicit class FunctionIdentifierHelper(ident: FunctionIdentifier) { + def asMultipart: Seq[String] = { + ident.database match { + case Some(db) => + Seq(db, ident.funcName) + case _ => + Seq(ident.funcName) + } + } + } + implicit class IdentifierHelper(ident: Identifier) { def quoted: String = { if (ident.namespace.nonEmpty) { @@ -132,6 +156,14 @@ private[sql] object CatalogV2Implicits { s"$quoted is not a valid TableIdentifier as it has more than 2 name parts.") } + def asFunctionIdentifier: FunctionIdentifier = parts match { + case Seq(funcName) => FunctionIdentifier(funcName) + case Seq(dbName, funcName) => FunctionIdentifier(funcName, Some(dbName)) + case _ => + throw new AnalysisException( + s"$quoted is not a valid FunctionIdentifier as it has more than 2 name parts.") + } + def quoted: String = parts.map(quoteIfNeeded).mkString(".") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index af951a0e7aa66..dcd352267a178 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -154,6 +154,28 @@ private[sql] trait LookupCatalog extends Logging { } } + object AsFunctionIdentifier { + def unapply(parts: Seq[String]): Option[FunctionIdentifier] = { + def namesToFunctionIdentifier(names: Seq[String]): Option[FunctionIdentifier] = names match { + case Seq(name) => Some(FunctionIdentifier(name)) + case Seq(database, name) => Some(FunctionIdentifier(name, Some(database))) + case _ => None + } + parts match { + case Seq(name) + if catalogManager.v1SessionCatalog.isRegisteredFunction(FunctionIdentifier(name)) => + Some(FunctionIdentifier(name)) + case CatalogAndMultipartIdentifier(None, names) + if CatalogV2Util.isSessionCatalog(currentCatalog) => + namesToFunctionIdentifier(names) + case CatalogAndMultipartIdentifier(Some(catalog), names) + if CatalogV2Util.isSessionCatalog(catalog) => + namesToFunctionIdentifier(names) + case _ => None + } + } + } + def parseSessionCatalogFunctionIdentifier(nameParts: Seq[String]): FunctionIdentifier = { if (nameParts.length == 1 && catalogManager.v1SessionCatalog.isTempFunction(nameParts.head)) { return FunctionIdentifier(nameParts.head) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index a3fbe4c742160..cf065125c4f0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -617,12 +617,6 @@ private[spark] object QueryCompilationErrors { s"the function '$func', please make sure it is on the classpath") } - def v2CatalogNotSupportFunctionError( - catalog: String, namespace: Seq[String]): Throwable = { - new AnalysisException("V2 catalog does not support functions yet. " + - s"catalog: $catalog, namespace: '${namespace.quoted}'") - } - def resourceTypeNotSupportedError(resourceType: String): Throwable = { new AnalysisException(s"Resource Type '$resourceType' is not supported.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala index e0f3c9a835b6e..85e0b1062c81f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ class LookupFunctionsSuite extends PlanTest { @@ -49,7 +50,7 @@ class LookupFunctionsSuite extends PlanTest { assert(externalCatalog.getFunctionExistsCalledTimes == 1) assert(analyzer.LookupFunctions.normalizeFuncName - (unresolvedPersistentFunc.name).database == Some("default")) + (unresolvedPersistentFunc.nameParts.asFunctionIdentifier).database == Some("default")) } test("SPARK-23486: the functionExists for the Registered function check") { @@ -70,9 +71,9 @@ class LookupFunctionsSuite extends PlanTest { table("TaBlE")) analyzer.LookupFunctions.apply(plan) - assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) + assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 4) assert(analyzer.LookupFunctions.normalizeFuncName - (unresolvedRegisteredFunc.name).database == Some("default")) + (unresolvedRegisteredFunc.nameParts.asFunctionIdentifier).database == Some("default")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala index eb35dd47a508f..bfff3ee855e6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, FakeV2SessionCatalog, NoSuchNamespaceException} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog => V1InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class CatalogManagerSuite extends SparkFunSuite with SQLHelper { private def createSessionCatalog(): SessionCatalog = { - val catalog = new InMemoryCatalog() + val catalog = new V1InMemoryCatalog() catalog.createDatabase( CatalogDatabase(SessionCatalog.DEFAULT_DATABASE, "", new URI("fake"), Map.empty), ignoreIfExists = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala similarity index 93% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index 5560bda928232..0cca1cc9bebf2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -24,14 +24,15 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} import org.apache.spark.sql.connector.expressions.LogicalExpressions import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap -class TableCatalogSuite extends SparkFunSuite { +class CatalogSuite extends SparkFunSuite { import CatalogV2Implicits._ private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] @@ -39,8 +40,8 @@ class TableCatalogSuite extends SparkFunSuite { .add("id", IntegerType) .add("data", StringType) - private def newCatalog(): TableCatalog with SupportsNamespaces = { - val newCatalog = new InMemoryTableCatalog + private def newCatalog(): InMemoryCatalog = { + val newCatalog = new InMemoryCatalog newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) newCatalog } @@ -927,4 +928,43 @@ class TableCatalogSuite extends SparkFunSuite { assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 2) assert(partTable.rows.isEmpty) } + + val function: UnboundFunction = new UnboundFunction { + override def bind(inputType: StructType): BoundFunction = new ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType) + override def resultType(): DataType = IntegerType + override def name(): String = "my_bound_function" + } + override def description(): String = "my_function" + override def name(): String = "my_function" + } + + test("list functions") { + val catalog = newCatalog() + val ident1 = Identifier.of(Array("ns1", "ns2"), "func1") + val ident2 = Identifier.of(Array("ns1", "ns2"), "func2") + val ident3 = Identifier.of(Array("ns1", "ns3"), "func3") + + catalog.createNamespace(Array("ns1", "ns2"), emptyProps) + catalog.createNamespace(Array("ns1", "ns3"), emptyProps) + catalog.createFunction(ident1, function) + catalog.createFunction(ident2, function) + catalog.createFunction(ident3, function) + + assert(catalog.listFunctions(Array("ns1", "ns2")).toSet === Set(ident1, ident2)) + assert(catalog.listFunctions(Array("ns1", "ns3")).toSet === Set(ident3)) + assert(catalog.listFunctions(Array("ns1")).toSet == Set()) + intercept[NoSuchNamespaceException](catalog.listFunctions(Array("ns2"))) + } + + test("lookup function") { + val catalog = newCatalog() + val ident = Identifier.of(Array("ns"), "func") + catalog.createNamespace(Array("ns"), emptyProps) + catalog.createFunction(ident, function) + + assert(catalog.loadFunction(ident) == function) + intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns"), "func1"))) + intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns1"), "func"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala new file mode 100644 index 0000000000000..202b03f28f082 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, NoSuchNamespaceException} +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction + +class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog { + protected val functions: util.Map[Identifier, UnboundFunction] = + new ConcurrentHashMap[Identifier, UnboundFunction]() + + override protected def allNamespaces: Seq[Seq[String]] = { + (tables.keySet.asScala.map(_.namespace.toSeq) ++ + functions.keySet.asScala.map(_.namespace.toSeq) ++ + namespaces.keySet.asScala).toSeq.distinct + } + + override def listFunctions(namespace: Array[String]): Array[Identifier] = { + if (namespace.isEmpty || namespaceExists(namespace)) { + functions.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray + } else { + throw new NoSuchNamespaceException(namespace) + } + } + + override def loadFunction(ident: Identifier): UnboundFunction = { + Option(functions.get(ident)) match { + case Some(func) => + func + case _ => + throw new NoSuchFunctionException(ident) + } + } + + def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction = { + functions.put(ident, fn) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 38113f9ea1902..0c403baca2113 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -138,7 +138,7 @@ class BasicInMemoryTableCatalog extends TableCatalog { } class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamespaces { - private def allNamespaces: Seq[Seq[String]] = { + protected def allNamespaces: Seq[Seq[String]] = { (tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java new file mode 100644 index 0000000000000..4e783fdd439b6 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.AggregateFunction; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.StructType; + +import java.io.Serializable; + +public class JavaAverage implements UnboundFunction { + @Override + public String name() { + return "avg"; + } + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + if (inputType.fields()[0].dataType() instanceof DoubleType) { + return new JavaDoubleAverage(); + } + throw new UnsupportedOperationException("Unsupported non-integral type: " + + inputType.fields()[0].dataType()); + } + + @Override + public String description() { + return null; + } + + public static class JavaDoubleAverage implements AggregateFunction, Double> { + @Override + public State newAggregationState() { + return new State<>(0.0, 0.0); + } + + @Override + public State update(State state, InternalRow input) { + if (input.isNullAt(0)) { + return state; + } + return new State<>(state.sum + input.getDouble(0), state.count + 1); + } + + @Override + public Double produceResult(State state) { + return state.sum / state.count; + } + + @Override + public State merge(State leftState, State rightState) { + return new State<>(leftState.sum + rightState.sum, leftState.count + rightState.count); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.DoubleType }; + } + + @Override + public DataType resultType() { + return DataTypes.DoubleType; + } + + @Override + public String name() { + return "davg"; + } + } + + public static class State implements Serializable { + T sum, count; + + State(T sum, T count) { + this.sum = sum; + this.count = count; + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java new file mode 100644 index 0000000000000..8b2d883a3703f --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; + +public class JavaStrLen implements UnboundFunction { + private final BoundFunction fn; + + public JavaStrLen(BoundFunction fn) { + this.fn = fn; + } + + @Override + public String name() { + return "strlen"; + } + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + + if (inputType.fields()[0].dataType() instanceof StringType) { + return fn; + } + + throw new UnsupportedOperationException("Except StringType"); + } + + @Override + public String description() { + return "strlen: returns the length of the input string\n" + + " strlen(string) -> int"; + } + + public static class JavaStrLenDefault implements ScalarFunction { + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.StringType }; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String name() { + return "strlen"; + } + + @Override + public Integer produceResult(InternalRow input) { + String str = input.getString(0); + return str.length(); + } + } + + public static class JavaStrLenMagic implements ScalarFunction { + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.StringType }; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String name() { + return "strlen"; + } + + public int invoke(UTF8String str) { + return str.toString().length(); + } + } + + public static class JavaStrLenNoImpl implements ScalarFunction { + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.StringType }; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String name() { + return "strlen"; + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala new file mode 100644 index 0000000000000..fe856ffecb84a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util +import java.util.Collections + +import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaStrLen} +import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces} +import org.apache.spark.sql.connector.catalog.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { + private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + + private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = { + catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn) + } + + test("undefined function") { + assert(intercept[AnalysisException]( + sql("SELECT testcat.non_exist('abc')").collect() + ).getMessage.contains("Undefined function")) + } + + test("non-function catalog") { + withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { + assert(intercept[AnalysisException]( + sql("SELECT testcat.strlen('abc')").collect() + ).getMessage.contains("is not a FunctionCatalog")) + } + } + + test("built-in with non-function catalog should still work") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat", + "spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + } + + test("built-in with default v2 function catalog") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + } + + test("looking up higher-order function with non-session catalog") { + checkAnswer(sql("SELECT transform(array(1, 2, 3), x -> x + 1)"), + Row(Array(2, 3, 4)) :: Nil) + } + + test("built-in override with default v2 function catalog") { + // a built-in function with the same name should take higher priority + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { + addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + } + + test("built-in override with non-session catalog") { + addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + + test("temp function override with default v2 function catalog") { + val className = "test.org.apache.spark.sql.JavaStringLength" + sql(s"CREATE FUNCTION length AS '$className'") + + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { + addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl)) + checkAnswer(sql("SELECT length('abc')"), Row(3)) + } + } + + test("view should use captured catalog and namespace for function lookup") { + val viewName = "my_view" + withView(viewName) { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "my_avg"), IntegralAverage) + sql("USE ns") + sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT my_avg(col1) FROM values (1), (2), (3)") + } + + // change default catalog and namespace and add a function with the same name but with no + // implementation + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat2") { + catalog("testcat2").asInstanceOf[SupportsNamespaces] + .createNamespace(Array("ns2"), emptyProps) + addFunction(Identifier.of(Array("ns2"), "my_avg"), NoImplAverage) + sql("USE ns2") + checkAnswer(sql(s"SELECT * FROM $viewName"), Row(2.0) :: Nil) + } + } + } + + test("scalar function: with default produceResult method") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: with default produceResult method w/ expression") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault)) + checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil) + } + + test("scalar function: lookup magic method") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: lookup magic method w/ expression") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil) + } + + test("scalar function: bad magic method") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagic)) + assert(intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect()) + .getMessage.contains("Cannot find a compatible")) + } + + test("scalar function: bad magic method with default impl") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagicWithDefault)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: no implementation found") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenNoImpl)) + intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect()) + } + + test("scalar function: invalid parameter type or length") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault)) + + assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen(42)")) + .getMessage.contains("Expect StringType")) + assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('a', 'b')")) + .getMessage.contains("Expect exactly one argument")) + } + + test("scalar function: default produceResult in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenDefault)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: magic method in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenMagic)) + checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil) + } + + test("scalar function: no implementation found in Java") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), + new JavaStrLen(new JavaStrLenNoImpl)) + assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')").collect()) + .getMessage.contains("neither implement magic method nor override 'produceResult'")) + } + + test("bad bound function (neither scalar nor aggregate)") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction)) + + assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')")) + .getMessage.contains("does not implement ScalarFunction or AggregateFunction")) + } + + test("aggregate function: lookup int average") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + + (1 to 100).toDF("i").write.saveAsTable(t) + checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil) + } + } + + test("aggregate function: lookup long average") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + + (1L to 100L).toDF("i").write.saveAsTable(t) + checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil) + } + } + + test("aggregate function: lookup double average in Java") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), new JavaAverage) + + Seq(1.toDouble, 2.toDouble, 3.toDouble).toDF("i").write.saveAsTable(t) + checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(2.0) :: Nil) + } + } + + test("aggregate function: lookup int average w/ expression") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + + (1 to 100).toDF("i").write.saveAsTable(t) + checkAnswer(sql(s"SELECT testcat.ns.avg(i * 10) from $t"), Row(505) :: Nil) + } + } + + test("aggregate function: unsupported input type") { + import testImplicits._ + val t = "testcat.ns.t" + withTable(t) { + addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage) + + Seq(1.toShort, 2.toShort).toDF("i").write.saveAsTable(t) + assert(intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t")) + .getMessage.contains("Unsupported non-integral type: ShortType")) + } + } + + private case class StrLen(impl: BoundFunction) extends UnboundFunction { + override def description(): String = + """strlen: returns the length of the input string + | strlen(string) -> int""".stripMargin + override def name(): String = "strlen" + + override def bind(inputType: StructType): BoundFunction = { + if (inputType.fields.length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + inputType.fields(0).dataType match { + case StringType => impl + case _ => + throw new UnsupportedOperationException("Expect StringType") + } + } + } + + private case object StrLenDefault extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_default" + + override def produceResult(input: InternalRow): Int = { + val s = input.getString(0) + s.length + } + } + + private case object StrLenMagic extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_magic" + + def invoke(input: UTF8String): Int = { + input.toString.length + } + } + + private case object StrLenBadMagic extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_bad_magic" + + def invoke(input: String): Int = { + input.length + } + } + + private case object StrLenBadMagicWithDefault extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_bad_magic" + + def invoke(input: String): Int = { + input.length + } + + override def produceResult(input: InternalRow): Int = { + val s = input.getString(0) + s.length + } + } + + private case object StrLenNoImpl extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "strlen_noimpl" + } + + private case object BadBoundFunction extends BoundFunction { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "bad_bound_func" + } + + object IntegralAverage extends UnboundFunction { + override def name(): String = "iavg" + + override def bind(inputType: StructType): BoundFunction = { + if (inputType.fields.length > 1) { + throw new UnsupportedOperationException("Too many arguments") + } + + inputType.fields(0).dataType match { + case _: IntegerType => IntAverage + case _: LongType => LongAverage + case dataType => + throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType") + } + } + + override def description(): String = + """iavg: produces an average using integer division, ignoring nulls + | iavg(int) -> int + | iavg(bigint) -> bigint""".stripMargin + } + + object IntAverage extends AggregateFunction[(Int, Int), Int] { + override def name(): String = "iavg" + override def inputTypes(): Array[DataType] = Array(IntegerType) + override def resultType(): DataType = IntegerType + + override def newAggregationState(): (Int, Int) = (0, 0) + + override def update(state: (Int, Int), input: InternalRow): (Int, Int) = { + if (input.isNullAt(0)) { + state + } else { + val i = input.getInt(0) + state match { + case (_, 0) => + (i, 1) + case (total, count) => + (total + i, count + 1) + } + } + } + + override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = { + (leftState._1 + rightState._1, leftState._2 + rightState._2) + } + + override def produceResult(state: (Int, Int)): Int = state._1 / state._2 + } + + object LongAverage extends AggregateFunction[(Long, Long), Long] { + override def name(): String = "iavg" + override def inputTypes(): Array[DataType] = Array(LongType) + override def resultType(): DataType = LongType + + override def newAggregationState(): (Long, Long) = (0L, 0L) + + override def update(state: (Long, Long), input: InternalRow): (Long, Long) = { + if (input.isNullAt(0)) { + state + } else { + val l = input.getLong(0) + state match { + case (_, 0L) => + (l, 1) + case (total, count) => + (total + l, count + 1L) + } + } + } + + override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = { + (leftState._1 + rightState._1, leftState._2 + rightState._2) + } + + override def produceResult(state: (Long, Long)): Long = state._1 / state._2 + } + + object NoImplAverage extends UnboundFunction { + override def name(): String = "no_impl_avg" + override def description(): String = name() + + override def bind(inputType: StructType): BoundFunction = { + throw new UnsupportedOperationException(s"Not implemented") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala index 3ef242f90f7e7..77a515b55ce76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import org.scalatest.BeforeAndAfter import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryPartitionTableCatalog, InMemoryTableCatalog, StagingInMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryCatalog, InMemoryPartitionTableCatalog, StagingInMemoryTableCatalog} import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSparkSession @@ -32,11 +32,11 @@ trait DatasourceV2SQLBase } before { - spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) spark.conf.set("spark.sql.catalog.testpart", classOf[InMemoryPartitionTableCatalog].getName) spark.conf.set( "spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryTableCatalog].getName) - spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryCatalog].getName) spark.conf.set( V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName)