Skip to content

Commit

Permalink
Support encryption and compression in disk store (#9454)
Browse files Browse the repository at this point in the history
* Support encryption and compression in disk store

Signed-off-by: Ferdinand Xu <[email protected]>

* Address some comments

* Address comments

* Fix issue

* Address discussion

* Minor fix

---------

Signed-off-by: Ferdinand Xu <[email protected]>
  • Loading branch information
winningsix authored Nov 1, 2023
1 parent 56d1be1 commit 71505ba
Show file tree
Hide file tree
Showing 6 changed files with 388 additions and 46 deletions.
18 changes: 12 additions & 6 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1643,12 +1643,12 @@ object RapidsConf {
.createWithDefault(500 * 1024)

val SHUFFLE_COMPRESSION_CODEC = conf("spark.rapids.shuffle.compression.codec")
.doc("The GPU codec used to compress shuffle data when using RAPIDS shuffle. " +
"Supported codecs: lz4, copy, none")
.internal()
.startupOnly()
.stringConf
.createWithDefault("none")
.doc("The GPU codec used to compress shuffle data when using RAPIDS shuffle. " +
"Supported codecs: lz4, copy, none")
.internal()
.startupOnly()
.stringConf
.createWithDefault("none")

val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.lz4.chunkSize")
.doc("A configurable chunk size to use when compressing with LZ4.")
Expand Down Expand Up @@ -2049,6 +2049,12 @@ object RapidsConf {
.longConf
.createOptional

val TEST_IO_ENCRYPTION = conf("spark.rapids.test.io.encryption")
.doc("Only for tests: verify for IO encryption")
.internal()
.booleanConf
.createOptional

private def printSectionHeader(category: String): Unit =
println(s"\n### $category")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@

package com.nvidia.spark.rapids

import java.io.{File, FileInputStream, FileOutputStream}
import java.io.{File, FileInputStream}
import java.nio.channels.{Channels, FileChannel}
import java.nio.channels.FileChannel.MapMode
import java.nio.file.StandardOpenOption
import java.util.concurrent.ConcurrentHashMap

import ai.rapids.cudf.{Cuda, HostMemoryBuffer, MemoryBuffer}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta
import org.apache.commons.io.IOUtils

import org.apache.spark.sql.rapids.RapidsDiskBlockManager
import org.apache.spark.sql.rapids.execution.SerializedHostTableUtils
Expand All @@ -47,7 +50,7 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
id.getDiskPath(diskBlockManager)
}

val (fileOffset, diskLength) = if (id.canShareDiskPaths) {
val (fileOffset, uncompressedSize, diskLength) = if (id.canShareDiskPaths) {
// only one writer at a time for now when using shared files
path.synchronized {
writeToFile(incoming, path, append = true, stream)
Expand All @@ -62,6 +65,7 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
new RapidsDiskColumnarBatch(
id,
fileOffset,
uncompressedSize,
diskLength,
incoming.meta,
incoming.getSpillPriority)
Expand All @@ -70,32 +74,50 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
new RapidsDiskBuffer(
id,
fileOffset,
uncompressedSize,
diskLength,
incoming.meta,
incoming.getSpillPriority)
}
Some(buff)
}

/** Copy a host buffer to a file, returning the file offset at which the data was written. */
/**
* Copy a host buffer to a file. It leverages [[RapidsSerializerManager]] from
* [[RapidsDiskBlockManager]] to do compression or encryption if needed.
*
* @param incoming the rapid buffer to be written into a file
* @param path file path
* @param append whether to append or written into the beginning of the file
* @param stream cuda stream
* @return a tuple of file offset, memory byte size and written size on disk. File offset is where
* buffer starts in the targeted file path. Memory byte size is the size of byte buffer
* occupied in memory before writing to disk. Written size on disk is actual byte size
* written to disk.
*/
private def writeToFile(
incoming: RapidsBuffer,
path: File,
append: Boolean,
stream: Cuda.Stream): (Long, Long) = {
stream: Cuda.Stream): (Long, Long, Long) = {
incoming match {
case fileWritable: RapidsBufferChannelWritable =>
withResource(new FileOutputStream(path, append)) { fos =>
withResource(fos.getChannel) { outputChannel =>
val startOffset = outputChannel.position()
val writtenBytes = fileWritable.writeToChannel(outputChannel, stream)
if (writtenBytes == 0) {
throw new IllegalStateException(
s"Buffer ${fileWritable} wrote 0 bytes disk on spill. This is not supported!"
)
val option = if (append) {
Array(StandardOpenOption.CREATE, StandardOpenOption.APPEND)
} else {
Array(StandardOpenOption.CREATE, StandardOpenOption.WRITE)
}
var currentPos, writtenBytes = 0L
withResource(FileChannel.open(path.toPath, option: _*)) { fc =>
currentPos = fc.position()
withResource(Channels.newOutputStream(fc)) { os =>
withResource(diskBlockManager.getSerializerManager()
.wrapStream(incoming.id, os)) { cos =>
val outputChannel = Channels.newChannel(cos)
writtenBytes = fileWritable.writeToChannel(outputChannel, stream)
}
(startOffset, writtenBytes)
}
(currentPos, writtenBytes, path.length() - currentPos)
}
case other =>
throw new IllegalStateException(
Expand All @@ -110,25 +132,45 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
class RapidsDiskBuffer(
id: RapidsBufferId,
fileOffset: Long,
size: Long,
uncompressedSize: Long,
onDiskSizeInBytes: Long,
meta: TableMeta,
spillPriority: Long)
extends RapidsBufferBase(
id, meta, spillPriority) {
extends RapidsBufferBase(id, meta, spillPriority) {
private[this] var hostBuffer: Option[HostMemoryBuffer] = None

override val memoryUsedBytes: Long = size
// FIXME: Need to be clean up. Tracked in https://github.com/NVIDIA/spark-rapids/issues/9496
override val memoryUsedBytes: Long = uncompressedSize

override val storageTier: StorageTier = StorageTier.DISK

override def getMemoryBuffer: MemoryBuffer = synchronized {
if (hostBuffer.isEmpty) {
require(size > 0,
require(onDiskSizeInBytes > 0,
s"$this attempted an invalid 0-byte mmap of a file")
val path = id.getDiskPath(diskBlockManager)
val mappedBuffer = HostMemoryBuffer.mapFile(path, MapMode.READ_WRITE,
fileOffset, size)
hostBuffer = Some(mappedBuffer)
val serializerManager = diskBlockManager.getSerializerManager()
val memBuffer = if (serializerManager.isRapidsSpill(id)) {
// Only go through serializerManager's stream wrapper for spill case
closeOnExcept(HostMemoryBuffer.allocate(uncompressedSize)) { decompressed =>
withResource(FileChannel.open(path.toPath, StandardOpenOption.READ)) { c =>
c.position(fileOffset)
withResource(Channels.newInputStream(c)) { compressed =>
withResource(serializerManager.wrapStream(id, compressed)) { in =>
withResource(new HostMemoryOutputStream(decompressed)) { out =>
IOUtils.copy(in, out)
}
decompressed
}
}
}
}
} else {
// Reserved mmap read fashion for UCX shuffle path. Also it's skipping encryption and
// compression.
HostMemoryBuffer.mapFile(path, MapMode.READ_WRITE, fileOffset, onDiskSizeInBytes)
}
hostBuffer = Some(memBuffer)
}
hostBuffer.foreach(_.incRefCount())
hostBuffer.get
Expand Down Expand Up @@ -170,11 +212,12 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
id: RapidsBufferId,
fileOffset: Long,
size: Long,
uncompressedSize: Long,
// TODO: remove meta
meta: TableMeta,
spillPriority: Long)
extends RapidsDiskBuffer(
id, fileOffset, size, meta, spillPriority)
id, fileOffset, size, uncompressedSize, meta, spillPriority)
with RapidsHostBatchBuffer {

override def getMemoryBuffer: MemoryBuffer =
Expand All @@ -191,11 +234,14 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
"paths on disk")
val path = id.getDiskPath(diskBlockManager)
withResource(new FileInputStream(path)) { fis =>
val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fis)
val hostCols = withResource(hostBuffer) { _ =>
SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes)
withResource(diskBlockManager.getSerializerManager()
.wrapStream(id, fis)) { fs =>
val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fs)
val hostCols = withResource(hostBuffer) { _ =>
SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes)
}
new ColumnarBatch(hostCols.toArray, header.getNumRows)
}
new ColumnarBatch(hostCols.toArray, header.getNumRows)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import java.io.{InputStream, OutputStream}

import org.apache.spark.SparkConf
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.rapids.TempSpillBufferId
import org.apache.spark.sql.rapids.execution.TrampolineUtil


/**
* It's a wrapper of Spark's SerializerManager, which supports compression and encryption
* on data streams.
* For compression, it's turned on/off via seperated Rapids configurations and the underlying
* compression codec uses existing Spark's.
* For encryption, it's controlled by Spark's configuration to turn on/off.
* @param conf
*/
class RapidsSerializerManager (conf: SparkConf) {
private lazy val compressSpill = TrampolineUtil.isCompressSpill(conf)

private lazy val serializerManager = if (conf
.getBoolean(RapidsConf.TEST_IO_ENCRYPTION.key,false)) {
TrampolineUtil.createSerializerManager(conf)
} else {
TrampolineUtil.getSerializerManager
}

private lazy val compressionCodec: CompressionCodec = TrampolineUtil.createCodec(conf)

// Whether it really goes through crypto streams replies on Spark configuration
// (e.g., `` `spark.io.encryption.enabled` ``) and the existence of crypto keys.
def wrapStream(bufferId: RapidsBufferId, s: OutputStream): OutputStream = {
if(isRapidsSpill(bufferId)) wrapForCompression(bufferId, wrapForEncryption(s)) else s
}

def wrapStream(bufferId: RapidsBufferId, s: InputStream): InputStream = {
if(isRapidsSpill(bufferId)) wrapForCompression(bufferId, wrapForEncryption(s)) else s
}

private[this] def wrapForCompression(bufferId: RapidsBufferId, s: InputStream): InputStream = {
if (shouldCompress(bufferId)) compressionCodec.compressedInputStream(s) else s
}

private[this] def wrapForCompression(bufferId: RapidsBufferId, s: OutputStream): OutputStream = {
if (shouldCompress(bufferId)) compressionCodec.compressedOutputStream(s) else s
}

private[this] def wrapForEncryption(s: InputStream): InputStream = {
if (serializerManager != null) serializerManager.wrapForEncryption(s) else s
}

private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
if (serializerManager != null) serializerManager.wrapForEncryption(s) else s
}

def isRapidsSpill(bufferId: RapidsBufferId): Boolean = {
bufferId match {
case _: TempSpillBufferId => true
case _ => false
}
}

private[this] def shouldCompress(bufferId: RapidsBufferId): Boolean = {
bufferId match {
case _: TempSpillBufferId => compressSpill
case _: ShuffleBufferId | _: ShuffleReceivedBufferId => false
case _ => false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@ package org.apache.spark.sql.rapids

import java.io.File

import com.nvidia.spark.rapids.RapidsSerializerManager

import org.apache.spark.SparkConf
import org.apache.spark.rapids.shims.storage.ShimDiskBlockManager
import org.apache.spark.storage.BlockId

/** Maps logical blocks to local disk locations. */
class RapidsDiskBlockManager(conf: SparkConf) {
private[this] val blockManager = new ShimDiskBlockManager(conf, true)
private[this] val serializerManager = new RapidsSerializerManager(conf)

def getFile(blockId: BlockId): File = blockManager.getFile(blockId)

def getFile(file: String): File = blockManager.getFile(file)

def getSerializerManager(): RapidsSerializerManager = serializerManager
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkMasterRegex, Sp
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.config
import org.apache.spark.internal.config.EXECUTOR_ID
import org.apache.spark.io.CompressionCodec
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
Expand Down Expand Up @@ -183,8 +186,25 @@ object TrampolineUtil {
CompressionCodec.createCodec(conf, codecName)
}

def createCodec(conf: SparkConf): CompressionCodec = {
CompressionCodec.createCodec(conf)
}

def getCodecShortName(codecName: String): String = CompressionCodec.getShortName(codecName)

def getSerializerManager(): SerializerManager = {
if (SparkEnv.get != null) SparkEnv.get.serializerManager else null
}

// For test only
def createSerializerManager(conf: SparkConf): SerializerManager = {
new SerializerManager(new JavaSerializer(conf), conf, Some(CryptoStreamUtils.createKey(conf)))
}

def isCompressSpill(conf: SparkConf): Boolean = {
conf.get(config.SHUFFLE_SPILL_COMPRESS)
}

// If the master is a local mode (local or local-cluster), return the number
// of cores per executor it is going to use, otherwise return 1.
def getCoresInLocalMode(master: String, conf: SparkConf): Int = {
Expand Down
Loading

0 comments on commit 71505ba

Please sign in to comment.