Skip to content

Commit

Permalink
added dedicated utilities for get/set/mod columns by index
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Sep 25, 2023
1 parent 2a9c5db commit 51cac22
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 5 deletions.
5 changes: 5 additions & 0 deletions legateboost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@
ExponentialObjective,
BaseObjective,
)
from .utils import (
pick_col_by_idx,
set_col_by_idx,
mod_col_by_idx,
)
11 changes: 9 additions & 2 deletions legateboost/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import cunumeric as cn

from .utils import pick_col_by_idx, set_col_by_idx


class BaseMetric(ABC):
"""The base class for metrics.
Expand Down Expand Up @@ -169,7 +171,10 @@ def metric(self, y: cn.ndarray, pred: cn.ndarray, w: cn.ndarray) -> float:
# multi-class case
assert pred.ndim == 2
label = y.astype(cn.int32)
logloss = -cn.log(pred[cn.arange(label.size), label])

logloss = -cn.log(pick_col_by_idx(pred, label))
# logloss = -cn.log(pred[cn.arange(label.size), label])

return float((logloss * w).sum() / w_sum)

def name(self) -> str:
Expand Down Expand Up @@ -201,7 +206,9 @@ def metric(self, y: cn.ndarray, pred: cn.ndarray, w: cn.ndarray) -> float:
K = pred.shape[1] # number of classes
f = cn.log(pred) * (K - 1) # undo softmax
y_k = cn.full((y.size, K), -1.0 / (K - 1.0))
y_k[cn.arange(y.size), y.astype(cn.int32)] = 1.0

set_col_by_idx(y_k, y.astype(cn.int32), 1.0)
# y_k[cn.arange(y.size), y.astype(cn.int32)] = 1.0

exp = cn.exp(-1 / K * cn.sum(y_k * f, axis=1))
return float((exp * w).sum() / w.sum())
Expand Down
8 changes: 5 additions & 3 deletions legateboost/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
NormalLLMetric,
QuantileMetric,
)
from .utils import preround
from .utils import mod_col_by_idx, preround, set_col_by_idx


class BaseObjective(ABC):
Expand Down Expand Up @@ -254,7 +254,8 @@ def gradient(
label = y.astype(cn.int32).squeeze()
h = pred * (1.0 - pred)
g = pred.copy()
g[cn.arange(y.size), label] -= 1.0
mod_col_by_idx(g, label, -1.0)
# g[cn.arange(y.size), label] -= 1.0
return g, cn.maximum(h, eps)

def transform(self, pred: cn.ndarray) -> cn.ndarray:
Expand Down Expand Up @@ -324,7 +325,8 @@ def gradient(self, y: cn.ndarray, pred: cn.ndarray) -> cn.ndarray:
f = cn.log(pred) * (K - 1) # undo softmax
y_k = cn.full((y.size, K), -1.0 / (K - 1.0))
labels = y.astype(cn.int32).squeeze()
y_k[cn.arange(y.size), labels] = 1.0
set_col_by_idx(y_k, labels, 1.0)
# y_k[cn.arange(y.size), labels] = 1.0
exp = cn.exp(-1 / K * cn.sum(y_k * f, axis=1))

return (
Expand Down
40 changes: 40 additions & 0 deletions legateboost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,46 @@ def replace(data: Any) -> None:
self.__dict__.update(state)


def pick_col_by_idx(a: cn.ndarray, b: cn.ndarray) -> cn.ndarray:
"""Alternative implementation for a[cn.arange(b.size), b]"""

assert a.ndim == 2
assert b.ndim == 1
assert a.shape[0] == b.shape[0]

range = cn.arange(a.shape[1])
bools = b[:, cn.newaxis] == range[cn.newaxis, :]
result = a * bools
return result.sum(axis=1)


def set_col_by_idx(a: cn.ndarray, b: cn.ndarray, delta: float) -> None:
"""Alternative implementation for a[cn.arange(b.size), b] = delta"""

assert a.ndim == 2
assert b.ndim == 1
assert a.shape[0] == b.shape[0]

range = cn.arange(a.shape[1])
bools = b[:, cn.newaxis] == range[cn.newaxis, :]
a -= a * bools
a += delta * bools
return


def mod_col_by_idx(a: cn.ndarray, b: cn.ndarray, delta: float) -> None:
"""Alternative implementation for a[cn.arange(b.size), b] += delta."""

assert a.ndim == 2
assert b.ndim == 1
assert a.shape[0] == b.shape[0]

range = cn.arange(a.shape[1])
bools = b[:, cn.newaxis] == range[cn.newaxis, :]
a += delta * bools
return


def preround(x: cn.ndarray) -> cn.ndarray:
"""Apply this function to grad/hess ensure reproducible floating point
summation.
Expand Down

0 comments on commit 51cac22

Please sign in to comment.