From 7556ffa67dbf5e206dc63d0337c33a0ca033ff83 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 30 Aug 2023 14:47:28 +0200 Subject: [PATCH] [SPARK-45022][SQL] Provide context for dataset API errors --- .../java/org/apache/spark/QueryContext.java | 12 + .../org/apache/spark/QueryContextType.java | 31 + .../apache/spark/SparkThrowableHelper.scala | 20 +- .../CheckConnectJvmClientCompatibility.scala | 2 + .../org/apache/spark/SparkFunSuite.scala | 57 +- .../apache/spark/SparkThrowableSuite.scala | 53 + .../spark/sql/catalyst/parser/parsers.scala | 7 +- ...QueryContext.scala => QueryContexts.scala} | 52 +- .../spark/sql/catalyst/trees/origin.scala | 27 +- .../spark/sql/catalyst/util/MathUtils.scala | 16 +- .../catalyst/util/SparkDateTimeUtils.scala | 6 +- .../spark/sql/errors/DataTypeErrors.scala | 12 +- .../spark/sql/errors/DataTypeErrorsBase.scala | 9 +- .../spark/sql/errors/ExecutionErrors.scala | 11 +- .../org/apache/spark/sql/types/Decimal.scala | 6 +- .../spark/sql/catalyst/expressions/Cast.scala | 8 +- .../sql/catalyst/expressions/Expression.scala | 10 +- .../expressions/aggregate/Average.scala | 5 +- .../catalyst/expressions/aggregate/Sum.scala | 5 +- .../sql/catalyst/expressions/arithmetic.scala | 4 +- .../expressions/collectionOperations.scala | 7 +- .../expressions/complexTypeExtractors.scala | 4 +- .../expressions/decimalExpressions.scala | 12 +- .../expressions/intervalExpressions.scala | 6 +- .../expressions/mathExpressions.scala | 6 +- .../expressions/stringExpressions.scala | 5 +- .../sql/catalyst/util/DateTimeUtils.scala | 6 +- .../sql/catalyst/util/NumberConverter.scala | 6 +- .../sql/catalyst/util/UTF8StringUtils.scala | 12 +- .../sql/errors/QueryExecutionErrors.scala | 30 +- .../sql/catalyst/analysis/AnalysisTest.scala | 4 +- .../analysis/V2WriteAnalysisSuite.scala | 5 +- .../scala/org/apache/spark/sql/Column.scala | 196 ++-- .../apache/spark/sql/execution/subquery.scala | 5 +- .../org/apache/spark/sql/functions.scala | 999 +++++++++++------- .../scala/org/apache/spark/sql/package.scala | 22 + .../spark/sql/DataFrameAggregateSuite.scala | 8 +- .../spark/sql/DataFrameFunctionsSuite.scala | 147 ++- .../sql/DataFrameWindowFunctionsSuite.scala | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 17 + .../spark/sql/GeneratorFunctionSuite.scala | 10 +- .../org/apache/spark/sql/QueryTest.scala | 12 + .../spark/sql/StringFunctionsSuite.scala | 9 +- .../errors/QueryCompilationErrorsSuite.scala | 5 +- .../QueryExecutionAnsiErrorsSuite.scala | 63 ++ .../spark/sql/execution/SQLViewSuite.scala | 4 +- 46 files changed, 1300 insertions(+), 657 deletions(-) create mode 100644 common/utils/src/main/java/org/apache/spark/QueryContextType.java rename sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/{SQLQueryContext.scala => QueryContexts.scala} (73%) diff --git a/common/utils/src/main/java/org/apache/spark/QueryContext.java b/common/utils/src/main/java/org/apache/spark/QueryContext.java index de5b29d02951d..de79c80ffb83d 100644 --- a/common/utils/src/main/java/org/apache/spark/QueryContext.java +++ b/common/utils/src/main/java/org/apache/spark/QueryContext.java @@ -27,6 +27,9 @@ */ @Evolving public interface QueryContext { + // The type of this query context. + QueryContextType contextType(); + // The object type of the query which throws the exception. // If the exception is directly from the main query, it should be an empty string. // Otherwise, it should be the exact object type in upper case. For example, a "VIEW". @@ -45,4 +48,13 @@ public interface QueryContext { // The corresponding fragment of the query which throws the exception. String fragment(); + + // The Spark code (API) that caused throwing the exception. + String code(); + + // The user code (call site of the API) that caused throwing the exception. + String callSite(); + + // Summary of the exception cause. + String summary(); } diff --git a/common/utils/src/main/java/org/apache/spark/QueryContextType.java b/common/utils/src/main/java/org/apache/spark/QueryContextType.java new file mode 100644 index 0000000000000..d7a28e63b79b5 --- /dev/null +++ b/common/utils/src/main/java/org/apache/spark/QueryContextType.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.annotation.Evolving; + +/** + * The type of {@link QueryContext}. + * + * @since 3.5.0 + */ +@Evolving +public enum QueryContextType { + SQL, + Dataset +} diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index 0f329b5655b32..cb508be6db47b 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -104,13 +104,19 @@ private[spark] object SparkThrowableHelper { g.writeArrayFieldStart("queryContext") e.getQueryContext.foreach { c => g.writeStartObject() - g.writeStringField("objectType", c.objectType()) - g.writeStringField("objectName", c.objectName()) - val startIndex = c.startIndex() + 1 - if (startIndex > 0) g.writeNumberField("startIndex", startIndex) - val stopIndex = c.stopIndex() + 1 - if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) - g.writeStringField("fragment", c.fragment()) + c.contextType() match { + case QueryContextType.SQL => + g.writeStringField("objectType", c.objectType()) + g.writeStringField("objectName", c.objectName()) + val startIndex = c.startIndex() + 1 + if (startIndex > 0) g.writeNumberField("startIndex", startIndex) + val stopIndex = c.stopIndex() + 1 + if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) + g.writeStringField("fragment", c.fragment()) + case QueryContextType.Dataset => + g.writeStringField("code", c.code()) + g.writeStringField("callSite", c.callSite()) + } g.writeEndObject() } g.writeEndArray() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 1e536cd37fec1..8cf238a946c4c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -209,6 +209,8 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.TypedColumn.withExprTyped"), + // KeyValueGroupedDataset ProblemFilters.exclude[Problem]( "org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"), diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index f5819b9508777..0b6e804664001 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -318,7 +318,7 @@ abstract class SparkFunSuite sqlState: Option[String] = None, parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, - queryContext: Array[QueryContext] = Array.empty): Unit = { + queryContext: Array[ExpectedContext] = Array.empty): Unit = { assert(exception.getErrorClass === errorClass) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala @@ -340,16 +340,23 @@ abstract class SparkFunSuite val actualQueryContext = exception.getQueryContext() assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context") actualQueryContext.zip(queryContext).foreach { case (actual, expected) => - assert(actual.objectType() === expected.objectType(), - "Invalid objectType of a query context Actual:" + actual.toString) - assert(actual.objectName() === expected.objectName(), - "Invalid objectName of a query context. Actual:" + actual.toString) - assert(actual.startIndex() === expected.startIndex(), - "Invalid startIndex of a query context. Actual:" + actual.toString) - assert(actual.stopIndex() === expected.stopIndex(), - "Invalid stopIndex of a query context. Actual:" + actual.toString) - assert(actual.fragment() === expected.fragment(), - "Invalid fragment of a query context. Actual:" + actual.toString) + if (actual.contextType() == QueryContextType.SQL) { + assert(actual.objectType() === expected.objectType, + "Invalid objectType of a query context Actual:" + actual.toString) + assert(actual.objectName() === expected.objectName, + "Invalid objectName of a query context. Actual:" + actual.toString) + assert(actual.startIndex() === expected.startIndex, + "Invalid startIndex of a query context. Actual:" + actual.toString) + assert(actual.stopIndex() === expected.stopIndex, + "Invalid stopIndex of a query context. Actual:" + actual.toString) + assert(actual.fragment() === expected.fragment, + "Invalid fragment of a query context. Actual:" + actual.toString) + } else if (actual.contextType() == QueryContextType.Dataset) { + assert(actual.code() === expected.code, + "Invalid code of a query context. Actual:" + actual.toString) + assert(actual.callSite().matches(expected.callSitePattern), + "Invalid callSite of a query context. Actual:" + actual.toString) + } } } @@ -365,21 +372,21 @@ abstract class SparkFunSuite errorClass: String, sqlState: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, sqlState: String, - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, Map.empty, false, Array(context)) protected def checkError( @@ -387,7 +394,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, false, Array(context)) @@ -402,7 +409,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, matchPVals = true, Array(context)) @@ -433,12 +440,28 @@ abstract class SparkFunSuite objectName: String, startIndex: Int, stopIndex: Int, - fragment: String) extends QueryContext + fragment: String, + code: String, + callSitePattern: String + ) object ExpectedContext { def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { ExpectedContext("", "", start, stop, fragment) } + + def apply( + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String): ExpectedContext = { + new ExpectedContext(objectType, objectName, startIndex, stopIndex, fragment, "", "") + } + + def apply(code: String, callSitePattern: String): ExpectedContext = { + new ExpectedContext("", "", -1, -1, "", code, callSitePattern) + } } class LogAppender(msg: String = "", maxEvents: Int = 1000) diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 57c4fe31b3b92..065880175746e 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -460,11 +460,15 @@ class SparkThrowableSuite extends SparkFunSuite { test("Get message in the specified format") { import ErrorMessageFormat._ class TestQueryContext extends QueryContext { + override val contextType = QueryContextType.SQL override val objectName = "v1" override val objectType = "VIEW" override val startIndex = 2 override val stopIndex = -1 override val fragment = "1 / 0" + override def code: String = throw new UnsupportedOperationException + override def callSite: String = throw new UnsupportedOperationException + override val summary = "" } val e = new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", @@ -532,6 +536,55 @@ class SparkThrowableSuite extends SparkFunSuite { | "message" : "Test message" | } |}""".stripMargin) + + class TestQueryContext2 extends QueryContext { + override val contextType = QueryContextType.Dataset + override def objectName: String = throw new UnsupportedOperationException + override def objectType: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + override def fragment: String = throw new UnsupportedOperationException + override val code: String = "div" + override val callSite: String = "SimpleApp$.main(SimpleApp.scala:9)" + override val summary = "" + } + val e4 = new SparkArithmeticException( + errorClass = "DIVIDE_BY_ZERO", + messageParameters = Map("config" -> "CONFIG"), + context = Array(new TestQueryContext2), + summary = "Query summary") + + assert(SparkThrowableHelper.getMessage(e4, PRETTY) === + "[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 " + + "and return NULL instead. If necessary set CONFIG to \"false\" to bypass this error." + + "\nQuery summary") + // scalastyle:off line.size.limit + assert(SparkThrowableHelper.getMessage(e4, MINIMAL) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "code" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + assert(SparkThrowableHelper.getMessage(e4, STANDARD) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "messageTemplate" : "Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set to \"false\" to bypass this error.", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "code" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + // scalastyle:on line.size.limit } test("overwrite error classes") { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index c3a051be89bcc..dca111e55c283 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -26,7 +26,7 @@ import org.antlr.v4.runtime.tree.TerminalNodeImpl import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, SQLQueryContext, WithOrigin} import org.apache.spark.sql.catalyst.util.SparkParserUtils import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf @@ -229,7 +229,7 @@ class ParseException( val builder = new StringBuilder builder ++= "\n" ++= message start match { - case Origin(Some(l), Some(p), _, _, _, _, _) => + case Origin(Some(l), Some(p), _, _, _, _, _, _) => builder ++= s"(line $l, pos $p)\n" command.foreach { cmd => val (above, below) = cmd.split("\n").splitAt(l) @@ -262,8 +262,7 @@ class ParseException( object ParseException { def getQueryContext(): Array[QueryContext] = { - val context = CurrentOrigin.get.context - if (context.isValid) Array(context) else Array.empty + Some(CurrentOrigin.get.context).collect { case b: SQLQueryContext if b.isValid => b }.toArray } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala similarity index 73% rename from sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 99889cf7dae96..d7c51138322a0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.trees -import org.apache.spark.QueryContext +import org.apache.spark.{QueryContext, QueryContextType} /** The class represents error context of a SQL query. */ case class SQLQueryContext( @@ -28,11 +28,12 @@ case class SQLQueryContext( sqlText: Option[String], originObjectType: Option[String], originObjectName: Option[String]) extends QueryContext { + override val contextType = QueryContextType.SQL - override val objectType = originObjectType.getOrElse("") - override val objectName = originObjectName.getOrElse("") - override val startIndex = originStartIndex.getOrElse(-1) - override val stopIndex = originStopIndex.getOrElse(-1) + val objectType = originObjectType.getOrElse("") + val objectName = originObjectName.getOrElse("") + val startIndex = originStartIndex.getOrElse(-1) + val stopIndex = originStopIndex.getOrElse(-1) /** * The SQL query context of current node. For example: @@ -40,7 +41,7 @@ case class SQLQueryContext( * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i * ^^^^^^^^^^^^^^^ */ - lazy val summary: String = { + override lazy val summary: String = { // If the query context is missing or incorrect, simply return an empty string. if (!isValid) { "" @@ -116,7 +117,7 @@ case class SQLQueryContext( } /** Gets the textual fragment of a SQL query. */ - override lazy val fragment: String = { + lazy val fragment: String = { if (!isValid) { "" } else { @@ -128,6 +129,43 @@ case class SQLQueryContext( sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined && originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length && originStartIndex.get <= originStopIndex.get + } + + override def code: String = throw new UnsupportedOperationException + override def callSite: String = throw new UnsupportedOperationException +} + +case class DatasetQueryContext( + override val code: String, + override val callSite: String) extends QueryContext { + override val contextType = QueryContextType.Dataset + + override def objectType: String = throw new UnsupportedOperationException + override def objectName: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + override def fragment: String = throw new UnsupportedOperationException + + override lazy val summary: String = { + val builder = new StringBuilder + builder ++= "== Dataset ==\n" + builder ++= "\"" + + builder ++= code + builder ++= "\"" + builder ++= " was called from " + builder ++= callSite + builder += '\n' + builder.result() + } +} + +object DatasetQueryContext { + def apply(elements: Array[StackTraceElement]): DatasetQueryContext = { + val methodName = elements(0).getMethodName + val code = (if (methodName.startsWith("$")) methodName.substring(1) else methodName) + val callSite = elements(1).toString + DatasetQueryContext(code, callSite) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index ec3e627ac9585..8c5887bda3912 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -30,15 +30,28 @@ case class Origin( stopIndex: Option[Int] = None, sqlText: Option[String] = None, objectType: Option[String] = None, - objectName: Option[String] = None) { + objectName: Option[String] = None, + stackTrace: Option[Array[StackTraceElement]] = None) { - lazy val context: SQLQueryContext = SQLQueryContext( - line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) - - def getQueryContext: Array[QueryContext] = if (context.isValid) { - Array(context) + lazy val context: QueryContext = if (stackTrace.isDefined) { + DatasetQueryContext(stackTrace.get) } else { - Array.empty + SQLQueryContext( + line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) + } + + def getQueryContext: Array[QueryContext] = { + Some(context).filter { + case s: SQLQueryContext => s.isValid + case _ => true + }.toArray + } +} + +object Origin { + def fromCurrentStackTrace(framesToDrop: Int = 0): Origin = { + val from = framesToDrop + 2 + Origin(stackTrace = Some(Thread.currentThread().getStackTrace.slice(from, from + 2))) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 7c1b37e9e5815..99caef978bb4a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.ExecutionErrors /** @@ -27,37 +27,37 @@ object MathUtils { def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b)) - def addExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def addExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b)) - def addExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def addExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def subtractExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def subtractExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def multiplyExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def multiplyExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } @@ -78,7 +78,7 @@ object MathUtils { def withOverflow[A]( f: => A, hint: String = "", - context: SQLQueryContext = null): A = { + context: QueryContext = null): A = { try { f } catch { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala index 698e7b37a9ef0..f8a9274a5646c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import sun.util.calendar.ZoneInfo -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, rebaseJulianToGregorianDays, rebaseJulianToGregorianMicros} import org.apache.spark.sql.errors.ExecutionErrors @@ -355,7 +355,7 @@ trait SparkDateTimeUtils { def stringToDateAnsi( s: UTF8String, - context: SQLQueryContext = null): Int = { + context: QueryContext = null): Int = { stringToDate(s).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, DateType, context) } @@ -567,7 +567,7 @@ trait SparkDateTimeUtils { def stringToTimestampAnsi( s: UTF8String, timeZoneId: ZoneId, - context: SQLQueryContext = null): Long = { + context: QueryContext = null): Long = { stringToTimestamp(s, timeZoneId).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampType, context) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index 5e52e283338d3..b30f7b7a00e91 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.sql.errors -import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.catalyst.util.QuotingUtils import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLSchema import org.apache.spark.sql.types.{DataType, Decimal, StringType} @@ -191,7 +191,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -199,7 +199,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -207,7 +207,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -222,7 +222,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { val convertedValueStr = "'" + s.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala index aed3c681365dc..7e039cec980cd 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.errors import java.util.Locale import org.apache.spark.QueryContext -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils} import org.apache.spark.sql.types.{AbstractDataType, DataType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -89,11 +88,11 @@ private[sql] trait DataTypeErrorsBase { "\"" + elem + "\"" } - def getSummary(sqlContext: SQLQueryContext): String = { - if (sqlContext == null) "" else sqlContext.summary + def getSummary(context: QueryContext): String = { + if (context == null) "" else context.summary } - def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = { - if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext]) + def getQueryContext(context: QueryContext): Array[QueryContext] = { + if (context == null) Array.empty else Array(context) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index c8321e81027ba..394e56062071b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -21,9 +21,8 @@ import java.time.temporal.ChronoField import org.apache.arrow.vector.types.pojo.ArrowType -import org.apache.spark.{SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.sql.catalyst.WalkedTypePath -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{DataType, DoubleType, StringType, UserDefinedType} import org.apache.spark.unsafe.types.UTF8String @@ -83,14 +82,14 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def invalidInputInCastToDatetimeError( value: UTF8String, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), StringType, to, context) } def invalidInputInCastToDatetimeError( value: Double, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), DoubleType, to, context) } @@ -98,7 +97,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { sqlValue: String, from: DataType, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { new SparkDateTimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -113,7 +112,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def arithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index afe73635a6824..c1661038025c0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -21,8 +21,8 @@ import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, Roundin import scala.util.Try +import org.apache.spark.QueryContext import org.apache.spark.annotation.Unstable -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.unsafe.types.UTF8String @@ -341,7 +341,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { scale: Int, roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, nullOnOverflow: Boolean = true, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { val copy = clone() if (copy.changePrecision(precision, scale, roundMode)) { copy @@ -617,7 +617,7 @@ object Decimal { def fromStringANSI( str: UTF8String, to: DecimalType = DecimalType.USER_DEFAULT, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { try { val bigDecimal = stringToJavaBigDecimal(str) // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b975dc3c7a596..4925b87afdc4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,13 +21,13 @@ import java.time.{ZoneId, ZoneOffset} import java.util.Locale import java.util.concurrent.TimeUnit._ -import org.apache.spark.SparkArithmeticException +import org.apache.spark.{QueryContext, SparkArithmeticException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.{PhysicalFractionalType, PhysicalIntegralType, PhysicalNumericType} import org.apache.spark.sql.catalyst.util._ @@ -524,7 +524,7 @@ case class Cast( } } - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -942,7 +942,7 @@ case class Cast( private[this] def toPrecision( value: Decimal, decimalType: DecimalType, - context: SQLQueryContext): Decimal = + context: QueryContext): Decimal = value.toPrecision( decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c2330cdb59dbc..3159b0e21e5d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.SparkException +import org.apache.spark.{QueryContext, SparkException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.truncatedString @@ -613,11 +613,11 @@ abstract class UnaryExpression extends Expression with UnaryLike[Expression] { * to executors. It will also be kept after rule transforms. */ trait SupportQueryContext extends Expression with Serializable { - protected var queryContext: Option[SQLQueryContext] = initQueryContext() + protected var queryContext: Option[QueryContext] = initQueryContext() - def initQueryContext(): Option[SQLQueryContext] + def initQueryContext(): Option[QueryContext] - def getContextOrNull(): SQLQueryContext = queryContext.orNull + def getContextOrNull(): QueryContext = queryContext.orNull def getContextOrNullCode(ctx: CodegenContext, withErrorContext: Boolean = true): String = { if (withErrorContext && queryContext.isDefined) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index fd6131f185606..fe30e2ea6f3ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -134,7 +135,7 @@ case class Average( override protected def withNewChildInternal(newChild: Expression): Average = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index e3881520e4902..dfd41ad12a280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{EvalMode, _} -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -186,7 +187,7 @@ case class Sum( // The flag `evalMode` won't be shown in the `toString` or `toAggString` methods override def flatArguments: Iterator[Any] = Iterator(child) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 2d9bccc0854a3..09079dd61ca4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions import scala.math.{max, min} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern, UNARY_POSITIVE} import org.apache.spark.sql.catalyst.types.{PhysicalDecimalType, PhysicalFractionalType, PhysicalIntegerType, PhysicalIntegralType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.{IntervalMathUtils, IntervalUtils, MathUtils, TypeUtils} @@ -264,7 +264,7 @@ abstract class BinaryArithmetic extends BinaryOperator final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 957aa1ab2d583..686c997e23eea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -22,13 +22,14 @@ import java.util.Comparator import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern} import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType} import org.apache.spark.sql.catalyst.util._ @@ -2525,7 +2526,7 @@ case class ElementAt( override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ElementAt = copy(left = newLeft, right = newRight) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError && left.resolved && left.dataType.isInstanceOf[ArrayType]) { Some(origin.context) } else { @@ -5046,7 +5047,7 @@ case class ArrayInsert( newSrcArrayExpr: Expression, newPosExpr: Expression, newItemExpr: Expression): ArrayInsert = copy(srcArrayExpr = newSrcArrayExpr, posExpr = newPosExpr, itemExpr = newItemExpr) - override def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + override def initQueryContext(): Option[QueryContext] = Some(origin.context) } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index e22af21daaad5..edd824b2d111e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -316,7 +316,7 @@ case class GetArrayItem( newLeft: Expression, newRight: Expression): GetArrayItem = copy(child = newLeft, ordinal = newRight) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 378920856eb11..5f13d397d1bf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.types.PhysicalDecimalType import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors @@ -146,7 +146,7 @@ case class CheckOverflow( override protected def withNewChildInternal(newChild: Expression): CheckOverflow = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (!nullOnOverflow) { + override def initQueryContext(): Option[QueryContext] = if (!nullOnOverflow) { Some(origin.context) } else { None @@ -158,7 +158,7 @@ case class CheckOverflowInSum( child: Expression, dataType: DecimalType, nullOnOverflow: Boolean, - context: SQLQueryContext) extends UnaryExpression with SupportQueryContext { + context: QueryContext) extends UnaryExpression with SupportQueryContext { override def nullable: Boolean = true @@ -210,7 +210,7 @@ case class CheckOverflowInSum( override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) } /** @@ -256,12 +256,12 @@ case class DecimalDivideWithOverflowCheck( left: Expression, right: Expression, override val dataType: DecimalType, - context: SQLQueryContext, + context: QueryContext, nullOnOverflow: Boolean) extends BinaryExpression with ExpectsInputTypes with SupportQueryContext { override def nullable: Boolean = nullOnOverflow override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) def decimalMethod: String = "$div" override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 5378639e6838b..13676733a9bad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -22,8 +22,8 @@ import java.util.Locale import com.google.common.math.{DoubleMath, IntMath, LongMath} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants.MONTHS_PER_YEAR import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ @@ -604,7 +604,7 @@ trait IntervalDivide { minValue: Any, num: Expression, numValue: Any, - context: SQLQueryContext): Unit = { + context: QueryContext): Unit = { if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { if (numValue.asInstanceOf[Number].longValue() == -1) { throw QueryExecutionErrors.intervalArithmeticOverflowError( @@ -616,7 +616,7 @@ trait IntervalDivide { def divideByZeroCheck( dataType: DataType, num: Any, - context: SQLQueryContext): Unit = dataType match { + context: QueryContext): Unit = dataType match { case _: DecimalType => if (num.asInstanceOf[Decimal].isZero) { throw QueryExecutionErrors.intervalDividedByZeroError(context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 89f354db5a97c..033eb5bba1f85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} import java.util.Locale +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -480,7 +480,7 @@ case class Conv( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird) - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -1523,7 +1523,7 @@ abstract class RoundBase(child: Expression, scale: Expression, private lazy val scaleV: Any = scale.eval(EmptyRow) protected lazy val _scale: Int = scaleV.asInstanceOf[Int] - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (ansiEnabled) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 46f8e1a9d673d..3d4964a3acaee 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import scala.collection.mutable.ArrayBuffer +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -411,7 +412,7 @@ case class Elt( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Elt = copy(children = newChildren) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 23bbc91c16d54..8fabb44876208 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit._ import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{Decimal, DoubleExactNumeric, TimestampNTZType, TimestampType} @@ -70,7 +70,7 @@ object DateTimeUtils extends SparkDateTimeUtils { // the "GMT" string. For example, it returns 2000-01-01T00:00+01:00 for 2000-01-01T00:00GMT+01:00. def cleanLegacyTimestampStr(s: UTF8String): UTF8String = s.replace(gmtUtf8, UTF8String.EMPTY_UTF8) - def doubleToTimestampAnsi(d: Double, context: SQLQueryContext): Long = { + def doubleToTimestampAnsi(d: Double, context: QueryContext): Long = { if (d.isNaN || d.isInfinite) { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(d, TimestampType, context) } else { @@ -91,7 +91,7 @@ object DateTimeUtils extends SparkDateTimeUtils { def stringToTimestampWithoutTimeZoneAnsi( s: UTF8String, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { stringToTimestampWithoutTimeZone(s, true).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampNTZType, context) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index 59765cde1f926..2730ab8f4b890 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.unsafe.types.UTF8String @@ -54,7 +54,7 @@ object NumberConverter { fromPos: Int, value: Array[Byte], ansiEnabled: Boolean, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { var v: Long = 0L // bound will always be positive since radix >= 2 // Note that: -1 is equivalent to 11111111...1111 which is the largest unsigned long value @@ -134,7 +134,7 @@ object NumberConverter { fromBase: Int, toBase: Int, ansiEnabled: Boolean, - context: SQLQueryContext): UTF8String = { + context: QueryContext): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala index f7800469c3528..1c3a5075dab2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, ShortType} import org.apache.spark.unsafe.types.UTF8String @@ -27,21 +27,21 @@ import org.apache.spark.unsafe.types.UTF8String */ object UTF8StringUtils { - def toLongExact(s: UTF8String, context: SQLQueryContext): Long = + def toLongExact(s: UTF8String, context: QueryContext): Long = withException(s.toLongExact, context, LongType, s) - def toIntExact(s: UTF8String, context: SQLQueryContext): Int = + def toIntExact(s: UTF8String, context: QueryContext): Int = withException(s.toIntExact, context, IntegerType, s) - def toShortExact(s: UTF8String, context: SQLQueryContext): Short = + def toShortExact(s: UTF8String, context: QueryContext): Short = withException(s.toShortExact, context, ShortType, s) - def toByteExact(s: UTF8String, context: SQLQueryContext): Byte = + def toByteExact(s: UTF8String, context: QueryContext): Byte = withException(s.toByteExact, context, ByteType, s) private def withException[A]( f: => A, - context: SQLQueryContext, + context: QueryContext, to: DataType, s: UTF8String): A = { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 2d655be0e700c..ec082f8a3cb55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode} +import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode} import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -103,7 +103,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -117,7 +117,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputSyntaxForBooleanError( s: UTF8String, - context: SQLQueryContext): SparkRuntimeException = { + context: QueryContext): SparkRuntimeException = { new SparkRuntimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -132,7 +132,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -193,15 +193,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = e) } - def divideByZeroError(context: SQLQueryContext): ArithmeticException = { + def divideByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", messageParameters = Map("config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = getQueryContext(context), + context = Array(context), summary = getSummary(context)) } - def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = { + def intervalDividedByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "INTERVAL_DIVIDED_BY_ZERO", messageParameters = Map.empty, @@ -212,7 +212,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidArrayIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX", messageParameters = Map( @@ -226,7 +226,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidElementAtIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", messageParameters = Map( @@ -291,15 +291,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ansiIllegalArgumentError(e.getMessage) } - def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = { + def overflowInSumOfDecimalError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in sum of decimals", context = context) } - def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = { + def overflowInIntegralDivideError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in integral divide", "try_divide", context) } - def overflowInConvError(context: SQLQueryContext): ArithmeticException = { + def overflowInConvError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in function conv()", context = context) } @@ -629,7 +629,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def intervalArithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" @@ -1395,7 +1395,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "prettyName" -> prettyName)) } - def invalidIndexOfZeroError(context: SQLQueryContext): RuntimeException = { + def invalidIndexOfZeroError(context: QueryContext): RuntimeException = { new SparkRuntimeException( errorClass = "INVALID_INDEX_OF_ZERO", cause = null, @@ -2553,7 +2553,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def multipleRowScalarSubqueryError(context: SQLQueryContext): Throwable = { + def multipleRowScalarSubqueryError(context: QueryContext): Throwable = { new SparkException( errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", messageParameters = Map.empty, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 997308c6ef44f..ba4e7b279f512 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -35,8 +35,6 @@ import org.apache.spark.sql.types.StructType trait AnalysisTest extends PlanTest { - import org.apache.spark.QueryContext - protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil protected def createTempView( @@ -177,7 +175,7 @@ trait AnalysisTest extends PlanTest { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val analyzer = getAnalyzer diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index d91a080d8fe89..3fd0c1ee5de4b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal} @@ -159,7 +158,7 @@ abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) { super.assertAnalysisErrorClass( @@ -196,7 +195,7 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.STRICT.toString) { super.assertAnalysisErrorClass( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 39e4815fc57c5..97874017f2c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.TypedAggUtils import org.apache.spark.sql.types._ @@ -86,6 +85,10 @@ class TypedColumn[-T, U]( new TypedColumn[T, U](newExpr, encoder) } + override protected def withExprTyped(f: => Expression, framesToDrop: Int = 0) = { + withOrigin(new TypedColumn[T, U](f, encoder), framesToDrop + 1) + } + /** * Gives the [[TypedColumn]] a name (alias). * If the current `TypedColumn` has metadata associated with it, this metadata will be propagated @@ -94,9 +97,7 @@ class TypedColumn[-T, U]( * @group expr_ops * @since 2.0.0 */ - override def name(alias: String): TypedColumn[T, U] = - new TypedColumn[T, U](super.name(alias).expr, encoder) - + override def name(alias: String): TypedColumn[T, U] = withExprTyped { nameImpl(alias) } } /** @@ -153,9 +154,6 @@ class Column(val expr: Expression) extends Logging { case a: AttributeReference => Column.stripColumnReferenceMetadata(a) } - /** Creates a column based on the given expression. */ - private def withExpr(newExpr: Expression): Column = new Column(newExpr) - /** * Returns the expression for this column either with an existing or auto assigned name. */ @@ -201,7 +199,7 @@ class Column(val expr: Expression) extends Logging { * @since 1.4.0 */ def apply(extraction: Any): Column = withExpr { - UnresolvedExtractValue(expr, lit(extraction).expr) + UnresolvedExtractValue(expr, litImpl(extraction).expr) } /** @@ -236,6 +234,16 @@ class Column(val expr: Expression) extends Logging { */ def unary_! : Column = withExpr { Not(expr) } + private def equalToImpl(other: Any) = { + val right = litImpl(other).expr + if (this.expr == right) { + logWarning( + s"Constructing trivially true equals predicate, '${this.expr} = $right'. " + + "Perhaps you need to use aliases.") + } + EqualTo(expr, right) + } + /** * Equality test. * {{{ @@ -250,15 +258,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def === (other: Any): Column = withExpr { - val right = lit(other).expr - if (this.expr == right) { - logWarning( - s"Constructing trivially true equals predicate, '${this.expr} = $right'. " + - "Perhaps you need to use aliases.") - } - EqualTo(expr, right) - } + def === (other: Any): Column = withExpr { equalToImpl (other) } /** * Equality test. @@ -274,7 +274,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def equalTo(other: Any): Column = this === other + def equalTo(other: Any): Column = withExpr { equalToImpl (other) } /** * Inequality test. @@ -291,7 +291,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.0.0 */ - def =!= (other: Any): Column = withExpr{ Not(EqualTo(expr, lit(other).expr)) } + def =!= (other: Any): Column = withExpr { Not(EqualTo(expr, litImpl(other).expr)) } /** * Inequality test. @@ -309,7 +309,7 @@ class Column(val expr: Expression) extends Logging { * @since 1.3.0 */ @deprecated("!== does not have the same precedence as ===, use =!= instead", "2.0.0") - def !== (other: Any): Column = this =!= other + def !== (other: Any): Column = withExpr { Not(EqualTo(expr, litImpl(other).expr)) } /** * Inequality test. @@ -326,7 +326,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def notEqual(other: Any): Column = withExpr { Not(EqualTo(expr, lit(other).expr)) } + def notEqual(other: Any): Column = withExpr { Not(EqualTo(expr, litImpl(other).expr)) } /** * Greater than. @@ -342,7 +342,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def > (other: Any): Column = withExpr { GreaterThan(expr, lit(other).expr) } + def > (other: Any): Column = withExpr { GreaterThan(expr, litImpl(other).expr) } /** * Greater than. @@ -358,7 +358,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def gt(other: Any): Column = this > other + def gt(other: Any): Column = withExpr { GreaterThan(expr, litImpl(other).expr) } /** * Less than. @@ -373,7 +373,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def < (other: Any): Column = withExpr { LessThan(expr, lit(other).expr) } + def < (other: Any): Column = withExpr { LessThan(expr, litImpl(other).expr) } /** * Less than. @@ -388,7 +388,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def lt(other: Any): Column = this < other + def lt(other: Any): Column = withExpr { LessThan(expr, litImpl(other).expr) } /** * Less than or equal to. @@ -403,7 +403,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <= (other: Any): Column = withExpr { LessThanOrEqual(expr, lit(other).expr) } + def <= (other: Any): Column = withExpr { LessThanOrEqual(expr, litImpl(other).expr) } /** * Less than or equal to. @@ -418,7 +418,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def leq(other: Any): Column = this <= other + def leq(other: Any): Column = withExpr { LessThanOrEqual(expr, litImpl(other).expr) } /** * Greater than or equal to an expression. @@ -433,7 +433,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def >= (other: Any): Column = withExpr { GreaterThanOrEqual(expr, lit(other).expr) } + def >= (other: Any): Column = withExpr { GreaterThanOrEqual(expr, litImpl(other).expr) } /** * Greater than or equal to an expression. @@ -448,31 +448,33 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def geq(other: Any): Column = this >= other + def geq(other: Any): Column = withExpr { GreaterThanOrEqual(expr, litImpl(other).expr) } - /** - * Equality test that is safe for null values. - * - * @group expr_ops - * @since 1.3.0 - */ - def <=> (other: Any): Column = withExpr { - val right = lit(other).expr + private def eqNullSafeImpl(other: Any) = { + val right = litImpl(other).expr if (this.expr == right) { logWarning( s"Constructing trivially true equals predicate, '${this.expr} <=> $right'. " + - "Perhaps you need to use aliases.") + "Perhaps you need to use aliases.") } EqualNullSafe(expr, right) } + /** + * Equality test that is safe for null values. + * + * @group expr_ops + * @since 1.3.0 + */ + def <=> (other: Any): Column = withExpr { eqNullSafeImpl(other) } + /** * Equality test that is safe for null values. * * @group java_expr_ops * @since 1.3.0 */ - def eqNullSafe(other: Any): Column = this <=> other + def eqNullSafe(other: Any): Column = withExpr { eqNullSafeImpl(other) } /** * Evaluates a list of conditions and returns one of multiple possible result expressions. @@ -497,7 +499,7 @@ class Column(val expr: Expression) extends Logging { */ def when(condition: Column, value: Any): Column = this.expr match { case CaseWhen(branches, None) => - withExpr { CaseWhen(branches :+ ((condition.expr, lit(value).expr))) } + withExpr { CaseWhen(branches :+ ((condition.expr, litImpl(value).expr))) } case CaseWhen(branches, Some(_)) => throw new IllegalArgumentException( "when() cannot be applied once otherwise() is applied") @@ -529,7 +531,7 @@ class Column(val expr: Expression) extends Logging { */ def otherwise(value: Any): Column = this.expr match { case CaseWhen(branches, None) => - withExpr { CaseWhen(branches, Option(lit(value).expr)) } + withExpr { CaseWhen(branches, Option(litImpl(value).expr)) } case CaseWhen(branches, Some(_)) => throw new IllegalArgumentException( "otherwise() can only be applied once on a Column previously generated by when()") @@ -544,8 +546,9 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.4.0 */ - def between(lowerBound: Any, upperBound: Any): Column = { - (this >= lowerBound) && (this <= upperBound) + def between(lowerBound: Any, upperBound: Any): Column = withExpr { + And(GreaterThanOrEqual(expr, litImpl(lowerBound).expr), + LessThanOrEqual(expr, litImpl(upperBound).expr)) } /** @@ -585,7 +588,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def || (other: Any): Column = withExpr { Or(expr, lit(other).expr) } + def || (other: Any): Column = withExpr { Or(expr, litImpl(other).expr) } /** * Boolean OR. @@ -600,7 +603,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def or(other: Column): Column = this || other + def or(other: Column): Column = withExpr { Or(expr, other.expr) } /** * Boolean AND. @@ -615,7 +618,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def && (other: Any): Column = withExpr { And(expr, lit(other).expr) } + def && (other: Any): Column = withExpr { And(expr, litImpl(other).expr) } /** * Boolean AND. @@ -630,7 +633,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def and(other: Column): Column = this && other + def and(other: Column): Column = withExpr { And(expr, other.expr) } /** * Sum of this expression and another expression. @@ -645,7 +648,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def + (other: Any): Column = withExpr { Add(expr, lit(other).expr) } + def + (other: Any): Column = withExpr { Add(expr, litImpl(other).expr) } /** * Sum of this expression and another expression. @@ -660,7 +663,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def plus(other: Any): Column = this + other + def plus(other: Any): Column = withExpr { Add(expr, litImpl(other).expr) } /** * Subtraction. Subtract the other expression from this expression. @@ -675,7 +678,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def - (other: Any): Column = withExpr { Subtract(expr, lit(other).expr) } + def - (other: Any): Column = withExpr { Subtract(expr, litImpl(other).expr) } /** * Subtraction. Subtract the other expression from this expression. @@ -690,7 +693,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def minus(other: Any): Column = this - other + def minus(other: Any): Column = withExpr { Subtract(expr, litImpl(other).expr) } /** * Multiplication of this expression and another expression. @@ -705,7 +708,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def * (other: Any): Column = withExpr { Multiply(expr, lit(other).expr) } + def * (other: Any): Column = withExpr { Multiply(expr, litImpl(other).expr) } /** * Multiplication of this expression and another expression. @@ -720,7 +723,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def multiply(other: Any): Column = this * other + def multiply(other: Any): Column = withExpr { Multiply(expr, litImpl(other).expr) } /** * Division this expression by another expression. @@ -735,7 +738,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def / (other: Any): Column = withExpr { Divide(expr, lit(other).expr) } + def / (other: Any): Column = withExpr { Divide(expr, litImpl(other).expr) } /** * Division this expression by another expression. @@ -750,7 +753,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def divide(other: Any): Column = this / other + def divide(other: Any): Column = withExpr { Divide(expr, litImpl(other).expr) } /** * Modulo (a.k.a. remainder) expression. @@ -758,7 +761,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def % (other: Any): Column = withExpr { Remainder(expr, lit(other).expr) } + def % (other: Any): Column = withExpr { Remainder(expr, litImpl(other).expr) } /** * Modulo (a.k.a. remainder) expression. @@ -766,7 +769,9 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def mod(other: Any): Column = this % other + def mod(other: Any): Column = withExpr { Remainder(expr, litImpl(other).expr) } + + private def isinImpl(list: Any*) = In(expr, list.map(litImpl(_).expr)) /** * A boolean expression that is evaluated to true if the value of this expression is contained @@ -784,7 +789,7 @@ class Column(val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + def isin(list: Any*): Column = withExpr { isinImpl(list: _*) } /** * A boolean expression that is evaluated to true if the value of this expression is contained @@ -801,7 +806,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.4.0 */ - def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) + def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr { + isinImpl(values.toSeq: _*) + } /** * A boolean expression that is evaluated to true if the value of this expression is contained @@ -818,7 +825,9 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 2.4.0 */ - def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) + def isInCollection(values: java.lang.Iterable[_]): Column = withExpr { + isinImpl(values.asScala.toSeq: _*) + } /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. @@ -826,7 +835,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def like(literal: String): Column = withExpr { new Like(expr, lit(literal).expr) } + def like(literal: String): Column = withExpr { new Like(expr, litImpl(literal).expr) } /** * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex @@ -835,7 +844,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def rlike(literal: String): Column = withExpr { RLike(expr, lit(literal).expr) } + def rlike(literal: String): Column = withExpr { RLike(expr, litImpl(literal).expr) } /** * SQL ILIKE expression (case insensitive LIKE). @@ -843,7 +852,8 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 3.3.0 */ - def ilike(literal: String): Column = withExpr { new ILike(expr, lit(literal).expr) } + def ilike(literal: String): Column = withExpr { new ILike(expr, litImpl(literal).expr) } + /** * An expression that gets an item at position `ordinal` out of an array, @@ -1008,7 +1018,7 @@ class Column(val expr: Expression) extends Logging { * @since 1.3.0 */ def substr(startPos: Int, len: Int): Column = withExpr { - Substring(expr, lit(startPos).expr, lit(len).expr) + Substring(expr, litImpl(startPos).expr, litImpl(len).expr) } /** @@ -1017,7 +1027,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def contains(other: Any): Column = withExpr { Contains(expr, lit(other).expr) } + def contains(other: Any): Column = withExpr { Contains(expr, litImpl(other).expr) } /** * String starts with. Returns a boolean column based on a string match. @@ -1025,7 +1035,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def startsWith(other: Column): Column = withExpr { StartsWith(expr, lit(other).expr) } + def startsWith(other: Column): Column = withExpr { StartsWith(expr, other.expr) } /** * String starts with another string literal. Returns a boolean column based on a string match. @@ -1033,7 +1043,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def startsWith(literal: String): Column = this.startsWith(lit(literal)) + def startsWith(literal: String): Column = withExpr { StartsWith(expr, litImpl(literal).expr) } /** * String ends with. Returns a boolean column based on a string match. @@ -1041,7 +1051,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def endsWith(other: Column): Column = withExpr { EndsWith(expr, lit(other).expr) } + def endsWith(other: Column): Column = withExpr { EndsWith(expr, other.expr) } /** * String ends with another string literal. Returns a boolean column based on a string match. @@ -1049,7 +1059,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def endsWith(literal: String): Column = this.endsWith(lit(literal)) + def endsWith(literal: String): Column = withExpr { EndsWith(expr, litImpl(literal).expr) } /** * Gives the column an alias. Same as `as`. @@ -1061,7 +1071,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def alias(alias: String): Column = name(alias) + def alias(alias: String): Column = withExprTyped { nameImpl(alias) } /** * Gives the column an alias. @@ -1077,7 +1087,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = name(alias) + def as(alias: String): Column = withExprTyped { nameImpl (alias) } /** * (Scala-specific) Assigns the given aliases to the results of a table generating function. @@ -1117,7 +1127,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = name(alias.name) + def as(alias: Symbol): Column = withExprTyped { nameImpl(alias.name) } /** * Gives the column an alias with metadata. @@ -1133,6 +1143,18 @@ class Column(val expr: Expression) extends Logging { Alias(expr, alias)(explicitMetadata = Some(metadata)) } + protected def withExprTyped(f: => Expression, framesToDrop: Int = 0) = { + withExpr(f, framesToDrop + 1) + } + + protected def nameImpl(alias: String) = { + // SPARK-33536: an alias is no longer a column reference. Therefore, + // we should not inherit the column reference related metadata in an alias + // so that it is not caught as a column reference in DetectAmbiguousSelfJoin. + Alias(expr, alias)( + nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) + } + /** * Gives the column a name (alias). * {{{ @@ -1147,12 +1169,12 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.0.0 */ - def name(alias: String): Column = withExpr { - // SPARK-33536: an alias is no longer a column reference. Therefore, - // we should not inherit the column reference related metadata in an alias - // so that it is not caught as a column reference in DetectAmbiguousSelfJoin. - Alias(expr, alias)( - nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) + def name(alias: String): Column = withExprTyped { nameImpl (alias) } + + private def castImpl(to: DataType) = { + val cast = Cast(expr, CharVarcharUtils.replaceCharVarcharWithStringForCast(to)) + cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) + cast } /** @@ -1169,11 +1191,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = withExpr { - val cast = Cast(expr, CharVarcharUtils.replaceCharVarcharWithStringForCast(to)) - cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) - cast - } + def cast(to: DataType): Column = withExpr { castImpl(to) } /** * Casts the column to a different data type, using the canonical string representation @@ -1187,7 +1205,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: String): Column = cast(CatalystSqlParser.parseDataType(to)) + def cast(to: String): Column = withExpr { castImpl(CatalystSqlParser.parseDataType(to)) } /** * Returns a sort expression based on the descending order of the column. @@ -1308,7 +1326,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseOR(other: Any): Column = withExpr { BitwiseOr(expr, lit(other).expr) } + def bitwiseOR(other: Any): Column = withExpr { BitwiseOr(expr, litImpl(other).expr) } /** * Compute bitwise AND of this expression with another expression. @@ -1319,7 +1337,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseAND(other: Any): Column = withExpr { BitwiseAnd(expr, lit(other).expr) } + def bitwiseAND(other: Any): Column = withExpr { BitwiseAnd(expr, litImpl(other).expr) } /** * Compute bitwise XOR of this expression with another expression. @@ -1330,7 +1348,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseXOR(other: Any): Column = withExpr { BitwiseXor(expr, lit(other).expr) } + def bitwiseXOR(other: Any): Column = withExpr { BitwiseXor(expr, litImpl(other).expr) } /** * Defines a windowing column. @@ -1346,7 +1364,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def over(window: expressions.WindowSpec): Column = window.withAggregate(this) + def over(window: expressions.WindowSpec): Column = withOrigin { window.withAggregate(this) } /** * Defines an empty analytic clause. In this case the analytic function is applied @@ -1362,7 +1380,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.0.0 */ - def over(): Column = over(Window.spec) + def over(): Column = withOrigin { Window.spec.withAggregate(this) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 41230c7792c50..b2a7f96b3b419 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution +import org.apache.spark.QueryContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, ExprId, InSet, ListQuery, Literal, PlanExpression, Predicate, SupportQueryContext} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.{LeafLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -68,7 +69,7 @@ case class ScalarSubquery( override def nullable: Boolean = true override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields) override def withNewPlan(query: BaseSubqueryExec): ScalarSubquery = copy(plan = query) - def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + def initQueryContext(): Option[QueryContext] = Some(origin.context) override lazy val canonicalized: Expression = { ScalarSubquery(plan.canonicalized.asInstanceOf[BaseSubqueryExec], ExprId(0)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a04a5e471ec59..99ccca827df2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -80,12 +80,8 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private def withExpr(expr: Expression): Column = Column(expr) - - private def withAggregateFunction( - func: AggregateFunction, - isDistinct: Boolean = false): Column = { - Column(func.toAggregateExpression(isDistinct)) + private def withAggregateFunction(func: => AggregateFunction, isDistinct: Boolean = false) = { + withExpr(func.toAggregateExpression(isDistinct), 1) } /** @@ -94,7 +90,7 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def col(colName: String): Column = Column(colName) + def col(colName: String): Column = withOrigin { Column(colName) } /** * Returns a [[Column]] based on the given column name. Alias of [[col]]. @@ -102,7 +98,7 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def column(colName: String): Column = Column(colName) + def column(colName: String): Column = withOrigin { Column(colName) } /** * Creates a [[Column]] of literal value. @@ -114,16 +110,12 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def lit(literal: Any): Column = literal match { + def lit(literal: Any): Column = withOrigin { litImpl(literal) } + + private def typedLitImpl[T: TypeTag](literal: T): Column = literal match { case c: Column => c case s: Symbol => new ColumnName(s.name) - case _ => - // This is different from `typedlit`. `typedlit` calls `Literal.create` to use - // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this method, - // `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, we can - // just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. This is - // significantly better when there are many threads calling `lit` concurrently. - Column(Literal(literal)) + case _ => Column(Literal.create(literal)) } /** @@ -134,7 +126,7 @@ object functions { * @group normal_funcs * @since 2.2.0 */ - def typedLit[T : TypeTag](literal: T): Column = typedlit(literal) + def typedLit[T : TypeTag](literal: T): Column = withOrigin { typedLitImpl(literal) } /** * Creates a [[Column]] of literal value. @@ -151,11 +143,7 @@ object functions { * @group normal_funcs * @since 3.2.0 */ - def typedlit[T : TypeTag](literal: T): Column = literal match { - case c: Column => c - case s: Symbol => new ColumnName(s.name) - case _ => Column(Literal.create(literal)) - } + def typedlit[T : TypeTag](literal: T): Column = withOrigin { typedLitImpl(literal) } ////////////////////////////////////////////////////////////////////////////////////////////// // Sort functions @@ -170,7 +158,7 @@ object functions { * @group sort_funcs * @since 1.3.0 */ - def asc(columnName: String): Column = Column(columnName).asc + def asc(columnName: String): Column = withExpr { SortOrder(Column(columnName).expr, Ascending) } /** * Returns a sort expression based on ascending order of the column, @@ -182,7 +170,9 @@ object functions { * @group sort_funcs * @since 2.1.0 */ - def asc_nulls_first(columnName: String): Column = Column(columnName).asc_nulls_first + def asc_nulls_first(columnName: String): Column = withExpr { + SortOrder(Column(columnName).expr, Ascending, NullsFirst, Seq.empty) + } /** * Returns a sort expression based on ascending order of the column, @@ -194,7 +184,9 @@ object functions { * @group sort_funcs * @since 2.1.0 */ - def asc_nulls_last(columnName: String): Column = Column(columnName).asc_nulls_last + def asc_nulls_last(columnName: String): Column = withExpr { + SortOrder(Column(columnName).expr, Ascending, NullsLast, Seq.empty) + } /** * Returns a sort expression based on the descending order of the column. @@ -205,7 +197,7 @@ object functions { * @group sort_funcs * @since 1.3.0 */ - def desc(columnName: String): Column = Column(columnName).desc + def desc(columnName: String): Column = withExpr { SortOrder(Column(columnName).expr, Descending) } /** * Returns a sort expression based on the descending order of the column, @@ -217,7 +209,9 @@ object functions { * @group sort_funcs * @since 2.1.0 */ - def desc_nulls_first(columnName: String): Column = Column(columnName).desc_nulls_first + def desc_nulls_first(columnName: String): Column = withExpr { + SortOrder(Column(columnName).expr, Descending, NullsFirst, Seq.empty) + } /** * Returns a sort expression based on the descending order of the column, @@ -229,7 +223,9 @@ object functions { * @group sort_funcs * @since 2.1.0 */ - def desc_nulls_last(columnName: String): Column = Column(columnName).desc_nulls_last + def desc_nulls_last(columnName: String): Column = withExpr { + SortOrder(Column(columnName).expr, Descending, NullsLast, Seq.empty) + } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -241,29 +237,35 @@ object functions { * @since 1.3.0 */ @deprecated("Use approx_count_distinct", "2.1.0") - def approxCountDistinct(e: Column): Column = approx_count_distinct(e) + def approxCountDistinct(e: Column): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr) + } /** * @group agg_funcs * @since 1.3.0 */ @deprecated("Use approx_count_distinct", "2.1.0") - def approxCountDistinct(columnName: String): Column = approx_count_distinct(columnName) + def approxCountDistinct(columnName: String): Column = withAggregateFunction { + HyperLogLogPlusPlus(Column(columnName).expr) + } /** * @group agg_funcs * @since 1.3.0 */ @deprecated("Use approx_count_distinct", "2.1.0") - def approxCountDistinct(e: Column, rsd: Double): Column = approx_count_distinct(e, rsd) + def approxCountDistinct(e: Column, rsd: Double): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr, rsd) + } /** * @group agg_funcs * @since 1.3.0 */ @deprecated("Use approx_count_distinct", "2.1.0") - def approxCountDistinct(columnName: String, rsd: Double): Column = { - approx_count_distinct(Column(columnName), rsd) + def approxCountDistinct(columnName: String, rsd: Double): Column = withAggregateFunction { + HyperLogLogPlusPlus(Column(columnName).expr, rsd) } /** @@ -282,7 +284,9 @@ object functions { * @group agg_funcs * @since 2.1.0 */ - def approx_count_distinct(columnName: String): Column = approx_count_distinct(column(columnName)) + def approx_count_distinct(columnName: String): Column = withAggregateFunction { + HyperLogLogPlusPlus(Column(columnName).expr) + } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -304,8 +308,8 @@ object functions { * @group agg_funcs * @since 2.1.0 */ - def approx_count_distinct(columnName: String, rsd: Double): Column = { - approx_count_distinct(Column(columnName), rsd) + def approx_count_distinct(columnName: String, rsd: Double): Column = withAggregateFunction { + HyperLogLogPlusPlus(Column(columnName).expr, rsd, 0, 0) } /** @@ -322,7 +326,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def avg(columnName: String): Column = avg(Column(columnName)) + def avg(columnName: String): Column = withAggregateFunction { Average(Column(columnName).expr) } /** * Aggregate function: returns a list of objects with duplicates. @@ -344,7 +348,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def collect_list(columnName: String): Column = collect_list(Column(columnName)) + def collect_list(columnName: String): Column = withAggregateFunction { + CollectList(Column(columnName).expr) + } /** * Aggregate function: returns a set of objects with duplicate elements eliminated. @@ -366,7 +372,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def collect_set(columnName: String): Column = collect_set(Column(columnName)) + def collect_set(columnName: String): Column = withAggregateFunction { + CollectSet(Column(columnName).expr) + } /** * Returns a count-min sketch of a column with the given esp, confidence and seed. The result @@ -378,10 +386,10 @@ object functions { * @since 3.5.0 */ def count_min_sketch( - e: Column, - eps: Column, - confidence: Column, - seed: Column): Column = withAggregateFunction { + e: Column, + eps: Column, + confidence: Column, + seed: Column): Column = withAggregateFunction { new CountMinSketchAgg(e.expr, eps.expr, confidence.expr, seed.expr) } @@ -404,17 +412,11 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def corr(columnName1: String, columnName2: String): Column = { - corr(Column(columnName1), Column(columnName2)) + def corr(columnName1: String, columnName2: String): Column = withAggregateFunction { + Corr(Column(columnName1).expr, Column(columnName2).expr) } - /** - * Aggregate function: returns the number of items in a group. - * - * @group agg_funcs - * @since 1.3.0 - */ - def count(e: Column): Column = withAggregateFunction { + private def countImpl(e: Column) = { e.expr match { // Turn count(*) into count(1) case s: Star => Count(Literal(1)) @@ -428,8 +430,23 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(ExpressionEncoder[Long]()) + def count(e: Column): Column = withAggregateFunction { countImpl(e) } + + /** + * Aggregate function: returns the number of items in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def count(columnName: String): TypedColumn[Any, Long] = withAggregateFunction { + countImpl(Column(columnName)) + }.as(ExpressionEncoder[Long]()) + + private def count_distinctImpl(expr: Column, exprs: Seq[Column]) = { + // For usage like countDistinct("*"), we should let analyzer expand star and + // resolve function. + UnresolvedFunction("count", (expr +: exprs).map(_.expr), isDistinct = true) + } /** * Aggregate function: returns the number of distinct items in a group. @@ -440,7 +457,9 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def countDistinct(expr: Column, exprs: Column*): Column = count_distinct(expr, exprs: _*) + def countDistinct(expr: Column, exprs: Column*): Column = withExpr { + count_distinctImpl(expr, exprs) + } /** * Aggregate function: returns the number of distinct items in a group. @@ -451,8 +470,9 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def countDistinct(columnName: String, columnNames: String*): Column = - count_distinct(Column(columnName), columnNames.map(Column.apply) : _*) + def countDistinct(columnName: String, columnNames: String*): Column = withExpr { + count_distinctImpl(Column(columnName), columnNames.map(Column.apply)) + } /** * Aggregate function: returns the number of distinct items in a group. @@ -461,10 +481,9 @@ object functions { * @since 3.2.0 */ @scala.annotation.varargs - def count_distinct(expr: Column, exprs: Column*): Column = - // For usage like countDistinct("*"), we should let analyzer expand star and - // resolve function. - Column(UnresolvedFunction("count", (expr +: exprs).map(_.expr), isDistinct = true)) + def count_distinct(expr: Column, exprs: Column*): Column = withExpr { + count_distinctImpl(expr, exprs) + } /** * Aggregate function: returns the population covariance for two columns. @@ -482,8 +501,8 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def covar_pop(columnName1: String, columnName2: String): Column = { - covar_pop(Column(columnName1), Column(columnName2)) + def covar_pop(columnName1: String, columnName2: String): Column = withAggregateFunction { + CovPopulation(Column(columnName1).expr, Column(columnName2).expr) } /** @@ -502,8 +521,8 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def covar_samp(columnName1: String, columnName2: String): Column = { - covar_samp(Column(columnName1), Column(columnName2)) + def covar_samp(columnName1: String, columnName2: String): Column = withAggregateFunction { + CovSample(Column(columnName1).expr, Column(columnName2).expr) } /** @@ -534,8 +553,8 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def first(columnName: String, ignoreNulls: Boolean): Column = { - first(Column(columnName), ignoreNulls) + def first(columnName: String, ignoreNulls: Boolean): Column = withAggregateFunction { + First(Column(columnName).expr, ignoreNulls) } /** @@ -550,7 +569,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def first(e: Column): Column = first(e, ignoreNulls = false) + def first(e: Column): Column = withAggregateFunction { + First(e.expr, ignoreNulls = false) + } /** * Aggregate function: returns the first value of a column in a group. @@ -564,7 +585,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def first(columnName: String): Column = first(Column(columnName)) + def first(columnName: String): Column = withAggregateFunction { + First(Column(columnName).expr, ignoreNulls = false) + } /** * Aggregate function: returns the first value in a group. @@ -575,7 +598,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def first_value(e: Column): Column = call_function("first_value", e) + def first_value(e: Column): Column = withExpr { + call_functionImpl("first_value", e) + } /** * Aggregate function: returns the first value in a group. @@ -589,8 +614,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def first_value(e: Column, ignoreNulls: Column): Column = - call_function("first_value", e, ignoreNulls) + def first_value(e: Column, ignoreNulls: Column): Column = withExpr { + call_functionImpl("first_value", e, ignoreNulls) + } /** * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated @@ -599,7 +625,7 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def grouping(e: Column): Column = Column(Grouping(e.expr)) + def grouping(e: Column): Column = withExpr { Grouping(e.expr) } /** * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated @@ -608,7 +634,7 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def grouping(columnName: String): Column = grouping(Column(columnName)) + def grouping(columnName: String): Column = withExpr { Grouping(Column(columnName).expr) } /** * Aggregate function: returns the level of grouping, equals to @@ -623,7 +649,7 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr))) + def grouping_id(cols: Column*): Column = withExpr { GroupingID(cols.map(_.expr)) } /** * Aggregate function: returns the level of grouping, equals to @@ -637,8 +663,8 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def grouping_id(colName: String, colNames: String*): Column = { - grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*) + def grouping_id(colName: String, colNames: String*): Column = withExpr { + GroupingID((Seq(colName) ++ colNames).map(Column(_).expr)) } /** @@ -671,8 +697,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def hll_sketch_agg(columnName: String, lgConfigK: Int): Column = { - hll_sketch_agg(Column(columnName), lgConfigK) + def hll_sketch_agg(columnName: String, lgConfigK: Int): Column = withAggregateFunction { + new HllSketchAgg(Column(columnName).expr, Literal(lgConfigK)) } /** @@ -693,8 +719,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def hll_sketch_agg(columnName: String): Column = { - hll_sketch_agg(Column(columnName)) + def hll_sketch_agg(columnName: String): Column = withAggregateFunction { + new HllSketchAgg(Column(columnName).expr) } /** @@ -732,9 +758,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def hll_union_agg(columnName: String, allowDifferentLgConfigK: Boolean): Column = { - hll_union_agg(Column(columnName), allowDifferentLgConfigK) - } + def hll_union_agg(columnName: String, allowDifferentLgConfigK: Boolean): Column = + withAggregateFunction { new HllUnionAgg(Column(columnName).expr, allowDifferentLgConfigK) } /** * Aggregate function: returns the updatable binary representation of the Datasketches @@ -758,8 +783,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def hll_union_agg(columnName: String): Column = { - hll_union_agg(Column(columnName)) + def hll_union_agg(columnName: String): Column = withAggregateFunction { + new HllUnionAgg(Column(columnName).expr) } /** @@ -776,7 +801,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + def kurtosis(columnName: String): Column = withAggregateFunction { + Kurtosis(Column(columnName).expr) + } /** * Aggregate function: returns the last value in a group. @@ -806,8 +833,8 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def last(columnName: String, ignoreNulls: Boolean): Column = { - last(Column(columnName), ignoreNulls) + def last(columnName: String, ignoreNulls: Boolean): Column = withAggregateFunction { + Last(Column(columnName).expr, ignoreNulls) } /** @@ -822,7 +849,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def last(e: Column): Column = last(e, ignoreNulls = false) + def last(e: Column): Column = withAggregateFunction { + Last(e.expr, ignoreNulls = false) + } /** * Aggregate function: returns the last value of the column in a group. @@ -836,7 +865,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false) + def last(columnName: String): Column = withAggregateFunction { + Last(Column(columnName).expr, ignoreNulls = false) + } /** * Aggregate function: returns the last value in a group. @@ -847,7 +878,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def last_value(e: Column): Column = call_function("last_value", e) + def last_value(e: Column): Column = withExpr { + call_functionImpl("last_value", e) + } /** * Aggregate function: returns the last value in a group. @@ -861,8 +894,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def last_value(e: Column, ignoreNulls: Column): Column = - call_function("last_value", e, ignoreNulls) + def last_value(e: Column, ignoreNulls: Column): Column = withExpr { + call_functionImpl("last_value", e, ignoreNulls) + } /** * Aggregate function: returns the most frequent value in a group. @@ -886,7 +920,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def max(columnName: String): Column = max(Column(columnName)) + def max(columnName: String): Column = withAggregateFunction { Max(Column(columnName).expr) } /** * Aggregate function: returns the value associated with the maximum value of ord. @@ -903,7 +937,7 @@ object functions { * @group agg_funcs * @since 1.4.0 */ - def mean(e: Column): Column = avg(e) + def mean(e: Column): Column = withAggregateFunction { Average(e.expr) } /** * Aggregate function: returns the average of the values in a group. @@ -912,7 +946,7 @@ object functions { * @group agg_funcs * @since 1.4.0 */ - def mean(columnName: String): Column = avg(columnName) + def mean(columnName: String): Column = withAggregateFunction { Average(Column(columnName).expr) } /** * Aggregate function: returns the median of the values in a group. @@ -936,7 +970,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def min(columnName: String): Column = min(Column(columnName)) + def min(columnName: String): Column = withAggregateFunction { Min(Column(columnName).expr) } /** * Aggregate function: returns the value associated with the minimum value of ord. @@ -1015,8 +1049,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def approx_percentile(e: Column, percentage: Column, accuracy: Column): Column = - call_function("approx_percentile", e, percentage, accuracy) + def approx_percentile(e: Column, percentage: Column, accuracy: Column): Column = withExpr { + call_functionImpl("approx_percentile", e, percentage, accuracy) + } /** * Aggregate function: returns the product of all numerical elements in a group. @@ -1041,7 +1076,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def skewness(columnName: String): Column = skewness(Column(columnName)) + def skewness(columnName: String): Column = withAggregateFunction { + Skewness(Column(columnName).expr) + } /** * Aggregate function: alias for `stddev_samp`. @@ -1049,7 +1086,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def std(e: Column): Column = call_function("std", e) + def std(e: Column): Column = withExpr { + call_functionImpl("std", e) + } /** * Aggregate function: alias for `stddev_samp`. @@ -1057,7 +1096,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = call_function("stddev", e) + def stddev(e: Column): Column = withExpr { + call_functionImpl("stddev", e) + } /** * Aggregate function: alias for `stddev_samp`. @@ -1065,7 +1106,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(columnName: String): Column = stddev(Column(columnName)) + def stddev(columnName: String): Column = withAggregateFunction { + StddevSamp(Column(columnName).expr) + } /** * Aggregate function: returns the sample standard deviation of @@ -1083,7 +1126,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) + def stddev_samp(columnName: String): Column = withAggregateFunction { + StddevSamp(Column(columnName).expr) + } /** * Aggregate function: returns the population standard deviation of @@ -1101,7 +1146,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) + def stddev_pop(columnName: String): Column = withAggregateFunction { + StddevPop(Column(columnName).expr) + } /** * Aggregate function: returns the sum of all values in the expression. @@ -1117,7 +1164,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sum(columnName: String): Column = sum(Column(columnName)) + def sum(columnName: String): Column = withAggregateFunction { Sum(Column(columnName).expr) } /** * Aggregate function: returns the sum of distinct values in the expression. @@ -1135,7 +1182,9 @@ object functions { * @since 1.3.0 */ @deprecated("Use sum_distinct", "3.2.0") - def sumDistinct(columnName: String): Column = sum_distinct(Column(columnName)) + def sumDistinct(columnName: String): Column = withAggregateFunction({ + Sum(Column(columnName).expr) + }, isDistinct = true) /** * Aggregate function: returns the sum of distinct values in the expression. @@ -1159,7 +1208,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def variance(columnName: String): Column = variance(Column(columnName)) + def variance(columnName: String): Column = withAggregateFunction { + VarianceSamp(Column(columnName).expr) + } /** * Aggregate function: returns the unbiased variance of the values in a group. @@ -1175,7 +1226,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(columnName: String): Column = var_samp(Column(columnName)) + def var_samp(columnName: String): Column = withAggregateFunction { + VarianceSamp(Column(columnName).expr) + } /** * Aggregate function: returns the population variance of the values in a group. @@ -1191,7 +1244,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_pop(columnName: String): Column = var_pop(Column(columnName)) + def var_pop(columnName: String): Column = withAggregateFunction { + VariancePop(Column(columnName).expr) + } /** * Aggregate function: returns the average of the independent variable for non-null pairs @@ -1327,7 +1382,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def every(e: Column): Column = call_function("every", e) + def every(e: Column): Column = withExpr { + call_functionImpl("every", e) + } /** * Aggregate function: returns true if all values of `e` are true. @@ -1343,7 +1400,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def some(e: Column): Column = call_function("some", e) + def some(e: Column): Column = withExpr { + call_functionImpl("some", e) + } /** * Aggregate function: returns true if at least one value of `e` is true. @@ -1351,7 +1410,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def any(e: Column): Column = call_function("any", e) + def any(e: Column): Column = withExpr { + call_functionImpl("any", e) + } /** * Aggregate function: returns true if at least one value of `e` is true. @@ -1429,7 +1490,9 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int): Column = lag(e, offset, null) + def lag(e: Column, offset: Int): Column = withExpr { + Lag(e.expr, Literal(offset), Literal(null), false) + } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -1441,7 +1504,9 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(columnName: String, offset: Int): Column = lag(columnName, offset, null) + def lag(columnName: String, offset: Int): Column = withExpr { + Lag(Column(columnName).expr, Literal(offset), Literal(null), false) + } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -1453,8 +1518,8 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(columnName: String, offset: Int, defaultValue: Any): Column = { - lag(Column(columnName), offset, defaultValue) + def lag(columnName: String, offset: Int, defaultValue: Any): Column = withExpr { + Lag(Column(columnName).expr, Literal(offset), Literal(defaultValue), false) } /** @@ -1467,8 +1532,8 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int, defaultValue: Any): Column = { - lag(e, offset, defaultValue, false) + def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr { + Lag(e.expr, Literal(offset), Literal(defaultValue), false) } /** @@ -1497,7 +1562,9 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(columnName: String, offset: Int): Column = { lead(columnName, offset, null) } + def lead(columnName: String, offset: Int): Column = withExpr { + Lead(Column(columnName).expr, Literal(offset), Literal(null), false) + } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -1509,7 +1576,9 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int): Column = { lead(e, offset, null) } + def lead(e: Column, offset: Int): Column = withExpr { + Lead(e.expr, Literal(offset), Literal(null), false) + } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -1521,8 +1590,8 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(columnName: String, offset: Int, defaultValue: Any): Column = { - lead(Column(columnName), offset, defaultValue) + def lead(columnName: String, offset: Int, defaultValue: Any): Column = withExpr { + Lead(Column(columnName).expr, Literal(offset), Literal(defaultValue), false) } /** @@ -1535,8 +1604,8 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int, defaultValue: Any): Column = { - lead(e, offset, defaultValue, false) + def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr { + Lead(e.expr, Literal(offset), Literal(defaultValue), false) } /** @@ -1655,8 +1724,8 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def array(colName: String, colNames: String*): Column = { - array((colName +: colNames).map(col) : _*) + def array(colName: String, colNames: String*): Column = withExpr { + CreateArray((colName +: colNames).map(col(_).expr)) } /** @@ -1793,7 +1862,7 @@ object functions { * @since 1.4.0 */ @deprecated("Use monotonically_increasing_id()", "2.0.0") - def monotonicallyIncreasingId(): Column = monotonically_increasing_id() + def monotonicallyIncreasingId(): Column = withExpr { MonotonicallyIncreasingID() } /** * A column expression that generates monotonically increasing 64-bit integers. @@ -1839,7 +1908,7 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def negate(e: Column): Column = -e + def negate(e: Column): Column = withExpr { UnaryMinus(e.expr) } /** * Inversion of boolean expression, i.e. NOT. @@ -1854,7 +1923,7 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def not(e: Column): Column = !e + def not(e: Column): Column = withExpr { Not(e.expr) } /** * Generate a random column with independent and identically distributed (i.i.d.) samples @@ -1876,7 +1945,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def rand(): Column = rand(Utils.random.nextLong) + def rand(): Column = withExpr { Rand(Utils.random.nextLong) } /** * Generate a column with independent and identically distributed (i.i.d.) samples from @@ -1898,7 +1967,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def randn(): Column = randn(Utils.random.nextLong) + def randn(): Column = withExpr { Randn(Utils.random.nextLong) } /** * Partition ID. @@ -1924,7 +1993,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def sqrt(colName: String): Column = sqrt(Column(colName)) + def sqrt(colName: String): Column = withExpr { Sqrt(Column(colName).expr) } /** * Returns the sum of `left` and `right` and the result is null on overflow. The acceptable @@ -1933,7 +2002,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_add(left: Column, right: Column): Column = call_function("try_add", left, right) + def try_add(left: Column, right: Column): Column = withExpr { + call_functionImpl("try_add", left, right) + } /** * Returns the mean calculated from values of a group and the result is null on overflow. @@ -1941,8 +2012,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_avg(e: Column): Column = - call_function("try_avg", e) + def try_avg(e: Column): Column = withExpr { + call_functionImpl("try_avg", e) + } /** * Returns `dividend``/``divisor`. It always performs floating point division. Its result is @@ -1951,8 +2023,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_divide(dividend: Column, divisor: Column): Column = - call_function("try_divide", dividend, divisor) + def try_divide(dividend: Column, divisor: Column): Column = withExpr { + call_functionImpl("try_divide", dividend, divisor) + } /** * Returns `left``*``right` and the result is null on overflow. The acceptable input types are @@ -1961,8 +2034,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_multiply(left: Column, right: Column): Column = - call_function("try_multiply", left, right) + def try_multiply(left: Column, right: Column): Column = withExpr { + call_functionImpl("try_multiply", left, right) + } /** * Returns `left``-``right` and the result is null on overflow. The acceptable input types are @@ -1971,8 +2045,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_subtract(left: Column, right: Column): Column = - call_function("try_subtract", left, right) + def try_subtract(left: Column, right: Column): Column = withExpr { + call_functionImpl("try_subtract", left, right) + } /** * Returns the sum calculated from values of a group and the result is null on overflow. @@ -1980,7 +2055,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_sum(e: Column): Column = call_function("try_sum", e) + def try_sum(e: Column): Column = withExpr { + call_functionImpl("try_sum", e) + } /** * Creates a new struct column. @@ -2002,8 +2079,8 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def struct(colName: String, colNames: String*): Column = { - struct((colName +: colNames).map(col) : _*) + def struct(colName: String, colNames: String*): Column = withExpr { + CreateStruct.create((colName +: colNames).map(Column(_).expr)) } /** @@ -2028,7 +2105,7 @@ object functions { * @since 1.4.0 */ def when(condition: Column, value: Any): Column = withExpr { - CaseWhen(Seq((condition.expr, lit(value).expr))) + CaseWhen(Seq((condition.expr, litImpl(value).expr))) } /** @@ -2038,7 +2115,7 @@ object functions { * @since 1.4.0 */ @deprecated("Use bitwise_not", "3.2.0") - def bitwiseNOT(e: Column): Column = bitwise_not(e) + def bitwiseNOT(e: Column): Column = withExpr { BitwiseNot(e.expr) } /** * Computes bitwise NOT (~) of a number. @@ -2075,7 +2152,9 @@ object functions { * @group bitwise_funcs * @since 3.5.0 */ - def getbit(e: Column, pos: Column): Column = call_function("getbit", e, pos) + def getbit(e: Column, pos: Column): Column = withExpr { + call_functionImpl("getbit", e, pos) + } /** * Parses the expression string into the column that it represents, similar to @@ -2087,11 +2166,11 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = { + def expr(expr: String): Column = withExpr { val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { new SparkSqlParser() } - Column(parser.parseExpression(expr)) + parser.parseExpression(expr) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2120,7 +2199,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def acos(columnName: String): Column = acos(Column(columnName)) + def acos(columnName: String): Column = withExpr { Acos(Column(columnName).expr) } /** * @return inverse hyperbolic cosine of `e` @@ -2136,7 +2215,7 @@ object functions { * @group math_funcs * @since 3.1.0 */ - def acosh(columnName: String): Column = acosh(Column(columnName)) + def acosh(columnName: String): Column = withExpr { Acosh(Column(columnName).expr) } /** * @return inverse sine of `e` in radians, as if computed by `java.lang.Math.asin` @@ -2152,7 +2231,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def asin(columnName: String): Column = asin(Column(columnName)) + def asin(columnName: String): Column = withExpr { Asin(Column(columnName).expr) } /** * @return inverse hyperbolic sine of `e` @@ -2168,7 +2247,7 @@ object functions { * @group math_funcs * @since 3.1.0 */ - def asinh(columnName: String): Column = asinh(Column(columnName)) + def asinh(columnName: String): Column = withExpr { Asinh(Column(columnName).expr) } /** * @return inverse tangent of `e` as if computed by `java.lang.Math.atan` @@ -2184,7 +2263,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan(columnName: String): Column = atan(Column(columnName)) + def atan(columnName: String): Column = withExpr { Atan(Column(columnName).expr) } /** * @param y coordinate on y-axis @@ -2212,7 +2291,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(y: Column, xName: String): Column = atan2(y, Column(xName)) + def atan2(y: Column, xName: String): Column = withExpr { Atan2(y.expr, Column(xName).expr) } /** * @param yName coordinate on y-axis @@ -2226,7 +2305,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(yName: String, x: Column): Column = atan2(Column(yName), x) + def atan2(yName: String, x: Column): Column = withExpr { Atan2(Column(yName).expr, x.expr) } /** * @param yName coordinate on y-axis @@ -2240,8 +2319,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(yName: String, xName: String): Column = - atan2(Column(yName), Column(xName)) + def atan2(yName: String, xName: String): Column = withExpr { + Atan2(Column(yName).expr, Column(xName).expr) + } /** * @param y coordinate on y-axis @@ -2255,7 +2335,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(y: Column, xValue: Double): Column = atan2(y, lit(xValue)) + def atan2(y: Column, xValue: Double): Column = withExpr { Atan2(y.expr, litImpl(xValue).expr) } /** * @param yName coordinate on y-axis @@ -2269,7 +2349,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(yName: String, xValue: Double): Column = atan2(Column(yName), xValue) + def atan2(yName: String, xValue: Double): Column = withExpr { + Atan2(Column(yName).expr, litImpl(xValue).expr) + } /** * @param yValue coordinate on y-axis @@ -2283,7 +2365,8 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(yValue: Double, x: Column): Column = atan2(lit(yValue), x) + // todo + def atan2(yValue: Double, x: Column): Column = withExpr { Atan2(litImpl(yValue).expr, x.expr) } /** * @param yValue coordinate on y-axis @@ -2297,7 +2380,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(yValue: Double, xName: String): Column = atan2(yValue, Column(xName)) + def atan2(yValue: Double, xName: String): Column = withExpr { + Atan2(litImpl(yValue).expr, Column(xName).expr) + } /** * @return inverse hyperbolic tangent of `e` @@ -2313,7 +2398,7 @@ object functions { * @group math_funcs * @since 3.1.0 */ - def atanh(columnName: String): Column = atanh(Column(columnName)) + def atanh(columnName: String): Column = withExpr { Atanh(Column(columnName).expr) } /** * An expression that returns the string representation of the binary value of the given long @@ -2331,7 +2416,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def bin(columnName: String): Column = bin(Column(columnName)) + def bin(columnName: String): Column = withExpr { Bin(Column(columnName).expr) } /** * Computes the cube-root of the given value. @@ -2347,7 +2432,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cbrt(columnName: String): Column = cbrt(Column(columnName)) + def cbrt(columnName: String): Column = withExpr { Cbrt(Column(columnName).expr) } /** * Computes the ceiling of the given value of `e` to `scale` decimal places. @@ -2355,7 +2440,7 @@ object functions { * @group math_funcs * @since 3.3.0 */ - def ceil(e: Column, scale: Column): Column = call_function("ceil", e, scale) + def ceil(e: Column, scale: Column): Column = withExpr { call_functionImpl ("ceil", e, scale) } /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2363,7 +2448,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def ceil(e: Column): Column = call_function("ceil", e) + def ceil(e: Column): Column = withExpr { call_functionImpl("ceil", e) } /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2371,7 +2456,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def ceil(columnName: String): Column = ceil(Column(columnName)) + def ceil(columnName: String): Column = withExpr { call_functionImpl("ceil", Column(columnName)) } /** * Computes the ceiling of the given value of `e` to `scale` decimal places. @@ -2379,8 +2464,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column, scale: Column): Column = - call_function("ceiling", e, scale) + def ceiling(e: Column, scale: Column): Column = withExpr { + call_functionImpl("ceiling", e, scale) + } /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2388,7 +2474,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column): Column = call_function("ceiling", e) + def ceiling(e: Column): Column = withExpr { + call_functionImpl("ceiling", e) + } /** * Convert a number in a string column from one base to another. @@ -2397,7 +2485,7 @@ object functions { * @since 1.5.0 */ def conv(num: Column, fromBase: Int, toBase: Int): Column = withExpr { - Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + Conv(num.expr, litImpl(fromBase).expr, litImpl(toBase).expr) } /** @@ -2416,7 +2504,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cos(columnName: String): Column = cos(Column(columnName)) + def cos(columnName: String): Column = withExpr { Cos(Column(columnName).expr) } /** * @param e hyperbolic angle @@ -2434,7 +2522,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cosh(columnName: String): Column = cosh(Column(columnName)) + def cosh(columnName: String): Column = withExpr { Cosh(Column(columnName).expr) } /** * @param e angle in radians @@ -2476,7 +2564,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def exp(columnName: String): Column = exp(Column(columnName)) + def exp(columnName: String): Column = withExpr { Exp(Column(columnName).expr) } /** * Computes the exponential of the given value minus one. @@ -2492,7 +2580,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def expm1(columnName: String): Column = expm1(Column(columnName)) + def expm1(columnName: String): Column = withExpr { Expm1(Column(columnName).expr) } /** * Computes the factorial of the given value. @@ -2508,7 +2596,9 @@ object functions { * @group math_funcs * @since 3.3.0 */ - def floor(e: Column, scale: Column): Column = call_function("floor", e, scale) + def floor(e: Column, scale: Column): Column = withExpr { + call_functionImpl("floor", e, scale) + } /** * Computes the floor of the given value of `e` to 0 decimal places. @@ -2516,7 +2606,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def floor(e: Column): Column = call_function("floor", e) + def floor(e: Column): Column = withExpr { call_functionImpl("floor", e) } /** * Computes the floor of the given column value to 0 decimal places. @@ -2524,7 +2614,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def floor(columnName: String): Column = floor(Column(columnName)) + def floor(columnName: String): Column = withExpr { + call_functionImpl("floor", Column(columnName)) + } /** * Returns the greatest value of the list of values, skipping null values. @@ -2544,8 +2636,8 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(columnName: String, columnNames: String*): Column = { - greatest((columnName +: columnNames).map(Column.apply): _*) + def greatest(columnName: String, columnNames: String*): Column = withExpr { + Greatest((columnName +: columnNames).map(Column(_).expr)) } /** @@ -2579,7 +2671,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + def hypot(l: Column, rightName: String): Column = withExpr { + Hypot(l.expr, Column(rightName).expr) + } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -2587,7 +2681,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + def hypot(leftName: String, r: Column): Column = withExpr { Hypot(Column(leftName).expr, r.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -2595,8 +2689,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(leftName: String, rightName: String): Column = - hypot(Column(leftName), Column(rightName)) + def hypot(leftName: String, rightName: String): Column = withExpr { + Hypot(Column(leftName).expr, Column(rightName).expr) + } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -2604,7 +2699,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Double): Column = hypot(l, lit(r)) + def hypot(l: Column, r: Double): Column = withExpr { Hypot(l.expr, litImpl(r).expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -2612,7 +2707,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + def hypot(leftName: String, r: Double): Column = withExpr { + Hypot(Column(leftName).expr, litImpl(r).expr) + } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -2620,7 +2717,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Double, r: Column): Column = hypot(lit(l), r) + def hypot(l: Double, r: Column): Column = withExpr { Hypot(litImpl(l).expr, r.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -2628,7 +2725,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + def hypot(l: Double, rightName: String): Column = withExpr { + Hypot(litImpl(l).expr, Column(rightName).expr) + } /** * Returns the least value of the list of values, skipping null values. @@ -2648,8 +2747,8 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(columnName: String, columnNames: String*): Column = { - least((columnName +: columnNames).map(Column.apply): _*) + def least(columnName: String, columnNames: String*): Column = withExpr { + Least((columnName +: columnNames).map(Column(_).expr)) } /** @@ -2658,7 +2757,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ln(e: Column): Column = log(e) + def ln(e: Column): Column = withExpr { Log(e.expr) } /** * Computes the natural logarithm of the given value. @@ -2674,7 +2773,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(columnName: String): Column = log(Column(columnName)) + def log(columnName: String): Column = withExpr { Log(Column(columnName).expr) } /** * Returns the first argument-base logarithm of the second argument. @@ -2682,7 +2781,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(base: Double, a: Column): Column = withExpr { Logarithm(lit(base).expr, a.expr) } + def log(base: Double, a: Column): Column = withExpr { Logarithm(litImpl(base).expr, a.expr) } /** * Returns the first argument-base logarithm of the second argument. @@ -2690,7 +2789,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(base: Double, columnName: String): Column = log(base, Column(columnName)) + def log(base: Double, columnName: String): Column = withExpr { + Logarithm(litImpl(base).expr, Column(columnName).expr) + } /** * Computes the logarithm of the given value in base 10. @@ -2706,7 +2807,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log10(columnName: String): Column = log10(Column(columnName)) + def log10(columnName: String): Column = withExpr { Log10(Column(columnName).expr) } /** * Computes the natural logarithm of the given value plus one. @@ -2722,7 +2823,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log1p(columnName: String): Column = log1p(Column(columnName)) + def log1p(columnName: String): Column = withExpr { Log1p(Column(columnName).expr) } /** * Computes the logarithm of the given column in base 2. @@ -2738,7 +2839,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def log2(columnName: String): Column = log2(Column(columnName)) + def log2(columnName: String): Column = withExpr { Log2(Column(columnName).expr) } /** * Returns the negated value. @@ -2746,7 +2847,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def negative(e: Column): Column = call_function("negative", e) + def negative(e: Column): Column = withExpr { + call_functionImpl("negative", e) + } /** * Returns Pi. @@ -2778,7 +2881,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) + def pow(l: Column, rightName: String): Column = withExpr { Pow(l.expr, Column(rightName).expr) } /** * Returns the value of the first argument raised to the power of the second argument. @@ -2786,7 +2889,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) + def pow(leftName: String, r: Column): Column = withExpr { Pow(Column(leftName).expr, r.expr) } /** * Returns the value of the first argument raised to the power of the second argument. @@ -2794,7 +2897,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + def pow(leftName: String, rightName: String): Column = withExpr { + Pow(Column(leftName).expr, Column(rightName).expr) + } /** * Returns the value of the first argument raised to the power of the second argument. @@ -2802,7 +2907,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Double): Column = pow(l, lit(r)) + def pow(l: Column, r: Double): Column = withExpr { Pow(l.expr, litImpl(r).expr) } /** * Returns the value of the first argument raised to the power of the second argument. @@ -2810,7 +2915,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) + def pow(leftName: String, r: Double): Column = withExpr { + Pow(Column(leftName).expr, litImpl(r).expr) + } /** * Returns the value of the first argument raised to the power of the second argument. @@ -2818,7 +2925,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Double, r: Column): Column = pow(lit(l), r) + def pow(l: Double, r: Column): Column = withExpr { Pow(litImpl(l).expr, r.expr) } /** * Returns the value of the first argument raised to the power of the second argument. @@ -2826,7 +2933,9 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + def pow(l: Double, rightName: String): Column = withExpr { + Pow(litImpl(l).expr, Column(rightName).expr) + } /** * Returns the value of the first argument raised to the power of the second argument. @@ -2834,7 +2943,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def power(l: Column, r: Column): Column = pow(l, r) + def power(l: Column, r: Column): Column = withExpr { Pow(l.expr, r.expr) } /** * Returns the positive value of dividend mod divisor. @@ -2862,7 +2971,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def rint(columnName: String): Column = rint(Column(columnName)) + def rint(columnName: String): Column = withExpr { Rint(Column(columnName).expr) } /** * Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode. @@ -2870,7 +2979,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column): Column = round(e, 0) + def round(e: Column): Column = withExpr { Round(e.expr, Literal(0)) } /** * Round the value of `e` to `scale` decimal places with HALF_UP round mode @@ -2887,7 +2996,7 @@ object functions { * @group math_funcs * @since 2.0.0 */ - def bround(e: Column): Column = bround(e, 0) + def bround(e: Column): Column = withExpr { BRound(e.expr, Literal(0)) } /** * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode @@ -2915,7 +3024,9 @@ object functions { * @since 1.5.0 */ @deprecated("Use shiftleft", "3.2.0") - def shiftLeft(e: Column, numBits: Int): Column = shiftleft(e, numBits) + def shiftLeft(e: Column, numBits: Int): Column = withExpr { + ShiftLeft(e.expr, litImpl(numBits).expr) + } /** * Shift the given value numBits left. If the given value is a long value, this function @@ -2924,7 +3035,9 @@ object functions { * @group math_funcs * @since 3.2.0 */ - def shiftleft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) } + def shiftleft(e: Column, numBits: Int): Column = withExpr { + ShiftLeft(e.expr, litImpl(numBits).expr) + } /** * (Signed) shift the given value numBits right. If the given value is a long value, it will @@ -2934,7 +3047,9 @@ object functions { * @since 1.5.0 */ @deprecated("Use shiftright", "3.2.0") - def shiftRight(e: Column, numBits: Int): Column = shiftright(e, numBits) + def shiftRight(e: Column, numBits: Int): Column = withExpr { + ShiftRight(e.expr, litImpl(numBits).expr) + } /** * (Signed) shift the given value numBits right. If the given value is a long value, it will @@ -2944,7 +3059,7 @@ object functions { * @since 3.2.0 */ def shiftright(e: Column, numBits: Int): Column = withExpr { - ShiftRight(e.expr, lit(numBits).expr) + ShiftRight(e.expr, litImpl(numBits).expr) } /** @@ -2955,7 +3070,9 @@ object functions { * @since 1.5.0 */ @deprecated("Use shiftrightunsigned", "3.2.0") - def shiftRightUnsigned(e: Column, numBits: Int): Column = shiftrightunsigned(e, numBits) + def shiftRightUnsigned(e: Column, numBits: Int): Column = withExpr { + ShiftRightUnsigned(e.expr, litImpl(numBits).expr) + } /** * Unsigned shift the given value numBits right. If the given value is a long value, @@ -2965,7 +3082,7 @@ object functions { * @since 3.2.0 */ def shiftrightunsigned(e: Column, numBits: Int): Column = withExpr { - ShiftRightUnsigned(e.expr, lit(numBits).expr) + ShiftRightUnsigned(e.expr, litImpl(numBits).expr) } /** @@ -2974,7 +3091,9 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def sign(e: Column): Column = call_function("sign", e) + def sign(e: Column): Column = withExpr { + call_functionImpl("sign", e) + } /** * Computes the signum of the given value. @@ -2990,7 +3109,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def signum(columnName: String): Column = signum(Column(columnName)) + def signum(columnName: String): Column = withExpr { Signum(Column(columnName).expr) } /** * @param e angle in radians @@ -3008,7 +3127,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sin(columnName: String): Column = sin(Column(columnName)) + def sin(columnName: String): Column = withExpr { Sin(Column(columnName).expr) } /** * @param e hyperbolic angle @@ -3026,7 +3145,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sinh(columnName: String): Column = sinh(Column(columnName)) + def sinh(columnName: String): Column = withExpr { Sinh(Column(columnName).expr) } /** * @param e angle in radians @@ -3044,7 +3163,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tan(columnName: String): Column = tan(Column(columnName)) + def tan(columnName: String): Column = withExpr { Tan(Column(columnName).expr) } /** * @param e hyperbolic angle @@ -3062,21 +3181,21 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tanh(columnName: String): Column = tanh(Column(columnName)) + def tanh(columnName: String): Column = withExpr { Tanh(Column(columnName).expr) } /** * @group math_funcs * @since 1.4.0 */ @deprecated("Use degrees", "2.1.0") - def toDegrees(e: Column): Column = degrees(e) + def toDegrees(e: Column): Column = withExpr { ToDegrees(e.expr) } /** * @group math_funcs * @since 1.4.0 */ @deprecated("Use degrees", "2.1.0") - def toDegrees(columnName: String): Column = degrees(Column(columnName)) + def toDegrees(columnName: String): Column = withExpr { ToDegrees(Column(columnName).expr) } /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. @@ -3098,21 +3217,21 @@ object functions { * @group math_funcs * @since 2.1.0 */ - def degrees(columnName: String): Column = degrees(Column(columnName)) + def degrees(columnName: String): Column = withExpr { ToDegrees(Column(columnName).expr) } /** * @group math_funcs * @since 1.4.0 */ @deprecated("Use radians", "2.1.0") - def toRadians(e: Column): Column = radians(e) + def toRadians(e: Column): Column = withExpr { ToRadians(e.expr) } /** * @group math_funcs * @since 1.4.0 */ @deprecated("Use radians", "2.1.0") - def toRadians(columnName: String): Column = radians(Column(columnName)) + def toRadians(columnName: String): Column = withExpr { ToRadians(Column(columnName).expr) } /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. @@ -3134,7 +3253,7 @@ object functions { * @group math_funcs * @since 2.1.0 */ - def radians(columnName: String): Column = radians(Column(columnName)) + def radians(columnName: String): Column = withExpr { ToRadians(Column(columnName).expr) } /** * Returns the bucket number into which the value of this expression would fall @@ -3179,7 +3298,9 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def current_schema(): Column = call_function("current_schema") + def current_schema(): Column = withExpr { + call_functionImpl("current_schema") + } /** * Returns the user name of current execution context. @@ -3220,7 +3341,7 @@ object functions { def sha2(e: Column, numBits: Int): Column = { require(Seq(0, 224, 256, 384, 512).contains(numBits), s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") - withExpr { Sha2(e.expr, lit(numBits).expr) } + withExpr { Sha2(e.expr, litImpl(numBits).expr) } } /** @@ -3304,8 +3425,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def hll_sketch_estimate(columnName: String): Column = { - hll_sketch_estimate(Column(columnName)) + def hll_sketch_estimate(columnName: String): Column = withExpr { + HllSketchEstimate(Column(columnName).expr) } /** @@ -3328,8 +3449,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def hll_union(columnName1: String, columnName2: String): Column = { - hll_union(Column(columnName1), Column(columnName2)) + def hll_union(columnName1: String, columnName2: String): Column = withExpr { + new HllUnion(Column(columnName1).expr, Column(columnName2).expr) } /** @@ -3363,7 +3484,9 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def user(): Column = call_function("user") + def user(): Column = withExpr { + call_functionImpl("user") + } /** * Returns the user name of current execution context. @@ -3641,7 +3764,9 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def sha(col: Column): Column = call_function("sha", col) + def sha(col: Column): Column = withExpr { + call_functionImpl("sha", col) + } /** * Returns the length of the block being read, or -1 if not available. @@ -3679,8 +3804,9 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def java_method(cols: Column*): Column = - call_function("java_method", cols: _*) + def java_method(cols: Column*): Column = withExpr { + call_functionImpl("java_method", cols: _*) + } /** * This is a special version of `reflect` that performs the same operation, but returns a NULL @@ -3732,7 +3858,9 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def random(seed: Column): Column = call_function("random", seed) + def random(seed: Column): Column = withExpr { + call_functionImpl("random", seed) + } /** * Returns a random value with independent and identically distributed (i.i.d.) uniformly @@ -3741,7 +3869,9 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def random(): Column = call_function("random") + def random(): Column = withExpr { + call_functionImpl("random") + } /** * Returns the bucket number for the given input column. @@ -3848,7 +3978,7 @@ object functions { * @since 1.5.0 */ def decode(value: Column, charset: String): Column = withExpr { - StringDecode(value.expr, lit(charset).expr) + StringDecode(value.expr, litImpl(charset).expr) } /** @@ -3860,7 +3990,7 @@ object functions { * @since 1.5.0 */ def encode(value: Column, charset: String): Column = withExpr { - Encode(value.expr, lit(charset).expr) + Encode(value.expr, litImpl(charset).expr) } /** @@ -3874,7 +4004,7 @@ object functions { * @since 1.5.0 */ def format_number(x: Column, d: Int): Column = withExpr { - FormatNumber(x.expr, lit(d).expr) + FormatNumber(x.expr, litImpl(d).expr) } /** @@ -3885,7 +4015,7 @@ object functions { */ @scala.annotation.varargs def format_string(format: String, arguments: Column*): Column = withExpr { - FormatString((lit(format) +: arguments).map(_.expr): _*) + FormatString((litImpl(format) +: arguments).map(_.expr): _*) } /** @@ -3910,7 +4040,7 @@ object functions { * @since 1.5.0 */ def instr(str: Column, substring: String): Column = withExpr { - StringInstr(str.expr, lit(substring).expr) + StringInstr(str.expr, litImpl(substring).expr) } /** @@ -3969,7 +4099,7 @@ object functions { * @since 1.5.0 */ def locate(substr: String, str: Column): Column = withExpr { - new StringLocate(lit(substr).expr, str.expr) + new StringLocate(litImpl(substr).expr, str.expr) } /** @@ -3982,7 +4112,7 @@ object functions { * @since 1.5.0 */ def locate(substr: String, str: Column, pos: Int): Column = withExpr { - StringLocate(lit(substr).expr, str.expr, lit(pos).expr) + StringLocate(litImpl(substr).expr, str.expr, litImpl(pos).expr) } /** @@ -3993,7 +4123,7 @@ object functions { * @since 1.5.0 */ def lpad(str: Column, len: Int, pad: String): Column = withExpr { - StringLPad(str.expr, lit(len).expr, lit(pad).expr) + StringLPad(str.expr, litImpl(len).expr, litImpl(pad).expr) } /** @@ -4003,8 +4133,9 @@ object functions { * @group string_funcs * @since 3.3.0 */ - def lpad(str: Column, len: Int, pad: Array[Byte]): Column = - call_function("lpad", str, lit(len), lit(pad)) + def lpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { + call_functionImpl("lpad", str, litImpl(len), litImpl(pad)) + } /** * Trim the spaces from left end for the specified string value. @@ -4047,8 +4178,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp(str: Column, regexp: Column): Column = - call_function("regexp", str, regexp) + def regexp(str: Column, regexp: Column): Column = withExpr { + call_functionImpl("regexp", str, regexp) + } /** * Returns true if `str` matches `regexp`, or false otherwise. @@ -4056,8 +4188,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_like(str: Column, regexp: Column): Column = - call_function("regexp_like", str, regexp) + def regexp_like(str: Column, regexp: Column): Column = withExpr { + call_functionImpl("regexp_like", str, regexp) + } /** * Returns a count of the number of times that the regular expression pattern `regexp` @@ -4080,7 +4213,7 @@ object functions { * @since 1.5.0 */ def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = withExpr { - RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) + RegExpExtract(e.expr, litImpl(exp).expr, litImpl(groupIdx).expr) } /** @@ -4112,7 +4245,7 @@ object functions { * @since 1.5.0 */ def regexp_replace(e: Column, pattern: String, replacement: String): Column = withExpr { - RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) + RegExpReplace(e.expr, litImpl(pattern).expr, litImpl(replacement).expr) } /** @@ -4177,7 +4310,7 @@ object functions { * @since 1.5.0 */ def rpad(str: Column, len: Int, pad: String): Column = withExpr { - StringRPad(str.expr, lit(len).expr, lit(pad).expr) + StringRPad(str.expr, litImpl(len).expr, litImpl(pad).expr) } /** @@ -4187,8 +4320,9 @@ object functions { * @group string_funcs * @since 3.3.0 */ - def rpad(str: Column, len: Int, pad: Array[Byte]): Column = - call_function("rpad", str, lit(len), lit(pad)) + def rpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { + call_functionImpl("rpad", str, litImpl(len), litImpl(pad)) + } /** * Repeats a string column n times, and returns it as a new string column. @@ -4197,7 +4331,7 @@ object functions { * @since 1.5.0 */ def repeat(str: Column, n: Int): Column = withExpr { - StringRepeat(str.expr, lit(n).expr) + StringRepeat(str.expr, litImpl(n).expr) } /** @@ -4282,7 +4416,7 @@ object functions { * @since 1.5.0 */ def substring(str: Column, pos: Int, len: Int): Column = withExpr { - Substring(str.expr, lit(pos).expr, lit(len).expr) + Substring(str.expr, litImpl(pos).expr, litImpl(len).expr) } /** @@ -4294,7 +4428,7 @@ object functions { * @group string_funcs */ def substring_index(str: Column, delim: String, count: Int): Column = withExpr { - SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + SubstringIndex(str.expr, litImpl(delim).expr, litImpl(count).expr) } /** @@ -4348,7 +4482,7 @@ object functions { * @since 1.5.0 */ def translate(src: Column, matchingString: String, replaceString: String): Column = withExpr { - StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) + StringTranslate(src.expr, litImpl(matchingString).expr, litImpl(replaceString).expr) } /** @@ -4424,7 +4558,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def to_char(e: Column, format: Column): Column = call_function("to_char", e, format) + def to_char(e: Column, format: Column): Column = withExpr { + call_functionImpl("to_char", e, format) + } /** * Convert `e` to a string based on the `format`. @@ -4450,7 +4586,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def to_varchar(e: Column, format: Column): Column = call_function("to_varchar", e, format) + def to_varchar(e: Column, format: Column): Column = withExpr { + call_functionImpl("to_varchar", e, format) + } /** * Convert string 'e' to a number based on the string format 'format'. @@ -4533,8 +4671,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def substr(str: Column, pos: Column, len: Column): Column = - call_function("substr", str, pos, len) + def substr(str: Column, pos: Column, len: Column): Column = withExpr { + call_functionImpl("substr", str, pos, len) + } /** * Returns the substring of `str` that starts at `pos`, @@ -4543,8 +4682,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def substr(str: Column, pos: Column): Column = - call_function("substr", str, pos) + def substr(str: Column, pos: Column): Column = withExpr { + call_functionImpl("substr", str, pos) + } /** * Extracts a part from a URL. @@ -4572,8 +4712,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def printf(format: Column, arguments: Column*): Column = - call_function("printf", (format +: arguments): _*) + def printf(format: Column, arguments: Column*): Column = withExpr { + call_functionImpl("printf", (format +: arguments): _*) + } /** * Decodes a `str` in 'application/x-www-form-urlencoded' format @@ -4604,8 +4745,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def position(substr: Column, str: Column, start: Column): Column = - call_function("position", substr, str, start) + def position(substr: Column, str: Column, start: Column): Column = withExpr { + call_functionImpl("position", substr, str, start) + } /** * Returns the position of the first occurrence of `substr` in `str` after position `1`. @@ -4614,8 +4756,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def position(substr: Column, str: Column): Column = - call_function("position", substr, str) + def position(substr: Column, str: Column): Column = withExpr { + call_functionImpl("position", substr, str) + } /** * Returns a boolean. The value is True if str ends with suffix. @@ -4625,7 +4768,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def endswith(str: Column, suffix: Column): Column = call_function("endswith", str, suffix) + def endswith(str: Column, suffix: Column): Column = withExpr { + call_functionImpl("endswith", str, suffix) + } /** * Returns a boolean. The value is True if str starts with prefix. @@ -4635,7 +4780,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def startswith(str: Column, prefix: Column): Column = call_function("startswith", str, prefix) + def startswith(str: Column, prefix: Column): Column = withExpr { + call_functionImpl("startswith", str, prefix) + } /** * Returns the ASCII character having the binary equivalent to `n`. @@ -4644,7 +4791,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def char(n: Column): Column = call_function("char", n) + def char(n: Column): Column = withExpr { + call_functionImpl("char", n) + } /** * Removes the leading and trailing space characters from `str`. @@ -4708,7 +4857,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def char_length(str: Column): Column = call_function("char_length", str) + def char_length(str: Column): Column = withExpr { + call_functionImpl("char_length", str) + } /** * Returns the character length of string data or number of bytes of binary data. @@ -4718,7 +4869,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def character_length(str: Column): Column = call_function("character_length", str) + def character_length(str: Column): Column = withExpr { + call_functionImpl("character_length", str) + } /** * Returns the ASCII character having the binary equivalent to `n`. @@ -4739,7 +4892,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def contains(left: Column, right: Column): Column = call_function("contains", left, right) + def contains(left: Column, right: Column): Column = withExpr { + call_functionImpl("contains", left, right) + } /** * Returns the `n`-th input, e.g., returns `input2` when `n` is 2. @@ -4827,15 +4982,18 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def lcase(str: Column): Column = call_function("lcase", str) - + def lcase(str: Column): Column = withExpr { + call_functionImpl("lcase", str) + } /** * Returns `str` with all characters changed to uppercase. * * @group string_funcs * @since 3.5.0 */ - def ucase(str: Column): Column = call_function("ucase", str) + def ucase(str: Column): Column = withExpr { + call_functionImpl("ucase", str) + } /** * Returns the leftmost `len`(`len` can be string type) characters from the string `str`, @@ -4873,7 +5031,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def add_months(startDate: Column, numMonths: Int): Column = add_months(startDate, lit(numMonths)) + def add_months(startDate: Column, numMonths: Int): Column = withExpr { + AddMonths(startDate.expr, litImpl(numMonths).expr) + } /** * Returns the date that is `numMonths` after `startDate`. @@ -4897,7 +5057,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def curdate(): Column = call_function("curdate") + def curdate(): Column = withExpr { + call_functionImpl("curdate") + } /** * Returns the current date at the start of query evaluation as a date column. @@ -4975,7 +5137,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def date_add(start: Column, days: Int): Column = date_add(start, lit(days)) + def date_add(start: Column, days: Int): Column = withExpr { + DateAdd(start.expr, litImpl(days).expr) + } /** * Returns the date that is `days` days after `start` @@ -4999,8 +5163,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def dateadd(start: Column, days: Column): Column = - call_function("dateadd", start, days) + def dateadd(start: Column, days: Column): Column = withExpr { + call_functionImpl("dateadd", start, days) + } /** * Returns the date that is `days` days before `start` @@ -5012,7 +5177,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def date_sub(start: Column, days: Int): Column = date_sub(start, lit(days)) + def date_sub(start: Column, days: Int): Column = withExpr { + DateSub(start.expr, litImpl(days).expr) + } /** * Returns the date that is `days` days before `start` @@ -5065,8 +5232,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_diff(end: Column, start: Column): Column = - call_function("date_diff", end, start) + def date_diff(end: Column, start: Column): Column = withExpr { + call_functionImpl("date_diff", end, start) + } /** * Create date from the number of `days` since 1970-01-01. @@ -5123,7 +5291,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def day(e: Column): Column = call_function("day", e) + def day(e: Column): Column = withExpr { + call_functionImpl("day", e) + } /** * Extracts the day of the year as an integer from a given date/timestamp/string. @@ -5150,7 +5320,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def extract(field: Column, source: Column): Column = call_function("extract", field, source) + def extract(field: Column, source: Column): Column = withExpr { + call_functionImpl("extract", field, source) + } /** * Extracts a part of the date/timestamp or interval source. @@ -5162,8 +5334,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_part(field: Column, source: Column): Column = - call_function("date_part", field, source) + def date_part(field: Column, source: Column): Column = withExpr { + call_functionImpl("date_part", field, source) + } /** * Extracts a part of the date/timestamp or interval source. @@ -5175,8 +5348,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def datepart(field: Column, source: Column): Column = - call_function("datepart", field, source) + def datepart(field: Column, source: Column): Column = withExpr { + call_functionImpl("datepart", field, source) + } /** * Returns the last day of the month which the given date belongs to. @@ -5249,7 +5423,7 @@ object functions { * @since 2.4.0 */ def months_between(end: Column, start: Column, roundOff: Boolean): Column = withExpr { - MonthsBetween(end.expr, start.expr, lit(roundOff).expr) + MonthsBetween(end.expr, start.expr, litImpl(roundOff).expr) } /** @@ -5267,7 +5441,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def next_day(date: Column, dayOfWeek: String): Column = next_day(date, lit(dayOfWeek)) + def next_day(date: Column, dayOfWeek: String): Column = withExpr { + NextDay(date.expr, litImpl(dayOfWeek).expr) + } /** * Returns the first date which is later than the value of the `date` column that is on the @@ -5387,7 +5563,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } + def unix_timestamp(s: Column, p: String): Column = withExpr { + UnixTimestamp(s.expr, Literal(p)) + } /** * Converts to a timestamp by casting rules to `TimestampType`. @@ -5429,8 +5607,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def try_to_timestamp(s: Column, format: Column): Column = - call_function("try_to_timestamp", s, format) + def try_to_timestamp(s: Column, format: Column): Column = withExpr { + call_functionImpl("try_to_timestamp", s, format) + } /** * Parses the `s` to a timestamp. The function always returns null on an invalid @@ -5440,8 +5619,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def try_to_timestamp(s: Column): Column = - call_function("try_to_timestamp", s) + def try_to_timestamp(s: Column): Column = withExpr { + call_functionImpl("try_to_timestamp", s) + } /** * Converts the column into `DateType` by casting rules to `DateType`. @@ -5717,7 +5897,9 @@ object functions { * @since 2.0.0 */ def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = { - window(timeColumn, windowDuration, slideDuration, "0 second") + withExpr { + TimeWindow(timeColumn.expr, windowDuration, slideDuration, "0 second") + }.as("window") } /** @@ -5754,7 +5936,9 @@ object functions { * @since 2.0.0 */ def window(timeColumn: Column, windowDuration: String): Column = { - window(timeColumn, windowDuration, windowDuration, "0 second") + withExpr { + TimeWindow(timeColumn.expr, windowDuration, windowDuration, "0 second") + }.as("window") } /** @@ -5878,8 +6062,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ltz(timestamp: Column, format: Column): Column = - call_function("to_timestamp_ltz", timestamp, format) + def to_timestamp_ltz(timestamp: Column, format: Column): Column = withExpr { + call_functionImpl("to_timestamp_ltz", timestamp, format) + } /** * Parses the `timestamp` expression with the default format to a timestamp without time zone. @@ -5888,8 +6073,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ltz(timestamp: Column): Column = - call_function("to_timestamp_ltz", timestamp) + def to_timestamp_ltz(timestamp: Column): Column = withExpr { + call_functionImpl("to_timestamp_ltz", timestamp) + } /** * Parses the `timestamp_str` expression with the `format` expression @@ -5898,8 +6084,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ntz(timestamp: Column, format: Column): Column = - call_function("to_timestamp_ntz", timestamp, format) + def to_timestamp_ntz(timestamp: Column, format: Column): Column = withExpr { + call_functionImpl("to_timestamp_ntz", timestamp, format) + } /** * Parses the `timestamp` expression with the default format to a timestamp without time zone. @@ -5908,8 +6095,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ntz(timestamp: Column): Column = - call_function("to_timestamp_ntz", timestamp) + def to_timestamp_ntz(timestamp: Column): Column = withExpr { + call_functionImpl("to_timestamp_ntz", timestamp) + } /** * Returns the UNIX timestamp of the given time. @@ -5941,7 +6129,7 @@ object functions { * @since 1.5.0 */ def array_contains(column: Column, value: Any): Column = withExpr { - ArrayContains(column.expr, lit(value).expr) + ArrayContains(column.expr, litImpl(value).expr) } /** @@ -5952,7 +6140,7 @@ object functions { * @since 3.4.0 */ def array_append(column: Column, element: Any): Column = withExpr { - ArrayAppend(column.expr, lit(element).expr) + ArrayAppend(column.expr, litImpl(element).expr) } @@ -5978,8 +6166,9 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def slice(x: Column, start: Int, length: Int): Column = - slice(x, lit(start), lit(length)) + def slice(x: Column, start: Int, length: Int): Column = withExpr { + Slice(x.expr, litImpl(start).expr, litImpl(length).expr) + } /** * Returns an array containing all the elements in `x` from index `start` (or starting from the @@ -6038,7 +6227,7 @@ object functions { * @since 2.4.0 */ def array_position(column: Column, value: Any): Column = withExpr { - ArrayPosition(column.expr, lit(value).expr) + ArrayPosition(column.expr, litImpl(value).expr) } /** @@ -6049,7 +6238,7 @@ object functions { * @since 2.4.0 */ def element_at(column: Column, value: Any): Column = withExpr { - ElementAt(column.expr, lit(value).expr) + ElementAt(column.expr, litImpl(value).expr) } /** @@ -6108,7 +6297,7 @@ object functions { * @since 2.4.0 */ def array_remove(column: Column, element: Any): Column = withExpr { - ArrayRemove(column.expr, lit(element).expr) + ArrayRemove(column.expr, litImpl(element).expr) } /** @@ -6129,7 +6318,7 @@ object functions { * @since 3.5.0 */ def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) + ArrayPrepend(column.expr, litImpl(element).expr) } /** @@ -6347,8 +6536,15 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def aggregate(expr: Column, initialValue: Column, merge: (Column, Column) => Column): Column = - aggregate(expr, initialValue, merge, c => c) + def aggregate(expr: Column, initialValue: Column, merge: (Column, Column) => Column): Column = { + withExpr { + ArrayAggregate( + expr.expr, + initialValue.expr, + createLambda(merge), + createLambda(identity(_))) + } + } /** * Applies a binary operator to an initial state and all elements in the array, @@ -6372,7 +6568,13 @@ object functions { expr: Column, initialValue: Column, merge: (Column, Column) => Column, - finish: Column => Column): Column = aggregate(expr, initialValue, merge, finish) + finish: Column => Column): Column = withExpr { + ArrayAggregate( + expr.expr, + initialValue.expr, + createLambda(merge), + createLambda(finish)) + } /** * Applies a binary operator to an initial state and all elements in the array, @@ -6388,8 +6590,15 @@ object functions { * @group collection_funcs * @since 3.5.0 */ - def reduce(expr: Column, initialValue: Column, merge: (Column, Column) => Column): Column = - aggregate(expr, initialValue, merge, c => c) + def reduce(expr: Column, initialValue: Column, merge: (Column, Column) => Column): Column = { + withExpr { + ArrayAggregate( + expr.expr, + initialValue.expr, + createLambda(merge), + createLambda(identity(_))) + } + } /** * Merge two given arrays, element-wise, into a single array using a function. @@ -6548,7 +6757,7 @@ object functions { * @since 1.6.0 */ def get_json_object(e: Column, path: String): Column = withExpr { - GetJsonObject(e.expr, lit(path).expr) + GetJsonObject(e.expr, litImpl(path).expr) } /** @@ -6581,8 +6790,9 @@ object functions { * @since 2.1.0 */ // scalastyle:on line.size.limit - def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = - from_json(e, schema.asInstanceOf[DataType], options) + def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + JsonToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), options, e.expr) + } // scalastyle:off line.size.limit /** @@ -6625,8 +6835,11 @@ object functions { * @since 2.1.0 */ // scalastyle:on line.size.limit - def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column = - from_json(e, schema, options.asScala.toMap) + def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column = { + withExpr { + JsonToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), options.asScala.toMap, e.expr) + } + } // scalastyle:off line.size.limit /** @@ -6648,7 +6861,9 @@ object functions { */ // scalastyle:on line.size.limit def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column = { - from_json(e, CharVarcharUtils.failIfHasCharVarchar(schema), options.asScala.toMap) + withExpr { + JsonToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), options.asScala.toMap, e.expr) + } } /** @@ -6661,8 +6876,9 @@ object functions { * @group collection_funcs * @since 2.1.0 */ - def from_json(e: Column, schema: StructType): Column = - from_json(e, schema, Map.empty[String, String]) + def from_json(e: Column, schema: StructType): Column = withExpr { + JsonToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), Map.empty[String, String], e.expr) + } /** * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, @@ -6675,8 +6891,9 @@ object functions { * @group collection_funcs * @since 2.2.0 */ - def from_json(e: Column, schema: DataType): Column = - from_json(e, schema, Map.empty[String, String]) + def from_json(e: Column, schema: DataType): Column = withExpr { + JsonToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), Map.empty[String, String], e.expr) + } // scalastyle:off line.size.limit /** @@ -6698,7 +6915,11 @@ object functions { */ // scalastyle:on line.size.limit def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = { - from_json(e, schema, options.asScala.toMap) + val dataType = parseTypeWithFallback( + schema, + DataType.fromJson, + fallbackParser = DataType.fromDDL) + withExpr { JsonToStructs(dataType, options.asScala.toMap, e.expr) } } // scalastyle:off line.size.limit @@ -6725,7 +6946,7 @@ object functions { schema, DataType.fromJson, fallbackParser = DataType.fromDDL) - from_json(e, dataType, options) + withExpr { JsonToStructs(dataType, options, e.expr) } } /** @@ -6739,8 +6960,8 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def from_json(e: Column, schema: Column): Column = { - from_json(e, schema, Map.empty[String, String].asJava) + def from_json(e: Column, schema: Column): Column = withExpr { + new JsonToStructs(e.expr, schema.expr, Map.empty[String, String]) } // scalastyle:off line.size.limit @@ -6774,7 +6995,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def schema_of_json(json: String): Column = schema_of_json(lit(json)) + def schema_of_json(json: String): Column = schema_of_json(litImpl(json)) /** * Parses a JSON string and infers its schema in DDL format. @@ -6874,8 +7095,9 @@ object functions { * @since 2.1.0 */ // scalastyle:on line.size.limit - def to_json(e: Column, options: java.util.Map[String, String]): Column = - to_json(e, options.asScala.toMap) + def to_json(e: Column, options: java.util.Map[String, String]): Column = withExpr { + StructsToJson(options.asScala.toMap, e.expr) + } /** * Converts a column containing a `StructType`, `ArrayType` or @@ -6887,8 +7109,9 @@ object functions { * @group collection_funcs * @since 2.1.0 */ - def to_json(e: Column): Column = - to_json(e, Map.empty[String, String]) + def to_json(e: Column): Column = withExpr { + StructsToJson(Map.empty, e.expr) + } /** * Masks the given string value. The function replaces characters with 'X' or 'x', and numbers @@ -7014,7 +7237,9 @@ object functions { * @group collection_funcs * @since 3.5.0 */ - def cardinality(e: Column): Column = call_function("cardinality", e) + def cardinality(e: Column): Column = withExpr { + call_functionImpl("cardinality", e) + } /** * Sorts the input array for the given column in ascending order, @@ -7036,7 +7261,9 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + def sort_array(e: Column, asc: Boolean): Column = withExpr { + SortArray(e.expr, litImpl(asc).expr) + } /** * Returns the minimum value in the array. NaN is greater than any non-NaN elements for @@ -7072,7 +7299,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def array_agg(e: Column): Column = call_function("array_agg", e) + def array_agg(e: Column): Column = withExpr { + call_functionImpl("array_agg", e) + } /** * Returns a random permutation of the given array. @@ -7138,7 +7367,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count)) + def array_repeat(e: Column, count: Int): Column = array_repeat(e, litImpl(count)) /** * Returns true if the map contains the key. @@ -7146,7 +7375,7 @@ object functions { * @since 3.3.0 */ def map_contains_key(column: Column, key: Any): Column = withExpr { - ArrayContains(MapKeys(column.expr), lit(key).expr) + ArrayContains(MapKeys(column.expr), litImpl(key).expr) } /** @@ -7247,7 +7476,7 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def schema_of_csv(csv: String): Column = schema_of_csv(lit(csv)) + def schema_of_csv(csv: String): Column = schema_of_csv(litImpl(csv)) /** * Parses a CSV string and infers its schema in DDL format. @@ -7367,8 +7596,9 @@ object functions { * @group collection_funcs * @since */ - def from_xml(e: Column, schema: StructType): Column = - from_xml(e, schema, Map.empty[String, String]) + def from_xml(e: Column, schema: StructType): Column = withExpr { + XmlToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), Map.empty, e.expr) + } /** * Parses a XML string and infers its schema in DDL format. @@ -7377,7 +7607,7 @@ object functions { * @group collection_funcs * @since 4.0.0 */ - def schema_of_xml(xml: String): Column = schema_of_xml(lit(xml)) + def schema_of_xml(xml: String): Column = schema_of_xml(litImpl(xml)) /** * Parses a XML string and infers its schema in DDL format. @@ -7471,8 +7701,9 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_number(x: Column, p: Column): Column = - call_function("xpath_number", x, p) + def xpath_number(x: Column, p: Column): Column = withExpr { + call_functionImpl("xpath_number", x, p) + } /** * Returns a float value, the value zero if no match is found, @@ -7771,9 +8002,10 @@ object functions { hours: Column, mins: Column, secs: Column, - timezone: Column): Column = - call_function("make_timestamp_ltz", + timezone: Column): Column = withExpr { + call_functionImpl("make_timestamp_ltz", years, months, days, hours, mins, secs, timezone) + } /** * Create the current timestamp with local time zone from years, months, days, hours, mins and @@ -7789,9 +8021,10 @@ object functions { days: Column, hours: Column, mins: Column, - secs: Column): Column = - call_function("make_timestamp_ltz", + secs: Column): Column = withExpr { + call_functionImpl("make_timestamp_ltz", years, months, days, hours, mins, secs) + } /** * Create local date-time from years, months, days, hours, mins, secs fields. If the @@ -7807,9 +8040,10 @@ object functions { days: Column, hours: Column, mins: Column, - secs: Column): Column = - call_function("make_timestamp_ntz", + secs: Column): Column = withExpr { + call_functionImpl("make_timestamp_ntz", years, months, days, hours, mins, secs) + } /** * Make year-month interval from years, months. @@ -7876,8 +8110,9 @@ object functions { * @group predicates_funcs * @since 3.5.0 */ - def ifnull(col1: Column, col2: Column): Column = - call_function("ifnull", col1, col2) + def ifnull(col1: Column, col2: Column): Column = withExpr { + call_functionImpl("ifnull", col1, col2) + } /** * Returns true if `col` is not null, or false otherwise. @@ -8431,8 +8666,9 @@ object functions { */ @scala.annotation.varargs @deprecated("Use call_udf") - def callUDF(udfName: String, cols: Column*): Column = - call_function(Seq(udfName), cols: _*) + def callUDF(udfName: String, cols: Column*): Column = withExpr { + call_function(Seq(udfName), cols) + } /** * Call an user-defined function. @@ -8450,8 +8686,17 @@ object functions { * @since 3.2.0 */ @scala.annotation.varargs - def call_udf(udfName: String, cols: Column*): Column = - call_function(Seq(udfName), cols: _*) + def call_udf(udfName: String, cols: Column*): Column = withExpr { + call_function(Seq(udfName), cols) + } + + private def call_functionImpl(funcName: String, cols: Column*) = { + val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { + new SparkSqlParser() + } + val nameParts = parser.parseMultipartIdentifier(funcName) + call_function(nameParts, cols) + } /** * Call a SQL function. @@ -8462,15 +8707,11 @@ object functions { * @since 3.5.0 */ @scala.annotation.varargs - def call_function(funcName: String, cols: Column*): Column = { - val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser() - } - val nameParts = parser.parseMultipartIdentifier(funcName) - call_function(nameParts, cols: _*) + def call_function(funcName: String, cols: Column*): Column = withExpr { + call_functionImpl(funcName, cols: _*) } - private def call_function(nameParts: Seq[String], cols: Column*): Column = withExpr { + private def call_function(nameParts: Seq[String], cols: Seq[Column]) = { UnresolvedFunction(nameParts, cols.map(_.expr), false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 1794ac513749f..a1c3ef62f4345 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -18,6 +18,8 @@ package org.apache.spark import org.apache.spark.annotation.{DeveloperApi, Unstable} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.execution.SparkStrategy /** @@ -73,4 +75,24 @@ package object sql { * with rebasing. */ private[sql] val SPARK_LEGACY_INT96_METADATA_KEY = "org.apache.spark.legacyINT96" + + private[sql] def withOrigin[T](f: => T, framesToDrop: Int = 0): T = { + CurrentOrigin.withOrigin(Origin.fromCurrentStackTrace(framesToDrop + 1))(f) + } + + private[sql] def withExpr(f: => Expression, framesToDrop: Int = 0) = { + withOrigin(Column(f), framesToDrop + 1) + } + + private[sql] def litImpl(literal: Any): Column = literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => + // This is different from `typedlit`. `typedlit` calls `Literal.create` to use + // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this + // method, `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, + // we can just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. + // This is significantly better when there are many threads calling `lit` concurrently. + Column(Literal(literal)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d78771a8f19bc..8ffb441ffdfbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.sql.QueryTest.getCurrentClassCallSitePattern import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -632,7 +633,9 @@ class DataFrameAggregateSuite extends QueryTest "functionName" -> "`collect_set`", "dataType" -> "\"MAP\"", "sqlExpr" -> "\"collect_set(b)\"" - ) + ), + context = + ExpectedContext(code = "collect_set", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -705,7 +708,8 @@ class DataFrameAggregateSuite extends QueryTest testData.groupBy(sum($"key")).count() }, errorClass = "GROUP_BY_AGGREGATE", - parameters = Map("sqlExpr" -> "sum(key)") + parameters = Map("sqlExpr" -> "sum(key)"), + context = ExpectedContext(code = "sum", callSitePattern = getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8ca14385e597f..8401d0b2c9c1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -25,6 +25,7 @@ import java.sql.{Date, Timestamp} import scala.util.Random import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkRuntimeException} +import org.apache.spark.sql.QueryTest.getCurrentClassCallSitePattern import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.{Alias, ArraysZip, AttributeReference, Expression, NamedExpression, UnaryExpression} @@ -174,7 +175,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"k\"", "inputType" -> "\"INT\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "map_from_arrays", callSitePattern = getCurrentClassCallSitePattern)) ) val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") @@ -761,7 +764,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("The given function only supports array input") { val df = Seq(1, 2, 3).toDF("a") - checkErrorMatchPVals( + checkError( exception = intercept[AnalysisException] { df.select(array_sort(col("a"), (x, y) => x - y)) }, @@ -772,7 +775,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + matchPVals = true, + queryContext = Array( + ExpectedContext(code = "array_sort", callSitePattern = getCurrentClassCallSitePattern)) + ) } test("sort_array/array_sort functions") { @@ -1304,7 +1311,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", "dataType" -> "(\"MAP, INT>\" or \"MAP\")", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext(code = "map_concat", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -1332,7 +1341,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, 12)\"", "dataType" -> "[\"MAP, INT>\", \"INT\"]", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext(code = "map_concat", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1401,7 +1412,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"a\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of pair \"STRUCT\"" - ) + ), + context = + ExpectedContext(code = "map_from_entries", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1438,7 +1451,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array_contains(a, NULL)\"", "functionName" -> "`array_contains`" - ) + ), + context = + ExpectedContext(code = "array_contains", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2347,7 +2362,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"ARRAY\"")) + "rightType" -> "\"ARRAY\""), + context = + ExpectedContext(code = "array_union", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { @@ -2378,7 +2395,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"VOID\"") + "rightType" -> "\"VOID\""), + context = + ExpectedContext(code = "array_union", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2409,7 +2428,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY>\"", - "rightType" -> "\"ARRAY\"") + "rightType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "array_union", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2646,7 +2667,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"arr\"", "inputType" -> "\"ARRAY\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + context = + ExpectedContext(code = "flatten", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2659,7 +2682,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2672,7 +2697,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"s\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2781,7 +2808,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"b\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + context = + ExpectedContext(code = "array_repeat", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2794,7 +2823,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"1\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_repeat", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3122,7 +3153,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = + ExpectedContext(code = "array_except", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3150,7 +3183,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3178,7 +3213,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"VOID\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3206,7 +3243,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3275,7 +3314,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = + ExpectedContext(code = "array_intersect", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3304,7 +3345,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_intersect", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3333,7 +3376,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_intersect", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3749,7 +3794,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(code = "map_filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -3932,7 +3979,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4111,7 +4160,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "exists", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4303,7 +4354,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "forall", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4342,7 +4395,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.select(forall(col("a"), x => x))), errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`")) + parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), + queryContext = Array( + ExpectedContext(code = "col", callSitePattern = getCurrentClassCallSitePattern))) } test("aggregate function - array for primitive type not containing null") { @@ -4580,7 +4635,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "aggregate", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit // scalastyle:off line.size.limit @@ -4718,7 +4775,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "functionName" -> "`map_zip_with`", "leftType" -> "\"INT\"", - "rightType" -> "\"MAP\"")) + "rightType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(code = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4748,7 +4807,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(i, mis, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "1", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(code = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4778,7 +4839,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, i, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "2", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(code = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -5234,7 +5297,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"x\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext( + code = "transform_values", + callSitePattern = getCurrentClassCallSitePattern))) } testInvalidLambdaFunctions() @@ -5374,7 +5441,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "zip_with", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -5630,7 +5699,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map(m, 1)\"", "keyType" -> "\"MAP\"" - ) + ), + context = + ExpectedContext(code = "map", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer( df.select(map(map_entries($"m"), lit(1))), @@ -5752,7 +5823,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + context = + ExpectedContext(code = "array_compact", callSitePattern = getCurrentClassCallSitePattern)) } test("array_append -> Unit Test cases for the function ") { @@ -5771,7 +5844,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "dataType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"INT\"", - "sqlExpr" -> "\"array_append(a, b)\"") + "sqlExpr" -> "\"array_append(a, b)\""), + context = + ExpectedContext(code = "array_append", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer(df1.selectExpr("array_append(a, 3)"), Seq(Row(Seq(3, 2, 5, 1, 2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index a57e927ba8427..f421880e4e29d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalatest.matchers.must.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.sql.QueryTest.getCurrentClassCallSitePattern import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Lag, Literal, NonFoldableLiteral} import org.apache.spark.sql.catalyst.optimizer.TransposeWindow import org.apache.spark.sql.catalyst.plans.logical.{Window => LogicalWindow} @@ -412,7 +413,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`invalid`", - "proposal" -> "`value`, `key`")) + "proposal" -> "`value`, `key`"), + context = ExpectedContext(code = "count", callSitePattern = getCurrentClassCallSitePattern)) } test("numerical aggregate functions on string column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e05b545f235ba..dd5a568f8c811 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2561,6 +2561,23 @@ class DatasetSuite extends QueryTest checkDataset(ds.filter(f(col("_1"))), Tuple1(ValueClass(2))) } + + test("SPARK-45022: exact DatasetQueryContext call site") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df = Seq(1).toDS + var callSitePattern: String = null + checkError( + exception = intercept[AnalysisException] { + callSitePattern = QueryTest.getNextLineCallSitePattern() + val c = col("a") + df.select(c) + }, + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`a`", "proposal" -> "`value`"), + context = ExpectedContext(code = "col", callSitePattern = callSitePattern)) + } + } } class DatasetLargeResultCollectingSuite extends QueryTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index abec582d43a30..669b5a24d34ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.QueryTest.getCurrentClassCallSitePattern import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} @@ -293,7 +294,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"array()\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"ARRAY\"") + "requiredType" -> "\"ARRAY\""), + context = ExpectedContext(code = "inline", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -331,7 +333,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(b))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext(code = "array", callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct('a), struct('b.alias("a"))))), @@ -346,7 +349,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(2))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext(code = "array", callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct('a), struct(lit(2).alias("a"))))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index c2c333a998b43..251dd97b8d417 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.util.TimeZone +import java.util.regex.Pattern import scala.collection.JavaConverters._ @@ -424,6 +425,17 @@ object QueryTest extends Assertions { case None => } } + + def getCurrentClassCallSitePattern: String = { + val cs = Thread.currentThread().getStackTrace()(2) + s"${cs.getClassName}\\..*\\(${cs.getFileName}:\\d+\\)" + } + + def getNextLineCallSitePattern(lines: Int = 1): String = { + val cs = Thread.currentThread().getStackTrace()(2) + Pattern.quote( + s"${cs.getClassName}.${cs.getMethodName}(${cs.getFileName}:${cs.getLineNumber + lines})") + } } class QueryTestSuite extends QueryTest with test.SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 8e9be5dcdced5..468fb40e62b53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.{SPARK_DOC_ROOT, SparkRuntimeException} +import org.apache.spark.sql.QueryTest.getCurrentClassCallSitePattern import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.execution.FormattedMode import org.apache.spark.sql.functions._ @@ -879,7 +880,9 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "funcName" -> s"`$funcName`", "paramName" -> "`format`", - "paramType" -> "\"STRING\"")) + "paramType" -> "\"STRING\""), + context = + ExpectedContext(code = funcName, callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { df2.select(func(col("input"), lit("invalid_format"))).collect() @@ -888,7 +891,9 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "parameter" -> "`format`", "functionName" -> s"`$funcName`", - "invalidFormat" -> "'invalid_format'")) + "invalidFormat" -> "'invalid_format'"), + context = + ExpectedContext(code = funcName, callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { sql(s"select $funcName('a', 'b', 'c')") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 7f938deaaa645..14e236acd73fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.errors import org.apache.spark.SPARK_DOC_ROOT import org.apache.spark.sql.{AnalysisException, ClassData, IntegratedUDFTestUtils, QueryTest, Row} +import org.apache.spark.sql.QueryTest.getCurrentClassCallSitePattern import org.apache.spark.sql.api.java.{UDF1, UDF2, UDF23Test} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -696,7 +697,9 @@ class QueryCompilationErrorsSuite Seq("""{"a":1}""").toDF("a").select(from_json($"a", IntegerType)).collect() }, errorClass = "DATATYPE_MISMATCH.INVALID_JSON_SCHEMA", - parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\"")) + parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\""), + context = + ExpectedContext(code = "from_json", callSitePattern = getCurrentClassCallSitePattern)) } test("WRONG_NUM_ARGS.WITHOUT_SUGGESTION: wrong args of CAST(parameter types contains DataType)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index ee28a90aed9af..779f6bfc4a92b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql.errors import org.apache.spark._ import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.QueryTest.getCurrentClassCallSitePattern import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime} +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.ByteType @@ -53,6 +56,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest sqlState = "22012", parameters = Map("config" -> ansiConf), context = ExpectedContext(fragment = "6/0", start = 7, stop = 9)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit(5) / lit(0)).collect() + }, + errorClass = "DIVIDE_BY_ZERO", + sqlState = "22012", + parameters = Map("config" -> ansiConf), + context = ExpectedContext(code = "div", callSitePattern = getCurrentClassCallSitePattern)) } test("INTERVAL_DIVIDED_BY_ZERO: interval divided by zero") { @@ -92,6 +104,19 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('66666666666666.666' AS DECIMAL(8, 1))", start = 7, stop = 49)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit("66666666666666.666").cast("DECIMAL(8, 1)")).collect() + }, + errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", + sqlState = "22003", + parameters = Map( + "value" -> "66666666666666.666", + "precision" -> "8", + "scale" -> "1", + "config" -> ansiConf), + context = ExpectedContext(code = "cast", callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX: get element from array") { @@ -102,6 +127,14 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest errorClass = "INVALID_ARRAY_INDEX", parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), context = ExpectedContext(fragment = "array(1, 2, 3, 4, 5)[8]", start = 7, stop = 29)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(lit(Array(1, 2, 3, 4, 5))(8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = ExpectedContext(code = "apply", callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX_IN_ELEMENT_AT: element_at from array") { @@ -115,6 +148,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "element_at(array(1, 2, 3, 4, 5), 8)", start = 7, stop = 41)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = + ExpectedContext(code = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_INDEX_OF_ZERO: element_at from array by index zero") { @@ -129,6 +171,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest start = 7, stop = 41) ) + + checkError( + exception = intercept[SparkRuntimeException]( + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 0)).collect() + ), + errorClass = "INVALID_INDEX_OF_ZERO", + parameters = Map.empty, + context = + ExpectedContext(code = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("CAST_INVALID_INPUT: cast string to double") { @@ -146,6 +197,18 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('111111111111xe23' AS DOUBLE)", start = 7, stop = 40)) + + checkError( + exception = intercept[SparkNumberFormatException] { + OneRowRelation().select(lit("111111111111xe23").cast("DOUBLE")).collect() + }, + errorClass = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'111111111111xe23'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"DOUBLE\"", + "ansiConfig" -> ansiConf), + context = ExpectedContext(code = "cast", callSitePattern = getCurrentClassCallSitePattern)) } test("CANNOT_PARSE_TIMESTAMP: parse string to timestamp") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 0ab292ee6c3ee..28e6492aa66f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -1207,7 +1207,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v5", fragment = "1/0", @@ -1226,7 +1226,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v1", fragment = "1/0",