Skip to content

Commit

Permalink
[SPARK-48675][SQL] Fix cache table with collated column
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Following sequence of queries produces the error:
```
> cache lazy table t as select col from values ('a' collate utf8_lcase) as (col);
> select col from t;
org.apache.spark.SparkException: not support type: org.apache.spark.sql.types.StringType1.
        at org.apache.spark.sql.errors.QueryExecutionErrors$.notSupportTypeError(QueryExecutionErrors.scala:1069)
        at org.apache.spark.sql.execution.columnar.ColumnBuilder$.apply(ColumnBuilder.scala:200)
        at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.$anonfun$next$1(InMemoryRelation.scala:85)
        at scala.collection.immutable.List.map(List.scala:247)
        at scala.collection.immutable.List.map(List.scala:79)
        at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.next(InMemoryRelation.scala:84)
        at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.next(InMemoryRelation.scala:82)
        at org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$2.next(InMemoryRelation.scala:296)
        at org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$2.next(InMemoryRelation.scala:293)
...
```
This is also the problem on non-lazy cached tables.

It turns out that the problem happens to occur during the execution of `InMemoryTableScanExec` where we need to update `ColumnAccessor`, `ColumnBuilder`, `ColumnType` and `ColumnStats`.

### Why are the changes needed?

To fix the described error.

### Does this PR introduce _any_ user-facing change?

Yes, the described sequence of queries should produce valid results after these changes are applied instead of throwing error.

### How was this patch tested?

Added checks to columnar suites for the mentioned classes and integration test to `CollationSuite`.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #47045 from nikolamand-db/SPARK-48675.

Authored-by: Nikola Mandic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
nikolamand-db authored and cloud-fan committed Jun 21, 2024
1 parent f0563ef commit 0bc38ac
Show file tree
Hide file tree
Showing 15 changed files with 205 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ private[columnar] class FloatColumnAccessor(buffer: ByteBuffer)
private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, DOUBLE)

private[columnar] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING)
private[columnar] class StringColumnAccessor(buffer: ByteBuffer, dataType: StringType)
extends NativeColumnAccessor(buffer, STRING(dataType))

