Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-39446][MLLIB] Add relevance score for nDCG evaluation #36843

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,23 @@ import org.apache.spark.rdd.RDD
*
* Java users should use `RankingMetrics$.of` to create a [[RankingMetrics]] instance.
*
* @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
* @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pair
* or (predicted ranking, ground truth set,
* . relevance value of ground truth set).
* Since 3.4.0, it supports ndcg evaluation with relevance value.
*/
@Since("1.2.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, we need to update the @Since tags. I think you can somehow add a @Since annotation to the default constructor args? as it's since 3.4.0. I'm not sure exactly where it goes. The old constructor can remain since 1.2.0. If that doesn't work, maybe we can leave this constructor and add the new one as a new def this(...) since 3.4.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made changes and updated the docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should refer to BinaryClassificationMetrics and MulticlassMetrics, in which RDD[_ <: Product] was used as the input.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since end users now mainly use the .ml, is there any plan to expose this function to .ml?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry in advance for my poor understanding. I have some questions.

  • What is .ml ? Do you mean org.apache.spark.ml ?
  • RDD[_ <: Product] is used to makeMulticlassMetric class available to .ml ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first one - yeah this change is in the 'older' .mllib package. I don't think there is an equivalent for it in the DataFrame-based .ml packages, so, maybe we can ignore that here. But if nDCG is supported in the .ml package somewhere and I forgot it, would be good to add it there too.

The declaration suggested here might actually work for both input types without a separate constructor. Try it maybe? if it works, yes, that is simpler, and lets this API support even more inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your explanation!

Could you review this?
#36920

class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
extends Logging with Serializable {
class RankingMetrics[T: ClassTag] @Since("3.4.0") (
predictionAndLabels: RDD[(Array[T], Array[T], Array[Double])])
extends Logging
with Serializable {

@Since("1.2.0")
def this(predictionAndLabelsWithoutRelevance: => RDD[(Array[T], Array[T])]) = {
this(predictionAndLabelsWithoutRelevance.map {
case (pred, lab) => (pred, lab, Array.empty[Double])
})
}

/**
* Compute the average precision of all the queries, truncated at ranking position k.
Expand All @@ -58,7 +70,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
@Since("1.2.0")
def precisionAt(k: Int): Double = {
srowen marked this conversation as resolved.
Show resolved Hide resolved
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
predictionAndLabels.map { case (pred, lab, _) =>
countRelevantItemRatio(pred, lab, k, k)
}.mean()
}
Expand All @@ -70,7 +82,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
*/
@Since("1.2.0")
lazy val meanAveragePrecision: Double = {
predictionAndLabels.map { case (pred, lab) =>
predictionAndLabels.map { case (pred, lab, _) =>
val labSet = lab.toSet
val k = math.max(pred.length, labSet.size)
averagePrecision(pred, labSet, k)
Expand All @@ -87,7 +99,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
@Since("3.0.0")
def meanAveragePrecisionAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
predictionAndLabels.map { case (pred, lab, _) =>
averagePrecision(pred, lab.toSet, k)
}.mean()
}
Expand Down Expand Up @@ -127,7 +139,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
* The discounted cumulative gain at position k is computed as:
* sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
* and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
* implementation, the relevance value is binary.
* implementation, the relevance value is binary if the relevance value is empty.

* If a query has an empty ground truth set, zero will be used as ndcg together with
* a log warning.
Expand All @@ -142,8 +154,15 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
@Since("1.2.0")
def ndcgAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
predictionAndLabels.map { case (pred, lab, rel) =>
val useBinary = rel.isEmpty
val labSet = lab.toSet
val relMap = lab.zip(rel).toMap
if (useBinary && lab.size != rel.size) {
logWarning(
"# of ground truth set and # of relevance value set should be equal, " +
"check input data")
}

if (labSet.nonEmpty) {
val labSetSize = labSet.size
Expand All @@ -152,18 +171,32 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
var dcg = 0.0
var i = 0
while (i < n) {
// Base of the log doesn't matter for calculating NDCG,
// if the relevance value is binary.
val gain = 1.0 / math.log(i + 2)
if (i < pred.length && labSet.contains(pred(i))) {
dcg += gain
}
if (i < labSetSize) {
maxDcg += gain
if (useBinary) {
// Base of the log doesn't matter for calculating NDCG,
// if the relevance value is binary.
val gain = 1.0 / math.log(i + 2)
if (i < pred.length && labSet.contains(pred(i))) {
dcg += gain
}
if (i < labSetSize) {
maxDcg += gain
}
} else {
if (i < pred.length) {
dcg += (math.pow(2.0, relMap.getOrElse(pred(i), 0.0)) - 1) / math.log(i + 2)
}
if (i < labSetSize) {
maxDcg += (math.pow(2.0, relMap.getOrElse(lab(i), 0.0)) - 1) / math.log(i + 2)
}
}
i += 1
}
dcg / maxDcg
if (maxDcg == 0.0) {
logWarning("Maximum of relevance of ground truth set is zero, check input data")
0.0
} else {
dcg / maxDcg
}
} else {
logWarning("Empty ground truth set, check input data")
0.0
Expand Down Expand Up @@ -191,7 +224,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
@Since("3.0.0")
def recallAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
predictionAndLabels.map { case (pred, lab) =>
predictionAndLabels.map { case (pred, lab, _) =>
countRelevantItemRatio(pred, lab, k, lab.toSet.size)
}.mean()
}
Expand All @@ -207,10 +240,11 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
* @param denominator the denominator of ratio
* @return relevant item ratio at the first k ranking positions
*/
private def countRelevantItemRatio(pred: Array[T],
lab: Array[T],
k: Int,
denominator: Int): Double = {
private def countRelevantItemRatio(
pred: Array[T],
lab: Array[T],
k: Int,
denominator: Int): Double = {
val labSet = lab.toSet
if (labSet.nonEmpty) {
val n = math.min(pred.length, k)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,89 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
Seq(
(Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)),
(Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array(1, 2, 3)),
(Array(1, 2, 3, 4, 5), Array.empty[Int])
), 2)
val eps = 1.0E-5
(Array(1, 2, 3, 4, 5), Array.empty[Int])),
2)
val eps = 1.0e-5

val metrics = new RankingMetrics(predictionAndLabels)
val map = metrics.meanAveragePrecision

assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps)
assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps)
assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps)
assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps)
assert(metrics.precisionAt(1) ~== 1.0 / 3 absTol eps)
assert(metrics.precisionAt(2) ~== 1.0 / 3 absTol eps)
assert(metrics.precisionAt(3) ~== 1.0 / 3 absTol eps)
assert(metrics.precisionAt(4) ~== 0.75 / 3 absTol eps)
assert(metrics.precisionAt(5) ~== 0.8 / 3 absTol eps)
assert(metrics.precisionAt(10) ~== 0.8 / 3 absTol eps)
assert(metrics.precisionAt(15) ~== 8.0 / 45 absTol eps)

assert(map ~== 0.355026 absTol eps)

assert(metrics.meanAveragePrecisionAt(1) ~== 0.333334 absTol eps)
assert(metrics.meanAveragePrecisionAt(2) ~== 0.25 absTol eps)
assert(metrics.meanAveragePrecisionAt(3) ~== 0.24074 absTol eps)

assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps)
assert(metrics.ndcgAt(3) ~== 1.0 / 3 absTol eps)
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)

assert(metrics.recallAt(1) ~== 1.0/15 absTol eps)
assert(metrics.recallAt(2) ~== 8.0/45 absTol eps)
assert(metrics.recallAt(3) ~== 11.0/45 absTol eps)
assert(metrics.recallAt(4) ~== 11.0/45 absTol eps)
assert(metrics.recallAt(5) ~== 16.0/45 absTol eps)
assert(metrics.recallAt(10) ~== 2.0/3 absTol eps)
assert(metrics.recallAt(15) ~== 2.0/3 absTol eps)
assert(metrics.recallAt(1) ~== 1.0 / 15 absTol eps)
assert(metrics.recallAt(2) ~== 8.0 / 45 absTol eps)
assert(metrics.recallAt(3) ~== 11.0 / 45 absTol eps)
assert(metrics.recallAt(4) ~== 11.0 / 45 absTol eps)
assert(metrics.recallAt(5) ~== 16.0 / 45 absTol eps)
assert(metrics.recallAt(10) ~== 2.0 / 3 absTol eps)
assert(metrics.recallAt(15) ~== 2.0 / 3 absTol eps)
}

