Skip to content

Commit

Permalink
[SPARK-49026][CONNECT] Add ColumnNode to Proto conversion
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR adds a converter that converts ColumnNodes into Connect proto.Expression.

### Why are the changes needed?
TBD

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added a test suite

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #47812 from hvanhovell/SPARK-49026.

Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
hvanhovell authored and HyukjinKwon committed Aug 21, 2024
1 parent bc7bfbc commit 3305939
Show file tree
Hide file tree
Showing 7 changed files with 604 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect

import scala.jdk.CollectionConverters._

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering.{SORT_NULLS_FIRST, SORT_NULLS_LAST}
import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection.{SORT_DIRECTION_ASCENDING, SORT_DIRECTION_DESCENDING}
import org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBoundary, FrameType}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder
import org.apache.spark.sql.expressions.ScalaUserDefinedFunction
import org.apache.spark.sql.internal._

/**
* Converter for [[ColumnNode]] to [[proto.Expression]] conversions.
*/
object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
override def apply(node: ColumnNode): proto.Expression = {
val builder = proto.Expression.newBuilder()
// TODO(SPARK-49273) support Origin in Connect Scala Client.
node match {
case Literal(value, None, _) =>
builder.setLiteral(toLiteralProtoBuilder(value))

case Literal(value, Some(dataType), _) =>
builder.setLiteral(toLiteralProtoBuilder(value, dataType))

case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) =>
val b = builder.getUnresolvedAttributeBuilder
.setUnparsedIdentifier(unparsedIdentifier)
.setIsMetadataColumn(isMetadataColumn)
planId.foreach(b.setPlanId)

case UnresolvedStar(unparsedTarget, planId, _) =>
val b = builder.getUnresolvedStarBuilder
unparsedTarget.foreach(b.setUnparsedTarget)
planId.foreach(b.setPlanId)

case UnresolvedRegex(regex, planId, _) =>
val b = builder.getUnresolvedRegexBuilder
.setColName(regex)
planId.foreach(b.setPlanId)

case UnresolvedFunction(functionName, arguments, isDistinct, isUserDefinedFunction, _, _) =>
// TODO(SPARK-49087) use internal namespace.
builder.getUnresolvedFunctionBuilder
.setFunctionName(functionName)
.setIsUserDefinedFunction(isUserDefinedFunction)
.setIsDistinct(isDistinct)
.addAllArguments(arguments.map(apply).asJava)

case Alias(child, name, metadata, _) =>
val b = builder.getAliasBuilder.setExpr(apply(child))
name.foreach(b.addName)
metadata.foreach(m => b.setMetadata(m.json))

case Cast(child, dataType, evalMode, _) =>
val b = builder.getCastBuilder
.setExpr(apply(child))
.setType(DataTypeProtoConverter.toConnectProtoType(dataType))
evalMode.foreach { mode =>
val convertedMode = mode match {
case Cast.Try => proto.Expression.Cast.EvalMode.EVAL_MODE_TRY
case Cast.Ansi => proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI
case Cast.Legacy => proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY
}
b.setEvalMode(convertedMode)
}

case SqlExpression(expression, _) =>
builder.getExpressionStringBuilder.setExpression(expression)

case s: SortOrder =>
builder.setSortOrder(convertSortOrder(s))

case Window(windowFunction, windowSpec, _) =>
val b = builder.getWindowBuilder
.setWindowFunction(apply(windowFunction))
.addAllPartitionSpec(windowSpec.partitionColumns.map(apply).asJava)
.addAllOrderSpec(windowSpec.sortColumns.map(convertSortOrder).asJava)
windowSpec.frame.foreach { frame =>
b.getFrameSpecBuilder
.setFrameType(frame.frameType match {
case WindowFrame.Row => FrameType.FRAME_TYPE_ROW
case WindowFrame.Range => FrameType.FRAME_TYPE_RANGE
})
.setLower(convertFrameBoundary(frame.lower))
.setUpper(convertFrameBoundary(frame.upper))
}

case UnresolvedExtractValue(child, extraction, _) =>
builder.getUnresolvedExtractValueBuilder
.setChild(apply(child))
.setExtraction(apply(extraction))

case UpdateFields(structExpression, fieldName, valueExpression, _) =>
val b = builder.getUpdateFieldsBuilder
.setStructExpression(apply(structExpression))
.setFieldName(fieldName)
valueExpression.foreach(v => b.setValueExpression(apply(v)))

case v: UnresolvedNamedLambdaVariable =>
builder.setUnresolvedNamedLambdaVariable(convertNamedLambdaVariable(v))

case LambdaFunction(function, arguments, _) =>
builder.getLambdaFunctionBuilder
.setFunction(apply(function))
.addAllArguments(arguments.map(convertNamedLambdaVariable).asJava)

case InvokeInlineUserDefinedFunction(udf: ScalaUserDefinedFunction, arguments, false, _) =>
val b = builder.getCommonInlineUserDefinedFunctionBuilder
.setScalarScalaUdf(udf.udf)
.setDeterministic(udf.deterministic)
.addAllArguments(arguments.map(apply).asJava)
udf.givenName.foreach(b.setFunctionName)

case CaseWhenOtherwise(branches, otherwise, _) =>
val b = builder.getUnresolvedFunctionBuilder
.setFunctionName("when")
branches.foreach { case (condition, value) =>
b.addArguments(apply(condition))
b.addArguments(apply(value))
}
otherwise.foreach { value =>
b.addArguments(apply(value))
}

case ProtoColumnNode(e, _) =>
return e

case node =>
throw SparkException.internalError("Unsupported ColumnNode: " + node)
}
builder.build()
}

