Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sum of NULL values should return NULL #76

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
private val sum = MutableLiteral(null, calcType)

private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))

override def update(input: Row): Unit = {
sum.update(addFunction, input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ case class GeneratedAggregate(
Add(
Coalesce(currentSum :: zero :: Nil),
Cast(expr, calcType)
) :: currentSum :: zero :: Nil)
) :: currentSum :: Nil)
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}

import org.apache.spark.sql.types._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.TestData.NullInts
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.errors.DialectException
Expand All @@ -27,6 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation}
import org.apache.spark.sql.parquet.ParquetRelation2
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._

case class Nested1(f1: Nested2)
Expand Down Expand Up @@ -59,7 +61,9 @@ class MyDialect extends DefaultParserDialect
* Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is
* valid, but Hive currently cannot execute it.
*/
class SQLQuerySuite extends QueryTest {
class SQLQuerySuite extends QueryTest with SQLTestUtils {
override val sqlContext: SQLContext = TestHive

test("SPARK-6835: udtf in lateral view") {
val df = Seq((1, 1)).toDF("c1", "c2")
df.registerTempTable("table1")
Expand Down Expand Up @@ -946,4 +950,28 @@ class SQLQuerySuite extends QueryTest {

checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
}

test("SPARK-8828 sum should return null if all input values are null") {
val allNulls =
TestHive.sparkContext.parallelize(
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
NullInts(null) :: Nil).toDF()
allNulls.registerTempTable("allNulls")

withSQLConf(SQLConf.CODEGEN_ENABLED -> "true") {
checkAnswer(
sql("select sum(a), avg(a) from allNulls"),
Seq(Row(null, null))
)
}
withSQLConf(SQLConf.CODEGEN_ENABLED -> "false") {
checkAnswer(
sql("select sum(a), avg(a) from allNulls"),
Seq(Row(null, null))
)
}
}

}