Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ICE #18

Merged
merged 3 commits into from
Jul 2, 2023
Merged

ICE #18

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ S3method(H2_j,default)
S3method(H2_j,interact)
S3method(H2_jk,default)
S3method(H2_jk,interact)
S3method(ice,Learner)
S3method(ice,default)
S3method(ice,ranger)
S3method(interact,Learner)
S3method(interact,default)
S3method(interact,ranger)
Expand All @@ -14,14 +17,17 @@ S3method(partial_dep,default)
S3method(partial_dep,ranger)
S3method(pd_importance,default)
S3method(pd_importance,interact)
S3method(plot,ice)
S3method(plot,interact)
S3method(plot,partial_dep)
S3method(print,ice)
S3method(print,interact)
S3method(print,partial_dep)
S3method(summary,interact)
export(H2)
export(H2_j)
export(H2_jk)
export(ice)
export(interact)
export(multivariate_grid)
export(partial_dep)
Expand Down
2 changes: 1 addition & 1 deletion R/H2_jk.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ H2_jk.interact <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE
return(NULL)
}

# Note that F_jk are in the same order as combn
# Note that F_jk are in the same order as combs
num <- denom <- with(
object,
matrix(
Expand Down
245 changes: 245 additions & 0 deletions R/ice.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
#' Individual Conditional Expectations
#'
#' Disaggregated partial dependencies, see reference. The plot method supports
#' up to two grouping variables via `BY`.
#'
#' @inheritParams partial_dep
#' @param BY Optional grouping vector/matrix/data.frame (up to two columns),
#' or up to two column names. Unlike with [partial_dep()], these variables are not
#' binned. The first variable is visualized on the color scale, while the second
#' one goes into a `facet_wrap()`. Thus, make sure that the second variable is
#' discrete.
#' @returns
#' An object of class "ice" containing these elements:
#' - `ice_curves`: data.frame containing the ice values.
#' - `grid`: Vector, matrix or data.frame of grid values.
#' - `v`: Same as input `v`.
#' - `K`: Number of columns of prediction matrix.
#' - `pred_names`: Column names of prediction matrix.
#' - `by_names`: Column name(s) of grouping variable(s) (or `NULL`).
#' @references
#' Goldstein, Alex, and Adam Kapelner and Justin Bleich and Emil Pitkin.
#' *Peeking inside the black box: Visualizing statistical learning with plots of individual conditional expectation.*
#' Journal of Computational and Graphical Statistics, 24, no. 1 (2015): 44-65.
#' @export
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ . + Species * Petal.Length, data = iris)
#' plot(ice(fit, v = "Sepal.Width", X = iris))
#'
#' # Stratified by one variable
#' ic <- ice(fit, v = "Petal.Length", X = iris, BY = "Species")
#' ic
#' plot(ic)
#' plot(ic, center = TRUE)
#'
#' # Stratified by two variables (the second one goes into facets)
#' ic <- ice(fit, v = "Petal.Length", X = iris, BY = c("Petal.Width", "Species"))
#' plot(ic)
#' plot(ic, center = TRUE)
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' ic <- ice(fit, v = "Petal.Width", X = iris, BY = iris$Species)
#' plot(ic)
#' plot(ic, center = TRUE)
#'
#' # MODEL 3: Gamma GLM -> pass options to predict() via ...
#' fit <- glm(
#' Sepal.Length ~ . + Petal.Width:Species,
#' data = iris,
#' family = Gamma(link = log)
#' )
#' plot(ice(fit, v = "Petal.Length", X = iris, BY = "Species"))
#' plot(ice(fit, v = "Petal.Length", X = iris, type = "response", BY = "Species"))
ice <- function(object, ...) {
UseMethod("ice")
}

#' @describeIn ice Default method.
#' @export
ice.default <- function(object, v, X, pred_fun = stats::predict,
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 100L, ...) {
basic_check(X = X, v = v, pred_fun = pred_fun)

# Prepare grid
if (is.null(grid)) {
grid <- multivariate_grid(
x = X[, v], grid_size = grid_size, trim = trim, strategy = strategy
)
} else {
check_grid(g = grid, v = v, X_is_matrix = is.matrix(X))
}

# Prepare BY
if (!is.null(BY)) {
if (length(BY) <= 2L && all(BY %in% colnames(X))) {
by_names <- BY
BY <- X[, BY]
} else {
n_by <- NCOL(BY)
by_names = if (n_by == 1L) "Group" else paste0("Group_", seq_len(n_by))
if (NROW(BY) != nrow(X)) {
stop("BY variable(s) must have same length as X.")
}
}
if (!is.data.frame(BY)) {
BY <- as.data.frame(BY)
}
} else {
by_names <- NULL
}

# Reduce size of X (and w)
if (nrow(X) > n_max) {
ix <- sample(nrow(X), n_max)
X <- X[ix, , drop = FALSE]
if (!is.null(BY)) {
BY <- BY[ix, , drop = FALSE]
}
}

ice_out <- ice_raw(
object, v = v, X = X, grid = grid, pred_fun = pred_fun, pred_only = FALSE, ...
)
pred <- ice_out[["pred"]]
grid_pred <- ice_out[["grid_pred"]]
K <- ncol(pred)
if (is.null(colnames(pred))) {
colnames(pred) <- if (K == 1L) "y" else paste0("y", seq_len(K))
}
pred_names <- colnames(pred)
if (!is.data.frame(grid_pred) && !is.matrix(grid_pred)) {
grid_pred <- stats::setNames(as.data.frame(grid_pred), v)
}
ice_curves <- cbind.data.frame(obs_ = seq_len(nrow(X)), grid_pred, pred)
if (!is.null(BY)) {
ice_curves[by_names] <- BY[rep(seq_len(nrow(BY)), times = NROW(grid)), ]
}
row.names(ice_curves) <- NULL # could be solved before
out <- list(
ice_curves = ice_curves,
grid = grid,
v = v,
K = K,
pred_names = pred_names,
by_names = by_names
)
return(structure(out, class = "ice"))
}

#' @describeIn ice Method for "ranger" models.
#' @export
ice.ranger <- function(object, v, X,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 100, ...) {
ice.default(
object = object,
v = v,
X = X,
pred_fun = pred_fun,
BY = BY,
grid = grid,
grid_size = grid_size,
trim = trim,
strategy = strategy,
n_max = n_max,
...
)
}

#' @describeIn ice Method for "mlr3" models.
#' @export
ice.Learner <- function(object, v, X,
pred_fun = function(m, X) m$predict_newdata(X)$response,
BY = NULL, grid = NULL, grid_size = 49L, trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), n_max = 100L, ...) {
ice.default(
object = object,
v = v,
X = X,
pred_fun = pred_fun,
BY = BY,
grid = grid,
grid_size = grid_size,
trim = trim,
strategy = strategy,
n_max = n_max,
...
)
}

