diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java
index 6982ebb329ff3..4181feafed101 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java
@@ -25,13 +25,12 @@
/**
* Interface for a function that produces a result value by aggregating over multiple input rows.
*
- * For each input row, Spark will call an update method that corresponds to the
- * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
- * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
- * update with {@link InternalRow}.
- *
- * The JVM type of result values produced by this function must be the type used by Spark's
+ * For each input row, Spark will call the {@link #update} method which should evaluate the row
+ * and update the aggregation state. The JVM type of result values produced by
+ * {@link #produceResult} must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
+ * Please refer to class documentation of {@link ScalarFunction} for the mapping between
+ * {@link DataType} and the JVM type.
*
* All implementations must support partial aggregation by implementing merge so that Spark can
* partially aggregate and shuffle intermediate results, instead of shuffling all rows for an
@@ -68,9 +67,7 @@ public interface AggregateFunction extends BoundFunct
* @param input an input row
* @return updated aggregation state
*/
- default S update(S state, InternalRow input) {
- throw new UnsupportedOperationException("Cannot find a compatible AggregateFunction#update");
- }
+ S update(S state, InternalRow input);
/**
* Merge two partial aggregation states.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
index c2106a21c4a8f..ef755aae3fb07 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
@@ -23,17 +23,67 @@
/**
* Interface for a function that produces a result value for each input row.
*
- * For each input row, Spark will call a produceResult method that corresponds to the
- * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
- * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
- * {@link #produceResult(InternalRow)}.
+ * To evaluate each input row, Spark will first try to lookup and use a "magic method" (described
+ * below) through Java reflection. If the method is not found, Spark will call
+ * {@link #produceResult(InternalRow)} as a fallback approach.
*
* The JVM type of result values produced by this function must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
+ *
+ * IMPORTANT: the default implementation of {@link #produceResult} throws
+ * {@link UnsupportedOperationException}. Users can choose to override this method, or implement
+ * a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters
+ * instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java
+ * reflection and will also provide better performance in general, due to optimizations such as
+ * codegen, removal of Java boxing, etc.
+ *
+ * For example, a scalar UDF for adding two integers can be defined as follow with the magic
+ * method approach:
+ *
+ *
+ * public class IntegerAdd implements{@code ScalarFunction} {
+ * public int invoke(int left, int right) {
+ * return left + right;
+ * }
+ * }
+ *
+ * In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over
+ * {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by
+ * first converting the actual input SQL data types to their corresponding Java types following
+ * the mapping defined below, and then checking if there is a matching method from all the
+ * declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and
+ * the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}.
+ *
+ * The following are the mapping from {@link DataType SQL data type} to Java type through
+ * the magic method approach:
+ *
+ * - {@link org.apache.spark.sql.types.BooleanType}: {@code boolean}
+ * - {@link org.apache.spark.sql.types.ByteType}: {@code byte}
+ * - {@link org.apache.spark.sql.types.ShortType}: {@code short}
+ * - {@link org.apache.spark.sql.types.IntegerType}: {@code int}
+ * - {@link org.apache.spark.sql.types.LongType}: {@code long}
+ * - {@link org.apache.spark.sql.types.FloatType}: {@code float}
+ * - {@link org.apache.spark.sql.types.DoubleType}: {@code double}
+ * - {@link org.apache.spark.sql.types.StringType}:
+ * {@link org.apache.spark.unsafe.types.UTF8String}
+ * - {@link org.apache.spark.sql.types.DateType}: {@code int}
+ * - {@link org.apache.spark.sql.types.TimestampType}: {@code long}
+ * - {@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}
+ * - {@link org.apache.spark.sql.types.DayTimeIntervalType}: {@code long}
+ * - {@link org.apache.spark.sql.types.YearMonthIntervalType}: {@code int}
+ * - {@link org.apache.spark.sql.types.DecimalType}:
+ * {@link org.apache.spark.sql.types.Decimal}
+ * - {@link org.apache.spark.sql.types.StructType}: {@link InternalRow}
+ * - {@link org.apache.spark.sql.types.ArrayType}:
+ * {@link org.apache.spark.sql.catalyst.util.ArrayData}
+ * - {@link org.apache.spark.sql.types.MapType}:
+ * {@link org.apache.spark.sql.catalyst.util.MapData}
+ *
*
* @param the JVM type of result values
*/
public interface ScalarFunction extends BoundFunction {
+ String MAGIC_METHOD_NAME = "invoke";
/**
* Applies the function to an input row to produce a value.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 87b8d52ac277f..1bca8f24538b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
+import java.lang.reflect.Method
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
@@ -29,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
-import org.apache.spark.sql.catalyst.expressions.{FrameLessOffsetWindowFunction, _}
+import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _}
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects._
@@ -44,6 +45,8 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
+import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction}
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -281,7 +284,7 @@ class Analyzer(override val catalogManager: CatalogManager)
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables ::
- ResolveHigherOrderFunctions(v1SessionCatalog) ::
+ ResolveHigherOrderFunctions(catalogManager) ::
ResolveLambdaVariables ::
ResolveTimeZone ::
ResolveRandomSeed ::
@@ -895,9 +898,10 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}
- // If we are resolving relations insides views, we need to expand single-part relation names with
- // the current catalog and namespace of when the view was created.
- private def expandRelationName(nameParts: Seq[String]): Seq[String] = {
+ // If we are resolving database objects (relations, functions, etc.) insides views, we may need to
+ // expand single or multi-part identifiers with the current catalog and namespace of when the
+ // view was created.
+ private def expandIdentifier(nameParts: Seq[String]): Seq[String] = {
if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts
if (nameParts.length == 1) {
@@ -1040,7 +1044,7 @@ class Analyzer(override val catalogManager: CatalogManager)
identifier: Seq[String],
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] =
- expandRelationName(identifier) match {
+ expandIdentifier(identifier) match {
case NonSessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) =>
@@ -1153,7 +1157,7 @@ class Analyzer(override val catalogManager: CatalogManager)
}
private def lookupTableOrView(identifier: Seq[String]): Option[LogicalPlan] = {
- expandRelationName(identifier) match {
+ expandIdentifier(identifier) match {
case SessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident).map {
case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW =>
@@ -1173,7 +1177,7 @@ class Analyzer(override val catalogManager: CatalogManager)
identifier: Seq[String],
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] = {
- expandRelationName(identifier) match {
+ expandIdentifier(identifier) match {
case SessionCatalogAndIdentifier(catalog, ident) =>
lazy val loaded = CatalogV2Util.loadTable(catalog, ident).map {
case v1Table: V1Table =>
@@ -1569,8 +1573,7 @@ class Analyzer(override val catalogManager: CatalogManager)
// results and confuse users if there is any null values. For count(t1.*, t2.*), it is
// still allowed, since it's well-defined in spark.
if (!conf.allowStarWithSingleTableIdentifierInCount &&
- f1.name.database.isEmpty &&
- f1.name.funcName == "count" &&
+ f1.nameParts == Seq("count") &&
f1.arguments.length == 1) {
f1.arguments.foreach {
case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) =>
@@ -1958,17 +1961,19 @@ class Analyzer(override val catalogManager: CatalogManager)
override def apply(plan: LogicalPlan): LogicalPlan = {
val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]()
plan.resolveExpressions {
- case f: UnresolvedFunction
- if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
- case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f
- case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) =>
- externalFunctionNameSet.add(normalizeFuncName(f.name))
- f
- case f: UnresolvedFunction =>
- withPosition(f) {
- throw new NoSuchFunctionException(
- f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
- f.name.funcName)
+ case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) =>
+ if (externalFunctionNameSet.contains(normalizeFuncName(ident)) ||
+ v1SessionCatalog.isRegisteredFunction(ident)) {
+ f
+ } else if (v1SessionCatalog.isPersistentFunction(ident)) {
+ externalFunctionNameSet.add(normalizeFuncName(ident))
+ f
+ } else {
+ withPosition(f) {
+ throw new NoSuchFunctionException(
+ ident.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
+ ident.funcName)
+ }
}
}
}
@@ -2016,9 +2021,10 @@ class Analyzer(override val catalogManager: CatalogManager)
name, other.getClass.getCanonicalName)
}
}
- case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter, ignoreNulls) =>
- withPosition(u) {
- v1SessionCatalog.lookupFunction(funcId, arguments) match {
+
+ case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, isDistinct, filter,
+ ignoreNulls) => withPosition(u) {
+ v1SessionCatalog.lookupFunction(ident, arguments) match {
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
@@ -2095,9 +2101,123 @@ class Analyzer(override val catalogManager: CatalogManager)
case other =>
other
}
+ }
+
+ case u @ UnresolvedFunction(nameParts, arguments, isDistinct, filter, ignoreNulls) =>
+ withPosition(u) {
+ expandIdentifier(nameParts) match {
+ case NonSessionCatalogAndIdentifier(catalog, ident) =>
+ if (!catalog.isFunctionCatalog) {
+ throw new AnalysisException(s"Trying to lookup function '$ident' in " +
+ s"catalog '${catalog.name()}', but it is not a FunctionCatalog.")
+ }
+
+ val unbound = catalog.asFunctionCatalog.loadFunction(ident)
+ val inputType = StructType(arguments.zipWithIndex.map {
+ case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
+ })
+ val bound = try {
+ unbound.bind(inputType)
+ } catch {
+ case unsupported: UnsupportedOperationException =>
+ throw new AnalysisException(s"Function '${unbound.name}' cannot process " +
+ s"input: (${arguments.map(_.dataType.simpleString).mkString(", ")}): " +
+ unsupported.getMessage, cause = Some(unsupported))
+ }
+
+ bound match {
+ case scalarFunc: ScalarFunction[_] =>
+ processV2ScalarFunction(scalarFunc, inputType, arguments, isDistinct,
+ filter, ignoreNulls)
+ case aggFunc: V2AggregateFunction[_, _] =>
+ processV2AggregateFunction(aggFunc, arguments, isDistinct, filter,
+ ignoreNulls)
+ case _ =>
+ failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction" +
+ s" or AggregateFunction")
+ }
+
+ case _ => u
+ }
}
}
}
+
+ private def processV2ScalarFunction(
+ scalarFunc: ScalarFunction[_],
+ inputType: StructType,
+ arguments: Seq[Expression],
+ isDistinct: Boolean,
+ filter: Option[Expression],
+ ignoreNulls: Boolean): Expression = {
+ if (isDistinct) {
+ throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+ scalarFunc.name(), "DISTINCT")
+ } else if (filter.isDefined) {
+ throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+ scalarFunc.name(), "FILTER clause")
+ } else if (ignoreNulls) {
+ throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+ scalarFunc.name(), "IGNORE NULLS")
+ } else {
+ // TODO: implement type coercion by looking at input type from the UDF. We
+ // may also want to check if the parameter types from the magic method
+ // match the input type through `BoundFunction.inputTypes`.
+ val argClasses = inputType.fields.map(_.dataType)
+ findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
+ case Some(_) =>
+ val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
+ Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
+ arguments, returnNullable = scalarFunc.isResultNullable)
+ case _ =>
+ // TODO: handle functions defined in Scala too - in Scala, even if a
+ // subclass do not override the default method in parent interface
+ // defined in Java, the method can still be found from
+ // `getDeclaredMethod`.
+ // since `inputType` is a `StructType`, it is mapped to a `InternalRow`
+ // which we can use to lookup the `produceResult` method.
+ findMethod(scalarFunc, "produceResult", Seq(inputType)) match {
+ case Some(_) =>
+ ApplyFunctionExpression(scalarFunc, arguments)
+ case None =>
+ failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" +
+ s" magic method nor override 'produceResult'")
+ }
+ }
+ }
+ }
+
+ private def processV2AggregateFunction(
+ aggFunc: V2AggregateFunction[_, _],
+ arguments: Seq[Expression],
+ isDistinct: Boolean,
+ filter: Option[Expression],
+ ignoreNulls: Boolean): Expression = {
+ if (ignoreNulls) {
+ throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+ aggFunc.name(), "IGNORE NULLS")
+ }
+ val aggregator = V2Aggregator(aggFunc, arguments)
+ AggregateExpression(aggregator, Complete, isDistinct, filter)
+ }
+
+ /**
+ * Check if the input `fn` implements the given `methodName` with parameter types specified
+ * via `inputType`.
+ */
+ private def findMethod(
+ fn: BoundFunction,
+ methodName: String,
+ inputType: Seq[DataType]): Option[Method] = {
+ val cls = fn.getClass
+ try {
+ val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass)
+ Some(cls.getDeclaredMethod(methodName, argClasses: _*))
+ } catch {
+ case _: NoSuchMethodException =>
+ None
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
index 7d74c0d1cd14f..b86d5d75a4550 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.DataType
@@ -30,13 +30,14 @@ import org.apache.spark.sql.types.DataType
* so we need to resolve higher order function when all children are either resolved or a lambda
* function.
*/
-case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] {
+case class ResolveHigherOrderFunctions(catalogManager: CatalogManager)
+ extends Rule[LogicalPlan] with LookupCatalog {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
- case u @ UnresolvedFunction(fn, children, false, filter, ignoreNulls)
+ case u @ UnresolvedFunction(AsFunctionIdentifier(ident), children, false, filter, ignoreNulls)
if hasLambdaAndResolvedArguments(children) =>
withPosition(u) {
- catalog.lookupFunction(fn, children) match {
+ catalogManager.v1SessionCatalog.lookupFunction(ident, children) match {
case func: HigherOrderFunction =>
filter.foreach(_.failAnalysis("FILTER predicate specified, " +
s"but ${func.prettyName} is not an aggregate function"))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 5001e2ea88ac7..6fcde63ca225c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -269,12 +269,13 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio
}
case class UnresolvedFunction(
- name: FunctionIdentifier,
+ nameParts: Seq[String],
arguments: Seq[Expression],
isDistinct: Boolean,
filter: Option[Expression] = None,
ignoreNulls: Boolean = false)
extends Expression with Unevaluable {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
override def children: Seq[Expression] = arguments ++ filter.toSeq
@@ -282,10 +283,10 @@ case class UnresolvedFunction(
override def nullable: Boolean = throw new UnresolvedException("nullable")
override lazy val resolved = false
- override def prettyName: String = name.unquotedString
+ override def prettyName: String = nameParts.quoted
override def toString: String = {
val distinct = if (isDistinct) "distinct " else ""
- s"'$name($distinct${children.mkString(", ")})"
+ s"'${nameParts.quoted}($distinct${children.mkString(", ")})"
}
override protected def withNewChildrenInternal(
@@ -299,8 +300,17 @@ case class UnresolvedFunction(
}
object UnresolvedFunction {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ def apply(
+ name: FunctionIdentifier,
+ arguments: Seq[Expression],
+ isDistinct: Boolean): UnresolvedFunction = {
+ UnresolvedFunction(name.asMultipart, arguments, isDistinct)
+ }
+
def apply(name: String, arguments: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = {
- UnresolvedFunction(FunctionIdentifier(name, None), arguments, isDistinct)
+ UnresolvedFunction(Seq(name), arguments, isDistinct)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index d259d6a706d72..0813d41af1617 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -1576,6 +1576,8 @@ class SessionCatalog(
name: FunctionIdentifier,
children: Seq[Expression],
registry: FunctionRegistryBase[T]): T = synchronized {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+
// Note: the implementation of this function is a little bit convoluted.
// We probably shouldn't use a single FunctionRegistry to register all three kinds of functions
// (built-in, temp, and external).
@@ -1598,7 +1600,8 @@ class SessionCatalog(
case Seq() => getCurrentDatabase
case Seq(_, db) => db
case Seq(catalog, namespace @ _*) =>
- throw QueryCompilationErrors.v2CatalogNotSupportFunctionError(catalog, namespace)
+ throw new IllegalStateException(s"[BUG] unexpected v2 catalog: $catalog, and " +
+ s"namespace: ${namespace.quoted} in v1 function lookup")
}
// If the name itself is not qualified, add the current database to it.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
new file mode 100644
index 0000000000000..ce9c933902942
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
+import org.apache.spark.sql.types.DataType
+
+case class ApplyFunctionExpression(
+ function: ScalarFunction[_],
+ children: Seq[Expression]) extends Expression with UserDefinedExpression with CodegenFallback {
+ override def nullable: Boolean = function.isResultNullable
+ override def name: String = function.name()
+ override def dataType: DataType = function.resultType()
+
+ private lazy val reusedRow = new GenericInternalRow(children.size)
+
+ /** Returns the result of evaluating this expression on a given input Row */
+ override def eval(input: InternalRow): Any = {
+ children.zipWithIndex.foreach {
+ case (expr, pos) =>
+ reusedRow.update(pos, expr.eval(input))
+ }
+
+ function.produceResult(reusedRow)
+ }
+
+ override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+ copy(children = newChildren)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala
new file mode 100644
index 0000000000000..55e3f504ae2e6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection}
+import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction}
+import org.apache.spark.sql.types.DataType
+
+case class V2Aggregator[BUF <: java.io.Serializable, OUT](
+ aggrFunc: V2AggregateFunction[BUF, OUT],
+ children: Seq[Expression],
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[BUF] {
+ private[this] lazy val inputProjection = UnsafeProjection.create(children)
+
+ override def nullable: Boolean = aggrFunc.isResultNullable
+ override def dataType: DataType = aggrFunc.resultType()
+ override def createAggregationBuffer(): BUF = aggrFunc.newAggregationState()
+
+ override def update(buffer: BUF, input: InternalRow): BUF = {
+ aggrFunc.update(buffer, inputProjection(input))
+ }
+
+ override def merge(buffer: BUF, input: BUF): BUF = aggrFunc.merge(buffer, input)
+
+ override def eval(buffer: BUF): Any = {
+ aggrFunc.produceResult(buffer)
+ }
+
+ override def serialize(buffer: BUF): Array[Byte] = {
+ val bos = new ByteArrayOutputStream()
+ val out = new ObjectOutputStream(bos)
+ out.writeObject(buffer)
+ out.close()
+ bos.toByteArray
+ }
+
+ override def deserialize(bytes: Array[Byte]): BUF = {
+ val in = new ObjectInputStream(new ByteArrayInputStream(bytes))
+ in.readObject().asInstanceOf[BUF]
+ }
+
+ def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): V2Aggregator[BUF, OUT] =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): V2Aggregator[BUF, OUT] =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+ copy(children = newChildren)
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index cd21beafbd85f..da7c5c0221bd2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1818,7 +1818,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
val ignoreNulls =
Option(ctx.nullsOption).map(_.getType == SqlBaseParser.IGNORE).getOrElse(false)
val function = UnresolvedFunction(
- getFunctionIdentifier(ctx.functionName), arguments, isDistinct, filter, ignoreNulls)
+ getFunctionMultiparts(ctx.functionName), arguments, isDistinct, filter, ignoreNulls)
// Check if the function is evaluated in a windowed context.
ctx.windowSpec match {
@@ -1830,15 +1830,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
}
}
-
- /**
- * Create a function database (optional) and name pair, for multipartIdentifier.
- * This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS.
- */
- protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = {
- visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq)
- }
-
/**
* Create a function database (optional) and name pair.
*/
@@ -1869,6 +1860,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
}
}
+ protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = {
+ if (ctx.qualifiedName != null) {
+ ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq
+ } else {
+ Seq(ctx.getText)
+ }
+ }
+
/**
* Create an [[LambdaFunction]].
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
index be9b94c606196..cc41d8ca9007f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
@@ -83,12 +83,36 @@ private[sql] object CatalogV2Implicits {
throw new AnalysisException(
s"Cannot use catalog ${plugin.name}: does not support namespaces")
}
+
+ def isFunctionCatalog: Boolean = plugin match {
+ case _: FunctionCatalog => true
+ case _ => false
+ }
+
+ def asFunctionCatalog: FunctionCatalog = plugin match {
+ case functionCatalog: FunctionCatalog =>
+ functionCatalog
+ case _ =>
+ throw new AnalysisException(
+ s"Cannot use catalog '${plugin.name}': not a FunctionCatalog")
+ }
}
implicit class NamespaceHelper(namespace: Array[String]) {
def quoted: String = namespace.map(quoteIfNeeded).mkString(".")
}
+ implicit class FunctionIdentifierHelper(ident: FunctionIdentifier) {
+ def asMultipart: Seq[String] = {
+ ident.database match {
+ case Some(db) =>
+ Seq(db, ident.funcName)
+ case _ =>
+ Seq(ident.funcName)
+ }
+ }
+ }
+
implicit class IdentifierHelper(ident: Identifier) {
def quoted: String = {
if (ident.namespace.nonEmpty) {
@@ -132,6 +156,14 @@ private[sql] object CatalogV2Implicits {
s"$quoted is not a valid TableIdentifier as it has more than 2 name parts.")
}
+ def asFunctionIdentifier: FunctionIdentifier = parts match {
+ case Seq(funcName) => FunctionIdentifier(funcName)
+ case Seq(dbName, funcName) => FunctionIdentifier(funcName, Some(dbName))
+ case _ =>
+ throw new AnalysisException(
+ s"$quoted is not a valid FunctionIdentifier as it has more than 2 name parts.")
+ }
+
def quoted: String = parts.map(quoteIfNeeded).mkString(".")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
index af951a0e7aa66..dcd352267a178 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
@@ -154,6 +154,28 @@ private[sql] trait LookupCatalog extends Logging {
}
}
+ object AsFunctionIdentifier {
+ def unapply(parts: Seq[String]): Option[FunctionIdentifier] = {
+ def namesToFunctionIdentifier(names: Seq[String]): Option[FunctionIdentifier] = names match {
+ case Seq(name) => Some(FunctionIdentifier(name))
+ case Seq(database, name) => Some(FunctionIdentifier(name, Some(database)))
+ case _ => None
+ }
+ parts match {
+ case Seq(name)
+ if catalogManager.v1SessionCatalog.isRegisteredFunction(FunctionIdentifier(name)) =>
+ Some(FunctionIdentifier(name))
+ case CatalogAndMultipartIdentifier(None, names)
+ if CatalogV2Util.isSessionCatalog(currentCatalog) =>
+ namesToFunctionIdentifier(names)
+ case CatalogAndMultipartIdentifier(Some(catalog), names)
+ if CatalogV2Util.isSessionCatalog(catalog) =>
+ namesToFunctionIdentifier(names)
+ case _ => None
+ }
+ }
+ }
+
def parseSessionCatalogFunctionIdentifier(nameParts: Seq[String]): FunctionIdentifier = {
if (nameParts.length == 1 && catalogManager.v1SessionCatalog.isTempFunction(nameParts.head)) {
return FunctionIdentifier(nameParts.head)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index a3fbe4c742160..cf065125c4f0b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -617,12 +617,6 @@ private[spark] object QueryCompilationErrors {
s"the function '$func', please make sure it is on the classpath")
}
- def v2CatalogNotSupportFunctionError(
- catalog: String, namespace: Seq[String]): Throwable = {
- new AnalysisException("V2 catalog does not support functions yet. " +
- s"catalog: $catalog, namespace: '${namespace.quoted}'")
- }
-
def resourceTypeNotSupportedError(resourceType: String): Throwable = {
new AnalysisException(s"Resource Type '$resourceType' is not supported.")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala
index e0f3c9a835b6e..85e0b1062c81f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog,
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
class LookupFunctionsSuite extends PlanTest {
@@ -49,7 +50,7 @@ class LookupFunctionsSuite extends PlanTest {
assert(externalCatalog.getFunctionExistsCalledTimes == 1)
assert(analyzer.LookupFunctions.normalizeFuncName
- (unresolvedPersistentFunc.name).database == Some("default"))
+ (unresolvedPersistentFunc.nameParts.asFunctionIdentifier).database == Some("default"))
}
test("SPARK-23486: the functionExists for the Registered function check") {
@@ -70,9 +71,9 @@ class LookupFunctionsSuite extends PlanTest {
table("TaBlE"))
analyzer.LookupFunctions.apply(plan)
- assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2)
+ assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 4)
assert(analyzer.LookupFunctions.normalizeFuncName
- (unresolvedRegisteredFunc.name).database == Some("default"))
+ (unresolvedRegisteredFunc.nameParts.asFunctionIdentifier).database == Some("default"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala
index eb35dd47a508f..bfff3ee855e6d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, FakeV2SessionCatalog, NoSuchNamespaceException}
-import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog => V1InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
class CatalogManagerSuite extends SparkFunSuite with SQLHelper {
private def createSessionCatalog(): SessionCatalog = {
- val catalog = new InMemoryCatalog()
+ val catalog = new V1InMemoryCatalog()
catalog.createDatabase(
CatalogDatabase(SessionCatalog.DEFAULT_DATABASE, "", new URI("fake"), Map.empty),
ignoreIfExists = true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
similarity index 93%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
index 5560bda928232..0cca1cc9bebf2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
@@ -24,14 +24,15 @@ import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.expressions.LogicalExpressions
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType}
+import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
-class TableCatalogSuite extends SparkFunSuite {
+class CatalogSuite extends SparkFunSuite {
import CatalogV2Implicits._
private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
@@ -39,8 +40,8 @@ class TableCatalogSuite extends SparkFunSuite {
.add("id", IntegerType)
.add("data", StringType)
- private def newCatalog(): TableCatalog with SupportsNamespaces = {
- val newCatalog = new InMemoryTableCatalog
+ private def newCatalog(): InMemoryCatalog = {
+ val newCatalog = new InMemoryCatalog
newCatalog.initialize("test", CaseInsensitiveStringMap.empty())
newCatalog
}
@@ -927,4 +928,43 @@ class TableCatalogSuite extends SparkFunSuite {
assert(partTable.listPartitionIdentifiers(Array.empty, InternalRow.empty).length == 2)
assert(partTable.rows.isEmpty)
}
+
+ val function: UnboundFunction = new UnboundFunction {
+ override def bind(inputType: StructType): BoundFunction = new ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(IntegerType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "my_bound_function"
+ }
+ override def description(): String = "my_function"
+ override def name(): String = "my_function"
+ }
+
+ test("list functions") {
+ val catalog = newCatalog()
+ val ident1 = Identifier.of(Array("ns1", "ns2"), "func1")
+ val ident2 = Identifier.of(Array("ns1", "ns2"), "func2")
+ val ident3 = Identifier.of(Array("ns1", "ns3"), "func3")
+
+ catalog.createNamespace(Array("ns1", "ns2"), emptyProps)
+ catalog.createNamespace(Array("ns1", "ns3"), emptyProps)
+ catalog.createFunction(ident1, function)
+ catalog.createFunction(ident2, function)
+ catalog.createFunction(ident3, function)
+
+ assert(catalog.listFunctions(Array("ns1", "ns2")).toSet === Set(ident1, ident2))
+ assert(catalog.listFunctions(Array("ns1", "ns3")).toSet === Set(ident3))
+ assert(catalog.listFunctions(Array("ns1")).toSet == Set())
+ intercept[NoSuchNamespaceException](catalog.listFunctions(Array("ns2")))
+ }
+
+ test("lookup function") {
+ val catalog = newCatalog()
+ val ident = Identifier.of(Array("ns"), "func")
+ catalog.createNamespace(Array("ns"), emptyProps)
+ catalog.createFunction(ident, function)
+
+ assert(catalog.loadFunction(ident) == function)
+ intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns"), "func1")))
+ intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns1"), "func")))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala
new file mode 100644
index 0000000000000..202b03f28f082
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, NoSuchNamespaceException}
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
+
+class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog {
+ protected val functions: util.Map[Identifier, UnboundFunction] =
+ new ConcurrentHashMap[Identifier, UnboundFunction]()
+
+ override protected def allNamespaces: Seq[Seq[String]] = {
+ (tables.keySet.asScala.map(_.namespace.toSeq) ++
+ functions.keySet.asScala.map(_.namespace.toSeq) ++
+ namespaces.keySet.asScala).toSeq.distinct
+ }
+
+ override def listFunctions(namespace: Array[String]): Array[Identifier] = {
+ if (namespace.isEmpty || namespaceExists(namespace)) {
+ functions.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
+ } else {
+ throw new NoSuchNamespaceException(namespace)
+ }
+ }
+
+ override def loadFunction(ident: Identifier): UnboundFunction = {
+ Option(functions.get(ident)) match {
+ case Some(func) =>
+ func
+ case _ =>
+ throw new NoSuchFunctionException(ident)
+ }
+ }
+
+ def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction = {
+ functions.put(ident, fn)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
index 38113f9ea1902..0c403baca2113 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
@@ -138,7 +138,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
}
class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamespaces {
- private def allNamespaces: Seq[Seq[String]] = {
+ protected def allNamespaces: Seq[Seq[String]] = {
(tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java
new file mode 100644
index 0000000000000..4e783fdd439b6
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaAverage.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.connector.catalog.functions;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.functions.AggregateFunction;
+import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.StructType;
+
+import java.io.Serializable;
+
+public class JavaAverage implements UnboundFunction {
+ @Override
+ public String name() {
+ return "avg";
+ }
+
+ @Override
+ public BoundFunction bind(StructType inputType) {
+ if (inputType.fields().length != 1) {
+ throw new UnsupportedOperationException("Expect exactly one argument");
+ }
+ if (inputType.fields()[0].dataType() instanceof DoubleType) {
+ return new JavaDoubleAverage();
+ }
+ throw new UnsupportedOperationException("Unsupported non-integral type: " +
+ inputType.fields()[0].dataType());
+ }
+
+ @Override
+ public String description() {
+ return null;
+ }
+
+ public static class JavaDoubleAverage implements AggregateFunction, Double> {
+ @Override
+ public State newAggregationState() {
+ return new State<>(0.0, 0.0);
+ }
+
+ @Override
+ public State update(State state, InternalRow input) {
+ if (input.isNullAt(0)) {
+ return state;
+ }
+ return new State<>(state.sum + input.getDouble(0), state.count + 1);
+ }
+
+ @Override
+ public Double produceResult(State state) {
+ return state.sum / state.count;
+ }
+
+ @Override
+ public State merge(State leftState, State rightState) {
+ return new State<>(leftState.sum + rightState.sum, leftState.count + rightState.count);
+ }
+
+ @Override
+ public DataType[] inputTypes() {
+ return new DataType[] { DataTypes.DoubleType };
+ }
+
+ @Override
+ public DataType resultType() {
+ return DataTypes.DoubleType;
+ }
+
+ @Override
+ public String name() {
+ return "davg";
+ }
+ }
+
+ public static class State implements Serializable {
+ T sum, count;
+
+ State(T sum, T count) {
+ this.sum = sum;
+ this.count = count;
+ }
+ }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
new file mode 100644
index 0000000000000..8b2d883a3703f
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.connector.catalog.functions;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.types.UTF8String;
+
+public class JavaStrLen implements UnboundFunction {
+ private final BoundFunction fn;
+
+ public JavaStrLen(BoundFunction fn) {
+ this.fn = fn;
+ }
+
+ @Override
+ public String name() {
+ return "strlen";
+ }
+
+ @Override
+ public BoundFunction bind(StructType inputType) {
+ if (inputType.fields().length != 1) {
+ throw new UnsupportedOperationException("Expect exactly one argument");
+ }
+
+ if (inputType.fields()[0].dataType() instanceof StringType) {
+ return fn;
+ }
+
+ throw new UnsupportedOperationException("Except StringType");
+ }
+
+ @Override
+ public String description() {
+ return "strlen: returns the length of the input string\n" +
+ " strlen(string) -> int";
+ }
+
+ public static class JavaStrLenDefault implements ScalarFunction {
+ @Override
+ public DataType[] inputTypes() {
+ return new DataType[] { DataTypes.StringType };
+ }
+
+ @Override
+ public DataType resultType() {
+ return DataTypes.IntegerType;
+ }
+
+ @Override
+ public String name() {
+ return "strlen";
+ }
+
+ @Override
+ public Integer produceResult(InternalRow input) {
+ String str = input.getString(0);
+ return str.length();
+ }
+ }
+
+ public static class JavaStrLenMagic implements ScalarFunction {
+ @Override
+ public DataType[] inputTypes() {
+ return new DataType[] { DataTypes.StringType };
+ }
+
+ @Override
+ public DataType resultType() {
+ return DataTypes.IntegerType;
+ }
+
+ @Override
+ public String name() {
+ return "strlen";
+ }
+
+ public int invoke(UTF8String str) {
+ return str.toString().length();
+ }
+ }
+
+ public static class JavaStrLenNoImpl implements ScalarFunction {
+ @Override
+ public DataType[] inputTypes() {
+ return new DataType[] { DataTypes.StringType };
+ }
+
+ @Override
+ public DataType resultType() {
+ return DataTypes.IntegerType;
+ }
+
+ @Override
+ public String name() {
+ return "strlen";
+ }
+ }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
new file mode 100644
index 0000000000000..fe856ffecb84a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -0,0 +1,420 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import java.util
+import java.util.Collections
+
+import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaStrLen}
+import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces}
+import org.apache.spark.sql.connector.catalog.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
+ private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
+
+ private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = {
+ catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn)
+ }
+
+ test("undefined function") {
+ assert(intercept[AnalysisException](
+ sql("SELECT testcat.non_exist('abc')").collect()
+ ).getMessage.contains("Undefined function"))
+ }
+
+ test("non-function catalog") {
+ withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) {
+ assert(intercept[AnalysisException](
+ sql("SELECT testcat.strlen('abc')").collect()
+ ).getMessage.contains("is not a FunctionCatalog"))
+ }
+ }
+
+ test("built-in with non-function catalog should still work") {
+ withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat",
+ "spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) {
+ checkAnswer(sql("SELECT length('abc')"), Row(3))
+ }
+ }
+
+ test("built-in with default v2 function catalog") {
+ withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
+ checkAnswer(sql("SELECT length('abc')"), Row(3))
+ }
+ }
+
+ test("looking up higher-order function with non-session catalog") {
+ checkAnswer(sql("SELECT transform(array(1, 2, 3), x -> x + 1)"),
+ Row(Array(2, 3, 4)) :: Nil)
+ }
+
+ test("built-in override with default v2 function catalog") {
+ // a built-in function with the same name should take higher priority
+ withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
+ addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl))
+ checkAnswer(sql("SELECT length('abc')"), Row(3))
+ }
+ }
+
+ test("built-in override with non-session catalog") {
+ addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl))
+ checkAnswer(sql("SELECT length('abc')"), Row(3))
+ }
+
+ test("temp function override with default v2 function catalog") {
+ val className = "test.org.apache.spark.sql.JavaStringLength"
+ sql(s"CREATE FUNCTION length AS '$className'")
+
+ withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
+ addFunction(Identifier.of(Array.empty, "length"), new JavaStrLen(new JavaStrLenNoImpl))
+ checkAnswer(sql("SELECT length('abc')"), Row(3))
+ }
+ }
+
+ test("view should use captured catalog and namespace for function lookup") {
+ val viewName = "my_view"
+ withView(viewName) {
+ withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "my_avg"), IntegralAverage)
+ sql("USE ns")
+ sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT my_avg(col1) FROM values (1), (2), (3)")
+ }
+
+ // change default catalog and namespace and add a function with the same name but with no
+ // implementation
+ withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat2") {
+ catalog("testcat2").asInstanceOf[SupportsNamespaces]
+ .createNamespace(Array("ns2"), emptyProps)
+ addFunction(Identifier.of(Array("ns2"), "my_avg"), NoImplAverage)
+ sql("USE ns2")
+ checkAnswer(sql(s"SELECT * FROM $viewName"), Row(2.0) :: Nil)
+ }
+ }
+ }
+
+ test("scalar function: with default produceResult method") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
+ checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
+ }
+
+ test("scalar function: with default produceResult method w/ expression") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
+ checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil)
+ }
+
+ test("scalar function: lookup magic method") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic))
+ checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
+ }
+
+ test("scalar function: lookup magic method w/ expression") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenMagic))
+ checkAnswer(sql("SELECT testcat.ns.strlen(substr('abcde', 3))"), Row(3) :: Nil)
+ }
+
+ test("scalar function: bad magic method") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagic))
+ assert(intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect())
+ .getMessage.contains("Cannot find a compatible"))
+ }
+
+ test("scalar function: bad magic method with default impl") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagicWithDefault))
+ checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
+ }
+
+ test("scalar function: no implementation found") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenNoImpl))
+ intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect())
+ }
+
+ test("scalar function: invalid parameter type or length") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
+
+ assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen(42)"))
+ .getMessage.contains("Expect StringType"))
+ assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('a', 'b')"))
+ .getMessage.contains("Expect exactly one argument"))
+ }
+
+ test("scalar function: default produceResult in Java") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"),
+ new JavaStrLen(new JavaStrLenDefault))
+ checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
+ }
+
+ test("scalar function: magic method in Java") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"),
+ new JavaStrLen(new JavaStrLenMagic))
+ checkAnswer(sql("SELECT testcat.ns.strlen('abc')"), Row(3) :: Nil)
+ }
+
+ test("scalar function: no implementation found in Java") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"),
+ new JavaStrLen(new JavaStrLenNoImpl))
+ assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')").collect())
+ .getMessage.contains("neither implement magic method nor override 'produceResult'"))
+ }
+
+ test("bad bound function (neither scalar nor aggregate)") {
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction))
+
+ assert(intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')"))
+ .getMessage.contains("does not implement ScalarFunction or AggregateFunction"))
+ }
+
+ test("aggregate function: lookup int average") {
+ import testImplicits._
+ val t = "testcat.ns.t"
+ withTable(t) {
+ addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
+
+ (1 to 100).toDF("i").write.saveAsTable(t)
+ checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil)
+ }
+ }
+
+ test("aggregate function: lookup long average") {
+ import testImplicits._
+ val t = "testcat.ns.t"
+ withTable(t) {
+ addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
+
+ (1L to 100L).toDF("i").write.saveAsTable(t)
+ checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(50) :: Nil)
+ }
+ }
+
+ test("aggregate function: lookup double average in Java") {
+ import testImplicits._
+ val t = "testcat.ns.t"
+ withTable(t) {
+ addFunction(Identifier.of(Array("ns"), "avg"), new JavaAverage)
+
+ Seq(1.toDouble, 2.toDouble, 3.toDouble).toDF("i").write.saveAsTable(t)
+ checkAnswer(sql(s"SELECT testcat.ns.avg(i) from $t"), Row(2.0) :: Nil)
+ }
+ }
+
+ test("aggregate function: lookup int average w/ expression") {
+ import testImplicits._
+ val t = "testcat.ns.t"
+ withTable(t) {
+ addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
+
+ (1 to 100).toDF("i").write.saveAsTable(t)
+ checkAnswer(sql(s"SELECT testcat.ns.avg(i * 10) from $t"), Row(505) :: Nil)
+ }
+ }
+
+ test("aggregate function: unsupported input type") {
+ import testImplicits._
+ val t = "testcat.ns.t"
+ withTable(t) {
+ addFunction(Identifier.of(Array("ns"), "avg"), IntegralAverage)
+
+ Seq(1.toShort, 2.toShort).toDF("i").write.saveAsTable(t)
+ assert(intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t"))
+ .getMessage.contains("Unsupported non-integral type: ShortType"))
+ }
+ }
+
+ private case class StrLen(impl: BoundFunction) extends UnboundFunction {
+ override def description(): String =
+ """strlen: returns the length of the input string
+ | strlen(string) -> int""".stripMargin
+ override def name(): String = "strlen"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length != 1) {
+ throw new UnsupportedOperationException("Expect exactly one argument");
+ }
+ inputType.fields(0).dataType match {
+ case StringType => impl
+ case _ =>
+ throw new UnsupportedOperationException("Expect StringType")
+ }
+ }
+ }
+
+ private case object StrLenDefault extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_default"
+
+ override def produceResult(input: InternalRow): Int = {
+ val s = input.getString(0)
+ s.length
+ }
+ }
+
+ private case object StrLenMagic extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_magic"
+
+ def invoke(input: UTF8String): Int = {
+ input.toString.length
+ }
+ }
+
+ private case object StrLenBadMagic extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_bad_magic"
+
+ def invoke(input: String): Int = {
+ input.length
+ }
+ }
+
+ private case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_bad_magic"
+
+ def invoke(input: String): Int = {
+ input.length
+ }
+
+ override def produceResult(input: InternalRow): Int = {
+ val s = input.getString(0)
+ s.length
+ }
+ }
+
+ private case object StrLenNoImpl extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_noimpl"
+ }
+
+ private case object BadBoundFunction extends BoundFunction {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "bad_bound_func"
+ }
+
+ object IntegralAverage extends UnboundFunction {
+ override def name(): String = "iavg"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
+ }
+
+ inputType.fields(0).dataType match {
+ case _: IntegerType => IntAverage
+ case _: LongType => LongAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType")
+ }
+ }
+
+ override def description(): String =
+ """iavg: produces an average using integer division, ignoring nulls
+ | iavg(int) -> int
+ | iavg(bigint) -> bigint""".stripMargin
+ }
+
+ object IntAverage extends AggregateFunction[(Int, Int), Int] {
+ override def name(): String = "iavg"
+ override def inputTypes(): Array[DataType] = Array(IntegerType)
+ override def resultType(): DataType = IntegerType
+
+ override def newAggregationState(): (Int, Int) = (0, 0)
+
+ override def update(state: (Int, Int), input: InternalRow): (Int, Int) = {
+ if (input.isNullAt(0)) {
+ state
+ } else {
+ val i = input.getInt(0)
+ state match {
+ case (_, 0) =>
+ (i, 1)
+ case (total, count) =>
+ (total + i, count + 1)
+ }
+ }
+ }
+
+ override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = {
+ (leftState._1 + rightState._1, leftState._2 + rightState._2)
+ }
+
+ override def produceResult(state: (Int, Int)): Int = state._1 / state._2
+ }
+
+ object LongAverage extends AggregateFunction[(Long, Long), Long] {
+ override def name(): String = "iavg"
+ override def inputTypes(): Array[DataType] = Array(LongType)
+ override def resultType(): DataType = LongType
+
+ override def newAggregationState(): (Long, Long) = (0L, 0L)
+
+ override def update(state: (Long, Long), input: InternalRow): (Long, Long) = {
+ if (input.isNullAt(0)) {
+ state
+ } else {
+ val l = input.getLong(0)
+ state match {
+ case (_, 0L) =>
+ (l, 1)
+ case (total, count) =>
+ (total + l, count + 1L)
+ }
+ }
+ }
+
+ override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = {
+ (leftState._1 + rightState._1, leftState._2 + rightState._2)
+ }
+
+ override def produceResult(state: (Long, Long)): Long = state._1 / state._2
+ }
+
+ object NoImplAverage extends UnboundFunction {
+ override def name(): String = "no_impl_avg"
+ override def description(): String = name()
+
+ override def bind(inputType: StructType): BoundFunction = {
+ throw new UnsupportedOperationException(s"Not implemented")
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
index 3ef242f90f7e7..77a515b55ce76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.connector
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryPartitionTableCatalog, InMemoryTableCatalog, StagingInMemoryTableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryCatalog, InMemoryPartitionTableCatalog, StagingInMemoryTableCatalog}
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.SharedSparkSession
@@ -32,11 +32,11 @@ trait DatasourceV2SQLBase
}
before {
- spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
spark.conf.set("spark.sql.catalog.testpart", classOf[InMemoryPartitionTableCatalog].getName)
spark.conf.set(
"spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryTableCatalog].getName)
- spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName)
+ spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryCatalog].getName)
spark.conf.set(
V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName)