Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34981][SQL] Implement V2 function resolution and evaluation #32082

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
/**
* Interface for a function that produces a result value by aggregating over multiple input rows.
* <p>
* For each input row, Spark will call an update method that corresponds to the
* {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
* Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
* update with {@link InternalRow}.
* <p>
* The JVM type of result values produced by this function must be the type used by Spark's
* For each input row, Spark will call the {@link #update} method which should evaluate the row
* and update the aggregation state. The JVM type of result values produced by
* {@link #produceResult} must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
sunchao marked this conversation as resolved.
Show resolved Hide resolved
* Please refer to class documentation of {@link ScalarFunction} for the mapping between
* {@link DataType} and the JVM type.
* <p>
* All implementations must support partial aggregation by implementing merge so that Spark can
* partially aggregate and shuffle intermediate results, instead of shuffling all rows for an
Expand Down Expand Up @@ -68,9 +67,7 @@ public interface AggregateFunction<S extends Serializable, R> extends BoundFunct
* @param input an input row
* @return updated aggregation state
*/
default S update(S state, InternalRow input) {
throw new UnsupportedOperationException("Cannot find a compatible AggregateFunction#update");
}
S update(S state, InternalRow input);

/**
* Merge two partial aggregation states.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,67 @@
/**
* Interface for a function that produces a result value for each input row.
* <p>
* For each input row, Spark will call a produceResult method that corresponds to the
* {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by
* Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call
* {@link #produceResult(InternalRow)}.
* To evaluate each input row, Spark will first try to lookup and use a "magic method" (described
* below) through Java reflection. If the method is not found, Spark will call
* {@link #produceResult(InternalRow)} as a fallback approach.
* <p>
* The JVM type of result values produced by this function must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
* <p>
* <b>IMPORTANT</b>: the default implementation of {@link #produceResult} throws
* {@link UnsupportedOperationException}. Users can choose to override this method, or implement
* a "magic method" with name {@link #MAGIC_METHOD_NAME} which takes individual parameters
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
* instead of a {@link InternalRow}. The magic method will be loaded by Spark through Java
* reflection and will also provide better performance in general, due to optimizations such as
* codegen, removal of Java boxing, etc.
*
* For example, a scalar UDF for adding two integers can be defined as follow with the magic
* method approach:
*
* <pre>
* public class IntegerAdd implements{@code ScalarFunction<Integer>} {
* public int invoke(int left, int right) {
Copy link
Member Author

@sunchao sunchao Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan I think we can also consider adding another "static invoke" API for those stateless UDFs. From the benchmark you did sometime back it seems this can give a decent performance improvement. WDYT?

Java HotSpot(TM) 64-Bit Server VM 1.8.0_161-b12 on Mac OS X 10.14.6
Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
UDF perf:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------
native add                                        14206          14516         535         70.4          14.2       1.0X
udf add                                           24609          25271         898         40.6          24.6       0.6X
new udf add                                       18657          19096         726         53.6          18.7       0.8X
new row udf add                                   21128          22343        1478         47.3          21.1       0.7X
static udf add                                    16678          16887         278         60.0          16.7       0.9X

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunchao can you spend some time on the API design? I'd love to see this feature!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure will do. It should similar to the current invoke and we can leverage StaticInvoke for the purpose. Do you think we can do this in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* return left + right;
* }
* }
* </pre>
* In this case, since {@link #MAGIC_METHOD_NAME} is defined, Spark will use it over
* {@link #produceResult} to evalaute the inputs. In general Spark looks up the magic method by
* first converting the actual input SQL data types to their corresponding Java types following
* the mapping defined below, and then checking if there is a matching method from all the
* declared methods in the UDF class, using method name (i.e., {@link #MAGIC_METHOD_NAME}) and
* the Java types. If no magic method is found, Spark will falls back to use {@link #produceResult}.
* <p>
* The following are the mapping from {@link DataType SQL data type} to Java type through
* the magic method approach:
* <ul>
* <li>{@link org.apache.spark.sql.types.BooleanType}: {@code boolean}</li>
* <li>{@link org.apache.spark.sql.types.ByteType}: {@code byte}</li>
* <li>{@link org.apache.spark.sql.types.ShortType}: {@code short}</li>
* <li>{@link org.apache.spark.sql.types.IntegerType}: {@code int}</li>
* <li>{@link org.apache.spark.sql.types.LongType}: {@code long}</li>
* <li>{@link org.apache.spark.sql.types.FloatType}: {@code float}</li>
* <li>{@link org.apache.spark.sql.types.DoubleType}: {@code double}</li>
* <li>{@link org.apache.spark.sql.types.StringType}:
* {@link org.apache.spark.unsafe.types.UTF8String}</li>
* <li>{@link org.apache.spark.sql.types.DateType}: {@code int}</li>
* <li>{@link org.apache.spark.sql.types.TimestampType}: {@code long}</li>
* <li>{@link org.apache.spark.sql.types.BinaryType}: {@code byte[]}</li>
* <li>{@link org.apache.spark.sql.types.DayTimeIntervalType}: {@code long}</li>
* <li>{@link org.apache.spark.sql.types.YearMonthIntervalType}: {@code int}</li>
* <li>{@link org.apache.spark.sql.types.DecimalType}:
* {@link org.apache.spark.sql.types.Decimal}</li>
* <li>{@link org.apache.spark.sql.types.StructType}: {@link InternalRow}</li>
* <li>{@link org.apache.spark.sql.types.ArrayType}:
* {@link org.apache.spark.sql.catalyst.util.ArrayData}</li>
* <li>{@link org.apache.spark.sql.types.MapType}:
* {@link org.apache.spark.sql.catalyst.util.MapData}</li>
* </ul>
*
* @param <R> the JVM type of result values
*/
public interface ScalarFunction<R> extends BoundFunction {
String MAGIC_METHOD_NAME = "invoke";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? I think that the magic name should be "produceResult" just like the InternalRow version so that it is clear what the method is supposed to do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that initially, but since StructType maps to InternalRow, we need a way to differentiate a) magic method with a single parameter of StructType and b) the default non-magic method. Users can only define the latter in this situation but in Spark we'll lookup magic method first as it has higher priority. This may cause some issue.


