Skip to content

Commit

Permalink
[SPARK-27049][SQL] Create util class to support handling partition va…
Browse files Browse the repository at this point in the history
…lues in file source V2

## What changes were proposed in this pull request?

While I am migrating other data sources, I find that we should abstract the logic that:
1. converting safe `InternalRow`s into `UnsafeRow`s
2. appending partition values to the end of the result row if existed

This PR proposes to support handling partition values in file source v2 abstraction by adding a util class `PartitionReaderWithPartitionValues`.

## How was this patch tested?

Existing unit tests

Closes #23987 from gengliangwang/SPARK-27049.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
gengliangwang authored and cloud-fan committed Mar 7, 2019
1 parent 340c8b8 commit a543f91
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile}
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile, PartitioningUtils}
import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch

abstract class FilePartitionReaderFactory extends PartitionReaderFactory {
Expand All @@ -45,6 +46,19 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory {
def buildColumnarReader(partitionedFile: PartitionedFile): PartitionReader[ColumnarBatch] = {
throw new UnsupportedOperationException("Cannot create columnar reader.")
}

protected def getReadDataSchema(
readSchema: StructType,
partitionSchema: StructType,
isCaseSensitive: Boolean): StructType = {
val partitionNameSet =
partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
val fields = readSchema.fields.filterNot { field =>
partitionNameSet.contains(PartitioningUtils.getColName(field, isCaseSensitive))
}

StructType(fields)
}
}

// A compound class for combining file and its corresponding reader.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.sources.v2.reader.PartitionReader
import org.apache.spark.sql.types.StructType

/**
* A wrapper reader that always appends partition values to [[InternalRow]]s produced by the input
* reader [[fileReader]].
*/
class PartitionReaderWithPartitionValues(
fileReader: PartitionReader[InternalRow],
readDataSchema: StructType,
partitionSchema: StructType,
partitionValues: InternalRow) extends PartitionReader[InternalRow] {
private val fullSchema = readDataSchema.toAttributes ++ partitionSchema.toAttributes
private val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
// Note that we have to apply the converter even though `file.partitionValues` is empty.
// This is because the converter is also responsible for converting safe `InternalRow`s into
// `UnsafeRow`s
private val rowConverter = {
if (partitionSchema.isEmpty) {
() => unsafeProjection(fileReader.get())}
else {
val joinedRow = new JoinedRow()
() => unsafeProjection(joinedRow(fileReader.get(), partitionValues))
}
}

override def next(): Boolean = fileReader.next()

override def get(): InternalRow = rowConverter()

override def close(): Unit = fileReader.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,3 @@ class PartitionRecordReader[T](

override def close(): Unit = rowReader.close()
}

class PartitionRecordReaderWithProject[X, T](
private[this] var rowReader: RecordReader[_, X],
project: X => T) extends PartitionReader[T] {
override def next(): Boolean = rowReader.nextKeyValue()

override def get(): T = project(rowReader.getCurrentValue)

override def close(): Unit = rowReader.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,20 @@ case class OrcPartitionReaderFactory(
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)

val requiredDataSchema = subtractSchema(readSchema, partitionSchema)
val readDataSchema = getReadDataSchema(readSchema, partitionSchema, isCaseSensitive)
val orcRecordReader = new OrcInputFormat[OrcStruct]
.createRecordReader(fileSplit, taskAttemptContext)
val deserializer = new OrcDeserializer(dataSchema, readDataSchema, requestedColIds)
val fileReader = new PartitionReader[InternalRow] {
override def next(): Boolean = orcRecordReader.nextKeyValue()

val fullSchema = requiredDataSchema.toAttributes ++ partitionSchema.toAttributes
val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
val deserializer = new OrcDeserializer(dataSchema, requiredDataSchema, requestedColIds)
override def get(): InternalRow = deserializer.deserialize(orcRecordReader.getCurrentValue)

val projection = if (partitionSchema.length == 0) {
(value: OrcStruct) => unsafeProjection(deserializer.deserialize(value))
} else {
val joinedRow = new JoinedRow()
(value: OrcStruct) =>
unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))
override def close(): Unit = orcRecordReader.close()
}
new PartitionRecordReaderWithProject(orcRecordReader, projection)

new PartitionReaderWithPartitionValues(fileReader, readDataSchema,
partitionSchema, file.partitionValues)
}
}

Expand Down Expand Up @@ -153,17 +151,4 @@ case class OrcPartitionReaderFactory(
}
}

/**
* Returns a new StructType that is a copy of the original StructType, removing any items that
* also appear in other StructType. The order is preserved from the original StructType.
*/
private def subtractSchema(original: StructType, other: StructType): StructType = {
val otherNameSet = other.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
val fields = original.fields.filterNot { field =>
otherNameSet.contains(PartitioningUtils.getColName(field, isCaseSensitive))
}

StructType(fields)
}

}

0 comments on commit a543f91

Please sign in to comment.