Skip to content

Commit

Permalink
Merge pull request #69 from mayer79/organize_utils
Browse files Browse the repository at this point in the history
Organize utils
  • Loading branch information
mayer79 authored Oct 7, 2023
2 parents 61e2b48 + bd0291c commit 4cc732f
Show file tree
Hide file tree
Showing 28 changed files with 904 additions and 766 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ This release mainly changes the *output*. The numeric results are unchanged.
- `summary.hstats()` now returns an object of class "hstats_summary" instead of "summary_hstats".
- `average_loss()` is more flexible regarding the group `BY` argument. It can also be a variable *name*. Non-discrete `BY` variables are now automatically binned. Like `partial_dep()`, binning is controlled by the `by_size = 4` argument.
- `average_loss()` also returns a "hstats_matrix" object with `print()` and `plot()` method. The values can be extracted via `$M`.
- Case weights `w` can now also be passed as column name of `X`.
- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column name of `w` if passed as name).

# hstats 0.3.0

Expand Down
14 changes: 9 additions & 5 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ average_loss <- function(object, ...) {
average_loss.default <- function(object, X, y,
pred_fun = stats::predict,
loss = "squared_error",
BY = NULL, by_size = 4L, w = NULL, ...) {
BY = NULL, by_size = 4L,
w = NULL, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
nrow(X) >= 1L,
is.function(pred_fun),
is.null(w) || length(w) == nrow(X),
NROW(y) == nrow(X)
)
if (!is.null(w)) {
w <- prepare_w(w = w, X = X)[["w"]]
}
if (!is.null(BY)) {
BY <- prepare_by(BY = BY, X = X, by_size = by_size)[["BY"]]
}
Expand Down Expand Up @@ -104,7 +106,8 @@ average_loss.default <- function(object, X, y,
average_loss.ranger <- function(object, X, y,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
loss = "squared_error",
BY = NULL, by_size = 4L, w = NULL, ...) {
BY = NULL, by_size = 4L,
w = NULL, ...) {
average_loss.default(
object = object,
X = X,
Expand All @@ -122,7 +125,8 @@ average_loss.ranger <- function(object, X, y,
average_loss.Learner <- function(object, v, X, y,
pred_fun = NULL,
loss = "squared_error",
BY = NULL, by_size = 4L, w = NULL, ...) {
BY = NULL, by_size = 4L,
w = NULL, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
Expand Down
111 changes: 32 additions & 79 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
#'
#' @param object Fitted model object.
#' @param X A data.frame or matrix serving as background dataset.
#' @param v Vector of feature names, by default `colnames(X)`.
#' @param v Vector of feature names. The default (`NULL`) will use all column names of
#' `X` except the column name of the optional case weight `w` (if specified as name).
#' @param pred_fun Prediction function of the form `function(object, X, ...)`,
#' providing \eqn{K \ge 1} predictions per row. Its first argument represents the
#' model `object`, its second argument a data structure like `X`. Additional arguments
Expand All @@ -32,7 +33,7 @@
#' most cases.
#' @param n_max If `X` has more than `n_max` rows, a random sample of `n_max` rows is
#' selected from `X`. In this case, set a random seed for reproducibility.
#' @param w Optional vector of case weights for each row of `X`.
#' @param w Optional vector of case weights. Can also be a column name of `X`.
#' @param pairwise_m Number of features for which pairwise statistics are to be
#' calculated. The features are selected based on Friedman and Popescu's overall
#' interaction strength \eqn{H^2_j}. Set to to 0 to avoid pairwise calculations.
Expand All @@ -48,8 +49,9 @@
#' @returns
#' An object of class "hstats" containing these elements:
#' - `X`: Input `X` (sampled to `n_max` rows).
#' - `w`: Input `w` (sampled to `n_max` values, or `NULL`).
#' - `v`: Same as input `v`.
#' - `w`: Case weight vector `w` (sampled to `n_max` values), or `NULL`.
#' - `v`: Vector of column names in `X` for which overall
#' H statistics have been calculated.
#' - `f`: Matrix with (centered) predictions \eqn{F}.
#' - `mean_f2`: (Weighted) column means of `f`. Used to normalize \eqn{H^2} and
#' \eqn{H^2_j}.
Expand Down Expand Up @@ -124,12 +126,33 @@ hstats <- function(object, ...) {

#' @describeIn hstats Default hstats method.
#' @export
hstats.default <- function(object, X, v = colnames(X),
hstats.default <- function(object, X, v = NULL,
pred_fun = stats::predict, n_max = 300L,
w = NULL, pairwise_m = 5L,
threeway_m = min(pairwise_m, 5L),
eps = 1e-10, verbose = TRUE, ...) {
basic_check(X = X, v = v, pred_fun = pred_fun, w = w)
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
)

# Is w a column name or a vector?
if (!is.null(w)) {
w2 <- prepare_w(w = w, X = X)
w <- w2[["w"]]
w_name <- w2[["w_name"]]
}

# Determine missing v or check consistency with X
if (is.null(v)) {
v <- colnames(X)
if (!is.null(w) && !is.null(w_name)) {
v <- setdiff(v, w_name)
}
} else {
stopifnot(all(v %in% colnames(X)))
}

p <- length(v)
stopifnot(p >= 2L)
pairwise_m <- min(pairwise_m, p)
Expand Down Expand Up @@ -234,7 +257,7 @@ hstats.default <- function(object, X, v = colnames(X),

#' @describeIn hstats Method for "ranger" models.
#' @export
hstats.ranger <- function(object, X, v = colnames(X),
hstats.ranger <- function(object, X, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
n_max = 300L, w = NULL, pairwise_m = 5L,
threeway_m = min(pairwise_m, 5L),
Expand All @@ -256,7 +279,7 @@ hstats.ranger <- function(object, X, v = colnames(X),

#' @describeIn hstats Method for "mlr3" models.
#' @export
hstats.Learner <- function(object, X, v = colnames(X),
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
n_max = 300L, w = NULL, pairwise_m = 5L,
threeway_m = min(pairwise_m, 5L),
Expand All @@ -282,7 +305,7 @@ hstats.Learner <- function(object, X, v = colnames(X),
#' @describeIn hstats Method for DALEX "explainer".
#' @export
hstats.explainer <- function(object, X = object[["data"]],
v = colnames(X),
v = NULL,
pred_fun = object[["predict_function"]],
n_max = 300L, w = object[["weights"]],
pairwise_m = 5L,
Expand Down Expand Up @@ -436,73 +459,3 @@ plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE,
}
p
}

# Helper functions used only in this script

#' Pairwise or 3-Way Partial Dependencies
#'
#' Calculates centered partial dependence functions for selected pairwise or three-way
#' situations.
#'
#' @noRd
#' @keywords internal
#'
#' @param v Vector of column names to calculate `way` order interactions.
#' @inheritParams hstats
#' @param way Pairwise (`way = 2`) or three-way (`way = 3`) interactions.
#' @param verb Verbose (`TRUE`/`FALSE`).
#'
#' @returns
#' A list with a named list of feature combinations (pairs or triples), and
#' corresponding centered partial dependencies.
mway <- function(object, v, X, pred_fun = stats::predict, w = NULL,
way = 2L, verb = TRUE, ...) {
combs <- utils::combn(v, way, simplify = FALSE)
n_combs <- length(combs)
F_way <- vector("list", length = n_combs)
names(F_way) <- names(combs) <- sapply(combs, paste, collapse = ":")

if (verb) {
cat(way, "way calculations...\n", sep = "-")
pb <- utils::txtProgressBar(max = n_combs, style = 3)
}

for (i in seq_len(n_combs)) {
z <- combs[[i]]
F_way[[i]] <- wcenter(
pd_raw(object, v = z, X = X, grid = X[, z], pred_fun = pred_fun, w = w, ...),
w = w
)
if (verb) {
utils::setTxtProgressBar(pb, i)
}
}
if (verb) {
cat("\n")
}
list(combs, F_way)
}

#' Get Feature Names
#'
#' This function takes the unsorted and unnormalized H2_j matrix and extracts the top
#' m feature names (unsorted). If H2_j has multiple columns, this is done per column and
#' then the union is returned.
#'
#' @noRd
#' @keywords internal
#'
#' @param H Unnormalized, unsorted H2_j values.
#' @param m Number of features to pick per column.
#'
#' @returns A vector of the union of the m column-wise most important features.
get_v <- function(H, m) {
v <- rownames(H)
selector <- function(vv) names(utils::head(sort(-vv[vv > 0]), m))
if (NCOL(H) == 1L) {
v_cand <- selector(drop(H))
} else {
v_cand <- Reduce(union, lapply(asplit(H, MARGIN = 2L), FUN = selector))
}
v[v %in% v_cand]
}
6 changes: 5 additions & 1 deletion R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ ice.default <- function(object, v, X, pred_fun = stats::predict,
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 100L, ...) {
basic_check(X = X, v = v, pred_fun = pred_fun)
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun),
all(v %in% colnames(X))
)

# Prepare grid
if (is.null(grid)) {
Expand Down
3 changes: 3 additions & 0 deletions R/onLoad.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
}
invisible()
}

# Fix undefined global variable note
utils::globalVariables(c("varying_", "value_", "id_", "variable_", "obs_", "error_"))
10 changes: 9 additions & 1 deletion R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ partial_dep.default <- function(object, v, X, pred_fun = stats::predict,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 1000L,
w = NULL, ...) {
basic_check(X = X, v = v, pred_fun = pred_fun, w = w)
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun),
all(v %in% colnames(X))
)

# Care about grid
if (is.null(grid)) {
Expand All @@ -110,6 +114,10 @@ partial_dep.default <- function(object, v, X, pred_fun = stats::predict,
check_grid(g = grid, v = v, X_is_matrix = is.matrix(X))
}

if (!is.null(w)) {
w <- prepare_w(w = w, X = X)[["w"]]
}

# The function itself is called per BY group
if (!is.null(BY)) {
BY2 <- prepare_by(BY = BY, X = X, by_size = by_size)
Expand Down
46 changes: 31 additions & 15 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#' @inheritSection average_loss Losses
#'
#' @param v Vector of feature names, or named list of feature groups.
#' The default (`NULL`) will use all column names of `X` except the column name
#' of the optional case weight `w` (if specified as name).
#' @param m_rep Number of permutations (default 4).
#' @param agg_cols Should multivariate losses be summed up? Default is `FALSE`.
#' @param normalize Should importance statistics be divided by average loss?
Expand Down Expand Up @@ -53,23 +55,42 @@ perm_importance <- function(object, ...) {

#' @describeIn perm_importance Default method.
#' @export
perm_importance.default <- function(object, X, y, v = colnames(X),
perm_importance.default <- function(object, X, y, v = NULL,
pred_fun = stats::predict,
loss = "squared_error",
m_rep = 4L, agg_cols = FALSE,
normalize = FALSE, n_max = 10000L,
w = NULL, verbose = FALSE, ...) {
basic_check(
X = X,
v = unlist(v, use.names = FALSE, recursive = FALSE),
pred_fun = pred_fun,
w = w
)
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun),
NROW(y) == nrow(X),
m_rep >= 1L
)

# Is w a column name or a vector?
if (!is.null(w)) {
w2 <- prepare_w(w = w, X = X)
w <- w2[["w"]]
w_name <- w2[["w_name"]]
}

# Prepare v
if (is.null(v)) {
v <- colnames(X)
if (!is.null(w) && !is.null(w_name)) {
v <- setdiff(v, w_name)
}
} else {
v_c <- unlist(v, use.names = FALSE, recursive = FALSE)
stopifnot(all(v_c %in% colnames(X)))
}
if (!is.list(v)) {
v <- as.list(v)
names(v) <- v
}
p <- length(v)

# Reduce size of X, y (and w)
if (nrow(X) > n_max) {
ix <- sample(nrow(X), n_max)
Expand All @@ -84,11 +105,6 @@ perm_importance.default <- function(object, X, y, v = colnames(X),
}
}
n <- nrow(X)
p <- length(v)
if (!is.list(v)) {
v <- as.list(v)
names(v) <- v
}

if (!is.function(loss)) {
loss <- get_loss_fun(loss)
Expand Down Expand Up @@ -175,7 +191,7 @@ perm_importance.default <- function(object, X, y, v = colnames(X),

#' @describeIn perm_importance Method for "ranger" models.
#' @export
perm_importance.ranger <- function(object, X, y, v = colnames(X),
perm_importance.ranger <- function(object, X, y, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
loss = "squared_error", m_rep = 4L,
agg_cols = FALSE,
Expand All @@ -200,7 +216,7 @@ perm_importance.ranger <- function(object, X, y, v = colnames(X),

#' @describeIn perm_importance Method for "mlr3" models.
#' @export
perm_importance.Learner <- function(object, X, y, v = colnames(X),
perm_importance.Learner <- function(object, X, y, v = NULL,
pred_fun = NULL,
loss = "squared_error", m_rep = 4L,
agg_cols = FALSE,
Expand Down Expand Up @@ -231,7 +247,7 @@ perm_importance.Learner <- function(object, X, y, v = colnames(X),
perm_importance.explainer <- function(object,
X = object[["data"]],
y = object[["y"]],
v = colnames(X),
v = NULL,
pred_fun = object[["predict_function"]],
loss = "squared_error",
m_rep = 4L,
Expand Down
Loading

0 comments on commit 4cc732f

Please sign in to comment.