private def convertSortOrder(s: SortOrder): proto.Expression.SortOrder = {
proto.Expression.SortOrder
.newBuilder()
.setChild(apply(s.child))
.setDirection(s.sortDirection match {
case SortOrder.Ascending => SORT_DIRECTION_ASCENDING
case SortOrder.Descending => SORT_DIRECTION_DESCENDING
})
.setNullOrdering(s.nullOrdering match {
case SortOrder.NullsFirst => SORT_NULLS_FIRST
case SortOrder.NullsLast => SORT_NULLS_LAST
})
.build()
}

private def convertFrameBoundary(boundary: WindowFrame.FrameBoundary): FrameBoundary = {
val builder = FrameBoundary.newBuilder()
boundary match {
case WindowFrame.UnboundedPreceding => builder.setUnbounded(true)
case WindowFrame.UnboundedFollowing => builder.setUnbounded(true)
case WindowFrame.CurrentRow => builder.setCurrentRow(true)
case WindowFrame.Value(value) => builder.setValue(apply(value))
}
builder.build()
}

private def convertNamedLambdaVariable(
v: UnresolvedNamedLambdaVariable): proto.Expression.UnresolvedNamedLambdaVariable = {
proto.Expression.UnresolvedNamedLambdaVariable.newBuilder().addNameParts(v.name).build()
}
}

case class ProtoColumnNode(
expr: proto.Expression,
override val origin: Origin = CurrentOrigin.get)
extends ColumnNode {
override def sql: String = expr.toString
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket}
import org.apache.spark.sql.internal.UserDefinedFunctionLike
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils}

Expand Down Expand Up @@ -101,13 +102,14 @@ case class ScalaUserDefinedFunction private[sql] (
serializedUdfPacket: Array[Byte],
inputTypes: Seq[proto.DataType],
outputType: proto.DataType,
name: Option[String],
givenName: Option[String],
override val nullable: Boolean,
override val deterministic: Boolean,
aggregate: Boolean)
extends UserDefinedFunction {
extends UserDefinedFunction
with UserDefinedFunctionLike {

private[expressions] lazy val udf = {
private[sql] lazy val udf = {
val scalaUdfBuilder = proto.ScalarScalaUDF
.newBuilder()
.setPayload(ByteString.copyFrom(serializedUdfPacket))
Expand All @@ -128,10 +130,10 @@ case class ScalaUserDefinedFunction private[sql] (
.setScalarScalaUdf(udf)
.addAllArguments(exprs.map(_.expr).asJava)

name.foreach(udfBuilder.setFunctionName)
givenName.foreach(udfBuilder.setFunctionName)
}

override def withName(name: String): ScalaUserDefinedFunction = copy(name = Option(name))
override def withName(name: String): ScalaUserDefinedFunction = copy(givenName = Option(name))

override def asNonNullable(): ScalaUserDefinedFunction = copy(nullable = false)

Expand All @@ -143,9 +145,11 @@ case class ScalaUserDefinedFunction private[sql] (
.setDeterministic(deterministic)
.setScalarScalaUdf(udf)

name.foreach(builder.setFunctionName)
givenName.foreach(builder.setFunctionName)
builder.build()
}

override def name: String = givenName.getOrElse("UDF")
}

object ScalaUserDefinedFunction {
Expand Down Expand Up @@ -195,7 +199,7 @@ object ScalaUserDefinedFunction {
serializedUdfPacket = udfPacketBytes,
inputTypes = inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType),
outputType = DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType),
name = None,
givenName = None,
nullable = true,
deterministic = true,
aggregate = aggregate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession {
serializedUdfPacket = udfByteArray,
inputTypes = Seq(ProtoDataTypes.IntegerType),
outputType = ProtoDataTypes.IntegerType,
name = Some("dummyUdf"),
givenName = Some("dummyUdf"),
nullable = true,
deterministic = true,
aggregate = false)
Expand Down
Loading

0 comments on commit 3305939

Please sign in to comment.