Skip to content

Commit

Permalink
Add support for AtomicCreateTableAsSelect with Delta Lake [databricks] (
Browse files Browse the repository at this point in the history
#9425)

* Add support for AtomicCreateTableAsSelect with Delta Lake

Signed-off-by: Jason Lowe <[email protected]>

* Avoid showing GpuColumnarToRow transition in plan that does not actually execute

* Fix 3.5 build

---------

Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe authored Oct 17, 2023
1 parent de032b3 commit f2dbb23
Show file tree
Hide file tree
Showing 58 changed files with 6,158 additions and 56 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.databricks.sql.transaction.tahoe.rapids

import com.databricks.sql.transaction.tahoe.{DeltaLog, OptimisticTransaction}
import com.nvidia.spark.rapids.RapidsConf
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.util.Clock
Expand Down Expand Up @@ -70,4 +71,12 @@ object GpuDeltaLog {
val deltaLog = DeltaLog.forTable(spark, dataPath, options)
new GpuDeltaLog(deltaLog, rapidsConf)
}

def forTable(
spark: SparkSession,
tableLocation: Path,
rapidsConf: RapidsConf): GpuDeltaLog = {
val deltaLog = DeltaLog.forTable(spark, tableLocation)
new GpuDeltaLog(deltaLog, rapidsConf)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,31 @@

package com.nvidia.spark.rapids.delta

import scala.collection.JavaConverters.mapAsScalaMapConverter

import com.databricks.sql.managedcatalog.UnityCatalogV2Proxy
import com.databricks.sql.transaction.tahoe.{DeltaLog, DeltaParquetFileFormat}
import com.databricks.sql.transaction.tahoe.catalog.DeltaCatalog
import com.databricks.sql.transaction.tahoe.commands.{DeleteCommand, DeleteCommandEdge, MergeIntoCommand, MergeIntoCommandEdge, UpdateCommand, UpdateCommandEdge}
import com.databricks.sql.transaction.tahoe.sources.DeltaDataSource
import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog
import com.databricks.sql.transaction.tahoe.sources.{DeltaDataSource, DeltaSourceUtils}
import com.nvidia.spark.rapids._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.StagingTableCatalog
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand}
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.ExternalSource
import org.apache.spark.sql.sources.CreatableRelationProvider

/**
* Implements the DeltaProvider interface for Databricks Delta Lake.
* Common implementation of the DeltaProvider interface for all Databricks versions.
*/
object DeltaProviderImpl extends DeltaProviderImplBase {
object DatabricksDeltaProvider extends DeltaProviderImplBase {
override def getCreatableRelationRules: Map[Class[_ <: CreatableRelationProvider],
CreatableRelationProviderRule[_ <: CreatableRelationProvider]] = {
Seq(
Expand Down Expand Up @@ -92,6 +101,43 @@ object DeltaProviderImpl extends DeltaProviderImplBase {
val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat]
GpuDeltaParquetFileFormat.convertToGpu(cpuFormat)
}

override def isSupportedCatalog(catalogClass: Class[_ <: StagingTableCatalog]): Boolean = {
catalogClass == classOf[DeltaCatalog] || catalogClass == classOf[UnityCatalogV2Proxy]
}

override def tagForGpu(
cpuExec: AtomicCreateTableAsSelectExec,
meta: AtomicCreateTableAsSelectExecMeta): Unit = {
require(isSupportedCatalog(cpuExec.catalog.getClass))
if (!meta.conf.isDeltaWriteEnabled) {
meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
val properties = cpuExec.properties
val provider = properties.getOrElse("provider",
cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME))
if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) {
meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider")
}
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None,
cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session)
}

override def convertToGpu(
cpuExec: AtomicCreateTableAsSelectExec,
meta: AtomicCreateTableAsSelectExecMeta): GpuExec = {
GpuAtomicCreateTableAsSelectExec(
cpuExec.output,
new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
cpuExec.ident,
cpuExec.partitioning,
cpuExec.plan,
meta.childPlans.head.convertIfNeeded(),
cpuExec.tableSpec,
cpuExec.writeOptions,
cpuExec.ifNotExists)
}
}

class DeltaCreatableRelationProviderMeta(
Expand All @@ -115,8 +161,8 @@ class DeltaCreatableRelationProviderMeta(
val path = saveCmd.options.get("path")
if (path.isDefined) {
val deltaLog = DeltaLog.forTable(SparkSession.active, path.get, saveCmd.options)
RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, deltaLog, saveCmd.options,
SparkSession.active)
RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, Some(deltaLog),
saveCmd.options, SparkSession.active)
} else {
willNotWorkOnGpu("no path specified for Delta Lake table")
}
Expand All @@ -131,5 +177,5 @@ class DeltaCreatableRelationProviderMeta(
*/
class DeltaProbeImpl extends DeltaProbe {
// Delta Lake is built-in for Databricks instances, so no probing is necessary.
override def getDeltaProvider: DeltaProvider = DeltaProviderImpl
override def getDeltaProvider: DeltaProvider = DatabricksDeltaProvider
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DeleteCommandMeta(
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
DeleteCommandMetaShim.tagForGpu(this)
RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog,
RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog),
Map.empty, SparkSession.active)
}

Expand All @@ -62,7 +62,7 @@ class DeleteCommandEdgeMeta(
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
DeleteCommandMetaShim.tagForGpu(this)
RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog,
RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog),
Map.empty, SparkSession.active)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class MergeIntoCommandMeta(
MergeIntoCommandMetaShim.tagForGpu(this, mergeCmd)
val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema)
val deltaLog = mergeCmd.targetFileIndex.deltaLog
RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active)
RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty,
SparkSession.active)
}

