Skip to content

Commit

Permalink
Implement CRPS for normal distribution. (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 1, 2023
1 parent 454f72d commit 998b9ea
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 10 deletions.
9 changes: 5 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ if (NOT legateboost_ROOT)
set(legateboost_ROOT ${CMAKE_SOURCE_DIR}/build)
endif()

if (NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()

set(BUILD_SHARED_LIBS ON)

# Look for an existing C++ editable build
Expand All @@ -26,6 +22,11 @@ legate_add_cpp_subdirectory(src TARGET legateboost EXPORT legateboost-export)
legate_add_cffi(${CMAKE_SOURCE_DIR}/src/legateboost.h TARGET legateboost)
legate_python_library_template(legateboost)
legate_default_python_install(legateboost EXPORT legateboost-export)
set_target_properties(legateboost
PROPERTIES
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON)

if (SANITIZE)
message(STATUS "Adding sanitizer flags")
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Metrics
.. autoclass:: legateboost.NormalLLMetric
:members:

.. autoclass:: legateboost.NormalCRPSMetric
:members:

.. autoclass:: legateboost.QuantileMetric
:members:

Expand Down
1 change: 1 addition & 0 deletions legateboost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .metrics import (
MSEMetric,
NormalLLMetric,
NormalCRPSMetric,
QuantileMetric,
LogLossMetric,
ExponentialMetric,
Expand Down
87 changes: 81 additions & 6 deletions legateboost/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Tuple

import cunumeric as cn

Expand Down Expand Up @@ -67,6 +68,16 @@ def name(self) -> str:
return "mse"


def check_normal(y: cn.ndarray, pred: cn.ndarray) -> Tuple[cn.ndarray, cn.ndarray]:
"""Checks for normal distribution inputs."""
if y.size * 2 != pred.size:
raise ValueError("Expected pred to contain mean and sd for each y_i")
if y.ndim == 1:
y = y.reshape((y.size, 1))
pred = pred.reshape((y.shape[0], y.shape[1], 2))
return y, pred


class NormalLLMetric(BaseMetric):
"""The mean negative log likelihood of the labels, given mean and variance
parameters.
Expand All @@ -80,12 +91,7 @@ class NormalLLMetric(BaseMetric):
""" # noqa: E501

def metric(self, y: cn.ndarray, pred: cn.ndarray, w: cn.ndarray) -> float:
assert (
y.size * 2 == pred.size
), "Expected pred to contain mean and sd for each y_i"
if y.ndim == 1:
y = y.reshape((y.size, 1))
pred = pred.reshape((y.shape[0], y.shape[1], 2))
y, pred = check_normal(y, pred)
w_sum = w.sum()
if w_sum == 0:
return 0
Expand All @@ -103,6 +109,74 @@ def name(self) -> str:
return "normal_neg_ll"


def erf(x: cn.ndarray) -> cn.ndarray:
"""Element-wise error function.
Parameters
----------
x :
Input array.
Returns :
The error function applied element-wise to the input array.
"""
# Code from https://www.johndcook.com/blog/python_erf/
a1 = 0.254829592
a2 = -0.284496736
a3 = 1.421413741
a4 = -1.453152027
a5 = 1.061405429
p = 0.3275911

# Save the sign of x
sign = cn.ones(shape=x.shape, dtype=cn.int8)
sign[x < 0.0] = -1
x = cn.abs(x)

# A&S formula 7.1.26
t = 1.0 / (1.0 + p * x)
y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * cn.exp(-x * x)

return sign * y


def norm_cdf(x: cn.ndarray) -> cn.ndarray:
"""CDF function for standard normal distribution."""
return 0.5 * (1.0 + erf(x / cn.sqrt(2.0)))


def norm_pdf(x: cn.ndarray) -> cn.ndarray:
"""PDF function for standard normal distribution."""
return cn.exp(-0.5 * (x) ** 2) / (cn.sqrt(2.0 * cn.pi))


class NormalCRPSMetric(BaseMetric):
"""Continuous Ranked Probability Score for normal distribution. Can be used with
:py:class:`~legateboost.objectives.NormalObjective`.
References
----------
[1] Tilmann Gneiting, Adrian E. Raftery (2007)
`Strictly Proper Scoring Rules, Prediction, and Estimation`
"""

def metric(self, y: cn.ndarray, pred: cn.ndarray, w: cn.ndarray) -> float:
y, pred = check_normal(y, pred)
loc = pred[:, :, 0]
# `NormalObjective` outputs variance instead of scale.
scale = cn.sqrt(pred[:, :, 1])
z = (y - loc) / scale
# This is negating the definition in [1] to make it a loss.
v = scale * (z * (2 * norm_cdf(z) - 1) + 2 * norm_pdf(z) - 1 / cn.sqrt(cn.pi))

v = cn.average(v, weights=w[:, cn.newaxis])
return float(v)

def name(self) -> str:
return "normal_crps"


class QuantileMetric(BaseMetric):
"""The quantile loss, otherwise known as check loss or pinball loss.
Expand Down Expand Up @@ -222,4 +296,5 @@ def name(self) -> str:
"mse": MSEMetric,
"exp": ExponentialMetric,
"normal_neg_ll": NormalLLMetric,
"normal_crps": NormalCRPSMetric,
}
28 changes: 28 additions & 0 deletions legateboost/test/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import cunumeric as cn
import legateboost as lb
from legateboost.metrics import erf


def test_multiple_metrics():
Expand Down Expand Up @@ -174,6 +175,33 @@ def neg_ll(y, p):
assert cn.allclose(our_metric, ref_metric)


def test_erf() -> None:
from scipy.special import erf as scipy_erf

rng = cn.random.default_rng(0)
for t in [cn.float32, cn.float64]:
for s in [(100,), (100, 10), (100, 10, 10)]:
x = rng.normal(size=s)
y0 = erf(x)
y1 = scipy_erf(x)
assert y0.shape == x.shape
assert cn.allclose(y0, y1)


def test_normal_crps() -> None:
"""Tests for the `NormalCRPSMetric`."""
cprs = lb.NormalCRPSMetric()
y = cn.array([1.0, 0.0, 1.0]).T
p = cn.array([[1.0, 1.0], [0.0, 1.0], [1.0, 1.0]])
score = cprs.metric(y, p, cn.ones(y.shape))
assert np.isclose(score, 0.233695)

y = cn.array([12.0, 13.0, 14.0]).T
p = cn.array([[4.0, 8.0], [5.0, 9.0], [6.0, 10.0]])
score = cprs.metric(y, p, cn.ones(y.shape))
assert np.isclose(score, 6.316697)


def test_quantile_metric():
quantiles = cn.array([0.1, 0.5, 0.9])
metric = lb.QuantileMetric(quantiles)
Expand Down

0 comments on commit 998b9ea

Please sign in to comment.