private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
Expand Down Expand Up @@ -147,7 +147,7 @@ private[sql] object ColumnAccessor {
new LongColumnAccessor(buf)
case FloatType => new FloatColumnAccessor(buf)
case DoubleType => new DoubleColumnAccessor(buf)
case StringType => new StringColumnAccessor(buf)
case s: StringType => new StringColumnAccessor(buf, s)
case BinaryType => new BinaryColumnAccessor(buf)
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnAccessor(buf, dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ private[columnar]
class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE)

private[columnar]
class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
class StringColumnBuilder(dataType: StringType)
extends NativeColumnBuilder(new StringColumnStats(dataType), STRING(dataType))

private[columnar]
class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
Expand Down Expand Up @@ -185,7 +186,7 @@ private[columnar] object ColumnBuilder {
new LongColumnBuilder
case FloatType => new FloatColumnBuilder
case DoubleType => new DoubleColumnBuilder
case StringType => new StringColumnBuilder
case s: StringType => new StringColumnBuilder(s)
case BinaryType => new BinaryColumnBuilder
case CalendarIntervalType => new IntervalColumnBuilder
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,23 +255,25 @@ private[columnar] final class DoubleColumnStats extends ColumnStats {
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

private[columnar] final class StringColumnStats extends ColumnStats {
private[columnar] final class StringColumnStats(collationId: Int) extends ColumnStats {
def this(dt: StringType) = this(dt.collationId)

protected var upper: UTF8String = null
protected var lower: UTF8String = null

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getUTF8String(ordinal)
val size = STRING.actualSize(row, ordinal)
val size = STRING(collationId).actualSize(row, ordinal)
gatherValueStats(value, size)
} else {
gatherNullStats()
}
}

def gatherValueStats(value: UTF8String, size: Int): Unit = {
if (upper == null || value.binaryCompare(upper) > 0) upper = value.clone()
if (lower == null || value.binaryCompare(lower) < 0) lower = value.clone()
if (upper == null || value.semanticCompare(upper, collationId) > 0) upper = value.clone()
if (lower == null || value.semanticCompare(lower, collationId) < 0) lower = value.clone()
sizeInBytes += size
count += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,8 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType
}
}

private[columnar] object STRING
extends NativeColumnType(PhysicalStringType(StringType.collationId), 8)
private[columnar] case class STRING(collationId: Int)
extends NativeColumnType(PhysicalStringType(collationId), 8)
with DirectCopyColumnType[UTF8String] {

override def actualSize(row: InternalRow, ordinal: Int): Int = {
Expand Down Expand Up @@ -532,6 +532,12 @@ private[columnar] object STRING
override def clone(v: UTF8String): UTF8String = v.clone()
}

private[columnar] object STRING {
def apply(dt: StringType): STRING = {
STRING(dt.collationId)
}
}

private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(PhysicalDecimalType(precision, scale), 8) {

Expand Down Expand Up @@ -821,7 +827,7 @@ private[columnar] object ColumnType {
case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => LONG
case FloatType => FLOAT
case DoubleType => DOUBLE
case StringType => STRING
case s: StringType => STRING(s)
case BinaryType => BINARY
case i: CalendarIntervalType => CALENDAR_INTERVAL
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
classOf[LongColumnAccessor].getName
case FloatType => classOf[FloatColumnAccessor].getName
case DoubleType => classOf[DoubleColumnAccessor].getName
case StringType => classOf[StringColumnAccessor].getName
case _: StringType => classOf[StringColumnAccessor].getName
case BinaryType => classOf[BinaryColumnAccessor].getName
case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
Expand All @@ -101,7 +101,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
val createCode = dt match {
case t if CodeGenerator.isPrimitiveType(dt) =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case NullType | StringType | BinaryType | CalendarIntervalType =>
case NullType | BinaryType | CalendarIntervalType =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case other =>
s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme {
}

override def supports(columnType: ColumnType[_]): Boolean = columnType match {
case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
case INT | LONG | SHORT | BYTE | _: STRING | BOOLEAN => true
case _ => false
}

Expand Down Expand Up @@ -373,7 +373,7 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme {
}

override def supports(columnType: ColumnType[_]): Boolean = columnType match {
case INT | LONG | STRING => true
case INT | LONG | _: STRING => true
case _ => false
}

Expand Down
34 changes: 34 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}
Expand Down Expand Up @@ -1431,4 +1432,37 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
})
}

test("cache table with collated columns") {
val collations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI")
val lazyOptions = Seq(false, true)

for (
collation <- collations;
lazyTable <- lazyOptions
) {
val lazyStr = if (lazyTable) "LAZY" else ""

def checkCacheTable(values: String): Unit = {
sql(s"CACHE $lazyStr TABLE tbl AS SELECT col FROM VALUES ($values) AS (col)")
// Checks in-memory fetching code path.
val all = sql("SELECT col FROM tbl")
assert(all.queryExecution.executedPlan.collectFirst {
case _: InMemoryTableScanExec => true
}.nonEmpty)
checkAnswer(all, Row("a"))
// Checks column stats code path.
checkAnswer(sql("SELECT col FROM tbl WHERE col = 'a'"), Row("a"))
checkAnswer(sql("SELECT col FROM tbl WHERE col = 'b'"), Seq.empty)
}

withTable("tbl") {
checkCacheTable(s"'a' COLLATE $collation")
}
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) {
withTable("tbl") {
checkCacheTable("'a'")
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.types.StringType

class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
Expand All @@ -28,9 +29,9 @@ class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
testDecimalColumnStats(Array(null, null, 0))
testIntervalColumnStats(Array(null, null, 0))
testStringColumnStats(Array(null, null, 0))

def testColumnStats[T <: PhysicalDataType, U <: ColumnStats](
columnStatsClass: Class[U],
Expand Down Expand Up @@ -141,4 +142,60 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}
}

def testStringColumnStats[T <: PhysicalDataType, U <: ColumnStats](
initialStatistics: Array[Any]): Unit = {

Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI").foreach(collation => {
val columnType = STRING(StringType(collation))

test(s"STRING($collation): empty") {
val columnStats = new StringColumnStats(StringType(collation).collationId)
columnStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}

test(s"STRING($collation): non-empty") {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._

val columnStats = new StringColumnStats(StringType(collation).collationId)
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))

val values = rows.take(10).map(_.get(0,
ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)))
val ordering = PhysicalDataType.ordering(
ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType))
val stats = columnStats.collectedStatistics

assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
}
}
})