test("MAP, NDCG, Recall with few predictions (SPARK-14886)") {
test("Ranking metrics: NDCG with relevance") {
val predictionAndLabels = sc.parallelize(
Seq(
(Array(1, 6, 2), Array(1, 2, 3, 4, 5)),
(Array.empty[Int], Array(1, 2, 3))
), 2)
val eps = 1.0E-5
(
Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5),
Array(1, 2, 3, 4, 5),
Array(3.0, 2.0, 1.0, 1.0, 1.0)),
(Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array(1, 2, 3), Array(2.0, 0.0, 0.0)),
(Array(1, 2, 3, 4, 5), Array.empty[Int], Array.empty[Double])),
2)
val eps = 1.0e-5

val metrics = new RankingMetrics(predictionAndLabels)
val map = metrics.meanAveragePrecision

assert(metrics.precisionAt(1) ~== 1.0 / 3 absTol eps)
assert(metrics.precisionAt(2) ~== 1.0 / 3 absTol eps)
assert(metrics.precisionAt(3) ~== 1.0 / 3 absTol eps)
assert(metrics.precisionAt(4) ~== 0.75 / 3 absTol eps)
assert(metrics.precisionAt(5) ~== 0.8 / 3 absTol eps)
assert(metrics.precisionAt(10) ~== 0.8 / 3 absTol eps)
assert(metrics.precisionAt(15) ~== 8.0 / 45 absTol eps)

