Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark] Pass table identifiers/table catalogs to DeltaLog API call sites #3788

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.hudi.storage.hadoop.HadoopStorageConfiguration

import org.apache.spark.internal.MDC
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTable

object HudiConverter {
Expand Down Expand Up @@ -177,7 +178,7 @@ class HudiConverter(spark: SparkSession)
if (!UniversalFormat.hudiEnabled(snapshotToConvert.metadata)) {
return None
}
convertSnapshot(snapshotToConvert, None, Option.apply(catalogTable.identifier.table))
convertSnapshot(snapshotToConvert, None, Some(catalogTable))
}

/**
Expand All @@ -193,7 +194,7 @@ class HudiConverter(spark: SparkSession)
if (!UniversalFormat.hudiEnabled(snapshotToConvert.metadata)) {
return None
}
convertSnapshot(snapshotToConvert, Some(txn), txn.catalogTable.map(_.identifier.table))
convertSnapshot(snapshotToConvert, Some(txn), txn.catalogTable)
}

/**
Expand All @@ -208,11 +209,13 @@ class HudiConverter(spark: SparkSession)
private def convertSnapshot(
snapshotToConvert: Snapshot,
txnOpt: Option[OptimisticTransactionImpl],
tableName: Option[String]): Option[(Long, Long)] =
catalogTable: Option[CatalogTable]): Option[(Long, Long)] =
recordFrameProfile("Delta", "HudiConverter.convertSnapshot") {
val log = snapshotToConvert.deltaLog
val metaClient = loadTableMetaClient(snapshotToConvert.deltaLog.dataPath.toString,
tableName, snapshotToConvert.metadata.partitionColumns,
val metaClient = loadTableMetaClient(
snapshotToConvert.deltaLog.dataPath.toString,
catalogTable.flatMap(ct => Option(ct.identifier.table)),
snapshotToConvert.metadata.partitionColumns,
new HadoopStorageConfiguration(log.newDeltaHadoopConf()))
val lastDeltaVersionConverted: Option[Long] = loadLastDeltaVersionConverted(metaClient)
val maxCommitsToConvert =
Expand All @@ -233,7 +236,7 @@ class HudiConverter(spark: SparkSession)
try {
// TODO: We can optimize this by providing a checkpointHint to getSnapshotAt. Check if
// txn.snapshot.version < version. If true, use txn.snapshot's checkpoint as a hint.
Some(log.getSnapshotAt(version))
Some(log.getSnapshotAt(version, catalogTableOpt = catalogTable))
} catch {
// If we can't load the file since the last time Hudi was converted, it's likely that
// the commit file expired. Treat this like a new Hudi table conversion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class IcebergConverter(spark: SparkSession)
try {
// TODO: We can optimize this by providing a checkpointHint to getSnapshotAt. Check if
// txn.snapshot.version < version. If true, use txn.snapshot's checkpoint as a hint.
Some(log.getSnapshotAt(version))
Some(log.getSnapshotAt(version, catalogTableOpt = Some(catalogTable)))
} catch {
// If we can't load the file since the last time Iceberg was converted, it's likely that
// the commit file expired. Treat this like a new Iceberg table conversion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,10 @@ class DeltaAnalysis(session: SparkSession)
session, dataSourceV1.options
).foreach { rootSchemaTrackingLocation =>
assert(dataSourceV1.options.contains("path"), "Path for Delta table must be defined")
val log = DeltaLog.forTable(session, new Path(dataSourceV1.options("path")))
val log = dataSourceV1.catalogTable match {
case Some(catalogTable) => DeltaLog.forTable(session, catalogTable)
case None => DeltaLog.forTable(session, new Path(dataSourceV1.options("path")))
}
val sourceIdOpt = dataSourceV1.options.get(DeltaOptions.STREAMING_SOURCE_TRACKING_ID)
val schemaTrackingLocation =
DeltaSourceMetadataTrackingLog.fullMetadataTrackingLocation(
Expand Down
57 changes: 48 additions & 9 deletions spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,16 @@ object DeltaLog extends DeltaLogging {
}
}

/** Helper for creating a log for the table. */
def forTable(spark: SparkSession, table: CatalogTable, options: Map[String, String]): DeltaLog = {
apply(
spark,
logPathFor(new Path(table.location)),
options,
Some(table.identifier),
new SystemClock)
}