test("STRING(UTF8_LCASE): collation-defined ordering") {
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.unsafe.types.UTF8String

val columnStats = new StringColumnStats(StringType("UTF8_LCASE").collationId)
val rows = Seq("b", "a", "C", "A").map(str => {
val row = new GenericInternalRow(1)
row(0) = UTF8String.fromString(str)
row
})
rows.foreach(columnStats.gatherStats(_, 0))

val stats = columnStats.collectedStatistics
assertResult(UTF8String.fromString("a"), "Wrong lower bound")(stats(0))
assertResult(UTF8String.fromString("C"), "Wrong upper bound")(stats(1))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalDataType, PhysicalMapType, PhysicalStructType}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
Expand All @@ -40,7 +41,9 @@ class ColumnTypeSuite extends SparkFunSuite {
val checks = Map(
NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8,
FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12,
STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
STRING(StringType) -> 8, STRING(StringType("UTF8_LCASE")) -> 8,
STRING(StringType("UNICODE")) -> 8, STRING(StringType("UNICODE_CI")) -> 8,
BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
CALENDAR_INTERVAL -> 16)

checks.foreach { case (columnType, expectedSize) =>
Expand Down Expand Up @@ -73,7 +76,12 @@ class ColumnTypeSuite extends SparkFunSuite {
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(STRING, "hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length)
Seq(
"UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
).foreach(collation => {
checkActualSize(STRING(StringType(collation)),
"hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length)
})
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
Expand All @@ -93,7 +101,10 @@ class ColumnTypeSuite extends SparkFunSuite {
testNativeColumnType(FLOAT)
testNativeColumnType(DOUBLE)
testNativeColumnType(COMPACT_DECIMAL(15, 10))
testNativeColumnType(STRING)
testNativeColumnType(STRING(StringType)) // UTF8_BINARY
testNativeColumnType(STRING(StringType("UTF8_LCASE")))
testNativeColumnType(STRING(StringType("UNICODE")))
testNativeColumnType(STRING(StringType("UNICODE_CI")))

testColumnType(NULL)
testColumnType(BINARY)
Expand All @@ -104,20 +115,28 @@ class ColumnTypeSuite extends SparkFunSuite {
testColumnType(CALENDAR_INTERVAL)

def testNativeColumnType[T <: PhysicalDataType](columnType: NativeColumnType[T]): Unit = {
testColumnType[T#InternalType](columnType)
val typeName = columnType match {
case s: STRING =>
val collation = CollationFactory.fetchCollation(s.collationId).collationName
Some(if (collation == "UTF8_BINARY") "STRING" else s"STRING($collation)")
case _ => None
}
testColumnType[T#InternalType](columnType, typeName)
}

def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = {

def testColumnType[JvmType](
columnType: ColumnType[JvmType],
typeName: Option[String] = None): Unit = {
val proj = UnsafeProjection.create(
Array[DataType](ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)))
val converter = CatalystTypeConverters.createToScalaConverter(
ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType))
val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy())
val totalSize = seq.map(_.getSizeInBytes).sum
val bufferSize = Math.max(DEFAULT_BUFFER_SIZE, totalSize)
val testName = typeName.getOrElse(columnType.toString)

test(s"$columnType append/extract") {
test(s"$testName append/extract") {
val buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.nativeOrder())
seq.foreach(r => columnType.append(columnType.getField(r, 0), buffer))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object ColumnarTestUtils {
case LONG => Random.nextLong()
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case _: STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case BINARY => randomBytes(Random.nextInt(32))
case CALENDAR_INTERVAL =>
new CalendarInterval(Random.nextInt(), Random.nextInt(), Random.nextLong())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalMapType, PhysicalStructType}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.types._

class TestNullableColumnAccessor[JvmType](
Expand All @@ -41,21 +42,33 @@ object TestNullableColumnAccessor {
class NullableColumnAccessorSuite extends SparkFunSuite {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._

Seq(
val stringTypes = Seq(
STRING(StringType), // UTF8_BINARY
STRING(StringType("UTF8_LCASE")),
STRING(StringType("UNICODE")),
STRING(StringType("UNICODE_CI")))
val otherTypes = Seq(
NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRUCT(PhysicalStructType(Array(StructField("a", StringType)))),
ARRAY(PhysicalArrayType(IntegerType, true)),
MAP(PhysicalMapType(IntegerType, StringType, true)),
CALENDAR_INTERVAL)
.foreach {

stringTypes.foreach(s => {
val collation = CollationFactory.fetchCollation(s.collationId).collationName
val typeName = if (collation == "UTF8_BINARY") "STRING" else s"STRING($collation)"
testNullableColumnAccessor(s, Some(typeName))
})
otherTypes.foreach {
testNullableColumnAccessor(_)
}

def testNullableColumnAccessor[JvmType](
columnType: ColumnType[JvmType]): Unit = {
columnType: ColumnType[JvmType],
testTypeName: Option[String] = None): Unit = {

val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
val typeName = testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
val nullRow = makeNullRow(1)

test(s"Nullable $typeName column accessor: empty column") {
Expand Down
Loading

0 comments on commit 0bc38ac

Please sign in to comment.