#' Prints "ice" Object
#'
#' Print method for object of class "ice".
#'
#' @param x An object of class "ice".
#' @param n Number of rows of partial dependencies to show.
#' @param ... Further arguments passed from other methods.
#' @returns Invisibly, the input is returned.
#' @export
#' @seealso See [ice()] for examples.
print.ice <- function(x, n = 3L, ...) {
cat("'ice' object (", nrow(x[["ice_curves"]]), " rows). Extract via $ice_curves. Top rows:\n\n", sep = "")
print(utils::head(x[["ice_curves"]], n))
invisible(x)
}

#' Plots "ice" Object
#'
#' Plot method for objects of class "ice".
#'
#' @importFrom ggplot2 .data
#' @inheritParams plot.partial_dep
#' @param x An object of class "ice".
#' @param center Should curves be centered? Default is `FALSE`.
#' @param alpha Transparency passed to `ggplot2::geom_line()`.
#' @export
#' @returns An object of class "ggplot".
#' @seealso See [ice()] for examples.
plot.ice <- function(x, center = FALSE, alpha = 0.2, rotate_x = FALSE,
color = "#2b51a1", facet_scales = "fixed", ...) {
v <- x[["v"]]
K <- x[["K"]]
ice_curves <- x[["ice_curves"]]
pred_names <- x[["pred_names"]]
by_names <- x[["by_names"]]

if (length(v) > 1L) {
stop("Maximal one feature v can be plotted.")
}
if ((K > 1L) + length(by_names) > 2L) {
stop("Two BY variables and multivariate output has no plot method yet.")
}
if (center) {
pos <- trunc((NROW(x[["grid"]]) + 1) / 2)
ice_curves[pred_names] <- lapply(
ice_curves[pred_names],
function(z) stats::ave(z, ice_curves[["obs_"]], FUN = function(zz) zz - zz[pos])
)
}
data <- poor_man_stack(ice_curves, to_stack = pred_names)

p <- ggplot2::ggplot(data, ggplot2::aes(x = .data[[v]], y = value_, group = obs_)) +
ggplot2::labs(x = v, y = if (center) "Centered ICE" else "ICE")

if (is.null(by_names)) {
p <- p + ggplot2::geom_line(color = color, alpha = alpha, ...)
} else {
p <- p +
ggplot2::geom_line(
ggplot2::aes(color = .data[[by_names[1L]]]), alpha = alpha, ...
) +
ggplot2::labs(color = by_names[1L]) +
ggplot2::guides(color = ggplot2::guide_legend(override.aes = list(alpha = 1)))
}
if (K > 1L || length(by_names) == 2L) { # Only one is possible
wrp <- if (K > 1L) "varying_" else by_names[2L]
p <- p + ggplot2::facet_wrap(wrp, scales = facet_scales)
}
if (rotate_x) p + rotate_x_labs() else p
}
Loading