Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16922] [SPARK-17211] [SQL] make the address of values portable in LongToUnsafeRowMap #14927

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,20 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
private def nextSlot(pos: Int): Int = (pos + 2) & mask

private[this] def toAddress(offset: Long, size: Int): Long = {
((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size
}

private[this] def toOffset(address: Long): Long = {
(address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET
}

private[this] def toSize(address: Long): Int = {
(address & SIZE_MASK).toInt
}

private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
val offset = address >>> SIZE_BITS
val size = address & SIZE_MASK
resultRow.pointTo(page, offset, size.toInt)
resultRow.pointTo(page, toOffset(address), toSize(address))
resultRow
}

Expand Down Expand Up @@ -485,9 +495,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
var addr = address
override def hasNext: Boolean = addr != 0
override def next(): UnsafeRow = {
val offset = addr >>> SIZE_BITS
val size = addr & SIZE_MASK
resultRow.pointTo(page, offset, size.toInt)
val offset = toOffset(addr)
val size = toSize(addr)
resultRow.pointTo(page, offset, size)
addr = Platform.getLong(page, offset + size)
resultRow
}
Expand Down Expand Up @@ -554,14 +564,15 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
Platform.putLong(page, cursor, 0)
cursor += 8
numValues += 1
updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes)
updateIndex(key, toAddress(offset, row.getSizeInBytes))
}

/**
* Update the address in array for given key.
*/
private def updateIndex(key: Long, address: Long): Unit = {
var pos = firstSlot(key)
assert(numKeys < array.length / 2)
while (array(pos) != key && array(pos + 1) != 0) {
pos = nextSlot(pos)
}
Expand All @@ -582,7 +593,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
// there are some values for this key, put the address in the front of them.
val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK)
val pointer = toOffset(address) + toSize(address)
Platform.putLong(page, pointer, array(pos + 1))
array(pos + 1) = address
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}

import scala.util.Random

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.serializer.KryoSerializer
Expand Down Expand Up @@ -197,6 +199,60 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
}
}

test("LongToUnsafeRowMap with random keys") {
val taskMemoryManager = new TaskMemoryManager(
new StaticMemoryManager(
new SparkConf().set("spark.memory.offHeap.enabled", "false"),
Long.MaxValue,
Long.MaxValue,
1),
0)
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))

val N = 1000000
val rand = new Random
val keys = (0 to N).map(x => rand.nextLong()).toArray

val map = new LongToUnsafeRowMap(taskMemoryManager, 10)
keys.foreach { k =>
map.append(k, unsafeProj(InternalRow(k)))
}
map.optimize()

val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
map.writeExternal(out)
out.flush()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1)
map2.readExternal(in)

val row = unsafeProj(InternalRow(0L)).copy()
keys.foreach { k =>
val r = map2.get(k, row)
assert(r.hasNext)
var c = 0
while (r.hasNext) {
val rr = r.next()
assert(rr.getLong(0) === k)
c += 1
}
}
var i = 0
while (i < N * 10) {
val k = rand.nextLong()
val r = map2.get(k, row)
if (r != null) {
assert(r.hasNext)
while (r.hasNext) {
assert(r.next().getLong(0) === k)
}
}
i += 1
}
map.free()
}

test("Spark-14521") {
val ser = new KryoSerializer(
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
Expand Down