assert(map ~== 0.355026 absTol eps)

assert(metrics.meanAveragePrecisionAt(1) ~== 0.333334 absTol eps)
assert(metrics.meanAveragePrecisionAt(2) ~== 0.25 absTol eps)
assert(metrics.meanAveragePrecisionAt(3) ~== 0.24074 absTol eps)

assert(metrics.ndcgAt(3) ~== 0.511959 absTol eps)
assert(metrics.ndcgAt(5) ~== 0.487806 absTol eps)
assert(metrics.ndcgAt(10) ~== 0.518700 absTol eps)
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)

assert(metrics.recallAt(1) ~== 1.0 / 15 absTol eps)
assert(metrics.recallAt(2) ~== 8.0 / 45 absTol eps)
assert(metrics.recallAt(3) ~== 11.0 / 45 absTol eps)
assert(metrics.recallAt(4) ~== 11.0 / 45 absTol eps)
assert(metrics.recallAt(5) ~== 16.0 / 45 absTol eps)
assert(metrics.recallAt(10) ~== 2.0 / 3 absTol eps)
assert(metrics.recallAt(15) ~== 2.0 / 3 absTol eps)
}

test("MAP, NDCG, Recall with few predictions (SPARK-14886)") {
val predictionAndLabels = sc.parallelize(
Seq((Array(1, 6, 2), Array(1, 2, 3, 4, 5)), (Array.empty[Int], Array(1, 2, 3))),
2)
val eps = 1.0e-5

val metrics = new RankingMetrics(predictionAndLabels)
assert(metrics.precisionAt(1) ~== 0.5 absTol eps)
Expand Down