/**
* Applies the function to an input row to produce a value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import java.lang.reflect.Method
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
Expand All @@ -29,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions.{FrameLessOffsetWindowFunction, _}
import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _}
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects._
Expand All @@ -44,6 +45,8 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand Down Expand Up @@ -281,7 +284,7 @@ class Analyzer(override val catalogManager: CatalogManager)
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables ::
ResolveHigherOrderFunctions(v1SessionCatalog) ::
ResolveHigherOrderFunctions(catalogManager) ::
ResolveLambdaVariables ::
ResolveTimeZone ::
ResolveRandomSeed ::
Expand Down Expand Up @@ -895,9 +898,10 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}

// If we are resolving relations insides views, we need to expand single-part relation names with
// the current catalog and namespace of when the view was created.
private def expandRelationName(nameParts: Seq[String]): Seq[String] = {
// If we are resolving database objects (relations, functions, etc.) insides views, we may need to
// expand single or multi-part identifiers with the current catalog and namespace of when the
// view was created.
private def expandIdentifier(nameParts: Seq[String]): Seq[String] = {
if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts

if (nameParts.length == 1) {
Expand Down Expand Up @@ -1040,7 +1044,7 @@ class Analyzer(override val catalogManager: CatalogManager)
identifier: Seq[String],
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] =
expandRelationName(identifier) match {
expandIdentifier(identifier) match {
case NonSessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) =>
Expand Down Expand Up @@ -1153,7 +1157,7 @@ class Analyzer(override val catalogManager: CatalogManager)
}

private def lookupTableOrView(identifier: Seq[String]): Option[LogicalPlan] = {
expandRelationName(identifier) match {
expandIdentifier(identifier) match {
case SessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident).map {
case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW =>
Expand All @@ -1173,7 +1177,7 @@ class Analyzer(override val catalogManager: CatalogManager)
identifier: Seq[String],
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] = {
expandRelationName(identifier) match {
expandIdentifier(identifier) match {
case SessionCatalogAndIdentifier(catalog, ident) =>
lazy val loaded = CatalogV2Util.loadTable(catalog, ident).map {
case v1Table: V1Table =>
Expand Down Expand Up @@ -1569,8 +1573,7 @@ class Analyzer(override val catalogManager: CatalogManager)
// results and confuse users if there is any null values. For count(t1.*, t2.*), it is
// still allowed, since it's well-defined in spark.
if (!conf.allowStarWithSingleTableIdentifierInCount &&
f1.name.database.isEmpty &&
f1.name.funcName == "count" &&
f1.nameParts == Seq("count") &&
f1.arguments.length == 1) {
f1.arguments.foreach {
case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) =>
Expand Down Expand Up @@ -1958,17 +1961,19 @@ class Analyzer(override val catalogManager: CatalogManager)
override def apply(plan: LogicalPlan): LogicalPlan = {
val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]()
plan.resolveExpressions {
case f: UnresolvedFunction
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) =>
externalFunctionNameSet.add(normalizeFuncName(f.name))
f
case f: UnresolvedFunction =>
withPosition(f) {
throw new NoSuchFunctionException(
f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
f.name.funcName)
case f @ UnresolvedFunction(AsFunctionIdentifier(ident), _, _, _, _) =>
if (externalFunctionNameSet.contains(normalizeFuncName(ident)) ||
v1SessionCatalog.isRegisteredFunction(ident)) {
f
} else if (v1SessionCatalog.isPersistentFunction(ident)) {
externalFunctionNameSet.add(normalizeFuncName(ident))
f
} else {
withPosition(f) {
throw new NoSuchFunctionException(
ident.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
ident.funcName)
}
}
}
}
Expand Down Expand Up @@ -2016,9 +2021,10 @@ class Analyzer(override val catalogManager: CatalogManager)
name, other.getClass.getCanonicalName)
}
}
case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter, ignoreNulls) =>
withPosition(u) {
v1SessionCatalog.lookupFunction(funcId, arguments) match {

case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments, isDistinct, filter,
ignoreNulls) => withPosition(u) {
v1SessionCatalog.lookupFunction(ident, arguments) match {
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
Expand Down Expand Up @@ -2095,9 +2101,123 @@ class Analyzer(override val catalogManager: CatalogManager)
case other =>
other
}
}

case u @ UnresolvedFunction(nameParts, arguments, isDistinct, filter, ignoreNulls) =>
withPosition(u) {
expandIdentifier(nameParts) match {
case NonSessionCatalogAndIdentifier(catalog, ident) =>
if (!catalog.isFunctionCatalog) {
throw new AnalysisException(s"Trying to lookup function '$ident' in " +
s"catalog '${catalog.name()}', but it is not a FunctionCatalog.")
}

val unbound = catalog.asFunctionCatalog.loadFunction(ident)
val inputType = StructType(arguments.zipWithIndex.map {
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
})
val bound = try {
unbound.bind(inputType)
} catch {
case unsupported: UnsupportedOperationException =>
throw new AnalysisException(s"Function '${unbound.name}' cannot process " +
s"input: (${arguments.map(_.dataType.simpleString).mkString(", ")}): " +
unsupported.getMessage, cause = Some(unsupported))
}

bound match {
case scalarFunc: ScalarFunction[_] =>
processV2ScalarFunction(scalarFunc, inputType, arguments, isDistinct,
filter, ignoreNulls)
case aggFunc: V2AggregateFunction[_, _] =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, put into a new method.

processV2AggregateFunction(aggFunc, arguments, isDistinct, filter,
ignoreNulls)
case _ =>
failAnalysis(s"Function '${bound.name()}' does not implement ScalarFunction" +
s" or AggregateFunction")
}

case _ => u
}
}
}
}

private def processV2ScalarFunction(
scalarFunc: ScalarFunction[_],
inputType: StructType,
arguments: Seq[Expression],
isDistinct: Boolean,
filter: Option[Expression],
ignoreNulls: Boolean): Expression = {
if (isDistinct) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "DISTINCT")
} else if (filter.isDefined) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "FILTER clause")
} else if (ignoreNulls) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "IGNORE NULLS")
} else {
// TODO: implement type coercion by looking at input type from the UDF. We
// may also want to check if the parameter types from the magic method
// match the input type through `BoundFunction.inputTypes`.
val argClasses = inputType.fields.map(_.dataType)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, returnNullable = scalarFunc.isResultNullable)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface
// defined in Java, the method can still be found from
// `getDeclaredMethod`.
// since `inputType` is a `StructType`, it is mapped to a `InternalRow`
// which we can use to lookup the `produceResult` method.
findMethod(scalarFunc, "produceResult", Seq(inputType)) match {
case Some(_) =>
ApplyFunctionExpression(scalarFunc, arguments)
case None =>
failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" +
s" magic method nor override 'produceResult'")
}
}
}
}

private def processV2AggregateFunction(
aggFunc: V2AggregateFunction[_, _],
arguments: Seq[Expression],
isDistinct: Boolean,
filter: Option[Expression],
ignoreNulls: Boolean): Expression = {
if (ignoreNulls) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
aggFunc.name(), "IGNORE NULLS")
}
val aggregator = V2Aggregator(aggFunc, arguments)
AggregateExpression(aggregator, Complete, isDistinct, filter)
}

/**
* Check if the input `fn` implements the given `methodName` with parameter types specified
* via `inputType`.
*/
private def findMethod(
fn: BoundFunction,
methodName: String,
inputType: Seq[DataType]): Option[Method] = {
val cls = fn.getClass
try {
val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass)
Some(cls.getDeclaredMethod(methodName, argClasses: _*))
} catch {
case _: NoSuchMethodException =>
None
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.DataType

Expand All @@ -30,13 +30,14 @@ import org.apache.spark.sql.types.DataType
* so we need to resolve higher order function when all children are either resolved or a lambda
* function.
*/
case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] {
case class ResolveHigherOrderFunctions(catalogManager: CatalogManager)
extends Rule[LogicalPlan] with LookupCatalog {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
case u @ UnresolvedFunction(fn, children, false, filter, ignoreNulls)
case u @ UnresolvedFunction(AsFunctionIdentifier(ident), children, false, filter, ignoreNulls)
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
if hasLambdaAndResolvedArguments(children) =>
withPosition(u) {
catalog.lookupFunction(fn, children) match {
catalogManager.v1SessionCatalog.lookupFunction(ident, children) match {
case func: HigherOrderFunction =>
filter.foreach(_.failAnalysis("FILTER predicate specified, " +
s"but ${func.prettyName} is not an aggregate function"))
Expand Down
Loading