override def convertToGpu(): RunnableCommand =
Expand All @@ -60,7 +61,8 @@ class MergeIntoCommandEdgeMeta(
MergeIntoCommandMetaShim.tagForGpu(this, mergeCmd)
val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema)
val deltaLog = mergeCmd.targetFileIndex.deltaLog
RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active)
RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty,
SparkSession.active)
}

override def convertToGpu(): RunnableCommand =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ object RapidsDeltaUtils {
def tagForDeltaWrite(
meta: RapidsMeta[_, _, _],
schema: StructType,
deltaLog: DeltaLog,
deltaLog: Option[DeltaLog],
options: Map[String, String],
spark: SparkSession): Unit = {
FileFormatChecks.tag(meta, schema, DeltaFormatType, WriteFileOp)
val format = DeltaLogShim.fileFormat(deltaLog)
if (format.getClass == classOf[DeltaParquetFileFormat]) {
val format = deltaLog.map(log => DeltaLogShim.fileFormat(log).getClass)
.getOrElse(classOf[DeltaParquetFileFormat])
if (format == classOf[DeltaParquetFileFormat]) {
GpuParquetFileFormat.tagGpuSupport(meta, spark, options, schema)
} else {
meta.willNotWorkOnGpu(s"file format $format is not supported")
Expand All @@ -45,7 +46,7 @@ object RapidsDeltaUtils {
private def checkIncompatibleConfs(
meta: RapidsMeta[_, _, _],
schema: StructType,
deltaLog: DeltaLog,
deltaLog: Option[DeltaLog],
sqlConf: SQLConf,
options: Map[String, String]): Unit = {
def getSQLConf(key: String): Option[String] = {
Expand All @@ -65,19 +66,21 @@ object RapidsDeltaUtils {
orderableTypeSig.isSupportedByPlugin(t)
}
if (unorderableTypes.nonEmpty) {
val metadata = DeltaLogShim.getMetadata(deltaLog)
val hasPartitioning = metadata.partitionColumns.nonEmpty ||
val metadata = deltaLog.map(log => DeltaLogShim.getMetadata(log))
val hasPartitioning = metadata.exists(_.partitionColumns.nonEmpty) ||
options.get(DataSourceUtils.PARTITIONING_COLUMNS_KEY).exists(_.nonEmpty)
if (!hasPartitioning) {
val optimizeWriteEnabled = {
val deltaOptions = new DeltaOptions(options, sqlConf)
deltaOptions.optimizeWrite.orElse {
getSQLConf("spark.databricks.delta.optimizeWrite.enabled").map(_.toBoolean).orElse {
DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(metadata).orElse {
metadata.configuration.get("delta.autoOptimize.optimizeWrite").orElse {
getSQLConf(
"spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite")
}.map(_.toBoolean)
metadata.flatMap { m =>
DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(m).orElse {
m.configuration.get("delta.autoOptimize.optimizeWrite").orElse {
getSQLConf(
"spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite")
}.map(_.toBoolean)
}
}
}
}.getOrElse(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class UpdateCommandMeta(
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema,
updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark)
Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark)
}

override def convertToGpu(): RunnableCommand = {
Expand All @@ -62,7 +62,7 @@ class UpdateCommandEdgeMeta(
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema,
updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark)
Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark)
}

override def convertToGpu(): RunnableCommand = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@

package com.nvidia.spark.rapids.delta

import scala.collection.JavaConverters.mapAsScalaMapConverter
import scala.util.Try

import com.nvidia.spark.rapids._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.StagingTableCatalog
import org.apache.spark.sql.delta.{DeltaLog, DeltaParquetFileFormat}
import org.apache.spark.sql.delta.catalog.DeltaCatalog
import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim
import org.apache.spark.sql.delta.sources.DeltaDataSource
import org.apache.spark.sql.delta.sources.{DeltaDataSource, DeltaSourceUtils}
import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand}
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.ExternalSource
import org.apache.spark.sql.rapids.execution.UnshimmedTrampolineUtil
import org.apache.spark.sql.sources.CreatableRelationProvider
Expand All @@ -48,6 +53,28 @@ abstract class DeltaIOProvider extends DeltaProviderImplBase {
override def isSupportedFormat(format: Class[_ <: FileFormat]): Boolean = {
format == classOf[DeltaParquetFileFormat]
}

override def isSupportedCatalog(catalogClass: Class[_ <: StagingTableCatalog]): Boolean = {
catalogClass == classOf[DeltaCatalog]
}

override def tagForGpu(
cpuExec: AtomicCreateTableAsSelectExec,
meta: AtomicCreateTableAsSelectExecMeta): Unit = {
require(isSupportedCatalog(cpuExec.catalog.getClass))
if (!meta.conf.isDeltaWriteEnabled) {
meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
val properties = cpuExec.properties
val provider = properties.getOrElse("provider",
cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME))
if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) {
meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider")
}
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None,
cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session)
}
}

class DeltaCreatableRelationProviderMeta(
Expand All @@ -71,8 +98,8 @@ class DeltaCreatableRelationProviderMeta(
val path = saveCmd.options.get("path")
if (path.isDefined) {
val deltaLog = DeltaLog.forTable(SparkSession.active, path.get, saveCmd.options)
RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, deltaLog, saveCmd.options,
SparkSession.active)
RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, Some(deltaLog),
saveCmd.options, SparkSession.active)
} else {
willNotWorkOnGpu("no path specified for Delta Lake table")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ object RapidsDeltaUtils {
def tagForDeltaWrite(
meta: RapidsMeta[_, _, _],
schema: StructType,
deltaLog: DeltaLog,
deltaLog: Option[DeltaLog],
options: Map[String, String],
spark: SparkSession): Unit = {
FileFormatChecks.tag(meta, schema, DeltaFormatType, WriteFileOp)
val format = DeltaRuntimeShim.fileFormatFromLog(deltaLog)
if (format.getClass == classOf[DeltaParquetFileFormat]) {
val format = deltaLog.map(log => DeltaRuntimeShim.fileFormatFromLog(log).getClass)
.getOrElse(classOf[DeltaParquetFileFormat])
if (format == classOf[DeltaParquetFileFormat]) {
GpuParquetFileFormat.tagGpuSupport(meta, spark, options, schema)
} else {
meta.willNotWorkOnGpu(s"file format $format is not supported")
Expand All @@ -43,7 +44,7 @@ object RapidsDeltaUtils {

private def checkIncompatibleConfs(
meta: RapidsMeta[_, _, _],
deltaLog: DeltaLog,
deltaLog: Option[DeltaLog],
sqlConf: SQLConf,
options: Map[String, String]): Unit = {
def getSQLConf(key: String): Option[String] = {
Expand All @@ -58,11 +59,13 @@ object RapidsDeltaUtils {
val deltaOptions = new DeltaOptions(options, sqlConf)
deltaOptions.optimizeWrite.orElse {
getSQLConf("spark.databricks.delta.optimizeWrite.enabled").map(_.toBoolean).orElse {
val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(deltaLog).metadata
DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(metadata).orElse {
metadata.configuration.get("delta.autoOptimize.optimizeWrite").orElse {
getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite")
}.map(_.toBoolean)
deltaLog.flatMap { log =>
val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(log).metadata
DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(metadata).orElse {
metadata.configuration.get("delta.autoOptimize.optimizeWrite").orElse {
getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite")
}.map(_.toBoolean)
}
}
}
}.getOrElse(false)
Expand All @@ -73,9 +76,11 @@ object RapidsDeltaUtils {

val autoCompactEnabled =
getSQLConf("spark.databricks.delta.autoCompact.enabled").orElse {
val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(deltaLog).metadata
metadata.configuration.get("delta.autoOptimize.autoCompact").orElse {
getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.autoCompact")
deltaLog.flatMap { log =>
val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(log).metadata
metadata.configuration.get("delta.autoOptimize.autoCompact").orElse {
getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.autoCompact")
}
}
}.exists(_.toBoolean)
if (autoCompactEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import com.nvidia.spark.rapids.{RapidsConf, ShimReflectionUtils, VersionUtils}
import com.nvidia.spark.rapids.delta.DeltaProvider

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.StagingTableCatalog
import org.apache.spark.sql.delta.{DeltaLog, DeltaUDF, Snapshot}
import org.apache.spark.sql.delta.catalog.DeltaCatalog
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.util.Clock
Expand All @@ -35,6 +37,8 @@ trait DeltaRuntimeShim {
def fileFormatFromLog(deltaLog: DeltaLog): FileFormat

def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean

def getGpuDeltaCatalog(cpuCatalog: DeltaCatalog, rapidsConf: RapidsConf): StagingTableCatalog
}

object DeltaRuntimeShim {
Expand Down Expand Up @@ -81,4 +85,8 @@ object DeltaRuntimeShim {

def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean =
shimInstance.getTightBoundColumnOnFileInitDisabled(spark)

def getGpuDeltaCatalog(cpuCatalog: DeltaCatalog, rapidsConf: RapidsConf): StagingTableCatalog = {
shimInstance.getGpuDeltaCatalog(cpuCatalog, rapidsConf)
}
}
Loading

0 comments on commit f2dbb23

Please sign in to comment.