From 3c647a6131a7957b751f4424120cee6f4c314a7b Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sat, 2 Sep 2023 13:39:14 +0200 Subject: [PATCH 1/2] update news and readme --- NEWS.md | 2 +- README.md | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/NEWS.md b/NEWS.md index 915df4f7..6ba3e8e5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -5,7 +5,7 @@ - **average_loss()**: This new function calculates the average loss of a model for a given dataset, optionally grouped by a discrete vector. It supports the most important loss functions (squared error, Poisson deviance, Gamma deviance, Log loss, multivariate Log loss, absolute error, classification error), and allows for case weights. Custom losses can be passed as vector/matrix valued functions of signature `f(obs, pred)`. Note that such a custom function needs to return per-row losses, not their average. -- **perm_importance()**: H-statistics are often calculated for important features only. To support this workflow, we have added permutation importance regarding the most important loss functions. Multivariate losses can be studied individually or collapsed over dimensions. The importance of *feature groups* can be studied as well. Note that the API is different from the experimental `pd_importance()`, which is calculated from a "hstats" object, while `perm_importance()` acts on the fitted model. +- **perm_importance()**: H-statistics are often calculated for important features only. To support this workflow, we have added permutation importance regarding the most important loss functions. Multivariate losses can be studied individually or collapsed over dimensions. The importance of *feature groups* can be studied as well. Note that the API of `perm_importance()` is different from the experimental `pd_importance()`, which is calculated from a "hstats" object. ## Minor improvements diff --git a/README.md b/README.md index c31e08eb..a8a8c517 100644 --- a/README.md +++ b/README.md @@ -270,6 +270,14 @@ plot(s, normalize = FALSE, squared = FALSE) + ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |> plot(center = TRUE) + ggtitle("Centered ICE plots") + +# Permutation importance +perm_importance( + fit, v = colnames(iris)[-5], X = iris, y = iris$Species, loss = "mlogloss" +) +# +# Petal.Length Petal.Width Sepal.Length Sepal.Width +# 0.48918225 0.47393814 0.05435491 0.01426659 ``` ![](man/figures/multivariate.svg) @@ -302,6 +310,10 @@ fit <- iris_wf %>% s <- hstats(fit, v = colnames(iris[-1]), X = iris) s # 0 -> no interactions plot(partial_dep(fit, v = "Petal.Width", X = iris)) + +perm_importance(fit, v = colnames(iris[-1]), X = iris, y = iris$Sepal.Length) +# Petal.Length Species Petal.Width Sepal.Width +# 4.29502933 0.36451226 0.11146004 0.09371835 ``` ### caret @@ -321,6 +333,7 @@ fit <- train( h2(hstats(fit, v = colnames(iris[-1]), X = iris)) # 0 plot(ice(fit, v = "Petal.Width", X = iris), center = TRUE) +plot(perm_importance(fit, v = colnames(iris[-1]), X = iris, y = iris$Sepal.Length)) ``` ### mlr3 @@ -332,10 +345,10 @@ library(mlr3learners) # Probabilistic classification task_iris <- TaskClassif$new(id = "class", backend = iris, target = "Species") -fit_rf <- lrn("classif.ranger", predict_type = "prob", num.trees = 50) +fit_rf <- lrn("classif.ranger", predict_type = "prob") fit_rf$train(task_iris) v <- colnames(iris[-5]) -s <- hstats(fit_rf, v = v, X = iris) +s <- hstats(fit_rf, v = v, X = iris, threeway_m = 0) plot(s) # Permutation importance From eec823ea2f8316d3fb7106258405d1efd8e7b5e7 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sat, 2 Sep 2023 14:49:39 +0200 Subject: [PATCH 2/2] Change print logic of summary.hstats(), and simplify code of plot.hstats() --- NAMESPACE | 1 + NEWS.md | 4 +- R/hstats.R | 48 +- README.md | 21 +- man/figures/importance.svg | 94 +-- man/figures/importance_perm.svg | 132 ++-- man/figures/multivariate.svg | 265 +++---- man/figures/multivariate_ice.svg | 1146 +++++++++++++++--------------- man/print.summary_hstats.Rd | 22 + man/summary.hstats.Rd | 4 +- tests/testthat/test_hstats.R | 3 +- 11 files changed, 889 insertions(+), 851 deletions(-) create mode 100644 man/print.summary_hstats.Rd diff --git a/NAMESPACE b/NAMESPACE index 309d58c7..c0c763c4 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -38,6 +38,7 @@ S3method(print,hstats) S3method(print,ice) S3method(print,partial_dep) S3method(print,perm_importance) +S3method(print,summary_hstats) S3method(summary,hstats) export(average_loss) export(h2) diff --git a/NEWS.md b/NEWS.md index 6ba3e8e5..2a54ec63 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,8 +7,10 @@ Note that such a custom function needs to return per-row losses, not their avera - **perm_importance()**: H-statistics are often calculated for important features only. To support this workflow, we have added permutation importance regarding the most important loss functions. Multivariate losses can be studied individually or collapsed over dimensions. The importance of *feature groups* can be studied as well. Note that the API of `perm_importance()` is different from the experimental `pd_importance()`, which is calculated from a "hstats" object. -## Minor improvements +## Minor changes +- `summary.hstats()` now returns an object of class "summary_hstats" with its own `print()` method. Like this, one can use `su <- summary()` without printing to the console. +- The output of `summary.hstats()` is printed sligthly more compact. - `plot.hstats()` has recieved a `rotate_x = FALSE` argument for rotating x labels by 45 degrees. - `plot.hstats()` and `summary.hstats()` have received explicit arguments `normalize`, `squared`, `sort`, `eps` instead of passing them via `...`. - `plot.hstats()` now passes `...` to `geom_bar()`. diff --git a/R/hstats.R b/R/hstats.R index d8e66448..32feb16c 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -307,17 +307,18 @@ print.hstats <- function(x, ...) { #' #' @inheritParams h2_overall #' @param ... Currently not used. -#' @returns A named list of statistics. +#' @returns +#' An object of class "summary_hstats" representing a named list with statistics. #' @export #' @seealso See [hstats()] for examples. summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE, - top_m = 6L, eps = 1e-8, ...) { + top_m = Inf, eps = 1e-8, ...) { args <- list( object = object, normalize = normalize, squared = squared, sort = sort, - top_m = Inf, + top_m = top_m, eps = eps, plot = FALSE ) @@ -327,9 +328,21 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE h2_pairwise = do.call(h2_pairwise, args), h2_threeway = do.call(h2_threeway, args) ) - out <- out[sapply(out, Negate(is.null))] - - addon <- "(only for features with strong overall interactions)" + class(out) <- "summary_hstats" + out +} + +#' Print Method +#' +#' Print method for object of class "summary_hstats". +#' +#' @param x An object of class "summary_hstats". +#' @param ... Further arguments passed from other methods. +#' @returns Invisibly, the input is returned. +#' @export +#' @seealso See [hstats()] for examples. +print.summary_hstats <- function(x, ...) { + addon <- "(for features with strong overall interactions)" txt <- c( h2 = "Proportion of prediction variability unexplained by main effects of v", h2_overall = "Strongest overall interactions", @@ -337,13 +350,13 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE h2_threeway = paste0("Strongest relative three-way interactions\n", addon) ) - for (nm in names(out)) { + for (nm in names(Filter(Negate(is.null), x))) { cat(txt[[nm]]) cat("\n") - print(utils::head(out[[nm]], top_m)) + print(utils::head(drop(x[[nm]]))) cat("\n") } - invisible(out) + invisible(x) } #' Plot Method for "hstats" Object @@ -365,19 +378,12 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE, sort = TRUE, top_m = 15L, eps = 1e-8, fill = "#2b51a1", facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) { + su <- summary( + x, normalize = normalize, squared = squared, sort = sort, top_m = top_m, eps = eps + ) + nms <- c("h2_overall", "h2_pairwise", "h2_threeway") ids <- c("Overall", "Pairwise", "Threeway") - funs <- c(h2_overall, h2_pairwise, h2_threeway) - dat <- list() - i <- 1L - for (f in funs) { - if (i %in% which) - dat[[i]] <- mat2df( - f(x, normalize = normalize, squared = squared, - sort = sort, top_m = top_m, eps = eps, plot = FALSE), - id = ids[i] - ) - i <- i + 1L - } + dat <- lapply(which, FUN = function(j) mat2df(su[[nms[j]]], id = ids[j])) dat <- do.call(rbind, dat) p <- ggplot2::ggplot(dat, ggplot2::aes(x = value_, y = variable_)) + ggplot2::ylab(ggplot2::element_blank()) + diff --git a/README.md b/README.md index a8a8c517..f1369b96 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,6 @@ plot(s, which = 1:3, normalize = F, squared = F, facet_scales = "free_y", ncol = ![](man/figures/hstats3.svg) - ### Describe interaction Let's study different plots to understand *how* the strong interaction between distance to the ocean and age looks like. We will check the following three visualizations. @@ -200,7 +199,7 @@ pd_importance(s) + set.seed(10) imp <- perm_importance(fit, v = x, X = X_valid, y = y_valid) plot(imp) + - ggtitle("Permutation importance with standard errors") + ggtitle("Permutation importance + standard errors") ``` ![](man/figures/importance.svg) @@ -221,7 +220,7 @@ library(hstats) set.seed(1) fit <- ranger(Sepal.Length ~ ., data = iris) -ex <- explain(fit, data = iris[-1], y = iris[, 1]) +ex <- DALEX::explain(fit, data = iris[-1], y = iris[, 1]) s <- hstats(ex) s # Non-additivity index 0.054 @@ -254,14 +253,15 @@ library(ranger) library(ggplot2) library(hstats) -fit <- ranger(Species ~ ., data = iris, probability = TRUE, seed = 1) +set.seed(1) +fit <- ranger(Species ~ ., data = iris, probability = TRUE) average_loss(fit, X = iris, y = iris$Species, loss = "mlogloss") # 0.054 s <- hstats(fit, v = colnames(iris)[-5], X = iris) s # Proportion of prediction variability unexplained by main effects of v: # setosa versicolor virginica -# 0.002705945 0.065629375 0.046742035 +# 0.001547791 0.064550141 0.049758237 plot(s, normalize = FALSE, squared = FALSE) + ggtitle("Unnormalized statistics") + @@ -275,9 +275,9 @@ ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |> perm_importance( fit, v = colnames(iris)[-5], X = iris, y = iris$Species, loss = "mlogloss" ) -# + # Petal.Length Petal.Width Sepal.Length Sepal.Width -# 0.48918225 0.47393814 0.05435491 0.01426659 +# 0.50941613 0.49187688 0.05669978 0.00950009 ``` ![](man/figures/multivariate.svg) @@ -311,9 +311,12 @@ s <- hstats(fit, v = colnames(iris[-1]), X = iris) s # 0 -> no interactions plot(partial_dep(fit, v = "Petal.Width", X = iris)) -perm_importance(fit, v = colnames(iris[-1]), X = iris, y = iris$Sepal.Length) +imp <- perm_importance(fit, v = colnames(iris[-1]), X = iris, y = iris$Sepal.Length) +imp # Petal.Length Species Petal.Width Sepal.Width -# 4.29502933 0.36451226 0.11146004 0.09371835 +# 4.44682039 0.34064367 0.10195946 0.09520902 + +plot(imp) ``` ### caret diff --git a/man/figures/importance.svg b/man/figures/importance.svg index 3ab033c3..d50b64c5 100644 --- a/man/figures/importance.svg +++ b/man/figures/importance.svg @@ -1,5 +1,5 @@ - +