Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into detect-unused-error…
Browse files Browse the repository at this point in the history
…-params
  • Loading branch information
MaxGekk committed Sep 9, 2024
2 parents bc72974 + 20643bb commit 07df49a
Show file tree
Hide file tree
Showing 23 changed files with 661 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ List<String> buildClassPath(String appClassPath) throws IOException {

boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES"));
boolean isTesting = "1".equals(getenv("SPARK_TESTING"));
boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING"));
String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql);
if (prependClasses || isTesting) {
String scala = getScalaVersion();
List<String> projects = Arrays.asList(
Expand Down Expand Up @@ -176,6 +178,9 @@ List<String> buildClassPath(String appClassPath) throws IOException {
"NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " +
"assembly.");
}
boolean shouldPrePendSparkHive = isJarAvailable(jarsDir, "spark-hive_");
boolean shouldPrePendSparkHiveThriftServer =
shouldPrePendSparkHive && isJarAvailable(jarsDir, "spark-hive-thriftserver_");
for (String project : projects) {
// Do not use locally compiled class files for Spark server because it should use shaded
// dependencies.
Expand All @@ -185,6 +190,24 @@ List<String> buildClassPath(String appClassPath) throws IOException {
if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL")) && project.equals("sql/core")) {
continue;
}
// SPARK-49534: The assumption here is that if `spark-hive_xxx.jar` is not in the
// classpath, then the `-Phive` profile was not used during package, and therefore
// the Hive-related jars should also not be in the classpath. To avoid failure in
// loading the SPI in `DataSourceRegister` under `sql/hive`, no longer prepend `sql/hive`.
if (!shouldPrePendSparkHive && project.equals("sql/hive")) {
continue;
}
// SPARK-49534: Meanwhile, due to the strong dependency of `sql/hive-thriftserver`
// on `sql/hive`, the prepend for `sql/hive-thriftserver` will also be excluded
// if `spark-hive_xxx.jar` is not in the classpath. On the other hand, if
// `spark-hive-thriftserver_xxx.jar` is not in the classpath, then the
// `-Phive-thriftserver` profile was not used during package, and therefore,
// jars such as hive-cli and hive-beeline should also not be included in the classpath.
// To avoid the inelegant startup failures of tools such as spark-sql, in this scenario,
// `sql/hive-thriftserver` will no longer be prepended to the classpath.
if (!shouldPrePendSparkHiveThriftServer && project.equals("sql/hive-thriftserver")) {
continue;
}
addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project,
scala));
}
Expand All @@ -205,8 +228,6 @@ List<String> buildClassPath(String appClassPath) throws IOException {
// Add Spark jars to the classpath. For the testing case, we rely on the test code to set and
// propagate the test classpath appropriately. For normal invocation, look for the jars
// directory under SPARK_HOME.
boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING"));
String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql);
if (jarsDir != null) {
// Place slf4j-api-* jar first to be robust
for (File f: new File(jarsDir).listFiles()) {
Expand Down Expand Up @@ -265,6 +286,24 @@ private void addToClassPath(Set<String> cp, String entries) {
}
}

/**
* Checks if a JAR file with a specific prefix is available in the given directory.
*
* @param jarsDir the directory to search for JAR files
* @param jarNamePrefix the prefix of the JAR file name to look for
* @return true if a JAR file with the specified prefix is found, false otherwise
*/
private boolean isJarAvailable(String jarsDir, String jarNamePrefix) {
if (jarsDir != null) {
for (File f : new File(jarsDir).listFiles()) {
if (f.getName().startsWith(jarNamePrefix)) {
return true;
}
}
}
return false;
}

String getScalaVersion() {
String scala = getenv("SPARK_SCALA_VERSION");
if (scala != null) {
Expand Down
18 changes: 5 additions & 13 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote
from pyspark import pandas as ps
from pyspark.pandas.spark import functions as SF
from pyspark.pandas._typing import Label
from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
from pyspark.pandas.data_type_ops.base import DataTypeOps
Expand Down Expand Up @@ -938,19 +939,10 @@ def attach_distributed_sequence_column(
+--------+---+
"""
if len(sdf.columns) > 0:
if is_remote():
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.connect.expressions import DistributedSequenceID

return sdf.select(
ConnectColumn(DistributedSequenceID()).alias(column_name),
"*",
)
else:
return PySparkDataFrame(
sdf._jdf.toDF().withSequenceColumn(column_name),
sdf.sparkSession,
)
return sdf.select(
SF.distributed_sequence_id().alias(column_name),
"*",
)
else:
cnt = sdf.count()
if cnt > 0:
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ def null_index(col: Column) -> Column:
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))


def distributed_sequence_id() -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function

return _invoke_function("distributed_sequence_id")
else:
from pyspark import SparkContext

sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id())


def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns
Expand Down
60 changes: 30 additions & 30 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7305,36 +7305,36 @@ def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) ->
| b| 2|
+---+---+
>>> w = Window.partitionBy("c1").orderBy("c2")
>>> df.withColumn("previos_value", lag("c2").over(w)).show()
+---+---+-------------+
| c1| c2|previos_value|
+---+---+-------------+
| a| 1| NULL|
| a| 2| 1|
| a| 3| 2|
| b| 2| NULL|
| b| 8| 2|
+---+---+-------------+
>>> df.withColumn("previos_value", lag("c2", 1, 0).over(w)).show()
+---+---+-------------+
| c1| c2|previos_value|
+---+---+-------------+
| a| 1| 0|
| a| 2| 1|
| a| 3| 2|
| b| 2| 0|
| b| 8| 2|
+---+---+-------------+
>>> df.withColumn("previos_value", lag("c2", 2, -1).over(w)).show()
+---+---+-------------+
| c1| c2|previos_value|
+---+---+-------------+
| a| 1| -1|
| a| 2| -1|
| a| 3| 1|
| b| 2| -1|
| b| 8| -1|
+---+---+-------------+
>>> df.withColumn("previous_value", lag("c2").over(w)).show()
+---+---+--------------+
| c1| c2|previous_value|
+---+---+--------------+
| a| 1| NULL|
| a| 2| 1|
| a| 3| 2|
| b| 2| NULL|
| b| 8| 2|
+---+---+--------------+
>>> df.withColumn("previous_value", lag("c2", 1, 0).over(w)).show()
+---+---+--------------+
| c1| c2|previous_value|
+---+---+--------------+
| a| 1| 0|
| a| 2| 1|
| a| 3| 2|
| b| 2| 0|
| b| 8| 2|
+---+---+--------------+
>>> df.withColumn("previous_value", lag("c2", 2, -1).over(w)).show()
+---+---+--------------+
| c1| c2|previous_value|
+---+---+--------------+
| a| 1| -1|
| a| 2| -1|
| a| 3| 1|
| b| 2| -1|
| b| 8| -1|
+---+---+--------------+
"""
from pyspark.sql.classic.column import _to_java_column

Expand Down
8 changes: 0 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2010,14 +2010,6 @@ class Dataset[T] private[sql](
// For Python API
////////////////////////////////////////////////////////////////////////////

/**
* It adds a new long column with the name `name` that increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
private[sql] def withSequenceColumn(name: String) = {
select(column(DistributedSequenceID()).alias(name), col("*"))
}

/**
* Converts a JavaRDD to a PythonRDD.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ private[sql] object PythonSQLUtils extends Logging {
def pandasCovar(col1: Column, col2: Column, ddof: Int): Column =
Column.internalFn("pandas_covar", col1, col2, lit(ddof))

/**
* A long column that increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
def distributed_sequence_id(): Column =
Column.internalFn("distributed_sequence_id")

def unresolvedNamedLambdaVariable(name: String): Column =
Column(internal.UnresolvedNamedLambdaVariable.apply(name))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.ThreadUtils

Expand All @@ -40,8 +41,14 @@ trait AsyncLogPurge extends Logging {

private val purgeRunning = new AtomicBoolean(false)

private val statefulMetadataPurgeRunning = new AtomicBoolean(false)

protected def purge(threshold: Long): Unit

// This method is used to purge the oldest OperatorStateMetadata and StateSchema files
// which are written per run.
protected def purgeStatefulMetadata(plan: SparkPlan): Unit

protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE)

protected def purgeAsync(batchId: Long): Unit = {
Expand All @@ -62,6 +69,24 @@ trait AsyncLogPurge extends Logging {
}
}

protected def purgeStatefulMetadataAsync(plan: SparkPlan): Unit = {
if (statefulMetadataPurgeRunning.compareAndSet(false, true)) {
asyncPurgeExecutorService.execute(() => {
try {
purgeStatefulMetadata(plan)
} catch {
case throwable: Throwable =>
logError("Encountered error while performing async log purge", throwable)
errorNotifier.markError(throwable)
} finally {
statefulMetadataPurgeRunning.set(false)
}
})
} else {
log.debug("Skipped log purging since there is already one in progress.")
}
}

protected def asyncLogPurgeShutdown(): Unit = {
ThreadUtils.shutdown(asyncPurgeExecutorService)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,10 @@ class MicroBatchExecution(

markMicroBatchExecutionStart(execCtx)

if (execCtx.previousContext.isEmpty) {
purgeStatefulMetadataAsync(execCtx.executionPlan.executedPlan)
}

val nextBatch =
new Dataset(execCtx.executionPlan, ExpressionEncoder(execCtx.executionPlan.analyzed.schema))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream}
import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.StreamingExplainCommand
import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchUserFuncException, ForeachUserFuncException}
import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataV2FileManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
import org.apache.spark.sql.streaming._
Expand Down Expand Up @@ -686,6 +688,31 @@ abstract class StreamExecution(
offsetLog.purge(threshold)
commitLog.purge(threshold)
}

protected def purgeStatefulMetadata(plan: SparkPlan): Unit = {
plan.collect { case statefulOperator: StatefulOperator =>
statefulOperator match {
case ssw: StateStoreWriter =>
ssw.operatorStateMetadataVersion match {
case 2 =>
// checkpointLocation of the operator is runId/state, and commitLog path is
// runId/commits, so we want the parent of the checkpointLocation to get the
// commit log path.
val parentCheckpointLocation =
new Path(statefulOperator.getStateInfo.checkpointLocation).getParent

val fileManager = new OperatorStateMetadataV2FileManager(
parentCheckpointLocation,
sparkSession,
ssw
)
fileManager.purgeMetadataFiles()
case _ =>
}
case _ =>
}
}
}
}

object StreamExecution {
Expand Down
Loading

0 comments on commit 07df49a

Please sign in to comment.