Skip to content

Commit

Permalink
Move API back to Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Aug 8, 2024
1 parent 056492b commit a56228c
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,8 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.PartitionTransform$ExtractTransform"),

// Update Writer
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SparkSession.update"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWriter"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWriter$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithAssignment"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithAssignment$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithCondition"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithCondition$")) ++
mergeIntoWriterExcludeRules
Expand Down
23 changes: 23 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4136,6 +4136,29 @@ class Dataset[T] private[sql](
new MergeIntoWriter[T](table, this, condition)
}

/**
* Update rows in a table that match a condition.
*
* Scala Example:
* {{{
* spark.table("source").update(Map("salary" -> lit(200)))
* .where($"salary" === 100)
* .execute()
*
* }}}
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @since 4.0.0
*/
def update(assignments: Map[String, Column]): UpdateWriter[T] = {
if (isStreaming) {
throw new AnalysisException(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}
new UpdateWriter[T](this, assignments)
}

/**
* Interface for saving the content of the streaming Dataset out into external storage.
*
Expand Down
31 changes: 0 additions & 31 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -833,36 +832,6 @@ class SparkSession private(
ret
}

/**
* Update rows in a table that match a condition.
*
* Scala Example:
* {{{
* spark.update("source")
* .set(
* Map("salary" -> lit(200))
* )
* .where($"salary" === 100)
* .execute()
*
* }}}
* @param tableName is either a qualified or unqualified name that designates a table or view.
* If a database is specified, it identifies the table/view from the database.
* Otherwise, it first attempts to find a temporary view with the given name
* and then match the table/view from the current database.
* Note that, the global temporary view database is also valid here.
* @since 4.0.0
*/
def update(tableName: String): UpdateWriter = {
val tableDF = table(tableName)
if (tableDF.isStreaming) {
throw new AnalysisException(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}
new UpdateWriter(tableDF)
}

// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
/**
Expand Down
45 changes: 14 additions & 31 deletions sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,71 +22,54 @@ import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable}
import org.apache.spark.sql.functions.expr

/**
* `UpdateWriter` provides methods to define and execute an update action on a target table.
* This class defines methods to specify a condition an an update operation
* or directly executing it.
*
* @param tableDF DataFrame representing table to update.
*
* @since 4.0.0
*/
@Experimental
class UpdateWriter (tableDF: DataFrame) {

/**
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
*/
def set(assignments: Map[String, Column]): UpdateWithAssignment = {
new UpdateWithAssignment(tableDF, assignments)
}
}

/**
* A class for defining a condition on an update operation or directly executing it.
*
* @param tableDF DataFrame representing table to update.
* @param assignment A Map of column names to Column expressions representing the updates
* @param dataset DataSet representing table to update.
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @param T type of dataset
*
* @since 4.0.0
*/
@Experimental
class UpdateWithAssignment(tableDF: DataFrame, assignment: Map[String, Column]) {
class UpdateWriter[T](dataset: Dataset[T], assignments: Map[String, Column]) {

/**
* Limits the update to rows matching the specified condition.
*
* @param condition the update condition
* @return
*/
def where(condition: Column): UpdateWithCondition = {
new UpdateWithCondition(tableDF, assignment, Some(condition))
def where(condition: Column): UpdateWithCondition[T] = {
new UpdateWithCondition(dataset, assignments, Some(condition))
}

/**
* Executes the update operation.
*/
def execute(): Unit = {
new UpdateWithCondition(tableDF, assignment, None)
new UpdateWithCondition(dataset, assignments, None)
}
}

/**
* A class for executing an update operation.
*
* @param tableDF DataFrame representing table to update.
* @param dataset Dataset representing table to update.
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @param condition the update condition
* @since 4.0.0
*/
@Experimental
class UpdateWithCondition(
tableDF: DataFrame,
class UpdateWithCondition[T](
dataset: Dataset[T],
assignments: Map[String, Column],
condition: Option[Column]) {

private val sparkSession = tableDF.sparkSession
private val logicalPlan = tableDF.queryExecution.logical
private val sparkSession = dataset.sparkSession
private val logicalPlan = dataset.queryExecution.logical

/**
* Executes the update operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {
|{ "pk": 3, "salary": 120, "dep": 'hr' }
|""".stripMargin)

spark.update(tableNameAsString)
.set(Map("salary" -> lit(-1)))
spark.table(tableNameAsString)
.update(Map("salary" -> lit(-1)))
.where($"pk" >= 2)
.execute()

Expand All @@ -51,8 +51,8 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {
|{ "pk": 3, "salary": 120, "dep": 'hr' }
|""".stripMargin)

spark.update(tableNameAsString)
.set(Map("dep" -> lit("software")))
spark.table(tableNameAsString)
.update(Map("dep" -> lit("software")))
.execute()

checkAnswer(
Expand Down

0 comments on commit a56228c

Please sign in to comment.