Skip to content

Commit

Permalink
support array map struct
Browse files Browse the repository at this point in the history
  • Loading branch information
AngersZhuuuu committed Jul 21, 2020
1 parent 858f4e5 commit cfecc90
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream}
import java.nio.charset.StandardCharsets
import java.util.Map.Entry
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
Expand All @@ -33,10 +34,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}

trait BaseScriptTransformationExec extends UnaryExecNode {
Expand All @@ -47,7 +46,13 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
def ioschema: ScriptTransformationIOSchema

protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = {
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
input.map { in: Expression =>
in.dataType match {
case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => in
case _ => Cast(in, StringType)
.withTimeZone(conf.sessionLocalTimeZone).asInstanceOf[Expression]
}
}
}

override def producedAttributes: AttributeSet = outputSet -- inputSet
Expand Down Expand Up @@ -177,55 +182,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
}

private lazy val fieldWriters: Seq[String => Any] = output.map { attr =>
val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType)
attr.dataType match {
case StringType => wrapperConvertException(data => data, converter)
case BooleanType => wrapperConvertException(data => data.toBoolean, converter)
case ByteType => wrapperConvertException(data => data.toByte, converter)
case BinaryType => wrapperConvertException(data => data.getBytes, converter)
case IntegerType => wrapperConvertException(data => data.toInt, converter)
case ShortType => wrapperConvertException(data => data.toShort, converter)
case LongType => wrapperConvertException(data => data.toLong, converter)
case FloatType => wrapperConvertException(data => data.toFloat, converter)
case DoubleType => wrapperConvertException(data => data.toDouble, converter)
case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter)
case DateType if conf.datetimeJava8ApiEnabled =>
wrapperConvertException(data => DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.daysToLocalDate).orNull, converter)
case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaDate).orNull, converter)
case TimestampType if conf.datetimeJava8ApiEnabled =>
wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.microsToInstant).orNull, converter)
case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaTimestamp).orNull, converter)
case CalendarIntervalType => wrapperConvertException(
data => IntervalUtils.stringToInterval(UTF8String.fromString(data)),
converter)
case udt: UserDefinedType[_] =>
wrapperConvertException(data => udt.deserialize(data), converter)
case _ => wrapperConvertException(data => data, converter)
}
SparkInspectors.unwrapper(attr.dataType, conf, ioschema)
}

// Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null
private val wrapperConvertException: (String => Any, Any => Any) => String => Any =
(f: String => Any, converter: Any => Any) =>
(data: String) => converter {
try {
f(data)
} catch {
case _: Exception => null
}
}
}

abstract class BaseScriptTransformationWriterThread extends Thread with Logging {
Expand All @@ -248,18 +206,23 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging

protected def processRows(): Unit

val wrappers = inputSchema.map(dt => SparkInspectors.wrapper(dt))

protected def processRowsWithoutSerde(): Unit = {
val len = inputSchema.length
iter.foreach { row =>
val values = row.asInstanceOf[GenericInternalRow].values.zip(wrappers).map {
case (value, wrapper) => wrapper(value)
}
val data = if (len == 0) {
ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")
} else {
val sb = new StringBuilder
sb.append(row.get(0, inputSchema(0)))
buildString(sb, values(0), inputSchema(0))
var i = 1
while (i < len) {
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
sb.append(row.get(i, inputSchema(i)))
buildString(sb, values(i), inputSchema(i))
i += 1
}
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES"))
Expand All @@ -269,6 +232,38 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging
}
}

private def buildString(sb: StringBuilder, obj: Any, dataType: DataType): Unit = {
(obj, dataType) match {
case (arrayList: java.util.ArrayList[_], StructType(fields)) =>
(0 until arrayList.size).foreach { i =>
if (i > 0) {
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATSTRUCTFIELD"))
}
buildString(sb, arrayList.get(i), fields(i).dataType)
}
case (list: java.util.List[_], ArrayType(typ, _)) =>
(0 until list.size).foreach { i =>
if (i > 0) {
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS"))
}
buildString(sb, list.get(i), typ)
}
case (map: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
val entries = map.entrySet().toArray()
(0 until entries.size).foreach { i =>
if (i > 0) {
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS"))
}
val entry = entries(i).asInstanceOf[Entry[_, _]]
buildString(sb, entry.getKey, keyType)
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS"))
buildString(sb, entry.getValue, valueType)
}
case (other, _) =>
sb.append(other.toString)
}
}

override def run(): Unit = Utils.logUncaughtExceptions {
TaskContext.setTaskContext(taskContext)

Expand Down Expand Up @@ -328,7 +323,10 @@ case class ScriptTransformationIOSchema(
object ScriptTransformationIOSchema {
val defaultFormat = Map(
("TOK_TABLEROWFORMATFIELD", "\t"),
("TOK_TABLEROWFORMATLINES", "\n")
("TOK_TABLEROWFORMATLINES", "\n"),
("TOK_TABLEROWFORMATSTRUCTFIELD", "\u0001"),
("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"),
("TOK_TABLEROWFORMATMAPKEYS", "\u0003")
)

val defaultIOSchema = ScriptTransformationIOSchema(
Expand Down
Loading

0 comments on commit cfecc90

Please sign in to comment.