diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 9191b3ec4b..9214f55130 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -222,6 +222,7 @@ class LinearSVC @Since("2.2.0") (
}
val featuresStd = summarizer.std.toArray
+ val featuresMean = summarizer.mean.toArray
val getFeaturesStd = (j: Int) => featuresStd(j)
val regularization = if ($(regParam) != 0.0) {
val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures
@@ -239,7 +240,8 @@ class LinearSVC @Since("2.2.0") (
as a result, no scaling is needed.
*/
val (rawCoefficients, objectiveHistory) =
- trainImpl(instances, actualBlockSizeInMB, featuresStd, regularization, optimizer)
+ trainImpl(instances, actualBlockSizeInMB, featuresStd, featuresMean,
+ regularization, optimizer)
if (rawCoefficients == null) {
val msg = s"${optimizer.getClass.getName} failed."
@@ -277,16 +279,19 @@ class LinearSVC @Since("2.2.0") (
instances: RDD[Instance],
actualBlockSizeInMB: Double,
featuresStd: Array[Double],
+ featuresMean: Array[Double],
regularization: Option[L2Regularization],
optimizer: BreezeOWLQN[Int, BDV[Double]]): (Array[Double], Array[Double]) = {
val numFeatures = featuresStd.length
val numFeaturesPlusIntercept = if ($(fitIntercept)) numFeatures + 1 else numFeatures
- val bcFeaturesStd = instances.context.broadcast(featuresStd)
+ val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0)
+ val scaledMean = Array.tabulate(numFeatures)(i => inverseStd(i) * featuresMean(i))
+ val bcInverseStd = instances.context.broadcast(inverseStd)
+ val bcScaledMean = instances.context.broadcast(scaledMean)
val standardized = instances.mapPartitions { iter =>
- val inverseStd = bcFeaturesStd.value.map { std => if (std != 0) 1.0 / std else 0.0 }
- val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true)
+ val func = StandardScalerModel.getTransformFunc(Array.empty, bcInverseStd.value, false, true)
iter.map { case Instance(label, weight, vec) => Instance(label, weight, func(vec)) }
}
@@ -295,13 +300,24 @@ class LinearSVC @Since("2.2.0") (
.persist(StorageLevel.MEMORY_AND_DISK)
.setName(s"training blocks (blockSizeInMB=$actualBlockSizeInMB)")
- val getAggregatorFunc = new BlockHingeAggregator($(fitIntercept))(_)
+ val getAggregatorFunc = new HingeBlockAggregator(bcInverseStd, bcScaledMean,
+ $(fitIntercept))(_)
val costFun = new RDDLossFunction(blocks, getAggregatorFunc,
regularization, $(aggregationDepth))
- val states = optimizer.iterations(new CachedDiffFunction(costFun),
- Vectors.zeros(numFeaturesPlusIntercept).asBreeze.toDenseVector)
+ val initialSolution = Array.ofDim[Double](numFeaturesPlusIntercept)
+ if ($(fitIntercept)) {
+ // orginal `initialSolution` is for problem:
+ // y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
+ // we should adjust it to the initial solution for problem:
+ // y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
+ // NOTE: this is NOOP before we finally support model initialization
+ val adapt = BLAS.javaBLAS.ddot(numFeatures, initialSolution, 1, scaledMean, 1)
+ initialSolution(numFeatures) += adapt
+ }
+ val states = optimizer.iterations(new CachedDiffFunction(costFun),
+ new BDV[Double](initialSolution))
val arrayBuilder = mutable.ArrayBuilder.make[Double]
var state: optimizer.State = null
while (states.hasNext) {
@@ -309,9 +325,19 @@ class LinearSVC @Since("2.2.0") (
arrayBuilder += state.adjustedValue
}
blocks.unpersist()
- bcFeaturesStd.destroy()
-
- (if (state != null) state.x.toArray else null, arrayBuilder.result)
+ bcInverseStd.destroy()
+ bcScaledMean.destroy()
+
+ val solution = if (state == null) null else state.x.toArray
+ if ($(fitIntercept) && solution != null) {
+ // the final solution is for problem:
+ // y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
+ // we should adjust it back for original problem:
+ // y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
+ val adapt = BLAS.javaBLAS.ddot(numFeatures, solution, 1, scaledMean, 1)
+ solution(numFeatures) -= adapt
+ }
+ (solution, arrayBuilder.result)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 57fb46b451..c3c54651ba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -982,14 +982,14 @@ class LogisticRegression @Since("1.2.0") (
val adapt = Array.ofDim[Double](numClasses)
BLAS.javaBLAS.dgemv("N", numClasses, numFeatures, 1.0,
initialSolution, numClasses, scaledMean, 1, 0.0, adapt, 1)
- BLAS.getBLAS(numFeatures).daxpy(numClasses, 1.0, adapt, 0, 1,
+ BLAS.javaBLAS.daxpy(numClasses, 1.0, adapt, 0, 1,
initialSolution, numClasses * numFeatures, 1)
} else {
- // orginal `initialCoefWithInterceptArray` is for problem:
+ // original `initialSolution` is for problem:
// y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
// we should adjust it to the initial solution for problem:
// y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
- val adapt = BLAS.getBLAS(numFeatures).ddot(numFeatures, initialSolution, 1, scaledMean, 1)
+ val adapt = BLAS.javaBLAS.ddot(numFeatures, initialSolution, 1, scaledMean, 1)
initialSolution(numFeatures) += adapt
}
}
@@ -1018,14 +1018,14 @@ class LogisticRegression @Since("1.2.0") (
val adapt = Array.ofDim[Double](numClasses)
BLAS.javaBLAS.dgemv("N", numClasses, numFeatures, 1.0,
solution, numClasses, scaledMean, 1, 0.0, adapt, 1)
- BLAS.getBLAS(numFeatures).daxpy(numClasses, -1.0, adapt, 0, 1,
+ BLAS.javaBLAS.daxpy(numClasses, -1.0, adapt, 0, 1,
solution, numClasses * numFeatures, 1)
} else {
// the final solution is for problem:
// y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
// we should adjust it back for original problem:
// y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
- val adapt = BLAS.getBLAS(numFeatures).ddot(numFeatures, solution, 1, scaledMean, 1)
+ val adapt = BLAS.javaBLAS.ddot(numFeatures, solution, 1, scaledMean, 1)
solution(numFeatures) -= adapt
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregator.scala
index 091c885ca0..09a4335dad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregator.scala
@@ -72,7 +72,7 @@ private[ml] class BinaryLogisticBlockAggregator(
// deal with non-zero values in prediction.
private val marginOffset = if (fitWithMean) {
coefficientsArray.last -
- BLAS.getBLAS(numFeatures).ddot(numFeatures, coefficientsArray, 1, bcScaledMean.value, 1)
+ BLAS.javaBLAS.ddot(numFeatures, coefficientsArray, 1, bcScaledMean.value, 1)
} else {
Double.NaN
}
@@ -142,7 +142,7 @@ private[ml] class BinaryLogisticBlockAggregator(
case sm: SparseMatrix if fitIntercept =>
val linearGradSumVec = new DenseVector(Array.ofDim[Double](numFeatures))
BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec)
- BLAS.getBLAS(numFeatures).daxpy(numFeatures, 1.0, linearGradSumVec.values, 1,
+ BLAS.javaBLAS.daxpy(numFeatures, 1.0, linearGradSumVec.values, 1,
gradientSumArray, 1)
case sm: SparseMatrix if !fitIntercept =>
@@ -156,7 +156,7 @@ private[ml] class BinaryLogisticBlockAggregator(
if (fitWithMean) {
// above update of the linear part of gradientSumArray does NOT take the centering
// into account, here we need to adjust this part.
- BLAS.getBLAS(numFeatures).daxpy(numFeatures, -multiplierSum, bcScaledMean.value, 1,
+ BLAS.javaBLAS.daxpy(numFeatures, -multiplierSum, bcScaledMean.value, 1,
gradientSumArray, 1)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregator.scala
new file mode 100644
index 0000000000..f99c531c96
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregator.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.optim.aggregator
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.InstanceBlock
+import org.apache.spark.ml.linalg._
+
+
+/**
+ * HingeBlockAggregator computes the gradient and loss for Huber loss function
+ * as used in linear regression for blocks in sparse or dense matrix in an online fashion.
+ *
+ * Two BlockHuberAggregators can be merged together to have a summary of loss and gradient
+ * of the corresponding joint dataset.
+ *
+ * NOTE: The feature values are expected to already have be scaled (multiplied by bcInverseStd,
+ * but NOT centered) before computation.
+ *
+ * @param bcCoefficients The coefficients corresponding to the features.
+ * @param fitIntercept Whether to fit an intercept term. When true, will perform data centering
+ * in a virtual way. Then we MUST adjust the intercept of both initial
+ * coefficients and final solution in the caller.
+ */
+private[ml] class HingeBlockAggregator(
+ bcInverseStd: Broadcast[Array[Double]],
+ bcScaledMean: Broadcast[Array[Double]],
+ fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector])
+ extends DifferentiableLossAggregator[InstanceBlock, HingeBlockAggregator]
+ with Logging {
+
+ if (fitIntercept) {
+ require(bcScaledMean != null && bcScaledMean.value.length == bcInverseStd.value.length,
+ "scaled means is required when center the vectors")
+ }
+
+ private val numFeatures = bcInverseStd.value.length
+ protected override val dim: Int = bcCoefficients.value.size
+
+ @transient private lazy val coefficientsArray = bcCoefficients.value match {
+ case DenseVector(values) => values
+ case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector but " +
+ s"got type ${bcCoefficients.value.getClass}.)")
+ }
+
+ @transient private lazy val linear = if (fitIntercept) {
+ new DenseVector(coefficientsArray.take(numFeatures))
+ } else {
+ new DenseVector(coefficientsArray)
+ }
+
+ // pre-computed margin of an empty vector.
+ // with this variable as an offset, for a sparse vector, we only need to
+ // deal with non-zero values in prediction.
+ private val marginOffset = if (fitIntercept) {
+ coefficientsArray.last -
+ BLAS.javaBLAS.ddot(numFeatures, coefficientsArray, 1, bcScaledMean.value, 1)
+ } else {
+ Double.NaN
+ }
+
+ /**
+ * Add a new training instance block to this HingeBlockAggregator, and update the loss
+ * and gradient of the objective function.
+ *
+ * @param block The instance block of data point to be added.
+ * @return This HingeBlockAggregator object.
+ */
+ def add(block: InstanceBlock): this.type = {
+ require(block.matrix.isTransposed)
+ require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " +
+ s"instance. Expecting $numFeatures but got ${block.numFeatures}.")
+ require(block.weightIter.forall(_ >= 0),
+ s"instance weights ${block.weightIter.mkString("[", ",", "]")} has to be >= 0.0")
+
+ if (block.weightIter.forall(_ == 0)) return this
+ val size = block.size
+
+ // vec/arr here represents margins
+ val vec = new DenseVector(Array.ofDim[Double](size))
+ val arr = vec.values
+ if (fitIntercept) java.util.Arrays.fill(arr, marginOffset)
+ BLAS.gemv(1.0, block.matrix, linear, 1.0, vec)
+
+ // in-place convert margins to multiplier
+ // then, vec/arr represents multiplier
+ var localLossSum = 0.0
+ var localWeightSum = 0.0
+ var multiplierSum = 0.0
+ var i = 0
+ while (i < size) {
+ val weight = block.getWeight(i)
+ localWeightSum += weight
+ if (weight > 0) {
+ // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x)))
+ // Therefore the gradient is -(2y - 1)*x
+ val label = block.getLabel(i)
+ val labelScaled = label + label - 1.0
+ val loss = (1.0 - labelScaled * arr(i)) * weight
+ if (loss > 0) {
+ localLossSum += loss
+ val multiplier = -labelScaled * weight
+ arr(i) = multiplier
+ multiplierSum += multiplier
+ } else { arr(i) = 0.0 }
+ } else { arr(i) = 0.0 }
+ i += 1
+ }
+ lossSum += localLossSum
+ weightSum += localWeightSum
+
+ // predictions are all correct, no gradient signal
+ if (arr.forall(_ == 0)) return this
+
+ // update the linear part of gradientSumArray
+ block.matrix match {
+ case dm: DenseMatrix =>
+ BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols,
+ vec.values, 1, 1.0, gradientSumArray, 1)
+
+ case sm: SparseMatrix if fitIntercept =>
+ val linearGradSumVec = new DenseVector(Array.ofDim[Double](numFeatures))
+ BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec)
+ BLAS.javaBLAS.daxpy(numFeatures, 1.0, linearGradSumVec.values, 1,
+ gradientSumArray, 1)
+
+ case sm: SparseMatrix if !fitIntercept =>
+ val gradSumVec = new DenseVector(gradientSumArray)
+ BLAS.gemv(1.0, sm.transpose, vec, 1.0, gradSumVec)
+
+ case m =>
+ throw new IllegalArgumentException(s"Unknown matrix type ${m.getClass}.")
+ }
+
+ if (fitIntercept) {
+ // above update of the linear part of gradientSumArray does NOT take the centering
+ // into account, here we need to adjust this part.
+ BLAS.javaBLAS.daxpy(numFeatures, -multiplierSum, bcScaledMean.value, 1,
+ gradientSumArray, 1)
+
+ // update the intercept part of gradientSumArray
+ gradientSumArray(numFeatures) += multiplierSum
+ }
+
+ this
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregator.scala
index de64440843..0683cec628 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregator.scala
@@ -203,7 +203,7 @@ private[ml] class MultinomialLogisticBlockAggregator(
}
if (fitIntercept) {
- BLAS.getBLAS(numClasses).daxpy(numClasses, 1.0, multiplierSum, 0, 1,
+ BLAS.javaBLAS.daxpy(numClasses, 1.0, multiplierSum, 0, 1,
gradientSumArray, numClasses * numFeatures, 1)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index df64de4b10..837883e53d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
private[libsvm] class LibSVMOutputWriter(
- path: String,
+ val path: String,
dataSchema: StructType,
context: TaskAttemptContext)
extends OutputWriter {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index a2c376d80e..d2cfedcc33 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -22,7 +22,6 @@ import java.lang.{Iterable => JavaIterable}
import scala.collection.JavaConverters._
import scala.collection.mutable
-import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.google.common.collect.{Ordering => GuavaOrdering}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
@@ -34,6 +33,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Kryo.KRYO_SERIALIZER_MAX_BUFFER_SIZE
+import org.apache.spark.ml.linalg.BLAS
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
@@ -401,18 +401,18 @@ class Word2Vec extends Serializable with Logging {
val inner = bcVocab.value(word).point(d)
val l2 = inner * vectorSize
// Propagate hidden -> output
- var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
+ var f = BLAS.nativeBLAS.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
- blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
- blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
+ BLAS.nativeBLAS.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
+ BLAS.nativeBLAS.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
syn1Modify(inner) += 1
}
d += 1
}
- blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
+ BLAS.nativeBLAS.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
syn0Modify(lastWord) += 1
}
}
@@ -448,10 +448,10 @@ class Word2Vec extends Serializable with Logging {
(id, (vec, 1))
}
}.reduceByKey { (vc1, vc2) =>
- blas.saxpy(vectorSize, 1.0f, vc2._1, 1, vc1._1, 1)
+ BLAS.nativeBLAS.saxpy(vectorSize, 1.0f, vc2._1, 1, vc1._1, 1)
(vc1._1, vc1._2 + vc2._2)
}.map { case (id, (vec, count)) =>
- blas.sscal(vectorSize, 1.0f / count, vec, 1)
+ BLAS.nativeBLAS.sscal(vectorSize, 1.0f / count, vec, 1)
(id, vec)
}.collect()
var i = 0
@@ -511,7 +511,7 @@ class Word2VecModel private[spark] (
private lazy val wordVecInvNorms: Array[Float] = {
val size = vectorSize
Array.tabulate(numWords) { i =>
- val norm = blas.snrm2(size, wordVectors, i * size, 1)
+ val norm = BLAS.nativeBLAS.snrm2(size, wordVectors, i * size, 1)
if (norm != 0) 1 / norm else 0.0F
}
}
@@ -587,7 +587,7 @@ class Word2VecModel private[spark] (
val localVectorSize = vectorSize
val floatVec = vector.map(_.toFloat)
- val vecNorm = blas.snrm2(localVectorSize, floatVec, 1)
+ val vecNorm = BLAS.nativeBLAS.snrm2(localVectorSize, floatVec, 1)
val localWordList = wordList
val localNumWords = numWords
@@ -597,11 +597,11 @@ class Word2VecModel private[spark] (
.take(num)
.toArray
} else {
- // Normalize input vector before blas.sgemv to avoid Inf value
- blas.sscal(localVectorSize, 1 / vecNorm, floatVec, 0, 1)
+ // Normalize input vector before BLAS.nativeBLAS.sgemv to avoid Inf value
+ BLAS.nativeBLAS.sscal(localVectorSize, 1 / vecNorm, floatVec, 0, 1)
val cosineVec = Array.ofDim[Float](localNumWords)
- blas.sgemv("T", localVectorSize, localNumWords, 1.0F, wordVectors, localVectorSize,
+ BLAS.nativeBLAS.sgemv("T", localVectorSize, localNumWords, 1.0F, wordVectors, localVectorSize,
floatVec, 1, 0.0F, cosineVec, 1)
val localWordVecInvNorms = wordVecInvNorms
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala
new file mode 100644
index 0000000000..fb0f6ddd47
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import dev.ludovic.netlib.{ARPACK => NetlibARPACK,
+ JavaARPACK => NetlibJavaARPACK,
+ NativeARPACK => NetlibNativeARPACK}
+
+/**
+ * ARPACK routines for MLlib's vectors and matrices.
+ */
+private[spark] object ARPACK extends Serializable {
+
+ @transient private var _javaARPACK: NetlibARPACK = _
+ @transient private var _nativeARPACK: NetlibARPACK = _
+
+ private[spark] def javaARPACK: NetlibARPACK = {
+ if (_javaARPACK == null) {
+ _javaARPACK = NetlibJavaARPACK.getInstance
+ }
+ _javaARPACK
+ }
+
+ private[spark] def nativeARPACK: NetlibARPACK = {
+ if (_nativeARPACK == null) {
+ _nativeARPACK =
+ try { NetlibNativeARPACK.getInstance } catch { case _: Throwable => javaARPACK }
+ }
+ _nativeARPACK
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index bd60364326..e38cfe4e18 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -17,8 +17,9 @@
package org.apache.spark.mllib.linalg
-import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
-import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
+import dev.ludovic.netlib.{BLAS => NetlibBLAS,
+ JavaBLAS => NetlibJavaBLAS,
+ NativeBLAS => NetlibNativeBLAS}
import org.apache.spark.internal.Logging
@@ -27,21 +28,30 @@ import org.apache.spark.internal.Logging
*/
private[spark] object BLAS extends Serializable with Logging {
- @transient private var _f2jBLAS: NetlibBLAS = _
+ @transient private var _javaBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
private val nativeL1Threshold: Int = 256
- // For level-1 function dspmv, use f2jBLAS for better performance.
- private[mllib] def f2jBLAS: NetlibBLAS = {
- if (_f2jBLAS == null) {
- _f2jBLAS = new F2jBLAS
+ // For level-1 function dspmv, use javaBLAS for better performance.
+ private[spark] def javaBLAS: NetlibBLAS = {
+ if (_javaBLAS == null) {
+ _javaBLAS = NetlibJavaBLAS.getInstance
}
- _f2jBLAS
+ _javaBLAS
}
- private[mllib] def getBLAS(vectorSize: Int): NetlibBLAS = {
+ // For level-3 routines, we use the native BLAS.
+ private[spark] def nativeBLAS: NetlibBLAS = {
+ if (_nativeBLAS == null) {
+ _nativeBLAS =
+ try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS }
+ }
+ _nativeBLAS
+ }
+
+ private[spark] def getBLAS(vectorSize: Int): NetlibBLAS = {
if (vectorSize < nativeL1Threshold) {
- f2jBLAS
+ javaBLAS
} else {
nativeBLAS
}
@@ -237,14 +247,6 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
- // For level-3 routines, we use the native BLAS.
- private[mllib] def nativeBLAS: NetlibBLAS = {
- if (_nativeBLAS == null) {
- _nativeBLAS = NativeBLAS
- }
- _nativeBLAS
- }
-
/**
* Adds alpha * v * v.t to a matrix in-place. This is the same as BLAS's ?SPR.
*
@@ -263,7 +265,7 @@ private[spark] object BLAS extends Serializable with Logging {
val n = v.size
v match {
case DenseVector(values) =>
- NativeBLAS.dspr("U", n, alpha, values, 1, U)
+ nativeBLAS.dspr("U", n, alpha, values, 1, U)
case SparseVector(size, indices, values) =>
val nnz = indices.length
var colStartIdx = 0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala
index 68771f1afb..f06ea9418f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.linalg
-import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.netlib.util.intW
import org.apache.spark.ml.optim.SingularMatrixException
@@ -37,7 +36,7 @@ private[spark] object CholeskyDecomposition {
def solve(A: Array[Double], bx: Array[Double]): Array[Double] = {
val k = bx.length
val info = new intW(0)
- lapack.dppsv("U", k, 1, A, bx, k, info)
+ LAPACK.nativeLAPACK.dppsv("U", k, 1, A, bx, k, info)
checkReturnValue(info, "dppsv")
bx
}
@@ -52,7 +51,7 @@ private[spark] object CholeskyDecomposition {
*/
def inverse(UAi: Array[Double], k: Int): Array[Double] = {
val info = new intW(0)
- lapack.dpptri("U", k, UAi, info)
+ LAPACK.nativeLAPACK.dpptri("U", k, UAi, info)
checkReturnValue(info, "dpptri")
UAi
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
index 4c71cd6496..2cbf5d09dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
@@ -18,7 +18,6 @@
package org.apache.spark.mllib.linalg
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV}
-import com.github.fommil.netlib.ARPACK
import org.netlib.util.{doubleW, intW}
/**
@@ -51,8 +50,6 @@ private[mllib] object EigenValueDecomposition {
// TODO: remove this function and use eigs in breeze when switching breeze version
require(n > k, s"Number of required eigenvalues $k must be smaller than matrix dimension $n")
- val arpack = ARPACK.getInstance()
-
// tolerance used in stopping criterion
val tolW = new doubleW(tol)
// number of desired eigenvalues, 0 < nev < n
@@ -87,8 +84,8 @@ private[mllib] object EigenValueDecomposition {
val ipntr = new Array[Int](11)
// call ARPACK's reverse communication, first iteration with ido = 0
- arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr, workd,
- workl, workl.length, info)
+ ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv,
+ v, n, iparam, ipntr, workd, workl, workl.length, info)
val w = BDV(workd)
@@ -105,8 +102,8 @@ private[mllib] object EigenValueDecomposition {
val y = w.slice(outputOffset, outputOffset + n)
y := mul(x)
// call ARPACK's reverse communication
- arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr,
- workd, workl, workl.length, info)
+ ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv,
+ v, n, iparam, ipntr, workd, workl, workl.length, info)
}
if (info.`val` != 0) {
@@ -127,8 +124,8 @@ private[mllib] object EigenValueDecomposition {
val z = java.util.Arrays.copyOfRange(v, 0, nev.`val` * n)
// call ARPACK's post-processing for eigenvectors
- arpack.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid, ncv, v, n,
- iparam, ipntr, workd, workl, workl.length, info)
+ ARPACK.nativeARPACK.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid,
+ ncv, v, n, iparam, ipntr, workd, workl, workl.length, info)
// number of computed eigenvalues, might be smaller than k
val computed = iparam(4)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala
new file mode 100644
index 0000000000..4d25aed283
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import dev.ludovic.netlib.{JavaLAPACK => NetlibJavaLAPACK,
+ LAPACK => NetlibLAPACK,
+ NativeLAPACK => NetlibNativeLAPACK}
+
+/**
+ * LAPACK routines for MLlib's vectors and matrices.
+ */
+private[spark] object LAPACK extends Serializable {
+
+ @transient private var _javaLAPACK: NetlibLAPACK = _
+ @transient private var _nativeLAPACK: NetlibLAPACK = _
+
+ private[spark] def javaLAPACK: NetlibLAPACK = {
+ if (_javaLAPACK == null) {
+ _javaLAPACK = NetlibJavaLAPACK.getInstance
+ }
+ _javaLAPACK
+ }
+
+ private[spark] def nativeLAPACK: NetlibLAPACK = {
+ if (_nativeLAPACK == null) {
+ _nativeLAPACK =
+ try { NetlibNativeLAPACK.getInstance } catch { case _: Throwable => javaLAPACK }
+ }
+ _nativeLAPACK
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 57edc96511..e4f64b4e34 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, Has
import scala.language.implicitConversions
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
-import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{linalg => newlinalg}
@@ -427,7 +426,7 @@ class DenseMatrix @Since("1.3.0") (
if (isTransposed) {
Iterator.tabulate(numCols) { j =>
val col = new Array[Double](numRows)
- blas.dcopy(numRows, values, j, numCols, col, 0, 1)
+ BLAS.nativeBLAS.dcopy(numRows, values, j, numCols, col, 0, 1)
new DenseVector(col)
}
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
index 86632ae335..e070d605b1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.optimization
import java.{util => ju}
-import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.apache.spark.ml.linalg.BLAS
/**
* Object used to solve nonnegative least squares problems using a modified
@@ -75,10 +75,10 @@ private[spark] object NNLS {
// find the optimal unconstrained step
def steplen(dir: Array[Double], res: Array[Double]): Double = {
- val top = blas.ddot(n, dir, 1, res, 1)
- blas.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1)
+ val top = BLAS.nativeBLAS.ddot(n, dir, 1, res, 1)
+ BLAS.nativeBLAS.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1)
// Push the denominator upward very slightly to avoid infinities and silliness
- top / (blas.ddot(n, scratch, 1, dir, 1) + 1e-20)
+ top / (BLAS.nativeBLAS.ddot(n, scratch, 1, dir, 1) + 1e-20)
}
// stopping condition
@@ -103,9 +103,9 @@ private[spark] object NNLS {
var i = 0
while (iterno < iterMax) {
// find the residual
- blas.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1)
- blas.daxpy(n, -1.0, atb, 1, res, 1)
- blas.dcopy(n, res, 1, grad, 1)
+ BLAS.nativeBLAS.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1)
+ BLAS.nativeBLAS.daxpy(n, -1.0, atb, 1, res, 1)
+ BLAS.nativeBLAS.dcopy(n, res, 1, grad, 1)
// project the gradient
i = 0
@@ -115,28 +115,28 @@ private[spark] object NNLS {
}
i = i + 1
}
- val ngrad = blas.ddot(n, grad, 1, grad, 1)
+ val ngrad = BLAS.nativeBLAS.ddot(n, grad, 1, grad, 1)
- blas.dcopy(n, grad, 1, dir, 1)
+ BLAS.nativeBLAS.dcopy(n, grad, 1, dir, 1)
// use a CG direction under certain conditions
var step = steplen(grad, res)
var ndir = 0.0
- val nx = blas.ddot(n, x, 1, x, 1)
+ val nx = BLAS.nativeBLAS.ddot(n, x, 1, x, 1)
if (iterno > lastWall + 1) {
val alpha = ngrad / lastNorm
- blas.daxpy(n, alpha, lastDir, 1, dir, 1)
+ BLAS.nativeBLAS.daxpy(n, alpha, lastDir, 1, dir, 1)
val dstep = steplen(dir, res)
- ndir = blas.ddot(n, dir, 1, dir, 1)
+ ndir = BLAS.nativeBLAS.ddot(n, dir, 1, dir, 1)
if (stop(dstep, ndir, nx)) {
// reject the CG step if it could lead to premature termination
- blas.dcopy(n, grad, 1, dir, 1)
- ndir = blas.ddot(n, dir, 1, dir, 1)
+ BLAS.nativeBLAS.dcopy(n, grad, 1, dir, 1)
+ ndir = BLAS.nativeBLAS.ddot(n, dir, 1, dir, 1)
} else {
step = dstep
}
} else {
- ndir = blas.ddot(n, dir, 1, dir, 1)
+ ndir = BLAS.nativeBLAS.ddot(n, dir, 1, dir, 1)
}
// terminate?
@@ -166,7 +166,7 @@ private[spark] object NNLS {
}
iterno = iterno + 1
- blas.dcopy(n, dir, 1, lastDir, 1)
+ BLAS.nativeBLAS.dcopy(n, dir, 1, lastDir, 1)
lastNorm = ngrad
}
x.clone
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index b1be5225ce..3276513213 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -21,7 +21,6 @@ import java.io.IOException
import java.lang.{Integer => JavaInteger}
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
-import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.google.common.collect.{Ordering => GuavaOrdering}
import org.apache.hadoop.fs.Path
import org.json4s._
@@ -85,7 +84,7 @@ class MatrixFactorizationModel @Since("0.8.0") (
val userVector = userFeatureSeq.head
val productVector = productFeatureSeq.head
- blas.ddot(rank, userVector, 1, productVector, 1)
+ BLAS.nativeBLAS.ddot(rank, userVector, 1, productVector, 1)
}
/**
@@ -136,7 +135,7 @@ class MatrixFactorizationModel @Since("0.8.0") (
}
users.join(productFeatures).map {
case (product, ((user, uFeatures), pFeatures)) =>
- Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+ Rating(user, product, BLAS.nativeBLAS.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
}
} else {
val products = productFeatures.join(usersProducts.map(_.swap)).map {
@@ -144,7 +143,7 @@ class MatrixFactorizationModel @Since("0.8.0") (
}
products.join(userFeatures).map {
case (user, ((product, pFeatures), uFeatures)) =>
- Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+ Rating(user, product, BLAS.nativeBLAS.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
}
}
}
@@ -263,7 +262,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
recommendableFeatures: RDD[(Int, Array[Double])],
num: Int): Array[(Int, Double)] = {
val scored = recommendableFeatures.map { case (id, features) =>
- (id, blas.ddot(features.length, recommendToFeatures, 1, features, 1))
+ (id, BLAS.nativeBLAS.ddot(features.length, recommendToFeatures, 1, features, 1))
}
scored.top(num)(Ordering.by(_._2))
}
@@ -320,7 +319,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
Iterator.range(0, m).flatMap { i =>
// scores = i-th vec in srcMat * dstMat
- BLAS.f2jBLAS.dgemv("T", rank, n, 1.0F, dstMat, 0, rank,
+ BLAS.javaBLAS.dgemv("T", rank, n, 1.0F, dstMat, 0, rank,
srcMat, i * rank, 1, 0.0F, scores, 0, 1)
val srcId = srcIds(i)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
index f253963270..f0236f0528 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
@@ -17,10 +17,9 @@
package org.apache.spark.mllib.stat
-import com.github.fommil.netlib.BLAS.{getInstance => blas}
-
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.ml.linalg.BLAS
import org.apache.spark.rdd.RDD
/**
@@ -99,10 +98,10 @@ class KernelDensity extends Serializable {
(x._1, x._2 + 1)
},
(x, y) => {
- blas.daxpy(n, 1.0, y._1, 1, x._1, 1)
+ BLAS.nativeBLAS.daxpy(n, 1.0, y._1, 1, x._1, 1)
(x._1, x._2 + y._2)
})
- blas.dscal(n, 1.0 / count, densities, 1)
+ BLAS.nativeBLAS.dscal(n, 1.0 / count, densities, 1)
densities
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index c5069277fa..1f879a4d9d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree.model
import scala.collection.mutable
-import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
@@ -28,6 +27,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.linalg.BLAS
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo
@@ -280,7 +280,7 @@ private[tree] sealed class TreeEnsembleModel(
*/
private def predictBySumming(features: Vector): Double = {
val treePredictions = trees.map(_.predict(features))
- blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)
+ BLAS.nativeBLAS.ddot(numTrees, treePredictions, 1, treeWeights, 1)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
index 9fffa508af..0f99cef665 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
@@ -19,10 +19,9 @@ package org.apache.spark.mllib.util
import scala.util.Random
-import com.github.fommil.netlib.BLAS.{getInstance => blas}
-
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.linalg.BLAS
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -61,7 +60,8 @@ object SVMDataGenerator {
val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0
}
- val yD = blas.ddot(trueWeights.length, x, 1, trueWeights, 1) + rnd.nextGaussian() * 0.1
+ val yD = BLAS.nativeBLAS.ddot(trueWeights.length, x, 1, trueWeights, 1)
+ + rnd.nextGaussian() * 0.1
val y = if (yD < 0) 0.0 else 1.0
LabeledPoint(y, Vectors.dense(x))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index d8b9c6a606..d18a950a01 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -23,9 +23,8 @@ import breeze.linalg.{DenseVector => BDV}
import org.scalatest.Assertions._
import org.apache.spark.ml.classification.LinearSVCSuite._
-import org.apache.spark.ml.feature.{Instance, LabeledPoint}
+import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.ml.optim.aggregator.HingeAggregator
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
@@ -176,28 +175,13 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
assert(model2.intercept !== 0.0)
}
- test("sparse coefficients in HingeAggregator") {
- val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0)))
- val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0))
- val agg = new HingeAggregator(bcFeaturesStd, true)(bcCoefficients)
- val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") {
- intercept[IllegalArgumentException] {
- agg.add(Instance(1.0, 1.0, Vectors.dense(1.0)))
- }
- }
- assert(thrown.getMessage.contains("coefficients only supports dense"))
-
- bcCoefficients.destroy()
- bcFeaturesStd.destroy()
- }
-
test("linearSVC with sample weights") {
def modelEquals(m1: LinearSVCModel, m2: LinearSVCModel): Unit = {
- assert(m1.coefficients ~== m2.coefficients absTol 0.05)
+ assert(m1.coefficients ~== m2.coefficients relTol 0.05)
assert(m1.intercept ~== m2.intercept absTol 0.05)
}
- val estimator = new LinearSVC().setRegParam(0.01).setTol(0.01)
+ val estimator = new LinearSVC().setRegParam(0.01).setTol(0.001)
val dataset = smallBinaryDataset
MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC](
dataset.as[LabeledPoint], estimator, modelEquals)
@@ -237,7 +221,7 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
val model1 = trainer1.fit(binaryDataset)
/*
- Use the following R code to load the data and train the model using glmnet package.
+ Use the following R code to load the data and train the model using e1071 package.
library(e1071)
data <- read.csv("path/target/tmp/LinearSVC/binaryDataset/part-00000", header=FALSE)
@@ -257,8 +241,8 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
*/
val coefficientsR = Vectors.dense(7.310338, 14.89741, 22.21005, 29.83508)
val interceptR = 7.440177
- assert(model1.intercept ~== interceptR relTol 1E-2)
- assert(model1.coefficients ~== coefficientsR relTol 1E-2)
+ assert(model1.intercept ~== interceptR relTol 1E-3)
+ assert(model1.coefficients ~== coefficientsR relTol 5E-3)
/*
Use the following python code to load the data and train the model using scikit-learn package.
@@ -280,8 +264,8 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
val coefficientsSK = Vectors.dense(7.24690165, 14.77029087, 21.99924004, 29.5575729)
val interceptSK = 7.36947518
- assert(model1.intercept ~== interceptSK relTol 1E-3)
- assert(model1.coefficients ~== coefficientsSK relTol 4E-3)
+ assert(model1.intercept ~== interceptSK relTol 1E-2)
+ assert(model1.coefficients ~== coefficientsSK relTol 1E-2)
}
test("summary and training summary") {
@@ -379,8 +363,8 @@ object LinearSVCSuite {
}
def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = {
- assert(model1.intercept == model2.intercept)
- assert(model1.coefficients.equals(model2.coefficients))
+ assert(model1.intercept ~== model2.intercept relTol 1e-9)
+ assert(model1.coefficients ~== model2.coefficients relTol 1e-9)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala
new file mode 100644
index 0000000000..029911adb4
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.optim.aggregator
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.{Instance, InstanceBlock}
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.stat.Summarizer
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ @transient var instances: Array[Instance] = _
+ @transient var instancesConstantFeature: Array[Instance] = _
+ @transient var instancesConstantFeatureFiltered: Array[Instance] = _
+ @transient var scaledInstances: Array[Instance] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ instances = Array(
+ Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)),
+ Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)),
+ Instance(0.0, 0.3, Vectors.dense(4.0, 0.5))
+ )
+ instancesConstantFeature = Array(
+ Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)),
+ Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)),
+ Instance(1.0, 0.3, Vectors.dense(1.0, 0.5))
+ )
+ instancesConstantFeatureFiltered = Array(
+ Instance(0.0, 0.1, Vectors.dense(2.0)),
+ Instance(1.0, 0.5, Vectors.dense(1.0)),
+ Instance(1.0, 0.3, Vectors.dense(0.5))
+ )
+ scaledInstances = standardize(instances)
+ }
+
+
+ /** Get summary statistics for some data and create a new HingeBlockAggregator. */
+ private def getNewAggregator(
+ instances: Array[Instance],
+ coefficients: Vector,
+ fitIntercept: Boolean): HingeBlockAggregator = {
+ val (featuresSummarizer, _) =
+ Summarizer.getClassificationSummarizers(sc.parallelize(instances))
+ val featuresStd = featuresSummarizer.std.toArray
+ val featuresMean = featuresSummarizer.mean.toArray
+ val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0)
+ val scaledMean = inverseStd.zip(featuresMean).map(t => t._1 * t._2)
+ val bcInverseStd = sc.broadcast(inverseStd)
+ val bcScaledMean = sc.broadcast(scaledMean)
+ val bcCoefficients = sc.broadcast(coefficients)
+ new HingeBlockAggregator(bcInverseStd, bcScaledMean, fitIntercept)(bcCoefficients)
+ }
+
+ test("sparse coefficients") {
+ val bcInverseStd = sc.broadcast(Array(1.0))
+ val bcScaledMean = sc.broadcast(Array(2.0))
+ val bcCoefficients = sc.broadcast(Vectors.sparse(2, Array(0), Array(1.0)))
+ val binaryAgg = new HingeBlockAggregator(bcInverseStd, bcScaledMean,
+ fitIntercept = false)(bcCoefficients)
+ val block = InstanceBlock.fromInstances(Seq(Instance(1.0, 1.0, Vectors.dense(1.0))))
+ val thrownBinary = withClue("aggregator cannot handle sparse coefficients") {
+ intercept[IllegalArgumentException] {
+ binaryAgg.add(block)
+ }
+ }
+ assert(thrownBinary.getMessage.contains("coefficients only supports dense"))
+ }
+
+ test("aggregator add method input size") {
+ val coefArray = Array(1.0, 2.0)
+ val interceptValue = 4.0
+ val agg = getNewAggregator(instances, Vectors.dense(coefArray :+ interceptValue),
+ fitIntercept = true)
+ val block = InstanceBlock.fromInstances(Seq(Instance(1.0, 1.0, Vectors.dense(2.0))))
+ withClue("BinaryLogisticBlockAggregator features dimension must match coefficients dimension") {
+ intercept[IllegalArgumentException] {
+ agg.add(block)
+ }
+ }
+ }
+
+ test("negative weight") {
+ val coefArray = Array(1.0, 2.0)
+ val interceptValue = 4.0
+ val agg = getNewAggregator(instances, Vectors.dense(coefArray :+ interceptValue),
+ fitIntercept = true)
+ val block = InstanceBlock.fromInstances(Seq(Instance(1.0, -1.0, Vectors.dense(2.0, 1.0))))
+ withClue("BinaryLogisticBlockAggregator does not support negative instance weights") {
+ intercept[IllegalArgumentException] {
+ agg.add(block)
+ }
+ }
+ }
+
+ test("check sizes") {
+ val rng = new scala.util.Random
+ val numFeatures = instances.head.features.size
+ val coefWithIntercept = Vectors.dense(Array.fill(numFeatures + 1)(rng.nextDouble))
+ val coefWithoutIntercept = Vectors.dense(Array.fill(numFeatures)(rng.nextDouble))
+ val block = InstanceBlock.fromInstances(instances)
+
+ val aggIntercept = getNewAggregator(instances, coefWithIntercept, fitIntercept = true)
+ aggIntercept.add(block)
+ assert(aggIntercept.gradient.size === numFeatures + 1)
+
+ val aggNoIntercept = getNewAggregator(instances, coefWithoutIntercept, fitIntercept = false)
+ aggNoIntercept.add(block)
+ assert(aggNoIntercept.gradient.size === numFeatures)
+ }
+
+ test("check correctness: fitIntercept = false") {
+ val coefVec = Vectors.dense(1.0, 2.0)
+ val numFeatures = instances.head.features.size
+ val (featuresSummarizer, _) =
+ Summarizer.getClassificationSummarizers(sc.parallelize(instances))
+ val featuresStd = featuresSummarizer.std
+ val stdCoefVec = Vectors.dense(Array.tabulate(coefVec.size)(i => coefVec(i) / featuresStd(i)))
+ val weightSum = instances.map(_.weight).sum
+
+ // compute the loss and the gradients
+ var lossSum = 0.0
+ val gradientCoef = Array.ofDim[Double](numFeatures)
+ instances.foreach { case Instance(l, w, f) =>
+ val margin = BLAS.dot(stdCoefVec, f)
+ val labelScaled = 2 * l - 1.0
+ if (1.0 > labelScaled * margin) {
+ lossSum += (1.0 - labelScaled * margin) * w
+ gradientCoef.indices.foreach { i =>
+ gradientCoef(i) += f(i) * -(2 * l - 1.0) * w / featuresStd(i)
+ }
+ }
+ }
+ val loss = lossSum / weightSum
+ val gradient = Vectors.dense(gradientCoef.map(_ / weightSum))
+
+ Seq(1, 2, 4).foreach { blockSize =>
+ val blocks1 = scaledInstances
+ .grouped(blockSize)
+ .map(seq => InstanceBlock.fromInstances(seq))
+ .toArray
+ val blocks2 = blocks1.map { block =>
+ new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor)
+ }
+
+ Seq(blocks1, blocks2).foreach { blocks =>
+ val agg = getNewAggregator(instances, coefVec, fitIntercept = false)
+ blocks.foreach(agg.add)
+ assert(agg.loss ~== loss relTol 1e-9)
+ assert(agg.gradient ~== gradient relTol 1e-9)
+ }
+ }
+ }
+
+ test("check correctness: fitIntercept = true") {
+ val coefVec = Vectors.dense(1.0, 2.0)
+ val interceptValue = 1.0
+ val numFeatures = instances.head.features.size
+ val (featuresSummarizer, _) =
+ Summarizer.getClassificationSummarizers(sc.parallelize(instances))
+ val featuresStd = featuresSummarizer.std
+ val featuresMean = featuresSummarizer.mean
+ val stdCoefVec = Vectors.dense(Array.tabulate(coefVec.size)(i => coefVec(i) / featuresStd(i)))
+ val weightSum = instances.map(_.weight).sum
+
+ // compute the loss and the gradients
+ var lossSum = 0.0
+ val gradientCoef = Array.ofDim[Double](numFeatures)
+ var gradientIntercept = 0.0
+ instances.foreach { case Instance(l, w, f) =>
+ val centered = f.toDense.copy
+ BLAS.axpy(-1.0, featuresMean, centered)
+ val margin = BLAS.dot(stdCoefVec, centered) + interceptValue
+ val labelScaled = 2 * l - 1.0
+ if (1.0 > labelScaled * margin) {
+ lossSum += (1.0 - labelScaled * margin) * w
+ gradientCoef.indices.foreach { i =>
+ gradientCoef(i) += (f(i) - featuresMean(i)) * -(2 * l - 1.0) * w / featuresStd(i)
+ }
+ gradientIntercept += -(2 * l - 1.0) * w
+ }
+ }
+ val loss = lossSum / weightSum
+ val gradient = Vectors.dense((gradientCoef :+ gradientIntercept).map(_ / weightSum))
+
+ Seq(1, 2, 4).foreach { blockSize =>
+ val blocks1 = scaledInstances
+ .grouped(blockSize)
+ .map(seq => InstanceBlock.fromInstances(seq))
+ .toArray
+ val blocks2 = blocks1.map { block =>
+ new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor)
+ }
+
+ Seq(blocks1, blocks2).foreach { blocks =>
+ val agg = getNewAggregator(instances, Vectors.dense(coefVec.toArray :+ interceptValue),
+ fitIntercept = true)
+ blocks.foreach(agg.add)
+ assert(agg.loss ~== loss relTol 1e-9)
+ assert(agg.gradient ~== gradient relTol 1e-9)
+ }
+ }
+ }
+
+ test("check with zero standard deviation") {
+ val coefArray = Array(1.0, 2.0)
+ val coefArrayFiltered = Array(2.0)
+ val interceptValue = 1.0
+
+ Seq(false, true).foreach { fitIntercept =>
+ val coefVec = if (fitIntercept) {
+ Vectors.dense(coefArray :+ interceptValue)
+ } else {
+ Vectors.dense(coefArray)
+ }
+ val aggConstantFeature = getNewAggregator(instancesConstantFeature,
+ coefVec, fitIntercept = fitIntercept)
+ aggConstantFeature
+ .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature)))
+ val grad = aggConstantFeature.gradient
+
+ val coefVecFiltered = if (fitIntercept) {
+ Vectors.dense(coefArrayFiltered :+ interceptValue)
+ } else {
+ Vectors.dense(coefArrayFiltered)
+ }
+ val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered,
+ coefVecFiltered, fitIntercept = fitIntercept)
+ aggConstantFeatureFiltered
+ .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered)))
+ val gradFiltered = aggConstantFeatureFiltered.gradient
+
+ // constant features should not affect gradient
+ assert(aggConstantFeature.loss ~== aggConstantFeatureFiltered.loss relTol 1e-9)
+ assert(grad(0) === 0)
+ assert(grad(1) ~== gradFiltered(0) relTol 1e-9)
+ if (fitIntercept) {
+ assert(grad.toArray.last ~== gradFiltered.toArray.last relTol 1e-9)
+ }
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 12ab2ac3cc..91d1e9a447 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.mllib.util.TestingUtils._
class BLASSuite extends SparkFunSuite {
test("nativeL1Threshold") {
- assert(getBLAS(128) == BLAS.f2jBLAS)
+ assert(getBLAS(128) == BLAS.javaBLAS)
assert(getBLAS(256) == BLAS.nativeBLAS)
assert(getBLAS(512) == BLAS.nativeBLAS)
}
diff --git a/pom.xml b/pom.xml
index 22d794ccde..9402fd4528 100644
--- a/pom.xml
+++ b/pom.xml
@@ -133,12 +133,12 @@
2.3
- 2.6.0
+ 2.8.0
10.14.2.0
1.12.0
1.6.7
- 9.4.39.v20210325
+ 9.4.40.v20210413
4.0.3
0.9.5
2.4.0
@@ -172,6 +172,7 @@
2.12.2
1.1.8.2
1.1.2
+ 1.3.2
1.15
1.20
2.8.0
@@ -2455,6 +2456,21 @@
commons-cli
${commons-cli.version}
+
+ dev.ludovic.netlib
+ blas
+ ${netlib.ludovic.dev.version}
+
+
+ dev.ludovic.netlib
+ lapack
+ ${netlib.ludovic.dev.version}
+
+
+ dev.ludovic.netlib
+ arpack
+ ${netlib.ludovic.dev.version}
+
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 54ac3c19fa..906065ca09 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -294,7 +294,7 @@ object SparkBuild extends PomBuild {
javaOptions ++= {
val versionParts = System.getProperty("java.version").split("[+.\\-]+", 3)
var major = versionParts(0).toInt
- if (major >= 16) Seq("--add-modules=jdk.incubator.vector") else Seq.empty
+ if (major >= 16) Seq("--add-modules=jdk.incubator.vector,jdk.incubator.foreign", "-Dforeign.restricted=warn") else Seq.empty
},
(Compile / doc / javacOptions) ++= {
@@ -414,6 +414,10 @@ object SparkBuild extends PomBuild {
enable(YARN.settings)(yarn)
+ if (profiles.contains("sparkr")) {
+ enable(SparkR.settings)(core)
+ }
+
/**
* Adds the ability to run the spark shell directly from SBT without building an assembly
* jar.
@@ -888,6 +892,25 @@ object PySparkAssembly {
}
+object SparkR {
+ import scala.sys.process.Process
+
+ val buildRPackage = taskKey[Unit]("Build the R package")
+ lazy val settings = Seq(
+ buildRPackage := {
+ val command = baseDirectory.value / ".." / "R" / "install-dev.sh"
+ Process(command.toString).!!
+ },
+ (Compile / compile) := (Def.taskDyn {
+ val c = (Compile / compile).value
+ Def.task {
+ (Compile / buildRPackage).value
+ c
+ }
+ }).value
+ )
+}
+
object Unidoc {
import BuildCommons._
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 17994ed5e3..620760905a 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -571,9 +571,9 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl
>>> model.getMaxBlockSizeInMB()
0.0
>>> model.coefficients
- DenseVector([0.0, -0.2792, -0.1833])
+ DenseVector([0.0, -1.0319, -0.5159])
>>> model.intercept
- 1.0206118982229047
+ 2.579645978780695
>>> model.numClasses
2
>>> model.numFeatures
@@ -582,12 +582,12 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl
>>> model.predict(test0.head().features)
1.0
>>> model.predictRaw(test0.head().features)
- DenseVector([-1.4831, 1.4831])
+ DenseVector([-4.1274, 4.1274])
>>> result = model.transform(test0).head()
>>> result.newPrediction
1.0
>>> result.rawPrediction
- DenseVector([-1.4831, 1.4831])
+ DenseVector([-4.1274, 4.1274])
>>> svm_path = temp_path + "/svm"
>>> svm.save(svm_path)
>>> svm2 = LinearSVC.load(svm_path)
diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py
index fb245a3d05..1eadbd6942 100644
--- a/python/pyspark/ml/functions.py
+++ b/python/pyspark/ml/functions.py
@@ -71,7 +71,8 @@ def vector_to_array(col, dtype="float64"):
def array_to_vector(col):
"""
- Converts a column of array of numeric type into a column of dense vectors in MLlib
+ Converts a column of array of numeric type into a column of pyspark.ml.linalg.DenseVector
+ instances
.. versionadded:: 3.1.0
@@ -83,7 +84,7 @@ def array_to_vector(col):
Returns
-------
:py:class:`pyspark.sql.Column`
- The converted column of MLlib dense vectors.
+ The converted column of dense vectors.
Examples
--------
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 28c4499f77..5bc1801a0c 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -260,9 +260,9 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable):
>>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
>>> predictions[0]
- Row(user=0, item=2, newPrediction=0.692910...)
+ Row(user=0, item=2, newPrediction=0.69291...)
>>> predictions[1]
- Row(user=1, item=0, newPrediction=3.473569...)
+ Row(user=1, item=0, newPrediction=3.47356...)
>>> predictions[2]
Row(user=2, item=0, newPrediction=-0.899198...)
>>> user_recs = model.recommendForAllUsers(3)
diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py
index 7dafdcb3d6..5b31c871fb 100644
--- a/python/pyspark/ml/tests/test_training_summary.py
+++ b/python/pyspark/ml/tests/test_training_summary.py
@@ -223,12 +223,12 @@ def test_linear_svc_summary(self):
self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
print(s.weightedTruePositiveRate)
- self.assertAlmostEqual(s.weightedTruePositiveRate, 0.5, 2)
- self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.5, 2)
- self.assertAlmostEqual(s.weightedRecall, 0.5, 2)
- self.assertAlmostEqual(s.weightedPrecision, 0.25, 2)
- self.assertAlmostEqual(s.weightedFMeasure(), 0.3333333333333333, 2)
- self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.3333333333333333, 2)
+ self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
+ self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
+ self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py
index aeb603c5e7..a0eb243a6c 100644
--- a/python/pyspark/pandas/tests/indexes/test_base.py
+++ b/python/pyspark/pandas/tests/indexes/test_base.py
@@ -22,7 +22,6 @@
import numpy as np
import pandas as pd
-import pyspark
import pyspark.pandas as ps
from pyspark.pandas.exceptions import PandasNotImplementedError
@@ -32,10 +31,10 @@
MissingPandasLikeIndex,
MissingPandasLikeMultiIndex,
)
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils, SPARK_CONF_ARROW_ENABLED
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils, SPARK_CONF_ARROW_ENABLED
-class IndexesTest(ReusedSQLTestCase, TestUtils):
+class IndexesTest(PandasOnSparkTestCase, TestUtils):
@property
def pdf(self):
return pd.DataFrame(
@@ -280,12 +279,7 @@ def test_multi_index_names(self):
pidx.names = ["renamed_number", None]
kidx.names = ["renamed_number", None]
self.assertEqual(kidx.names, pidx.names)
- if LooseVersion(pyspark.__version__) < LooseVersion("2.4"):
- # PySpark < 2.4 does not support struct type with arrow enabled.
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self.assert_eq(kidx, pidx)
- else:
- self.assert_eq(kidx, pidx)
+ self.assert_eq(kidx, pidx)
with self.assertRaises(PandasNotImplementedError):
kidx.name
@@ -1401,11 +1395,7 @@ def test_asof(self):
self.assert_eq(kidx.asof("2014-01-01"), pidx.asof("2014-01-01"))
self.assert_eq(kidx.asof("2014-01-02"), pidx.asof("2014-01-02"))
- if LooseVersion(pyspark.__version__) >= LooseVersion("3.0"):
- self.assert_eq(repr(kidx.asof("1999-01-02")), repr(pidx.asof("1999-01-02")))
- else:
- # FIXME: self.assert_eq(repr(kidx.asof("1999-01-02")), repr(pidx.asof("1999-01-02")))
- pass
+ self.assert_eq(repr(kidx.asof("1999-01-02")), repr(pidx.asof("1999-01-02")))
# Decreasing values
pidx = pd.Index(["2014-01-03", "2014-01-02", "2013-12-31"])
@@ -1427,11 +1417,7 @@ def test_asof(self):
self.assert_eq(kidx.asof("2014-01-01"), pd.Timestamp("2014-01-02 00:00:00"))
self.assert_eq(kidx.asof("2014-01-02"), pd.Timestamp("2014-01-02 00:00:00"))
self.assert_eq(kidx.asof("1999-01-02"), pd.Timestamp("2013-12-31 00:00:00"))
- if LooseVersion(pyspark.__version__) >= LooseVersion("3.0"):
- self.assert_eq(repr(kidx.asof("2015-01-02")), repr(pd.NaT))
- else:
- # FIXME: self.assert_eq(repr(kidx.asof("2015-01-02")), repr(pd.NaT))
- pass
+ self.assert_eq(repr(kidx.asof("2015-01-02")), repr(pd.NaT))
# Not increasing, neither decreasing (ValueError)
kidx = ps.Index(["2013-12-31", "2015-01-02", "2014-01-03"])
@@ -2249,13 +2235,7 @@ def test_to_list(self):
kmidx = ps.from_pandas(pmidx)
self.assert_eq(kidx.tolist(), pidx.tolist())
-
- if LooseVersion(pyspark.__version__) < LooseVersion("2.4"):
- # PySpark < 2.4 does not support struct type with arrow enabled.
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self.assert_eq(kmidx.tolist(), pmidx.tolist())
- else:
- self.assert_eq(kidx.tolist(), pidx.tolist())
+ self.assert_eq(kmidx.tolist(), pmidx.tolist())
def test_index_ops(self):
pidx = pd.Index([1, 2, 3, 4, 5])
diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py
index 0fe5eeb209..360e31863d 100644
--- a/python/pyspark/pandas/tests/indexes/test_category.py
+++ b/python/pyspark/pandas/tests/indexes/test_category.py
@@ -21,10 +21,10 @@
from pandas.api.types import CategoricalDtype
import pyspark.pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class CategoricalIndexTest(ReusedSQLTestCase, TestUtils):
+class CategoricalIndexTest(PandasOnSparkTestCase, TestUtils):
def test_categorical_index(self):
pidx = pd.CategoricalIndex([1, 2, 3])
kidx = ps.CategoricalIndex([1, 2, 3])
diff --git a/python/pyspark/pandas/tests/indexes/test_datetime.py b/python/pyspark/pandas/tests/indexes/test_datetime.py
index 407565b46d..af511ed6c2 100644
--- a/python/pyspark/pandas/tests/indexes/test_datetime.py
+++ b/python/pyspark/pandas/tests/indexes/test_datetime.py
@@ -22,10 +22,10 @@
import pandas as pd
import pyspark.pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class DatetimeIndexTest(ReusedSQLTestCase, TestUtils):
+class DatetimeIndexTest(PandasOnSparkTestCase, TestUtils):
@property
def fixed_freqs(self):
return [
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot.py b/python/pyspark/pandas/tests/plot/test_frame_plot.py
index 70822e1e32..b57acd4959 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot.py
@@ -22,10 +22,10 @@
from pyspark.pandas.config import set_option, reset_option, option_context
from pyspark.pandas.plot import TopNPlotBase, SampledPlotBase, HistogramPlotBase
from pyspark.pandas.exceptions import PandasNotImplementedError
-from pyspark.pandas.testing.utils import ReusedSQLTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
-class DataFramePlotTest(ReusedSQLTestCase):
+class DataFramePlotTest(PandasOnSparkTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py
index 6e8f3c0256..5de5c90856 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py
@@ -25,7 +25,12 @@
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
-from pyspark.pandas.testing.utils import have_matplotlib, ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import (
+ have_matplotlib,
+ matplotlib_requirement_message,
+ PandasOnSparkTestCase,
+ TestUtils,
+)
if have_matplotlib:
import matplotlib
@@ -34,8 +39,8 @@
matplotlib.use("agg")
-@unittest.skipIf(not have_matplotlib, "matplotlib is not installed.")
-class DataFramePlotMatplotlibTest(ReusedSQLTestCase, TestUtils):
+@unittest.skipIf(not have_matplotlib, matplotlib_requirement_message)
+class DataFramePlotMatplotlibTest(PandasOnSparkTestCase, TestUtils):
sample_ratio_default = None
@classmethod
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
index dca5a4307b..33d6bef2c8 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
@@ -24,7 +24,12 @@
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils, have_plotly
+from pyspark.testing.pandasutils import (
+ have_plotly,
+ plotly_requirement_message,
+ PandasOnSparkTestCase,
+ TestUtils,
+)
from pyspark.pandas.utils import name_like_string
if have_plotly:
@@ -34,10 +39,10 @@
@unittest.skipIf(
not have_plotly or LooseVersion(pd.__version__) < "1.0.0",
- "plotly is not installed or pandas<1.0. pandas<1.0 does not support latest plotly "
+ plotly_requirement_message + " Or pandas<1.0; pandas<1.0 does not support latest plotly "
"and/or 'plotting.backend' option.",
)
-class DataFramePlotPlotlyTest(ReusedSQLTestCase, TestUtils):
+class DataFramePlotPlotlyTest(PandasOnSparkTestCase, TestUtils):
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/python/pyspark/pandas/tests/plot/test_series_plot.py b/python/pyspark/pandas/tests/plot/test_series_plot.py
index 4292c960a2..fbfda88648 100644
--- a/python/pyspark/pandas/tests/plot/test_series_plot.py
+++ b/python/pyspark/pandas/tests/plot/test_series_plot.py
@@ -22,7 +22,7 @@
from pyspark import pandas as ps
from pyspark.pandas.plot import PandasOnSparkPlotAccessor, BoxPlotBase
-from pyspark.pandas.testing.utils import have_plotly
+from pyspark.testing.pandasutils import have_plotly, plotly_requirement_message
class SeriesPlotTest(unittest.TestCase):
@@ -36,7 +36,7 @@ def pdf1(self):
def kdf1(self):
return ps.from_pandas(self.pdf1)
- @unittest.skipIf(not have_plotly, "plotly is not installed")
+ @unittest.skipIf(not have_plotly, plotly_requirement_message)
def test_plot_backends(self):
plot_backend = "plotly"
diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py
index 6bef5c9316..364a39bff8 100644
--- a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py
+++ b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py
@@ -25,7 +25,12 @@
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
-from pyspark.pandas.testing.utils import have_matplotlib, ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import (
+ have_matplotlib,
+ matplotlib_requirement_message,
+ PandasOnSparkTestCase,
+ TestUtils,
+)
if have_matplotlib:
import matplotlib
@@ -34,8 +39,8 @@
matplotlib.use("agg")
-@unittest.skipIf(not have_matplotlib, "matplotlib is not installed.")
-class SeriesPlotMatplotlibTest(ReusedSQLTestCase, TestUtils):
+@unittest.skipIf(not have_matplotlib, matplotlib_requirement_message)
+class SeriesPlotMatplotlibTest(PandasOnSparkTestCase, TestUtils):
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
index 5c0f2f7e89..2a14d373d2 100644
--- a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
+++ b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
@@ -24,8 +24,13 @@
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
-from pyspark.pandas.testing.utils import have_plotly, ReusedSQLTestCase, TestUtils
from pyspark.pandas.utils import name_like_string
+from pyspark.testing.pandasutils import (
+ have_plotly,
+ plotly_requirement_message,
+ PandasOnSparkTestCase,
+ TestUtils,
+)
if have_plotly:
from plotly import express
@@ -34,10 +39,10 @@
@unittest.skipIf(
not have_plotly or LooseVersion(pd.__version__) < "1.0.0",
- "plotly is not installed or pandas<1.0. pandas<1.0 does not support latest plotly "
+ plotly_requirement_message + " Or pandas<1.0; pandas<1.0 does not support latest plotly "
"and/or 'plotting.backend' option.",
)
-class SeriesPlotPlotlyTest(ReusedSQLTestCase, TestUtils):
+class SeriesPlotPlotlyTest(PandasOnSparkTestCase, TestUtils):
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/python/pyspark/pandas/tests/test_categorical.py b/python/pyspark/pandas/tests/test_categorical.py
index 90e37ddbf6..28de94bbcb 100644
--- a/python/pyspark/pandas/tests/test_categorical.py
+++ b/python/pyspark/pandas/tests/test_categorical.py
@@ -22,10 +22,10 @@
from pandas.api.types import CategoricalDtype
import pyspark.pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class CategoricalTest(ReusedSQLTestCase, TestUtils):
+class CategoricalTest(PandasOnSparkTestCase, TestUtils):
@property
def pdf(self):
return pd.DataFrame(
diff --git a/python/pyspark/pandas/tests/test_config.py b/python/pyspark/pandas/tests/test_config.py
index 1fb2cd344a..ba717a9712 100644
--- a/python/pyspark/pandas/tests/test_config.py
+++ b/python/pyspark/pandas/tests/test_config.py
@@ -18,10 +18,10 @@
from pyspark import pandas as ps
from pyspark.pandas import config
from pyspark.pandas.config import Option, DictWrapper
-from pyspark.pandas.testing.utils import ReusedSQLTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
-class ConfigTest(ReusedSQLTestCase):
+class ConfigTest(PandasOnSparkTestCase):
def setUp(self):
config._options_dict["test.config"] = Option(key="test.config", doc="", default="default")
diff --git a/python/pyspark/pandas/tests/test_csv.py b/python/pyspark/pandas/tests/test_csv.py
index 7d32d819b5..17b3060c92 100644
--- a/python/pyspark/pandas/tests/test_csv.py
+++ b/python/pyspark/pandas/tests/test_csv.py
@@ -24,14 +24,14 @@
import numpy as np
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
def normalize_text(s):
return "\n".join(map(str.strip, s.strip().split("\n")))
-class CsvTest(ReusedSQLTestCase, TestUtils):
+class CsvTest(PandasOnSparkTestCase, TestUtils):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp(prefix=CsvTest.__name__)
diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py
index 6fa4933c25..d7cb3ab359 100644
--- a/python/pyspark/pandas/tests/test_dataframe.py
+++ b/python/pyspark/pandas/tests/test_dataframe.py
@@ -25,7 +25,6 @@
import numpy as np
import pandas as pd
from pandas.tseries.offsets import DateOffset
-import pyspark
from pyspark import StorageLevel
from pyspark.ml.linalg import SparseVector
from pyspark.sql import functions as F
@@ -41,16 +40,17 @@
extension_float_dtypes_available,
extension_object_dtypes_available,
)
-from pyspark.pandas.testing.utils import (
+from pyspark.testing.pandasutils import (
have_tabulate,
- ReusedSQLTestCase,
- SQLTestUtils,
+ PandasOnSparkTestCase,
SPARK_CONF_ARROW_ENABLED,
+ tabulate_requirement_message,
)
+from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.pandas.utils import name_like_string
-class DataFrameTest(ReusedSQLTestCase, SQLTestUtils):
+class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
@property
def pdf(self):
return pd.DataFrame(
@@ -565,11 +565,7 @@ def test_empty_dataframe(self):
pdf = pd.DataFrame({"a": pd.Series([], dtype="i1"), "b": pd.Series([], dtype="str")})
kdf = ps.from_pandas(pdf)
- if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
- self.assert_eq(kdf, pdf)
- else:
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self.assert_eq(kdf, pdf)
+ self.assert_eq(kdf, pdf)
with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
kdf = ps.from_pandas(pdf)
@@ -601,11 +597,7 @@ def test_all_null_dataframe(self):
)
kdf = ps.from_pandas(pdf)
- if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
- self.assert_eq(kdf, pdf)
- else:
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self.assert_eq(kdf, pdf)
+ self.assert_eq(kdf, pdf)
with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
kdf = ps.from_pandas(pdf)
@@ -2990,10 +2982,6 @@ def test_pivot_table_and_index(self):
self.assert_eq(ktable.index, ptable.index)
self.assert_eq(repr(ktable.index), repr(ptable.index))
- @unittest.skipIf(
- LooseVersion(pyspark.__version__) < LooseVersion("2.4"),
- "stack won't work properly with PySpark<2.4",
- )
def test_stack(self):
pdf_single_level_cols = pd.DataFrame(
[[0, 1], [2, 3]], index=["cat", "dog"], columns=["weight", "height"]
@@ -3235,22 +3223,13 @@ def _test_cumprod(self, pdf, kdf):
self.assert_eq(pdf.cumprod().sum(), kdf.cumprod().sum(), almost=True)
def test_cumprod(self):
- if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
- pdf = pd.DataFrame(
- [[2.0, 1.0, 1], [5, None, 2], [1.0, -1.0, -3], [2.0, 0, 4], [4.0, 9.0, 5]],
- columns=list("ABC"),
- index=np.random.rand(5),
- )
- kdf = ps.from_pandas(pdf)
- self._test_cumprod(pdf, kdf)
- else:
- pdf = pd.DataFrame(
- [[2, 1, 1], [5, 1, 2], [1, -1, -3], [2, 0, 4], [4, 9, 5]],
- columns=list("ABC"),
- index=np.random.rand(5),
- )
- kdf = ps.from_pandas(pdf)
- self._test_cumprod(pdf, kdf)
+ pdf = pd.DataFrame(
+ [[2.0, 1.0, 1], [5, None, 2], [1.0, -1.0, -3], [2.0, 0, 4], [4.0, 9.0, 5]],
+ columns=list("ABC"),
+ index=np.random.rand(5),
+ )
+ kdf = ps.from_pandas(pdf)
+ self._test_cumprod(pdf, kdf)
def test_cumprod_multiindex_columns(self):
arrays = [np.array(["A", "A", "B", "B"]), np.array(["one", "two", "one", "two"])]
@@ -4725,13 +4704,8 @@ def test_udt(self):
sparse_vector = SparseVector(len(sparse_values), sparse_values)
pdf = pd.DataFrame({"a": [sparse_vector], "b": [10]})
- if LooseVersion(pyspark.__version__) < LooseVersion("2.4"):
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- kdf = ps.from_pandas(pdf)
- self.assert_eq(kdf, pdf)
- else:
- kdf = ps.from_pandas(pdf)
- self.assert_eq(kdf, pdf)
+ kdf = ps.from_pandas(pdf)
+ self.assert_eq(kdf, pdf)
def test_eval(self):
pdf = pd.DataFrame({"A": range(1, 6), "B": range(10, 0, -2)})
@@ -4767,7 +4741,7 @@ def test_eval(self):
kdf.columns = columns
self.assertRaises(ValueError, lambda: kdf.eval("x.a + y.b"))
- @unittest.skipIf(not have_tabulate, "tabulate not installed")
+ @unittest.skipIf(not have_tabulate, tabulate_requirement_message)
def test_to_markdown(self):
pdf = pd.DataFrame(data={"animal_1": ["elk", "pig"], "animal_2": ["dog", "quetzal"]})
kdf = ps.from_pandas(pdf)
@@ -5161,10 +5135,6 @@ def test_iteritems(self):
self.assert_eq(p_name, k_name)
self.assert_eq(p_items, k_items)
- @unittest.skipIf(
- LooseVersion(pyspark.__version__) < LooseVersion("3.0"),
- "tail won't work properly with PySpark<3.0",
- )
def test_tail(self):
pdf = pd.DataFrame({"x": range(1000)})
kdf = ps.from_pandas(pdf)
@@ -5184,10 +5154,6 @@ def test_tail(self):
with self.assertRaisesRegex(TypeError, "bad operand type for unary -: 'str'"):
kdf.tail("10")
- @unittest.skipIf(
- LooseVersion(pyspark.__version__) < LooseVersion("3.0"),
- "last_valid_index won't work properly with PySpark<3.0",
- )
def test_last_valid_index(self):
pdf = pd.DataFrame(
{"a": [1, 2, 3, None], "b": [1.0, 2.0, 3.0, None], "c": [100, 200, 400, None]},
diff --git a/python/pyspark/pandas/tests/test_dataframe_conversion.py b/python/pyspark/pandas/tests/test_dataframe_conversion.py
index 8b64398634..92ddef6014 100644
--- a/python/pyspark/pandas/tests/test_dataframe_conversion.py
+++ b/python/pyspark/pandas/tests/test_dataframe_conversion.py
@@ -24,12 +24,13 @@
import numpy as np
import pandas as pd
-from pyspark import pandas as pp
from distutils.version import LooseVersion
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils, TestUtils
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.testing.sqlutils import SQLTestUtils
-class DataFrameConversionTest(ReusedSQLTestCase, SQLTestUtils, TestUtils):
+class DataFrameConversionTest(PandasOnSparkTestCase, SQLTestUtils, TestUtils):
"""Test cases for "small data" conversion and I/O."""
def setUp(self):
@@ -44,7 +45,7 @@ def pdf(self):
@property
def kdf(self):
- return pp.from_pandas(self.pdf)
+ return ps.from_pandas(self.pdf)
@staticmethod
def strip_all_whitespace(str):
@@ -113,7 +114,7 @@ def test_to_excel(self):
pdf = pd.DataFrame({"a": [1, None, 3], "b": ["one", "two", None]}, index=[0, 1, 3])
- kdf = pp.from_pandas(pdf)
+ kdf = ps.from_pandas(pdf)
kdf.to_excel(koalas_location, na_rep="null")
pdf.to_excel(pandas_location, na_rep="null")
@@ -122,7 +123,7 @@ def test_to_excel(self):
pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}, index=[0, 1, 3])
- kdf = pp.from_pandas(pdf)
+ kdf = ps.from_pandas(pdf)
kdf.to_excel(koalas_location, float_format="%.1f")
pdf.to_excel(pandas_location, float_format="%.1f")
@@ -141,12 +142,12 @@ def test_to_excel(self):
def test_to_json(self):
pdf = self.pdf
- kdf = pp.from_pandas(pdf)
+ kdf = ps.from_pandas(pdf)
self.assert_eq(kdf.to_json(orient="records"), pdf.to_json(orient="records"))
def test_to_json_negative(self):
- kdf = pp.from_pandas(self.pdf)
+ kdf = ps.from_pandas(self.pdf)
with self.assertRaises(NotImplementedError):
kdf.to_json(orient="table")
@@ -156,11 +157,11 @@ def test_to_json_negative(self):
def test_read_json_negative(self):
with self.assertRaises(NotImplementedError):
- pp.read_json("invalid", lines=False)
+ ps.read_json("invalid", lines=False)
def test_to_json_with_path(self):
pdf = pd.DataFrame({"a": [1], "b": ["a"]})
- kdf = pp.DataFrame(pdf)
+ kdf = ps.DataFrame(pdf)
kdf.to_json(self.tmp_dir, num_files=1)
expected = pdf.to_json(orient="records")
@@ -172,7 +173,7 @@ def test_to_json_with_path(self):
def test_to_json_with_partition_cols(self):
pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
- kdf = pp.DataFrame(pdf)
+ kdf = ps.DataFrame(pdf)
kdf.to_json(self.tmp_dir, partition_cols="b", num_files=1)
@@ -224,7 +225,7 @@ def test_to_records(self):
if LooseVersion(pd.__version__) >= LooseVersion("0.24.0"):
pdf = pd.DataFrame({"A": [1, 2], "B": [0.5, 0.75]}, index=["a", "b"])
- kdf = pp.from_pandas(pdf)
+ kdf = ps.from_pandas(pdf)
self.assert_eq(kdf.to_records(), pdf.to_records())
self.assert_eq(kdf.to_records(index=False), pdf.to_records(index=False))
@@ -233,32 +234,32 @@ def test_to_records(self):
def test_from_records(self):
# Assert using a dict as input
self.assert_eq(
- pp.DataFrame.from_records({"A": [1, 2, 3]}), pd.DataFrame.from_records({"A": [1, 2, 3]})
+ ps.DataFrame.from_records({"A": [1, 2, 3]}), pd.DataFrame.from_records({"A": [1, 2, 3]})
)
# Assert using a list of tuples as input
self.assert_eq(
- pp.DataFrame.from_records([(1, 2), (3, 4)]), pd.DataFrame.from_records([(1, 2), (3, 4)])
+ ps.DataFrame.from_records([(1, 2), (3, 4)]), pd.DataFrame.from_records([(1, 2), (3, 4)])
)
# Assert using a NumPy array as input
- self.assert_eq(pp.DataFrame.from_records(np.eye(3)), pd.DataFrame.from_records(np.eye(3)))
+ self.assert_eq(ps.DataFrame.from_records(np.eye(3)), pd.DataFrame.from_records(np.eye(3)))
# Asserting using a custom index
self.assert_eq(
- pp.DataFrame.from_records([(1, 2), (3, 4)], index=[2, 3]),
+ ps.DataFrame.from_records([(1, 2), (3, 4)], index=[2, 3]),
pd.DataFrame.from_records([(1, 2), (3, 4)], index=[2, 3]),
)
# Assert excluding excluding column(s)
self.assert_eq(
- pp.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, exclude=["B"]),
+ ps.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, exclude=["B"]),
pd.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, exclude=["B"]),
)
# Assert limiting to certain column(s)
self.assert_eq(
- pp.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, columns=["A"]),
+ ps.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, columns=["A"]),
pd.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, columns=["A"]),
)
# Assert limiting to a number of rows
self.assert_eq(
- pp.DataFrame.from_records([(1, 2), (3, 4)], nrows=1),
+ ps.DataFrame.from_records([(1, 2), (3, 4)], nrows=1),
pd.DataFrame.from_records([(1, 2), (3, 4)], nrows=1),
)
diff --git a/python/pyspark/pandas/tests/test_dataframe_spark_io.py b/python/pyspark/pandas/tests/test_dataframe_spark_io.py
index 818ce61dd8..f0982bd4e2 100644
--- a/python/pyspark/pandas/tests/test_dataframe_spark_io.py
+++ b/python/pyspark/pandas/tests/test_dataframe_spark_io.py
@@ -23,13 +23,12 @@
import numpy as np
import pandas as pd
import pyarrow as pa
-import pyspark
-from pyspark import pandas as pp
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class DataFrameSparkIOTest(ReusedSQLTestCase, TestUtils):
+class DataFrameSparkIOTest(PandasOnSparkTestCase, TestUtils):
"""Test cases for big data I/O using Spark."""
@property
@@ -60,7 +59,7 @@ def test_parquet_read(self):
def check(columns, expected):
if LooseVersion("0.21.1") <= LooseVersion(pd.__version__):
expected = pd.read_parquet(tmp, columns=columns)
- actual = pp.read_parquet(tmp, columns=columns)
+ actual = ps.read_parquet(tmp, columns=columns)
self.assertPandasEqual(expected, actual.to_pandas())
check(None, data)
@@ -82,24 +81,20 @@ def check(columns, expected):
expected = pd.read_parquet(tmp)
else:
expected = data
- actual = pp.read_parquet(tmp)
+ actual = ps.read_parquet(tmp)
self.assertPandasEqual(expected, actual.to_pandas())
# When index columns are known
pdf = self.test_pdf
- expected = pp.DataFrame(pdf)
+ expected = ps.DataFrame(pdf)
expected_idx = expected.set_index("bhello")[["f", "i32", "i64"]]
- actual_idx = pp.read_parquet(tmp, index_col="bhello")[["f", "i32", "i64"]]
+ actual_idx = ps.read_parquet(tmp, index_col="bhello")[["f", "i32", "i64"]]
self.assert_eq(
actual_idx.sort_values(by="f").to_spark().toPandas(),
expected_idx.sort_values(by="f").to_spark().toPandas(),
)
- @unittest.skipIf(
- LooseVersion(pyspark.__version__) < LooseVersion("3.0.0"),
- "The test only works with Spark>=3.0",
- )
def test_parquet_read_with_pandas_metadata(self):
with self.temp_dir() as tmp:
expected1 = self.test_pdf
@@ -107,32 +102,32 @@ def test_parquet_read_with_pandas_metadata(self):
path1 = "{}/file1.parquet".format(tmp)
expected1.to_parquet(path1)
- self.assert_eq(pp.read_parquet(path1, pandas_metadata=True), expected1)
+ self.assert_eq(ps.read_parquet(path1, pandas_metadata=True), expected1)
expected2 = expected1.reset_index()
path2 = "{}/file2.parquet".format(tmp)
expected2.to_parquet(path2)
- self.assert_eq(pp.read_parquet(path2, pandas_metadata=True), expected2)
+ self.assert_eq(ps.read_parquet(path2, pandas_metadata=True), expected2)
expected3 = expected2.set_index("index", append=True)
path3 = "{}/file3.parquet".format(tmp)
expected3.to_parquet(path3)
- self.assert_eq(pp.read_parquet(path3, pandas_metadata=True), expected3)
+ self.assert_eq(ps.read_parquet(path3, pandas_metadata=True), expected3)
def test_parquet_write(self):
with self.temp_dir() as tmp:
pdf = self.test_pdf
- expected = pp.DataFrame(pdf)
+ expected = ps.DataFrame(pdf)
# Write out partitioned by one column
expected.to_parquet(tmp, mode="overwrite", partition_cols="i32")
# Reset column order, as once the data is written out, Spark rearranges partition
# columns to appear first.
- actual = pp.read_parquet(tmp)
+ actual = ps.read_parquet(tmp)
self.assertFalse((actual.columns == self.test_column_order).all())
actual = actual[self.test_column_order]
self.assert_eq(
@@ -144,7 +139,7 @@ def test_parquet_write(self):
expected.to_parquet(tmp, mode="overwrite", partition_cols=["i32", "bhello"])
# Reset column order, as once the data is written out, Spark rearranges partition
# columns to appear first.
- actual = pp.read_parquet(tmp)
+ actual = ps.read_parquet(tmp)
self.assertFalse((actual.columns == self.test_column_order).all())
actual = actual[self.test_column_order]
self.assert_eq(
@@ -155,13 +150,13 @@ def test_parquet_write(self):
def test_table(self):
with self.table("test_table"):
pdf = self.test_pdf
- expected = pp.DataFrame(pdf)
+ expected = ps.DataFrame(pdf)
# Write out partitioned by one column
expected.spark.to_table("test_table", mode="overwrite", partition_cols="i32")
# Reset column order, as once the data is written out, Spark rearranges partition
# columns to appear first.
- actual = pp.read_table("test_table")
+ actual = ps.read_table("test_table")
self.assertFalse((actual.columns == self.test_column_order).all())
actual = actual[self.test_column_order]
self.assert_eq(
@@ -173,7 +168,7 @@ def test_table(self):
expected.to_table("test_table", mode="overwrite", partition_cols=["i32", "bhello"])
# Reset column order, as once the data is written out, Spark rearranges partition
# columns to appear first.
- actual = pp.read_table("test_table")
+ actual = ps.read_table("test_table")
self.assertFalse((actual.columns == self.test_column_order).all())
actual = actual[self.test_column_order]
self.assert_eq(
@@ -183,21 +178,21 @@ def test_table(self):
# When index columns are known
expected_idx = expected.set_index("bhello")[["f", "i32", "i64"]]
- actual_idx = pp.read_table("test_table", index_col="bhello")[["f", "i32", "i64"]]
+ actual_idx = ps.read_table("test_table", index_col="bhello")[["f", "i32", "i64"]]
self.assert_eq(
actual_idx.sort_values(by="f").to_spark().toPandas(),
expected_idx.sort_values(by="f").to_spark().toPandas(),
)
expected_idx = expected.set_index(["bhello"])[["f", "i32", "i64"]]
- actual_idx = pp.read_table("test_table", index_col=["bhello"])[["f", "i32", "i64"]]
+ actual_idx = ps.read_table("test_table", index_col=["bhello"])[["f", "i32", "i64"]]
self.assert_eq(
actual_idx.sort_values(by="f").to_spark().toPandas(),
expected_idx.sort_values(by="f").to_spark().toPandas(),
)
expected_idx = expected.set_index(["i32", "bhello"])[["f", "i64"]]
- actual_idx = pp.read_table("test_table", index_col=["i32", "bhello"])[["f", "i64"]]
+ actual_idx = ps.read_table("test_table", index_col=["i32", "bhello"])[["f", "i64"]]
self.assert_eq(
actual_idx.sort_values(by="f").to_spark().toPandas(),
expected_idx.sort_values(by="f").to_spark().toPandas(),
@@ -206,13 +201,13 @@ def test_table(self):
def test_spark_io(self):
with self.temp_dir() as tmp:
pdf = self.test_pdf
- expected = pp.DataFrame(pdf)
+ expected = ps.DataFrame(pdf)
# Write out partitioned by one column
expected.to_spark_io(tmp, format="json", mode="overwrite", partition_cols="i32")
# Reset column order, as once the data is written out, Spark rearranges partition
# columns to appear first.
- actual = pp.read_spark_io(tmp, format="json")
+ actual = ps.read_spark_io(tmp, format="json")
self.assertFalse((actual.columns == self.test_column_order).all())
actual = actual[self.test_column_order]
self.assert_eq(
@@ -226,7 +221,7 @@ def test_spark_io(self):
)
# Reset column order, as once the data is written out, Spark rearranges partition
# columns to appear first.
- actual = pp.read_spark_io(path=tmp, format="json")
+ actual = ps.read_spark_io(path=tmp, format="json")
self.assertFalse((actual.columns == self.test_column_order).all())
actual = actual[self.test_column_order]
self.assert_eq(
@@ -236,11 +231,11 @@ def test_spark_io(self):
# When index columns are known
pdf = self.test_pdf
- expected = pp.DataFrame(pdf)
+ expected = ps.DataFrame(pdf)
col_order = ["f", "i32", "i64"]
expected_idx = expected.set_index("bhello")[col_order]
- actual_idx = pp.read_spark_io(tmp, format="json", index_col="bhello")[col_order]
+ actual_idx = ps.read_spark_io(tmp, format="json", index_col="bhello")[col_order]
self.assert_eq(
actual_idx.sort_values(by="f").to_spark().toPandas(),
expected_idx.sort_values(by="f").to_spark().toPandas(),
@@ -253,45 +248,42 @@ def test_read_excel(self):
path1 = "{}/file1.xlsx".format(tmp)
self.test_pdf[["i32"]].to_excel(path1)
- self.assert_eq(pp.read_excel(open(path1, "rb")), pd.read_excel(open(path1, "rb")))
+ self.assert_eq(ps.read_excel(open(path1, "rb")), pd.read_excel(open(path1, "rb")))
self.assert_eq(
- pp.read_excel(open(path1, "rb"), index_col=0),
+ ps.read_excel(open(path1, "rb"), index_col=0),
pd.read_excel(open(path1, "rb"), index_col=0),
)
self.assert_eq(
- pp.read_excel(open(path1, "rb"), index_col=0, squeeze=True),
+ ps.read_excel(open(path1, "rb"), index_col=0, squeeze=True),
pd.read_excel(open(path1, "rb"), index_col=0, squeeze=True),
)
- if LooseVersion(pyspark.__version__) >= LooseVersion("3.0.0"):
- self.assert_eq(pp.read_excel(path1), pd.read_excel(path1))
- self.assert_eq(pp.read_excel(path1, index_col=0), pd.read_excel(path1, index_col=0))
- self.assert_eq(
- pp.read_excel(path1, index_col=0, squeeze=True),
- pd.read_excel(path1, index_col=0, squeeze=True),
- )
+ self.assert_eq(ps.read_excel(path1), pd.read_excel(path1))
+ self.assert_eq(ps.read_excel(path1, index_col=0), pd.read_excel(path1, index_col=0))
+ self.assert_eq(
+ ps.read_excel(path1, index_col=0, squeeze=True),
+ pd.read_excel(path1, index_col=0, squeeze=True),
+ )
- self.assert_eq(pp.read_excel(tmp), pd.read_excel(path1))
+ self.assert_eq(ps.read_excel(tmp), pd.read_excel(path1))
- path2 = "{}/file2.xlsx".format(tmp)
- self.test_pdf[["i32"]].to_excel(path2)
- self.assert_eq(
- pp.read_excel(tmp, index_col=0).sort_index(),
- pd.concat(
- [pd.read_excel(path1, index_col=0), pd.read_excel(path2, index_col=0)]
- ).sort_index(),
- )
- self.assert_eq(
- pp.read_excel(tmp, index_col=0, squeeze=True).sort_index(),
- pd.concat(
- [
- pd.read_excel(path1, index_col=0, squeeze=True),
- pd.read_excel(path2, index_col=0, squeeze=True),
- ]
- ).sort_index(),
- )
- else:
- self.assertRaises(ValueError, lambda: pp.read_excel(tmp))
+ path2 = "{}/file2.xlsx".format(tmp)
+ self.test_pdf[["i32"]].to_excel(path2)
+ self.assert_eq(
+ ps.read_excel(tmp, index_col=0).sort_index(),
+ pd.concat(
+ [pd.read_excel(path1, index_col=0), pd.read_excel(path2, index_col=0)]
+ ).sort_index(),
+ )
+ self.assert_eq(
+ ps.read_excel(tmp, index_col=0, squeeze=True).sort_index(),
+ pd.concat(
+ [
+ pd.read_excel(path1, index_col=0, squeeze=True),
+ pd.read_excel(path2, index_col=0, squeeze=True),
+ ]
+ ).sort_index(),
+ )
with self.temp_dir() as tmp:
path1 = "{}/file1.xlsx".format(tmp)
@@ -307,79 +299,76 @@ def test_read_excel(self):
)
for sheet_name in sheet_names:
- kdfs = pp.read_excel(open(path1, "rb"), sheet_name=sheet_name, index_col=0)
+ kdfs = ps.read_excel(open(path1, "rb"), sheet_name=sheet_name, index_col=0)
self.assert_eq(kdfs["Sheet_name_1"], pdfs1["Sheet_name_1"])
self.assert_eq(kdfs["Sheet_name_2"], pdfs1["Sheet_name_2"])
- kdfs = pp.read_excel(
+ kdfs = ps.read_excel(
open(path1, "rb"), sheet_name=sheet_name, index_col=0, squeeze=True
)
self.assert_eq(kdfs["Sheet_name_1"], pdfs1_squeezed["Sheet_name_1"])
self.assert_eq(kdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"])
- if LooseVersion(pyspark.__version__) >= LooseVersion("3.0.0"):
- self.assert_eq(
- pp.read_excel(tmp, index_col=0, sheet_name="Sheet_name_2"),
- pdfs1["Sheet_name_2"],
- )
+ self.assert_eq(
+ ps.read_excel(tmp, index_col=0, sheet_name="Sheet_name_2"),
+ pdfs1["Sheet_name_2"],
+ )
- for sheet_name in sheet_names:
- kdfs = pp.read_excel(tmp, sheet_name=sheet_name, index_col=0)
- self.assert_eq(kdfs["Sheet_name_1"], pdfs1["Sheet_name_1"])
- self.assert_eq(kdfs["Sheet_name_2"], pdfs1["Sheet_name_2"])
+ for sheet_name in sheet_names:
+ kdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0)
+ self.assert_eq(kdfs["Sheet_name_1"], pdfs1["Sheet_name_1"])
+ self.assert_eq(kdfs["Sheet_name_2"], pdfs1["Sheet_name_2"])
- kdfs = pp.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True)
- self.assert_eq(kdfs["Sheet_name_1"], pdfs1_squeezed["Sheet_name_1"])
- self.assert_eq(kdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"])
+ kdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True)
+ self.assert_eq(kdfs["Sheet_name_1"], pdfs1_squeezed["Sheet_name_1"])
+ self.assert_eq(kdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"])
- path2 = "{}/file2.xlsx".format(tmp)
- with pd.ExcelWriter(path2) as writer:
- self.test_pdf.to_excel(writer, sheet_name="Sheet_name_1")
- self.test_pdf[["i32"]].to_excel(writer, sheet_name="Sheet_name_2")
+ path2 = "{}/file2.xlsx".format(tmp)
+ with pd.ExcelWriter(path2) as writer:
+ self.test_pdf.to_excel(writer, sheet_name="Sheet_name_1")
+ self.test_pdf[["i32"]].to_excel(writer, sheet_name="Sheet_name_2")
- pdfs2 = pd.read_excel(path2, sheet_name=None, index_col=0)
- pdfs2_squeezed = pd.read_excel(path2, sheet_name=None, index_col=0, squeeze=True)
+ pdfs2 = pd.read_excel(path2, sheet_name=None, index_col=0)
+ pdfs2_squeezed = pd.read_excel(path2, sheet_name=None, index_col=0, squeeze=True)
+ self.assert_eq(
+ ps.read_excel(tmp, sheet_name="Sheet_name_2", index_col=0).sort_index(),
+ pd.concat([pdfs1["Sheet_name_2"], pdfs2["Sheet_name_2"]]).sort_index(),
+ )
+ self.assert_eq(
+ ps.read_excel(
+ tmp, sheet_name="Sheet_name_2", index_col=0, squeeze=True
+ ).sort_index(),
+ pd.concat(
+ [pdfs1_squeezed["Sheet_name_2"], pdfs2_squeezed["Sheet_name_2"]]
+ ).sort_index(),
+ )
+
+ for sheet_name in sheet_names:
+ kdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0)
self.assert_eq(
- pp.read_excel(tmp, sheet_name="Sheet_name_2", index_col=0).sort_index(),
+ kdfs["Sheet_name_1"].sort_index(),
+ pd.concat([pdfs1["Sheet_name_1"], pdfs2["Sheet_name_1"]]).sort_index(),
+ )
+ self.assert_eq(
+ kdfs["Sheet_name_2"].sort_index(),
pd.concat([pdfs1["Sheet_name_2"], pdfs2["Sheet_name_2"]]).sort_index(),
)
+
+ kdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True)
self.assert_eq(
- pp.read_excel(
- tmp, sheet_name="Sheet_name_2", index_col=0, squeeze=True
+ kdfs["Sheet_name_1"].sort_index(),
+ pd.concat(
+ [pdfs1_squeezed["Sheet_name_1"], pdfs2_squeezed["Sheet_name_1"]]
).sort_index(),
+ )
+ self.assert_eq(
+ kdfs["Sheet_name_2"].sort_index(),
pd.concat(
[pdfs1_squeezed["Sheet_name_2"], pdfs2_squeezed["Sheet_name_2"]]
).sort_index(),
)
- for sheet_name in sheet_names:
- kdfs = pp.read_excel(tmp, sheet_name=sheet_name, index_col=0)
- self.assert_eq(
- kdfs["Sheet_name_1"].sort_index(),
- pd.concat([pdfs1["Sheet_name_1"], pdfs2["Sheet_name_1"]]).sort_index(),
- )
- self.assert_eq(
- kdfs["Sheet_name_2"].sort_index(),
- pd.concat([pdfs1["Sheet_name_2"], pdfs2["Sheet_name_2"]]).sort_index(),
- )
-
- kdfs = pp.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True)
- self.assert_eq(
- kdfs["Sheet_name_1"].sort_index(),
- pd.concat(
- [pdfs1_squeezed["Sheet_name_1"], pdfs2_squeezed["Sheet_name_1"]]
- ).sort_index(),
- )
- self.assert_eq(
- kdfs["Sheet_name_2"].sort_index(),
- pd.concat(
- [pdfs1_squeezed["Sheet_name_2"], pdfs2_squeezed["Sheet_name_2"]]
- ).sort_index(),
- )
- else:
- self.assertRaises(ValueError, lambda: pp.read_excel(tmp))
-
def test_read_orc(self):
with self.temp_dir() as tmp:
path = "{}/file1.orc".format(tmp)
@@ -393,50 +382,50 @@ def test_read_orc(self):
orc_file_path = glob.glob(os.path.join(path, "*.orc"))[0]
expected = data.reset_index()[data.columns]
- actual = pp.read_orc(path)
+ actual = ps.read_orc(path)
self.assertPandasEqual(expected, actual.to_pandas())
# columns
columns = ["i32", "i64"]
expected = data.reset_index()[columns]
- actual = pp.read_orc(path, columns=columns)
+ actual = ps.read_orc(path, columns=columns)
self.assertPandasEqual(expected, actual.to_pandas())
# index_col
expected = data.set_index("i32")
- actual = pp.read_orc(path, index_col="i32")
+ actual = ps.read_orc(path, index_col="i32")
self.assert_eq(actual, expected)
expected = data.set_index(["i32", "f"])
- actual = pp.read_orc(path, index_col=["i32", "f"])
+ actual = ps.read_orc(path, index_col=["i32", "f"])
self.assert_eq(actual, expected)
# index_col with columns
expected = data.set_index("i32")[["i64", "bhello"]]
- actual = pp.read_orc(path, index_col=["i32"], columns=["i64", "bhello"])
+ actual = ps.read_orc(path, index_col=["i32"], columns=["i64", "bhello"])
self.assert_eq(actual, expected)
expected = data.set_index(["i32", "f"])[["bhello", "i64"]]
- actual = pp.read_orc(path, index_col=["i32", "f"], columns=["bhello", "i64"])
+ actual = ps.read_orc(path, index_col=["i32", "f"], columns=["bhello", "i64"])
self.assert_eq(actual, expected)
msg = "Unknown column name 'i'"
with self.assertRaises(ValueError, msg=msg):
- pp.read_orc(path, columns="i32")
+ ps.read_orc(path, columns="i32")
msg = "Unknown column name 'i34'"
with self.assertRaises(ValueError, msg=msg):
- pp.read_orc(path, columns=["i34", "i64"])
+ ps.read_orc(path, columns=["i34", "i64"])
def test_orc_write(self):
with self.temp_dir() as tmp:
pdf = self.test_pdf
- expected = pp.DataFrame(pdf)
+ expected = ps.DataFrame(pdf)
# Write out partitioned by one column
expected.to_orc(tmp, mode="overwrite", partition_cols="i32")
# Reset column order, as once the data is written out, Spark rearranges partition
# columns to appear first.
- actual = pp.read_orc(tmp)
+ actual = ps.read_orc(tmp)
self.assertFalse((actual.columns == self.test_column_order).all())
actual = actual[self.test_column_order]
self.assert_eq(
@@ -448,7 +437,7 @@ def test_orc_write(self):
expected.to_orc(tmp, mode="overwrite", partition_cols=["i32", "bhello"])
# Reset column order, as once the data is written out, Spark rearranges partition
# columns to appear first.
- actual = pp.read_orc(tmp)
+ actual = ps.read_orc(tmp)
self.assertFalse((actual.columns == self.test_column_order).all())
actual = actual[self.test_column_order]
self.assert_eq(
diff --git a/python/pyspark/pandas/tests/test_default_index.py b/python/pyspark/pandas/tests/test_default_index.py
index 4075de4f11..838e04a9eb 100644
--- a/python/pyspark/pandas/tests/test_default_index.py
+++ b/python/pyspark/pandas/tests/test_default_index.py
@@ -18,10 +18,10 @@
import pandas as pd
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
-class DefaultIndexTest(ReusedSQLTestCase):
+class DefaultIndexTest(PandasOnSparkTestCase):
def test_default_index_sequence(self):
with ps.option_context("compute.default_index_type", "sequence"):
sdf = self.spark.range(1000)
diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/test_expanding.py
index c341892c51..7198a1d5d0 100644
--- a/python/pyspark/pandas/tests/test_expanding.py
+++ b/python/pyspark/pandas/tests/test_expanding.py
@@ -21,11 +21,11 @@
import pandas as pd
import pyspark.pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
from pyspark.pandas.window import Expanding
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class ExpandingTest(ReusedSQLTestCase, TestUtils):
+class ExpandingTest(PandasOnSparkTestCase, TestUtils):
def _test_expanding_func(self, f):
pser = pd.Series([1, 2, 3], index=np.random.rand(3))
kser = ps.from_pandas(pser)
diff --git a/python/pyspark/pandas/tests/test_extension.py b/python/pyspark/pandas/tests/test_extension.py
index 9fb61d02cc..17dc2bcd8b 100644
--- a/python/pyspark/pandas/tests/test_extension.py
+++ b/python/pyspark/pandas/tests/test_extension.py
@@ -21,7 +21,7 @@
import pandas as pd
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import assert_produces_warning, ReusedSQLTestCase
+from pyspark.testing.pandasutils import assert_produces_warning, PandasOnSparkTestCase
from pyspark.pandas.extensions import (
register_dataframe_accessor,
register_series_accessor,
@@ -66,7 +66,7 @@ def check_length(self, col=None):
raise ValueError(str(e))
-class ExtensionTest(ReusedSQLTestCase):
+class ExtensionTest(PandasOnSparkTestCase):
@property
def pdf(self):
return pd.DataFrame(
diff --git a/python/pyspark/pandas/tests/test_frame_spark.py b/python/pyspark/pandas/tests/test_frame_spark.py
index 3dca25f6ab..6a226a740f 100644
--- a/python/pyspark/pandas/tests/test_frame_spark.py
+++ b/python/pyspark/pandas/tests/test_frame_spark.py
@@ -15,22 +15,21 @@
# limitations under the License.
#
-from distutils.version import LooseVersion
import os
import pandas as pd
-import pyspark
-from pyspark import pandas as pp
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils, TestUtils
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.testing.sqlutils import SQLTestUtils
-class SparkFrameMethodsTest(ReusedSQLTestCase, SQLTestUtils, TestUtils):
+class SparkFrameMethodsTest(PandasOnSparkTestCase, SQLTestUtils, TestUtils):
def test_frame_apply_negative(self):
with self.assertRaisesRegex(
ValueError, "The output of the function.* pyspark.sql.DataFrame.*int"
):
- pp.range(10).spark.apply(lambda scol: 1)
+ ps.range(10).spark.apply(lambda scol: 1)
def test_hint(self):
pdf1 = pd.DataFrame(
@@ -39,13 +38,10 @@ def test_hint(self):
pdf2 = pd.DataFrame(
{"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]}
).set_index("rkey")
- kdf1 = pp.from_pandas(pdf1)
- kdf2 = pp.from_pandas(pdf2)
+ kdf1 = ps.from_pandas(pdf1)
+ kdf2 = ps.from_pandas(pdf2)
- if LooseVersion(pyspark.__version__) >= LooseVersion("3.0"):
- hints = ["broadcast", "merge", "shuffle_hash", "shuffle_replicate_nl"]
- else:
- hints = ["broadcast"]
+ hints = ["broadcast", "merge", "shuffle_hash", "shuffle_replicate_nl"]
for hint in hints:
self.assert_eq(
@@ -68,7 +64,7 @@ def test_hint(self):
)
def test_repartition(self):
- kdf = pp.DataFrame({"age": [5, 5, 2, 2], "name": ["Bob", "Bob", "Alice", "Alice"]})
+ kdf = ps.DataFrame({"age": [5, 5, 2, 2], "name": ["Bob", "Bob", "Alice", "Alice"]})
num_partitions = kdf.to_spark().rdd.getNumPartitions() + 1
num_partitions += 1
@@ -91,7 +87,7 @@ def test_repartition(self):
self.assert_eq(kdf2.sort_index(), (kdf + 1).spark.repartition(num_partitions).sort_index())
# Reserves MultiIndex
- kdf = pp.DataFrame({"a": ["a", "b", "c"]}, index=[[1, 2, 3], [4, 5, 6]])
+ kdf = ps.DataFrame({"a": ["a", "b", "c"]}, index=[[1, 2, 3], [4, 5, 6]])
num_partitions = kdf.to_spark().rdd.getNumPartitions() + 1
new_kdf = kdf.spark.repartition(num_partitions)
self.assertEqual(new_kdf.to_spark().rdd.getNumPartitions(), num_partitions)
@@ -99,7 +95,7 @@ def test_repartition(self):
def test_coalesce(self):
num_partitions = 10
- kdf = pp.DataFrame({"age": [5, 5, 2, 2], "name": ["Bob", "Bob", "Alice", "Alice"]})
+ kdf = ps.DataFrame({"age": [5, 5, 2, 2], "name": ["Bob", "Bob", "Alice", "Alice"]})
kdf = kdf.spark.repartition(num_partitions)
num_partitions -= 1
@@ -122,7 +118,7 @@ def test_coalesce(self):
self.assert_eq(kdf2.sort_index(), (kdf + 1).spark.coalesce(num_partitions).sort_index())
# Reserves MultiIndex
- kdf = pp.DataFrame({"a": ["a", "b", "c"]}, index=[[1, 2, 3], [4, 5, 6]])
+ kdf = ps.DataFrame({"a": ["a", "b", "c"]}, index=[[1, 2, 3], [4, 5, 6]])
num_partitions -= 1
kdf = kdf.spark.repartition(num_partitions)
@@ -134,13 +130,13 @@ def test_coalesce(self):
def test_checkpoint(self):
with self.temp_dir() as tmp:
self.spark.sparkContext.setCheckpointDir(tmp)
- kdf = pp.DataFrame({"a": ["a", "b", "c"]})
+ kdf = ps.DataFrame({"a": ["a", "b", "c"]})
new_kdf = kdf.spark.checkpoint()
self.assertIsNotNone(os.listdir(tmp))
self.assert_eq(kdf, new_kdf)
def test_local_checkpoint(self):
- kdf = pp.DataFrame({"a": ["a", "b", "c"]})
+ kdf = ps.DataFrame({"a": ["a", "b", "c"]})
new_kdf = kdf.spark.local_checkpoint()
self.assert_eq(kdf, new_kdf)
diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py
index ca1f9e9f72..a6d006fad9 100644
--- a/python/pyspark/pandas/tests/test_groupby.py
+++ b/python/pyspark/pandas/tests/test_groupby.py
@@ -30,11 +30,11 @@
MissingPandasLikeDataFrameGroupBy,
MissingPandasLikeSeriesGroupBy,
)
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
from pyspark.pandas.groupby import is_multi_agg_with_relabel
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class GroupByTest(ReusedSQLTestCase, TestUtils):
+class GroupByTest(PandasOnSparkTestCase, TestUtils):
def test_groupby_simple(self):
pdf = pd.DataFrame(
{
diff --git a/python/pyspark/pandas/tests/test_indexing.py b/python/pyspark/pandas/tests/test_indexing.py
index 8298767f67..0d02c46b0e 100644
--- a/python/pyspark/pandas/tests/test_indexing.py
+++ b/python/pyspark/pandas/tests/test_indexing.py
@@ -24,7 +24,7 @@
from pyspark import pandas as ps
from pyspark.pandas.exceptions import SparkPandasIndexingError
-from pyspark.pandas.testing.utils import ComparisonTestBase, ReusedSQLTestCase, compare_both
+from pyspark.testing.pandasutils import ComparisonTestBase, PandasOnSparkTestCase, compare_both
class BasicIndexingTest(ComparisonTestBase):
@@ -153,7 +153,7 @@ def test_limitations(self):
)
-class IndexingTest(ReusedSQLTestCase):
+class IndexingTest(PandasOnSparkTestCase):
@property
def pdf(self):
return pd.DataFrame(
diff --git a/python/pyspark/pandas/tests/test_indexops_spark.py b/python/pyspark/pandas/tests/test_indexops_spark.py
index ae659ac17f..831b764271 100644
--- a/python/pyspark/pandas/tests/test_indexops_spark.py
+++ b/python/pyspark/pandas/tests/test_indexops_spark.py
@@ -20,10 +20,11 @@
from pyspark.sql import functions as F
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
-class SparkIndexOpsMethodsTest(ReusedSQLTestCase, SQLTestUtils):
+class SparkIndexOpsMethodsTest(PandasOnSparkTestCase, SQLTestUtils):
@property
def pser(self):
return pd.Series([1, 2, 3, 4, 5, 6, 7], name="x")
diff --git a/python/pyspark/pandas/tests/test_internal.py b/python/pyspark/pandas/tests/test_internal.py
index f93b24bbe1..f9e96cd995 100644
--- a/python/pyspark/pandas/tests/test_internal.py
+++ b/python/pyspark/pandas/tests/test_internal.py
@@ -22,10 +22,11 @@
SPARK_DEFAULT_INDEX_NAME,
SPARK_INDEX_NAME_FORMAT,
)
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
-class InternalFrameTest(ReusedSQLTestCase, SQLTestUtils):
+class InternalFrameTest(PandasOnSparkTestCase, SQLTestUtils):
def test_from_pandas(self):
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
diff --git a/python/pyspark/pandas/tests/test_namespace.py b/python/pyspark/pandas/tests/test_namespace.py
index 9172f045c2..e8787397e1 100644
--- a/python/pyspark/pandas/tests/test_namespace.py
+++ b/python/pyspark/pandas/tests/test_namespace.py
@@ -20,11 +20,12 @@
import pandas as pd
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
from pyspark.pandas.namespace import _get_index_map
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
-class NamespaceTest(ReusedSQLTestCase, SQLTestUtils):
+class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils):
def test_from_pandas(self):
pdf = pd.DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]})
kdf = ps.from_pandas(pdf)
diff --git a/python/pyspark/pandas/tests/test_numpy_compat.py b/python/pyspark/pandas/tests/test_numpy_compat.py
index e278739c31..ce2bbe1702 100644
--- a/python/pyspark/pandas/tests/test_numpy_compat.py
+++ b/python/pyspark/pandas/tests/test_numpy_compat.py
@@ -23,10 +23,11 @@
from pyspark import pandas as ps
from pyspark.pandas import set_option, reset_option
from pyspark.pandas.numpy_compat import unary_np_spark_mappings, binary_np_spark_mappings
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
-class NumPyCompatTest(ReusedSQLTestCase, SQLTestUtils):
+class NumPyCompatTest(PandasOnSparkTestCase, SQLTestUtils):
blacklist = [
# Koalas does not currently support
"conj",
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
index d567bae3cd..a998414542 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
@@ -22,12 +22,11 @@
import pandas as pd
import numpy as np
-import pyspark
-
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
from pyspark.pandas.frame import DataFrame
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.pandas.typedef.typehints import (
extension_dtypes,
extension_dtypes_available,
@@ -36,7 +35,7 @@
)
-class OpsOnDiffFramesEnabledTest(ReusedSQLTestCase, SQLTestUtils):
+class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
@classmethod
def setUpClass(cls):
super().setUpClass()
@@ -1549,10 +1548,7 @@ def test_series_repeat(self):
kser1 = ps.from_pandas(pser1)
kser2 = ps.from_pandas(pser2)
- if LooseVersion(pyspark.__version__) < LooseVersion("2.4"):
- self.assertRaises(ValueError, lambda: kser1.repeat(kser2))
- else:
- self.assert_eq(kser1.repeat(kser2).sort_index(), pser1.repeat(pser2).sort_index())
+ self.assert_eq(kser1.repeat(kser2).sort_index(), pser1.repeat(pser2).sort_index())
def test_series_ops(self):
pser1 = pd.Series([1, 2, 3, 4, 5, 6, 7], name="x", index=[11, 12, 13, 14, 15, 16, 17])
@@ -1774,7 +1770,7 @@ def test_rank(self):
)
-class OpsOnDiffFramesDisabledTest(ReusedSQLTestCase, SQLTestUtils):
+class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils):
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py
index 84c72cbbbb..ce4653868d 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py
@@ -21,10 +21,11 @@
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
-class OpsOnDiffFramesGroupByTest(ReusedSQLTestCase, SQLTestUtils):
+class OpsOnDiffFramesGroupByTest(PandasOnSparkTestCase, SQLTestUtils):
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py
index 88cf84e95d..afd81854e8 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py
@@ -22,10 +22,10 @@
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class OpsOnDiffFramesGroupByExpandingTest(ReusedSQLTestCase, TestUtils):
+class OpsOnDiffFramesGroupByExpandingTest(PandasOnSparkTestCase, TestUtils):
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py
index 8b7e3edb43..158af35f61 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py
@@ -19,10 +19,10 @@
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class OpsOnDiffFramesGroupByRollingTest(ReusedSQLTestCase, TestUtils):
+class OpsOnDiffFramesGroupByRollingTest(PandasOnSparkTestCase, TestUtils):
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/python/pyspark/pandas/tests/test_repr.py b/python/pyspark/pandas/tests/test_repr.py
index 7259639ae9..e2b1c166f1 100644
--- a/python/pyspark/pandas/tests/test_repr.py
+++ b/python/pyspark/pandas/tests/test_repr.py
@@ -15,17 +15,14 @@
# limitations under the License.
#
-from distutils.version import LooseVersion
-
import numpy as np
-import pyspark
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option, option_context
-from pyspark.pandas.testing.utils import ReusedSQLTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
-class ReprTest(ReusedSQLTestCase):
+class ReprTest(PandasOnSparkTestCase):
max_display_count = 23
@classmethod
@@ -82,26 +79,25 @@ def test_repr_series(self):
kser = ps.range(ReprTest.max_display_count + 1).id.rename()
self.assert_eq(repr(kser), repr(kser.to_pandas()))
- if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
- kser = ps.MultiIndex.from_tuples(
- [(100 * i, i) for i in range(ReprTest.max_display_count)]
- ).to_series()
- self.assertTrue("Showing only the first" not in repr(kser))
- self.assert_eq(repr(kser), repr(kser.to_pandas()))
+ kser = ps.MultiIndex.from_tuples(
+ [(100 * i, i) for i in range(ReprTest.max_display_count)]
+ ).to_series()
+ self.assertTrue("Showing only the first" not in repr(kser))
+ self.assert_eq(repr(kser), repr(kser.to_pandas()))
+
+ kser = ps.MultiIndex.from_tuples(
+ [(100 * i, i) for i in range(ReprTest.max_display_count + 1)]
+ ).to_series()
+ self.assertTrue("Showing only the first" in repr(kser))
+ self.assertTrue(
+ repr(kser).startswith(repr(kser.to_pandas().head(ReprTest.max_display_count)))
+ )
+ with option_context("display.max_rows", None):
kser = ps.MultiIndex.from_tuples(
[(100 * i, i) for i in range(ReprTest.max_display_count + 1)]
).to_series()
- self.assertTrue("Showing only the first" in repr(kser))
- self.assertTrue(
- repr(kser).startswith(repr(kser.to_pandas().head(ReprTest.max_display_count)))
- )
-
- with option_context("display.max_rows", None):
- kser = ps.MultiIndex.from_tuples(
- [(100 * i, i) for i in range(ReprTest.max_display_count + 1)]
- ).to_series()
- self.assert_eq(repr(kser), repr(kser.to_pandas()))
+ self.assert_eq(repr(kser), repr(kser.to_pandas()))
def test_repr_indexes(self):
kidx = ps.range(ReprTest.max_display_count).index
diff --git a/python/pyspark/pandas/tests/test_reshape.py b/python/pyspark/pandas/tests/test_reshape.py
index 1f3dfbe2d7..96665dfa01 100644
--- a/python/pyspark/pandas/tests/test_reshape.py
+++ b/python/pyspark/pandas/tests/test_reshape.py
@@ -21,14 +21,13 @@
import numpy as np
import pandas as pd
-import pyspark
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SPARK_CONF_ARROW_ENABLED
from pyspark.pandas.utils import name_like_string
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
-class ReshapeTest(ReusedSQLTestCase):
+class ReshapeTest(PandasOnSparkTestCase):
def test_get_dummies(self):
for pdf_or_ps in [
pd.Series([1, 1, 1, 2, 2, 1, 3, 4]),
@@ -111,41 +110,23 @@ def test_get_dummies_date_datetime(self):
)
kdf = ps.from_pandas(pdf)
- if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
- self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8))
- self.assert_eq(ps.get_dummies(kdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8))
- else:
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8))
- self.assert_eq(ps.get_dummies(kdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(kdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8))
def test_get_dummies_boolean(self):
pdf = pd.DataFrame({"b": [True, False, True]})
kdf = ps.from_pandas(pdf)
- if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
- self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(ps.get_dummies(kdf.b), pd.get_dummies(pdf.b, dtype=np.int8))
- else:
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(ps.get_dummies(kdf.b), pd.get_dummies(pdf.b, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(kdf.b), pd.get_dummies(pdf.b, dtype=np.int8))
def test_get_dummies_decimal(self):
pdf = pd.DataFrame({"d": [Decimal(1.0), Decimal(2.0), Decimal(1)]})
kdf = ps.from_pandas(pdf)
- if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
- self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True)
- else:
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(
- ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True
- )
+ self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True)
def test_get_dummies_kwargs(self):
# pser = pd.Series([1, 1, 1, 2, 2, 1, 3, 4], dtype='category')
diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py
index f664b2ac9f..3827a6017e 100644
--- a/python/pyspark/pandas/tests/test_rolling.py
+++ b/python/pyspark/pandas/tests/test_rolling.py
@@ -19,11 +19,11 @@
import pandas as pd
import pyspark.pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
from pyspark.pandas.window import Rolling
-class RollingTest(ReusedSQLTestCase, TestUtils):
+class RollingTest(PandasOnSparkTestCase, TestUtils):
def test_rolling_error(self):
with self.assertRaisesRegex(ValueError, "window must be >= 0"):
ps.range(10).rolling(window=-1)
diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py
index b5960d65ab..eae26bc4c8 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -24,17 +24,17 @@
import numpy as np
import pandas as pd
-import pyspark
from pyspark.ml.linalg import SparseVector
from pyspark.sql import functions as F
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import (
+from pyspark.testing.pandasutils import (
have_tabulate,
- ReusedSQLTestCase,
- SQLTestUtils,
+ PandasOnSparkTestCase,
SPARK_CONF_ARROW_ENABLED,
+ tabulate_requirement_message,
)
+from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.pandas.exceptions import PandasNotImplementedError
from pyspark.pandas.missing.series import MissingPandasLikeSeries
from pyspark.pandas.typedef.typehints import (
@@ -45,7 +45,7 @@
)
-class SeriesTest(ReusedSQLTestCase, SQLTestUtils):
+class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
@property
def pser(self):
return pd.Series([1, 2, 3, 4, 5, 6, 7], name="x")
@@ -146,11 +146,7 @@ def test_empty_series(self):
self.assert_eq(ps.from_pandas(pser_a), pser_a)
kser_b = ps.from_pandas(pser_b)
- if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
- self.assert_eq(kser_b, pser_b)
- else:
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self.assert_eq(kser_b, pser_b)
+ self.assert_eq(kser_b, pser_b)
with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
self.assert_eq(ps.from_pandas(pser_a), pser_a)
@@ -163,11 +159,7 @@ def test_all_null_series(self):
self.assert_eq(ps.from_pandas(pser_a), pser_a)
kser_b = ps.from_pandas(pser_b)
- if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
- self.assert_eq(kser_b, pser_b)
- else:
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self.assert_eq(kser_b, pser_b)
+ self.assert_eq(kser_b, pser_b)
with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
self.assert_eq(ps.from_pandas(pser_a), pser_a)
@@ -628,7 +620,7 @@ def test_nunique(self):
self.assertEqual(ps.Series(range(100)).nunique(approx=True), 103)
self.assertEqual(ps.Series(range(100)).nunique(approx=True, rsd=0.01), 100)
- def _test_value_counts(self):
+ def test_value_counts(self):
# this is also containing test for Index & MultiIndex
pser = pd.Series(
[1, 2, 1, 3, 3, np.nan, 1, 4, 2, np.nan, 3, np.nan, 3, 1, 3],
@@ -856,17 +848,6 @@ def _test_value_counts(self):
almost=True,
)
- def test_value_counts(self):
- if LooseVersion(pyspark.__version__) < LooseVersion("2.4"):
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- self._test_value_counts()
- self.assertRaises(
- RuntimeError,
- lambda: ps.MultiIndex.from_tuples([("x", "a"), ("x", "b")]).value_counts(),
- )
- else:
- self._test_value_counts()
-
def test_nsmallest(self):
sample_lst = [1, 2, 3, 4, np.nan, 6]
pser = pd.Series(sample_lst, name="x")
@@ -1891,14 +1872,8 @@ def test_udt(self):
sparse_values = {0: 0.1, 1: 1.1}
sparse_vector = SparseVector(len(sparse_values), sparse_values)
pser = pd.Series([sparse_vector])
-
- if LooseVersion(pyspark.__version__) < LooseVersion("2.4"):
- with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
- kser = ps.from_pandas(pser)
- self.assert_eq(kser, pser)
- else:
- kser = ps.from_pandas(pser)
- self.assert_eq(kser, pser)
+ kser = ps.from_pandas(pser)
+ self.assert_eq(kser, pser)
def test_repeat(self):
pser = pd.Series(["a", "b", "c"], name="0", index=np.random.rand(3))
@@ -1913,10 +1888,7 @@ def test_repeat(self):
pdf = pd.DataFrame({"a": ["a", "b", "c"], "rep": [10, 20, 30]}, index=np.random.rand(3))
kdf = ps.from_pandas(pdf)
- if LooseVersion(pyspark.__version__) < LooseVersion("2.4"):
- self.assertRaises(ValueError, lambda: kdf.a.repeat(kdf.rep))
- else:
- self.assert_eq(kdf.a.repeat(kdf.rep).sort_index(), pdf.a.repeat(pdf.rep).sort_index())
+ self.assert_eq(kdf.a.repeat(kdf.rep).sort_index(), pdf.a.repeat(pdf.rep).sort_index())
def test_take(self):
pser = pd.Series([100, 200, 300, 400, 500], name="Koalas")
@@ -2209,7 +2181,7 @@ def test_shape(self):
self.assert_eq(pser.shape, kser.shape)
- @unittest.skipIf(not have_tabulate, "tabulate not installed")
+ @unittest.skipIf(not have_tabulate, tabulate_requirement_message)
def test_to_markdown(self):
pser = pd.Series(["elk", "pig", "dog", "quetzal"], name="animal")
kser = ps.from_pandas(pser)
@@ -2407,10 +2379,6 @@ def test_dot(self):
self.assert_eq((kdf["b"] * 10).dot(kdf), (pdf["b"] * 10).dot(pdf))
self.assert_eq((kdf["b"] * 10).dot(kdf + 1), (pdf["b"] * 10).dot(pdf + 1))
- @unittest.skipIf(
- LooseVersion(pyspark.__version__) < LooseVersion("3.0"),
- "tail won't work properly with PySpark<3.0",
- )
def test_tail(self):
pser = pd.Series(range(1000), name="Koalas")
kser = ps.from_pandas(pser)
@@ -2508,10 +2476,6 @@ def test_hasnans(self):
kser = ps.from_pandas(pser)
self.assert_eq(pser.hasnans, kser.hasnans)
- @unittest.skipIf(
- LooseVersion(pyspark.__version__) < LooseVersion("3.0"),
- "last_valid_index won't work properly with PySpark<3.0",
- )
def test_last_valid_index(self):
pser = pd.Series([250, 1.5, 320, 1, 0.3, None, None, None, None])
kser = ps.from_pandas(pser)
diff --git a/python/pyspark/pandas/tests/test_series_conversion.py b/python/pyspark/pandas/tests/test_series_conversion.py
index 2b19249c0d..18ce24de74 100644
--- a/python/pyspark/pandas/tests/test_series_conversion.py
+++ b/python/pyspark/pandas/tests/test_series_conversion.py
@@ -21,10 +21,11 @@
import pandas as pd
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
-class SeriesConversionTest(ReusedSQLTestCase, SQLTestUtils):
+class SeriesConversionTest(PandasOnSparkTestCase, SQLTestUtils):
@property
def pser(self):
return pd.Series([1, 2, 3, 4, 5, 6, 7], name="x")
diff --git a/python/pyspark/pandas/tests/test_series_datetime.py b/python/pyspark/pandas/tests/test_series_datetime.py
index fc27c96edf..deb4497483 100644
--- a/python/pyspark/pandas/tests/test_series_datetime.py
+++ b/python/pyspark/pandas/tests/test_series_datetime.py
@@ -22,10 +22,11 @@
import pandas as pd
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
-class SeriesDateTimeTest(ReusedSQLTestCase, SQLTestUtils):
+class SeriesDateTimeTest(PandasOnSparkTestCase, SQLTestUtils):
@property
def pdf1(self):
date1 = pd.Series(pd.date_range("2012-1-1 12:45:31", periods=3, freq="M"))
diff --git a/python/pyspark/pandas/tests/test_series_string.py b/python/pyspark/pandas/tests/test_series_string.py
index 053c4d79be..69a9ab3424 100644
--- a/python/pyspark/pandas/tests/test_series_string.py
+++ b/python/pyspark/pandas/tests/test_series_string.py
@@ -20,10 +20,11 @@
import re
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
-class SeriesStringTest(ReusedSQLTestCase, SQLTestUtils):
+class SeriesStringTest(PandasOnSparkTestCase, SQLTestUtils):
@property
def pser(self):
return pd.Series(
diff --git a/python/pyspark/pandas/tests/test_sql.py b/python/pyspark/pandas/tests/test_sql.py
index 6d29beee97..6c3405f0f0 100644
--- a/python/pyspark/pandas/tests/test_sql.py
+++ b/python/pyspark/pandas/tests/test_sql.py
@@ -16,12 +16,12 @@
#
from pyspark import pandas as ps
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
-
from pyspark.sql.utils import ParseException
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
-class SQLTest(ReusedSQLTestCase, SQLTestUtils):
+class SQLTest(PandasOnSparkTestCase, SQLTestUtils):
def test_error_variable_not_exist(self):
msg = "The key variable_foo in the SQL statement was not found.*"
with self.assertRaisesRegex(ValueError, msg):
diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py
index 1a1885c845..905add4baa 100644
--- a/python/pyspark/pandas/tests/test_stats.py
+++ b/python/pyspark/pandas/tests/test_stats.py
@@ -27,14 +27,11 @@
from pyspark import pandas as ps
from pyspark.pandas.config import option_context
-from pyspark.pandas.testing.utils import (
- ReusedSQLTestCase,
- SQLTestUtils,
- SPARK_CONF_ARROW_ENABLED,
-)
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, SPARK_CONF_ARROW_ENABLED
+from pyspark.testing.sqlutils import SQLTestUtils
-class StatsTest(ReusedSQLTestCase, SQLTestUtils):
+class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
def _test_stat_functions(self, pdf_or_pser, kdf_or_kser):
functions = ["max", "min", "mean", "sum", "count"]
for funcname in functions:
diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py
index 8ab1e0324b..2f4039ba20 100644
--- a/python/pyspark/pandas/tests/test_utils.py
+++ b/python/pyspark/pandas/tests/test_utils.py
@@ -17,17 +17,18 @@
import pandas as pd
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
from pyspark.pandas.utils import (
lazy_property,
validate_arguments_and_invoke_function,
validate_bool_kwarg,
)
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
some_global_variable = 0
-class UtilsTest(ReusedSQLTestCase, SQLTestUtils):
+class UtilsTest(PandasOnSparkTestCase, SQLTestUtils):
# a dummy to_html version with an extra parameter that pandas does not support
# used in test_validate_arguments_and_invoke_function
diff --git a/python/pyspark/pandas/tests/test_window.py b/python/pyspark/pandas/tests/test_window.py
index 742b3b9cbd..8c347b8687 100644
--- a/python/pyspark/pandas/tests/test_window.py
+++ b/python/pyspark/pandas/tests/test_window.py
@@ -25,10 +25,10 @@
MissingPandasLikeExpandingGroupby,
MissingPandasLikeRollingGroupby,
)
-from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-class ExpandingRollingTest(ReusedSQLTestCase, TestUtils):
+class ExpandingRollingTest(PandasOnSparkTestCase, TestUtils):
def test_missing(self):
kdf = ps.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9]})
diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py
new file mode 100644
index 0000000000..4a5bfe8d56
--- /dev/null
+++ b/python/pyspark/testing/pandasutils.py
@@ -0,0 +1,373 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import functools
+import shutil
+import tempfile
+import unittest
+import warnings
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+
+import pandas as pd
+from pandas.api.types import is_list_like
+from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal
+
+from pyspark import pandas as ps
+from pyspark.pandas.frame import DataFrame
+from pyspark.pandas.indexes import Index
+from pyspark.pandas.series import Series
+from pyspark.pandas.utils import default_session, SPARK_CONF_ARROW_ENABLED
+from pyspark.testing.sqlutils import SQLTestUtils
+
+
+tabulate_requirement_message = None
+try:
+ from tabulate import tabulate # noqa: F401
+except ImportError as e:
+ # If tabulate requirement is not satisfied, skip related tests.
+ tabulate_requirement_message = str(e)
+have_tabulate = tabulate_requirement_message is None
+
+matplotlib_requirement_message = None
+try:
+ import matplotlib # type: ignore # noqa: F401
+except ImportError as e:
+ # If matplotlib requirement is not satisfied, skip related tests.
+ matplotlib_requirement_message = str(e)
+have_matplotlib = matplotlib_requirement_message is None
+
+plotly_requirement_message = None
+try:
+ import plotly # type: ignore # noqa: F401
+except ImportError as e:
+ # If plotly requirement is not satisfied, skip related tests.
+ plotly_requirement_message = str(e)
+have_plotly = plotly_requirement_message is None
+
+
+class PandasOnSparkTestCase(unittest.TestCase, SQLTestUtils):
+ @classmethod
+ def setUpClass(cls):
+ cls.spark = default_session()
+ cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)
+
+ @classmethod
+ def tearDownClass(cls):
+ # We don't stop Spark session to reuse across all tests.
+ # The Spark session will be started and stopped at PyTest session level.
+ # Please see databricks/koalas/conftest.py.
+ pass
+
+ def assertPandasEqual(self, left, right, check_exact=True):
+ if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+ try:
+ if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+ kwargs = dict(check_freq=False)
+ else:
+ kwargs = dict()
+
+ assert_frame_equal(
+ left,
+ right,
+ check_index_type=("equiv" if len(left.index) > 0 else False),
+ check_column_type=("equiv" if len(left.columns) > 0 else False),
+ check_exact=check_exact,
+ **kwargs
+ )
+ except AssertionError as e:
+ msg = (
+ str(e)
+ + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
+ + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
+ )
+ raise AssertionError(msg) from e
+ elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+ try:
+ if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+ kwargs = dict(check_freq=False)
+ else:
+ kwargs = dict()
+
+ assert_series_equal(
+ left,
+ right,
+ check_index_type=("equiv" if len(left.index) > 0 else False),
+ check_exact=check_exact,
+ **kwargs
+ )
+ except AssertionError as e:
+ msg = (
+ str(e)
+ + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+ + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+ )
+ raise AssertionError(msg) from e
+ elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+ try:
+ assert_index_equal(left, right, check_exact=check_exact)
+ except AssertionError as e:
+ msg = (
+ str(e)
+ + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+ + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+ )
+ raise AssertionError(msg) from e
+ else:
+ raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+ def assertPandasAlmostEqual(self, left, right):
+ """
+ This function checks if given pandas objects approximately same,
+ which means the conditions below:
+ - Both objects are nullable
+ - Compare floats rounding to the number of decimal places, 7 after
+ dropping missing values (NaN, NaT, None)
+ """
+ if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+ msg = (
+ "DataFrames are not almost equal: "
+ + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
+ + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
+ )
+ self.assertEqual(left.shape, right.shape, msg=msg)
+ for lcol, rcol in zip(left.columns, right.columns):
+ self.assertEqual(lcol, rcol, msg=msg)
+ for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()):
+ self.assertEqual(lnull, rnull, msg=msg)
+ for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()):
+ self.assertAlmostEqual(lval, rval, msg=msg)
+ self.assertEqual(left.columns.names, right.columns.names, msg=msg)
+ elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+ msg = (
+ "Series are not almost equal: "
+ + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+ + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+ )
+ self.assertEqual(left.name, right.name, msg=msg)
+ self.assertEqual(len(left), len(right), msg=msg)
+ for lnull, rnull in zip(left.isnull(), right.isnull()):
+ self.assertEqual(lnull, rnull, msg=msg)
+ for lval, rval in zip(left.dropna(), right.dropna()):
+ self.assertAlmostEqual(lval, rval, msg=msg)
+ elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex):
+ msg = (
+ "MultiIndices are not almost equal: "
+ + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+ + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+ )
+ self.assertEqual(len(left), len(right), msg=msg)
+ for lval, rval in zip(left, right):
+ self.assertAlmostEqual(lval, rval, msg=msg)
+ elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+ msg = (
+ "Indices are not almost equal: "
+ + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+ + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+ )
+ self.assertEqual(len(left), len(right), msg=msg)
+ for lnull, rnull in zip(left.isnull(), right.isnull()):
+ self.assertEqual(lnull, rnull, msg=msg)
+ for lval, rval in zip(left.dropna(), right.dropna()):
+ self.assertAlmostEqual(lval, rval, msg=msg)
+ else:
+ raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+ def assert_eq(self, left, right, check_exact=True, almost=False):
+ """
+ Asserts if two arbitrary objects are equal or not. If given objects are Koalas DataFrame
+ or Series, they are converted into pandas' and compared.
+
+ :param left: object to compare
+ :param right: object to compare
+ :param check_exact: if this is False, the comparison is done less precisely.
+ :param almost: if this is enabled, the comparison is delegated to `unittest`'s
+ `assertAlmostEqual`. See its documentation for more details.
+ """
+ lobj = self._to_pandas(left)
+ robj = self._to_pandas(right)
+ if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)):
+ if almost:
+ self.assertPandasAlmostEqual(lobj, robj)
+ else:
+ self.assertPandasEqual(lobj, robj, check_exact=check_exact)
+ elif is_list_like(lobj) and is_list_like(robj):
+ self.assertTrue(len(left) == len(right))
+ for litem, ritem in zip(left, right):
+ self.assert_eq(litem, ritem, check_exact=check_exact, almost=almost)
+ elif (lobj is not None and pd.isna(lobj)) and (robj is not None and pd.isna(robj)):
+ pass
+ else:
+ if almost:
+ self.assertAlmostEqual(lobj, robj)
+ else:
+ self.assertEqual(lobj, robj)
+
+ @staticmethod
+ def _to_pandas(obj):
+ if isinstance(obj, (DataFrame, Series, Index)):
+ return obj.to_pandas()
+ else:
+ return obj
+
+
+class TestUtils(object):
+ @contextmanager
+ def temp_dir(self):
+ tmp = tempfile.mkdtemp()
+ try:
+ yield tmp
+ finally:
+ shutil.rmtree(tmp)
+
+ @contextmanager
+ def temp_file(self):
+ with self.temp_dir() as tmp:
+ yield tempfile.mktemp(dir=tmp)
+
+
+class ComparisonTestBase(PandasOnSparkTestCase):
+ @property
+ def kdf(self):
+ return ps.from_pandas(self.pdf)
+
+ @property
+ def pdf(self):
+ return self.kdf.to_pandas()
+
+
+def compare_both(f=None, almost=True):
+
+ if f is None:
+ return functools.partial(compare_both, almost=almost)
+ elif isinstance(f, bool):
+ return functools.partial(compare_both, almost=f)
+
+ @functools.wraps(f)
+ def wrapped(self):
+ if almost:
+ compare = self.assertPandasAlmostEqual
+ else:
+ compare = self.assertPandasEqual
+
+ for result_pandas, result_spark in zip(f(self, self.pdf), f(self, self.kdf)):
+ compare(result_pandas, result_spark.to_pandas())
+
+ return wrapped
+
+
+@contextmanager
+def assert_produces_warning(
+ expected_warning=Warning,
+ filter_level="always",
+ check_stacklevel=True,
+ raise_on_extra_warnings=True,
+):
+ """
+ Context manager for running code expected to either raise a specific
+ warning, or not raise any warnings. Verifies that the code raises the
+ expected warning, and that it does not raise any other unexpected
+ warnings. It is basically a wrapper around ``warnings.catch_warnings``.
+
+ Notes
+ -----
+ Replicated from pandas/_testing/_warnings.py.
+
+ Parameters
+ ----------
+ expected_warning : {Warning, False, None}, default Warning
+ The type of Exception raised. ``exception.Warning`` is the base
+ class for all warnings. To check that no warning is returned,
+ specify ``False`` or ``None``.
+ filter_level : str or None, default "always"
+ Specifies whether warnings are ignored, displayed, or turned
+ into errors.
+ Valid values are:
+ * "error" - turns matching warnings into exceptions
+ * "ignore" - discard the warning
+ * "always" - always emit a warning
+ * "default" - print the warning the first time it is generated
+ from each location
+ * "module" - print the warning the first time it is generated
+ from each module
+ * "once" - print the warning the first time it is generated
+ check_stacklevel : bool, default True
+ If True, displays the line that called the function containing
+ the warning to show were the function is called. Otherwise, the
+ line that implements the function is displayed.
+ raise_on_extra_warnings : bool, default True
+ Whether extra warnings not of the type `expected_warning` should
+ cause the test to fail.
+
+ Examples
+ --------
+ >>> import warnings
+ >>> with assert_produces_warning():
+ ... warnings.warn(UserWarning())
+ ...
+ >>> with assert_produces_warning(False): # doctest: +SKIP
+ ... warnings.warn(RuntimeWarning())
+ ...
+ Traceback (most recent call last):
+ ...
+ AssertionError: Caused unexpected warning(s): ['RuntimeWarning'].
+ >>> with assert_produces_warning(UserWarning): # doctest: +SKIP
+ ... warnings.warn(RuntimeWarning())
+ Traceback (most recent call last):
+ ...
+ AssertionError: Did not see expected warning of class 'UserWarning'
+ ..warn:: This is *not* thread-safe.
+ """
+ __tracebackhide__ = True
+
+ with warnings.catch_warnings(record=True) as w:
+
+ saw_warning = False
+ warnings.simplefilter(filter_level)
+ yield w
+ extra_warnings = []
+
+ for actual_warning in w:
+ if expected_warning and issubclass(actual_warning.category, expected_warning):
+ saw_warning = True
+
+ if check_stacklevel and issubclass(
+ actual_warning.category, (FutureWarning, DeprecationWarning)
+ ):
+ from inspect import getframeinfo, stack
+
+ caller = getframeinfo(stack()[2][0])
+ msg = (
+ "Warning not set with correct stacklevel. ",
+ "File where warning is raised: {} != ".format(actual_warning.filename),
+ "{}. Warning message: {}".format(caller.filename, actual_warning.message),
+ )
+ assert actual_warning.filename == caller.filename, msg
+ else:
+ extra_warnings.append(
+ (
+ actual_warning.category.__name__,
+ actual_warning.message,
+ actual_warning.filename,
+ actual_warning.lineno,
+ )
+ )
+ if expected_warning:
+ msg = "Did not see expected warning of class {}".format(repr(expected_warning.__name__))
+ assert saw_warning, msg
+ if raise_on_extra_warnings and extra_warnings:
+ raise AssertionError("Caused unexpected warning(s): {}".format(repr(extra_warnings)))
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index f960aa4fee..bbd93d1c38 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+
import glob
import os
import struct
@@ -29,13 +30,13 @@
try:
import scipy.sparse # noqa: F401
have_scipy = True
-except:
+except ImportError:
# No SciPy, but that's okay, we'll skip those tests
pass
try:
import numpy as np # noqa: F401
have_numpy = True
-except:
+except ImportError:
# No NumPy, but that's okay, we'll skip those tests
pass
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
index 174c62183c..aa41b7ec2d 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
@@ -62,6 +62,14 @@ private[spark] object Config extends Logging {
.booleanConf
.createWithDefault(true)
+ val KUBERNETES_DRIVER_OWN_PVC =
+ ConfigBuilder("spark.kubernetes.driver.ownPersistentVolumeClaim")
+ .doc("If true, driver pod becomes the owner of on-demand persistent volume claims " +
+ "instead of the executor pods")
+ .version("3.2.0")
+ .booleanConf
+ .createWithDefault(false)
+
val KUBERNETES_NAMESPACE =
ConfigBuilder("spark.kubernetes.namespace")
.doc("The namespace that will be used for running the driver and executor pods.")
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
index c66756fd69..4e1647372e 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import io.fabric8.kubernetes.api.model._
import org.apache.spark.deploy.k8s._
-import org.apache.spark.deploy.k8s.Constants.ENV_EXECUTOR_ID
+import org.apache.spark.deploy.k8s.Constants.{ENV_EXECUTOR_ID, SPARK_APP_ID_LABEL}
private[spark] class MountVolumesFeatureStep(conf: KubernetesConf)
extends KubernetesFeatureConfigStep {
@@ -85,6 +85,7 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf)
.withApiVersion("v1")
.withNewMetadata()
.withName(claimName)
+ .addToLabels(SPARK_APP_ID_LABEL, conf.sparkConf.getAppId)
.endMetadata()
.withNewSpec()
.withStorageClassName(storageClass.get)
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala
index 5ebd172f7d..d54f665a38 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala
@@ -339,6 +339,9 @@ private[spark] class ExecutorPodsAllocator(
resources
.filter(_.getKind == "PersistentVolumeClaim")
.foreach { resource =>
+ if (conf.get(KUBERNETES_DRIVER_OWN_PVC) && driverPod.nonEmpty) {
+ addOwnerReference(driverPod.get, Seq(resource))
+ }
val pvc = resource.asInstanceOf[PersistentVolumeClaim]
logInfo(s"Trying to create PersistentVolumeClaim ${pvc.getMetadata.getName} with " +
s"StorageClass ${pvc.getSpec.getStorageClassName}")
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
index 780b08bd0e..d5a4856d37 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
@@ -134,6 +134,13 @@ private[spark] class KubernetesClusterSchedulerBackend(
}
}
+ Utils.tryLogNonFatalError {
+ kubernetesClient
+ .persistentVolumeClaims()
+ .withLabel(SPARK_APP_ID_LABEL, applicationId())
+ .delete()
+ }
+
if (shouldDeleteExecutors) {
Utils.tryLogNonFatalError {
kubernetesClient
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index c958f9c387..5566687054 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -606,7 +606,13 @@ groupByClause
;
groupingAnalytics
- : (ROLLUP | CUBE | GROUPING SETS) '(' groupingSet (',' groupingSet)* ')'
+ : (ROLLUP | CUBE) '(' groupingSet (',' groupingSet)* ')'
+ | GROUPING SETS '(' groupingElement (',' groupingElement)* ')'
+ ;
+
+groupingElement
+ : groupingAnalytics
+ | groupingSet
;
groupingSet
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java
new file mode 100644
index 0000000000..71e83002dd
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.metric;
+
+import org.apache.spark.annotation.Evolving;
+
+import java.util.Arrays;
+import java.text.DecimalFormat;
+
+/**
+ * Built-in `CustomMetric` that computes average of metric values. Note that please extend this
+ * class and override `name` and `description` to create your custom metric for real usage.
+ *
+ * @since 3.2.0
+ */
+@Evolving
+public abstract class CustomAvgMetric implements CustomMetric {
+ @Override
+ public String aggregateTaskMetrics(long[] taskMetrics) {
+ if (taskMetrics.length > 0) {
+ double average = ((double)Arrays.stream(taskMetrics).sum()) / taskMetrics.length;
+ return new DecimalFormat("#0.000").format(average);
+ } else {
+ return "0";
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomMetric.java
new file mode 100644
index 0000000000..4c4151ad96
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomMetric.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.metric;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * A custom metric. Data source can define supported custom metrics using this interface.
+ * During query execution, Spark will collect the task metrics using {@link CustomTaskMetric}
+ * and combine the metrics at the driver side. How to combine task metrics is defined by the
+ * metric class with the same metric name.
+ *
+ * @since 3.2.0
+ */
+@Evolving
+public interface CustomMetric {
+ /**
+ * Returns the name of custom metric.
+ */
+ String name();
+
+ /**
+ * Returns the description of custom metric.
+ */
+ String description();
+
+ /**
+ * The initial value of this metric.
+ */
+ long initialValue = 0L;
+
+ /**
+ * Given an array of task metric values, returns aggregated final metric value.
+ */
+ String aggregateTaskMetrics(long[] taskMetrics);
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java
new file mode 100644
index 0000000000..ba28e9b918
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.metric;
+
+import org.apache.spark.annotation.Evolving;
+
+import java.util.Arrays;
+
+/**
+ * Built-in `CustomMetric` that sums up metric values. Note that please extend this class
+ * and override `name` and `description` to create your custom metric for real usage.
+ *
+ * @since 3.2.0
+ */
+@Evolving
+public abstract class CustomSumMetric implements CustomMetric {
+ @Override
+ public String aggregateTaskMetrics(long[] taskMetrics) {
+ return String.valueOf(Arrays.stream(taskMetrics).sum());
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java
new file mode 100644
index 0000000000..1b6f04d927
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.metric;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.read.PartitionReader;
+
+/**
+ * A custom task metric. This is a logical representation of a metric reported by data sources
+ * at the executor side. During query execution, Spark will collect the task metrics per partition
+ * by {@link PartitionReader} and update internal metrics based on collected metric values.
+ * For streaming query, Spark will collect and combine metrics for a final result per micro batch.
+ *
+ * The metrics will be gathered during query execution back to the driver and then combined. How
+ * the task metrics are combined is defined by corresponding {@link CustomMetric} with same metric
+ * name. The final result will be shown up in the data source scan operator in Spark UI.
+ *
+ * @since 3.2.0
+ */
+@Evolving
+public interface CustomTaskMetric {
+ /**
+ * Returns the name of custom task metric.
+ */
+ String name();
+
+ /**
+ * Returns the long value of custom task metric.
+ */
+ long value();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
index d6cf070cf4..5286bbf9f8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
@@ -21,7 +21,7 @@
import java.io.IOException;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.CustomTaskMetric;
+import org.apache.spark.sql.connector.metric.CustomTaskMetric;
/**
* A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java
index b70a656c49..0c009f5c56 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java
@@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.read;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.CustomMetric;
+import org.apache.spark.sql.connector.metric.CustomMetric;
import org.apache.spark.sql.connector.read.streaming.ContinuousStream;
import org.apache.spark.sql.connector.read.streaming.MicroBatchStream;
import org.apache.spark.sql.types.StructType;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index f149c9bb0c..fe48670cb3 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -19,12 +19,16 @@
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.*;
+import org.apache.arrow.vector.holders.NullableIntervalDayHolder;
import org.apache.arrow.vector.holders.NullableVarCharHolder;
import org.apache.spark.sql.util.ArrowUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.UTF8String;
+import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY;
+import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS;
+
/**
* A column vector backed by Apache Arrow. Currently calendar interval type and map type are not
* supported.
@@ -172,6 +176,10 @@ public ArrowColumnVector(ValueVector vector) {
}
} else if (vector instanceof NullVector) {
accessor = new NullAccessor((NullVector) vector);
+ } else if (vector instanceof IntervalYearVector) {
+ accessor = new IntervalYearAccessor((IntervalYearVector) vector);
+ } else if (vector instanceof IntervalDayVector) {
+ accessor = new IntervalDayAccessor((IntervalDayVector) vector);
} else {
throw new UnsupportedOperationException();
}
@@ -508,4 +516,37 @@ private static class NullAccessor extends ArrowVectorAccessor {
super(vector);
}
}
+
+ private static class IntervalYearAccessor extends ArrowVectorAccessor {
+
+ private final IntervalYearVector accessor;
+
+ IntervalYearAccessor(IntervalYearVector vector) {
+ super(vector);
+ this.accessor = vector;
+ }
+
+ @Override
+ int getInt(int rowId) {
+ return accessor.get(rowId);
+ }
+ }
+
+ private static class IntervalDayAccessor extends ArrowVectorAccessor {
+
+ private final IntervalDayVector accessor;
+ private final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder();
+
+ IntervalDayAccessor(IntervalDayVector vector) {
+ super(vector);
+ this.accessor = vector;
+ }
+
+ @Override
+ long getLong(int rowId) {
+ accessor.get(rowId, intervalDayHolder);
+ return Math.addExact(Math.multiplyExact(intervalDayHolder.days, MICROS_PER_DAY),
+ intervalDayHolder.milliseconds * MICROS_PER_MILLIS);
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index b55d1b725f..ccf0a50b73 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -304,18 +304,23 @@ object CatalystTypeConverters {
row.getUTF8String(column).toString
}
- private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
- override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue)
+ private object DateConverter extends CatalystTypeConverter[Any, Date, Any] {
+ override def toCatalystImpl(scalaValue: Any): Int = scalaValue match {
+ case d: Date => DateTimeUtils.fromJavaDate(d)
+ case l: LocalDate => DateTimeUtils.localDateToDays(l)
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to the ${DateType.sql} type")
+ }
override def toScala(catalystValue: Any): Date =
if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int])
override def toScalaImpl(row: InternalRow, column: Int): Date =
DateTimeUtils.toJavaDate(row.getInt(column))
}
- private object LocalDateConverter extends CatalystTypeConverter[LocalDate, LocalDate, Any] {
- override def toCatalystImpl(scalaValue: LocalDate): Int = {
- DateTimeUtils.localDateToDays(scalaValue)
- }
+ private object LocalDateConverter extends CatalystTypeConverter[Any, LocalDate, Any] {
+ override def toCatalystImpl(scalaValue: Any): Int =
+ DateConverter.toCatalystImpl(scalaValue)
override def toScala(catalystValue: Any): LocalDate = {
if (catalystValue == null) null
else DateTimeUtils.daysToLocalDate(catalystValue.asInstanceOf[Int])
@@ -324,9 +329,14 @@ object CatalystTypeConverters {
DateTimeUtils.daysToLocalDate(row.getInt(column))
}
- private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] {
- override def toCatalystImpl(scalaValue: Timestamp): Long =
- DateTimeUtils.fromJavaTimestamp(scalaValue)
+ private object TimestampConverter extends CatalystTypeConverter[Any, Timestamp, Any] {
+ override def toCatalystImpl(scalaValue: Any): Long = scalaValue match {
+ case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t)
+ case i: Instant => DateTimeUtils.instantToMicros(i)
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to the ${TimestampType.sql} type")
+ }
override def toScala(catalystValue: Any): Timestamp =
if (catalystValue == null) null
else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long])
@@ -334,9 +344,9 @@ object CatalystTypeConverters {
DateTimeUtils.toJavaTimestamp(row.getLong(column))
}
- private object InstantConverter extends CatalystTypeConverter[Instant, Instant, Any] {
- override def toCatalystImpl(scalaValue: Instant): Long =
- DateTimeUtils.instantToMicros(scalaValue)
+ private object InstantConverter extends CatalystTypeConverter[Any, Instant, Any] {
+ override def toCatalystImpl(scalaValue: Any): Long =
+ TimestampConverter.toCatalystImpl(scalaValue)
override def toScala(catalystValue: Any): Instant =
if (catalystValue == null) null
else DateTimeUtils.microsToInstant(catalystValue.asInstanceOf[Long])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c2c146c7de..87b8d52ac2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -39,9 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
-import org.apache.spark.sql.catalyst.trees.TreePattern.{
- EXPRESSION_WITH_RANDOM_SEED, NATURAL_LIKE_JOIN, WINDOW_EXPRESSION
-}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
@@ -2179,7 +2177,8 @@ class Analyzer(override val catalogManager: CatalogManager)
* outer plan to get evaluated.
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
- plan transformExpressions {
+ plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY,
+ EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
@@ -2196,7 +2195,8 @@ class Analyzer(override val catalogManager: CatalogManager)
/**
* Resolve and rewrite all subqueries in an operator tree..
*/
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
+ _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
// its child for resolution.
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
@@ -3790,9 +3790,9 @@ object UpdateOuterReferences extends Rule[LogicalPlan] {
}
def apply(plan: LogicalPlan): LogicalPlan = {
- plan resolveOperators {
+ plan.resolveOperatorsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, FILTER), ruleId) {
case f @ Filter(_, a: Aggregate) if f.resolved =>
- f transformExpressions {
+ f.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
case s: SubqueryExpression if s.children.nonEmpty =>
// Collect the aliases from output of aggregate.
val outerAliases = a.aggregateExpressions collect { case a: Alias => a }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala
index e9673d7f20..1f3f762662 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala
@@ -35,7 +35,7 @@ trait AliasHelper {
protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = {
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression or PythonUDF, and create a map from the alias to the expression
- val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect {
+ val aliasMap = plan.aggregateExpressions.collect {
case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
(a.toAttribute, a)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 5d799c768a..30317c9e91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvable
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CAST, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@@ -1800,6 +1801,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
+ final override val nodePatterns: Seq[TreePattern] = Seq(CAST)
+
override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled
override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
index de4b874637..1c185dd316 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
trait DynamicPruning extends Predicate
@@ -69,6 +70,8 @@ case class DynamicPruningSubquery(
pruningKey.dataType == buildKeys(broadcastKeyIndex).dataType
}
+ final override def nodePatternsInternal: Seq[TreePattern] = Seq(DYNAMIC_PRUNING_SUBQUERY)
+
override def toString: String = s"dynamicpruning#${exprId.id} $conditionString"
override lazy val canonicalized: DynamicPruning = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index 03b5517f6d..a6be98c8a3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -42,7 +42,8 @@ case class ProjectionOverSchema(schema: StructType) {
getProjection(a.child).map(p => (p, p.dataType)).map {
case (projection, ArrayType(projSchema @ StructType(_), _)) =>
// For case-sensitivity aware field resolution, we should take `ordinal` which
- // points to correct struct field.
+ // points to correct struct field, because `ExtractValue` actually does column
+ // name resolving correctly.
val selectedField = a.child.dataType.asInstanceOf[ArrayType]
.elementType.asInstanceOf[StructType](a.ordinal)
val prunedField = projSchema(selectedField.name)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index 4ee6488c92..30093ef085 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -22,9 +22,12 @@ import org.apache.spark.sql.types._
object SchemaPruning extends SQLConfHelper {
/**
- * Filters the schema by the requested fields. For example, if the schema is struct,
- * and given requested field are "a", the field "b" is pruned in the returned schema.
- * Note that schema field ordering at original schema is still preserved in pruned schema.
+ * Prunes the nested schema by the requested fields. For example, if the schema is:
+ * `id int, s struct`, and given requested field "s.a", the inner field "b"
+ * is pruned in the returned schema: `id int, s struct`.
+ * Note that:
+ * 1. The schema field ordering at original schema is still preserved in pruned schema.
+ * 2. The top-level fields are not pruned here.
*/
def pruneDataSchema(
dataSchema: StructType,
@@ -34,11 +37,10 @@ object SchemaPruning extends SQLConfHelper {
// in the resulting schema may differ from their ordering in the logical relation's
// original schema
val mergedSchema = requestedRootFields
- .map { case root: RootField => StructType(Array(root.field)) }
+ .map { root: RootField => StructType(Array(root.field)) }
.reduceLeft(_ merge _)
- val dataSchemaFieldNames = dataSchema.fieldNames.toSet
val mergedDataSchema =
- StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name))))
+ StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 4fc0256bce..8ae24e5135 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -87,8 +87,12 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal()(
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
- case _: YearMonthIntervalType => DivideYMInterval(sum, count)
- case _: DayTimeIntervalType => DivideDTInterval(sum, count)
+ case _: YearMonthIntervalType =>
+ If(EqualTo(count, Literal(0L)),
+ Literal(null, YearMonthIntervalType), DivideYMInterval(sum, count))
+ case _: DayTimeIntervalType =>
+ If(EqualTo(count, Literal(0L)),
+ Literal(null, DayTimeIntervalType), DivideDTInterval(sum, count))
case _ =>
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 1d13155ef6..dfdd828d10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT, TreePattern}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -48,6 +49,8 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(COUNT)
+
// Return data type.
override def dataType: DataType = LongType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 8c70c86aa1..281734c6f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -80,14 +80,6 @@ object AggregateExpression {
filter,
NamedExpression.newExprId)
}
-
- def containsAggregate(expr: Expression): Boolean = {
- expr.find(isAggregate).isDefined
- }
-
- def isAggregate(expr: Expression): Boolean = {
- expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
- }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 954a4b9fc1..10b4a7be30 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern,
+ UNARY_POSITIVE}
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
@@ -128,6 +130,8 @@ case class UnaryPositive(child: Expression)
override def dataType: DataType = child.dataType
+ final override val nodePatterns: Seq[TreePattern] = Seq(UNARY_POSITIVE)
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c => c)
@@ -199,6 +203,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
override def dataType: DataType = left.dataType
+ final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC)
+
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
/** Name of the function for this expression on a [[Decimal]] type. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 125e796a98..a408280a3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.catalyst.expressions
-import java.time.ZoneId
+import java.time.{Duration, Period, ZoneId}
import java.util.Comparator
import scala.collection.mutable
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CONCAT, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@@ -2172,6 +2173,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)
+ final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)
+
override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
@@ -2484,8 +2487,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran
The start and stop expressions must resolve to the same type.
If start and stop expressions resolve to the 'date' or 'timestamp' type
- then the step expression must resolve to the 'interval' type, otherwise to the same type
- as the start and stop expressions.
+ then the step expression must resolve to the 'interval' or 'year-month interval' or
+ 'day-time interval' type, otherwise to the same type as the start and stop expressions.
""",
arguments = """
Arguments:
@@ -2504,6 +2507,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran
[5,4,3,2,1]
> SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month);
[2018-01-01,2018-02-01,2018-03-01]
+ > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval '0-1' year to month);
+ [2018-01-01,2018-02-01,2018-03-01]
""",
group = "array_funcs",
since = "2.4.0"
@@ -2550,8 +2555,13 @@ case class Sequence(
val typesCorrect =
startType.sameType(stop.dataType) &&
(startType match {
- case TimestampType | DateType =>
- stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType)
+ case TimestampType =>
+ stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) ||
+ YearMonthIntervalType.acceptsType(stepType) ||
+ DayTimeIntervalType.acceptsType(stepType)
+ case DateType =>
+ stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) ||
+ YearMonthIntervalType.acceptsType(stepType)
case _: IntegralType =>
stepOpt.isEmpty || stepType.sameType(startType)
case _ => false
@@ -2561,29 +2571,51 @@ case class Sequence(
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
- s"$prettyName only supports integral, timestamp or date types")
+ s"""
+ |$prettyName uses the wrong parameter type. The parameter type must conform to:
+ |1. The start and stop expressions must resolve to the same type.
+ |2. If start and stop expressions resolve to the 'date' or 'timestamp' type
+ |then the step expression must resolve to the 'interval' or
+ |'${YearMonthIntervalType.typeName}' or '${DayTimeIntervalType.typeName}' type,
+ |otherwise to the same type as the start and stop expressions.
+ """.stripMargin)
}
}
- def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType)
+ private def isNotIntervalType(expr: Expression) = expr.dataType match {
+ case CalendarIntervalType | YearMonthIntervalType | DayTimeIntervalType => false
+ case _ => true
+ }
+
+ def coercibleChildren: Seq[Expression] = children.filter(isNotIntervalType)
def castChildrenTo(widerType: DataType): Expression = Sequence(
Cast(start, widerType),
Cast(stop, widerType),
- stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step),
+ stepOpt.map(step => if (isNotIntervalType(step)) Cast(step, widerType) else step),
timeZoneId)
- @transient private lazy val impl: SequenceImpl = dataType.elementType match {
+ @transient private lazy val impl: InternalSequence = dataType.elementType match {
case iType: IntegralType =>
type T = iType.InternalType
val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
new IntegralSequenceImpl(iType)(ct, iType.integral)
case TimestampType =>
- new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId)
+ if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) {
+ new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId)
+ } else if (YearMonthIntervalType.acceptsType(stepOpt.get.dataType)) {
+ new PeriodSequenceImpl[Long](LongType, 1, identity, zoneId)
+ } else {
+ new DurationSequenceImpl[Long](LongType, 1, identity, zoneId)
+ }
case DateType =>
- new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId)
+ if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) {
+ new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId)
+ } else {
+ new PeriodSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId)
+ }
}
override def eval(input: InternalRow): Any = {
@@ -2666,7 +2698,7 @@ object Sequence {
}
}
- private trait SequenceImpl {
+ private trait InternalSequence {
def eval(start: Any, stop: Any, step: Any): Any
def genCode(
@@ -2681,7 +2713,7 @@ object Sequence {
}
private class IntegralSequenceImpl[T: ClassTag]
- (elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl {
+ (elemType: IntegralType)(implicit num: Integral[T]) extends InternalSequence {
override val defaultStep: DefaultStep = new DefaultStep(
(elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
@@ -2695,7 +2727,7 @@ object Sequence {
val stop = input2.asInstanceOf[T]
val step = input3.asInstanceOf[T]
- var i: Int = getSequenceLength(start, stop, step)
+ var i: Int = getSequenceLength(start, stop, step, step)
val arr = new Array[T](i)
while (i > 0) {
i -= 1
@@ -2713,7 +2745,7 @@ object Sequence {
elemType: String): String = {
val i = ctx.freshName("i")
s"""
- |${genSequenceLengthCode(ctx, start, stop, step, i)}
+ |${genSequenceLengthCode(ctx, start, stop, step, step, i)}
|$arr = new $elemType[$i];
|while ($i > 0) {
| $i--;
@@ -2723,32 +2755,105 @@ object Sequence {
}
}
+ private class PeriodSequenceImpl[T: ClassTag]
+ (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
+ (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) {
+
+ override val defaultStep: DefaultStep = new DefaultStep(
+ (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
+ YearMonthIntervalType,
+ Period.of(0, 1, 0))
+
+ val intervalType: DataType = YearMonthIntervalType
+
+ def splitStep(input: Any): (Int, Int, Long) = {
+ (input.asInstanceOf[Int], 0, 0)
+ }
+
+ def stepSplitCode(
+ stepMonths: String, stepDays: String, stepMicros: String, step: String): String = {
+ s"""
+ |final int $stepMonths = $step;
+ |final int $stepDays = 0;
+ |final long $stepMicros = 0L;
+ """.stripMargin
+ }
+ }
+
+ private class DurationSequenceImpl[T: ClassTag]
+ (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
+ (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) {
+
+ override val defaultStep: DefaultStep = new DefaultStep(
+ (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
+ DayTimeIntervalType,
+ Duration.ofDays(1))
+
+ val intervalType: DataType = DayTimeIntervalType
+
+ def splitStep(input: Any): (Int, Int, Long) = {
+ (0, 0, input.asInstanceOf[Long])
+ }
+
+ def stepSplitCode(
+ stepMonths: String, stepDays: String, stepMicros: String, step: String): String = {
+ s"""
+ |final int $stepMonths = 0;
+ |final int $stepDays = 0;
+ |final long $stepMicros = $step;
+ """.stripMargin
+ }
+ }
+
private class TemporalSequenceImpl[T: ClassTag]
(dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
- (implicit num: Integral[T]) extends SequenceImpl {
+ (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) {
override val defaultStep: DefaultStep = new DefaultStep(
(dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
CalendarIntervalType,
new CalendarInterval(0, 1, 0))
+ val intervalType: DataType = CalendarIntervalType
+
+ def splitStep(input: Any): (Int, Int, Long) = {
+ val step = input.asInstanceOf[CalendarInterval]
+ (step.months, step.days, step.microseconds)
+ }
+
+ def stepSplitCode(
+ stepMonths: String, stepDays: String, stepMicros: String, step: String): String = {
+ s"""
+ |final int $stepMonths = $step.months;
+ |final int $stepDays = $step.days;
+ |final long $stepMicros = $step.microseconds;
+ """.stripMargin
+ }
+ }
+
+ private abstract class InternalSequenceBase[T: ClassTag]
+ (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
+ (implicit num: Integral[T]) extends InternalSequence {
+
+ val defaultStep: DefaultStep
+
private val backedSequenceImpl = new IntegralSequenceImpl[T](dt)
- private val microsPerDay = HOURS_PER_DAY * MICROS_PER_HOUR
// We choose a minimum days(28) in one month to calculate the `intervalStepInMicros`
// in order to make sure the estimated array length is long enough
- private val microsPerMonth = 28 * microsPerDay
+ private val microsPerMonth = 28 * MICROS_PER_DAY
+
+ protected val intervalType: DataType
+
+ protected def splitStep(input: Any): (Int, Int, Long)
override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
val start = input1.asInstanceOf[T]
val stop = input2.asInstanceOf[T]
- val step = input3.asInstanceOf[CalendarInterval]
- val stepMonths = step.months
- val stepDays = step.days
- val stepMicros = step.microseconds
+ val (stepMonths, stepDays, stepMicros) = splitStep(input3)
if (scale == MICROS_PER_DAY && stepMonths == 0 && stepDays == 0) {
throw new IllegalArgumentException(
- "sequence step must be a day interval if start and end values are dates")
+ s"sequence step must be a day ${intervalType.typeName} if start and end values are dates")
}
if (stepMonths == 0 && stepMicros == 0 && scale == MICROS_PER_DAY) {
@@ -2763,11 +2868,12 @@ object Sequence {
// To estimate the resulted array length we need to make assumptions
// about a month length in days and a day length in microseconds
val intervalStepInMicros =
- stepMicros + stepMonths * microsPerMonth + stepDays * microsPerDay
+ stepMicros + stepMonths * microsPerMonth + stepDays * MICROS_PER_DAY
val startMicros: Long = num.toLong(start) * scale
val stopMicros: Long = num.toLong(stop) * scale
+
val maxEstimatedArrayLength =
- getSequenceLength(startMicros, stopMicros, intervalStepInMicros)
+ getSequenceLength(startMicros, stopMicros, input3, intervalStepInMicros)
val stepSign = if (stopMicros >= startMicros) +1 else -1
val exclusiveItem = stopMicros + stepSign
@@ -2787,6 +2893,9 @@ object Sequence {
}
}
+ protected def stepSplitCode(
+ stepMonths: String, stepDays: String, stepMicros: String, step: String): String
+
override def genCode(
ctx: CodegenContext,
start: String,
@@ -2811,25 +2920,27 @@ object Sequence {
val sequenceLengthCode =
s"""
|final long $intervalInMicros =
- | $stepMicros + $stepMonths * ${microsPerMonth}L + $stepDays * ${microsPerDay}L;
- |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)}
- """.stripMargin
+ | $stepMicros + $stepMonths * ${microsPerMonth}L + $stepDays * ${MICROS_PER_DAY}L;
+ |${genSequenceLengthCode(
+ ctx, startMicros, stopMicros, step, intervalInMicros, arrLength)}
+ """.stripMargin
val check = if (scale == MICROS_PER_DAY) {
s"""
|if ($stepMonths == 0 && $stepDays == 0) {
| throw new IllegalArgumentException(
- | "sequence step must be a day interval if start and end values are dates");
+ | "sequence step must be a day ${intervalType.typeName} " +
+ | "if start and end values are dates");
|}
- """.stripMargin
+ """.stripMargin
} else {
""
}
+ val stepSplits = stepSplitCode(stepMonths, stepDays, stepMicros, step)
+
s"""
- |final int $stepMonths = $step.months;
- |final int $stepDays = $step.days;
- |final long $stepMicros = $step.microseconds;
+ |$stepSplits
|
|$check
|
@@ -2866,15 +2977,16 @@ object Sequence {
}
}
- private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = {
+ private def getSequenceLength[U](start: U, stop: U, step: Any, estimatedStep: U)
+ (implicit num: Integral[U]): Int = {
import num._
require(
- (step > num.zero && start <= stop)
- || (step < num.zero && start >= stop)
- || (step == num.zero && start == stop),
+ (estimatedStep > num.zero && start <= stop)
+ || (estimatedStep < num.zero && start >= stop)
+ || (estimatedStep == num.zero && start == stop),
s"Illegal sequence boundaries: $start to $stop by $step")
- val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong
+ val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
require(
len <= MAX_ROUNDED_ARRAY_LENGTH,
@@ -2888,16 +3000,17 @@ object Sequence {
start: String,
stop: String,
step: String,
+ estimatedStep: String,
len: String): String = {
val longLen = ctx.freshName("longLen")
s"""
- |if (!(($step > 0 && $start <= $stop) ||
- | ($step < 0 && $start >= $stop) ||
- | ($step == 0 && $start == $stop))) {
+ |if (!(($estimatedStep > 0 && $start <= $stop) ||
+ | ($estimatedStep < 0 && $start >= $stop) ||
+ | ($estimatedStep == 0 && $start == $stop))) {
| throw new IllegalArgumentException(
| "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step);
|}
- |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step;
+ |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep;
|if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
| throw new IllegalArgumentException(
| "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH");
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index e708d56cd8..3e356f1e8a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CASE_WHEN, IF, TreePattern}
import org.apache.spark.sql.types._
// scalastyle:off line.size.limit
@@ -48,6 +49,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def third: Expression = falseValue
override def nullable: Boolean = trueValue.nullable || falseValue.nullable
+ final override val nodePatterns : Seq[TreePattern] = Seq(IF)
+
override def checkInputDataTypes(): TypeCheckResult = {
if (predicate.dataType != BooleanType) {
TypeCheckResult.TypeCheckFailure(
@@ -139,6 +142,8 @@ case class CaseWhen(
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
+ final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)
+
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
super.legacyWithNewChildren(newChildren)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 010b9b0fae..e69bf2e5f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -2378,15 +2378,15 @@ object DatePart {
Literal(null, DoubleType)
} else {
val fieldStr = fieldEval.asInstanceOf[UTF8String].toString
- val analysisException = QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(
- fieldStr, source)
- if (source.dataType == CalendarIntervalType) {
- ExtractIntervalPart.parseExtractField(
- fieldStr,
- source,
- throw analysisException)
- } else {
- DatePart.parseExtractField(fieldStr, source, throw analysisException)
+
+ def analysisException =
+ throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(fieldStr, source)
+
+ source.dataType match {
+ case YearMonthIntervalType | DayTimeIntervalType | CalendarIntervalType =>
+ ExtractIntervalPart.parseExtractField(fieldStr, source, analysisException)
+ case _ =>
+ DatePart.parseExtractField(fieldStr, source, analysisException)
}
}
}
@@ -2414,6 +2414,10 @@ object DatePart {
5
> SELECT _FUNC_('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds);
30.001001
+ > SELECT _FUNC_('MONTH', INTERVAL '2021-11' YEAR TO MONTH);
+ 11
+ > SELECT _FUNC_('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND);
+ 55
""",
note = """
The _FUNC_ function is equivalent to the SQL-standard function `EXTRACT(field FROM source)`
@@ -2479,6 +2483,10 @@ case class DatePart(field: Expression, source: Expression, child: Expression)
5
> SELECT _FUNC_(seconds FROM interval 5 hours 30 seconds 1 milliseconds 1 microseconds);
30.001001
+ > SELECT _FUNC_(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH);
+ 11
+ > SELECT _FUNC_(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND);
+ 55
""",
note = """
The _FUNC_ function is equivalent to `date_part(field, source)`.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
index 808c8222d2..aff1806582 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
@@ -214,9 +214,9 @@ case class Grouping(child: Expression) extends Expression with Unevaluable
Examples:
> SELECT name, _FUNC_(), sum(age), avg(height) FROM VALUES (2, 'Alice', 165), (5, 'Bob', 180) people(age, name, height) GROUP BY cube(name, height);
Alice 0 2 165.0
+ Bob 0 5 180.0
Alice 1 2 165.0
NULL 3 7 172.5
- Bob 0 5 180.0
Bob 1 5 180.0
NULL 2 2 165.0
NULL 2 5 180.0
@@ -277,22 +277,3 @@ object GroupingAnalytics {
}
}
}
-
-/**
- * A reference to an grouping expression in [[Aggregate]] node.
- *
- * @param ordinal The ordinal of the grouping expression in [[Aggregate]] that this expression
- * refers to.
- * @param dataType The [[DataType]] of the referenced grouping expression.
- * @param nullable True if null is a valid value for the referenced grouping expression.
- */
-case class GroupingExprRef(
- ordinal: Int,
- dataType: DataType,
- nullable: Boolean)
- extends LeafExpression with Unevaluable {
-
- override def stringArgs: Iterator[Any] = {
- Iterator(ordinal)
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index b34bcaf5ce..94ca6cc65d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -25,78 +25,131 @@ import com.google.common.math.{DoubleMath, IntMath, LongMath}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils._
+import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
-abstract class ExtractIntervalPart(
- child: Expression,
+abstract class ExtractIntervalPart[T](
val dataType: DataType,
- func: CalendarInterval => Any,
- funcName: String)
- extends UnaryExpression with ExpectsInputTypes with NullIntolerant with Serializable {
-
- override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType)
-
- override protected def nullSafeEval(interval: Any): Any = {
- func(interval.asInstanceOf[CalendarInterval])
- }
-
+ func: T => Any,
+ funcName: String) extends UnaryExpression with NullIntolerant with Serializable {
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$iu.$funcName($c)")
}
+
+ override protected def nullSafeEval(interval: Any): Any = {
+ func(interval.asInstanceOf[T])
+ }
}
case class ExtractIntervalYears(child: Expression)
- extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") {
+ extends ExtractIntervalPart[CalendarInterval](IntegerType, getYears, "getYears") {
override protected def withNewChildInternal(newChild: Expression): ExtractIntervalYears =
copy(child = newChild)
}
case class ExtractIntervalMonths(child: Expression)
- extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") {
+ extends ExtractIntervalPart[CalendarInterval](ByteType, getMonths, "getMonths") {
override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMonths =
copy(child = newChild)
}
case class ExtractIntervalDays(child: Expression)
- extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") {
+ extends ExtractIntervalPart[CalendarInterval](IntegerType, getDays, "getDays") {
override protected def withNewChildInternal(newChild: Expression): ExtractIntervalDays =
copy(child = newChild)
}
case class ExtractIntervalHours(child: Expression)
- extends ExtractIntervalPart(child, LongType, getHours, "getHours") {
+ extends ExtractIntervalPart[CalendarInterval](ByteType, getHours, "getHours") {
override protected def withNewChildInternal(newChild: Expression): ExtractIntervalHours =
copy(child = newChild)
}
case class ExtractIntervalMinutes(child: Expression)
- extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") {
+ extends ExtractIntervalPart[CalendarInterval](ByteType, getMinutes, "getMinutes") {
override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMinutes =
copy(child = newChild)
}
case class ExtractIntervalSeconds(child: Expression)
- extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds") {
+ extends ExtractIntervalPart[CalendarInterval](DecimalType(8, 6), getSeconds, "getSeconds") {
override protected def withNewChildInternal(newChild: Expression): ExtractIntervalSeconds =
copy(child = newChild)
}
+case class ExtractANSIIntervalYears(child: Expression)
+ extends ExtractIntervalPart[Int](IntegerType, getYears, "getYears") {
+ override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalYears =
+ copy(child = newChild)
+}
+
+case class ExtractANSIIntervalMonths(child: Expression)
+ extends ExtractIntervalPart[Int](ByteType, getMonths, "getMonths") {
+ override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalMonths =
+ copy(child = newChild)
+}
+
+case class ExtractANSIIntervalDays(child: Expression)
+ extends ExtractIntervalPart[Long](IntegerType, getDays, "getDays") {
+ override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalDays = {
+ copy(child = newChild)
+ }
+}
+
+case class ExtractANSIIntervalHours(child: Expression)
+ extends ExtractIntervalPart[Long](ByteType, getHours, "getHours") {
+ override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalHours =
+ copy(child = newChild)
+}
+
+case class ExtractANSIIntervalMinutes(child: Expression)
+ extends ExtractIntervalPart[Long](ByteType, getMinutes, "getMinutes") {
+ override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalMinutes =
+ copy(child = newChild)
+}
+
+case class ExtractANSIIntervalSeconds(child: Expression)
+ extends ExtractIntervalPart[Long](DecimalType(8, 6), getSeconds, "getSeconds") {
+ override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalSeconds =
+ copy(child = newChild)
+}
+
object ExtractIntervalPart {
def parseExtractField(
extractField: String,
source: Expression,
- errorHandleFunc: => Nothing): Expression = extractField.toUpperCase(Locale.ROOT) match {
- case "YEAR" | "Y" | "YEARS" | "YR" | "YRS" => ExtractIntervalYears(source)
- case "MONTH" | "MON" | "MONS" | "MONTHS" => ExtractIntervalMonths(source)
- case "DAY" | "D" | "DAYS" => ExtractIntervalDays(source)
- case "HOUR" | "H" | "HOURS" | "HR" | "HRS" => ExtractIntervalHours(source)
- case "MINUTE" | "M" | "MIN" | "MINS" | "MINUTES" => ExtractIntervalMinutes(source)
- case "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => ExtractIntervalSeconds(source)
- case _ => errorHandleFunc
+ errorHandleFunc: => Nothing): Expression = {
+ (extractField.toUpperCase(Locale.ROOT), source.dataType) match {
+ case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType) =>
+ ExtractANSIIntervalYears(source)
+ case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) =>
+ ExtractIntervalYears(source)
+ case ("MONTH" | "MON" | "MONS" | "MONTHS", YearMonthIntervalType) =>
+ ExtractANSIIntervalMonths(source)
+ case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) =>
+ ExtractIntervalMonths(source)
+ case ("DAY" | "D" | "DAYS", DayTimeIntervalType) =>
+ ExtractANSIIntervalDays(source)
+ case ("DAY" | "D" | "DAYS", CalendarIntervalType) =>
+ ExtractIntervalDays(source)
+ case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", DayTimeIntervalType) =>
+ ExtractANSIIntervalHours(source)
+ case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", CalendarIntervalType) =>
+ ExtractIntervalHours(source)
+ case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", DayTimeIntervalType) =>
+ ExtractANSIIntervalMinutes(source)
+ case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", CalendarIntervalType) =>
+ ExtractIntervalMinutes(source)
+ case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", DayTimeIntervalType) =>
+ ExtractANSIIntervalSeconds(source)
+ case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) =>
+ ExtractIntervalSeconds(source)
+ case _ => errorHandleFunc
+ }
}
}
@@ -391,11 +444,22 @@ case class MultiplyDTInterval(
copy(interval = newLeft, num = newRight)
}
+trait IntervalDivide {
+ def checkDivideOverflow(value: Any, minValue: Any, num: Expression, numValue: Any): Unit = {
+ if (value == minValue && num.dataType.isInstanceOf[IntegralType]) {
+ if (numValue.asInstanceOf[Number].longValue() == -1) {
+ throw QueryExecutionErrors.overflowInIntegralDivideError()
+ }
+ }
+ }
+}
+
// Divide an year-month interval by a numeric
case class DivideYMInterval(
interval: Expression,
num: Expression)
- extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
+ extends BinaryExpression with ImplicitCastInputTypes with IntervalDivide
+ with NullIntolerant with Serializable {
override def left: Expression = interval
override def right: Expression = num
@@ -418,20 +482,31 @@ case class DivideYMInterval(
}
override def nullSafeEval(interval: Any, num: Any): Any = {
+ checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num)
evalFunc(interval.asInstanceOf[Int], num)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = right.dataType match {
- case LongType =>
- val math = classOf[LongMath].getName
+ case t: IntegralType =>
+ val math = t match {
+ case LongType => classOf[LongMath].getName
+ case _ => classOf[IntMath].getName
+ }
val javaType = CodeGenerator.javaType(dataType)
- defineCodeGen(ctx, ev, (m, n) =>
+ val months = left.genCode(ctx)
+ val num = right.genCode(ctx)
+ val checkIntegralDivideOverflow =
+ s"""
+ |if (${months.value} == ${Int.MinValue} && ${num.value} == -1)
+ | throw QueryExecutionErrors.overflowInIntegralDivideError();
+ |""".stripMargin
+ nullSafeCodeGen(ctx, ev, (m, n) =>
// Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`.
// Casting to `Int` is safe here.
- s"($javaType)($math.divide($m, $n, java.math.RoundingMode.HALF_UP))")
- case _: IntegralType =>
- val math = classOf[IntMath].getName
- defineCodeGen(ctx, ev, (m, n) => s"$math.divide($m, $n, java.math.RoundingMode.HALF_UP)")
+ s"""
+ |$checkIntegralDivideOverflow
+ |${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP);
+ """.stripMargin)
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
@@ -454,7 +529,8 @@ case class DivideYMInterval(
case class DivideDTInterval(
interval: Expression,
num: Expression)
- extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
+ extends BinaryExpression with ImplicitCastInputTypes with IntervalDivide
+ with NullIntolerant with Serializable {
override def left: Expression = interval
override def right: Expression = num
@@ -473,13 +549,25 @@ case class DivideDTInterval(
}
override def nullSafeEval(interval: Any, num: Any): Any = {
+ checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num)
evalFunc(interval.asInstanceOf[Long], num)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = right.dataType match {
case _: IntegralType =>
val math = classOf[LongMath].getName
- defineCodeGen(ctx, ev, (m, n) => s"$math.divide($m, $n, java.math.RoundingMode.HALF_UP)")
+ val micros = left.genCode(ctx)
+ val num = right.genCode(ctx)
+ val checkIntegralDivideOverflow =
+ s"""
+ |if (${micros.value} == ${Long.MinValue}L && ${num.value} == -1L)
+ | throw QueryExecutionErrors.overflowInIntegralDivideError();
+ |""".stripMargin
+ nullSafeCodeGen(ctx, ev, (m, n) =>
+ s"""
+ |$checkIntegralDivideOverflow
+ |${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP);
+ """.stripMargin)
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 2c2df6bf43..d4a02c7fc8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -345,6 +346,8 @@ case class NaNvl(left: Expression, right: Expression)
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
+
override def eval(input: InternalRow): Any = {
child.eval(input) == null
}
@@ -375,6 +378,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
+
override def eval(input: InternalRow): Any = {
child.eval(input) != null
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 5ae0cef7b4..a17ac203ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
+import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -1705,6 +1706,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
override def foldable: Boolean = false
override def nullable: Boolean = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
+
override def flatArguments: Iterator[Any] = Iterator(child)
private val errMsg = "Null value appeared in non-nullable field:" +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index d78c726753..4885f7761f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -309,6 +309,8 @@ case class Not(child: Expression)
override def inputTypes: Seq[DataType] = Seq(BooleanType)
+ final override val nodePatterns: Seq[TreePattern] = Seq(NOT)
+
// +---------+-----------+
// | CHILD | NOT CHILD |
// +---------+-----------+
@@ -342,6 +344,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
values.head
}
+ final override val nodePatterns: Seq[TreePattern] = Seq(IN_SUBQUERY)
override def checkInputDataTypes(): TypeCheckResult = {
if (values.length != query.childOutputs.length) {
@@ -434,7 +437,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
- override val nodePatterns: Seq[TreePattern] = Seq(IN)
+ final override val nodePatterns: Seq[TreePattern] = Seq(IN)
override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
@@ -547,6 +550,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
override def nullable: Boolean = child.nullable || hasNull
+ final override val nodePatterns: Seq[TreePattern] = Seq(INSET)
+
protected override def nullSafeEval(value: Any): Any = {
if (set.contains(value)) {
true
@@ -663,6 +668,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
override def sqlOperator: String = "AND"
+ final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)
+
// +---------+---------+---------+---------+
// | AND | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
@@ -749,6 +756,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
override def sqlOperator: String = "OR"
+ final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)
+
// +---------+---------+---------+---------+
// | OR | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
@@ -820,6 +829,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
// finitely enumerable. The allowable types are checked below by checkInputDataTypes.
override def inputType: AbstractDataType = AnyDataType
+ final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_COMPARISON)
+
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 13d00faea3..57d7d76268 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -129,6 +130,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
+ final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)
+
override def toString: String = escapeChar match {
case '\\' => s"$left LIKE $right"
case c => s"$left LIKE $right ESCAPE '$c'"
@@ -198,6 +201,8 @@ sealed abstract class MultiLikeBase
override def nullable: Boolean = true
+ final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)
+
protected lazy val hasNull: Boolean = patterns.contains(null)
protected lazy val cache = patterns.filterNot(_ == null)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 3d5f812af9..5956c3e882 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
@@ -406,6 +407,8 @@ case class Upper(child: Expression)
override def convert(v: UTF8String): UTF8String = v.toUpperCase
// scalastyle:on caselocale
+ final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
}
@@ -432,6 +435,8 @@ case class Lower(child: Expression)
override def convert(v: UTF8String): UTF8String = v.toLowerCase
// scalastyle:on caselocale
+ final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 2bedf84271..ac939bf6d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, LIST_SUBQUERY,
+ PLAN_EXPRESSION, SCALAR_SUBQUERY, TreePattern}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.BitSet
@@ -38,6 +40,11 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
bits
}
+ final override val nodePatterns: Seq[TreePattern] = Seq(PLAN_EXPRESSION) ++ nodePatternsInternal
+
+ // Subclasses can override this function to provide more TreePatterns.
+ def nodePatternsInternal(): Seq[TreePattern] = Seq()
+
/** The id of the subquery expression. */
def exprId: ExprId
@@ -247,6 +254,8 @@ case class ScalarSubquery(
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren)
+
+ final override def nodePatternsInternal: Seq[TreePattern] = Seq(SCALAR_SUBQUERY)
}
object ScalarSubquery {
@@ -295,6 +304,8 @@ case class ListQuery(
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery =
copy(children = newChildren)
+
+ final override def nodePatternsInternal: Seq[TreePattern] = Seq(LIST_SUBQUERY)
}
/**
@@ -340,4 +351,6 @@ case class Exists(
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists =
copy(children = newChildren)
+
+ final override def nodePatternsInternal: Seq[TreePattern] = Seq(EXISTS_SUBQUERY)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 916f4eae7e..fe9c41e387 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -101,7 +101,10 @@ case class WindowSpecDefinition(
private def isValidFrameType(ft: DataType): Boolean = (orderSpec.head.dataType, ft) match {
case (DateType, IntegerType) => true
+ case (DateType, YearMonthIntervalType) => true
case (TimestampType, CalendarIntervalType) => true
+ case (TimestampType, YearMonthIntervalType) => true
+ case (TimestampType, DayTimeIntervalType) => true
case (a, b) => a == b
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
index 8f1548a978..0ff11ca49f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
/**
@@ -26,6 +26,15 @@ import org.apache.spark.sql.catalyst.rules.Rule
*/
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // One place where this optimization is invalid is an aggregation where the select
+ // list expression is a function of a grouping expression:
+ //
+ // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
+ //
+ // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
+ // optimization for Aggregates (although this misses some cases where the optimization
+ // can be made).
+ case a: Aggregate => a
case p => p.transformExpressionsUp {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
index 0be2792bfd..5b12667f4a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
@@ -231,6 +231,27 @@ object NestedColumnAliasing {
* of it.
*/
object GeneratorNestedColumnAliasing {
+ // Partitions `attrToAliases` based on whether the attribute is in Generator's output.
+ private def aliasesOnGeneratorOutput(
+ attrToAliases: Map[ExprId, Seq[Alias]],
+ generatorOutput: Seq[Attribute]) = {
+ val generatorOutputExprId = generatorOutput.map(_.exprId)
+ attrToAliases.partition { k =>
+ generatorOutputExprId.contains(k._1)
+ }
+ }
+
+ // Partitions `nestedFieldToAlias` based on whether the attribute of nested field extractor
+ // is in Generator's output.
+ private def nestedFieldOnGeneratorOutput(
+ nestedFieldToAlias: Map[ExtractValue, Alias],
+ generatorOutput: Seq[Attribute]) = {
+ val generatorOutputSet = AttributeSet(generatorOutput)
+ nestedFieldToAlias.partition { pair =>
+ pair._1.references.subsetOf(generatorOutputSet)
+ }
+ }
+
def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
// Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we
// need to prune nested columns through Project and under Generate. The difference is
@@ -241,12 +262,81 @@ object GeneratorNestedColumnAliasing {
// On top on `Generate`, a `Project` that might have nested column accessors.
// We try to get alias maps for both project list and generator's children expressions.
val exprsToPrune = projectList ++ g.generator.children
- NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map {
+ NestedColumnAliasing.getAliasSubMap(exprsToPrune).map {
case (nestedFieldToAlias, attrToAliases) =>
+ val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) =
+ nestedFieldOnGeneratorOutput(nestedFieldToAlias, g.qualifiedGeneratorOutput)
+ val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) =
+ aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput)
+
+ // Push nested column accessors through `Generator`.
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
- val newChild =
- NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
- Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild)
+ val newChild = NestedColumnAliasing.replaceWithAliases(g,
+ nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator)
+ val pushedThrough = Project(NestedColumnAliasing
+ .getNewProjectList(projectList, nestedFieldsNotOnGenerator), newChild)
+
+ // If the generator output is `ArrayType`, we cannot push through the extractor.
+ // It is because we don't allow field extractor on two-level array,
+ // i.e., attr.field when attr is a ArrayType(ArrayType(...)).
+ // Similarily, we also cannot push through if the child of generator is `MapType`.
+ g.generator.children.head.dataType match {
+ case _: MapType => return Some(pushedThrough)
+ case ArrayType(_: ArrayType, _) => return Some(pushedThrough)
+ case _ =>
+ }
+
+ // Pruning on `Generator`'s output. We only process single field case.
+ // For multiple field case, we cannot directly move field extractor into
+ // the generator expression. A workaround is to re-construct array of struct
+ // from multiple fields. But it will be more complicated and may not worth.
+ // TODO(SPARK-34956): support multiple fields.
+ if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) {
+ pushedThrough
+ } else {
+ // Only one nested column accessor.
+ // E.g., df.select(explode($"items").as("item")).select($"item.a")
+ pushedThrough match {
+ case p @ Project(_, newG: Generate) =>
+ // Replace the child expression of `ExplodeBase` generator with
+ // nested column accessor.
+ // E.g., df.select(explode($"items").as("item")).select($"item.a") =>
+ // df.select(explode($"items.a").as("item.a"))
+ val rewrittenG = newG.transformExpressions {
+ case e: ExplodeBase =>
+ val extractor = nestedFieldsOnGenerator.head._1.transformUp {
+ case _: Attribute =>
+ e.child
+ case g: GetStructField =>
+ ExtractValue(g.child, Literal(g.extractFieldName), SQLConf.get.resolver)
+ }
+ e.withNewChildren(Seq(extractor))
+ }
+
+ // As we change the child of the generator, its output data type must be updated.
+ val updatedGeneratorOutput = rewrittenG.generatorOutput
+ .zip(rewrittenG.generator.elementSchema.toAttributes)
+ .map { case (oldAttr, newAttr) =>
+ newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
+ }
+ assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length,
+ "Updated generator output must have the same length " +
+ "with original generator output.")
+ val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput)
+
+ // Replace nested column accessor with generator output.
+ p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
+ case f: ExtractValue if nestedFieldsOnGenerator.contains(f) =>
+ updatedGenerate.output
+ .find(a => attrToAliasesOnGenerator.contains(a.exprId))
+ .getOrElse(f)
+ }
+
+ case other =>
+ // We should not reach here.
+ throw new IllegalStateException(s"Unreasonable plan after optimization: $other")
+ }
+ }
}
case g: Generate if SQLConf.get.nestedSchemaPruningEnabled &&
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 5343fce07c..16e3e43356 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -118,8 +119,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
OptimizeUpdateFields,
SimplifyExtractValueOps,
OptimizeCsvJsonExprs,
- CombineConcats,
- UpdateGroupingExprRefNullability) ++
+ CombineConcats) ++
extendedOperatorOptimizationRules
val operatorOptimizationBatch: Seq[Batch] = {
@@ -148,7 +148,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateView,
ReplaceExpressions,
RewriteNonCorrelatedExists,
- EnforceGroupingReferencesInAggregates,
ComputeCurrentTime,
GetCurrentDatabaseAndCatalog(catalogManager)) ::
//////////////////////////////////////////////////////////////////////////////////////////
@@ -268,9 +267,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
- ReplaceUpdateFieldsExpression.ruleName ::
- EnforceGroupingReferencesInAggregates.ruleName ::
- UpdateGroupingExprRefNullability.ruleName :: Nil
+ ReplaceUpdateFieldsExpression.ruleName :: Nil
/**
* Optimize all the subqueries inside expression.
@@ -283,7 +280,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
case other => other
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(PLAN_EXPRESSION), ruleId) {
case s: SubqueryExpression =>
val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s))
// At this point we have an optimized subquery plan that we are going to attach
@@ -510,7 +508,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) =>
val aliasMap = getAliasMap(lower)
- val newAggregate = Aggregate.withGroupingRefs(
+ val newAggregate = upper.copy(
child = lower.child,
groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
aggregateExpressions = upper.aggregateExpressions.map(
@@ -526,19 +524,23 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
}
private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
- val upperHasNoAggregateExpressions =
- !upper.aggregateExpressions.exists(AggregateExpression.containsAggregate)
+ val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate)
lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet(
lower
.aggregateExpressions
.filter(_.deterministic)
- .filterNot(AggregateExpression.containsAggregate)
+ .filter(!isAggregate(_))
.map(_.toAttribute)
))
upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
}
+
+ private def isAggregate(expr: Expression): Boolean = {
+ expr.find(e => e.isInstanceOf[AggregateExpression] ||
+ PythonUDF.isGroupedAggPandasUDF(e)).isDefined
+ }
}
/**
@@ -1976,18 +1978,7 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
val newGrouping = grouping.filter(!_.foldable)
if (newGrouping.nonEmpty) {
- val droppedGroupsBefore =
- grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray
-
- val newAggregateExpressions =
- a.aggregateExpressions.map(_.transform {
- case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
- g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
- }.asInstanceOf[NamedExpression])
-
- a.copy(
- groupingExpressions = newGrouping,
- aggregateExpressions = newAggregateExpressions)
+ a.copy(groupingExpressions = newGrouping)
} else {
// All grouping expressions are literals. We should not drop them all, because this can
// change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
@@ -2008,25 +1999,7 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
if (newGrouping.size == grouping.size) {
a
} else {
- var i = 0
- val droppedGroupsBefore = grouping.scanLeft(0)((n, e) =>
- n + (if (i >= newGrouping.size || e.eq(newGrouping(i))) {
- i += 1
- 0
- } else {
- 1
- })
- ).toArray
-
- val newAggregateExpressions =
- a.aggregateExpressions.map(_.transform {
- case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
- g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
- }.asInstanceOf[NamedExpression])
-
- a.copy(
- groupingExpressions = newGrouping,
- aggregateExpressions = newAggregateExpressions)
+ a.copy(groupingExpressions = newGrouping)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
index 3c48742b87..3de19afa91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter,
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, InsertStarAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_LITERAL, TRUE_OR_FALSE_LITERAL}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{INSET, NULL_LITERAL, TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.util.Utils
@@ -51,7 +51,7 @@ import org.apache.spark.util.Utils
object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
- _.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL), ruleId) {
+ _.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET), ruleId) {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala
index 8964a2776b..be39c3f10e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala
@@ -49,28 +49,22 @@ object OptimizeUpdateFields extends Rule[LogicalPlan] {
val values = withFields.map(_.valExpr)
val newNames = mutable.ArrayBuffer.empty[String]
- val newValues = mutable.ArrayBuffer.empty[Expression]
+ val newValues = mutable.HashMap.empty[String, Expression]
+ // Used to remember the casing of the last instance
+ val nameMap = mutable.HashMap.empty[String, String]
- if (caseSensitive) {
- names.zip(values).reverse.foreach { case (name, value) =>
- if (!newNames.contains(name)) {
- newNames += name
- newValues += value
- }
- }
- } else {
- val nameSet = mutable.HashSet.empty[String]
- names.zip(values).reverse.foreach { case (name, value) =>
- val lowercaseName = name.toLowerCase(Locale.ROOT)
- if (!nameSet.contains(lowercaseName)) {
- newNames += name
- newValues += value
- nameSet += lowercaseName
- }
+ names.zip(values).foreach { case (name, value) =>
+ val normalizedName = if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
+ if (nameMap.contains(normalizedName)) {
+ newValues += normalizedName -> value
+ } else {
+ newNames += normalizedName
+ newValues += normalizedName -> value
}
+ nameMap += normalizedName -> name
}
- val newWithFields = newNames.reverse.zip(newValues.reverse).map(p => WithField(p._1, p._2))
+ val newWithFields = newNames.map(n => WithField(nameMap(n), newValues(n)))
UpdateFields(structExpr, newWithFields.toSeq)
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 65466c5ec0..e9752e046a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.catalyst.trees.TreePattern.IN
+import org.apache.spark.sql.catalyst.trees.AlwaysProcess
+import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -50,8 +51,9 @@ object ConstantFolding extends Rule[LogicalPlan] {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsDown {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(AlwaysProcess.fn, ruleId) {
+ case q: LogicalPlan => q.transformExpressionsDownWithPruning(
+ AlwaysProcess.fn, ruleId) {
// Skip redundant folding of literals. This rule is technically not necessary. Placing this
// here avoids running the next rule for Literal values, which would create a new Literal
// object and running eval unnecessarily.
@@ -83,7 +85,8 @@ object ConstantFolding extends Rule[LogicalPlan] {
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+ _.containsAllPatterns(LITERAL, FILTER), ruleId) {
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true)
if (newCondition.isDefined) {
@@ -210,14 +213,15 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
case _ => ExpressionSet(Seq.empty)
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsPattern(BINARY_ARITHMETIC), ruleId) {
case q: LogicalPlan =>
// We have to respect aggregate expressions which exists in grouping expressions when plan
// is an Aggregate operator, otherwise the optimized expression could not be derived from
// grouping expressions.
// TODO: do not reorder consecutive `Add`s or `Multiply`s with different `failOnError` flags
val groupingExpressionSet = collectGroupingExpressions(q)
- q transformExpressionsDown {
+ q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) {
case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
@@ -286,8 +290,10 @@ object OptimizeIn extends Rule[LogicalPlan] {
* 4. Removes `Not` operator.
*/
object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsAnyPattern(AND_OR, NOT), ruleId) {
+ case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+ _.containsAnyPattern(AND_OR, NOT), ruleId) {
case TrueLiteral And e => e
case e And TrueLiteral => e
case FalseLiteral Or e => e
@@ -460,7 +466,8 @@ object SimplifyBinaryComparison
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsPattern(BINARY_COMPARISON), ruleId) {
case l: LogicalPlan =>
lazy val notNullExpressions = ExpressionSet(l match {
case Filter(fc, _) =>
@@ -470,7 +477,7 @@ object SimplifyBinaryComparison
case _ => Seq.empty
})
- l transformExpressionsUp {
+ l.transformExpressionsUpWithPruning(_.containsPattern(BINARY_COMPARISON)) {
// True with equality
case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
case a EqualTo b if canSimplifyComparison(a, b, notNullExpressions) => TrueLiteral
@@ -496,7 +503,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsAnyPattern(IF, CASE_WHEN), ruleId) {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
@@ -584,7 +592,7 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
true
case _: CastBase => true
case _: GetDateField | _: LastDay => true
- case _: ExtractIntervalPart => true
+ case _: ExtractIntervalPart[_] => true
case _: ArraySetLike => true
case _: ExtractValue => true
case _ => false
@@ -601,8 +609,10 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsAnyPattern(CASE_WHEN, IF), ruleId) {
+ case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+ _.containsAnyPattern(CASE_WHEN, IF), ruleId) {
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
@@ -713,7 +723,8 @@ object LikeSimplification extends Rule[LogicalPlan] {
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(LIKE_FAMLIY), ruleId) {
case l @ Like(input, Literal(pattern, StringType), escapeChar) =>
if (pattern == null) {
// If pattern is null, return null value directly, since "col like null" == null.
@@ -740,8 +751,12 @@ object NullPropagation extends Rule[LogicalPlan] {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT)
+ || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) {
+ case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+ t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT)
+ || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) {
case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
case e @ AggregateExpression(Count(exprs), _, _, _, _) if exprs.forall(isNullLiteral) =>
@@ -917,7 +932,8 @@ object FoldablePropagation extends Rule[LogicalPlan] {
* Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
*/
object SimplifyCasts extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(CAST), ruleId) {
case Cast(e, dataType, _) if e.dataType == dataType => e
case c @ Cast(e, dataType, _) => (e.dataType, dataType) match {
case (ArrayType(from, false), ArrayType(to, true)) if from == to => e
@@ -933,7 +949,8 @@ object SimplifyCasts extends Rule[LogicalPlan] {
* Removes nodes that are not necessary.
*/
object RemoveDispensableExpressions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(UNARY_POSITIVE), ruleId) {
case UnaryPositive(child) => child
}
}
@@ -944,8 +961,10 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] {
* the inner conversion is overwritten by the outer one.
*/
object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case q: LogicalPlan => q transformExpressionsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsPattern(UPPER_OR_LOWER), ruleId) {
+ case q: LogicalPlan => q.transformExpressionsUpWithPruning(
+ _.containsPattern(UPPER_OR_LOWER), ruleId) {
case Upper(Upper(child)) => Upper(child)
case Upper(Lower(child)) => Upper(child)
case Lower(Upper(child)) => Lower(child)
@@ -986,7 +1005,8 @@ object CombineConcats extends Rule[LogicalPlan] {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
+ _.containsPattern(CONCAT), ruleId) {
case concat: Concat if hasNestedConcats(concat) =>
flattenConcats(concat)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 9381796d3d..ca3aca54f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, FILTER, IN_SUBQUERY,
+ LIST_SUBQUERY, SCALAR_SUBQUERY}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -94,7 +96,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ t => t.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY) && t.containsPattern(FILTER)) {
case Filter(condition, child)
if SubqueryExpression.hasInOrCorrelatedExistsSubquery(condition) =>
val (withSubquery, withoutSubquery) =
@@ -164,7 +167,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
plan: LogicalPlan): (Option[Expression], LogicalPlan) = {
var newPlan = plan
val newExprs = exprs.map { e =>
- e transformDown {
+ e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) {
case Exists(sub, conditions, _) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
newPlan =
@@ -303,7 +306,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
}
}
- plan transformExpressions {
+ plan.transformExpressionsWithPruning(_.containsAnyPattern(
+ SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, outerPlans)
ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId)
@@ -319,7 +323,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
/**
* Pull up the correlated predicates and rewrite all subqueries in an operator tree..
*/
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+ _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case f @ Filter(_, a: Aggregate) =>
rewriteSubQueries(f, Seq(a, a.child))
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
@@ -341,7 +346,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
private def extractCorrelatedScalarSubqueries[E <: Expression](
expression: E,
subqueries: ArrayBuffer[ScalarSubquery]): E = {
- val newExpression = expression transform {
+ val newExpression = expression.transformWithPruning(_.containsPattern(SCALAR_SUBQUERY)) {
case s: ScalarSubquery if s.children.nonEmpty =>
subqueries += s
s.plan.output.head
@@ -628,10 +633,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
- case a @ Aggregate(grouping, _, child) =>
+ case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
- val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs
- .map(extractCorrelatedScalarSubqueries(_, subqueries))
+ val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 8c8e8c6a8a..06bbb984d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -993,26 +993,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
.map(groupByExpr => {
val groupingAnalytics = groupByExpr.groupingAnalytics
if (groupingAnalytics != null) {
- val groupingSets = groupingAnalytics.groupingSet.asScala
- .map(_.expression.asScala.map(e => expression(e)).toSeq)
- if (groupingAnalytics.CUBE != null) {
- // CUBE(A, B, (A, B), ()) is not supported.
- if (groupingSets.exists(_.isEmpty)) {
- throw new ParseException("Empty set in CUBE grouping sets is not supported.",
- groupingAnalytics)
- }
- Cube(groupingSets.toSeq)
- } else if (groupingAnalytics.ROLLUP != null) {
- // ROLLUP(A, B, (A, B), ()) is not supported.
- if (groupingSets.exists(_.isEmpty)) {
- throw new ParseException("Empty set in ROLLUP grouping sets is not supported.",
- groupingAnalytics)
- }
- Rollup(groupingSets.toSeq)
- } else {
- assert(groupingAnalytics.GROUPING != null && groupingAnalytics.SETS != null)
- GroupingSets(groupingSets.toSeq)
- }
+ visitGroupingAnalytics(groupingAnalytics)
} else {
expression(groupByExpr.expression)
}
@@ -1021,6 +1002,36 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
}
}
+ override def visitGroupingAnalytics(
+ groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = {
+ val groupingSets = groupingAnalytics.groupingSet.asScala
+ .map(_.expression.asScala.map(e => expression(e)).toSeq)
+ if (groupingAnalytics.CUBE != null) {
+ // CUBE(A, B, (A, B), ()) is not supported.
+ if (groupingSets.exists(_.isEmpty)) {
+ throw QueryParsingErrors.invalidGroupingSetError("CUBE", groupingAnalytics)
+ }
+ Cube(groupingSets.toSeq)
+ } else if (groupingAnalytics.ROLLUP != null) {
+ // ROLLUP(A, B, (A, B), ()) is not supported.
+ if (groupingSets.exists(_.isEmpty)) {
+ throw QueryParsingErrors.invalidGroupingSetError("ROLLUP", groupingAnalytics)
+ }
+ Rollup(groupingSets.toSeq)
+ } else {
+ assert(groupingAnalytics.GROUPING != null && groupingAnalytics.SETS != null)
+ val groupingSets = groupingAnalytics.groupingElement.asScala.flatMap { expr =>
+ val groupingAnalytics = expr.groupingAnalytics()
+ if (groupingAnalytics != null) {
+ visitGroupingAnalytics(groupingAnalytics).selectedGroupByExprs
+ } else {
+ Seq(expr.groupingSet().expression().asScala.map(e => expression(e)).toSeq)
+ }
+ }
+ GroupingSets(groupingSets.toSeq)
+ }
+ }
+
/**
* Add [[UnresolvedHint]]s to a logical plan.
*/
@@ -2395,13 +2406,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
*/
override def visitUnitToUnitInterval(ctx: UnitToUnitIntervalContext): CalendarInterval = {
withOrigin(ctx) {
- val value = Option(ctx.intervalValue.STRING).map(string).getOrElse {
+ val value = Option(ctx.intervalValue.STRING).map(string).map { interval =>
+ if (ctx.intervalValue().MINUS() == null) {
+ interval
+ } else {
+ interval.startsWith("-") match {
+ case true => interval.replaceFirst("-", "")
+ case false => s"-$interval"
+ }
+ }
+ }.getOrElse {
throw QueryParsingErrors.invalidFromToUnitValueError(ctx.intervalValue)
}
try {
val from = ctx.from.getText.toLowerCase(Locale.ROOT)
val to = ctx.to.getText.toLowerCase(Locale.ROOT)
- val interval = (from, to) match {
+ (from, to) match {
case ("year", "month") =>
IntervalUtils.fromYearMonthString(value)
case ("day", "hour") =>
@@ -2419,9 +2439,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
case _ =>
throw QueryParsingErrors.fromToIntervalUnsupportedError(from, to, ctx)
}
- Option(ctx.intervalValue.MINUS)
- .map(_ => IntervalUtils.negateExact(interval))
- .getOrElse(interval)
} catch {
// Handle Exceptions thrown by CalendarInterval
case e: IllegalArgumentException =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index a96674fe97..c22a874779 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -287,7 +287,7 @@ object PhysicalAggregation {
(Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
def unapply(a: Any): Option[ReturnType] = a match {
- case a @ logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+ case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of semantically distinct aggregate expressions and re-write expressions so
@@ -297,9 +297,11 @@ object PhysicalAggregation {
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
// addExpr() always returns false for non-deterministic expressions and do not add them.
- case a
- if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) =>
- a
+ case agg: AggregateExpression
+ if !equivalentAggregateExpressions.addExpr(agg) => agg
+ case udf: PythonUDF
+ if PythonUDF.isGroupedAggPandasUDF(udf) &&
+ !equivalentAggregateExpressions.addExpr(udf) => udf
}
}
@@ -320,7 +322,7 @@ object PhysicalAggregation {
// which takes the grouping columns and final aggregate result buffer as input.
// Thus, we must re-write the result expressions so that their attributes match up with
// the attributes of the final result projection's input row:
- val rewrittenResultExpressions = a.aggregateExpressionsWithoutGroupingRefs.map { expr =>
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transformDown {
case ae: AggregateExpression =>
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 49e3e3c9ee..0f5bc7e1f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.plans.logical
-import scala.collection.mutable
-
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, MultiInstanceRelation, TypeCoercion, TypeCoercionBase}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
@@ -28,9 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.catalyst.trees.TreePattern.{
- INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern
-}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
@@ -166,6 +162,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
override def maxRows: Option[Long] = child.maxRows
+ final override val nodePatterns: Seq[TreePattern] = Seq(FILTER)
+
override protected lazy val validConstraints: ExpressionSet = {
val predicates = splitConjunctivePredicates(condition)
.filterNot(SubqueryExpression.hasCorrelatedSubquery)
@@ -781,23 +779,14 @@ case class Range(
/**
* This is a Group by operator with the aggregate functions and projections.
*
- * @param groupingExpressions Expressions for grouping keys.
- * @param aggregateExpressions Expressions for a project list, which can contain
- * [[AggregateExpression]]s and [[GroupingExprRef]]s.
- * @param child The child of the aggregate node.
+ * @param groupingExpressions expressions for grouping keys
+ * @param aggregateExpressions expressions for a project list, which could contain
+ * [[AggregateExpression]]s.
*
- * Expressions without aggregate functions in [[aggregateExpressions]] can contain
- * [[GroupingExprRef]]s to refer to complex grouping expressions in [[groupingExpressions]]. These
- * references ensure that optimization rules don't change the aggregate expressions to invalid ones
- * that no longer refer to any grouping expressions and also simplify the expression transformations
- * on the node (need to transform the expression only once).
- *
- * For example, in the following query Spark shouldn't optimize the aggregate expression
- * `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`:
- * SELECT not(c IS NULL)
- * FROM t
- * GROUP BY c IS NULL
- * Instead, the aggregate expression should contain `Not(GroupingExprRef(0))`.
+ * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before
+ * separating projection from grouping and aggregate, we should avoid expression-level optimization
+ * on aggregateExpressions, which could reference an expression in groupingExpressions.
+ * For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]]
*/
case class Aggregate(
groupingExpressions: Seq[Expression],
@@ -824,21 +813,8 @@ case class Aggregate(
}
}
- private def expandGroupingReferences(e: Expression): Expression = {
- e match {
- case _ if AggregateExpression.isAggregate(e) => e
- case g: GroupingExprRef => groupingExpressions(g.ordinal)
- case _ => e.mapChildren(expandGroupingReferences)
- }
- }
-
- lazy val aggregateExpressionsWithoutGroupingRefs = {
- aggregateExpressions.map(expandGroupingReferences(_).asInstanceOf[NamedExpression])
- }
-
override lazy val validConstraints: ExpressionSet = {
- val nonAgg = aggregateExpressionsWithoutGroupingRefs.
- filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty)
+ val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty)
getAllValidConstraints(nonAgg)
}
@@ -846,51 +822,6 @@ case class Aggregate(
copy(child = newChild)
}
-object Aggregate {
- private def collectComplexGroupingExpressions(groupingExpressions: Seq[Expression]) = {
- val complexGroupingExpressions = mutable.Map.empty[Expression, (Expression, Int)]
- var i = 0
- groupingExpressions.foreach { ge =>
- if (!ge.foldable && ge.children.nonEmpty &&
- !complexGroupingExpressions.contains(ge.canonicalized)) {
- complexGroupingExpressions += ge.canonicalized -> (ge, i)
- }
- i += 1
- }
- complexGroupingExpressions
- }
-
- private def insertGroupingReferences(
- aggregateExpressions: Seq[NamedExpression],
- groupingExpressions: collection.Map[Expression, (Expression, Int)]): Seq[NamedExpression] = {
- def insertGroupingExprRefs(e: Expression): Expression = {
- e match {
- case _ if AggregateExpression.isAggregate(e) => e
- case _ if groupingExpressions.contains(e.canonicalized) =>
- val (groupingExpression, ordinal) = groupingExpressions(e.canonicalized)
- GroupingExprRef(ordinal, groupingExpression.dataType, groupingExpression.nullable)
- case _ => e.mapChildren(insertGroupingExprRefs)
- }
- }
-
- aggregateExpressions.map(insertGroupingExprRefs(_).asInstanceOf[NamedExpression])
- }
-
- def withGroupingRefs(
- groupingExpressions: Seq[Expression],
- aggregateExpressions: Seq[NamedExpression],
- child: LogicalPlan): Aggregate = {
- val complexGroupingExpressions = collectComplexGroupingExpressions(groupingExpressions)
- val aggrExprWithGroupingReferences = if (complexGroupingExpressions.nonEmpty) {
- insertGroupingReferences(aggregateExpressions, complexGroupingExpressions)
- } else {
- aggregateExpressions
- }
-
- new Aggregate(groupingExpressions, aggrExprWithGroupingReferences, child)
- }
-}
-
case class Window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index d745f50604..1c997d3740 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -42,17 +42,33 @@ object RuleIdCollection {
// Catalyst Analyzer rules
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRandomSeed" ::
+ "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" ::
+ "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
// Catalyst Optimizer rules
+ "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
+ "org.apache.spark.sql.catalyst.optimizer.CombineConcats" ::
+ "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::
+ "org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" ::
"org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" ::
+ "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
+ "org.apache.spark.sql.catalyst.optimizer.NullPropagation" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
+ "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
"org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
+ "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" ::
"org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
+ "org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions" ::
+ "org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator" ::
"org.apache.spark.sql.catalyst.optimizer.ReorderJoin" ::
- "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" :: Nil
+ "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" ::
+ "org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" ::
+ "org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" ::
+ "org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" ::
+ "org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" :: Nil
}
// Maps rule names to ids. Rule ids are continuous natural numbers starting from 0.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index bb09b9ddda..faf736d9c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -23,15 +23,36 @@ object TreePattern extends Enumeration {
// Enum Ids start from 0.
// Expression patterns (alphabetically ordered)
- val ATTRIBUTE_REFERENCE = Value(0)
- val EXPRESSION_WITH_RANDOM_SEED = Value
+ val AND_OR: Value = Value(0)
+ val ATTRIBUTE_REFERENCE: Value = Value
+ val BINARY_ARITHMETIC: Value = Value
+ val BINARY_COMPARISON: Value = Value
+ val CASE_WHEN: Value = Value
+ val CAST: Value = Value
+ val CONCAT: Value = Value
+ val COUNT: Value = Value
+ val DYNAMIC_PRUNING_SUBQUERY: Value = Value
+ val EXISTS_SUBQUERY = Value
+ val EXPRESSION_WITH_RANDOM_SEED: Value = Value
+ val IF: Value = Value
val IN: Value = Value
+ val IN_SUBQUERY: Value = Value
+ val INSET: Value = Value
+ val LIKE_FAMLIY: Value = Value
+ val LIST_SUBQUERY: Value = Value
val LITERAL: Value = Value
+ val NOT: Value = Value
+ val NULL_CHECK: Value = Value
val NULL_LITERAL: Value = Value
+ val PLAN_EXPRESSION: Value = Value
+ val SCALAR_SUBQUERY: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
+ val UNARY_POSITIVE: Value = Value
+ val UPPER_OR_LOWER: Value = Value
// Logical plan patterns (alphabetically ordered)
+ val FILTER: Value = Value
val INNER_LIKE_JOIN: Value = Value
val JOIN: Value = Value
val LEFT_SEMI_OR_ANTI_JOIN: Value = Value
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index e52d3c8817..c9bc579ceb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -54,31 +54,39 @@ object IntervalUtils {
}
import IntervalUnit._
- def getYears(interval: CalendarInterval): Int = {
- interval.months / MONTHS_PER_YEAR
- }
+ def getYears(months: Int): Int = months / MONTHS_PER_YEAR
- def getMonths(interval: CalendarInterval): Byte = {
- (interval.months % MONTHS_PER_YEAR).toByte
- }
+ def getYears(interval: CalendarInterval): Int = getYears(interval.months)
+
+ def getMonths(months: Int): Byte = (months % MONTHS_PER_YEAR).toByte
+
+ def getMonths(interval: CalendarInterval): Byte = getMonths(interval.months)
+
+ def getDays(microseconds: Long): Int = (microseconds / MICROS_PER_DAY).toInt
def getDays(interval: CalendarInterval): Int = {
- val daysInMicroseconds = (interval.microseconds / MICROS_PER_DAY).toInt
+ val daysInMicroseconds = getDays(interval.microseconds)
Math.addExact(interval.days, daysInMicroseconds)
}
- def getHours(interval: CalendarInterval): Long = {
- (interval.microseconds % MICROS_PER_DAY) / MICROS_PER_HOUR
+ def getHours(microseconds: Long): Byte = {
+ ((microseconds % MICROS_PER_DAY) / MICROS_PER_HOUR).toByte
}
- def getMinutes(interval: CalendarInterval): Byte = {
- ((interval.microseconds % MICROS_PER_HOUR) / MICROS_PER_MINUTE).toByte
+ def getHours(interval: CalendarInterval): Byte = getHours(interval.microseconds)
+
+ def getMinutes(microseconds: Long): Byte = {
+ ((microseconds % MICROS_PER_HOUR) / MICROS_PER_MINUTE).toByte
}
- def getSeconds(interval: CalendarInterval): Decimal = {
- Decimal(interval.microseconds % MICROS_PER_MINUTE, 8, 6)
+ def getMinutes(interval: CalendarInterval): Byte = getMinutes(interval.microseconds)
+
+ def getSeconds(microseconds: Long): Decimal = {
+ Decimal(microseconds % MICROS_PER_MINUTE, 8, 6)
}
+ def getSeconds(interval: CalendarInterval): Decimal = getSeconds(interval.microseconds)
+
private def toLongWithRange(
fieldName: IntervalUnit,
s: String,
@@ -100,12 +108,11 @@ object IntervalUtils {
*/
def fromYearMonthString(input: String): CalendarInterval = {
require(input != null, "Interval year-month string must be not null")
- def toInterval(yearStr: String, monthStr: String): CalendarInterval = {
+ def toInterval(yearStr: String, monthStr: String, sign: Int): CalendarInterval = {
try {
- val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE).toInt
- val months = toLongWithRange(MONTH, monthStr, 0, 11).toInt
- val totalMonths = Math.addExact(Math.multiplyExact(years, 12), months)
- new CalendarInterval(totalMonths, 0, 0)
+ val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR)
+ val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11))
+ new CalendarInterval(Math.toIntExact(totalMonths), 0, 0)
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(
@@ -114,9 +121,9 @@ object IntervalUtils {
}
input.trim match {
case yearMonthPattern("-", yearStr, monthStr) =>
- negateExact(toInterval(yearStr, monthStr))
+ toInterval(yearStr, monthStr, -1)
case yearMonthPattern(_, yearStr, monthStr) =>
- toInterval(yearStr, monthStr)
+ toInterval(yearStr, monthStr, 1)
case _ =>
throw new IllegalArgumentException(
s"Interval string does not match year-month format of 'y-m': $input")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index a3fbe4c742..446486bdf1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -1352,4 +1352,9 @@ private[spark] object QueryCompilationErrors {
s"Expected udfs have the same evalType but got different evalTypes: " +
s"${evalTypes.mkString(",")}")
}
+
+ def ambiguousFieldNameError(fieldName: String, names: String): Throwable = {
+ new AnalysisException(
+ s"Ambiguous field name: $fieldName. Found multiple columns that can match: $names")
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index eb7b7b4ff6..3589c875fb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -309,7 +309,7 @@ object QueryExecutionErrors {
new IllegalStateException("table stats must be specified.")
}
- def unaryMinusCauseOverflowError(originValue: Short): ArithmeticException = {
+ def unaryMinusCauseOverflowError(originValue: AnyVal): ArithmeticException = {
new ArithmeticException(s"- $originValue caused overflow.")
}
@@ -772,4 +772,55 @@ object QueryExecutionErrors {
new IllegalArgumentException(s"Unexpected: $o")
}
+ def unscaledValueTooLargeForPrecisionError(): Throwable = {
+ new ArithmeticException("Unscaled value too large for precision")
+ }
+
+ def decimalPrecisionExceedsMaxPrecisionError(precision: Int, maxPrecision: Int): Throwable = {
+ new ArithmeticException(
+ s"Decimal precision $precision exceeds max precision $maxPrecision")
+ }
+
+ def outOfDecimalTypeRangeError(str: UTF8String): Throwable = {
+ new ArithmeticException(s"out of decimal type range: $str")
+ }
+
+ def unsupportedArrayTypeError(clazz: Class[_]): Throwable = {
+ new RuntimeException(s"Do not support array of type $clazz.")
+ }
+
+ def unsupportedJavaTypeError(clazz: Class[_]): Throwable = {
+ new RuntimeException(s"Do not support type $clazz.")
+ }
+
+ def failedParsingStructTypeError(raw: String): Throwable = {
+ new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw")
+ }
+
+ def failedMergingFieldsError(leftName: String, rightName: String, e: Throwable): Throwable = {
+ new SparkException(s"Failed to merge fields '$leftName' and '$rightName'. ${e.getMessage}")
+ }
+
+ def cannotMergeDecimalTypesWithIncompatiblePrecisionAndScaleError(
+ leftPrecision: Int, rightPrecision: Int, leftScale: Int, rightScale: Int): Throwable = {
+ new SparkException("Failed to merge decimal types with incompatible " +
+ s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale")
+ }
+
+ def cannotMergeDecimalTypesWithIncompatiblePrecisionError(
+ leftPrecision: Int, rightPrecision: Int): Throwable = {
+ new SparkException("Failed to merge decimal types with incompatible " +
+ s"precision $leftPrecision and $rightPrecision")
+ }
+
+ def cannotMergeDecimalTypesWithIncompatibleScaleError(
+ leftScale: Int, rightScale: Int): Throwable = {
+ new SparkException("Failed to merge decimal types with incompatible " +
+ s"scala $leftScale and $rightScale")
+ }
+
+ def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = {
+ new SparkException(s"Failed to merge incompatible data types ${left.catalogString}" +
+ s" and ${right.catalogString}")
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index d97b19954f..b714f57875 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -367,4 +367,7 @@ object QueryParsingErrors {
new ParseException("LOCAL is supported only with file: scheme", ctx)
}
+ def invalidGroupingSetError(element: String, ctx: GroupingAnalyticsContext): Throwable = {
+ new ParseException(s"Empty set in $element grouping sets is not supported.", ctx)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 04e740039f..9d09715d25 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3150,6 +3150,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val MAX_CONCURRENT_OUTPUT_FILE_WRITERS = buildConf("spark.sql.maxConcurrentOutputFileWriters")
+ .internal()
+ .doc("Maximum number of output file writers to use concurrently. If number of writers " +
+ "needed reaches this limit, task will sort rest of output then writing them.")
+ .version("3.2.0")
+ .intConf
+ .createWithDefault(0)
+
/**
* Holds information about keys that have been deprecated.
*
@@ -3839,6 +3847,8 @@ class SQLConf extends Serializable with Logging {
def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED)
+ def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 960e174f9c..d9f457f153 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -22,6 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, Roundin
import scala.util.Try
import org.apache.spark.annotation.Unstable
+import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.unsafe.types.UTF8String
@@ -80,7 +81,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
*/
def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
if (setOrNull(unscaled, precision, scale) == null) {
- throw new ArithmeticException("Unscaled value too large for precision")
+ throw QueryExecutionErrors.unscaledValueTooLargeForPrecisionError()
}
this
}
@@ -118,8 +119,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
DecimalType.checkNegativeScale(scale)
this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
if (decimalVal.precision > precision) {
- throw new ArithmeticException(
- s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
+ throw QueryExecutionErrors.decimalPrecisionExceedsMaxPrecisionError(
+ decimalVal.precision, precision)
}
this.longVal = 0L
this._precision = precision
@@ -251,7 +252,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def toByte: Byte = toLong.toByte
private def overflowException(dataType: String) =
- throw new ArithmeticException(s"Casting $this to $dataType causes overflow")
+ throw QueryExecutionErrors.castingCauseOverflowError(this, dataType)
/**
* @return the Byte value that is equal to the rounded decimal.
@@ -263,14 +264,14 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (actualLongVal == actualLongVal.toByte) {
actualLongVal.toByte
} else {
- overflowException("byte")
+ throw QueryExecutionErrors.castingCauseOverflowError(this, "byte")
}
} else {
val doubleVal = decimalVal.toDouble
if (Math.floor(doubleVal) <= Byte.MaxValue && Math.ceil(doubleVal) >= Byte.MinValue) {
doubleVal.toByte
} else {
- overflowException("byte")
+ throw QueryExecutionErrors.castingCauseOverflowError(this, "byte")
}
}
}
@@ -285,14 +286,14 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (actualLongVal == actualLongVal.toShort) {
actualLongVal.toShort
} else {
- overflowException("short")
+ throw QueryExecutionErrors.castingCauseOverflowError(this, "short")
}
} else {
val doubleVal = decimalVal.toDouble
if (Math.floor(doubleVal) <= Short.MaxValue && Math.ceil(doubleVal) >= Short.MinValue) {
doubleVal.toShort
} else {
- overflowException("short")
+ throw QueryExecutionErrors.castingCauseOverflowError(this, "short")
}
}
}
@@ -307,14 +308,14 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (actualLongVal == actualLongVal.toInt) {
actualLongVal.toInt
} else {
- overflowException("int")
+ throw QueryExecutionErrors.castingCauseOverflowError(this, "int")
}
} else {
val doubleVal = decimalVal.toDouble
if (Math.floor(doubleVal) <= Int.MaxValue && Math.ceil(doubleVal) >= Int.MinValue) {
doubleVal.toInt
} else {
- overflowException("int")
+ throw QueryExecutionErrors.castingCauseOverflowError(this, "int")
}
}
}
@@ -333,7 +334,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
// `longValueExact` to make sure the range check is accurate.
decimalVal.bigDecimal.toBigInteger.longValueExact()
} catch {
- case _: ArithmeticException => overflowException("long")
+ case _: ArithmeticException =>
+ throw QueryExecutionErrors.castingCauseOverflowError(this, "long")
}
}
}
@@ -365,8 +367,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (nullOnOverflow) {
null
} else {
- throw new ArithmeticException(
- s"$toDebugString cannot be represented as Decimal($precision, $scale).")
+ throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(this, precision, scale)
}
}
}
@@ -622,13 +623,13 @@ object Decimal {
// We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow.
// For example: Decimal("6.0790316E+25569151")
if (calculatePrecision(bigDecimal) > DecimalType.MAX_PRECISION) {
- throw new ArithmeticException(s"out of decimal type range: $str")
+ throw QueryExecutionErrors.outOfDecimalTypeRangeError(str)
} else {
Decimal(bigDecimal)
}
} catch {
case _: NumberFormatException =>
- throw new NumberFormatException(s"invalid input syntax for type numeric: $str")
+ throw QueryExecutionErrors.invalidInputSyntaxForNumericError(str)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
index bedf6ccf44..3e05eda344 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
@@ -23,6 +23,7 @@ import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.Stable
+import org.apache.spark.sql.errors.QueryExecutionErrors
/**
@@ -162,13 +163,13 @@ object Metadata {
builder.putMetadataArray(
key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray)
case other =>
- throw new RuntimeException(s"Do not support array of type ${other.getClass}.")
+ throw QueryExecutionErrors.unsupportedArrayTypeError(other.getClass)
}
}
case (key, JNull) =>
builder.putNull(key)
case (key, other) =>
- throw new RuntimeException(s"Do not support type ${other.getClass}.")
+ throw QueryExecutionErrors.unsupportedJavaTypeError(other.getClass)
}
builder.build()
}
@@ -195,7 +196,7 @@ object Metadata {
case x: Metadata =>
toJsonValue(x.map)
case other =>
- throw new RuntimeException(s"Do not support type ${other.getClass}.")
+ throw QueryExecutionErrors.unsupportedJavaTypeError(other.getClass)
}
}
@@ -222,7 +223,7 @@ object Metadata {
case null =>
0
case other =>
- throw new RuntimeException(s"Do not support type ${other.getClass}.")
+ throw QueryExecutionErrors.unsupportedJavaTypeError(other.getClass)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index a223344e92..8ff0536c2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -23,14 +23,13 @@ import scala.util.control.NonFatal
import org.json4s.JsonDSL._
-import org.apache.spark.SparkException
import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser}
import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils}
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
+import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
/**
@@ -333,9 +332,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
if (found.length > 1) {
val names = found.map(f => prettyFieldName(normalizedPath :+ f.name))
.mkString("[", ", ", " ]")
- throw new AnalysisException(
- s"Ambiguous field name: ${prettyFieldName(normalizedPath :+ searchName)}. Found " +
- s"multiple columns that can match: $names")
+ throw QueryCompilationErrors.ambiguousFieldNameError(
+ prettyFieldName(normalizedPath :+ searchName), names)
} else if (found.isEmpty) {
None
} else {
@@ -523,7 +521,7 @@ object StructType extends AbstractDataType {
private[sql] def fromString(raw: String): StructType = {
Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parseString(raw)) match {
case t: StructType => t
- case _ => throw new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw")
+ case _ => throw QueryExecutionErrors.failedParsingStructTypeError(raw)
}
}
@@ -586,8 +584,7 @@ object StructType extends AbstractDataType {
nullable = leftNullable || rightNullable)
} catch {
case NonFatal(e) =>
- throw new SparkException(s"Failed to merge fields '$leftName' and " +
- s"'$rightName'. " + e.getMessage)
+ throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e)
}
}
.orElse {
@@ -610,14 +607,14 @@ object StructType extends AbstractDataType {
if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) {
DecimalType(leftPrecision, leftScale)
} else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) {
- throw new SparkException("Failed to merge decimal types with incompatible " +
- s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale")
+ throw QueryExecutionErrors.cannotMergeDecimalTypesWithIncompatiblePrecisionAndScaleError(
+ leftPrecision, rightPrecision, leftScale, rightScale)
} else if (leftPrecision != rightPrecision) {
- throw new SparkException("Failed to merge decimal types with incompatible " +
- s"precision $leftPrecision and $rightPrecision")
+ throw QueryExecutionErrors.cannotMergeDecimalTypesWithIncompatiblePrecisionError(
+ leftPrecision, rightPrecision)
} else {
- throw new SparkException("Failed to merge decimal types with incompatible " +
- s"scala $leftScale and $rightScale")
+ throw QueryExecutionErrors.cannotMergeDecimalTypesWithIncompatibleScaleError(
+ leftScale, rightScale)
}
case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_])
@@ -627,8 +624,7 @@ object StructType extends AbstractDataType {
leftType
case _ =>
- throw new SparkException(s"Failed to merge incompatible data types ${left.catalogString}" +
- s" and ${right.catalogString}")
+ throw QueryExecutionErrors.cannotMergeIncompatibleDataTypesError(left, right)
}
private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
index 7026ff7de2..a3e76797b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
@@ -21,12 +21,13 @@ import scala.math.Numeric._
import scala.math.Ordering
import org.apache.spark.sql.catalyst.util.SQLOrderingUtil
+import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.Decimal.DecimalIsConflicted
private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering {
private def checkOverflow(res: Int, x: Byte, y: Byte, op: String): Unit = {
if (res > Byte.MaxValue || res < Byte.MinValue) {
- throw new ArithmeticException(s"$x $op $y caused overflow.")
+ throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y)
}
}
@@ -50,7 +51,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOr
override def negate(x: Byte): Byte = {
if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow can happen
- throw new ArithmeticException(s"- $x caused overflow.")
+ throw QueryExecutionErrors.unaryMinusCauseOverflowError(x)
}
(-x).toByte
}
@@ -60,7 +61,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOr
private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering {
private def checkOverflow(res: Int, x: Short, y: Short, op: String): Unit = {
if (res > Short.MaxValue || res < Short.MinValue) {
- throw new ArithmeticException(s"$x $op $y caused overflow.")
+ throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y)
}
}
@@ -84,7 +85,7 @@ private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.Shor
override def negate(x: Short): Short = {
if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow can happen
- throw new ArithmeticException(s"- $x caused overflow.")
+ throw QueryExecutionErrors.unaryMinusCauseOverflowError(x)
}
(-x).toShort
}
@@ -114,14 +115,11 @@ private[sql] object LongExactNumeric extends LongIsIntegral with Ordering.LongOr
if (x == x.toInt) {
x.toInt
} else {
- throw new ArithmeticException(s"Casting $x to int causes overflow")
+ throw QueryExecutionErrors.castingCauseOverflowError(x, "int")
}
}
private[sql] object FloatExactNumeric extends FloatIsFractional {
- private def overflowException(x: Float, dataType: String) =
- throw new ArithmeticException(s"Casting $x to $dataType causes overflow")
-
private val intUpperBound = Int.MaxValue
private val intLowerBound = Int.MinValue
private val longUpperBound = Long.MaxValue
@@ -137,7 +135,7 @@ private[sql] object FloatExactNumeric extends FloatIsFractional {
if (Math.floor(x) <= intUpperBound && Math.ceil(x) >= intLowerBound) {
x.toInt
} else {
- overflowException(x, "int")
+ throw QueryExecutionErrors.castingCauseOverflowError(x, "int")
}
}
@@ -145,7 +143,7 @@ private[sql] object FloatExactNumeric extends FloatIsFractional {
if (Math.floor(x) <= longUpperBound && Math.ceil(x) >= longLowerBound) {
x.toLong
} else {
- overflowException(x, "int")
+ throw QueryExecutionErrors.castingCauseOverflowError(x, "int")
}
}
@@ -153,9 +151,6 @@ private[sql] object FloatExactNumeric extends FloatIsFractional {
}
private[sql] object DoubleExactNumeric extends DoubleIsFractional {
- private def overflowException(x: Double, dataType: String) =
- throw new ArithmeticException(s"Casting $x to $dataType causes overflow")
-
private val intUpperBound = Int.MaxValue
private val intLowerBound = Int.MinValue
private val longUpperBound = Long.MaxValue
@@ -165,7 +160,7 @@ private[sql] object DoubleExactNumeric extends DoubleIsFractional {
if (Math.floor(x) <= intUpperBound && Math.ceil(x) >= intLowerBound) {
x.toInt
} else {
- overflowException(x, "int")
+ throw QueryExecutionErrors.castingCauseOverflowError(x, "int")
}
}
@@ -173,7 +168,7 @@ private[sql] object DoubleExactNumeric extends DoubleIsFractional {
if (Math.floor(x) <= longUpperBound && Math.ceil(x) >= longLowerBound) {
x.toLong
} else {
- overflowException(x, "long")
+ throw QueryExecutionErrors.castingCauseOverflowError(x, "long")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 5d5da795a5..ce8acd1825 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.complex.MapVector
-import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
+import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark.sql.internal.SQLConf
@@ -54,6 +54,8 @@ private[sql] object ArrowUtils {
new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
}
case NullType => ArrowType.Null.INSTANCE
+ case YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
+ case DayTimeIntervalType => new ArrowType.Interval(IntervalUnit.DAY_TIME)
case _ =>
throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}")
}
@@ -74,6 +76,8 @@ private[sql] object ArrowUtils {
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
case ArrowType.Null.INSTANCE => NullType
+ case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType
+ case di: ArrowType.Interval if di.getUnit == IntervalUnit.DAY_TIME => DayTimeIntervalType
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
index 0dbae707a4..169c5d6a31 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.{DateTimeConstants, DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -197,8 +197,8 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {
"1970-01-01",
"1972-12-31",
"2019-02-16",
- "2119-03-16").foreach { timestamp =>
- val input = LocalDate.parse(timestamp)
+ "2119-03-16").foreach { date =>
+ val input = LocalDate.parse(date)
val result = CatalystTypeConverters.convertToCatalyst(input)
val expected = DateTimeUtils.localDateToDays(input)
assert(result === expected)
@@ -294,4 +294,44 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {
}
}
}
+
+ test("SPARK-35204: createToCatalystConverter for date") {
+ Seq(true, false).foreach { enable =>
+ withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> enable.toString) {
+ Seq(-1234, 0, 1234).foreach { days =>
+ val converter = CatalystTypeConverters.createToCatalystConverter(DateType)
+
+ val ld = LocalDate.ofEpochDay(days)
+ val result1 = converter(ld)
+
+ val d = java.sql.Date.valueOf(ld)
+ val result2 = converter(d)
+
+ val expected = DateTimeUtils.localDateToDays(ld)
+ assert(result1 === expected)
+ assert(result2 === expected)
+ }
+ }
+ }
+ }
+
+ test("SPARK-35204: createToCatalystConverter for timestamp") {
+ Seq(true, false).foreach { enable =>
+ withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> enable.toString) {
+ Seq(-1234, 0, 1234).foreach { seconds =>
+ val converter = CatalystTypeConverters.createToCatalystConverter(TimestampType)
+
+ val i = Instant.ofEpochSecond(seconds)
+ val result1 = converter(i)
+
+ val t = new java.sql.Timestamp(seconds * DateTimeConstants.MILLIS_PER_SECOND)
+ val result2 = converter(t)
+
+ val expected = seconds * DateTimeConstants.MICROS_PER_SECOND
+ assert(result1 === expected)
+ assert(result2 === expected)
+ }
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 4ac7823502..dc9f92d7c6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.connector.InMemoryTable
+import org.apache.spark.sql.connector.catalog.InMemoryTable
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala
index 1c849fa21e..f7e57e3b27 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode}
-import org.apache.spark.sql.connector.InMemoryTableCatalog
-import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog}
import org.apache.spark.sql.connector.expressions.Expressions
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala
index ec9480514b..7d6ad3bc60 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala
@@ -29,8 +29,7 @@ import org.scalatest.matchers.must.Matchers
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, ExternalCatalog, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog}
-import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table}
+import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, InMemoryTable, InMemoryTableCatalog, Table}
import org.apache.spark.sql.types._
class TableLookupCacheSuite extends AnalysisTest with Matchers {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 095894b9ff..aec8725d51 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
+import java.time.{Duration, Period}
import java.util.TimeZone
import scala.language.implicitConversions
@@ -28,7 +29,6 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils}
-import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC
import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.internal.SQLConf
@@ -932,7 +932,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Literal(Date.valueOf("1970-02-01")),
Literal(negateExact(stringToInterval("interval 1 month")))),
EmptyRow,
- s"sequence boundaries: 0 to 2678400000000 by -${28 * MICROS_PER_DAY}")
+ s"sequence boundaries: 0 to 2678400000000 by -1 months")
// SPARK-32133: Sequence step must be a day interval if start and end values are dates
checkExceptionInExpression[IllegalArgumentException](Sequence(
@@ -943,6 +943,178 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
}
}
+ test("SPARK-35088: Accept ANSI intervals by the Sequence expression") {
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
+ Literal(Duration.ofHours(12))),
+ Seq(
+ Timestamp.valueOf("2018-01-01 00:00:00"),
+ Timestamp.valueOf("2018-01-01 12:00:00"),
+ Timestamp.valueOf("2018-01-02 00:00:00")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Timestamp.valueOf("2018-01-02 00:00:01")),
+ Literal(Duration.ofHours(12))),
+ Seq(
+ Timestamp.valueOf("2018-01-01 00:00:00"),
+ Timestamp.valueOf("2018-01-01 12:00:00"),
+ Timestamp.valueOf("2018-01-02 00:00:00")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Duration.ofHours(-12))),
+ Seq(
+ Timestamp.valueOf("2018-01-02 00:00:00"),
+ Timestamp.valueOf("2018-01-01 12:00:00"),
+ Timestamp.valueOf("2018-01-01 00:00:00")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
+ Literal(Timestamp.valueOf("2017-12-31 23:59:59")),
+ Literal(Duration.ofHours(-12))),
+ Seq(
+ Timestamp.valueOf("2018-01-02 00:00:00"),
+ Timestamp.valueOf("2018-01-01 12:00:00"),
+ Timestamp.valueOf("2018-01-01 00:00:00")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
+ Literal(Period.ofMonths(1))),
+ Seq(
+ Timestamp.valueOf("2018-01-01 00:00:00"),
+ Timestamp.valueOf("2018-02-01 00:00:00"),
+ Timestamp.valueOf("2018-03-01 00:00:00")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Period.ofMonths(-1))),
+ Seq(
+ Timestamp.valueOf("2018-03-01 00:00:00"),
+ Timestamp.valueOf("2018-02-01 00:00:00"),
+ Timestamp.valueOf("2018-01-01 00:00:00")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-01-31 00:00:00")),
+ Literal(Timestamp.valueOf("2018-04-30 00:00:00")),
+ Literal(Period.ofMonths(1))),
+ Seq(
+ Timestamp.valueOf("2018-01-31 00:00:00"),
+ Timestamp.valueOf("2018-02-28 00:00:00"),
+ Timestamp.valueOf("2018-03-31 00:00:00"),
+ Timestamp.valueOf("2018-04-30 00:00:00")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Timestamp.valueOf("2023-01-01 00:00:00")),
+ Literal(Period.of(1, 5, 0))),
+ Seq(
+ Timestamp.valueOf("2018-01-01 00:00:00.000"),
+ Timestamp.valueOf("2019-06-01 00:00:00.000"),
+ Timestamp.valueOf("2020-11-01 00:00:00.000"),
+ Timestamp.valueOf("2022-04-01 00:00:00.000")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2022-04-01 00:00:00")),
+ Literal(Timestamp.valueOf("2017-01-01 00:00:00")),
+ Literal(Period.of(-1, -5, 0))),
+ Seq(
+ Timestamp.valueOf("2022-04-01 00:00:00.000"),
+ Timestamp.valueOf("2020-11-01 00:00:00.000"),
+ Timestamp.valueOf("2019-06-01 00:00:00.000"),
+ Timestamp.valueOf("2018-01-01 00:00:00.000")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Timestamp.valueOf("2018-01-04 00:00:00")),
+ Literal(Duration.ofDays(1))),
+ Seq(
+ Timestamp.valueOf("2018-01-01 00:00:00.000"),
+ Timestamp.valueOf("2018-01-02 00:00:00.000"),
+ Timestamp.valueOf("2018-01-03 00:00:00.000"),
+ Timestamp.valueOf("2018-01-04 00:00:00.000")))
+
+ checkEvaluation(new Sequence(
+ Literal(Timestamp.valueOf("2018-01-04 00:00:00")),
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Duration.ofDays(-1))),
+ Seq(
+ Timestamp.valueOf("2018-01-04 00:00:00.000"),
+ Timestamp.valueOf("2018-01-03 00:00:00.000"),
+ Timestamp.valueOf("2018-01-02 00:00:00.000"),
+ Timestamp.valueOf("2018-01-01 00:00:00.000")))
+
+ checkExceptionInExpression[IllegalArgumentException](
+ new Sequence(
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Timestamp.valueOf("2018-01-04 00:00:00")),
+ Literal(Period.ofDays(1))),
+ EmptyRow, s"sequence boundaries: 1514793600000000 to 1515052800000000 by 0")
+
+ checkExceptionInExpression[IllegalArgumentException](
+ new Sequence(
+ Literal(Timestamp.valueOf("2018-01-04 00:00:00")),
+ Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+ Literal(Period.ofDays(-1))),
+ EmptyRow, s"sequence boundaries: 1515052800000000 to 1514793600000000 by 0")
+
+ DateTimeTestUtils.withDefaultTimeZone(UTC) {
+ checkEvaluation(new Sequence(
+ Literal(Date.valueOf("2018-01-01")),
+ Literal(Date.valueOf("2018-03-01")),
+ Literal(Period.ofMonths(1))),
+ Seq(
+ Date.valueOf("2018-01-01"),
+ Date.valueOf("2018-02-01"),
+ Date.valueOf("2018-03-01")))
+
+ checkEvaluation(new Sequence(
+ Literal(Date.valueOf("2018-01-31")),
+ Literal(Date.valueOf("2018-04-30")),
+ Literal(Period.ofMonths(1))),
+ Seq(
+ Date.valueOf("2018-01-31"),
+ Date.valueOf("2018-02-28"),
+ Date.valueOf("2018-03-31"),
+ Date.valueOf("2018-04-30")))
+
+ checkEvaluation(new Sequence(
+ Literal(Date.valueOf("2018-01-01")),
+ Literal(Date.valueOf("2023-01-01")),
+ Literal(Period.of(1, 5, 0))),
+ Seq(
+ Date.valueOf("2018-01-01"),
+ Date.valueOf("2019-06-01"),
+ Date.valueOf("2020-11-01"),
+ Date.valueOf("2022-04-01")))
+
+ checkExceptionInExpression[IllegalArgumentException](
+ new Sequence(
+ Literal(Date.valueOf("2018-01-01")),
+ Literal(Date.valueOf("2018-01-05")),
+ Literal(Period.ofDays(2))),
+ EmptyRow,
+ "sequence step must be a day year-month interval if start and end values are dates")
+
+ checkExceptionInExpression[IllegalArgumentException](
+ new Sequence(
+ Literal(Date.valueOf("1970-01-01")),
+ Literal(Date.valueOf("1970-02-01")),
+ Literal(Period.ofMonths(-1))),
+ EmptyRow,
+ s"sequence boundaries: 0 to 2678400000000 by -1")
+
+ assert(Sequence(
+ Cast(Literal("2011-03-01"), DateType),
+ Cast(Literal("2011-04-01"), DateType),
+ Option(Literal(Duration.ofHours(1)))).checkInputDataTypes().isFailure)
+ }
+ }
+
test("Sequence with default step") {
// +/- 1 for integral type
checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
index 3f3a64ef8d..cf2f5057cc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
@@ -23,8 +23,8 @@ import java.time.temporal.ChronoUnit
import scala.language.implicitConversions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
-import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils.{safeStringToInterval, stringToInterval}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DayTimeIntervalType, Decimal, DecimalType, YearMonthIntervalType}
@@ -76,17 +76,17 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("hours") {
- checkEvaluation(ExtractIntervalHours("0 hours"), 0L)
- checkEvaluation(ExtractIntervalHours("1 hour"), 1L)
- checkEvaluation(ExtractIntervalHours("-1 hour"), -1L)
- checkEvaluation(ExtractIntervalHours("23 hours"), 23L)
- checkEvaluation(ExtractIntervalHours("-23 hours"), -23L)
+ checkEvaluation(ExtractIntervalHours("0 hours"), 0.toByte)
+ checkEvaluation(ExtractIntervalHours("1 hour"), 1.toByte)
+ checkEvaluation(ExtractIntervalHours("-1 hour"), -1.toByte)
+ checkEvaluation(ExtractIntervalHours("23 hours"), 23.toByte)
+ checkEvaluation(ExtractIntervalHours("-23 hours"), -23.toByte)
// Years, months and days must not be taken into account
- checkEvaluation(ExtractIntervalHours("100 year 10 months 10 days 10 hours"), 10L)
+ checkEvaluation(ExtractIntervalHours("100 year 10 months 10 days 10 hours"), 10.toByte)
// Minutes should be taken into account
- checkEvaluation(ExtractIntervalHours("10 hours 100 minutes"), 11L)
- checkEvaluation(ExtractIntervalHours(largeInterval), 11L)
- checkEvaluation(ExtractIntervalHours("25 hours"), 1L)
+ checkEvaluation(ExtractIntervalHours("10 hours 100 minutes"), 11.toByte)
+ checkEvaluation(ExtractIntervalHours(largeInterval), 11.toByte)
+ checkEvaluation(ExtractIntervalHours("25 hours"), 1.toByte)
}
@@ -410,4 +410,40 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
DayTimeIntervalType, numType)
}
}
+
+ test("ANSI: extract years and months") {
+ Seq(Period.ZERO,
+ Period.ofMonths(100),
+ Period.ofMonths(-100),
+ Period.ofYears(100),
+ Period.ofYears(-100)).foreach { p =>
+ checkEvaluation(ExtractANSIIntervalYears(Literal(p)),
+ IntervalUtils.getYears(p.toTotalMonths.toInt))
+ checkEvaluation(ExtractANSIIntervalMonths(Literal(p)),
+ IntervalUtils.getMonths(p.toTotalMonths.toInt))
+ }
+ checkEvaluation(ExtractANSIIntervalYears(Literal(null, YearMonthIntervalType)), null)
+ checkEvaluation(ExtractANSIIntervalMonths(Literal(null, YearMonthIntervalType)), null)
+ }
+
+ test("ANSI: extract days, hours, minutes and seconds") {
+ Seq(Duration.ZERO,
+ Duration.ofMillis(1L * MILLIS_PER_DAY + 2 * MILLIS_PER_SECOND),
+ Duration.ofMillis(-1L * MILLIS_PER_DAY + 2 * MILLIS_PER_SECOND),
+ Duration.ofDays(100),
+ Duration.ofDays(-100),
+ Duration.ofHours(-100)).foreach { d =>
+
+ checkEvaluation(ExtractANSIIntervalDays(Literal(d)), d.toDays.toInt)
+ checkEvaluation(ExtractANSIIntervalHours(Literal(d)), (d.toHours % HOURS_PER_DAY).toByte)
+ checkEvaluation(ExtractANSIIntervalMinutes(Literal(d)),
+ (d.toMinutes % MINUTES_PER_HOUR).toByte)
+ checkEvaluation(ExtractANSIIntervalSeconds(Literal(d)),
+ IntervalUtils.getSeconds(IntervalUtils.durationToMicros(d)))
+ }
+ checkEvaluation(ExtractANSIIntervalDays(Literal(null, DayTimeIntervalType)), null)
+ checkEvaluation(ExtractANSIIntervalHours(Literal(null, DayTimeIntervalType)), null)
+ checkEvaluation(ExtractANSIIntervalMinutes(Literal(null, DayTimeIntervalType)), null)
+ checkEvaluation(ExtractANSIIntervalSeconds(Literal(null, DayTimeIntervalType)), null)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
index 926628aca9..3d11ff97f6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String
class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -50,8 +51,10 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
testBothCodegenAndInterpreted("unsafe buffer") {
val inputRow = InternalRow.fromSeq(Seq(
false, 1.toByte, 9.toShort, -18, 53L, 3.2f, 7.8, 4, 9L, Int.MinValue, Long.MaxValue))
- val numBytes = UnsafeRow.calculateBitSetWidthInBytes(fixedLengthTypes.length)
- val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length)
+ val numFields = fixedLengthTypes.length
+ val numBytes = Platform.BYTE_ARRAY_OFFSET + UnsafeRow.calculateBitSetWidthInBytes(numFields) +
+ UnsafeRow.WORD_SIZE * numFields
+ val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, numFields)
val proj = createMutableProjection(fixedLengthTypes)
val projUnsafeRow = proj.target(unsafeBuffer)(inputRow)
assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
index 7895f4d5ef..2fab553183 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -18,39 +18,36 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE
import org.apache.spark.sql.types._
class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
-
- def getRootFields(requestedFields: StructField*): Seq[RootField] = {
- requestedFields.map { f =>
+ private def testPrunedSchema(
+ schema: StructType,
+ requestedFields: Seq[StructField],
+ expectedSchema: StructType): Unit = {
+ val requestedRootFields = requestedFields.map { f =>
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
+ val prunedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
+ assert(prunedSchema === expectedSchema)
}
test("prune schema by the requested fields") {
- def testPrunedSchema(
- schema: StructType,
- requestedFields: StructField*): Unit = {
- val requestedRootFields = requestedFields.map { f =>
- // `derivedFromAtt` doesn't affect the result of pruned schema.
- SchemaPruning.RootField(field = f, derivedFromAtt = true)
- }
- val expectedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
- assert(expectedSchema == StructType(requestedFields))
- }
-
- testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("a", IntegerType))
- testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("b", IntegerType))
+ testPrunedSchema(
+ StructType.fromDDL("a int, b int"),
+ Seq(StructField("a", IntegerType)),
+ StructType.fromDDL("a int, b int"))
val structOfStruct = StructType.fromDDL("a struct, b int")
- testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("a int, b int")))
- testPrunedSchema(structOfStruct, StructField("b", IntegerType))
- testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("b int")))
+ testPrunedSchema(structOfStruct,
+ Seq(StructField("a", StructType.fromDDL("a int")), StructField("b", IntegerType)),
+ StructType.fromDDL("a struct, b int"))
+ testPrunedSchema(structOfStruct,
+ Seq(StructField("a", StructType.fromDDL("a int"))),
+ StructType.fromDDL("a struct, b int"))
val arrayOfStruct = StructField("a", ArrayType(StructType.fromDDL("a int, b int, c string")))
val mapOfStruct = StructField("d", MapType(StructType.fromDDL("a int, b int, c string"),
@@ -60,44 +57,76 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
arrayOfStruct :: StructField("b", structOfStruct) :: StructField("c", IntegerType) ::
mapOfStruct :: Nil)
- testPrunedSchema(complexStruct, StructField("a", ArrayType(StructType.fromDDL("b int"))),
- StructField("b", StructType.fromDDL("a int")))
testPrunedSchema(complexStruct,
- StructField("a", ArrayType(StructType.fromDDL("b int, c string"))),
- StructField("b", StructType.fromDDL("b int")))
+ Seq(StructField("a", ArrayType(StructType.fromDDL("b int"))),
+ StructField("b", StructType.fromDDL("a int"))),
+ StructType(
+ StructField("a", ArrayType(StructType.fromDDL("b int"))) ::
+ StructField("b", StructType.fromDDL("a int")) ::
+ StructField("c", IntegerType) ::
+ mapOfStruct :: Nil))
+ testPrunedSchema(complexStruct,
+ Seq(StructField("a", ArrayType(StructType.fromDDL("b int, c string"))),
+ StructField("b", StructType.fromDDL("b int"))),
+ StructType(
+ StructField("a", ArrayType(StructType.fromDDL("b int, c string"))) ::
+ StructField("b", StructType.fromDDL("b int")) ::
+ StructField("c", IntegerType) ::
+ mapOfStruct :: Nil))
val selectFieldInMap = StructField("d", MapType(StructType.fromDDL("a int, b int"),
StructType.fromDDL("e int, f string")))
- testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
+ testPrunedSchema(complexStruct,
+ Seq(StructField("c", IntegerType), selectFieldInMap),
+ StructType(
+ arrayOfStruct ::
+ StructField("b", structOfStruct) ::
+ StructField("c", IntegerType) ::
+ selectFieldInMap :: Nil))
}
test("SPARK-35096: test case insensitivity of pruned schema") {
- Seq(true, false).foreach(isCaseSensitive => {
+ val upperCaseSchema = StructType.fromDDL("A struct, B int")
+ val lowerCaseSchema = StructType.fromDDL("a struct, b int")
+ val upperCaseRequestedFields = Seq(StructField("A", StructType.fromDDL("A int")))
+ val lowerCaseRequestedFields = Seq(StructField("a", StructType.fromDDL("a int")))
+
+ Seq(true, false).foreach { isCaseSensitive =>
withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) {
if (isCaseSensitive) {
- // Schema is case-sensitive
- val requestedFields = getRootFields(StructField("id", IntegerType))
- val prunedSchema = SchemaPruning.pruneDataSchema(
- StructType.fromDDL("ID int, name String"), requestedFields)
- assert(prunedSchema == StructType(Seq.empty))
- // Root fields are case-sensitive
- val rootFieldsSchema = SchemaPruning.pruneDataSchema(
- StructType.fromDDL("id int, name String"),
- getRootFields(StructField("ID", IntegerType)))
- assert(rootFieldsSchema == StructType(StructType(Seq.empty)))
+ testPrunedSchema(
+ upperCaseSchema,
+ upperCaseRequestedFields,
+ StructType.fromDDL("A struct, B int"))
+ testPrunedSchema(
+ upperCaseSchema,
+ lowerCaseRequestedFields,
+ upperCaseSchema)
+
+ testPrunedSchema(
+ lowerCaseSchema,
+ upperCaseRequestedFields,
+ lowerCaseSchema)
+ testPrunedSchema(
+ lowerCaseSchema,
+ lowerCaseRequestedFields,
+ StructType.fromDDL("a struct, b int"))
} else {
- // Schema is case-insensitive
- val prunedSchema = SchemaPruning.pruneDataSchema(
- StructType.fromDDL("ID int, name String"),
- getRootFields(StructField("id", IntegerType)))
- assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil))
- // Root fields are case-insensitive
- val rootFieldsSchema = SchemaPruning.pruneDataSchema(
- StructType.fromDDL("id int, name String"),
- getRootFields(StructField("ID", IntegerType)))
- assert(rootFieldsSchema == StructType(StructField("id", IntegerType) :: Nil))
+ Seq(upperCaseRequestedFields, lowerCaseRequestedFields).foreach { requestedFields =>
+ testPrunedSchema(
+ upperCaseSchema,
+ requestedFields,
+ StructType.fromDDL("A struct, B int"))
+ }
+
+ Seq(upperCaseRequestedFields, lowerCaseRequestedFields).foreach { requestedFields =>
+ testPrunedSchema(
+ lowerCaseSchema,
+ requestedFields,
+ StructType.fromDDL("a struct, b int"))
+ }
}
}
- })
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
index 441c15340a..997ccb7204 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
@@ -31,8 +31,10 @@ class CombineConcatsSuite extends PlanTest {
}
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
- val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
- val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
+ val correctAnswer = Limit(Literal(1), Project(Alias(e2, "out")() :: Nil, OneRowRelation()))
+ .analyze
+ val actual = Optimize.execute(Limit(Literal(1), Project(Alias(e1, "out")() :: Nil,
+ OneRowRelation())).analyze)
comparePlans(actual, correctAnswer)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
index 0ae4d3f6e6..a856caa678 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
@@ -329,14 +329,14 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
comparePlans(optimized, expected)
}
- test("Nested field pruning for Project and Generate: not prune on generator output") {
+ test("Nested field pruning for Project and Generate: multiple-field case is not supported") {
val companies = LocalRelation(
'id.int,
'employers.array(employer))
val query = companies
.generate(Explode('employers.getField("company")), outputNames = Seq("company"))
- .select('company.getField("name"))
+ .select('company.getField("name"), 'company.getField("address"))
.analyze
val optimized = Optimize.execute(query)
@@ -347,7 +347,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
.generate(Explode($"${aliases(0)}"),
unrequiredChildIndex = Seq(0),
outputNames = Seq("company"))
- .select('company.getField("name").as("company.name"))
+ .select('company.getField("name").as("company.name"),
+ 'company.getField("address").as("company.address"))
.analyze
comparePlans(optimized, expected)
}
@@ -684,6 +685,29 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
).analyze
comparePlans(optimized2, expected2)
}
+
+ test("SPARK-34638: nested column prune on generator output for one field") {
+ val companies = LocalRelation(
+ 'id.int,
+ 'employers.array(employer))
+
+ val query = companies
+ .generate(Explode('employers.getField("company")), outputNames = Seq("company"))
+ .select('company.getField("name"))
+ .analyze
+ val optimized = Optimize.execute(query)
+
+ val aliases = collectGeneratedAliases(optimized)
+
+ val expected = companies
+ .select('employers.getField("company").getField("name").as(aliases(0)))
+ .generate(Explode($"${aliases(0)}"),
+ unrequiredChildIndex = Seq(0),
+ outputNames = Seq("company"))
+ .select('company.as("company.name"))
+ .analyze
+ comparePlans(optimized, expected)
+ }
}
object NestedColumnAliasingSuite {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala
index b093b39cc4..e63742ac0d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala
@@ -126,4 +126,25 @@ class OptimizeWithFieldsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
}
+
+ test("SPARK-35213: ensure optimize WithFields maintains correct WithField ordering") {
+ val originalQuery = testRelation
+ .select(
+ Alias(UpdateFields('a,
+ WithField("a1", Literal(3)) ::
+ WithField("b1", Literal(4)) ::
+ WithField("a1", Literal(5)) ::
+ Nil), "out")())
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = testRelation
+ .select(
+ Alias(UpdateFields('a,
+ WithField("a1", Literal(5)) ::
+ WithField("b1", Literal(4)) ::
+ Nil), "out")())
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
index 3eba003d77..d376c31ef9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
@@ -96,7 +96,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy('a + 'b)(('a + 'b) as 'c)
.analyze
val optimized = Optimize.execute(query)
- comparePlans(optimized, EnforceGroupingReferencesInAggregates(expected))
+ comparePlans(optimized, expected)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index d149967094..dcd2fbbf00 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -36,8 +36,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches =
- Batch("Finish Analysis", Once,
- EnforceGroupingReferencesInAggregates) ::
Batch("collapse projections", FixedPoint(10),
CollapseProject) ::
Batch("Constant Folding", FixedPoint(10),
@@ -59,7 +57,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
val optimized = Optimizer.execute(originalQuery.analyze)
assert(optimized.resolved, "optimized plans must be still resolvable")
- comparePlans(optimized, EnforceGroupingReferencesInAggregates(correctAnswer.analyze))
+ comparePlans(optimized, correctAnswer.analyze)
}
test("explicit get from namedStruct") {
@@ -407,6 +405,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
val arrayAggRel = relation.groupBy(
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
checkRule(arrayAggRel, arrayAggRel)
+
+ // This could be done if we had a more complex rule that checks that
+ // the CreateMap does not come from key.
+ val originalQuery = relation
+ .groupBy('id)(
+ GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
+ )
+ checkRule(originalQuery, originalQuery)
}
test("SPARK-23500: namedStruct and getField in the same Project #1") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
index 5c460f70a9..87d306a495 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
@@ -169,6 +169,19 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
fromYearMonthString)
failFuncWithInvalidInput("-\t99-15", "Interval string does not match year-month format",
fromYearMonthString)
+
+ assert(fromYearMonthString("178956970-6") == new CalendarInterval(Int.MaxValue - 1, 0, 0))
+ assert(fromYearMonthString("178956970-7") == new CalendarInterval(Int.MaxValue, 0, 0))
+
+ val e1 = intercept[IllegalArgumentException]{
+ assert(fromYearMonthString("178956970-8") == new CalendarInterval(Int.MinValue, 0, 0))
+ }.getMessage
+ assert(e1.contains("integer overflow"))
+ assert(fromYearMonthString("-178956970-8") == new CalendarInterval(Int.MinValue, 0, 0))
+ val e2 = intercept[IllegalArgumentException]{
+ assert(fromYearMonthString("-178956970-9") == new CalendarInterval(Int.MinValue, 0, 0))
+ }.getMessage
+ assert(e2.contains("integer overflow"))
}
test("from day-time string - legacy") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala
index aec361b979..eb35dd47a5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala
@@ -25,7 +25,6 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, FakeV2SessionCatalog, NoSuchNamespaceException}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.connector.InMemoryTableCatalog
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CaseInsensitiveStringMap
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala
new file mode 100644
index 0000000000..a48eb04a98
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException, PartitionsAlreadyExistException}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.types.StructType
+
+/**
+ * This class is used to test SupportsAtomicPartitionManagement API.
+ */
+class InMemoryAtomicPartitionTable (
+ name: String,
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String])
+ extends InMemoryPartitionTable(name, schema, partitioning, properties)
+ with SupportsAtomicPartitionManagement {
+
+ override def createPartition(
+ ident: InternalRow,
+ properties: util.Map[String, String]): Unit = {
+ if (memoryTablePartitions.containsKey(ident)) {
+ throw new PartitionAlreadyExistsException(name, ident, partitionSchema)
+ } else {
+ createPartitionKey(ident.toSeq(schema))
+ memoryTablePartitions.put(ident, properties)
+ }
+ }
+
+ override def dropPartition(ident: InternalRow): Boolean = {
+ if (memoryTablePartitions.containsKey(ident)) {
+ memoryTablePartitions.remove(ident)
+ removePartitionKey(ident.toSeq(schema))
+ true
+ } else {
+ false
+ }
+ }
+
+ override def createPartitions(
+ idents: Array[InternalRow],
+ properties: Array[util.Map[String, String]]): Unit = {
+ if (idents.exists(partitionExists)) {
+ throw new PartitionsAlreadyExistException(
+ name, idents.filter(partitionExists), partitionSchema)
+ }
+ idents.zip(properties).foreach { case (ident, property) =>
+ createPartition(ident, property)
+ }
+ }
+
+ override def dropPartitions(idents: Array[InternalRow]): Boolean = {
+ if (!idents.forall(partitionExists)) {
+ return false;
+ }
+ idents.forall(dropPartition)
+ }
+
+ override def truncatePartitions(idents: Array[InternalRow]): Boolean = {
+ val nonExistent = idents.filterNot(partitionExists)
+ if (nonExistent.isEmpty) {
+ idents.foreach(truncatePartition)
+ true
+ } else {
+ throw new NoSuchPartitionException(name, nonExistent.head, partitionSchema)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala
new file mode 100644
index 0000000000..58dc484711
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException}
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.types.StructType
+
+/**
+ * This class is used to test SupportsPartitionManagement API.
+ */
+class InMemoryPartitionTable(
+ name: String,
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String])
+ extends InMemoryTable(name, schema, partitioning, properties) with SupportsPartitionManagement {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ protected val memoryTablePartitions: util.Map[InternalRow, util.Map[String, String]] =
+ new ConcurrentHashMap[InternalRow, util.Map[String, String]]()
+
+ def partitionSchema: StructType = {
+ val partitionColumnNames = partitioning.toSeq.asPartitionColumns
+ new StructType(schema.filter(p => partitionColumnNames.contains(p.name)).toArray)
+ }
+
+ def createPartition(
+ ident: InternalRow,
+ properties: util.Map[String, String]): Unit = {
+ if (memoryTablePartitions.containsKey(ident)) {
+ throw new PartitionAlreadyExistsException(name, ident, partitionSchema)
+ } else {
+ createPartitionKey(ident.toSeq(schema))
+ memoryTablePartitions.put(ident, properties)
+ }
+ }
+
+ def dropPartition(ident: InternalRow): Boolean = {
+ if (memoryTablePartitions.containsKey(ident)) {
+ memoryTablePartitions.remove(ident)
+ removePartitionKey(ident.toSeq(schema))
+ true
+ } else {
+ false
+ }
+ }
+
+ def replacePartitionMetadata(ident: InternalRow, properties: util.Map[String, String]): Unit = {
+ if (memoryTablePartitions.containsKey(ident)) {
+ memoryTablePartitions.put(ident, properties)
+ } else {
+ throw new NoSuchPartitionException(name, ident, partitionSchema)
+ }
+ }
+
+ def loadPartitionMetadata(ident: InternalRow): util.Map[String, String] = {
+ if (memoryTablePartitions.containsKey(ident)) {
+ memoryTablePartitions.get(ident)
+ } else {
+ throw new NoSuchPartitionException(name, ident, partitionSchema)
+ }
+ }
+
+ override protected def addPartitionKey(key: Seq[Any]): Unit = {
+ memoryTablePartitions.putIfAbsent(InternalRow.fromSeq(key), Map.empty[String, String].asJava)
+ }
+
+ override def listPartitionIdentifiers(
+ names: Array[String],
+ ident: InternalRow): Array[InternalRow] = {
+ assert(names.length == ident.numFields,
+ s"Number of partition names (${names.length}) must be equal to " +
+ s"the number of partition values (${ident.numFields}).")
+ val schema = partitionSchema
+ assert(names.forall(fieldName => schema.fieldNames.contains(fieldName)),
+ s"Some partition names ${names.mkString("[", ", ", "]")} don't belong to " +
+ s"the partition schema '${schema.sql}'.")
+ val indexes = names.map(schema.fieldIndex)
+ val dataTypes = names.map(schema(_).dataType)
+ val currentRow = new GenericInternalRow(new Array[Any](names.length))
+ memoryTablePartitions.keySet().asScala.filter { key =>
+ for (i <- 0 until names.length) {
+ currentRow.values(i) = key.get(indexes(i), dataTypes(i))
+ }
+ currentRow == ident
+ }.toArray
+ }
+
+ override def renamePartition(from: InternalRow, to: InternalRow): Boolean = {
+ if (memoryTablePartitions.containsKey(to)) {
+ throw new PartitionAlreadyExistsException(name, to, partitionSchema)
+ } else {
+ val partValue = memoryTablePartitions.remove(from)
+ if (partValue == null) {
+ throw new NoSuchPartitionException(name, from, partitionSchema)
+ }
+ memoryTablePartitions.put(to, partValue) == null &&
+ renamePartitionKey(partitionSchema, from.toSeq(schema), to.toSeq(schema))
+ }
+ }
+
+ override def truncatePartition(ident: InternalRow): Boolean = {
+ if (memoryTablePartitions.containsKey(ident)) {
+ clearPartition(ident.toSeq(schema))
+ true
+ } else {
+ throw new NoSuchPartitionException(name, ident, partitionSchema)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala
new file mode 100644
index 0000000000..a24f5c9a0c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+
+import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.types.StructType
+
+class InMemoryPartitionTableCatalog extends InMemoryTableCatalog {
+ import CatalogV2Implicits._
+
+ override def createTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident)
+ }
+
+ InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
+
+ val table = new InMemoryAtomicPartitionTable(
+ s"$name.${ident.quoted}", schema, partitions, properties)
+ tables.put(ident, table)
+ namespaces.putIfAbsent(ident.namespace.toList, Map())
+ table
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
new file mode 100644
index 0000000000..b9069ff311
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -0,0 +1,535 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.time.{Instant, ZoneId}
+import java.time.temporal.ChronoUnit
+import java.util
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.scalatest.Assertions._
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow}
+import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
+import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
+import org.apache.spark.sql.connector.expressions._
+import org.apache.spark.sql.connector.read._
+import org.apache.spark.sql.connector.write._
+import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A simple in-memory table. Rows are stored as a buffered group produced by each output task.
+ */
+class InMemoryTable(
+ val name: String,
+ val schema: StructType,
+ override val partitioning: Array[Transform],
+ override val properties: util.Map[String, String],
+ val distribution: Distribution = Distributions.unspecified(),
+ val ordering: Array[SortOrder] = Array.empty,
+ val numPartitions: Option[Int] = None)
+ extends Table with SupportsRead with SupportsWrite with SupportsDelete
+ with SupportsMetadataColumns {
+
+ private object PartitionKeyColumn extends MetadataColumn {
+ override def name: String = "_partition"
+ override def dataType: DataType = StringType
+ override def comment: String = "Partition key used to store the row"
+ }
+
+ private object IndexColumn extends MetadataColumn {
+ override def name: String = "index"
+ override def dataType: DataType = IntegerType
+ override def comment: String = "Metadata column used to conflict with a data column"
+ }
+
+ // purposely exposes a metadata column that conflicts with a data column in some tests
+ override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn)
+ private val metadataColumnNames = metadataColumns.map(_.name).toSet -- schema.map(_.name)
+
+ private val allowUnsupportedTransforms =
+ properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean
+
+ partitioning.foreach {
+ case _: IdentityTransform =>
+ case _: YearsTransform =>
+ case _: MonthsTransform =>
+ case _: DaysTransform =>
+ case _: HoursTransform =>
+ case _: BucketTransform =>
+ case t if !allowUnsupportedTransforms =>
+ throw new IllegalArgumentException(s"Transform $t is not a supported transform")
+ }
+
+ // The key `Seq[Any]` is the partition values.
+ val dataMap: mutable.Map[Seq[Any], BufferedRows] = mutable.Map.empty
+
+ def data: Array[BufferedRows] = dataMap.values.toArray
+
+ def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq
+
+ private val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref =>
+ schema.findNestedField(ref.fieldNames(), includeCollections = false) match {
+ case Some(_) => ref.fieldNames()
+ case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.")
+ }
+ }
+
+ private val UTC = ZoneId.of("UTC")
+ private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate
+
+ private def getKey(row: InternalRow): Seq[Any] = {
+ def extractor(
+ fieldNames: Array[String],
+ schema: StructType,
+ row: InternalRow): (Any, DataType) = {
+ val index = schema.fieldIndex(fieldNames(0))
+ val value = row.toSeq(schema).apply(index)
+ if (fieldNames.length > 1) {
+ (value, schema(index).dataType) match {
+ case (row: InternalRow, nestedSchema: StructType) =>
+ extractor(fieldNames.drop(1), nestedSchema, row)
+ case (_, dataType) =>
+ throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}")
+ }
+ } else {
+ (value, schema(index).dataType)
+ }
+ }
+
+ val cleanedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)
+ partitioning.map {
+ case IdentityTransform(ref) =>
+ extractor(ref.fieldNames, cleanedSchema, row)._1
+ case YearsTransform(ref) =>
+ extractor(ref.fieldNames, cleanedSchema, row) match {
+ case (days: Int, DateType) =>
+ ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
+ case (micros: Long, TimestampType) =>
+ val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
+ ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
+ case (v, t) =>
+ throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
+ }
+ case MonthsTransform(ref) =>
+ extractor(ref.fieldNames, cleanedSchema, row) match {
+ case (days: Int, DateType) =>
+ ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
+ case (micros: Long, TimestampType) =>
+ val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
+ ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate)
+ case (v, t) =>
+ throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
+ }
+ case DaysTransform(ref) =>
+ extractor(ref.fieldNames, cleanedSchema, row) match {
+ case (days, DateType) =>
+ days
+ case (micros: Long, TimestampType) =>
+ ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
+ case (v, t) =>
+ throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
+ }
+ case HoursTransform(ref) =>
+ extractor(ref.fieldNames, cleanedSchema, row) match {
+ case (micros: Long, TimestampType) =>
+ ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
+ case (v, t) =>
+ throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
+ }
+ case BucketTransform(numBuckets, ref) =>
+ val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row)
+ val valueHashCode = if (value == null) 0 else value.hashCode
+ ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets
+ }
+ }
+
+ protected def addPartitionKey(key: Seq[Any]): Unit = {}
+
+ protected def renamePartitionKey(
+ partitionSchema: StructType,
+ from: Seq[Any],
+ to: Seq[Any]): Boolean = {
+ val rows = dataMap.remove(from).getOrElse(new BufferedRows(from.mkString("/")))
+ val newRows = new BufferedRows(to.mkString("/"))
+ rows.rows.foreach { r =>
+ val newRow = new GenericInternalRow(r.numFields)
+ for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType))
+ for (i <- 0 until partitionSchema.length) {
+ val j = schema.fieldIndex(partitionSchema(i).name)
+ newRow.update(j, to(i))
+ }
+ newRows.withRow(newRow)
+ }
+ dataMap.put(to, newRows).foreach { _ =>
+ throw new IllegalStateException(
+ s"The ${to.mkString("[", ", ", "]")} partition exists already")
+ }
+ true
+ }
+
+ protected def removePartitionKey(key: Seq[Any]): Unit = dataMap.synchronized {
+ dataMap.remove(key)
+ }
+
+ protected def createPartitionKey(key: Seq[Any]): Unit = dataMap.synchronized {
+ if (!dataMap.contains(key)) {
+ val emptyRows = new BufferedRows(key.toArray.mkString("/"))
+ val rows = if (key.length == schema.length) {
+ emptyRows.withRow(InternalRow.fromSeq(key))
+ } else emptyRows
+ dataMap.put(key, rows)
+ }
+ }
+
+ protected def clearPartition(key: Seq[Any]): Unit = dataMap.synchronized {
+ assert(dataMap.contains(key))
+ dataMap(key).clear()
+ }
+
+ def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized {
+ data.foreach(_.rows.foreach { row =>
+ val key = getKey(row)
+ dataMap += dataMap.get(key)
+ .map(key -> _.withRow(row))
+ .getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row))
+ addPartitionKey(key)
+ })
+ this
+ }
+
+ override def capabilities: util.Set[TableCapability] = Set(
+ TableCapability.BATCH_READ,
+ TableCapability.BATCH_WRITE,
+ TableCapability.STREAMING_WRITE,
+ TableCapability.OVERWRITE_BY_FILTER,
+ TableCapability.OVERWRITE_DYNAMIC,
+ TableCapability.TRUNCATE).asJava
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new InMemoryScanBuilder(schema)
+ }
+
+ class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder
+ with SupportsPushDownRequiredColumns {
+ private var schema: StructType = tableSchema
+
+ override def build: Scan =
+ new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema)
+
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ // if metadata columns are projected, return the table schema and metadata columns
+ val hasMetadataColumns = requiredSchema.map(_.name).exists(metadataColumnNames.contains)
+ if (hasMetadataColumns) {
+ schema = StructType(tableSchema ++ metadataColumnNames
+ .flatMap(name => metadataColumns.find(_.name == name))
+ .map(col => StructField(col.name, col.dataType, col.isNullable)))
+ }
+ }
+ }
+
+ class InMemoryBatchScan(data: Array[InputPartition], schema: StructType) extends Scan with Batch {
+ override def readSchema(): StructType = schema
+
+ override def toBatch: Batch = this
+
+ override def planInputPartitions(): Array[InputPartition] = data
+
+ override def createReaderFactory(): PartitionReaderFactory = {
+ val metadataColumns = schema.map(_.name).filter(metadataColumnNames.contains)
+ new BufferedRowsReaderFactory(metadataColumns)
+ }
+ }
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ InMemoryTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
+ InMemoryTable.maybeSimulateFailedTableWrite(info.options)
+
+ new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite {
+ private var writer: BatchWrite = Append
+ private var streamingWriter: StreamingWrite = StreamingAppend
+
+ override def truncate(): WriteBuilder = {
+ assert(writer == Append)
+ writer = TruncateAndAppend
+ streamingWriter = StreamingTruncateAndAppend
+ this
+ }
+
+ override def overwrite(filters: Array[Filter]): WriteBuilder = {
+ assert(writer == Append)
+ writer = new Overwrite(filters)
+ streamingWriter = new StreamingNotSupportedOperation(s"overwrite ($filters)")
+ this
+ }
+
+ override def overwriteDynamicPartitions(): WriteBuilder = {
+ assert(writer == Append)
+ writer = DynamicOverwrite
+ streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
+ this
+ }
+
+ override def build(): Write = new Write with RequiresDistributionAndOrdering {
+ override def requiredDistribution: Distribution = distribution
+
+ override def requiredOrdering: Array[SortOrder] = ordering
+
+ override def requiredNumPartitions(): Int = {
+ numPartitions.getOrElse(0)
+ }
+
+ override def toBatch: BatchWrite = writer
+
+ override def toStreaming: StreamingWrite = streamingWriter match {
+ case exc: StreamingNotSupportedOperation => exc.throwsException()
+ case s => s
+ }
+ }
+ }
+ }
+
+ private abstract class TestBatchWrite extends BatchWrite {
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
+ BufferedRowsWriterFactory
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {}
+ }
+
+ private object Append extends TestBatchWrite {
+ override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
+ withData(messages.map(_.asInstanceOf[BufferedRows]))
+ }
+ }
+
+ private object DynamicOverwrite extends TestBatchWrite {
+ override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
+ val newData = messages.map(_.asInstanceOf[BufferedRows])
+ dataMap --= newData.flatMap(_.rows.map(getKey))
+ withData(newData)
+ }
+ }
+
+ private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+ override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
+ val deleteKeys = InMemoryTable.filtersToKeys(
+ dataMap.keys, partCols.map(_.toSeq.quoted), filters)
+ dataMap --= deleteKeys
+ withData(messages.map(_.asInstanceOf[BufferedRows]))
+ }
+ }
+
+ private object TruncateAndAppend extends TestBatchWrite {
+ override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
+ dataMap.clear
+ withData(messages.map(_.asInstanceOf[BufferedRows]))
+ }
+ }
+
+ private abstract class TestStreamingWrite extends StreamingWrite {
+ def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = {
+ BufferedRowsWriterFactory
+ }
+
+ def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+ }
+
+ private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite {
+ override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory =
+ throwsException()
+
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit =
+ throwsException()
+
+ override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit =
+ throwsException()
+
+ def throwsException[T](): T = throw new IllegalStateException("The operation " +
+ s"${operation} isn't supported for streaming query.")
+ }
+
+ private object StreamingAppend extends TestStreamingWrite {
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
+ dataMap.synchronized {
+ withData(messages.map(_.asInstanceOf[BufferedRows]))
+ }
+ }
+ }
+
+ private object StreamingTruncateAndAppend extends TestStreamingWrite {
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
+ dataMap.synchronized {
+ dataMap.clear
+ withData(messages.map(_.asInstanceOf[BufferedRows]))
+ }
+ }
+ }
+
+ override def canDeleteWhere(filters: Array[Filter]): Boolean = {
+ InMemoryTable.supportsFilters(filters)
+ }
+
+ override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+ dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
+ }
+}
+
+object InMemoryTable {
+ val SIMULATE_FAILED_WRITE_OPTION = "spark.sql.test.simulateFailedWrite"
+
+ def filtersToKeys(
+ keys: Iterable[Seq[Any]],
+ partitionNames: Seq[String],
+ filters: Array[Filter]): Iterable[Seq[Any]] = {
+ keys.filter { partValues =>
+ filters.flatMap(splitAnd).forall {
+ case EqualTo(attr, value) =>
+ value == extractValue(attr, partitionNames, partValues)
+ case EqualNullSafe(attr, value) =>
+ val attrVal = extractValue(attr, partitionNames, partValues)
+ if (attrVal == null && value === null) {
+ true
+ } else if (attrVal == null || value === null) {
+ false
+ } else {
+ value == attrVal
+ }
+ case IsNull(attr) =>
+ null == extractValue(attr, partitionNames, partValues)
+ case IsNotNull(attr) =>
+ null != extractValue(attr, partitionNames, partValues)
+ case AlwaysTrue() => true
+ case f =>
+ throw new IllegalArgumentException(s"Unsupported filter type: $f")
+ }
+ }
+ }
+
+ def supportsFilters(filters: Array[Filter]): Boolean = {
+ filters.flatMap(splitAnd).forall {
+ case _: EqualTo => true
+ case _: EqualNullSafe => true
+ case _: IsNull => true
+ case _: IsNotNull => true
+ case _: AlwaysTrue => true
+ case _ => false
+ }
+ }
+
+ private def extractValue(
+ attr: String,
+ partFieldNames: Seq[String],
+ partValues: Seq[Any]): Any = {
+ partFieldNames.zipWithIndex.find(_._1 == attr) match {
+ case Some((_, partIndex)) =>
+ partValues(partIndex)
+ case _ =>
+ throw new IllegalArgumentException(s"Unknown filter attribute: $attr")
+ }
+ }
+
+ private def splitAnd(filter: Filter): Seq[Filter] = {
+ filter match {
+ case And(left, right) => splitAnd(left) ++ splitAnd(right)
+ case _ => filter :: Nil
+ }
+ }
+
+ def maybeSimulateFailedTableWrite(tableOptions: CaseInsensitiveStringMap): Unit = {
+ if (tableOptions.getBoolean(SIMULATE_FAILED_WRITE_OPTION, false)) {
+ throw new IllegalStateException("Manual write to table failure.")
+ }
+ }
+}
+
+class BufferedRows(
+ val key: String = "") extends WriterCommitMessage with InputPartition with Serializable {
+ val rows = new mutable.ArrayBuffer[InternalRow]()
+
+ def withRow(row: InternalRow): BufferedRows = {
+ rows.append(row)
+ this
+ }
+
+ def clear(): Unit = rows.clear()
+}
+
+private class BufferedRowsReaderFactory(
+ metadataColumns: Seq[String]) extends PartitionReaderFactory {
+ override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
+ new BufferedRowsReader(partition.asInstanceOf[BufferedRows], metadataColumns)
+ }
+}
+
+private class BufferedRowsReader(
+ partition: BufferedRows,
+ metadataColumns: Seq[String]) extends PartitionReader[InternalRow] {
+ private def addMetadata(row: InternalRow): InternalRow = {
+ val metadataRow = new GenericInternalRow(metadataColumns.map {
+ case "index" => index
+ case "_partition" => UTF8String.fromString(partition.key)
+ }.toArray)
+ new JoinedRow(row, metadataRow)
+ }
+
+ private var index: Int = -1
+
+ override def next(): Boolean = {
+ index += 1
+ index < partition.rows.length
+ }
+
+ override def get(): InternalRow = addMetadata(partition.rows(index))
+
+ override def close(): Unit = {}
+}
+
+private object BufferedRowsWriterFactory extends DataWriterFactory with StreamingDataWriterFactory {
+ override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
+ new BufferWriter
+ }
+
+ override def createWriter(
+ partitionId: Int,
+ taskId: Long,
+ epochId: Long): DataWriter[InternalRow] = {
+ new BufferWriter
+ }
+}
+
+private class BufferWriter extends DataWriter[InternalRow] {
+ private val buffer = new BufferedRows
+
+ override def write(row: InternalRow): Unit = buffer.rows.append(row.copy())
+
+ override def commit(): WriterCommitMessage = buffer
+
+ override def abort(): Unit = {}
+
+ override def close(): Unit = {}
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
new file mode 100644
index 0000000000..38113f9ea1
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
+import org.apache.spark.sql.connector.expressions.{SortOrder, Transform}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class BasicInMemoryTableCatalog extends TableCatalog {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ protected val namespaces: util.Map[List[String], Map[String, String]] =
+ new ConcurrentHashMap[List[String], Map[String, String]]()
+
+ protected val tables: util.Map[Identifier, Table] =
+ new ConcurrentHashMap[Identifier, Table]()
+
+ private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet()
+
+ private var _name: Option[String] = None
+
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {
+ _name = Some(name)
+ }
+
+ override def name: String = _name.get
+
+ override def listTables(namespace: Array[String]): Array[Identifier] = {
+ tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
+ }
+
+ override def loadTable(ident: Identifier): Table = {
+ Option(tables.get(ident)) match {
+ case Some(table) =>
+ table
+ case _ =>
+ throw new NoSuchTableException(ident)
+ }
+ }
+
+ override def invalidateTable(ident: Identifier): Unit = {
+ invalidatedTables.add(ident)
+ }
+
+ override def createTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ createTable(ident, schema, partitions, properties, Distributions.unspecified(),
+ Array.empty, None)
+ }
+
+ def createTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String],
+ distribution: Distribution,
+ ordering: Array[SortOrder],
+ requiredNumPartitions: Option[Int]): Table = {
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident)
+ }
+
+ InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
+
+ val tableName = s"$name.${ident.quoted}"
+ val table = new InMemoryTable(tableName, schema, partitions, properties, distribution,
+ ordering, requiredNumPartitions)
+ tables.put(ident, table)
+ namespaces.putIfAbsent(ident.namespace.toList, Map())
+ table
+ }
+
+ override def alterTable(ident: Identifier, changes: TableChange*): Table = {
+ val table = loadTable(ident).asInstanceOf[InMemoryTable]
+ val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes)
+ val schema = CatalogV2Util.applySchemaChanges(table.schema, changes)
+
+ // fail if the last column in the schema was dropped
+ if (schema.fields.isEmpty) {
+ throw new IllegalArgumentException(s"Cannot drop all fields")
+ }
+
+ val newTable = new InMemoryTable(table.name, schema, table.partitioning, properties)
+ .withData(table.data)
+
+ tables.put(ident, newTable)
+
+ newTable
+ }
+
+ override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined
+
+ override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = {
+ if (tables.containsKey(newIdent)) {
+ throw new TableAlreadyExistsException(newIdent)
+ }
+
+ Option(tables.remove(oldIdent)) match {
+ case Some(table) =>
+ tables.put(newIdent, table)
+ case _ =>
+ throw new NoSuchTableException(oldIdent)
+ }
+ }
+
+ def isTableInvalidated(ident: Identifier): Boolean = {
+ invalidatedTables.contains(ident)
+ }
+
+ def clearTables(): Unit = {
+ tables.clear()
+ }
+}
+
+class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamespaces {
+ private def allNamespaces: Seq[Seq[String]] = {
+ (tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct
+ }
+
+ override def namespaceExists(namespace: Array[String]): Boolean = {
+ allNamespaces.exists(_.startsWith(namespace))
+ }
+
+ override def listNamespaces: Array[Array[String]] = {
+ allNamespaces.map(_.head).distinct.map(Array(_)).toArray
+ }
+
+ override def listNamespaces(namespace: Array[String]): Array[Array[String]] = {
+ allNamespaces
+ .filter(_.size > namespace.length)
+ .filter(_.startsWith(namespace))
+ .map(_.take(namespace.length + 1))
+ .distinct
+ .map(_.toArray)
+ .toArray
+ }
+
+ override def loadNamespaceMetadata(namespace: Array[String]): util.Map[String, String] = {
+ Option(namespaces.get(namespace.toSeq)) match {
+ case Some(metadata) =>
+ metadata.asJava
+ case _ if namespaceExists(namespace) =>
+ util.Collections.emptyMap[String, String]
+ case _ =>
+ throw new NoSuchNamespaceException(namespace)
+ }
+ }
+
+ override def createNamespace(
+ namespace: Array[String],
+ metadata: util.Map[String, String]): Unit = {
+ if (namespaceExists(namespace)) {
+ throw new NamespaceAlreadyExistsException(namespace)
+ }
+
+ Option(namespaces.putIfAbsent(namespace.toList, metadata.asScala.toMap)) match {
+ case Some(_) =>
+ throw new NamespaceAlreadyExistsException(namespace)
+ case _ =>
+ // created successfully
+ }
+ }
+
+ override def alterNamespace(
+ namespace: Array[String],
+ changes: NamespaceChange*): Unit = {
+ val metadata = loadNamespaceMetadata(namespace).asScala.toMap
+ namespaces.put(namespace.toList, CatalogV2Util.applyNamespaceChanges(metadata, changes))
+ }
+
+ override def dropNamespace(namespace: Array[String]): Boolean = {
+ listNamespaces(namespace).foreach(dropNamespace)
+ try {
+ listTables(namespace).foreach(dropTable)
+ } catch {
+ case _: NoSuchNamespaceException =>
+ }
+ Option(namespaces.remove(namespace.toList)).isDefined
+ }
+
+ override def listTables(namespace: Array[String]): Array[Identifier] = {
+ if (namespace.isEmpty || namespaceExists(namespace)) {
+ super.listTables(namespace)
+ } else {
+ throw new NoSuchNamespaceException(namespace)
+ }
+ }
+}
+
+object InMemoryTableCatalog {
+ val SIMULATE_FAILED_CREATE_PROPERTY = "spark.sql.test.simulateFailedCreate"
+ val SIMULATE_DROP_BEFORE_REPLACE_PROPERTY = "spark.sql.test.simulateDropBeforeReplace"
+
+ def maybeSimulateFailedTableCreation(tableProperties: util.Map[String, String]): Unit = {
+ if ("true".equalsIgnoreCase(tableProperties.get(SIMULATE_FAILED_CREATE_PROPERTY))) {
+ throw new IllegalStateException("Manual create table failure.")
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala
new file mode 100644
index 0000000000..954650ae0e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+
+import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTableCatalog {
+ import InMemoryTableCatalog._
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override def stageCreate(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): StagedTable = {
+ validateStagedTable(partitions, properties)
+ new TestStagedCreateTable(
+ ident,
+ new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties))
+ }
+
+ override def stageReplace(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): StagedTable = {
+ validateStagedTable(partitions, properties)
+ new TestStagedReplaceTable(
+ ident,
+ new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties))
+ }
+
+ override def stageCreateOrReplace(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): StagedTable = {
+ validateStagedTable(partitions, properties)
+ new TestStagedCreateOrReplaceTable(
+ ident,
+ new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties))
+ }
+
+ private def validateStagedTable(
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Unit = {
+ if (partitions.nonEmpty) {
+ throw new UnsupportedOperationException(
+ s"Catalog $name: Partitioned tables are not supported")
+ }
+
+ maybeSimulateFailedTableCreation(properties)
+ }
+
+ private abstract class TestStagedTable(
+ ident: Identifier,
+ delegateTable: InMemoryTable)
+ extends StagedTable with SupportsWrite with SupportsRead {
+
+ override def abortStagedChanges(): Unit = {}
+
+ override def name(): String = delegateTable.name
+
+ override def schema(): StructType = delegateTable.schema
+
+ override def capabilities(): util.Set[TableCapability] = delegateTable.capabilities
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ delegateTable.newWriteBuilder(info)
+ }
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ delegateTable.newScanBuilder(options)
+ }
+ }
+
+ private class TestStagedCreateTable(
+ ident: Identifier,
+ delegateTable: InMemoryTable) extends TestStagedTable(ident, delegateTable) {
+
+ override def commitStagedChanges(): Unit = {
+ val maybePreCommittedTable = tables.putIfAbsent(ident, delegateTable)
+ if (maybePreCommittedTable != null) {
+ throw new TableAlreadyExistsException(
+ s"Table with identifier $ident and name $name was already created.")
+ }
+ }
+ }
+
+ private class TestStagedReplaceTable(
+ ident: Identifier,
+ delegateTable: InMemoryTable) extends TestStagedTable(ident, delegateTable) {
+
+ override def commitStagedChanges(): Unit = {
+ maybeSimulateDropBeforeCommit()
+ val maybePreCommittedTable = tables.replace(ident, delegateTable)
+ if (maybePreCommittedTable == null) {
+ throw QueryCompilationErrors.cannotReplaceMissingTableError(ident)
+ }
+ }
+
+ private def maybeSimulateDropBeforeCommit(): Unit = {
+ if ("true".equalsIgnoreCase(
+ delegateTable.properties.get(SIMULATE_DROP_BEFORE_REPLACE_PROPERTY))) {
+ tables.remove(ident)
+ }
+ }
+ }
+
+ private class TestStagedCreateOrReplaceTable(
+ ident: Identifier,
+ delegateTable: InMemoryTable) extends TestStagedTable(ident, delegateTable) {
+
+ override def commitStagedChanges(): Unit = {
+ tables.put(ident, delegateTable)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala
index ecfc6adff7..df2fbd6d17 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala
@@ -22,7 +22,6 @@ import java.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionsAlreadyExistException}
-import org.apache.spark.sql.connector.{BufferedRows, InMemoryAtomicPartitionTable, InMemoryTableCatalog}
import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala
index c95c459721..e5aeb90b84 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala
@@ -24,7 +24,6 @@ import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException}
-import org.apache.spark.sql.connector.{BufferedRows, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog}
import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala
index 485e41f9eb..5560bda928 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.connector.{BufferedRows, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTable, InMemoryTableCatalog}
import org.apache.spark.sql.connector.expressions.LogicalExpressions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
index e6565feebf..5ae74c5eaf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
@@ -48,6 +48,8 @@ class ArrowUtilsSuite extends SparkFunSuite {
roundtrip(BinaryType)
roundtrip(DecimalType.SYSTEM_DEFAULT)
roundtrip(DateType)
+ roundtrip(YearMonthIntervalType)
+ roundtrip(DayTimeIntervalType)
val tsExMsg = intercept[UnsupportedOperationException] {
roundtrip(TimestampType)
}
diff --git a/sql/core/benchmarks/AggregateBenchmark-jdk11-results.txt b/sql/core/benchmarks/AggregateBenchmark-jdk11-results.txt
index 546face681..7b1e82d64a 100644
--- a/sql/core/benchmarks/AggregateBenchmark-jdk11-results.txt
+++ b/sql/core/benchmarks/AggregateBenchmark-jdk11-results.txt
@@ -2,142 +2,147 @@
aggregate without grouping
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
agg w/o group: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-agg w/o group wholestage off 63666 64021 502 32.9 30.4 1.0X
-agg w/o group wholestage on 882 912 37 2376.9 0.4 72.2X
+agg w/o group wholestage off 82274 82877 853 25.5 39.2 1.0X
+agg w/o group wholestage on 1322 1358 37 1586.7 0.6 62.2X
================================================================================================
stat functions
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
stddev: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-stddev wholestage off 7370 7688 450 14.2 70.3 1.0X
-stddev wholestage on 931 997 50 112.6 8.9 7.9X
+stddev wholestage off 8975 9129 219 11.7 85.6 1.0X
+stddev wholestage on 1424 1444 34 73.6 13.6 6.3X
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
kurtosis: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-kurtosis wholestage off 30901 31209 436 3.4 294.7 1.0X
-kurtosis wholestage on 950 996 33 110.4 9.1 32.5X
+kurtosis wholestage off 42273 42424 213 2.5 403.1 1.0X
+kurtosis wholestage on 1492 1528 27 70.3 14.2 28.3X
================================================================================================
aggregate with linear keys
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
Aggregate w keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 8845 8874 41 9.5 105.4 1.0X
-codegen = T hashmap = F 5804 5854 47 14.5 69.2 1.5X
-codegen = T hashmap = T 954 1001 35 87.9 11.4 9.3X
+codegen = F 10873 10998 176 7.7 129.6 1.0X
+codegen = T, hashmap = F 5906 6005 95 14.2 70.4 1.8X
+codegen = T, row-based hashmap = T 2325 2410 94 36.1 27.7 4.7X
+codegen = T, vectorized hashmap = T 1185 1259 78 70.8 14.1 9.2X
================================================================================================
aggregate with randomized keys
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
Aggregate w keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 10398 10788 552 8.1 124.0 1.0X
-codegen = T hashmap = F 7426 7520 84 11.3 88.5 1.4X
-codegen = T hashmap = T 1883 1917 31 44.5 22.4 5.5X
+codegen = F 12385 12470 120 6.8 147.6 1.0X
+codegen = T, hashmap = F 7734 8110 378 10.8 92.2 1.6X
+codegen = T, row-based hashmap = T 3663 3702 37 22.9 43.7 3.4X
+codegen = T, vectorized hashmap = T 2532 2621 54 33.1 30.2 4.9X
================================================================================================
aggregate with string key
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
Aggregate w string key: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 3615 3888 386 5.8 172.4 1.0X
-codegen = T hashmap = F 2253 2381 168 9.3 107.4 1.6X
-codegen = T hashmap = T 1242 1316 59 16.9 59.2 2.9X
+codegen = F 4465 4517 73 4.7 212.9 1.0X
+codegen = T, hashmap = F 2667 2825 208 7.9 127.2 1.7X
+codegen = T, row-based hashmap = T 1436 1466 21 14.6 68.5 3.1X
+codegen = T, vectorized hashmap = T 1297 1301 5 16.2 61.8 3.4X
================================================================================================
aggregate with decimal key
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
Aggregate w decimal key: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 3437 3534 137 6.1 163.9 1.0X
-codegen = T hashmap = F 2122 2226 147 9.9 101.2 1.6X
-codegen = T hashmap = T 638 678 36 32.9 30.4 5.4X
+codegen = F 3722 3746 34 5.6 177.5 1.0X
+codegen = T, hashmap = F 2229 2297 96 9.4 106.3 1.7X
+codegen = T, row-based hashmap = T 927 957 28 22.6 44.2 4.0X
+codegen = T, vectorized hashmap = T 772 796 22 27.2 36.8 4.8X
================================================================================================
aggregate with multiple key types
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
Aggregate w multiple keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 6549 6648 140 3.2 312.3 1.0X
-codegen = T hashmap = F 3591 3693 144 5.8 171.2 1.8X
-codegen = T hashmap = T 2822 2922 141 7.4 134.6 2.3X
+codegen = F 7013 7060 67 3.0 334.4 1.0X
+codegen = T, hashmap = F 3750 3894 205 5.6 178.8 1.9X
+codegen = T, row-based hashmap = T 2948 2952 5 7.1 140.6 2.4X
+codegen = T, vectorized hashmap = T 2986 3145 226 7.0 142.4 2.3X
================================================================================================
max function bytecode size of wholestagecodegen
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
max function bytecode size: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 531 571 36 1.2 810.7 1.0X
-codegen = T hugeMethodLimit = 10000 223 282 36 2.9 340.1 2.4X
-codegen = T hugeMethodLimit = 1500 264 308 27 2.5 402.2 2.0X
+codegen = F 567 620 37 1.2 864.6 1.0X
+codegen = T, hugeMethodLimit = 10000 283 316 26 2.3 431.9 2.0X
+codegen = T, hugeMethodLimit = 1500 275 324 40 2.4 420.2 2.1X
================================================================================================
cube
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
cube: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-cube wholestage off 2963 3099 193 1.8 565.1 1.0X
-cube wholestage on 1624 1767 98 3.2 309.8 1.8X
+cube wholestage off 3389 3476 123 1.5 646.4 1.0X
+cube wholestage on 1692 1726 34 3.1 322.7 2.0X
================================================================================================
hash and BytesToBytesMap
================================================================================================
-OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz
BytesToBytesMap: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-UnsafeRowhash 247 268 19 84.8 11.8 1.0X
-murmur3 hash 99 123 40 211.3 4.7 2.5X
-fast hash 56 66 5 374.0 2.7 4.4X
-arrayEqual 186 200 8 113.0 8.8 1.3X
-Java HashMap (Long) 121 207 65 173.5 5.8 2.0X
-Java HashMap (two ints) 147 233 61 142.8 7.0 1.7X
-Java HashMap (UnsafeRow) 733 778 45 28.6 34.9 0.3X
-LongToUnsafeRowMap (opt=false) 489 504 15 42.8 23.3 0.5X
-LongToUnsafeRowMap (opt=true) 125 154 29 168.2 5.9 2.0X
-BytesToBytesMap (off Heap) 840 895 48 25.0 40.1 0.3X
-BytesToBytesMap (on Heap) 853 904 60 24.6 40.7 0.3X
-Aggregate HashMap 38 46 8 546.3 1.8 6.4X
+UnsafeRowhash 302 306 4 69.5 14.4 1.0X
+murmur3 hash 125 129 3 167.8 6.0 2.4X
+fast hash 69 73 3 304.1 3.3 4.4X
+arrayEqual 192 195 3 109.0 9.2 1.6X
+Java HashMap (Long) 133 187 53 157.2 6.4 2.3X
+Java HashMap (two ints) 156 230 62 134.3 7.4 1.9X
+Java HashMap (UnsafeRow) 807 812 6 26.0 38.5 0.4X
+LongToUnsafeRowMap (opt=false) 502 529 24 41.8 23.9 0.6X
+LongToUnsafeRowMap (opt=true) 148 164 20 141.7 7.1 2.0X
+BytesToBytesMap (off Heap) 936 950 23 22.4 44.6 0.3X
+BytesToBytesMap (on Heap) 954 956 2 22.0 45.5 0.3X
+Aggregate HashMap 46 54 11 455.4 2.2 6.6X
diff --git a/sql/core/benchmarks/AggregateBenchmark-results.txt b/sql/core/benchmarks/AggregateBenchmark-results.txt
index f18c470831..d4de806d03 100644
--- a/sql/core/benchmarks/AggregateBenchmark-results.txt
+++ b/sql/core/benchmarks/AggregateBenchmark-results.txt
@@ -2,142 +2,147 @@
aggregate without grouping
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
agg w/o group: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-agg w/o group wholestage off 47798 50190 NaN 43.9 22.8 1.0X
-agg w/o group wholestage on 1091 1128 28 1922.6 0.5 43.8X
+agg w/o group wholestage off 53440 63455 NaN 39.2 25.5 1.0X
+agg w/o group wholestage on 1157 1216 39 1812.5 0.6 46.2X
================================================================================================
stat functions
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
stddev: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-stddev wholestage off 7884 7959 106 13.3 75.2 1.0X
-stddev wholestage on 1012 1072 34 103.6 9.6 7.8X
+stddev wholestage off 7920 7947 39 13.2 75.5 1.0X
+stddev wholestage on 1147 1160 11 91.4 10.9 6.9X
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
kurtosis: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-kurtosis wholestage off 34023 34576 783 3.1 324.5 1.0X
-kurtosis wholestage on 1092 1121 30 96.1 10.4 31.2X
+kurtosis wholestage off 35143 35319 250 3.0 335.1 1.0X
+kurtosis wholestage on 1239 1258 20 84.6 11.8 28.4X
================================================================================================
aggregate with linear keys
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
Aggregate w keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 9309 9379 99 9.0 111.0 1.0X
-codegen = T hashmap = F 5453 5643 223 15.4 65.0 1.7X
-codegen = T hashmap = T 1084 1110 16 77.4 12.9 8.6X
+codegen = F 9147 9183 50 9.2 109.0 1.0X
+codegen = T, hashmap = F 5794 5949 226 14.5 69.1 1.6X
+codegen = T, row-based hashmap = T 1378 1397 14 60.9 16.4 6.6X
+codegen = T, vectorized hashmap = T 996 1034 25 84.3 11.9 9.2X
================================================================================================
aggregate with randomized keys
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
Aggregate w keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 10707 10950 344 7.8 127.6 1.0X
-codegen = T hashmap = F 7295 7423 145 11.5 87.0 1.5X
-codegen = T hashmap = T 2057 2199 199 40.8 24.5 5.2X
+codegen = F 9356 9425 98 9.0 111.5 1.0X
+codegen = T, hashmap = F 5787 5912 176 14.5 69.0 1.6X
+codegen = T, row-based hashmap = T 2569 2602 49 32.7 30.6 3.6X
+codegen = T, vectorized hashmap = T 2094 2128 27 40.1 25.0 4.5X
================================================================================================
aggregate with string key
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
Aggregate w string key: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 4570 4573 4 4.6 217.9 1.0X
-codegen = T hashmap = F 3600 3686 74 5.8 171.7 1.3X
-codegen = T hashmap = T 2384 2432 45 8.8 113.7 1.9X
+codegen = F 4270 4322 75 4.9 203.6 1.0X
+codegen = T, hashmap = F 3241 3264 30 6.5 154.6 1.3X
+codegen = T, row-based hashmap = T 2196 2247 32 9.6 104.7 1.9X
+codegen = T, vectorized hashmap = T 2291 2306 14 9.2 109.3 1.9X
================================================================================================
aggregate with decimal key
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
Aggregate w decimal key: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 2966 3011 64 7.1 141.4 1.0X
-codegen = T hashmap = F 1857 1908 73 11.3 88.5 1.6X
-codegen = T hashmap = T 695 702 8 30.2 33.2 4.3X
+codegen = F 2993 3010 23 7.0 142.7 1.0X
+codegen = T, hashmap = F 1940 1945 7 10.8 92.5 1.5X
+codegen = T, row-based hashmap = T 738 752 20 28.4 35.2 4.1X
+codegen = T, vectorized hashmap = T 620 650 21 33.8 29.6 4.8X
================================================================================================
aggregate with multiple key types
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
Aggregate w multiple keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 7361 7385 35 2.8 351.0 1.0X
-codegen = T hashmap = F 4525 4688 231 4.6 215.8 1.6X
-codegen = T hashmap = T 3865 3977 159 5.4 184.3 1.9X
+codegen = F 6635 6636 2 3.2 316.4 1.0X
+codegen = T, hashmap = F 4236 4269 47 5.0 202.0 1.6X
+codegen = T, row-based hashmap = T 3118 3158 57 6.7 148.7 2.1X
+codegen = T, vectorized hashmap = T 3259 3278 27 6.4 155.4 2.0X
================================================================================================
max function bytecode size of wholestagecodegen
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
max function bytecode size: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-codegen = F 451 489 23 1.5 688.5 1.0X
-codegen = T hugeMethodLimit = 10000 211 229 19 3.1 322.4 2.1X
-codegen = T hugeMethodLimit = 1500 203 226 20 3.2 309.5 2.2X
+codegen = F 467 492 33 1.4 712.4 1.0X
+codegen = T, hugeMethodLimit = 10000 216 231 19 3.0 329.7 2.2X
+codegen = T, hugeMethodLimit = 1500 209 221 9 3.1 319.0 2.2X
================================================================================================
cube
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
cube: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-cube wholestage off 2479 2548 97 2.1 472.9 1.0X
-cube wholestage on 1487 1567 62 3.5 283.7 1.7X
+cube wholestage off 2490 2529 56 2.1 474.8 1.0X
+cube wholestage on 1401 1416 22 3.7 267.3 1.8X
================================================================================================
hash and BytesToBytesMap
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure
-Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure
+Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
BytesToBytesMap: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-UnsafeRowhash 826 837 16 25.4 39.4 1.0X
-murmur3 hash 537 553 11 39.1 25.6 1.5X
-fast hash 559 572 14 37.5 26.6 1.5X
-arrayEqual 1665 1728 90 12.6 79.4 0.5X
-Java HashMap (Long) 732 739 7 28.7 34.9 1.1X
-Java HashMap (two ints) 682 694 15 30.7 32.5 1.2X
-Java HashMap (UnsafeRow) 1486 1499 19 14.1 70.9 0.6X
-LongToUnsafeRowMap (opt=false) 1235 1240 8 17.0 58.9 0.7X
-LongToUnsafeRowMap (opt=true) 718 736 17 29.2 34.2 1.2X
-BytesToBytesMap (off Heap) 945 965 20 22.2 45.1 0.9X
-BytesToBytesMap (on Heap) 870 895 28 24.1 41.5 0.9X
-Aggregate HashMap 64 71 5 325.6 3.1 12.8X
+UnsafeRowhash 259 264 5 81.0 12.3 1.0X
+murmur3 hash 113 121 3 185.7 5.4 2.3X
+fast hash 84 87 2 249.8 4.0 3.1X
+arrayEqual 172 180 4 121.9 8.2 1.5X
+Java HashMap (Long) 155 161 5 135.2 7.4 1.7X
+Java HashMap (two ints) 147 157 8 142.6 7.0 1.8X
+Java HashMap (UnsafeRow) 739 742 4 28.4 35.2 0.4X
+LongToUnsafeRowMap (opt=false) 489 491 3 42.9 23.3 0.5X
+LongToUnsafeRowMap (opt=true) 93 100 6 224.8 4.4 2.8X
+BytesToBytesMap (off Heap) 882 896 16 23.8 42.1 0.3X
+BytesToBytesMap (on Heap) 833 863 36 25.2 39.7 0.3X
+Aggregate HashMap 66 69 1 317.0 3.2 3.9X
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 829025f3dc..0795776eb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -214,11 +214,13 @@ class QueryExecution(
QueryPlan.append(logical, append, verbose, addSuffix, maxFields)
append("\n== Analyzed Logical Plan ==\n")
try {
- append(
- truncatedString(
- analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ", maxFields)
- )
- append("\n")
+ if (analyzed.output.nonEmpty) {
+ append(
+ truncatedString(
+ analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ", maxFields)
+ )
+ append("\n")
+ }
QueryPlan.append(analyzed, append, verbose, addSuffix, maxFields)
append("\n== Optimized Logical Plan ==\n")
QueryPlan.append(optimizedPlan, append, verbose, addSuffix, maxFields)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
index bd45863652..d50e32c8b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.internal.SQLConf
@@ -54,8 +54,21 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl
if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) {
plan
} else {
+ def insertCustomShuffleReader(partitionSpecs: Seq[ShufflePartitionSpec]): SparkPlan = {
+ // This transformation adds new nodes, so we must use `transformUp` here.
+ val stageIds = shuffleStages.map(_.id).toSet
+ plan.transformUp {
+ // even for shuffle exchange whose input RDD has 0 partition, we should still update its
+ // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
+ // number of output partitions.
+ case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
+ CustomShuffleReaderExec(stage, partitionSpecs)
+ }
+ }
+
// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
// we should skip it when calculating the `partitionStartIndices`.
+ // If all input RDDs have 0 partition, we create empty partition for every shuffle reader.
val validMetrics = shuffleStages.flatMap(_.mapStats)
// We may have different pre-shuffle partition numbers, don't reduce shuffle partition number
@@ -63,7 +76,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl
// partition) and a result of a SortMergeJoin (multiple partitions).
val distinctNumPreShufflePartitions =
validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
- if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) {
+ if (validMetrics.isEmpty) {
+ insertCustomShuffleReader(ShufflePartitionsUtil.createEmptyPartition() :: Nil)
+ } else if (distinctNumPreShufflePartitions.length == 1) {
// We fall back to Spark default parallelism if the minimum number of coalesced partitions
// is not set, so to avoid perf regressions compared to no coalescing.
val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM)
@@ -77,15 +92,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl
if (partitionSpecs.length == distinctNumPreShufflePartitions.head) {
plan
} else {
- // This transformation adds new nodes, so we must use `transformUp` here.
- val stageIds = shuffleStages.map(_.id).toSet
- plan.transformUp {
- // even for shuffle exchange whose input RDD has 0 partition, we should still update its
- // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
- // number of output partitions.
- case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
- CustomShuffleReaderExec(stage, partitionSpecs)
- }
+ insertCustomShuffleReader(partitionSpecs)
}
} else {
plan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
index d98b7c29a3..1065519256 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ListQuery, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
@@ -113,6 +114,9 @@ case class InsertAdaptiveSparkPlan(
*/
private def buildSubqueryMap(plan: SparkPlan): Map[Long, BaseSubqueryExec] = {
val subqueryMap = mutable.HashMap.empty[Long, BaseSubqueryExec]
+ if (!plan.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
+ return subqueryMap.toMap
+ }
plan.foreach(_.expressions.foreach(_.foreach {
case expressions.ScalarSubquery(p, _, exprId)
if !subqueryMap.contains(exprId.id) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
index 13ff236d20..a2e4397a36 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, DynamicPruningExpression, ListQuery, Literal}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY,
+ SCALAR_SUBQUERY}
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.{BaseSubqueryExec, InSubqueryExec, SparkPlan}
@@ -27,7 +29,8 @@ case class PlanAdaptiveSubqueries(
subqueryMap: Map[Long, BaseSubqueryExec]) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
- plan.transformAllExpressions {
+ plan.transformAllExpressionsWithPruning(
+ _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
case expressions.ScalarSubquery(_, _, exprId) =>
execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
index ed92af6adc..a70a5322a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala
@@ -125,6 +125,10 @@ object ShufflePartitionsUtil extends Logging {
partitionSpecs.toSeq
}
+ def createEmptyPartition(): ShufflePartitionSpec = {
+ CoalescedPartitionSpec(0, 0)
+ }
+
/**
* Given a list of size, return an array of indices to split the list into multiple partitions,
* so that the size sum of each partition is close to the target size. Each index indicates the
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 6e23a2844d..3c1304e9cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit._
import scala.collection.mutable
import org.apache.spark.TaskContext
-import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
+import org.apache.spark.memory.SparkOutOfMemoryError
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -435,8 +435,8 @@ case class HashAggregateExec(
)
}
- def getTaskMemoryManager(): TaskMemoryManager = {
- TaskContext.get().taskMemoryManager()
+ def getTaskContext(): TaskContext = {
+ TaskContext.get()
}
def getEmptyAggregationBuffer(): InternalRow = {
@@ -647,7 +647,7 @@ case class HashAggregateExec(
(groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) ||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] ||
f.dataType.isInstanceOf[CalendarIntervalType]) &&
- bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
+ bufferSchema.nonEmpty
// For vectorized hash map, We do not support byte array based decimal type for aggregate values
// as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place
@@ -663,7 +663,7 @@ case class HashAggregateExec(
private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = {
if (!checkIfFastHashMapSupported(ctx)) {
- if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) {
+ if (!Utils.isTesting) {
logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but"
+ " current version of codegened fast hashmap does not support this aggregate.")
}
@@ -683,7 +683,18 @@ case class HashAggregateExec(
} else if (sqlContext.conf.enableVectorizedHashMap) {
logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.")
}
- val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit
+ val bitMaxCapacity = testFallbackStartsAt match {
+ case Some((fastMapCounter, _)) =>
+ // In testing, with fall back counter of fast hash map (`fastMapCounter`), set the max bit
+ // of map to be no more than log2(`fastMapCounter`). This helps control the number of keys
+ // in map to mimic fall back.
+ if (fastMapCounter <= 1) {
+ 0
+ } else {
+ (math.log10(fastMapCounter) / math.log10(2)).floor.toInt
+ }
+ case _ => sqlContext.conf.fastHashAggregateRowMaxCapacityBit
+ }
val thisPlan = ctx.addReferenceObj("plan", this)
@@ -717,11 +728,28 @@ case class HashAggregateExec(
"org.apache.spark.unsafe.KVIterator",
"fastHashMapIter", forceInline = true)
val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
- s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());"
+ s"$thisPlan.getTaskContext().taskMemoryManager(), " +
+ s"$thisPlan.getEmptyAggregationBuffer());"
(iter, create)
}
} else ("", "")
+ // Generates the code to register a cleanup task with TaskContext to ensure that memory
+ // is guaranteed to be freed at the end of the task. This is necessary to avoid memory
+ // leaks in when the downstream operator does not fully consume the aggregation map's
+ // output (e.g. aggregate followed by limit).
+ val addHookToCloseFastHashMap = if (isFastHashMapEnabled) {
+ s"""
+ |$thisPlan.getTaskContext().addTaskCompletionListener(
+ | new org.apache.spark.util.TaskCompletionListener() {
+ | @Override
+ | public void onTaskCompletion(org.apache.spark.TaskContext context) {
+ | $fastHashMapTerm.close();
+ | }
+ |});
+ """.stripMargin
+ } else ""
+
// Create a name for the iterator from the regular hash map.
// Inline mutable state since not many aggregation operations in a task
val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName,
@@ -761,6 +789,8 @@ case class HashAggregateExec(
val bufferTerm = ctx.freshName("aggBuffer")
val outputFunc = generateResultFunction(ctx)
+ val limitNotReachedCondition = limitNotReachedCond
+
def outputFromFastHashMap: String = {
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
@@ -773,7 +803,7 @@ case class HashAggregateExec(
def outputFromRowBasedMap: String = {
s"""
- |while ($iterTermForFastHashMap.next()) {
+ |while ($limitNotReachedCondition $iterTermForFastHashMap.next()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
| $outputFunc($keyTerm, $bufferTerm);
@@ -798,7 +828,7 @@ case class HashAggregateExec(
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
})
s"""
- |while ($iterTermForFastHashMap.hasNext()) {
+ |while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) {
| InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
| ${generateKeyRow.code}
| ${generateBufferRow.code}
@@ -813,7 +843,7 @@ case class HashAggregateExec(
def outputFromRegularHashMap: String = {
s"""
- |while ($limitNotReachedCond $iterTerm.next()) {
+ |while ($limitNotReachedCondition $iterTerm.next()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
| $outputFunc($keyTerm, $bufferTerm);
@@ -832,6 +862,7 @@ case class HashAggregateExec(
|if (!$initAgg) {
| $initAgg = true;
| $createFastHashMap
+ | $addHookToCloseFastHashMap
| $hashMapTerm = $thisPlan.createHashMap();
| long $beforeAgg = System.nanoTime();
| $doAggFuncName();
@@ -866,13 +897,11 @@ case class HashAggregateExec(
}
}
- val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
- incCounter) = if (testFallbackStartsAt.isDefined) {
- val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
- (s"$countTerm < ${testFallbackStartsAt.get._1}",
- s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
- } else {
- ("true", "true", "", "")
+ val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match {
+ case Some((_, regularMapCounter)) =>
+ val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
+ (s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;")
+ case _ => ("true", "", "")
}
val oomeClassName = classOf[SparkOutOfMemoryError].getName
@@ -912,12 +941,10 @@ case class HashAggregateExec(
// If fast hash map is on, we first generate code to probe and update the fast hash map.
// If the probe is successful the corresponding fast row buffer will hold the mutable row.
s"""
- |if ($checkFallbackForGeneratedHashMap) {
- | ${fastRowKeys.map(_.code).mkString("\n")}
- | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
- | $fastRowBuffer = $fastHashMapTerm.findOrInsert(
- | ${fastRowKeys.map(_.value).mkString(", ")});
- | }
+ |${fastRowKeys.map(_.code).mkString("\n")}
+ |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
+ | $fastRowBuffer = $fastHashMapTerm.findOrInsert(
+ | ${fastRowKeys.map(_.value).mkString(", ")});
|}
|// Cannot find the key in fast hash map, try regular hash map.
|if ($fastRowBuffer == null) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index d3a02f2451..fcae7ac32b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_MILLIS}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
@@ -74,6 +75,8 @@ object ArrowWriter {
}
new StructWriter(vector, children.toArray)
case (NullType, vector: NullVector) => new NullWriter(vector)
+ case (YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector)
+ case (DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector)
case (dt, _) =>
throw QueryExecutionErrors.unsupportedDataTypeError(dt)
}
@@ -394,3 +397,28 @@ private[arrow] class NullWriter(val valueVector: NullVector) extends ArrowFieldW
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
}
}
+
+private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector)
+ extends ArrowFieldWriter {
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getInt(ordinal));
+ }
+}
+
+private[arrow] class IntervalDayWriter(val valueVector: IntervalDayVector)
+ extends ArrowFieldWriter {
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val totalMicroseconds = input.getLong(ordinal)
+ val days = totalMicroseconds / MICROS_PER_DAY
+ val millis = (totalMicroseconds % MICROS_PER_DAY) / MICROS_PER_MILLIS
+ valueVector.set(count, days.toInt, millis.toInt)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
index c6b6a21da5..c32d1d74c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
@@ -139,7 +139,7 @@ case class SetCommand(kv: Option[(String, Option[String])])
s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.")
Seq(Row(
SQLConf.SHUFFLE_PARTITIONS.key,
- sparkSession.sessionState.conf.numShufflePartitions.toString))
+ sparkSession.sessionState.conf.defaultNumShufflePartitions.toString))
}
(keyValueOutput, runFunc)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
index b6b07de8a5..4f60a9d4c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
@@ -53,11 +53,11 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
private[this] val partitions: mutable.ArrayBuffer[InternalRow] = mutable.ArrayBuffer.empty
private[this] var numFiles: Int = 0
- private[this] var submittedFiles: Int = 0
+ private[this] var numSubmittedFiles: Int = 0
private[this] var numBytes: Long = 0L
private[this] var numRows: Long = 0L
- private[this] var curFile: Option[String] = None
+ private[this] val submittedFiles = mutable.HashSet[String]()
/**
* Get the size of the file expected to have been written by a worker.
@@ -134,23 +134,20 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
partitions.append(partitionValues)
}
- override def newBucket(bucketId: Int): Unit = {
- // currently unhandled
+ override def newFile(filePath: String): Unit = {
+ submittedFiles += filePath
+ numSubmittedFiles += 1
}
- override def newFile(filePath: String): Unit = {
- statCurrentFile()
- curFile = Some(filePath)
- submittedFiles += 1
+ override def closeFile(filePath: String): Unit = {
+ updateFileStats(filePath)
+ submittedFiles.remove(filePath)
}
- private def statCurrentFile(): Unit = {
- curFile.foreach { path =>
- getFileSize(path).foreach { len =>
- numBytes += len
- numFiles += 1
- }
- curFile = None
+ private def updateFileStats(filePath: String): Unit = {
+ getFileSize(filePath).foreach { len =>
+ numBytes += len
+ numFiles += 1
}
}
@@ -159,7 +156,8 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
}
override def getFinalStats(): WriteTaskStats = {
- statCurrentFile()
+ submittedFiles.foreach(updateFileStats)
+ submittedFiles.clear()
// Reports bytesWritten and recordsWritten to the Spark output metrics.
Option(TaskContext.get()).map(_.taskMetrics().outputMetrics).foreach { outputMetrics =>
@@ -167,8 +165,8 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
outputMetrics.setRecordsWritten(numRows)
}
- if (submittedFiles != numFiles) {
- logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " +
+ if (numSubmittedFiles != numFiles) {
+ logInfo(s"Expected $numSubmittedFiles files, but only saw $numFiles. " +
"This could be due to the output format not writing empty files, " +
"or files being not immediately visible in the filesystem.")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
index 6de9b1d7ce..8230737a61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.sql.catalyst.InternalRow
@@ -28,6 +29,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
+import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.SerializableConfiguration
@@ -52,19 +55,35 @@ abstract class FileFormatDataWriter(
protected val statsTrackers: Seq[WriteTaskStatsTracker] =
description.statsTrackers.map(_.newTaskInstance())
- protected def releaseResources(): Unit = {
+ /** Release resources of `currentWriter`. */
+ protected def releaseCurrentWriter(): Unit = {
if (currentWriter != null) {
try {
currentWriter.close()
+ statsTrackers.foreach(_.closeFile(currentWriter.path()))
} finally {
currentWriter = null
}
}
}
- /** Writes a record */
+ /** Release all resources. */
+ protected def releaseResources(): Unit = {
+ // Call `releaseCurrentWriter()` by default, as this is the only resource to be released.
+ releaseCurrentWriter()
+ }
+
+ /** Writes a record. */
def write(record: InternalRow): Unit
+
+ /** Write an iterator of records. */
+ def writeWithIterator(iterator: Iterator[InternalRow]): Unit = {
+ while (iterator.hasNext) {
+ write(iterator.next())
+ }
+ }
+
/**
* Returns the summary of relative information which
* includes the list of partition strings written out. The list of partitions is sent back
@@ -144,34 +163,38 @@ class SingleDirectoryDataWriter(
}
/**
- * Writes data to using dynamic partition writes, meaning this single function can write to
+ * Holds common logic for writing data with dynamic partition writes, meaning it can write to
* multiple directories (partitions) or files (bucketing).
*/
-class DynamicPartitionDataWriter(
+abstract class BaseDynamicPartitionDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol)
extends FileFormatDataWriter(description, taskAttemptContext, committer) {
/** Flag saying whether or not the data to be written out is partitioned. */
- private val isPartitioned = description.partitionColumns.nonEmpty
+ protected val isPartitioned = description.partitionColumns.nonEmpty
/** Flag saying whether or not the data to be written out is bucketed. */
- private val isBucketed = description.bucketIdExpression.isDefined
+ protected val isBucketed = description.bucketIdExpression.isDefined
assert(isPartitioned || isBucketed,
s"""DynamicPartitionWriteTask should be used for writing out data that's either
- |partitioned or bucketed. In this case neither is true.
- |WriteJobDescription: $description
+ |partitioned or bucketed. In this case neither is true.
+ |WriteJobDescription: $description
""".stripMargin)
- private var fileCounter: Int = _
- private var recordsInFile: Long = _
- private var currentPartitionValues: Option[UnsafeRow] = None
- private var currentBucketId: Option[Int] = None
+ /** Number of records in current file. */
+ protected var recordsInFile: Long = _
+
+ /**
+ * File counter for writing current partition or bucket. For same partition or bucket,
+ * we may have more than one file, due to number of records limit per file.
+ */
+ protected var fileCounter: Int = _
/** Extracts the partition values out of an input row. */
- private lazy val getPartitionValues: InternalRow => UnsafeRow = {
+ protected lazy val getPartitionValues: InternalRow => UnsafeRow = {
val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns)
row => proj(row)
}
@@ -186,22 +209,24 @@ class DynamicPartitionDataWriter(
if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
})
- /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns
- * the partition string. */
+ /**
+ * Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns
+ * the partition string.
+ */
private lazy val getPartitionPath: InternalRow => String = {
val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns)
row => proj(row).getString(0)
}
/** Given an input row, returns the corresponding `bucketId` */
- private lazy val getBucketId: InternalRow => Int = {
+ protected lazy val getBucketId: InternalRow => Int = {
val proj =
UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns)
row => proj(row).getInt(0)
}
/** Returns the data columns to be written given an input row */
- private val getOutputRow =
+ protected val getOutputRow =
UnsafeProjection.create(description.dataColumns, description.allColumns)
/**
@@ -209,13 +234,20 @@ class DynamicPartitionDataWriter(
* If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
*
- * @param partitionValues the partition which all tuples being written by this `OutputWriter`
+ * @param partitionValues the partition which all tuples being written by this OutputWriter
* belong to
- * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to
+ * @param bucketId the bucket which all tuples being written by this OutputWriter belong to
+ * @param closeCurrentWriter close and release resource for current writer
*/
- private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = {
+ protected def renewCurrentWriter(
+ partitionValues: Option[InternalRow],
+ bucketId: Option[Int],
+ closeCurrentWriter: Boolean): Unit = {
+
recordsInFile = 0
- releaseResources()
+ if (closeCurrentWriter) {
+ releaseCurrentWriter()
+ }
val partDir = partitionValues.map(getPartitionPath(_))
partDir.foreach(updatedPartitions.add)
@@ -243,6 +275,51 @@ class DynamicPartitionDataWriter(
statsTrackers.foreach(_.newFile(currentPath))
}
+ /**
+ * Open a new output writer when number of records exceeding limit.
+ *
+ * @param partitionValues the partition which all tuples being written by this `OutputWriter`
+ * belong to
+ * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to
+ */
+ protected def renewCurrentWriterIfTooManyRecords(
+ partitionValues: Option[InternalRow],
+ bucketId: Option[Int]): Unit = {
+ // Exceeded the threshold in terms of the number of records per file.
+ // Create a new file by increasing the file counter.
+ fileCounter += 1
+ assert(fileCounter < MAX_FILE_COUNTER,
+ s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+ renewCurrentWriter(partitionValues, bucketId, closeCurrentWriter = true)
+ }
+
+ /**
+ * Writes the given record with current writer.
+ *
+ * @param record The record to write
+ */
+ protected def writeRecord(record: InternalRow): Unit = {
+ val outputRow = getOutputRow(record)
+ currentWriter.write(outputRow)
+ statsTrackers.foreach(_.newRow(outputRow))
+ recordsInFile += 1
+ }
+}
+
+/**
+ * Dynamic partition writer with single writer, meaning only one writer is opened at any time for
+ * writing. The records to be written are required to be sorted on partition and/or bucket
+ * column(s) before writing.
+ */
+class DynamicPartitionDataSingleWriter(
+ description: WriteJobDescription,
+ taskAttemptContext: TaskAttemptContext,
+ committer: FileCommitProtocol)
+ extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) {
+
+ private var currentPartitionValues: Option[UnsafeRow] = None
+ private var currentBucketId: Option[Int] = None
+
override def write(record: InternalRow): Unit = {
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None
val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None
@@ -255,25 +332,199 @@ class DynamicPartitionDataWriter(
}
if (isBucketed) {
currentBucketId = nextBucketId
- statsTrackers.foreach(_.newBucket(currentBucketId.get))
}
fileCounter = 0
- newOutputWriter(currentPartitionValues, currentBucketId)
+ renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true)
} else if (description.maxRecordsPerFile > 0 &&
recordsInFile >= description.maxRecordsPerFile) {
- // Exceeded the threshold in terms of the number of records per file.
- // Create a new file by increasing the file counter.
- fileCounter += 1
- assert(fileCounter < MAX_FILE_COUNTER,
- s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+ renewCurrentWriterIfTooManyRecords(currentPartitionValues, currentBucketId)
+ }
+ writeRecord(record)
+ }
+}
+
+/**
+ * Dynamic partition writer with concurrent writers, meaning multiple concurrent writers are opened
+ * for writing.
+ *
+ * The process has the following steps:
+ * - Step 1: Maintain a map of output writers per each partition and/or bucket columns. Keep all
+ * writers opened and write rows one by one.
+ * - Step 2: If number of concurrent writers exceeds limit, sort rest of rows on partition and/or
+ * bucket column(s). Write rows one by one, and eagerly close the writer when finishing
+ * each partition and/or bucket.
+ *
+ * Caller is expected to call `writeWithIterator()` instead of `write()` to write records.
+ */
+class DynamicPartitionDataConcurrentWriter(
+ description: WriteJobDescription,
+ taskAttemptContext: TaskAttemptContext,
+ committer: FileCommitProtocol,
+ concurrentOutputWriterSpec: ConcurrentOutputWriterSpec)
+ extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer)
+ with Logging {
+
+ /** Wrapper class to index a unique concurrent output writer. */
+ private case class WriterIndex(
+ var partitionValues: Option[UnsafeRow],
+ var bucketId: Option[Int])
+
+ /** Wrapper class for status of a unique concurrent output writer. */
+ private class WriterStatus(
+ var outputWriter: OutputWriter,
+ var recordsInFile: Long,
+ var fileCounter: Int)
+
+ /**
+ * State to indicate if we are falling back to sort-based writer.
+ * Because we first try to use concurrent writers, its initial value is false.
+ */
+ private var sorted: Boolean = false
+ private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatus]()
+
+ /**
+ * The index for current writer. Intentionally make the index mutable and reusable.
+ * Avoid JVM GC issue when many short-living `WriterIndex` objects are created
+ * if switching between concurrent writers frequently.
+ */
+ private val currentWriterId = WriterIndex(None, None)
+
+ /**
+ * Release resources for all concurrent output writers.
+ */
+ override protected def releaseResources(): Unit = {
+ currentWriter = null
+ concurrentWriters.values.foreach(status => {
+ if (status.outputWriter != null) {
+ try {
+ status.outputWriter.close()
+ } finally {
+ status.outputWriter = null
+ }
+ }
+ })
+ concurrentWriters.clear()
+ }
- newOutputWriter(currentPartitionValues, currentBucketId)
+ override def write(record: InternalRow): Unit = {
+ val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None
+ val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None
+
+ if (currentWriterId.partitionValues != nextPartitionValues ||
+ currentWriterId.bucketId != nextBucketId) {
+ // See a new partition or bucket - write to a new partition dir (or a new bucket file).
+ if (currentWriter != null) {
+ if (!sorted) {
+ // Update writer status in concurrent writers map, because the writer is probably needed
+ // again later for writing other rows.
+ updateCurrentWriterStatusInMap()
+ } else {
+ // Remove writer status in concurrent writers map and release current writer resource,
+ // because the writer is not needed any more.
+ concurrentWriters.remove(currentWriterId)
+ releaseCurrentWriter()
+ }
+ }
+
+ if (isBucketed) {
+ currentWriterId.bucketId = nextBucketId
+ }
+ if (isPartitioned && currentWriterId.partitionValues != nextPartitionValues) {
+ currentWriterId.partitionValues = Some(nextPartitionValues.get.copy())
+ if (!concurrentWriters.contains(currentWriterId)) {
+ statsTrackers.foreach(_.newPartition(currentWriterId.partitionValues.get))
+ }
+ }
+ setupCurrentWriterUsingMap()
}
- val outputRow = getOutputRow(record)
- currentWriter.write(outputRow)
- statsTrackers.foreach(_.newRow(outputRow))
- recordsInFile += 1
+
+ if (description.maxRecordsPerFile > 0 &&
+ recordsInFile >= description.maxRecordsPerFile) {
+ renewCurrentWriterIfTooManyRecords(currentWriterId.partitionValues, currentWriterId.bucketId)
+ // Update writer status in concurrent writers map, as a new writer is created.
+ updateCurrentWriterStatusInMap()
+ }
+ writeRecord(record)
+ }
+
+ /**
+ * Write iterator of records with concurrent writers.
+ */
+ override def writeWithIterator(iterator: Iterator[InternalRow]): Unit = {
+ while (iterator.hasNext && !sorted) {
+ write(iterator.next())
+ }
+
+ if (iterator.hasNext) {
+ clearCurrentWriterStatus()
+ val sorter = concurrentOutputWriterSpec.createSorter()
+ val sortIterator = sorter.sort(iterator.asInstanceOf[Iterator[UnsafeRow]])
+ while (sortIterator.hasNext) {
+ write(sortIterator.next())
+ }
+ }
+ }
+
+ /**
+ * Update current writer status in map.
+ */
+ private def updateCurrentWriterStatusInMap(): Unit = {
+ val status = concurrentWriters(currentWriterId)
+ status.outputWriter = currentWriter
+ status.recordsInFile = recordsInFile
+ status.fileCounter = fileCounter
+ }
+
+ /**
+ * Retrieve writer in map, or create a new writer if not exists.
+ */
+ private def setupCurrentWriterUsingMap(): Unit = {
+ if (concurrentWriters.contains(currentWriterId)) {
+ val status = concurrentWriters(currentWriterId)
+ currentWriter = status.outputWriter
+ recordsInFile = status.recordsInFile
+ fileCounter = status.fileCounter
+ } else {
+ fileCounter = 0
+ renewCurrentWriter(
+ currentWriterId.partitionValues,
+ currentWriterId.bucketId,
+ closeCurrentWriter = false)
+ if (!sorted) {
+ assert(concurrentWriters.size <= concurrentOutputWriterSpec.maxWriters,
+ s"Number of concurrent output file writers is ${concurrentWriters.size} " +
+ s" which is beyond max value ${concurrentOutputWriterSpec.maxWriters}")
+ } else {
+ assert(concurrentWriters.size <= concurrentOutputWriterSpec.maxWriters + 1,
+ s"Number of output file writers after sort is ${concurrentWriters.size} " +
+ s" which is beyond max value ${concurrentOutputWriterSpec.maxWriters + 1}")
+ }
+ concurrentWriters.put(
+ currentWriterId.copy(),
+ new WriterStatus(currentWriter, recordsInFile, fileCounter))
+ if (concurrentWriters.size >= concurrentOutputWriterSpec.maxWriters && !sorted) {
+ // Fall back to sort-based sequential writer mode.
+ logInfo(s"Number of concurrent writers ${concurrentWriters.size} reaches the threshold. " +
+ "Fall back from concurrent writers to sort-based sequential writer. You may change " +
+ s"threshold with configuration ${SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS.key}")
+ sorted = true
+ }
+ }
+ }
+
+ /**
+ * Clear the current writer status in map.
+ */
+ private def clearCurrentWriterStatus(): Unit = {
+ if (currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined) {
+ updateCurrentWriterStatusInMap()
+ }
+ currentWriterId.partitionValues = None
+ currentWriterId.bucketId = None
+ currentWriter = null
+ recordsInFile = 0
+ fileCounter = 0
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 6300e10c0b..6839a4db0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String
@@ -73,6 +73,11 @@ object FileFormatWriter extends Logging {
copy(child = newChild)
}
+ /** Describes how concurrent output writers should be executed. */
+ case class ConcurrentOutputWriterSpec(
+ maxWriters: Int,
+ createSorter: () => UnsafeExternalRowSorter)
+
/**
* Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific
@@ -177,18 +182,27 @@ object FileFormatWriter extends Logging {
committer.setupJob(job)
try {
- val rdd = if (orderingMatched) {
- empty2NullPlan.execute()
+ val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) {
+ (empty2NullPlan.execute(), None)
} else {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = bindReferences(
requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns)
- SortExec(
+ val sortPlan = SortExec(
orderingExpr,
global = false,
- child = empty2NullPlan).execute()
+ child = empty2NullPlan)
+
+ val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters
+ val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty
+ if (concurrentWritersEnabled) {
+ (empty2NullPlan.execute(),
+ Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter())))
+ } else {
+ (sortPlan.execute(), None)
+ }
}
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
@@ -211,7 +225,8 @@ object FileFormatWriter extends Logging {
sparkPartitionId = taskContext.partitionId(),
sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE,
committer,
- iterator = iter)
+ iterator = iter,
+ concurrentOutputWriterSpec = concurrentOutputWriterSpec)
},
rddWithNonEmptyPartitions.partitions.indices,
(index, res: WriteTaskResult) => {
@@ -245,7 +260,8 @@ object FileFormatWriter extends Logging {
sparkPartitionId: Int,
sparkAttemptNumber: Int,
committer: FileCommitProtocol,
- iterator: Iterator[InternalRow]): WriteTaskResult = {
+ iterator: Iterator[InternalRow],
+ concurrentOutputWriterSpec: Option[ConcurrentOutputWriterSpec]): WriteTaskResult = {
val jobId = SparkHadoopWriterUtils.createJobID(new Date(jobIdInstant), sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
@@ -273,15 +289,19 @@ object FileFormatWriter extends Logging {
} else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
} else {
- new DynamicPartitionDataWriter(description, taskAttemptContext, committer)
+ concurrentOutputWriterSpec match {
+ case Some(spec) =>
+ new DynamicPartitionDataConcurrentWriter(
+ description, taskAttemptContext, committer, spec)
+ case _ =>
+ new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer)
+ }
}
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Execute the task to write rows out and commit the task.
- while (iterator.hasNext) {
- dataWriter.write(iterator.next())
- }
+ dataWriter.writeWithIterator(iterator)
dataWriter.commit()
})(catchBlock = {
// If there is an error, abort the task
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
index a0b191e60f..4ed8943ef4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
@@ -40,6 +40,8 @@ import org.apache.spark.sql.types.{StructField, StructType}
case class HadoopFsRelation(
location: FileIndex,
partitionSchema: StructType,
+ // The top-level columns in `dataSchema` should match the actual physical file schema, otherwise
+ // the ORC data source may not work with the by-ordinal mode.
dataSchema: StructType,
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala
index 1d7abe5b93..7c479d986f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala
@@ -57,7 +57,7 @@ abstract class OutputWriterFactory extends Serializable {
*/
abstract class OutputWriter {
/**
- * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
+ * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
* tables, dynamic partition columns are not included in rows to be written.
*/
def write(row: InternalRow): Unit
@@ -67,4 +67,9 @@ abstract class OutputWriter {
* the task output is committed.
*/
def close(): Unit
+
+ /**
+ * The file path to write. Invoked on the executor side.
+ */
+ def path(): String
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala
index c39a82ee03..aaf866bced 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala
@@ -32,20 +32,7 @@ trait WriteTaskStats extends Serializable
* A trait for classes that are capable of collecting statistics on data that's being processed by
* a single write task in [[FileFormatWriter]] - i.e. there should be one instance per executor.
*
- * This trait is coupled with the way [[FileFormatWriter]] works, in the sense that its methods
- * will be called according to how tuples are being written out to disk, namely in sorted order
- * according to partitionValue(s), then bucketId.
- *
- * As such, a typical call scenario is:
- *
- * newPartition -> newBucket -> newFile -> newRow -.
- * ^ |______^___________^ ^ ^____|
- * | | |______________|
- * | |____________________________|
- * |____________________________________________|
- *
- * newPartition and newBucket events are only triggered if the relation to be written out is
- * partitioned and/or bucketed, respectively.
+ * newPartition event is only triggered if the relation to be written out is partitioned.
*/
trait WriteTaskStatsTracker {
@@ -56,22 +43,20 @@ trait WriteTaskStatsTracker {
*/
def newPartition(partitionValues: InternalRow): Unit
- /**
- * Process the fact that a new bucket is about to written.
- * Only triggered when the relation is bucketed by a (non-empty) sequence of columns.
- * @param bucketId The bucket number.
- */
- def newBucket(bucketId: Int): Unit
-
/**
* Process the fact that a new file is about to be written.
* @param filePath Path of the file into which future rows will be written.
*/
def newFile(filePath: String): Unit
+ /**
+ * Process the fact that a file is finished to be written and closed.
+ * @param filePath Path of the file.
+ */
+ def closeFile(filePath: String): Unit
+
/**
* Process the fact that a new row to update the tracked statistics accordingly.
- * The row will be written to the most recently witnessed file (via `newFile`).
* @note Keep in mind that any overhead here is per-row, obviously,
* so implementations should be as lightweight as possible.
* @param row Current data row to be processed.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala
index 2b549536ae..35d0e098b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter}
import org.apache.spark.sql.types.StructType
class CsvOutputWriter(
- path: String,
+ val path: String,
dataSchema: StructType,
context: TaskAttemptContext,
params: CSVOptions) extends OutputWriter with Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala
index 719d72f5b9..55602ce2ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter}
import org.apache.spark.sql.types.StructType
class JsonOutputWriter(
- path: String,
+ val path: String,
options: JSONOptions,
dataSchema: StructType,
context: TaskAttemptContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
index 08086bcd91..6f215737f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types._
private[sql] class OrcOutputWriter(
- path: String,
+ val path: String,
dataSchema: StructType,
context: TaskAttemptContext)
extends OutputWriter {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala
index 70f6726c58..efb322f3fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.OutputWriter
// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
-class ParquetOutputWriter(path: String, context: TaskAttemptContext)
+class ParquetOutputWriter(val path: String, context: TaskAttemptContext)
extends OutputWriter {
private val recordWriter: RecordWriter[Void, InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala
index 2b1b81f60c..2fb37c0dc0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter}
import org.apache.spark.sql.types.StructType
class TextOutputWriter(
- path: String,
+ val path: String,
dataSchema: StructType,
lineSeparator: Array[Byte],
context: TaskAttemptContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala
index 1f25fed300..d827e83623 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory}
-import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataWriter, SingleDirectoryDataWriter, WriteJobDescription}
+import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataSingleWriter, SingleDirectoryDataWriter, WriteJobDescription}
case class FileWriterFactory (
description: WriteJobDescription,
@@ -35,7 +35,7 @@ case class FileWriterFactory (
if (description.partitionColumns.isEmpty) {
new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
} else {
- new DynamicPartitionDataWriter(description, taskAttemptContext, committer)
+ new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 167ba45b88..1f57f17911 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -81,6 +81,10 @@ object PushDownUtils extends PredicateHelper {
relation: DataSourceV2Relation,
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Scan, Seq[AttributeReference]) = {
+ val exprs = projects ++ filters
+ val requiredColumns = AttributeSet(exprs.flatMap(_.references))
+ val neededOutput = relation.output.filter(requiredColumns.contains)
+
scanBuilder match {
case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled =>
val rootFields = SchemaPruning.identifyRootFields(projects, filters)
@@ -89,14 +93,12 @@ object PushDownUtils extends PredicateHelper {
} else {
new StructType()
}
- r.pruneColumns(prunedSchema)
+ val neededFieldNames = neededOutput.map(_.name).toSet
+ r.pruneColumns(StructType(prunedSchema.filter(f => neededFieldNames.contains(f.name))))
val scan = r.build()
scan -> toOutputAttrs(scan.readSchema(), relation)
case r: SupportsPushDownRequiredColumns =>
- val exprs = projects ++ filters
- val requiredColumns = AttributeSet(exprs.flatMap(_.references))
- val neededOutput = relation.output.filter(requiredColumns.contains)
r.pruneColumns(neededOutput.toStructType)
val scan = r.build()
// always project, in case the relation's output has been updated and doesn't match
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
index d3bc4aed57..9a05e396d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.DYNAMIC_PRUNING_SUBQUERY
import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan, SubqueryBroadcastExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins._
@@ -49,7 +50,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession)
return plan
}
- plan transformAllExpressions {
+ plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) {
case DynamicPruningSubquery(
value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) =>
val sparkPlan = QueryExecution.createSparkPlan(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala
index 3cb20f87ae..f2449a1ec5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala
@@ -17,10 +17,7 @@
package org.apache.spark.sql.execution.metric
-import java.text.NumberFormat
-import java.util.Locale
-
-import org.apache.spark.sql.connector.CustomMetric
+import org.apache.spark.sql.connector.metric.CustomMetric
object CustomMetrics {
private[spark] val V2_CUSTOM = "v2Custom"
@@ -45,35 +42,3 @@ object CustomMetrics {
}
}
}
-
-/**
- * Built-in `CustomMetric` that sums up metric values.
- */
-class CustomSumMetric extends CustomMetric {
- override def name(): String = "CustomSumMetric"
-
- override def description(): String = "Sum up CustomMetric"
-
- override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
- taskMetrics.sum.toString
- }
-}
-
-/**
- * Built-in `CustomMetric` that computes average of metric values.
- */
-class CustomAvgMetric extends CustomMetric {
- override def name(): String = "CustomAvgMetric"
-
- override def description(): String = "Average CustomMetric"
-
- override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
- val average = if (taskMetrics.isEmpty) {
- 0.0
- } else {
- taskMetrics.sum.toDouble / taskMetrics.length
- }
- val numberFormat = NumberFormat.getNumberInstance(Locale.US)
- numberFormat.format(average)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index da39e8c455..959144bab3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -24,7 +24,7 @@ import scala.concurrent.duration._
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
-import org.apache.spark.sql.connector.CustomMetric
+import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
@@ -113,7 +113,7 @@ object SQLMetrics {
*/
def createV2CustomMetric(sc: SparkContext, customMetric: CustomMetric): SQLMetric = {
val acc = new SQLMetric(CustomMetrics.buildV2CustomMetricTypeName(customMetric))
- acc.register(sc, name = Some(customMetric.name()), countFailedValues = false)
+ acc.register(sc, name = Some(customMetric.description()), countFailedValues = false)
acc
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 1c018be6d5..c6a70fb204 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -40,7 +40,6 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e) ||
- e.isInstanceOf[GroupingExprRef] ||
agg.groupingExpressions.exists(_.semanticEquals(e))
}
@@ -120,8 +119,23 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] {
groupingExpr += expr
}
}
+ val aggExpr = agg.aggregateExpressions.map { expr =>
+ expr.transformUp {
+ // PythonUDF over aggregate was pull out by ExtractPythonUDFFromAggregate.
+ // PythonUDF here should be either
+ // 1. Argument of an aggregate function.
+ // CheckAnalysis guarantees the arguments are deterministic.
+ // 2. PythonUDF in grouping key. Grouping key must be deterministic.
+ // 3. PythonUDF not in grouping key. It is either no arguments or with grouping key
+ // in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too.
+ case p: PythonUDF if p.udfDeterministic =>
+ val canonicalized = p.canonicalized.asInstanceOf[PythonUDF]
+ attributeMap.getOrElse(canonicalized, p)
+ }.asInstanceOf[NamedExpression]
+ }
agg.copy(
groupingExpressions = groupingExpr.toSeq,
+ aggregateExpressions = aggExpr,
child = Project((projList ++ agg.child.output).toSeq, agg.child))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 15b85013c4..f96e9ee3ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression,
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{IN_SUBQUERY, SCALAR_SUBQUERY}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
@@ -176,7 +177,7 @@ case class InSubqueryExec(
*/
case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
- plan.transformAllExpressions {
+ plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY)) {
case subquery: expressions.ScalarSubquery =>
val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, subquery.plan)
ScalarSubquery(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
index a3238551b2..e7ab4a184b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{JobExecutionStatus, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Status._
import org.apache.spark.scheduler._
-import org.apache.spark.sql.connector.CustomMetric
+import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
index ed1d24b682..b7f3dec224 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.execution.UnaryExecNode
-import org.apache.spark.sql.types.{CalendarIntervalType, DateType, IntegerType, TimestampType}
+import org.apache.spark.sql.types.{CalendarIntervalType, DateType, DayTimeIntervalType, IntegerType, TimestampType, YearMonthIntervalType}
trait WindowExecBase extends UnaryExecNode {
def windowExpression: Seq[NamedExpression]
@@ -95,8 +95,11 @@ trait WindowExecBase extends UnaryExecNode {
// Create the projection which returns the current 'value' modified by adding the offset.
val boundExpr = (expr.dataType, boundOffset.dataType) match {
case (DateType, IntegerType) => DateAdd(expr, boundOffset)
- case (TimestampType, CalendarIntervalType) =>
- TimeAdd(expr, boundOffset, Some(timeZone))
+ case (DateType, YearMonthIntervalType) => DateAddYMInterval(expr, boundOffset)
+ case (TimestampType, CalendarIntervalType) => TimeAdd(expr, boundOffset, Some(timeZone))
+ case (TimestampType, YearMonthIntervalType) =>
+ TimestampAddYMInterval(expr, boundOffset, Some(timeZone))
+ case (TimestampType, DayTimeIntervalType) => TimeAdd(expr, boundOffset, Some(timeZone))
case (a, b) if a == b => Add(expr, boundOffset)
}
val bound = MutableProjection.create(boundExpr :: Nil, child.output)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java
index e418958bef..59c5263563 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java
@@ -23,7 +23,7 @@
import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException;
-import org.apache.spark.sql.connector.InMemoryTableCatalog;
+import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog;
import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.StructType;
import org.junit.After;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/extract.sql b/sql/core/src/test/resources/sql-tests/inputs/extract.sql
index 0f1fd5bbcc..31603fba99 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/extract.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/extract.sql
@@ -128,3 +128,34 @@ select c - i from t;
select year(c - i) from t;
select extract(year from c - i) from t;
select extract(month from to_timestamp(c) - i) from t;
+
+-- extract fields from year-month/day-time intervals
+select extract(YEAR from interval '2-1' YEAR TO MONTH);
+select date_part('YEAR', interval '2-1' YEAR TO MONTH);
+select extract(YEAR from -interval '2-1' YEAR TO MONTH);
+select extract(MONTH from interval '2-1' YEAR TO MONTH);
+select date_part('MONTH', interval '2-1' YEAR TO MONTH);
+select extract(MONTH from -interval '2-1' YEAR TO MONTH);
+select date_part(NULL, interval '2-1' YEAR TO MONTH);
+
+-- invalid
+select extract(DAY from interval '2-1' YEAR TO MONTH);
+select date_part('DAY', interval '2-1' YEAR TO MONTH);
+select date_part('not_supported', interval '2-1' YEAR TO MONTH);
+
+select extract(DAY from interval '123 12:34:56.789123123' DAY TO SECOND);
+select date_part('DAY', interval '123 12:34:56.789123123' DAY TO SECOND);
+select extract(DAY from -interval '123 12:34:56.789123123' DAY TO SECOND);
+select extract(HOUR from interval '123 12:34:56.789123123' DAY TO SECOND);
+select date_part('HOUR', interval '123 12:34:56.789123123' DAY TO SECOND);
+select extract(HOUR from -interval '123 12:34:56.789123123' DAY TO SECOND);
+select extract(MINUTE from interval '123 12:34:56.789123123' DAY TO SECOND);
+select date_part('MINUTE', interval '123 12:34:56.789123123' DAY TO SECOND);
+select extract(MINUTE from -interval '123 12:34:56.789123123' DAY TO SECOND);
+select extract(SECOND from interval '123 12:34:56.789123123' DAY TO SECOND);
+select date_part('SECOND', interval '123 12:34:56.789123123' DAY TO SECOND);
+select extract(SECOND from -interval '123 12:34:56.789123123' DAY TO SECOND);
+select date_part(NULL, interval '123 12:34:56.789123123' DAY TO SECOND);
+
+select extract(MONTH from interval '123 12:34:56.789123123' DAY TO SECOND);
+select date_part('not_supported', interval '123 12:34:56.789123123' DAY TO SECOND);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql
index 6dfe31e270..d6381e59e0 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql
@@ -80,3 +80,14 @@ SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), (a), ());
SELECT a, b, count(1) FROM testData GROUP BY a, CUBE(a, b), GROUPING SETS((a, b), (a), ());
SELECT a, b, count(1) FROM testData GROUP BY a, CUBE(a, b), ROLLUP(a, b), GROUPING SETS((a, b), (a), ());
+-- Support nested CUBE/ROLLUP/GROUPING SETS in GROUPING SETS
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b));
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ()));
+
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), GROUPING SETS(ROLLUP(a, b)));
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b, a, b), (a, b, a), (a, b));
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b, a, b), (a, b, a), (a, b)));
+
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b), CUBE(a, b));
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ()), GROUPING SETS((a, b), (a), (b), ()));
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), (a), (), (a, b), (a), (b), ());
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index 988ad99418..6ee1014739 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -179,12 +179,3 @@ SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(
-- Aggregate with multiple distinct decimal columns
SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col);
-
--- SPARK-34581: Don't optimize out grouping expressions from aggregate expressions without aggregate function
-SELECT not(a IS NULL), count(*) AS c
-FROM testData
-GROUP BY a IS NULL;
-
-SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
-FROM testData
-GROUP BY a IS NULL;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql
index 3fcbdacda6..063727a76e 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql
@@ -102,6 +102,10 @@ select interval 30 day day day;
select interval (-30) days;
select interval (a + 1) days;
select interval 30 days days days;
+SELECT INTERVAL '178956970-7' YEAR TO MONTH;
+SELECT INTERVAL '178956970-8' YEAR TO MONTH;
+SELECT INTERVAL '-178956970-8' YEAR TO MONTH;
+SELECT INTERVAL -'178956970-8' YEAR TO MONTH;
-- Interval year-month arithmetic
@@ -218,3 +222,17 @@ select interval '1 day 1';
select interval '1 day 2' day;
select interval 'interval 1' day;
select interval '-\t 1' day;
+
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 2;
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 5;
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1;
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1L;
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0;
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0D;
+
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 2;
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 5;
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1;
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1L;
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0;
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0D;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql
index 7419ca1bd0..d84659c4cc 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql
@@ -11,6 +11,18 @@ CREATE OR REPLACE TEMPORARY VIEW script_trans AS SELECT * FROM VALUES
(7, 8, 9)
AS script_trans(a, b, c);
+CREATE OR REPLACE TEMPORARY VIEW complex_trans AS SELECT * FROM VALUES
+(1, 1),
+(1, 1),
+(2, 2),
+(2, 2),
+(3, 3),
+(2, 2),
+(3, 3),
+(1, 1),
+(3, 3)
+as complex_trans(a, b);
+
SELECT TRANSFORM(a)
USING 'cat' AS (a)
FROM t;
@@ -342,3 +354,22 @@ SELECT TRANSFORM(b, MAX(a) AS max_a, CAST(sum(c) AS STRING))
FROM script_trans
WHERE a <= 2
GROUP BY b;
+
+-- SPARK-33985: TRANSFORM with CLUSTER BY/ORDER BY/SORT BY
+FROM (
+ SELECT TRANSFORM(a, b)
+ USING 'cat' AS (a, b)
+ FROM complex_trans
+ CLUSTER BY a
+) map_output
+SELECT TRANSFORM(a, b)
+ USING 'cat' AS (a, b);
+
+FROM (
+ SELECT TRANSFORM(a, b)
+ USING 'cat' AS (a, b)
+ FROM complex_trans
+ ORDER BY a
+) map_output
+SELECT TRANSFORM(a, b)
+ USING 'cat' AS (a, b);
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql
index 56f2b0b20c..46d3629a5d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/window.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql
@@ -70,6 +70,18 @@ RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) FROM testData ORDER BY cate, val_date
SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp
RANGE BETWEEN CURRENT ROW AND interval 23 days 4 hours FOLLOWING) FROM testData
ORDER BY cate, val_timestamp;
+SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp
+RANGE BETWEEN CURRENT ROW AND interval '1-1' year to month FOLLOWING) FROM testData
+ORDER BY cate, val_timestamp;
+SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp
+RANGE BETWEEN CURRENT ROW AND interval '1 2:3:4.001' day to second FOLLOWING) FROM testData
+ORDER BY cate, val_timestamp;
+SELECT val_date, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_date
+RANGE BETWEEN CURRENT ROW AND interval '1-1' year to month FOLLOWING) FROM testData
+ORDER BY cate, val_date;
+SELECT val_date, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_date
+RANGE BETWEEN CURRENT ROW AND interval '1 2:3:4.001' day to second FOLLOWING) FROM testData
+ORDER BY cate, val_date;
-- RangeBetween with reverse OrderBy
SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/group-analytics.sql.out
index 1db8febb81..9dbfc4cf4f 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/group-analytics.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/group-analytics.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 44
+-- Number of queries: 52
-- !query
@@ -1067,3 +1067,227 @@ struct
3 NULL 2
3 NULL 2
3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 2 1
+1 NULL 2
+1 NULL 2
+2 1 1
+2 2 1
+2 NULL 2
+2 NULL 2
+3 1 1
+3 2 1
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ()))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 2 1
+1 NULL 2
+1 NULL 2
+2 1 1
+2 2 1
+2 NULL 2
+2 NULL 2
+3 1 1
+3 2 1
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), GROUPING SETS(ROLLUP(a, b)))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 NULL 2
+1 NULL 2
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 NULL 2
+2 NULL 2
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b, a, b), (a, b, a), (a, b))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b, a, b), (a, b, a), (a, b)))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b), CUBE(a, b))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+1 NULL 2
+1 NULL 2
+1 NULL 2
+1 NULL 2
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+2 NULL 2
+2 NULL 2
+2 NULL 2
+2 NULL 2
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+3 NULL 2
+3 NULL 2
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ()), GROUPING SETS((a, b), (a), (b), ()))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+1 NULL 2
+1 NULL 2
+1 NULL 2
+1 NULL 2
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+2 NULL 2
+2 NULL 2
+2 NULL 2
+2 NULL 2
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+3 NULL 2
+3 NULL 2
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), (a), (), (a, b), (a), (b), ())
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+1 NULL 2
+1 NULL 2
+1 NULL 2
+1 NULL 2
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+2 NULL 2
+2 NULL 2
+2 NULL 2
+2 NULL 2
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+3 NULL 2
+3 NULL 2
+3 NULL 2
+3 NULL 2
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
index 781a7d739c..e383fc1b85 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 118
+-- Number of queries: 134
-- !query
@@ -780,6 +780,44 @@ select interval 30 days days days
-----------------------------^^^
+-- !query
+SELECT INTERVAL '178956970-7' YEAR TO MONTH
+-- !query schema
+struct
+-- !query output
+178956970-7
+
+
+-- !query
+SELECT INTERVAL '178956970-8' YEAR TO MONTH
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+Error parsing interval year-month string: integer overflow(line 1, pos 16)
+
+== SQL ==
+SELECT INTERVAL '178956970-8' YEAR TO MONTH
+----------------^^^
+
+
+-- !query
+SELECT INTERVAL '-178956970-8' YEAR TO MONTH
+-- !query schema
+struct
+-- !query output
+-178956970-8
+
+
+-- !query
+SELECT INTERVAL -'178956970-8' YEAR TO MONTH
+-- !query schema
+struct
+-- !query output
+-178956970-8
+
+
-- !query
create temporary view interval_arithmetic as
select CAST(dateval AS date), CAST(tsval AS timestamp), dateval as strval from values
@@ -1221,3 +1259,107 @@ select interval '-\t 1' day
struct
-- !query output
-1 days
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 2
+-- !query schema
+struct<(INTERVAL '-178956970-8' YEAR TO MONTH / 2):year-month interval>
+-- !query output
+-89478485-4
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 5
+-- !query schema
+struct<(INTERVAL '-178956970-8' YEAR TO MONTH / 5):year-month interval>
+-- !query output
+-35791394-2
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow in integral divide.
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1L
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow in integral divide.
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0D
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+not in range
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 2
+-- !query schema
+struct<(INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND / 2):day-time interval>
+-- !query output
+-53375995 14:00:27.387904000
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 5
+-- !query schema
+struct<(INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND / 5):day-time interval>
+-- !query output
+-21350398 05:36:10.955162000
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow in integral divide.
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1L
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow in integral divide.
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0D
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+not in range
diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out
index 35cfda1767..63b5caac48 100644
--- a/sql/core/src/test/resources/sql-tests/results/extract.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 100
+-- Number of queries: 125
-- !query
@@ -197,7 +197,7 @@ struct
-- !query
select extract(hour from c), extract(hour from i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -205,7 +205,7 @@ struct
-- !query
select extract(h from c), extract(h from i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -213,7 +213,7 @@ struct
-- !query
select extract(hours from c), extract(hours from i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -221,7 +221,7 @@ struct
-- !query
select extract(hr from c), extract(hr from i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -229,7 +229,7 @@ struct
-- !query
select extract(hrs from c), extract(hrs from i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -519,7 +519,7 @@ struct
-- !query
select date_part('hour', c), date_part('hour', i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -527,7 +527,7 @@ struct
-- !query
select date_part('h', c), date_part('h', i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -535,7 +535,7 @@ struct
-- !query
select date_part('hours', c), date_part('hours', i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -543,7 +543,7 @@ struct
-- !query
select date_part('hr', c), date_part('hr', i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -551,7 +551,7 @@ struct
-- !query
select date_part('hrs', c), date_part('hrs', i) from t
-- !query schema
-struct
+struct
-- !query output
7 16
@@ -805,3 +805,208 @@ select extract(month from to_timestamp(c) - i) from t
struct
-- !query output
8
+
+
+-- !query
+select extract(YEAR from interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct
+-- !query output
+2
+
+
+-- !query
+select date_part('YEAR', interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct
+-- !query output
+2
+
+
+-- !query
+select extract(YEAR from -interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct
+-- !query output
+-2
+
+
+-- !query
+select extract(MONTH from interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+select date_part('MONTH', interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+select extract(MONTH from -interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct
+-- !query output
+-1
+
+
+-- !query
+select date_part(NULL, interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct
+-- !query output
+NULL
+
+
+-- !query
+select extract(DAY from interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+Literals of type 'DAY' are currently not supported for the year-month interval type.; line 1 pos 7
+
+
+-- !query
+select date_part('DAY', interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+Literals of type 'DAY' are currently not supported for the year-month interval type.; line 1 pos 7
+
+
+-- !query
+select date_part('not_supported', interval '2-1' YEAR TO MONTH)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+Literals of type 'not_supported' are currently not supported for the year-month interval type.; line 1 pos 7
+
+
+-- !query
+select extract(DAY from interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+123
+
+
+-- !query
+select date_part('DAY', interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+123
+
+
+-- !query
+select extract(DAY from -interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+-123
+
+
+-- !query
+select extract(HOUR from interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+12
+
+
+-- !query
+select date_part('HOUR', interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+12
+
+
+-- !query
+select extract(HOUR from -interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+-12
+
+
+-- !query
+select extract(MINUTE from interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+34
+
+
+-- !query
+select date_part('MINUTE', interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+34
+
+
+-- !query
+select extract(MINUTE from -interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+-34
+
+
+-- !query
+select extract(SECOND from interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+56.789123
+
+
+-- !query
+select date_part('SECOND', interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+56.789123
+
+
+-- !query
+select extract(SECOND from -interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+-56.789123
+
+
+-- !query
+select date_part(NULL, interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct
+-- !query output
+NULL
+
+
+-- !query
+select extract(MONTH from interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+Literals of type 'MONTH' are currently not supported for the day-time interval type.; line 1 pos 7
+
+
+-- !query
+select date_part('not_supported', interval '123 12:34:56.789123123' DAY TO SECOND)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+Literals of type 'not_supported' are currently not supported for the day-time interval type.; line 1 pos 7
diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out
index 6dc02ead9d..f249908163 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 44
+-- Number of queries: 52
-- !query
@@ -1087,3 +1087,227 @@ struct
3 NULL 2
3 NULL 2
3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 2 1
+1 NULL 2
+1 NULL 2
+2 1 1
+2 2 1
+2 NULL 2
+2 NULL 2
+3 1 1
+3 2 1
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ()))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 2 1
+1 NULL 2
+1 NULL 2
+2 1 1
+2 2 1
+2 NULL 2
+2 NULL 2
+3 1 1
+3 2 1
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), GROUPING SETS(ROLLUP(a, b)))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 NULL 2
+1 NULL 2
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 NULL 2
+2 NULL 2
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b, a, b), (a, b, a), (a, b))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b, a, b), (a, b, a), (a, b)))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b), CUBE(a, b))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+1 NULL 2
+1 NULL 2
+1 NULL 2
+1 NULL 2
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+2 NULL 2
+2 NULL 2
+2 NULL 2
+2 NULL 2
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+3 NULL 2
+3 NULL 2
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ()), GROUPING SETS((a, b), (a), (b), ()))
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+1 NULL 2
+1 NULL 2
+1 NULL 2
+1 NULL 2
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+2 NULL 2
+2 NULL 2
+2 NULL 2
+2 NULL 2
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+3 NULL 2
+3 NULL 2
+3 NULL 2
+3 NULL 2
+
+
+-- !query
+SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), (a), (), (a, b), (a), (b), ())
+-- !query schema
+struct
+-- !query output
+1 1 1
+1 1 1
+1 1 1
+1 2 1
+1 2 1
+1 2 1
+1 NULL 2
+1 NULL 2
+1 NULL 2
+1 NULL 2
+2 1 1
+2 1 1
+2 1 1
+2 2 1
+2 2 1
+2 2 1
+2 NULL 2
+2 NULL 2
+2 NULL 2
+2 NULL 2
+3 1 1
+3 1 1
+3 1 1
+3 2 1
+3 2 1
+3 2 1
+3 NULL 2
+3 NULL 2
+3 NULL 2
+3 NULL 2
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index b5471a785a..1d8c44c291 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 64
+-- Number of queries: 62
-- !query
@@ -642,25 +642,3 @@ SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1
struct
-- !query output
1.0000 1
-
-
--- !query
-SELECT not(a IS NULL), count(*) AS c
-FROM testData
-GROUP BY a IS NULL
--- !query schema
-struct<(NOT (a IS NULL)):boolean,c:bigint>
--- !query output
-false 2
-true 7
-
-
--- !query
-SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
-FROM testData
-GROUP BY a IS NULL
--- !query schema
-struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint>
--- !query output
-0.7604953758285915 7
-1.0 2
diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out
index d525d8044a..a2cbea2906 100644
--- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 118
+-- Number of queries: 134
-- !query
@@ -774,6 +774,44 @@ select interval 30 days days days
-----------------------------^^^
+-- !query
+SELECT INTERVAL '178956970-7' YEAR TO MONTH
+-- !query schema
+struct
+-- !query output
+178956970-7
+
+
+-- !query
+SELECT INTERVAL '178956970-8' YEAR TO MONTH
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+Error parsing interval year-month string: integer overflow(line 1, pos 16)
+
+== SQL ==
+SELECT INTERVAL '178956970-8' YEAR TO MONTH
+----------------^^^
+
+
+-- !query
+SELECT INTERVAL '-178956970-8' YEAR TO MONTH
+-- !query schema
+struct
+-- !query output
+-178956970-8
+
+
+-- !query
+SELECT INTERVAL -'178956970-8' YEAR TO MONTH
+-- !query schema
+struct
+-- !query output
+-178956970-8
+
+
-- !query
create temporary view interval_arithmetic as
select CAST(dateval AS date), CAST(tsval AS timestamp), dateval as strval from values
@@ -1210,3 +1248,107 @@ select interval '-\t 1' day
struct
-- !query output
-1 days
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 2
+-- !query schema
+struct<(INTERVAL '-178956970-8' YEAR TO MONTH / 2):year-month interval>
+-- !query output
+-89478485-4
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 5
+-- !query schema
+struct<(INTERVAL '-178956970-8' YEAR TO MONTH / 5):year-month interval>
+-- !query output
+-35791394-2
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow in integral divide.
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1L
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow in integral divide.
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow
+
+
+-- !query
+SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0D
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+not in range
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 2
+-- !query schema
+struct<(INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND / 2):day-time interval>
+-- !query output
+-53375995 14:00:27.387904000
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 5
+-- !query schema
+struct<(INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND / 5):day-time interval>
+-- !query output
+-21350398 05:36:10.955162000
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow in integral divide.
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1L
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow in integral divide.
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+Overflow
+
+
+-- !query
+SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0D
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArithmeticException
+not in range
diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out
index 1d7e9cdb43..6f94e742b8 100644
--- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 44
+-- Number of queries: 47
-- !query
@@ -26,6 +26,24 @@ struct<>
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW complex_trans AS SELECT * FROM VALUES
+(1, 1),
+(1, 1),
+(2, 2),
+(2, 2),
+(3, 3),
+(2, 2),
+(3, 3),
+(1, 1),
+(3, 3)
+as complex_trans(a, b)
+-- !query schema
+struct<>
+-- !query output
+
+
+
-- !query
SELECT TRANSFORM(a)
USING 'cat' AS (a)
@@ -717,3 +735,49 @@ SELECT TRANSFORM(b, MAX(a) AS max_a, CAST(sum(c) AS STRING))
FROM script_trans
WHERE a <= 2
GROUP BY b
+
+
+-- !query
+FROM (
+ SELECT TRANSFORM(a, b)
+ USING 'cat' AS (a, b)
+ FROM complex_trans
+ CLUSTER BY a
+) map_output
+SELECT TRANSFORM(a, b)
+ USING 'cat' AS (a, b)
+-- !query schema
+struct
+-- !query output
+1 1
+1 1
+1 1
+2 2
+2 2
+2 2
+3 3
+3 3
+3 3
+
+
+-- !query
+FROM (
+ SELECT TRANSFORM(a, b)
+ USING 'cat' AS (a, b)
+ FROM complex_trans
+ ORDER BY a
+) map_output
+SELECT TRANSFORM(a, b)
+ USING 'cat' AS (a, b)
+-- !query schema
+struct
+-- !query output
+1 1
+1 1
+1 1
+2 2
+2 2
+2 2
+3 3
+3 3
+3 3
diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out
index c377658722..7443b95582 100644
--- a/sql/core/src/test/resources/sql-tests/results/window.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 46
+-- Number of queries: 50
-- !query
@@ -211,6 +211,71 @@ NULL NULL NULL
2020-12-30 16:00:00 b 1.6093728E9
+-- !query
+SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp
+RANGE BETWEEN CURRENT ROW AND interval '1-1' year to month FOLLOWING) FROM testData
+ORDER BY cate, val_timestamp
+-- !query schema
+struct
+-- !query output
+NULL NULL NULL
+2017-07-31 17:00:00 NULL 1.5015456E9
+2017-07-31 17:00:00 a 1.5016970666666667E9
+2017-07-31 17:00:00 a 1.5016970666666667E9
+2017-08-05 23:13:20 a 1.502E9
+2020-12-30 16:00:00 a 1.6093728E9
+2017-07-31 17:00:00 b 1.5022728E9
+2017-08-17 13:00:00 b 1.503E9
+2020-12-30 16:00:00 b 1.6093728E9
+
+
+-- !query
+SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp
+RANGE BETWEEN CURRENT ROW AND interval '1 2:3:4.001' day to second FOLLOWING) FROM testData
+ORDER BY cate, val_timestamp
+-- !query schema
+struct
+-- !query output
+NULL NULL NULL
+2017-07-31 17:00:00 NULL 1.5015456E9
+2017-07-31 17:00:00 a 1.5015456E9
+2017-07-31 17:00:00 a 1.5015456E9
+2017-08-05 23:13:20 a 1.502E9
+2020-12-30 16:00:00 a 1.6093728E9
+2017-07-31 17:00:00 b 1.5015456E9
+2017-08-17 13:00:00 b 1.503E9
+2020-12-30 16:00:00 b 1.6093728E9
+
+
+-- !query
+SELECT val_date, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_date
+RANGE BETWEEN CURRENT ROW AND interval '1-1' year to month FOLLOWING) FROM testData
+ORDER BY cate, val_date
+-- !query schema
+struct
+-- !query output
+NULL NULL NULL
+2017-08-01 NULL 1.5015456E9
+2017-08-01 a 1.5016970666666667E9
+2017-08-01 a 1.5016970666666667E9
+2017-08-02 a 1.502E9
+2020-12-31 a 1.6093728E9
+2017-08-01 b 1.5022728E9
+2017-08-03 b 1.503E9
+2020-12-31 b 1.6093728E9
+
+
+-- !query
+SELECT val_date, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_date
+RANGE BETWEEN CURRENT ROW AND interval '1 2:3:4.001' day to second FOLLOWING) FROM testData
+ORDER BY cate, val_date
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(PARTITION BY testdata.cate ORDER BY testdata.val_date ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND INTERVAL '1 02:03:04.001' DAY TO SECOND FOLLOWING)' due to data type mismatch: The data type 'date' used in the order specification does not match the data type 'day-time interval' which is used in the range frame.; line 1 pos 46
+
+
-- !query
SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC
RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index 7d1e4ff040..c06544ee00 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.connector.{InMemoryPartitionTableCatalog, SchemaRequiredDataSource}
+import org.apache.spark.sql.connector.SchemaRequiredDataSource
+import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 5108c0169b..ad5d73c774 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -1686,6 +1686,61 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
StructType(Seq(StructField("a", IntegerType, nullable = true))))
}
+ test("SPARK-35213: chained withField operations should have correct schema for new columns") {
+ val df = spark.createDataFrame(
+ sparkContext.parallelize(Row(null) :: Nil),
+ StructType(Seq(StructField("data", NullType))))
+
+ checkAnswer(
+ df.withColumn("data", struct()
+ .withField("a", struct())
+ .withField("b", struct())
+ .withField("a.aa", lit("aa1"))
+ .withField("b.ba", lit("ba1"))
+ .withField("a.ab", lit("ab1"))),
+ Row(Row(Row("aa1", "ab1"), Row("ba1"))) :: Nil,
+ StructType(Seq(
+ StructField("data", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("aa", StringType, nullable = false),
+ StructField("ab", StringType, nullable = false)
+ )), nullable = false),
+ StructField("b", StructType(Seq(
+ StructField("ba", StringType, nullable = false)
+ )), nullable = false)
+ )), nullable = false)
+ ))
+ )
+ }
+
+ test("SPARK-35213: optimized withField operations should maintain correct nested struct " +
+ "ordering") {
+ val df = spark.createDataFrame(
+ sparkContext.parallelize(Row(null) :: Nil),
+ StructType(Seq(StructField("data", NullType))))
+
+ checkAnswer(
+ df.withColumn("data", struct()
+ .withField("a", struct().withField("aa", lit("aa1")))
+ .withField("b", struct().withField("ba", lit("ba1")))
+ )
+ .withColumn("data", col("data").withField("b.bb", lit("bb1")))
+ .withColumn("data", col("data").withField("a.ab", lit("ab1"))),
+ Row(Row(Row("aa1", "ab1"), Row("ba1", "bb1"))) :: Nil,
+ StructType(Seq(
+ StructField("data", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("aa", StringType, nullable = false),
+ StructField("ab", StringType, nullable = false)
+ )), nullable = false),
+ StructField("b", StructType(Seq(
+ StructField("ba", StringType, nullable = false),
+ StructField("bb", StringType, nullable = false)
+ )), nullable = false)
+ )), nullable = false)
+ ))
+ )
+ }
test("dropFields should throw an exception if called on a non-StructType column") {
intercept[AnalysisException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index c53bcf045d..c6f6cbdbf0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1135,7 +1135,7 @@ class DataFrameAggregateSuite extends QueryTest
val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time"))
checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
- Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) ::Nil)
+ Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) :: Nil)
assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
StructField("sum(year-month)", YearMonthIntervalType),
@@ -1173,7 +1173,7 @@ class DataFrameAggregateSuite extends QueryTest
val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
- Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil)
+ Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) :: Nil)
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
StructField("avg(year-month)", YearMonthIntervalType),
@@ -1188,6 +1188,13 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(df2.select(avg($"day-time")), Nil)
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
+
+ val df3 = df.filter($"class" > 4)
+ val avgDF3 = df3.select(avg($"year-month"), avg($"day-time"))
+ checkAnswer(avgDF3, Row(null, null) :: Nil)
+
+ val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
+ checkAnswer(avgDF4, Nil)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
index 35e732e084..8aef27a1b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
@@ -25,8 +25,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic}
-import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog}
-import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, InMemoryTableCatalog, TableCatalog}
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index 688288abee..13d1285401 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -506,6 +506,14 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
checkKeywordsExistsInExplain(df2, keywords = "[key1=value1, KEY2=VALUE2]")
}
}
+
+ test("SPARK-35225: Handle empty output for analyzed plan") {
+ withTempView("test") {
+ checkKeywordsExistsInExplain(
+ sql("CREATE TEMPORARY VIEW test AS SELECT 1"),
+ "== Analyzed Logical Plan ==\nCreateViewCommand")
+ }
+ }
}
class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
index 807c6b2a67..2f56fbaf7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.Hex
-import org.apache.spark.sql.connector.InMemoryPartitionTableCatalog
+import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.unsafe.types.UTF8String
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
index efb87dafe0..d83d1a2755 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
@@ -22,7 +22,7 @@ import java.util.Collections
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan}
-import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.StructType
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
index c973e2ba30..44fbc639a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.connector
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
-import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, Table, TableCatalog}
class DataSourceV2SQLSessionCatalogSuite
extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = true)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
index 8922eea8e0..3ef242f90f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.connector
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.connector.catalog.CatalogPlugin
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryPartitionTableCatalog, InMemoryTableCatalog, StagingInMemoryTableCatalog}
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.SharedSparkSession
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
index 3aad644655..076dad7530 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression}
-import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, SupportsCatalogOptions, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform}
import org.apache.spark.sql.execution.QueryExecution
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
index 9beef690cb..847953e09c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext}
-import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table, TableCapability}
+import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns, V1Scan}
import org.apache.spark.sql.execution.RowDataSourceScanExec
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index 45ddc6a6fc..7effc747ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveM
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index 6fd9dc4e39..db4a9c153c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning}
-import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder}
import org.apache.spark.sql.connector.expressions.LogicalExpressions._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
index 40f25d5599..c845dd81f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import java.sql.{Date, Timestamp}
+import java.time.{Duration, Period}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
@@ -43,6 +44,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
import testImplicits._
import ScriptTransformationIOSchema._
+ protected def defaultSerDe(): String
+
protected val uncaughtExceptionHandler = new TestUncaughtExceptionHandler
private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _
@@ -599,6 +602,37 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
'e.cast("string")).collect())
}
}
+
+ test("SPARK-35220: DayTimeIntervalType/YearMonthIntervalType show different " +
+ "between hive serde and row format delimited\t") {
+ assume(TestUtils.testCommandAvailable("/bin/bash"))
+ withTempView("v") {
+ val df = Seq(
+ (Duration.ofDays(1), Period.ofMonths(10))
+ ).toDF("a", "b")
+ df.createTempView("v")
+
+ if (defaultSerDe == "hive-serde") {
+ checkAnswer(sql(
+ """
+ |SELECT TRANSFORM(a, b)
+ | USING 'cat' AS (a, b)
+ |FROM v
+ |""".stripMargin),
+ identity,
+ Row("1 00:00:00.000000000", "0-10") :: Nil)
+ } else {
+ checkAnswer(sql(
+ """
+ |SELECT TRANSFORM(a, b)
+ | USING 'cat' AS (a, b)
+ |FROM v
+ |""".stripMargin),
+ identity,
+ Row("INTERVAL '1 00:00:00' DAY TO SECOND", "INTERVAL '0-10' YEAR TO MONTH") :: Nil)
+ }
+ }
+ }
}
case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala
index f16265ee61..f8366b3f7c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import java.time.{Duration, Period}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
-import org.apache.spark.sql.connector.InMemoryTableCatalog
+import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
import org.apache.spark.sql.execution.HiveResult._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala
index f69aea3729..5638743b76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala
@@ -25,6 +25,8 @@ import org.apache.spark.sql.test.SharedSparkSession
class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with SharedSparkSession {
import testImplicits._
+ override protected def defaultSerDe(): String = "row-format-delimited"
+
override def createScriptTransformationExec(
script: String,
output: Seq[Attribute],
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 31b6921132..2598d3ba8b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -1575,4 +1575,19 @@ class AdaptiveQueryExecSuite
checkNoCoalescePartitions(df.sort($"key"), ENSURE_REQUIREMENTS)
}
}
+
+ test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") {
+ withTable("t") {
+ withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
+ spark.sql("CREATE TABLE t (c1 int) USING PARQUET")
+ val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1")
+ assert(
+ collect(adaptive) {
+ case c @ CustomShuffleReaderExec(_, partitionSpecs) if partitionSpecs.length == 1 => c
+ }.length == 1
+ )
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index f58d3246f5..1684633c92 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.arrow
+import org.apache.arrow.vector.IntervalDayVector
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util._
@@ -54,6 +56,8 @@ class ArrowWriterSuite extends SparkFunSuite {
case BinaryType => reader.getBinary(rowId)
case DateType => reader.getInt(rowId)
case TimestampType => reader.getLong(rowId)
+ case YearMonthIntervalType => reader.getInt(rowId)
+ case DayTimeIntervalType => reader.getLong(rowId)
}
assert(value === datum)
}
@@ -73,6 +77,33 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DateType, Seq(0, 1, 2, null, 4))
check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles")
check(NullType, Seq(null, null, null))
+ check(YearMonthIntervalType, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue))
+ check(DayTimeIntervalType, Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L),
+ (Long.MinValue + 808L)))
+ }
+
+ test("long overflow for DayTimeIntervalType")
+ {
+ val schema = new StructType().add("value", DayTimeIntervalType, nullable = true)
+ val writer = ArrowWriter.create(schema, null)
+ val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+ val valueVector = writer.root.getFieldVectors().get(0).asInstanceOf[IntervalDayVector]
+
+ valueVector.set(0, 106751992, 0)
+ valueVector.set(1, 106751991, Int.MaxValue)
+
+ // first long overflow for test Math.multiplyExact()
+ val msg = intercept[java.lang.ArithmeticException] {
+ reader.getLong(0)
+ }.getMessage
+ assert(msg.equals("long overflow"))
+
+ // second long overflow for test Math.addExact()
+ val msg1 = intercept[java.lang.ArithmeticException] {
+ reader.getLong(1)
+ }.getMessage
+ assert(msg1.equals("long overflow"))
+ writer.root.close()
}
test("get multiple") {
@@ -97,6 +128,8 @@ class ArrowWriterSuite extends SparkFunSuite {
case DoubleType => reader.getDoubles(0, data.size)
case DateType => reader.getInts(0, data.size)
case TimestampType => reader.getLongs(0, data.size)
+ case YearMonthIntervalType => reader.getInts(0, data.size)
+ case DayTimeIntervalType => reader.getLongs(0, data.size)
}
assert(values === data)
@@ -111,6 +144,8 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DoubleType, (0 until 10).map(_.toDouble))
check(DateType, (0 until 10))
check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles")
+ check(YearMonthIntervalType, (0 until 10))
+ check(DayTimeIntervalType, (-10 until 10).map(_ * 1000.toLong))
}
test("array") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
index d77ef6e6bd..b8d7b774d8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
@@ -80,7 +80,7 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = F", numIters = 3) { _ =>
+ benchmark.addCase("codegen = T, hashmap = F", numIters = 3) { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false",
@@ -89,7 +89,16 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = T", numIters = 5) { _ =>
+ benchmark.addCase("codegen = T, row-based hashmap = T", numIters = 5) { _ =>
+ withSQLConf(
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
+ SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
+ SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") {
+ f()
+ }
+ }
+
+ benchmark.addCase("codegen = T, vectorized hashmap = T", numIters = 5) { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
@@ -116,7 +125,7 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = F", numIters = 3) { _ =>
+ benchmark.addCase("codegen = T, hashmap = F", numIters = 3) { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false",
@@ -125,7 +134,16 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = T", numIters = 5) { _ =>
+ benchmark.addCase("codegen = T, row-based hashmap = T", numIters = 5) { _ =>
+ withSQLConf(
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
+ SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
+ SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") {
+ f()
+ }
+ }
+
+ benchmark.addCase("codegen = T, vectorized hashmap = T", numIters = 5) { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
@@ -151,7 +169,7 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = F", numIters = 3) { _ =>
+ benchmark.addCase("codegen = T, hashmap = F", numIters = 3) { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false",
@@ -160,7 +178,16 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = T", numIters = 5) { _ =>
+ benchmark.addCase("codegen = T, row-based hashmap = T", numIters = 5) { _ =>
+ withSQLConf(
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
+ SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
+ SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") {
+ f()
+ }
+ }
+
+ benchmark.addCase("codegen = T, vectorized hashmap = T", numIters = 5) { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
@@ -186,7 +213,7 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = F") { _ =>
+ benchmark.addCase("codegen = T, hashmap = F") { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false",
@@ -195,7 +222,16 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = T") { _ =>
+ benchmark.addCase("codegen = T, row-based hashmap = T") { _ =>
+ withSQLConf(
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
+ SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
+ SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") {
+ f()
+ }
+ }
+
+ benchmark.addCase("codegen = T, vectorized hashmap = T") { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
@@ -231,7 +267,7 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = F") { _ =>
+ benchmark.addCase("codegen = T, hashmap = F") { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false",
@@ -240,7 +276,16 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hashmap = T") { _ =>
+ benchmark.addCase("codegen = T, row-based hashmap = T") { _ =>
+ withSQLConf(
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
+ SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
+ SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") {
+ f()
+ }
+ }
+
+ benchmark.addCase("codegen = T, vectorized hashmap = T") { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true",
@@ -291,7 +336,7 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hugeMethodLimit = 10000") { _ =>
+ benchmark.addCase("codegen = T, hugeMethodLimit = 10000") { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "10000") {
@@ -299,7 +344,7 @@ object AggregateBenchmark extends SqlBasedBenchmark {
}
}
- benchmark.addCase("codegen = T hugeMethodLimit = 1500") { _ =>
+ benchmark.addCase("codegen = T, hugeMethodLimit = 1500") { _ =>
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "1500") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala
index 1f47744ce4..ba683c049a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
import org.apache.spark.SparkConf
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.connector.InMemoryPartitionTableCatalog
+import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.sql.types._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala
index 2dd80b7bb6..bed04f4f26 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala
@@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.command.v2
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.analysis.ResolvePartitionSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.connector.{InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog}
-import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog}
import org.apache.spark.sql.test.SharedSparkSession
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala
index 7a2c136eea..bafb6608c8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v2
import org.apache.spark.SparkConf
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.connector.BasicInMemoryTableCatalog
+import org.apache.spark.sql.connector.catalog.BasicInMemoryTableCatalog
import org.apache.spark.sql.execution.command
import org.apache.spark.sql.internal.SQLConf
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 765d2fc584..ac5c28953a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -351,6 +351,43 @@ abstract class SchemaPruningSuite
}
}
+ testSchemaPruning("SPARK-34638: nested column prune on generator output") {
+ val query1 = spark.table("contacts")
+ .select(explode(col("friends")).as("friend"))
+ .select("friend.first")
+ checkScan(query1, "struct>>")
+ checkAnswer(query1, Row("Susan") :: Nil)
+
+ // Currently we don't prune multiple field case.
+ val query2 = spark.table("contacts")
+ .select(explode(col("friends")).as("friend"))
+ .select("friend.first", "friend.middle")
+ checkScan(query2, "struct>>")
+ checkAnswer(query2, Row("Susan", "Z.") :: Nil)
+
+ val query3 = spark.table("contacts")
+ .select(explode(col("friends")).as("friend"))
+ .select("friend.first", "friend.middle", "friend")
+ checkScan(query3, "struct>>")
+ checkAnswer(query3, Row("Susan", "Z.", Row("Susan", "Z.", "Smith")) :: Nil)
+ }
+
+ testSchemaPruning("SPARK-34638: nested column prune on generator output - case-sensitivity") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ val query1 = spark.table("contacts")
+ .select(explode(col("friends")).as("friend"))
+ .select("friend.First")
+ checkScan(query1, "struct>>")
+ checkAnswer(query1, Row("Susan") :: Nil)
+
+ val query2 = spark.table("contacts")
+ .select(explode(col("friends")).as("friend"))
+ .select("friend.MIDDLE")
+ checkScan(query2, "struct>>")
+ checkAnswer(query2, Row("Z.") :: Nil)
+ }
+ }
+
testSchemaPruning("select one deep nested complex field after repartition") {
val query = sql("select * from contacts")
.repartition(100)
@@ -816,4 +853,21 @@ abstract class SchemaPruningSuite
Row("John", "Y.") :: Nil)
}
}
+
+ test("SPARK-34638: queries should not fail on unsupported cases") {
+ withTable("nested_array") {
+ sql("select * from values array(array(named_struct('a', 1, 'b', 3), " +
+ "named_struct('a', 2, 'b', 4))) T(items)").write.saveAsTable("nested_array")
+ val query = sql("select d.a from (select explode(c) d from " +
+ "(select explode(items) c from nested_array))")
+ checkAnswer(query, Row(1) :: Row(2) :: Nil)
+ }
+
+ withTable("map") {
+ sql("select * from values map(1, named_struct('a', 1, 'b', 3), " +
+ "2, named_struct('a', 2, 'b', 4)) T(items)").write.saveAsTable("map")
+ val query = sql("select d.a from (select explode(items) (c, d) from map)")
+ checkAnswer(query, Row(1) :: Row(2) :: Nil)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
index 7bd068c0f9..eee8e2ecc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
@@ -633,4 +633,20 @@ class OrcSourceSuite extends OrcSuite with SharedSparkSession {
}
}
}
+
+ test("SPARK-34897: Support reconcile schemas based on index after nested column pruning") {
+ withTable("t1") {
+ spark.sql(
+ """
+ |CREATE TABLE t1 (
+ | _col0 INT,
+ | _col1 STRING,
+ | _col2 STRUCT)
+ |USING ORC
+ |""".stripMargin)
+
+ spark.sql("INSERT INTO t1 values(1, '2', struct('a', 'b', 'c', 10L))")
+ checkAnswer(spark.sql("SELECT _col0, _col2.c1 FROM t1"), Seq(Row(1, "a")))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala
index e2fa03ff23..440b0dc08e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala
@@ -18,11 +18,12 @@
package org.apache.spark.sql.execution.metric
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.connector.metric.{CustomAvgMetric, CustomSumMetric}
class CustomMetricsSuite extends SparkFunSuite {
test("Build/parse custom metric metric type") {
- Seq(new CustomSumMetric, new CustomAvgMetric).foreach { customMetric =>
+ Seq(new TestCustomSumMetric, new TestCustomAvgMetric).foreach { customMetric =>
val metricType = CustomMetrics.buildV2CustomMetricTypeName(customMetric)
assert(metricType == CustomMetrics.V2_CUSTOM + "_" + customMetric.getClass.getCanonicalName)
@@ -33,7 +34,7 @@ class CustomMetricsSuite extends SparkFunSuite {
}
test("Built-in CustomSumMetric") {
- val metric = new CustomSumMetric
+ val metric = new TestCustomSumMetric
val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L)
assert(metric.aggregateTaskMetrics(metricValues1) == metricValues1.sum.toString)
@@ -43,7 +44,7 @@ class CustomMetricsSuite extends SparkFunSuite {
}
test("Built-in CustomAvgMetric") {
- val metric = new CustomAvgMetric
+ val metric = new TestCustomAvgMetric
val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L)
assert(metric.aggregateTaskMetrics(metricValues1) == "4.667")
@@ -52,3 +53,13 @@ class CustomMetricsSuite extends SparkFunSuite {
assert(metric.aggregateTaskMetrics(metricValues2) == "0")
}
}
+
+private[spark] class TestCustomSumMetric extends CustomSumMetric {
+ override def name(): String = "CustomSumMetric"
+ override def description(): String = "Sum up CustomMetric"
+}
+
+private[spark] class TestCustomAvgMetric extends CustomAvgMetric {
+ override def name(): String = "CustomAvgMetric"
+ override def description(): String = "Average CustomMetric"
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
index a58265124d..612b74a661 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
@@ -37,7 +37,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.quietly
-import org.apache.spark.sql.connector.{CustomMetric, CustomTaskMetric, RangeInputPartition, SimpleScanBuilder}
+import org.apache.spark.sql.connector.{RangeInputPartition, SimpleScanBuilder}
+import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index b85a668e5b..90127557f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -117,6 +117,21 @@ class SQLConfSuite extends QueryTest with SharedSparkSession {
}
}
+ test(s"SPARK-35168: ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} should respect" +
+ s" ${SQLConf.SHUFFLE_PARTITIONS.key}") {
+ spark.sessionState.conf.clear()
+ try {
+ sql(s"SET ${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key}=true")
+ sql(s"SET ${SQLConf.COALESCE_PARTITIONS_ENABLED.key}=true")
+ sql(s"SET ${SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key}=1")
+ sql(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=2")
+ checkAnswer(sql(s"SET ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}"),
+ Row(SQLConf.SHUFFLE_PARTITIONS.key, "2"))
+ } finally {
+ spark.sessionState.conf.clear()
+ }
+ }
+
test("SPARK-31234: reset will not change static sql configs and spark core configs") {
val conf = spark.sparkContext.getConf.getAll.toMap
val appName = conf.get("spark.app.name")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
index 02f91399fc..0e2fcfbd46 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
@@ -54,27 +54,31 @@ class ContinuousSuiteBase extends StreamTest {
protected def waitForRateSourceCommittedValue(
query: ContinuousExecution,
- desiredValue: Long,
+ partitionIdToDesiredValue: Map[Int, Long],
maxWaitTimeMs: Long): Unit = {
- def readHighestCommittedValue(c: ContinuousExecution): Option[Long] = {
+ def readCommittedValues(c: ContinuousExecution): Option[Map[Int, Long]] = {
c.committedOffsets.lastOption.map { case (_, offset) =>
offset match {
case o: RateStreamOffset =>
- o.partitionToValueAndRunTimeMs.map {
- case (_, ValueRunTimeMsPair(value, _)) => value
- }.max
+ o.partitionToValueAndRunTimeMs.mapValues(_.value).toMap
}
}
}
+ def reachDesiredValues: Boolean = {
+ val committedValues = readCommittedValues(query).getOrElse(Map.empty)
+ partitionIdToDesiredValue.forall { case (key, value) =>
+ committedValues.contains(key) && committedValues(key) > value
+ }
+ }
+
val maxWait = System.currentTimeMillis() + maxWaitTimeMs
- while (System.currentTimeMillis() < maxWait &&
- readHighestCommittedValue(query).getOrElse(Long.MinValue) < desiredValue) {
+ while (System.currentTimeMillis() < maxWait && !reachDesiredValues) {
Thread.sleep(100)
}
if (System.currentTimeMillis() > maxWait) {
logWarning(s"Couldn't reach desired value in $maxWaitTimeMs milliseconds!" +
- s"Current highest committed value is ${readHighestCommittedValue(query)}")
+ s"Current committed values is ${readCommittedValues(query)}")
}
}
@@ -264,7 +268,7 @@ class ContinuousSuite extends ContinuousSuiteBase {
val expected = Set(0, 1, 2, 3)
val continuousExecution =
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.asInstanceOf[ContinuousExecution]
- waitForRateSourceCommittedValue(continuousExecution, expected.max, 20 * 1000)
+ waitForRateSourceCommittedValue(continuousExecution, Map(0 -> 2, 1 -> 3), 20 * 1000)
query.stop()
val results = spark.read.table("noharness").collect()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index 4c5c5e63ce..49e5218ea3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
@@ -29,8 +29,8 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
-import org.apache.spark.sql.connector.{FakeV2Provider, InMemoryTableCatalog, InMemoryTableSessionCatalog}
-import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table, TableCapability, V2TableWithV1Fallback}
+import org.apache.spark.sql.connector.{FakeV2Provider, InMemoryTableSessionCatalog}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, SupportsRead, Table, TableCapability, V2TableWithV1Fallback}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.streaming.{MemoryStream, MemoryStreamScanBuilder}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index 273658fcfa..41d1156875 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.test
import java.io.File
-import java.util.Locale
+import java.util.{Locale, Random}
import java.util.concurrent.ConcurrentLinkedQueue
import scala.collection.JavaConverters._
@@ -1219,4 +1219,40 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with
}
}
}
+
+ test("SPARK-26164: Allow concurrent writers for multiple partitions and buckets") {
+ withTable("t1", "t2") {
+ // Uses fixed seed to ensure reproducible test execution
+ val r = new Random(31)
+ val df = spark.range(200).map(_ => {
+ val n = r.nextInt()
+ (n, n.toString, n % 5)
+ }).toDF("k1", "k2", "part")
+ df.write.format("parquet").saveAsTable("t1")
+ spark.sql("CREATE TABLE t2(k1 int, k2 string, part int) USING parquet PARTITIONED " +
+ "BY (part) CLUSTERED BY (k1) INTO 3 BUCKETS")
+ val queryToInsertTable = "INSERT OVERWRITE TABLE t2 SELECT k1, k2, part FROM t1"
+
+ Seq(
+ // Single writer
+ 0,
+ // Concurrent writers without fallback
+ 200,
+ // concurrent writers with fallback
+ 3
+ ).foreach { maxWriters =>
+ withSQLConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS.key -> maxWriters.toString) {
+ spark.sql(queryToInsertTable).collect()
+ checkAnswer(spark.table("t2").orderBy("k1"),
+ spark.table("t1").orderBy("k1"))
+
+ withSQLConf(SQLConf.MAX_RECORDS_PER_FILE.key -> "1") {
+ spark.sql(queryToInsertTable).collect()
+ checkAnswer(spark.table("t2").orderBy("k1"),
+ spark.table("t1").orderBy("k1"))
+ }
+ }
+ }
+ }
+ }
}
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 4108d0f04b..729d3f4142 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -72,6 +72,13 @@
test-jar
test
+
+ org.apache.parquet
+ parquet-hadoop
+ ${parquet.version}
+ test-jar
+ test
+