From a8d01effffe7d854ec9e5d12fbf52c7a676e5c53 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Thu, 12 Oct 2023 18:59:28 +0200 Subject: [PATCH] Update examples in README --- R/hstats.R | 13 +- README.md | 90 ++- man/figures/lightgbm.svg | 350 ++++++----- man/figures/multivariate.svg | 263 ++++---- man/figures/multivariate_ice.svg | 997 +++++++++++++------------------ man/figures/xgboost.svg | 330 +++++----- man/hstats.Rd | 11 +- 7 files changed, 962 insertions(+), 1092 deletions(-) diff --git a/R/hstats.R b/R/hstats.R index 6ed5e11e..d1d0c6f9 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -72,20 +72,23 @@ #' - `h2`: List with numerator and denominator of \eqn{H^2}. #' - `h2_overall`: List with numerator and denominator of \eqn{H^2_j}. #' - `v_pairwise`: Subset of `v` with largest \eqn{H^2_j} used for pairwise -#' calculations. +#' calculations. Only if pairwise calculations have been done. #' - `combs2`: Named list of variable pairs for which pairwise partial -#' dependence functions are available. +#' dependence functions are available. Only if pairwise calculations have been done. #' - `F_jk`: List of matrices, each representing (centered) bivariate #' partial dependence functions \eqn{F_{jk}}. +#' Only if pairwise calculations have been done. #' - `h2_pairwise`: List with numerator and denominator of \eqn{H^2_{jk}}. #' Only if pairwise calculations have been done. #' - `v_threeway`: Subset of `v` with largest `h2_overall()` used for three-way -#' calculations. +#' calculations. Only if three-way calculations have been done. #' - `combs3`: Named list of variable triples for which three-way partial -#' dependence functions are available. +#' dependence functions are available. Only if three-way calculations have been done. #' - `F_jkl`: List of matrices, each representing (centered) three-way #' partial dependence functions \eqn{F_{jkl}}. -#' - `h2_threeway`: List with numerator and denominator of \eqn{H^2_{jkl}}. +#' Only if three-way calculations have been done. +#' - `h2_threeway`: List with numerator and denominator of \eqn{H^2_{jkl}}. +#' Only if three-way calculations have been done. #' @references #' Friedman, Jerome H., and Bogdan E. Popescu. *"Predictive Learning via Rule Ensembles."* #' The Annals of Applied Statistics 2, no. 3 (2008): 916-54. diff --git a/README.md b/README.md index 83fef540..768a1ae2 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ In the spirit of [1], and related to [4], we can extract from the "hstats" objec ```r plot(pd_importance(s)) -# Compared with repeated permutation importance regarding MSE +# Compared with four times repeated permutation importance regarding MSE set.seed(10) plot(perm_importance(fit, X = X_valid, y = y_valid)) ``` @@ -227,36 +227,46 @@ Strongest relative interaction shown as ICE plot. ## Multivariate responses -{hstats} works also with multivariate output such as probabilistic classification, see examples with {ranger}, LightGBM, and XGBoost. +{hstats} works also with multivariate output such as probabilistic classification, see examples with -### {ranger} +- ranger, +- LightGBM, and +- XGBoost. + +### Common preparation ```r library(hstats) + +set.seed(1) + +ix <- c(1:40, 51:90, 101:140) +train <- iris[ix, ] +valid <- iris[-ix, ] + +X_train <- data.matrix(train[-5]) +X_valid <- data.matrix(valid[-5]) +y_train <- train[, 5] +y_valid <- valid[, 5] +``` + +### ranger + +```r library(ranger) -library(ggplot2) set.seed(1) -fit <- ranger(Species ~ ., data = iris, probability = TRUE) -average_loss(fit, X = iris, y = "Species", loss = "mlogloss") # 0.0521 +fit <- ranger(Species ~ ., data = train, probability = TRUE) +average_loss(fit, X = valid, y = "Species", loss = "mlogloss") # 0.02 -s <- hstats(fit, X = iris[-5]) -s -# H^2 (normalized) -# setosa versicolor virginica -# 0.001547791 0.064550141 0.049758237 +perm_importance(fit, X = iris, y = "Species", loss = "mlogloss") +(s <- hstats(fit, X = iris[-5])) plot(s, normalize = FALSE, squared = FALSE) -ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |> - plot(center = TRUE) + - ggtitle("Centered ICE plots") - -perm_importance(fit, X = iris, y = "Species", loss = "mlogloss") -# Permutation importance -# Petal.Length Petal.Width Sepal.Length Sepal.Width -# 0.50941613 0.49187688 0.05669978 0.00950009 +ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width") |> + plot(center = TRUE) ``` ![](man/figures/multivariate.svg) @@ -268,22 +278,13 @@ perm_importance(fit, X = iris, y = "Species", loss = "mlogloss") Note: Versions from 4.0.0 upwards to not anymore require passing `reshape = TRUE` to the prediction function. ```r -library(hstats) library(lightgbm) set.seed(1) -ix <- c(1:40, 51:90, 101:140) -X <- data.matrix(iris[, -5]) -y <- as.integer(iris[, 5]) - 1 -X_train <- X[ix, ] -X_valid <- X[-ix, ] -y_train <- y[ix] -y_valid <- y[-ix] - params <- list(objective = "multiclass", num_class = 3, learning_rate = 0.2) -dtrain <- lgb.Dataset(X_train, label = y_train) -dvalid <- lgb.Dataset(X_valid, label = y_valid) +dtrain <- lgb.Dataset(X_train, label = as.integer(y_train) - 1) +dvalid <- lgb.Dataset(X_valid, label = as.integer(y_valid) - 1) fit <- lgb.train( params = params, @@ -302,22 +303,22 @@ predict(fit, head(X_train, 2), reshape = TRUE) # mlogloss: 9.331699e-05 average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE) -partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> - plot(show_points = FALSE) - -ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> - plot(swap_dim = TRUE, alpha = 0.05) - perm_importance( fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100 ) # Permutation importance regarding mlogloss # Petal.Length Petal.Width Sepal.Width Sepal.Length -# 2.61783760 1.00647382 0.08414687 0.01011645 +# 2.624241332 1.011168660 0.082477177 0.009757393 + +partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> + plot(show_points = FALSE) + +ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> + plot(swap_dim = TRUE, alpha = 0.05) # Interaction statistics, including three-way stats (H <- hstats(fit, X = X_train, reshape = TRUE, threeway_m = 4)) # 0.3010446 0.4167927 0.1623982 -plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1) +plot(H, ncol = 1) ``` ![](man/figures/lightgbm.svg) @@ -327,22 +328,13 @@ plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1) Also here, mind the `reshape = TRUE` sent to the prediction function. ```r -library(hstats) library(xgboost) set.seed(1) -ix <- c(1:40, 51:90, 101:140) -X <- data.matrix(iris[, -5]) -y <- as.integer(iris[, 5]) - 1 -X_train <- X[ix, ] -X_valid <- X[-ix, ] -y_train <- y[ix] -y_valid <- y[-ix] - params <- list(objective = "multi:softprob", num_class = 3, learning_rate = 0.2) -dtrain <- xgb.DMatrix(X_train, label = y_train) -dvalid <- xgb.DMatrix(X_valid, label = y_valid) +dtrain <- xgb.DMatrix(X_train, label = as.integer(y_train) - 1) +dvalid <- xgb.DMatrix(X_valid, label = as.integer(y_valid) - 1) fit <- xgb.train( params = params, diff --git a/man/figures/lightgbm.svg b/man/figures/lightgbm.svg index f2dc015a..b9e51bf8 100644 --- a/man/figures/lightgbm.svg +++ b/man/figures/lightgbm.svg @@ -1,5 +1,5 @@ - +