From fec1562b0ea03ff42d2468ea8ff7cbbc569336d8 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 23 Sep 2024 20:03:14 +0800 Subject: [PATCH] [SPARK-49755][CONNECT] Remove special casing for avro functions in Connect ### What changes were proposed in this pull request? apply the built-in registered functions ### Why are the changes needed? code simplification ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? updated tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48209 from zhengruifeng/connect_avro. Authored-by: Ruifeng Zheng Signed-off-by: yangjie01 --- .../expressions/toFromAvroSqlFunctions.scala | 3 ++ .../from_avro_with_options.explain | 2 +- .../from_avro_without_options.explain | 2 +- .../to_avro_with_schema.explain | 2 +- .../to_avro_without_schema.explain | 2 +- sql/connect/server/pom.xml | 2 +- .../connect/planner/SparkConnectPlanner.scala | 47 +------------------ 7 files changed, 9 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromAvroSqlFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromAvroSqlFunctions.scala index 58bddafac0882..457f469e0f687 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromAvroSqlFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromAvroSqlFunctions.scala @@ -61,6 +61,9 @@ case class FromAvro(child: Expression, jsonFormatSchema: Expression, options: Ex override def second: Expression = jsonFormatSchema override def third: Expression = options + def this(child: Expression, jsonFormatSchema: Expression) = + this(child, jsonFormatSchema, Literal.create(null)) + override def withNewChildrenInternal( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = { copy(child = newFirst, jsonFormatSchema = newSecond, options = newThird) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain index 1ef91ef8c36ac..f08c804d3b88a 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain @@ -1,2 +1,2 @@ -Project [from_avro(bytes#0, {"type": "int", "name": "id"}, (mode,FAILFAST), (compression,zstandard)) AS from_avro(bytes)#0] +Project [from_avro(bytes#0, {"type": "int", "name": "id"}, (mode,FAILFAST), (compression,zstandard)) AS from_avro(bytes, {"type": "int", "name": "id"}, map(mode, FAILFAST, compression, zstandard))#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_without_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_without_options.explain index 8fca0b5341694..6fe4a8babc689 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_without_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_without_options.explain @@ -1,2 +1,2 @@ -Project [from_avro(bytes#0, {"type": "string", "name": "name"}) AS from_avro(bytes)#0] +Project [from_avro(bytes#0, {"type": "string", "name": "name"}) AS from_avro(bytes, {"type": "string", "name": "name"}, NULL)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_with_schema.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_with_schema.explain index cd2dc984e3ffa..8ba9248f844c7 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_with_schema.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_with_schema.explain @@ -1,2 +1,2 @@ -Project [to_avro(a#0, Some({"type": "int", "name": "id"})) AS to_avro(a)#0] +Project [to_avro(a#0, Some({"type": "int", "name": "id"})) AS to_avro(a, {"type": "int", "name": "id"})#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain index a5371c70ac78a..b2947334945e3 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain @@ -1,2 +1,2 @@ -Project [to_avro(id#0L, None) AS to_avro(id)#0] +Project [to_avro(id#0L, None) AS to_avro(id, NULL)#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml index 3350c4261e9da..12e3ed9030437 100644 --- a/sql/connect/server/pom.xml +++ b/sql/connect/server/pom.xml @@ -105,7 +105,7 @@ org.apache.spark spark-avro_${scala.binary.version} ${project.version} - provided + test org.apache.spark diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 33c9edb1cd21a..231e54ff77d29 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -44,7 +44,6 @@ import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, Row, SparkSession} -import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedTranspose} import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder} @@ -1523,8 +1522,7 @@ class SparkConnectPlanner( case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE => transformUnresolvedAttribute(exp.getUnresolvedAttribute) case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION => - transformUnregisteredFunction(exp.getUnresolvedFunction) - .getOrElse(transformUnresolvedFunction(exp.getUnresolvedFunction)) + transformUnresolvedFunction(exp.getUnresolvedFunction) case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias) case proto.Expression.ExprTypeCase.EXPRESSION_STRING => transformExpressionString(exp.getExpressionString) @@ -1844,49 +1842,6 @@ class SparkConnectPlanner( UnresolvedNamedLambdaVariable(variable.getNamePartsList.asScala.toSeq) } - /** - * For some reason, not all functions are registered in 'FunctionRegistry'. For a unregistered - * function, we can still wrap it under the proto 'UnresolvedFunction', and then resolve it in - * this method. - */ - private def transformUnregisteredFunction( - fun: proto.Expression.UnresolvedFunction): Option[Expression] = { - fun.getFunctionName match { - // Avro-specific functions - case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) => - val children = fun.getArgumentsList.asScala.map(transformExpression) - val jsonFormatSchema = extractString(children(1), "jsonFormatSchema") - var options = Map.empty[String, String] - if (fun.getArgumentsCount == 3) { - options = extractMapData(children(2), "Options") - } - Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options)) - - case "to_avro" if Seq(1, 2).contains(fun.getArgumentsCount) => - val children = fun.getArgumentsList.asScala.map(transformExpression) - var jsonFormatSchema = Option.empty[String] - if (fun.getArgumentsCount == 2) { - jsonFormatSchema = Some(extractString(children(1), "jsonFormatSchema")) - } - Some(CatalystDataToAvro(children.head, jsonFormatSchema)) - - case _ => None - } - } - - private def extractString(expr: Expression, field: String): String = expr match { - case Literal(s, StringType) if s != null => s.toString - case other => throw InvalidPlanInput(s"$field should be a literal string, but got $other") - } - - @scala.annotation.tailrec - private def extractMapData(expr: Expression, field: String): Map[String, String] = expr match { - case map: CreateMap => ExprUtils.convertToMapData(map) - case UnresolvedFunction(Seq("map"), args, _, _, _, _, _) => - extractMapData(CreateMap(args), field) - case other => throw InvalidPlanInput(s"$field should be created by map, but got $other") - } - private def transformAlias(alias: proto.Expression.Alias): NamedExpression = { if (alias.getNameCount == 1) { val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) {