Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-15888] [SQL] fix Python UDF with aggregate #13682

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,21 @@ def test_broadcast_in_udf(self):

def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import udf, col, sum
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])

my_copy = udf(lambda x: x, IntegerType())
my_add = udf(lambda a, b: int(a + b), IntegerType())
my_strlen = udf(lambda x: len(x), IntegerType())
sel = df.groupBy(my_copy(col("key")).alias("k"))\
.agg(sum(my_strlen(col("value"))).alias("s"))\
.select(my_add(col("k"), col("s")).alias("t"))
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf

class SparkOptimizer(
Expand All @@ -28,6 +29,7 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog, conf) {

override def batches: Seq[Batch] = super.batches :+ Batch(
"User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
override def batches: Seq[Batch] = super.batches :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi

def children: Seq[SparkPlan] = child :: Nil

override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the ! for BatchEvalPythonExec


private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,68 @@
package org.apache.spark.sql.execution.python

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.SparkPlan


/**
* Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
* grouping key, evaluate them after aggregate.
*/
private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {

/**
* Returns whether the expression could only be evaluated within aggregate.
*/
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
e.isInstanceOf[AggregateExpression] ||
agg.groupingExpressions.exists(_.semanticEquals(e))
}

private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
expr.find {
e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
}.isDefined
}

private def extract(agg: Aggregate): LogicalPlan = {
val projList = new ArrayBuffer[NamedExpression]()
val aggExpr = new ArrayBuffer[NamedExpression]()
agg.aggregateExpressions.foreach { expr =>
if (hasPythonUdfOverAggregate(expr, agg)) {
// Python UDF can only be evaluated after aggregate
val newE = expr transformDown {
case e: Expression if belongAggregate(e, agg) =>
val alias = e match {
case a: NamedExpression => a
case o => Alias(e, "agg")()
}
aggExpr += alias
alias.toAttribute
}
projList += newE.asInstanceOf[NamedExpression]
} else {
aggExpr += expr
projList += expr.toAttribute
}
}
// There is no Python UDF over aggregate expression
Project(projList, agg.copy(aggregateExpressions = aggExpr))
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
extract(agg)
}
}


/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
* alone in a batch.
Expand Down Expand Up @@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
}

/**
* Extract all the PythonUDFs from the current operator.
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
def extract(plan: SparkPlan): SparkPlan = {
private def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
Expand All @@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
// Other cases are disallowed as they are ambiguous or would require a cartesian
// product.
udfs.filterNot(attributeMap.contains).foreach { udf =>
if (udf.references.subsetOf(plan.inputSet)) {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
} else {
sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
}
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
}

val rewritten = plan.transformExpressions {
Expand Down