diff --git a/NEWS.md b/NEWS.md index d035ae64..fc9c8d82 100644 --- a/NEWS.md +++ b/NEWS.md @@ -18,8 +18,9 @@ This release mainly changes the *output*. The numeric results are unchanged. - `summary.hstats()` now returns an object of class "hstats_summary" instead of "summary_hstats". - `average_loss()` is more flexible regarding the group `BY` argument. It can also be a variable *name*. Non-discrete `BY` variables are now automatically binned. Like `partial_dep()`, binning is controlled by the `by_size = 4` argument. - `average_loss()` also returns a "hstats_matrix" object with `print()` and `plot()` method. The values can be extracted via `$M`. -- Case weights `w` can now also be passed as column name of `X`. -- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column name of `w` if passed as name). +- Case weights `w` can now also be passed as column name of `X` (to any function). +- `perm_importance()` and `average_loss()`: The response(s) `y` can now also be passed as column name(s) of `X`. +- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column names of `w` and `y` if passed as name). # hstats 0.3.0 diff --git a/R/average_loss.R b/R/average_loss.R index b90931a2..33cc5d94 100644 --- a/R/average_loss.R +++ b/R/average_loss.R @@ -31,7 +31,7 @@ #' vector or matrix of the same length as the input. #' #' @inheritParams hstats -#' @param y Vector/matrix of the response corresponding to `X`. +#' @param y Vector/matrix of the response, or the corresponding column names in `X`. #' @param loss One of "squared_error", "logloss", "mlogloss", "poisson", #' "gamma", "absolute_error", "classification_error". Alternatively, a loss function #' can be provided that turns observed and predicted values into a numeric vector or @@ -49,14 +49,16 @@ #' @examples #' # MODEL 1: Linear regression #' fit <- lm(Sepal.Length ~ ., data = iris) -#' average_loss(fit, X = iris, y = iris$Sepal.Length) -#' average_loss(fit, X = iris, y = iris$Sepal.Length, BY = iris$Species) -#' average_loss(fit, X = iris, y = iris$Sepal.Length, BY = "Sepal.Width") +#' average_loss(fit, X = iris, y = "Sepal.Length") +#' average_loss(fit, X = iris, y = iris$Sepal.Length, BY = iris$Sepal.Width) +#' average_loss(fit, X = iris, y = "Sepal.Length", BY = "Sepal.Width") #' #' # MODEL 2: Multi-response linear regression #' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris) #' average_loss(fit, X = iris, y = iris[1:2]) -#' L <- average_loss(fit, X = iris, y = iris[1:2], loss = "gamma", BY = "Species") +#' L <- average_loss( +#' fit, X = iris, y = iris[1:2], loss = "gamma", BY = "Species" +#' ) #' L #' plot(L) average_loss <- function(object, ...) { @@ -72,9 +74,9 @@ average_loss.default <- function(object, X, y, w = NULL, ...) { stopifnot( is.matrix(X) || is.data.frame(X), - is.function(pred_fun), - NROW(y) == nrow(X) + is.function(pred_fun) ) + y <- prepare_y(y = y, X = X)[["y"]] if (!is.null(w)) { w <- prepare_w(w = w, X = X)[["w"]] } diff --git a/R/losses.R b/R/losses.R index c13fffbd..c5f9556b 100644 --- a/R/losses.R +++ b/R/losses.R @@ -149,6 +149,13 @@ expand_actual <- function(actual, predicted) { pp <- NCOL(predicted) pa <- NCOL(actual) if (pa == pp) { + if (pa > 1L) { + nmp <- colnames(predicted) + nma <- colnames(actual) + if (!is.null(nmp) && !is.null(nma) && !identical(nmp, nma)) { + stop("Column names of multi-output response must correspond to predictions.") + } + } return(actual) } if (pp > 1L && pa == 1L) { diff --git a/R/perm_importance.R b/R/perm_importance.R index 02c3d7d5..426b3a01 100644 --- a/R/perm_importance.R +++ b/R/perm_importance.R @@ -12,8 +12,8 @@ #' @inheritSection average_loss Losses #' #' @param v Vector of feature names, or named list of feature groups. -#' The default (`NULL`) will use all column names of `X` except the column name -#' of the optional case weight `w` (if specified as name). +#' The default (`NULL`) will use all column names of `X` with the following exception: +#' If `y` or `w` are passed as column names, they are dropped. #' @param m_rep Number of permutations (default 4). #' @param agg_cols Should multivariate losses be summed up? Default is `FALSE`. #' @param normalize Should importance statistics be divided by average loss? @@ -30,7 +30,7 @@ #' @examples #' # MODEL 1: Linear regression #' fit <- lm(Sepal.Length ~ ., data = iris) -#' s <- perm_importance(fit, X = iris[-1], y = iris$Sepal.Length) +#' s <- perm_importance(fit, X = iris, y = "Sepal.Length") #' s #' s$M #' s$SE # Standard errors are available thanks to repeated shuffling @@ -39,7 +39,7 @@ #' #' # Groups of features can be passed as named list #' v <- list(petal = c("Petal.Length", "Petal.Width"), species = "Species") -#' s <- perm_importance(fit, X = iris, y = iris$Sepal.Length, v = v) +#' s <- perm_importance(fit, X = iris, y = "Sepal.Length", v = v) #' s #' plot(s) #' @@ -64,10 +64,14 @@ perm_importance.default <- function(object, X, y, v = NULL, stopifnot( is.matrix(X) || is.data.frame(X), is.function(pred_fun), - NROW(y) == nrow(X), m_rep >= 1L ) + # Are y column names or a vector/matrix? + y2 <- prepare_y(y = y, X = X) + y <- y2[["y"]] + y_names <- y2[["y_names"]] + # Is w a column name or a vector? if (!is.null(w)) { w2 <- prepare_w(w = w, X = X) @@ -81,6 +85,9 @@ perm_importance.default <- function(object, X, y, v = NULL, if (!is.null(w) && !is.null(w_name)) { v <- setdiff(v, w_name) } + if (!is.null(y_names)) { + v <- setdiff(v, y_names) + } } else { v_c <- unlist(v, use.names = FALSE, recursive = FALSE) stopifnot(all(v_c %in% colnames(X))) diff --git a/R/utils_input.R b/R/utils_input.R index 446dcef5..ac671a9c 100644 --- a/R/utils_input.R +++ b/R/utils_input.R @@ -72,6 +72,27 @@ prepare_w <- function(w, X) { list(w = w, w_name = w_name) } +#' Prepares Response y +#' +#' Internal function that prepares the response `y`. +#' +#' @noRd +#' @keywords internal +#' @param y Vector/matrix-like of the same length as `X`, or column names in `X`. +#' @param X Matrix-like. +#' +#' @returns A list. +prepare_y <- function(y, X) { + if (NROW(y) < nrow(X) && all(y %in% colnames(X))) { + y_names <- y + y <- X[, y] + } else { + stopifnot(NROW(y) == nrow(X)) + y_names <- NULL + } + list(y = y, y_names = y_names) +} + #' mlr3 Helper #' #' Returns the prediction function of a mlr3 Learner. diff --git a/README.md b/README.md index de2b7922..0d3eee51 100644 --- a/README.md +++ b/README.md @@ -253,7 +253,7 @@ library(ggplot2) set.seed(1) fit <- ranger(Species ~ ., data = iris, probability = TRUE) -average_loss(fit, X = iris, y = iris$Species, loss = "mlogloss") # 0.0521 +average_loss(fit, X = iris, y = "Species", loss = "mlogloss") # 0.0521 s <- hstats(fit, X = iris[-5]) s @@ -267,7 +267,7 @@ ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |> plot(center = TRUE) + ggtitle("Centered ICE plots") -perm_importance(fit, X = iris[-5], y = iris$Species, loss = "mlogloss") +perm_importance(fit, X = iris, y = "Species", loss = "mlogloss") # Permutation importance # Petal.Length Petal.Width Sepal.Length Sepal.Width # 0.50941613 0.49187688 0.05669978 0.00950009 @@ -306,7 +306,7 @@ s <- hstats(fit, X = iris[-1]) s # 0 -> no interactions plot(partial_dep(fit, v = "Petal.Width", X = iris)) -imp <- perm_importance(fit, X = iris[-1], y = iris$Sepal.Length) +imp <- perm_importance(fit, X = iris, y = "Sepal.Length") imp # Permutation importance # Petal.Length Species Petal.Width Sepal.Width @@ -334,7 +334,7 @@ fit <- train( h2(hstats(fit, X = iris[-1])) # 0 plot(ice(fit, v = "Petal.Width", X = iris), center = TRUE) -plot(perm_importance(fit, X = iris[-1], y = iris$Sepal.Length)) +plot(perm_importance(fit, X = iris, y = "Sepal.Length")) ``` ### mlr3 @@ -354,7 +354,7 @@ s <- hstats(fit_rf, X = iris[-5], threeway_m = 0) plot(s) # Permutation importance -perm_importance(fit_rf, X = iris[-5], y = iris$Species, loss = "mlogloss") |> +perm_importance(fit_rf, X = iris, y = "Species", loss = "mlogloss") |> plot() ``` diff --git a/backlog/hstats_explainer.R b/backlog/hstats_explainer.R new file mode 100644 index 00000000..b7622b8e --- /dev/null +++ b/backlog/hstats_explainer.R @@ -0,0 +1,16 @@ +hstats_explainer <- function(object, X, pred_fun = stats::predict, + y = NULL, loss = "squared_error", + w = NULL, ...) { + structure( + list( + object = object, + X = X, + pred_fun = function(m, x) pred_fun(m, x, ...), + y = y, + loss = loss, + w = w + ), + class = "hstats_explainer" + ) +} + diff --git a/backlog/modeltuner.R b/backlog/modeltuner.R index f23fd87c..a78c5b32 100644 --- a/backlog/modeltuner.R +++ b/backlog/modeltuner.R @@ -8,14 +8,14 @@ fit_glm <- model(glm(form, iris, weights = Petal.Width, family = Gamma(link = "l mm <- c(lm = fit_lm, glm = fit_glm) predict(mm, head(iris)) -average_loss(mm, X = iris, y = iris$Sepal.Length, BY = "Species", w = "Petal.Width") |> +average_loss(mm, X = iris, y = "Sepal.Length", BY = "Species", w = "Petal.Width") |> plot() partial_dep(mm, v = "Sepal.Width", X = iris, BY = "Species", w = "Petal.Width") |> plot(show_points = FALSE) ice(mm, v = "Sepal.Width", X = iris, BY = "Species") |> plot(facet_scales = "fixed") -perm_importance(mm, X = iris[-1], y = iris[, 1], w = "Petal.Width") |> +perm_importance(mm, X = iris, y = "Sepal.Length", w = "Petal.Width") |> plot() # Interaction statistics (H-statistics) @@ -24,4 +24,3 @@ H plot(H) h2_pairwise(H, normalize = FALSE, squared = FALSE) |> plot() - diff --git a/man/average_loss.Rd b/man/average_loss.Rd index 85d06b6b..3033aa49 100644 --- a/man/average_loss.Rd +++ b/man/average_loss.Rd @@ -67,7 +67,7 @@ for instance \code{type = "response"} in a \code{\link[=glm]{glm()}} model.} \item{X}{A data.frame or matrix serving as background dataset.} -\item{y}{Vector/matrix of the response corresponding to \code{X}.} +\item{y}{Vector/matrix of the response, or the corresponding column names in \code{X}.} \item{pred_fun}{Prediction function of the form \verb{function(object, X, ...)}, providing \eqn{K \ge 1} predictions per row. Its first argument represents the @@ -158,14 +158,16 @@ vector or matrix of the same length as the input. \examples{ # MODEL 1: Linear regression fit <- lm(Sepal.Length ~ ., data = iris) -average_loss(fit, X = iris, y = iris$Sepal.Length) -average_loss(fit, X = iris, y = iris$Sepal.Length, BY = iris$Species) -average_loss(fit, X = iris, y = iris$Sepal.Length, BY = "Sepal.Width") +average_loss(fit, X = iris, y = "Sepal.Length") +average_loss(fit, X = iris, y = iris$Sepal.Length, BY = iris$Sepal.Width) +average_loss(fit, X = iris, y = "Sepal.Length", BY = "Sepal.Width") # MODEL 2: Multi-response linear regression fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris) average_loss(fit, X = iris, y = iris[1:2]) -L <- average_loss(fit, X = iris, y = iris[1:2], loss = "gamma", BY = "Species") +L <- average_loss( + fit, X = iris, y = iris[1:2], loss = "gamma", BY = "Species" +) L plot(L) } diff --git a/man/perm_importance.Rd b/man/perm_importance.Rd index a1fd37a0..60a671bc 100644 --- a/man/perm_importance.Rd +++ b/man/perm_importance.Rd @@ -82,11 +82,11 @@ for instance \code{type = "response"} in a \code{\link[=glm]{glm()}} model.} \item{X}{A data.frame or matrix serving as background dataset.} -\item{y}{Vector/matrix of the response corresponding to \code{X}.} +\item{y}{Vector/matrix of the response, or the corresponding column names in \code{X}.} \item{v}{Vector of feature names, or named list of feature groups. -The default (\code{NULL}) will use all column names of \code{X} except the column name -of the optional case weight \code{w} (if specified as name).} +The default (\code{NULL}) will use all column names of \code{X} with the following exception: +If \code{y} or \code{w} are passed as column names, they are dropped.} \item{pred_fun}{Prediction function of the form \verb{function(object, X, ...)}, providing \eqn{K \ge 1} predictions per row. Its first argument represents the @@ -187,7 +187,7 @@ vector or matrix of the same length as the input. \examples{ # MODEL 1: Linear regression fit <- lm(Sepal.Length ~ ., data = iris) -s <- perm_importance(fit, X = iris[-1], y = iris$Sepal.Length) +s <- perm_importance(fit, X = iris, y = "Sepal.Length") s s$M s$SE # Standard errors are available thanks to repeated shuffling @@ -196,7 +196,7 @@ plot(s, err_type = "SD") # Standard deviations instead of standard errors # Groups of features can be passed as named list v <- list(petal = c("Petal.Length", "Petal.Width"), species = "Species") -s <- perm_importance(fit, X = iris, y = iris$Sepal.Length, v = v) +s <- perm_importance(fit, X = iris, y = "Sepal.Length", v = v) s plot(s) diff --git a/tests/testthat/test_average_loss.R b/tests/testthat/test_average_loss.R index ad37915e..d2766f25 100644 --- a/tests/testthat/test_average_loss.R +++ b/tests/testthat/test_average_loss.R @@ -4,7 +4,9 @@ y <- iris$Sepal.Length test_that("average_loss() works ungrouped for regression", { s <- average_loss(fit, X = iris, y = y)$M + s2 <- average_loss(fit, X = iris, y = "Sepal.Length")$M expect_equal(drop(s), mean((y - predict(fit, iris))^2)) + expect_equal(s, s2) s <- average_loss(fit, X = iris, y = y, loss = "absolute_error")$M expect_equal(drop(s), mean(abs(y - predict(fit, iris)))) @@ -21,7 +23,7 @@ test_that("average_loss() works ungrouped for regression", { test_that("average_loss() works with groups for regression", { s <- average_loss(fit, X = iris, y = y, BY = iris$Species)$M - s2 <- average_loss(fit, X = iris, y = y, BY = "Species")$M + s2 <- average_loss(fit, X = iris, y = "Sepal.Length", BY = "Species")$M xpect <- by((y - predict(fit, iris))^2, FUN = mean, INDICES = iris$Species) expect_equal(drop(s), c(xpect)) @@ -29,16 +31,16 @@ test_that("average_loss() works with groups for regression", { expect_equal(dim(average_loss(fit, X = iris, y = y, BY = "Sepal.Width")$M), c(4L, 1L)) expect_equal( - dim(average_loss(fit, X = iris, y = y, BY = "Sepal.Width", by_size = 2L)$M), + dim(average_loss(fit, X = iris, y = "Sepal.Width", BY = "Sepal.Width", by_size = 2L)$M), c(2L, 1L) ) }) test_that("average_loss() works with weights for regression", { s1 <- average_loss(fit, X = iris, y = y) - s2 <- average_loss(fit, X = iris, y = y, w = rep(2, times = 150)) + s2 <- average_loss(fit, X = iris, y = "Sepal.Length", w = rep(2, times = 150)) s3 <- average_loss(fit, X = iris, y = y, w = "Petal.Width") - s4 <- average_loss(fit, X = iris, y = y, w = iris$Petal.Width) + s4 <- average_loss(fit, X = iris, y = "Sepal.Length", w = iris$Petal.Width) expect_equal(s1, s2) expect_false(identical(s2, s3)) @@ -51,9 +53,9 @@ test_that("average_loss() works with weights and grouped for regression", { g <- iris$Species s1 <- average_loss(fit, X = iris, y = y, BY = g) s2 <- average_loss( - fit, X = iris, y = y, w = rep(2, times = 150), BY = "Species" + fit, X = iris, y = "Sepal.Length", w = rep(2, times = 150), BY = "Species" ) - s3 <- average_loss(fit, X = iris, y = y, w = "Petal.Width", BY = g) + s3 <- average_loss(fit, X = iris, y = "Sepal.Length", w = "Petal.Width", BY = g) s4 <- average_loss(fit, X = iris, y = y, w = iris$Petal.Width, BY = g) expect_equal(s1, s2) @@ -66,11 +68,14 @@ test_that("average_loss() works with weights and grouped for regression", { #================================================ y <- as.matrix(iris[1:2]) +yy <- colnames(y) fit <- lm(y ~ Petal.Length + Species, data = iris) test_that("average_loss() works ungrouped (multi regression)", { s <- average_loss(fit, X = iris, y = y)$M expect_equal(drop(s), colMeans((y - predict(fit, iris))^2)) + s2 <- average_loss(fit, X = iris, y = yy)$M + expect_equal(s, s2) s <- average_loss(fit, X = iris, y = y, loss = "absolute_error")$M expect_equal(drop(s), colMeans(abs(y - predict(fit, iris)))) @@ -81,7 +86,7 @@ test_that("average_loss() works ungrouped (multi regression)", { s <- average_loss(fit, X = iris, y = y, loss = "poisson")$M expect_equal(drop(s), colMeans(poisson()$dev.resid(y, predict(fit, iris), 1))) - s <- average_loss(fit, X = iris, y = y, loss = "gamma")$M + s <- average_loss(fit, X = iris, y = yy, loss = "gamma")$M expect_equal(drop(s), colMeans(Gamma()$dev.resid(y, predict(fit, iris), 1))) }) @@ -93,9 +98,9 @@ test_that("average_loss() works with groups (multi regression)", { test_that("average_loss() works with weights (multi regression)", { s1 <- average_loss(fit, X = iris, y = y) - s2 <- average_loss(fit, X = iris, y = y, w = rep(2, times = 150)) + s2 <- average_loss(fit, X = iris, y = yy, w = rep(2, times = 150)) s3 <- average_loss(fit, X = iris, y = y, w = iris$Petal.Width) - s4 <- average_loss(fit, X = iris, y = y, w = "Petal.Width") + s4 <- average_loss(fit, X = iris, y = yy, w = "Petal.Width") expect_equal(s1, s2) expect_false(identical(s2, s3)) @@ -105,9 +110,9 @@ test_that("average_loss() works with weights (multi regression)", { test_that("average_loss() works with weights and grouped (multi regression)", { g <- iris$Species s1 <- average_loss(fit, X = iris, y = y, BY = g) - s2 <- average_loss(fit, X = iris, y = y, w = rep(2, times = 150), BY = g) + s2 <- average_loss(fit, X = iris, y = yy, w = rep(2, times = 150), BY = g) s3 <- average_loss(fit, X = iris, y = y, w = iris$Petal.Width, BY = g) - s4 <- average_loss(fit, X = iris, y = y, w = "Petal.Width", BY = g) + s4 <- average_loss(fit, X = iris, y = yy, w = "Petal.Width", BY = g) expect_equal(s1, s2) expect_false(identical(s2, s3)) diff --git a/tests/testthat/test_perm_importance.R b/tests/testthat/test_perm_importance.R index d5b57352..eed6aebb 100644 --- a/tests/testthat/test_perm_importance.R +++ b/tests/testthat/test_perm_importance.R @@ -2,6 +2,7 @@ fit <- lm(Sepal.Length ~ Sepal.Width + Species, data = iris) v <- setdiff(names(iris), "Sepal.Length") y <- iris$Sepal.Length +yy <- "Sepal.Length" set.seed(1L) s1 <- perm_importance(fit, X = iris[-1L], y = y) @@ -23,6 +24,12 @@ test_that("v can be selected (univariate)", { expect_equal(s1, s2) }) +test_that("y can also be passed as name (univariate)", { + set.seed(1L) + s2 <- perm_importance(fit, X = iris, y = yy) + expect_equal(s1, s2) +}) + test_that("results are positive for modeled features and zero otherwise (univariate)", { expect_true(all(s1$M[c("Sepal.Width", "Species"), ] > 1e-8)) expect_true(all(s1$M[c("Petal.Length", "Petal.Width"), ] < 1e-8)) @@ -30,6 +37,7 @@ test_that("results are positive for modeled features and zero otherwise (univari test_that("perm_importance() raises some errors (univariate)", { expect_error(perm_importance(fit, X = iris[-1L], y = 1:10)) + expect_error(perm_importance(fit, X = iris[-1], y = "Hello")) }) test_that("constant weights is same as unweighted (univariate)", { @@ -40,7 +48,8 @@ test_that("constant weights is same as unweighted (univariate)", { test_that("non-constant weights is different from unweighted (univariate)", { set.seed(1L) - s2 <- perm_importance(fit, X = iris[-1L], y = y, w = "Petal.Width") + s2 <- perm_importance(fit, X = iris, y = yy, w = "Petal.Width") + set.seed(1L) s3 <- perm_importance( fit, @@ -49,6 +58,7 @@ test_that("non-constant weights is different from unweighted (univariate)", { y = y, w = iris$Petal.Width ) + set.seed(1L) s4 <- perm_importance( fit, X = iris, v = colnames(iris[-1L]), y = y, w = "Petal.Width" @@ -142,6 +152,7 @@ test_that("non-numeric predictions can work as well (classification error)", { #================================================ y <- as.matrix(iris[1:2]) +yy <- colnames(y) fit <- lm(y ~ Petal.Length + Species, data = iris) v <- c("Petal.Length", "Petal.Width", "Species") set.seed(1L) @@ -152,6 +163,16 @@ test_that("print() does not give error (multivariate)", { capture_output(expect_no_error(print(s1))) }) +test_that("response can be passed as vector (multivariate)", { + set.seed(1L) + s2 <- perm_importance(fit, X = iris, y = yy) + expect_equal(s1, s2) + + set.seed(1L) + s3 <- perm_importance(fit, X = iris, y = yy, v = colnames(iris)) + expect_true(nrow(s2$M) < nrow(s3$M)) +}) + test_that("agg_cols works (multivariate)", { set.seed(1L) s2 <- perm_importance(fit, X = iris[3:5], y = y, agg_cols = TRUE) @@ -193,6 +214,8 @@ test_that("results are positive for modeled features and zero otherwise (multiva test_that("perm_importance() raises some errors (multivariate)", { expect_error(perm_importance(fit, X = iris[3:5], y = 1:10)) + expect_error(perm_importance(fit, X = iris[3:5], y = "hi")) + expect_error(perm_importance(fit, X = iris, y = rev(yy))) }) test_that("constant weights is same as unweighted (multivariate)", {