diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala index 101a70ee92ce5..1daf8ae72b639 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile, PartitioningUtils} import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch abstract class FilePartitionReaderFactory extends PartitionReaderFactory { @@ -45,6 +46,19 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory { def buildColumnarReader(partitionedFile: PartitionedFile): PartitionReader[ColumnarBatch] = { throw new UnsupportedOperationException("Cannot create columnar reader.") } + + protected def getReadDataSchema( + readSchema: StructType, + partitionSchema: StructType, + isCaseSensitive: Boolean): StructType = { + val partitionNameSet = + partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet + val fields = readSchema.fields.filterNot { field => + partitionNameSet.contains(PartitioningUtils.getColName(field, isCaseSensitive)) + } + + StructType(fields) + } } // A compound class for combining file and its corresponding reader. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderWithPartitionValues.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderWithPartitionValues.scala new file mode 100644 index 0000000000000..072465b56857d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderWithPartitionValues.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.sources.v2.reader.PartitionReader +import org.apache.spark.sql.types.StructType + +/** + * A wrapper reader that always appends partition values to [[InternalRow]]s produced by the input + * reader [[fileReader]]. + */ +class PartitionReaderWithPartitionValues( + fileReader: PartitionReader[InternalRow], + readDataSchema: StructType, + partitionSchema: StructType, + partitionValues: InternalRow) extends PartitionReader[InternalRow] { + private val fullSchema = readDataSchema.toAttributes ++ partitionSchema.toAttributes + private val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + // Note that we have to apply the converter even though `file.partitionValues` is empty. + // This is because the converter is also responsible for converting safe `InternalRow`s into + // `UnsafeRow`s + private val rowConverter = { + if (partitionSchema.isEmpty) { + () => unsafeProjection(fileReader.get())} + else { + val joinedRow = new JoinedRow() + () => unsafeProjection(joinedRow(fileReader.get(), partitionValues)) + } + } + + override def next(): Boolean = fileReader.next() + + override def get(): InternalRow = rowConverter() + + override def close(): Unit = fileReader.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionRecordReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionRecordReader.scala index ff78ef3220c17..baa8cb6b24659 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionRecordReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionRecordReader.scala @@ -29,13 +29,3 @@ class PartitionRecordReader[T]( override def close(): Unit = rowReader.close() } - -class PartitionRecordReaderWithProject[X, T]( - private[this] var rowReader: RecordReader[_, X], - project: X => T) extends PartitionReader[T] { - override def next(): Boolean = rowReader.nextKeyValue() - - override def get(): T = project(rowReader.getCurrentValue) - - override def close(): Unit = rowReader.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index f6fc0ca839081..4ae10a656e5e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -90,22 +90,20 @@ case class OrcPartitionReaderFactory( val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) - val requiredDataSchema = subtractSchema(readSchema, partitionSchema) + val readDataSchema = getReadDataSchema(readSchema, partitionSchema, isCaseSensitive) val orcRecordReader = new OrcInputFormat[OrcStruct] .createRecordReader(fileSplit, taskAttemptContext) + val deserializer = new OrcDeserializer(dataSchema, readDataSchema, requestedColIds) + val fileReader = new PartitionReader[InternalRow] { + override def next(): Boolean = orcRecordReader.nextKeyValue() - val fullSchema = requiredDataSchema.toAttributes ++ partitionSchema.toAttributes - val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredDataSchema, requestedColIds) + override def get(): InternalRow = deserializer.deserialize(orcRecordReader.getCurrentValue) - val projection = if (partitionSchema.length == 0) { - (value: OrcStruct) => unsafeProjection(deserializer.deserialize(value)) - } else { - val joinedRow = new JoinedRow() - (value: OrcStruct) => - unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues)) + override def close(): Unit = orcRecordReader.close() } - new PartitionRecordReaderWithProject(orcRecordReader, projection) + + new PartitionReaderWithPartitionValues(fileReader, readDataSchema, + partitionSchema, file.partitionValues) } } @@ -153,17 +151,4 @@ case class OrcPartitionReaderFactory( } } - /** - * Returns a new StructType that is a copy of the original StructType, removing any items that - * also appear in other StructType. The order is preserved from the original StructType. - */ - private def subtractSchema(original: StructType, other: StructType): StructType = { - val otherNameSet = other.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet - val fields = original.fields.filterNot { field => - otherNameSet.contains(PartitioningUtils.getColName(field, isCaseSensitive)) - } - - StructType(fields) - } - }