Skip to content

Commit

Permalink
[SPARK-49022][CONNECT][SQL][FOLLOW-UP] Parse unresolved identifier to…
Browse files Browse the repository at this point in the history
… keep the behavior same

### What changes were proposed in this pull request?

This PR is a followup of apache#47688 that keeps `Column.toString` as the same before.

### Why are the changes needed?

To keep the same behaviour with Spark Classic and Connect.

### Does this PR introduce _any_ user-facing change?

No, the main change has not been released out yet.

### How was this patch tested?

Will be added separately. I manually tested:

```scala
import org.apache.spark.sql.functions.col
val name = "with`!#$%dot".replace("`", "``")
col(s"`${name}`").toString.equals("with`!#$%dot")
```

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#48376 from HyukjinKwon/SPARK-49022-followup.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed Oct 9, 2024
1 parent 5f64e80 commit c1f18a0
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
case Literal(value, Some(dataType), _) =>
builder.setLiteral(toLiteralProtoBuilder(value, dataType))

case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) =>
case u @ UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) =>
val escapedName = u.sql
val b = builder.getUnresolvedAttributeBuilder
.setUnparsedIdentifier(unparsedIdentifier)
.setUnparsedIdentifier(escapedName)
if (isMetadataColumn) {
// We only set this field when it is needed. If we would always set it,
// too many of the verbatims we use for testing would have to be regenerated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ class ColumnTestSuite extends ConnectFunSuite {
assert(explain1 != explain2)
assert(explain1.strip() == "+(a, b)")
assert(explain2.contains("UnresolvedFunction(+"))
assert(explain2.contains("UnresolvedAttribute(a"))
assert(explain2.contains("UnresolvedAttribute(b"))
assert(explain2.contains("UnresolvedAttribute(List(a"))
assert(explain2.contains("UnresolvedAttribute(List(b"))
}

private def testColName(dataType: DataType, f: ColumnName => StructField): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong
import ColumnNode._

import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.catalyst.util.AttributeNameParser
import org.apache.spark.sql.errors.DataTypeErrorsBase
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, Metadata}
import org.apache.spark.util.SparkClassUtils
Expand Down Expand Up @@ -122,22 +123,48 @@ private[sql] case class Literal(
/**
* Reference to an attribute produced by one of the underlying DataFrames.
*
* @param unparsedIdentifier
* @param nameParts
* name of the attribute.
* @param planId
* id of the plan (Dataframe) that produces the attribute.
* @param isMetadataColumn
* whether this is a metadata column.
*/
private[sql] case class UnresolvedAttribute(
unparsedIdentifier: String,
nameParts: Seq[String],
planId: Option[Long] = None,
isMetadataColumn: Boolean = false,
override val origin: Origin = CurrentOrigin.get)
extends ColumnNode {

override private[internal] def normalize(): UnresolvedAttribute =
copy(planId = None, origin = NO_ORIGIN)
override def sql: String = unparsedIdentifier

override def sql: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
}

private[sql] object UnresolvedAttribute {
def apply(
unparsedIdentifier: String,
planId: Option[Long],
isMetadataColumn: Boolean,
origin: Origin): UnresolvedAttribute = UnresolvedAttribute(
AttributeNameParser.parseAttributeName(unparsedIdentifier),
planId = planId,
isMetadataColumn = isMetadataColumn,
origin = origin)

def apply(
unparsedIdentifier: String,
planId: Option[Long],
isMetadataColumn: Boolean): UnresolvedAttribute =
apply(unparsedIdentifier, planId, isMetadataColumn, CurrentOrigin.get)

def apply(unparsedIdentifier: String, planId: Option[Long]): UnresolvedAttribute =
apply(unparsedIdentifier, planId, false, CurrentOrigin.get)

def apply(unparsedIdentifier: String): UnresolvedAttribute =
apply(unparsedIdentifier, None, false, CurrentOrigin.get)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres
case Literal(value, None, _) =>
expressions.Literal(value)

case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) =>
convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn)
case UnresolvedAttribute(nameParts, planId, isMetadataColumn, _) =>
convertUnresolvedAttribute(nameParts, planId, isMetadataColumn)

case UnresolvedStar(unparsedTarget, None, _) =>
val target = unparsedTarget.map { t =>
Expand All @@ -74,7 +74,7 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres
analysis.UnresolvedRegex(columnNameRegex, Some(nameParts), conf.caseSensitiveAnalysis)

case UnresolvedRegex(unparsedIdentifier, planId, _) =>
convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn = false)
convertUnresolvedRegex(unparsedIdentifier, planId)

case UnresolvedFunction(functionName, arguments, isDistinct, isUDF, isInternal, _) =>
val nameParts = if (isUDF) {
Expand Down Expand Up @@ -223,10 +223,10 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres
}

private def convertUnresolvedAttribute(
unparsedIdentifier: String,
nameParts: Seq[String],
planId: Option[Long],
isMetadataColumn: Boolean): analysis.UnresolvedAttribute = {
val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier)
val attribute = analysis.UnresolvedAttribute(nameParts)
if (planId.isDefined) {
attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get)
}
Expand All @@ -235,6 +235,16 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres
}
attribute
}

private def convertUnresolvedRegex(
unparsedIdentifier: String,
planId: Option[Long]): analysis.UnresolvedAttribute = {
val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier)
if (planId.isDefined) {
attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get)
}
attribute
}
}

private[sql] object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter {
Expand Down

0 comments on commit c1f18a0

Please sign in to comment.