/** Helper for creating a log for the table. */
def forTable(spark: SparkSession, table: CatalogTable, clock: Clock): DeltaLog = {
apply(spark, logPathFor(new Path(table.location)), Some(table.identifier), clock)
Expand All @@ -809,37 +819,66 @@ object DeltaLog extends DeltaLogging {

/** Helper for getting a log, as well as the latest snapshot, of the table */
def forTableWithSnapshot(spark: SparkSession, dataPath: String): (DeltaLog, Snapshot) =
withFreshSnapshot { forTable(spark, new Path(dataPath), _) }
withFreshSnapshot { clock =>
(forTable(spark, new Path(dataPath), clock), None)
}

/** Helper for getting a log, as well as the latest snapshot, of the table */
def forTableWithSnapshot(spark: SparkSession, dataPath: Path): (DeltaLog, Snapshot) =
withFreshSnapshot { forTable(spark, dataPath, _) }
withFreshSnapshot { clock =>
(forTable(spark, dataPath, clock), None)
}

/** Helper for getting a log, as well as the latest snapshot, of the table */
def forTableWithSnapshot(
spark: SparkSession,
tableName: TableIdentifier): (DeltaLog, Snapshot) =
withFreshSnapshot { forTable(spark, tableName, _) }
tableName: TableIdentifier): (DeltaLog, Snapshot) = {
withFreshSnapshot { clock =>
if (DeltaTableIdentifier.isDeltaPath(spark, tableName)) {
(forTable(spark, new Path(tableName.table)), None)
} else {
val catalogTable = spark.sessionState.catalog.getTableMetadata(tableName)
(forTable(spark, catalogTable, clock), Some(catalogTable))
}
}
}

/** Helper for getting a log, as well as the latest snapshot, of the table */
def forTableWithSnapshot(
spark: SparkSession,
dataPath: Path,
options: Map[String, String]): (DeltaLog, Snapshot) =
withFreshSnapshot {
apply(spark, logPathFor(dataPath), options, initialTableIdentifier = None, _)
withFreshSnapshot { clock =>
(
apply(spark, logPathFor(dataPath), options, initialTableIdentifier = None, clock),
None
)
}

/** Helper for getting a log, as well as the latest snapshot, of the table */
def forTableWithSnapshot(
spark: SparkSession,
table: CatalogTable,
options: Map[String, String]): (DeltaLog, Snapshot) =
withFreshSnapshot { clock =>
(
apply(spark, logPathFor(new Path(table.location)), options, Some(table.identifier), clock),
Some(table)
)
}

/**
* Helper function to be used with the forTableWithSnapshot calls. Thunk is a
* partially applied DeltaLog.forTable call, which we can then wrap around with a
* snapshot update. We use the system clock to avoid back-to-back updates.
*/
private[delta] def withFreshSnapshot(thunk: Clock => DeltaLog): (DeltaLog, Snapshot) = {
private[delta] def withFreshSnapshot(
thunk: Clock => (DeltaLog, Option[CatalogTable])): (DeltaLog, Snapshot) = {
val clock = new SystemClock
val ts = clock.getTimeMillis()
val deltaLog = thunk(clock)
val snapshot = deltaLog.update(checkIfUpdatedSinceTs = Some(ts))
val (deltaLog, catalogTableOpt) = thunk(clock)
val snapshot =
deltaLog.update(checkIfUpdatedSinceTs = Some(ts), catalogTableOpt = catalogTableOpt)
(deltaLog, snapshot)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.internal.MDC
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.util.{ThreadUtils, Utils}

/**
Expand Down Expand Up @@ -1016,7 +1017,7 @@ trait SnapshotManagement { self: DeltaLog =>
def update(
stalenessAcceptable: Boolean = false,
checkIfUpdatedSinceTs: Option[Long] = None,
tableIdentifierOpt: Option[TableIdentifier] = None): Snapshot = {
catalogTableOpt: Option[CatalogTable] = None): Snapshot = {
val startTimeMs = System.currentTimeMillis()
// currentSnapshot is volatile. Make a local copy of it at the start of the update call, so
// that there's no chance of a race condition changing the snapshot partway through the update.
Expand Down Expand Up @@ -1049,7 +1050,7 @@ trait SnapshotManagement { self: DeltaLog =>
withSnapshotLockInterruptibly {
val newSnapshot = updateInternal(
isAsync = false,
tableIdentifierOpt)
catalogTableOpt.map(_.identifier))
sendEvent(newSnapshot = capturedSnapshot.snapshot)
newSnapshot
}
Expand All @@ -1067,7 +1068,7 @@ trait SnapshotManagement { self: DeltaLog =>
interruptOnCancel = true)
tryUpdate(
isAsync = true,
tableIdentifierOpt)
catalogTableOpt.map(_.identifier))
}
} catch {
case NonFatal(e) if !Utils.isTesting =>
Expand Down Expand Up @@ -1313,12 +1314,12 @@ trait SnapshotManagement { self: DeltaLog =>
def getSnapshotAt(
version: Long,
lastCheckpointHint: Option[CheckpointInstance] = None,
tableIdentifierOpt: Option[TableIdentifier] = None): Snapshot = {
catalogTableOpt: Option[CatalogTable] = None): Snapshot = {
getSnapshotAt(
version,
lastCheckpointHint,
lastCheckpointProvider = None,
tableIdentifierOpt)
catalogTableOpt)
}

/**
Expand All @@ -1329,7 +1330,7 @@ trait SnapshotManagement { self: DeltaLog =>
version: Long,
lastCheckpointHint: Option[CheckpointInstance],
lastCheckpointProvider: Option[CheckpointProvider],
tableIdentifierOpt: Option[TableIdentifier]): Snapshot = {
catalogTableOpt: Option[CatalogTable]): Snapshot = {

// See if the version currently cached on the cluster satisfies the requirement
val currentSnapshot = unsafeVolatileSnapshot
Expand All @@ -1338,7 +1339,7 @@ trait SnapshotManagement { self: DeltaLog =>
// upper bound.
currentSnapshot
} else {
val latestSnapshot = update(tableIdentifierOpt = tableIdentifierOpt)
val latestSnapshot = update(catalogTableOpt = catalogTableOpt)
if (latestSnapshot.version < version) {
throwNonExistentVersionError(version)
}
Expand All @@ -1360,6 +1361,7 @@ trait SnapshotManagement { self: DeltaLog =>
.map(manuallyLoadCheckpoint)
lastCheckpointInfoForListing -> None
}
val tableIdentifierOpt = catalogTableOpt.map(_.identifier)
val logSegmentOpt = createLogSegment(
versionToLoad = Some(version),
oldCheckpointProviderOpt = lastCheckpointProviderOpt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,12 @@ case class DeltaTableV2(
"queriedVersion" -> version,
"accessType" -> accessType
))
deltaLog.getSnapshotAt(version)
deltaLog.getSnapshotAt(version, catalogTableOpt = catalogTable)
}.getOrElse(
deltaLog.update(
stalenessAcceptable = true,
checkIfUpdatedSinceTs = Some(creationTimeMs)
checkIfUpdatedSinceTs = Some(creationTimeMs),
catalogTableOpt = catalogTable
)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ trait DeltaCommand extends DeltaLogging {

/**
* Utility method to return the [[DeltaLog]] of an existing Delta table referred
* by either the given [[path]] or [[tableIdentifier].
* by either the given [[path]] or [[tableIdentifier]].
*
* @param spark [[SparkSession]] reference to use.
* @param path Table location. Expects a non-empty [[tableIdentifier]] or [[path]].
Expand All @@ -241,18 +241,18 @@ trait DeltaCommand extends DeltaLogging {
tableIdentifier: Option[TableIdentifier],
operationName: String,
hadoopConf: Map[String, String] = Map.empty): DeltaLog = {
val tablePath =
val (deltaLog, catalogTable) =
if (path.nonEmpty) {
new Path(path.get)
(DeltaLog.forTable(spark, new Path(path.get), hadoopConf), None)
} else if (tableIdentifier.nonEmpty) {
val sessionCatalog = spark.sessionState.catalog
lazy val metadata = sessionCatalog.getTableMetadata(tableIdentifier.get)

DeltaTableIdentifier(spark, tableIdentifier.get) match {
case Some(id) if id.path.nonEmpty =>
new Path(id.path.get)
(DeltaLog.forTable(spark, new Path(id.path.get), hadoopConf), None)
case Some(id) if id.table.nonEmpty =>
new Path(metadata.location)
(DeltaLog.forTable(spark, metadata, hadoopConf), Some(metadata))
case _ =>
if (metadata.tableType == CatalogTableType.VIEW) {
throw DeltaErrors.viewNotSupported(operationName)
Expand All @@ -264,8 +264,9 @@ trait DeltaCommand extends DeltaLogging {
}

val startTime = Some(System.currentTimeMillis)
val deltaLog = DeltaLog.forTable(spark, tablePath, hadoopConf)
if (deltaLog.update(checkIfUpdatedSinceTs = startTime).version < 0) {
if (deltaLog
.update(checkIfUpdatedSinceTs = startTime, catalogTableOpt = catalogTable)
.version < 0) {
throw DeltaErrors.notADeltaTableException(
operationName,
DeltaTableIdentifier(path, tableIdentifier))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ case class DeltaGenerateCommand(
throw DeltaErrors.unsupportedGenerateModeException(modeName)
}

val tablePath = DeltaTableIdentifier(sparkSession, tableId) match {
val deltaLog = DeltaTableIdentifier(sparkSession, tableId) match {
case Some(id) if id.path.isDefined =>
new Path(id.path.get)
DeltaLog.forTable(sparkSession, new Path(id.path.get), options)
case _ =>
new Path(sparkSession.sessionState.catalog.getTableMetadata(tableId).location)
DeltaLog.forTable(
sparkSession,
sparkSession.sessionState.catalog.getTableMetadata(tableId),
options)
}

val deltaLog = DeltaLog.forTable(sparkSession, tablePath, options)
if (!deltaLog.tableExists) {
throw DeltaErrors.notADeltaTableException("GENERATE")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,14 @@ case class DescribeDeltaDetailCommand(
override def run(sparkSession: SparkSession): Seq[Row] = {
val tableMetadata = getTableCatalogTable(child, DescribeDeltaDetailCommand.CMD_NAME)
val (_, path) = getTablePathOrIdentifier(child, DescribeDeltaDetailCommand.CMD_NAME)
val basePath = tableMetadata match {
case Some(metadata) => new Path(metadata.location)
case _ if path.isDefined => new Path(path.get)
val deltaLog = tableMetadata match {
case Some(metadata) => DeltaLog.forTable(sparkSession, metadata, hadoopConf)
case _ if path.isDefined => DeltaLog.forTable(sparkSession, new Path(path.get), hadoopConf)
case _ =>
throw DeltaErrors.missingTableIdentifierException(DescribeDeltaDetailCommand.CMD_NAME)
}
val deltaLog = DeltaLog.forTable(sparkSession, basePath, hadoopConf)
recordDeltaOperation(deltaLog, "delta.ddl.describeDetails") {
val snapshot = deltaLog.update()
val snapshot = deltaLog.update(catalogTableOpt = tableMetadata)
if (snapshot.version == -1) {
if (path.nonEmpty) {
val fs = new Path(path.get).getFileSystem(deltaLog.newDeltaHadoopConf())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ case class RestoreTableCommand(sourceTable: DeltaTableV2)

override def run(spark: SparkSession): Seq[Row] = {
val deltaLog = sourceTable.deltaLog
val catalogTableOpt = sourceTable.catalogTable
val version = sourceTable.timeTravelOpt.get.version
val timestamp = getTimestamp()
recordDeltaOperation(deltaLog, "delta.restore") {
Expand All @@ -105,14 +106,18 @@ case class RestoreTableCommand(sourceTable: DeltaTableV2)
.version
}

val latestVersion = deltaLog.update().version
val latestVersion = deltaLog
.update(catalogTableOpt = catalogTableOpt)
.version

require(versionToRestore < latestVersion, s"Version to restore ($versionToRestore)" +
s"should be less then last available version ($latestVersion)")

deltaLog.withNewTransaction(sourceTable.catalogTable) { txn =>
deltaLog.withNewTransaction(catalogTableOpt) { txn =>
val latestSnapshot = txn.snapshot
val snapshotToRestore = deltaLog.getSnapshotAt(versionToRestore)
val snapshotToRestore = deltaLog.getSnapshotAt(
versionToRestore,
catalogTableOpt = catalogTableOpt)
val latestSnapshotFiles = latestSnapshot.allFiles
val snapshotToRestoreFiles = snapshotToRestore.allFiles

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object CheckpointHook extends PostCommitHook {
committedVersion,
lastCheckpointHint = None,
lastCheckpointProvider = Some(cp),
tableIdentifierOpt = txn.catalogTable.map(_.identifier))
catalogTableOpt = txn.catalogTable)
txn.deltaLog.checkpoint(snapshotToCheckpoint, txn.catalogTable.map(_.identifier))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ class DeltaLogSuite extends QueryTest
Iterator(JsonUtils.toJson(add.wrap)),
overwrite = false,
deltaLog.newDeltaHadoopConf())
deltaLog
(deltaLog, None)
}
assert(snapshot.version === 0)

Expand Down
Loading