Skip to content

Commit

Permalink
Rework the MAP metric. (#8931)
Browse files Browse the repository at this point in the history
- The new implementation is more strict as only binary labels are accepted. The previous implementation converts values greater than 1 to 1.
- Deterministic GPU. (no atomic add).
- Fix top-k handling.
- Precise definition of MAP. (There are other variants on how to handle top-k).
- Refactor GPU ranking tests.
  • Loading branch information
trivialfis committed Mar 22, 2023
1 parent b240f05 commit 5891f75
Show file tree
Hide file tree
Showing 18 changed files with 474 additions and 339 deletions.
13 changes: 11 additions & 2 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,17 @@ Specify the learning task and the corresponding learning objective. The objectiv

- ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_
- ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_
- ``ndcg@n``, ``map@n``: 'n' can be assigned as an integer to cut off the top positions in the lists for evaluation.
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions.

The `average precision` is defined as:

.. math::
AP@l = \frac{1}{min{(l, N)}}\sum^l_{k=1}P@k \cdot I_{(k)}

where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries.

- ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation.
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions.
- ``poisson-nloglik``: negative log-likelihood for Poisson regression
- ``gamma-nloglik``: negative log-likelihood for gamma regression
- ``cox-nloglik``: negative partial log-likelihood for Cox proportional hazards regression
Expand Down
9 changes: 5 additions & 4 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from platform import system
from typing import (
Any,
Expand Down Expand Up @@ -443,7 +444,7 @@ def get_mq2008(
from sklearn.datasets import load_svmlight_files

src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
target = dpath + "/MQ2008.zip"
target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip")
if not os.path.exists(target):
request.urlretrieve(url=src, filename=target)

Expand All @@ -462,9 +463,9 @@ def get_mq2008(
qid_valid,
) = load_svmlight_files(
(
dpath + "MQ2008/Fold1/train.txt",
dpath + "MQ2008/Fold1/test.txt",
dpath + "MQ2008/Fold1/vali.txt",
Path(dpath) / "MQ2008" / "Fold1" / "train.txt",
Path(dpath) / "MQ2008" / "Fold1" / "test.txt",
Path(dpath) / "MQ2008" / "Fold1" / "vali.txt",
),
query_id=True,
zero_based=False,
Expand Down
7 changes: 6 additions & 1 deletion python-package/xgboost/testing/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None:
def neg_mse(*args: Any, **kwargs: Any) -> float:
return -float(mean_squared_error(*args, **kwargs))

ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method)
ranker = xgb.XGBRanker(
n_estimators=3,
eval_metric=neg_mse,
tree_method=tree_method,
disable_default_eval_metric=True,
)
ranker.fit(df, y, eval_set=[(valid_df, y)])
score = ranker.score(valid_df, y)
assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1])
Expand Down
2 changes: 1 addition & 1 deletion src/common/error_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ constexpr StringView LabelScoreSize() {
}

constexpr StringView InfInData() {
return "Input data contains `inf` while `missing` is not set to `inf`";
return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_
26 changes: 13 additions & 13 deletions src/common/numeric.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
/*!
* Copyright 2022, XGBoost contributors.
/**
* Copyright 2022-2023 by XGBoost contributors.
*/
#ifndef XGBOOST_COMMON_NUMERIC_H_
#define XGBOOST_COMMON_NUMERIC_H_

#include <dmlc/common.h> // OMPException

#include <algorithm> // std::max
#include <iterator> // std::iterator_traits
#include <algorithm> // for std::max
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <iterator> // for iterator_traits
#include <vector>

#include "common.h" // AssertGPUSupport
#include "threading_utils.h" // MemStackAllocator, DefaultMaxThreads
#include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector

namespace xgboost {
namespace common {
namespace xgboost::common {

/**
* \brief Run length encode on CPU, input must be sorted.
Expand Down Expand Up @@ -111,11 +112,11 @@ inline double Reduce(Context const*, HostDeviceVector<float> const&) {
namespace cpu_impl {
template <typename It, typename V = typename It::value_type>
V Reduce(Context const* ctx, It first, It second, V const& init) {
size_t n = std::distance(first, second);
common::MemStackAllocator<V, common::DefaultMaxThreads()> result_tloc(ctx->Threads(), init);
common::ParallelFor(n, ctx->Threads(),
[&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + ctx->Threads(), init);
std::size_t n = std::distance(first, second);
auto n_threads = static_cast<std::size_t>(std::min(n, static_cast<std::size_t>(ctx->Threads())));
common::MemStackAllocator<V, common::DefaultMaxThreads()> result_tloc(n_threads, init);
common::ParallelFor(n, n_threads, [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + n_threads, init);
return result;
}
} // namespace cpu_impl
Expand Down Expand Up @@ -144,7 +145,6 @@ void Iota(Context const* ctx, It first, It last,
});
}
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common

#endif // XGBOOST_COMMON_NUMERIC_H_
9 changes: 9 additions & 0 deletions src/common/ranking_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUS

DMLC_REGISTER_PARAMETER(LambdaRankParam);

void MAPCache::InitOnCPU(Context const*, MetaInfo const& info) {
auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
CheckMapLabels(h_label, [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); });
}

#if !defined(XGBOOST_USE_CUDA)
void MAPCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA)

std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) {
std::string out_name;
if (!param.empty()) {
Expand Down
5 changes: 5 additions & 0 deletions src/common/ranking_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,9 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
dh::LaunchN(MaxGroupSize(), cuctx->Stream(),
[=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); });
}

void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()});
}
} // namespace xgboost::ltr
65 changes: 65 additions & 0 deletions src/common/ranking_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,71 @@ void CheckNDCGLabels(ltr::LambdaRankParam const& p, linalg::VectorView<float con
}
}

template <typename AllOf>
bool IsBinaryRel(linalg::VectorView<float const> label, AllOf all_of) {
auto s_label = label.Values();
return all_of(s_label.data(), s_label.data() + s_label.size(), [] XGBOOST_DEVICE(float y) {
return std::abs(y - 1.0f) < kRtEps || std::abs(y - 0.0f) < kRtEps;
});
}
/**
* \brief Validate label for MAP
*
* \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for
* both CPU and GPU.
*/
template <typename AllOf>
void CheckMapLabels(linalg::VectorView<float const> label, AllOf all_of) {
auto s_label = label.Values();
auto is_binary = IsBinaryRel(label, all_of);
CHECK(is_binary) << "MAP can only be used with binary labels.";
}

class MAPCache : public RankingCache {
// Total number of relevant documents for each group
HostDeviceVector<double> n_rel_;
// \sum l_k/k
HostDeviceVector<double> acc_;
HostDeviceVector<double> map_;
// Number of samples in this dataset.
std::size_t n_samples_{0};

void InitOnCPU(Context const* ctx, MetaInfo const& info);
void InitOnCUDA(Context const* ctx, MetaInfo const& info);

public:
MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
: RankingCache{ctx, info, p}, n_samples_{static_cast<std::size_t>(info.num_row_)} {
if (ctx->IsCPU()) {
this->InitOnCPU(ctx, info);
} else {
this->InitOnCUDA(ctx, info);
}
}

common::Span<double> NumRelevant(Context const* ctx) {
if (n_rel_.Empty()) {
n_rel_.SetDevice(ctx->gpu_id);
n_rel_.Resize(n_samples_);
}
return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
}
common::Span<double> Acc(Context const* ctx) {
if (acc_.Empty()) {
acc_.SetDevice(ctx->gpu_id);
acc_.Resize(n_samples_);
}
return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
}
common::Span<double> Map(Context const* ctx) {
if (map_.Empty()) {
map_.SetDevice(ctx->gpu_id);
map_.Resize(this->Groups());
}
return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();
}
};

/**
* \brief Parse name for ranking metric given parameters.
*
Expand Down
10 changes: 6 additions & 4 deletions src/common/threading_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
#include <dmlc/omp.h>

#include <algorithm>
#include <cstdint> // std::int32_t
#include <cstdint> // for int32_t
#include <cstdlib> // for malloc, free
#include <limits>
#include <type_traits> // std::is_signed
#include <new> // for bad_alloc
#include <type_traits> // for is_signed
#include <vector>

#include "xgboost/logging.h"
Expand Down Expand Up @@ -266,7 +268,7 @@ class MemStackAllocator {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
ptr_ = reinterpret_cast<T*>(std::malloc(required_size_ * sizeof(T)));
}
if (!ptr_) {
throw std::bad_alloc{};
Expand All @@ -278,7 +280,7 @@ class MemStackAllocator {

~MemStackAllocator() {
if (required_size_ > MaxStackSize) {
free(ptr_);
std::free(ptr_);
}
}
T& operator[](size_t i) { return ptr_[i]; }
Expand Down
97 changes: 62 additions & 35 deletions src/metric/rank_metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,37 +284,6 @@ struct EvalPrecision : public EvalRank {
}
};

/*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAP : public EvalRank {
public:
explicit EvalMAP(const char* name, const char* param) : EvalRank(name, param) {}

double EvalGroup(PredIndPairContainer *recptr) const override {
PredIndPairContainer &rec(*recptr);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhits = 0;
double sumap = 0.0;
for (size_t i = 0; i < rec.size(); ++i) {
if (rec[i].second != 0) {
nhits += 1;
if (i < this->topn) {
sumap += static_cast<double>(nhits) / (i + 1);
}
}
}
if (nhits != 0) {
sumap /= nhits;
return sumap;
} else {
if (this->minus) {
return 0.0;
} else {
return 1.0;
}
}
}
};

/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public MetricNoCache {
public:
Expand Down Expand Up @@ -370,10 +339,6 @@ XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); });

XGBOOST_REGISTER_METRIC(MAP, "map")
.describe("map@k for rank.")
.set_body([](const char* param) { return new EvalMAP("map", param); });

XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportional hazards model.")
.set_body([](const char*) { return new EvalCox(); });
Expand Down Expand Up @@ -516,6 +481,68 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
}
};

class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }

double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::MAPCache> p_cache) override {
if (ctx_->IsCUDA()) {
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
return Finalize(map.Residue(), map.Weights());
}

auto gptr = p_cache->DataGroupPtr(ctx_);
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size());

auto map_gloc = p_cache->Map(ctx_);
std::fill_n(map_gloc.data(), map_gloc.size(), 0.0);
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());

common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_rank = rank_idx.subspan(gptr[g]);

auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
double n_hits{0.0};
for (std::size_t i = 0; i < n; ++i) {
auto p = g_label(g_rank[i]);
n_hits += p;
map_gloc[g] += n_hits / static_cast<double>((i + 1)) * p;
}
for (std::size_t i = n; i < g_label.Size(); ++i) {
n_hits += g_label(g_rank[i]);
}
if (n_hits > 0.0) {
map_gloc[g] /= std::min(n_hits, static_cast<double>(param_.TopK()));
} else {
map_gloc[g] = minus_ ? 0.0 : 1.0;
}
});

auto sw = 0.0;
auto weight = common::MakeOptionalWeights(ctx_, info.weights_);
if (!weight.Empty()) {
CHECK_EQ(weight.weights.size(), p_cache->Groups());
}
for (std::size_t i = 0; i < map_gloc.size(); ++i) {
map_gloc[i] = map_gloc[i] * weight[i];
sw += weight[i];
}
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
return Finalize(sum, sw);
}
};

XGBOOST_REGISTER_METRIC(EvalMAP, "map")
.describe("map@k for ranking.")
.set_body([](char const* param) {
return new EvalMAPScore{"map", param};
});

XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg")
.describe("ndcg@k for ranking.")
.set_body([](char const* param) {
Expand Down
Loading

0 comments on commit 5891f75

Please sign in to comment.