From 90527f560462cc2d693176bd961b02767e460e06 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 8 May 2015 14:41:16 -0700 Subject: [PATCH] [SPARK-7390] [SQL] Only merge other CovarianceCounter when its count is greater than zero JIRA: https://issues.apache.org/jira/browse/SPARK-7390 Also fix a minor typo. Author: Liang-Chi Hsieh Closes #5931 from viirya/fix_covariancecounter and squashes the following commits: 352eda6 [Liang-Chi Hsieh] Only merge other CovarianceCounter when its count is greater than zero. --- .../sql/execution/stat/StatFunctions.scala | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 386ac969f1e7d..71b7f6c2a6756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -38,7 +38,7 @@ private[sql] object StatFunctions extends Logging { var yAvg = 0.0 // the mean of all examples seen so far in col2 var Ck = 0.0 // the co-moment after k examples var MkX = 0.0 // sum of squares of differences from the (current) mean for col1 - var MkY = 0.0 // sum of squares of differences from the (current) mean for col1 + var MkY = 0.0 // sum of squares of differences from the (current) mean for col2 var count = 0L // count of observed examples // add an example to the calculation def add(x: Double, y: Double): this.type = { @@ -55,15 +55,17 @@ private[sql] object StatFunctions extends Logging { // merge counters from other partitions. Formula can be found at: // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance def merge(other: CovarianceCounter): this.type = { - val totalCount = count + other.count - val deltaX = xAvg - other.xAvg - val deltaY = yAvg - other.yAvg - Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count - xAvg = (xAvg * count + other.xAvg * other.count) / totalCount - yAvg = (yAvg * count + other.yAvg * other.count) / totalCount - MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count - MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count - count = totalCount + if (other.count > 0) { + val totalCount = count + other.count + val deltaX = xAvg - other.xAvg + val deltaY = yAvg - other.yAvg + Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count + xAvg = (xAvg * count + other.xAvg * other.count) / totalCount + yAvg = (yAvg * count + other.yAvg * other.count) / totalCount + MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count + MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count + count = totalCount + } this } // return the sample covariance for the observed examples