diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 45fa449b58ed7..34a8a91a0ddf8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -52,9 +52,10 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) { case Literal(value, Some(dataType), _) => builder.setLiteral(toLiteralProtoBuilder(value, dataType)) - case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + case u @ UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + val escapedName = u.sql val b = builder.getUnresolvedAttributeBuilder - .setUnparsedIdentifier(unparsedIdentifier) + .setUnparsedIdentifier(escapedName) if (isMetadataColumn) { // We only set this field when it is needed. If we would always set it, // too many of the verbatims we use for testing would have to be regenerated. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala index c37100b729029..86c7a20136851 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala @@ -173,8 +173,8 @@ class ColumnTestSuite extends ConnectFunSuite { assert(explain1 != explain2) assert(explain1.strip() == "+(a, b)") assert(explain2.contains("UnresolvedFunction(+")) - assert(explain2.contains("UnresolvedAttribute(a")) - assert(explain2.contains("UnresolvedAttribute(b")) + assert(explain2.contains("UnresolvedAttribute(List(a")) + assert(explain2.contains("UnresolvedAttribute(List(b")) } private def testColName(dataType: DataType, f: ColumnName => StructField): Unit = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index 51b26a1fa2435..979baf12be614 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import ColumnNode._ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.catalyst.util.AttributeNameParser import org.apache.spark.sql.errors.DataTypeErrorsBase import org.apache.spark.sql.types.{DataType, IntegerType, LongType, Metadata} import org.apache.spark.util.SparkClassUtils @@ -122,7 +123,7 @@ private[sql] case class Literal( /** * Reference to an attribute produced by one of the underlying DataFrames. * - * @param unparsedIdentifier + * @param nameParts * name of the attribute. * @param planId * id of the plan (Dataframe) that produces the attribute. @@ -130,14 +131,40 @@ private[sql] case class Literal( * whether this is a metadata column. */ private[sql] case class UnresolvedAttribute( - unparsedIdentifier: String, + nameParts: Seq[String], planId: Option[Long] = None, isMetadataColumn: Boolean = false, override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): UnresolvedAttribute = copy(planId = None, origin = NO_ORIGIN) - override def sql: String = unparsedIdentifier + + override def sql: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") +} + +private[sql] object UnresolvedAttribute { + def apply( + unparsedIdentifier: String, + planId: Option[Long], + isMetadataColumn: Boolean, + origin: Origin): UnresolvedAttribute = UnresolvedAttribute( + AttributeNameParser.parseAttributeName(unparsedIdentifier), + planId = planId, + isMetadataColumn = isMetadataColumn, + origin = origin) + + def apply( + unparsedIdentifier: String, + planId: Option[Long], + isMetadataColumn: Boolean): UnresolvedAttribute = + apply(unparsedIdentifier, planId, isMetadataColumn, CurrentOrigin.get) + + def apply(unparsedIdentifier: String, planId: Option[Long]): UnresolvedAttribute = + apply(unparsedIdentifier, planId, false, CurrentOrigin.get) + + def apply(unparsedIdentifier: String): UnresolvedAttribute = + apply(unparsedIdentifier, None, false, CurrentOrigin.get) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 920c0371292c9..476956e58e8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -54,8 +54,8 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres case Literal(value, None, _) => expressions.Literal(value) - case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => - convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn) + case UnresolvedAttribute(nameParts, planId, isMetadataColumn, _) => + convertUnresolvedAttribute(nameParts, planId, isMetadataColumn) case UnresolvedStar(unparsedTarget, None, _) => val target = unparsedTarget.map { t => @@ -74,7 +74,7 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres analysis.UnresolvedRegex(columnNameRegex, Some(nameParts), conf.caseSensitiveAnalysis) case UnresolvedRegex(unparsedIdentifier, planId, _) => - convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn = false) + convertUnresolvedRegex(unparsedIdentifier, planId) case UnresolvedFunction(functionName, arguments, isDistinct, isUDF, isInternal, _) => val nameParts = if (isUDF) { @@ -223,10 +223,10 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres } private def convertUnresolvedAttribute( - unparsedIdentifier: String, + nameParts: Seq[String], planId: Option[Long], isMetadataColumn: Boolean): analysis.UnresolvedAttribute = { - val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier) + val attribute = analysis.UnresolvedAttribute(nameParts) if (planId.isDefined) { attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) } @@ -235,6 +235,16 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres } attribute } + + private def convertUnresolvedRegex( + unparsedIdentifier: String, + planId: Option[Long]): analysis.UnresolvedAttribute = { + val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier) + if (planId.isDefined) { + attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) + } + attribute + } } private[sql] object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter {