Skip to content

Commit

Permalink
[SPARK-10065] [SQL] avoid the extra copy when generate unsafe array
Browse files Browse the repository at this point in the history
The reason for this extra copy is that we iterate the array twice: calculate elements data size and copy elements to array buffer.

A simple solution is to follow `createCodeForStruct`, we can dynamically grow the buffer when needed and thus don't need to know the data size ahead.

This PR also include some typo and style fixes, and did some minor refactor to make sure `input.primitive` is always variable name not code when generate unsafe code.

Author: Wenchen Fan <[email protected]>

Closes #8496 from cloud-fan/avoid-copy.
  • Loading branch information
cloud-fan authored and davies committed Sep 10, 2015
1 parent 48817cc commit 4f1daa1
Showing 1 changed file with 24 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();")
val buffer = ctx.freshName("buffer")
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
val tmpBuffer = ctx.freshName("tmpBuffer")
val outputIsNull = ctx.freshName("isNull")
val numElements = ctx.freshName("numElements")
val fixedSize = ctx.freshName("fixedSize")
val numBytes = ctx.freshName("numBytes")
val elements = ctx.freshName("elements")
val cursor = ctx.freshName("cursor")
val index = ctx.freshName("index")
val elementName = ctx.freshName("elementName")
Expand All @@ -224,57 +224,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

val convertedElement = createConvertCode(ctx, element, elementType)

// go through the input array to calculate how many bytes we need.
val calculateNumBytes = elementType match {
case _ if ctx.isPrimitiveType(elementType) =>
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
$numBytes += $elementSize * $numElements;
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
$numBytes += 8 * $numElements;
"""
case _ =>
val writer = getWriter(elementType)
val elementSize = s"$writer.getSize($elements[$index])"
// TODO(davies): avoid the copy
val unsafeType = elementType match {
case _: StructType => "UnsafeRow"
case _: ArrayType => "UnsafeArrayData"
case _: MapType => "UnsafeMapData"
case _ => ctx.javaType(elementType)
}
val copy = elementType match {
// We reuse the buffer during conversion, need copy it before process next element.
case _: StructType | _: ArrayType | _: MapType => ".copy()"
case _ => ""
}

val newElements = if (elementType == BinaryType) {
s"new byte[$numElements][]"
} else {
s"new $unsafeType[$numElements]"
}
s"""
final $unsafeType[] $elements = $newElements;
for (int $index = 0; $index < $numElements; $index++) {
${convertedElement.code}
if (!${convertedElement.isNull}) {
$elements[$index] = ${convertedElement.primitive}$copy;
$numBytes += $elementSize;
}
}
"""
}

val writeElement = elementType match {
case _ if ctx.isPrimitiveType(elementType) =>
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
${convertedElement.code}
Platform.put${ctx.primitiveTypeName(elementType)}(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
Expand All @@ -283,7 +237,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
${convertedElement.code}
Platform.putLong(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
Expand All @@ -296,15 +249,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$cursor += $writer.write(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
$elements[$index]);
${convertedElement.primitive});
"""
}

val checkNull = elementType match {
case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}"
case t: DecimalType => s"$elements[$index] == null" +
s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})"
case _ => s"$elements[$index] == null"
val checkNull = convertedElement.isNull + (elementType match {
case t: DecimalType =>
s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})"
case _ => ""
})

val elementSize = elementType match {
// Should we do word align for primitive types?
case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8"
case _ =>
val writer = getWriter(elementType)
s"$writer.getSize(${convertedElement.primitive})"
}

val code = s"""
Expand All @@ -318,18 +279,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final int $fixedSize = 4 * $numElements;
int $numBytes = $fixedSize;

$calculateNumBytes

if ($numBytes > $buffer.length) {
$buffer = new byte[$numBytes];
}

int $cursor = $fixedSize;
for (int $index = 0; $index < $numElements; $index++) {
${convertedElement.code}
if ($checkNull) {
// If element is null, write the negative value address into offset region.
Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor);
} else {
$numBytes += $elementSize;
if ($buffer.length < $numBytes) {
// This will not happen frequently, because the buffer is re-used.
byte[] $tmpBuffer = new byte[$numBytes * 2];
Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET,
$tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
$buffer = $tmpBuffer;
}
Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor);
$writeElement
}
Expand Down

0 comments on commit 4f1daa1

Please sign in to comment.