Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update examples in README #75

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,23 @@
#' - `h2`: List with numerator and denominator of \eqn{H^2}.
#' - `h2_overall`: List with numerator and denominator of \eqn{H^2_j}.
#' - `v_pairwise`: Subset of `v` with largest \eqn{H^2_j} used for pairwise
#' calculations.
#' calculations. Only if pairwise calculations have been done.
#' - `combs2`: Named list of variable pairs for which pairwise partial
#' dependence functions are available.
#' dependence functions are available. Only if pairwise calculations have been done.
#' - `F_jk`: List of matrices, each representing (centered) bivariate
#' partial dependence functions \eqn{F_{jk}}.
#' Only if pairwise calculations have been done.
#' - `h2_pairwise`: List with numerator and denominator of \eqn{H^2_{jk}}.
#' Only if pairwise calculations have been done.
#' - `v_threeway`: Subset of `v` with largest `h2_overall()` used for three-way
#' calculations.
#' calculations. Only if three-way calculations have been done.
#' - `combs3`: Named list of variable triples for which three-way partial
#' dependence functions are available.
#' dependence functions are available. Only if three-way calculations have been done.
#' - `F_jkl`: List of matrices, each representing (centered) three-way
#' partial dependence functions \eqn{F_{jkl}}.
#' - `h2_threeway`: List with numerator and denominator of \eqn{H^2_{jkl}}.
#' Only if three-way calculations have been done.
#' - `h2_threeway`: List with numerator and denominator of \eqn{H^2_{jkl}}.
#' Only if three-way calculations have been done.
#' @references
#' Friedman, Jerome H., and Bogdan E. Popescu. *"Predictive Learning via Rule Ensembles."*
#' The Annals of Applied Statistics 2, no. 3 (2008): 916-54.
Expand Down
90 changes: 41 additions & 49 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ In the spirit of [1], and related to [4], we can extract from the "hstats" objec
```r
plot(pd_importance(s))

# Compared with repeated permutation importance regarding MSE
# Compared with four times repeated permutation importance regarding MSE
set.seed(10)
plot(perm_importance(fit, X = X_valid, y = y_valid))
```
Expand Down Expand Up @@ -227,36 +227,46 @@ Strongest relative interaction shown as ICE plot.

## Multivariate responses

{hstats} works also with multivariate output such as probabilistic classification, see examples with {ranger}, LightGBM, and XGBoost.
{hstats} works also with multivariate output such as probabilistic classification, see examples with

### {ranger}
- ranger,
- LightGBM, and
- XGBoost.

### Common preparation

```r
library(hstats)

set.seed(1)

ix <- c(1:40, 51:90, 101:140)
train <- iris[ix, ]
valid <- iris[-ix, ]

X_train <- data.matrix(train[-5])
X_valid <- data.matrix(valid[-5])
y_train <- train[, 5]
y_valid <- valid[, 5]
```

### ranger

```r
library(ranger)
library(ggplot2)

set.seed(1)

fit <- ranger(Species ~ ., data = iris, probability = TRUE)
average_loss(fit, X = iris, y = "Species", loss = "mlogloss") # 0.0521
fit <- ranger(Species ~ ., data = train, probability = TRUE)
average_loss(fit, X = valid, y = "Species", loss = "mlogloss") # 0.02

s <- hstats(fit, X = iris[-5])
s
# H^2 (normalized)
# setosa versicolor virginica
# 0.001547791 0.064550141 0.049758237
perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")

(s <- hstats(fit, X = iris[-5]))
plot(s, normalize = FALSE, squared = FALSE)

ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |>
plot(center = TRUE) +
ggtitle("Centered ICE plots")

perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")
# Permutation importance
# Petal.Length Petal.Width Sepal.Length Sepal.Width
# 0.50941613 0.49187688 0.05669978 0.00950009
ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width") |>
plot(center = TRUE)
```

![](man/figures/multivariate.svg)
Expand All @@ -268,22 +278,13 @@ perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")
Note: Versions from 4.0.0 upwards to not anymore require passing `reshape = TRUE` to the prediction function.

```r
library(hstats)
library(lightgbm)

set.seed(1)

ix <- c(1:40, 51:90, 101:140)
X <- data.matrix(iris[, -5])
y <- as.integer(iris[, 5]) - 1
X_train <- X[ix, ]
X_valid <- X[-ix, ]
y_train <- y[ix]
y_valid <- y[-ix]

params <- list(objective = "multiclass", num_class = 3, learning_rate = 0.2)
dtrain <- lgb.Dataset(X_train, label = y_train)
dvalid <- lgb.Dataset(X_valid, label = y_valid)
dtrain <- lgb.Dataset(X_train, label = as.integer(y_train) - 1)
dvalid <- lgb.Dataset(X_valid, label = as.integer(y_valid) - 1)

fit <- lgb.train(
params = params,
Expand All @@ -302,22 +303,22 @@ predict(fit, head(X_train, 2), reshape = TRUE)
# mlogloss: 9.331699e-05
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE)

partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(swap_dim = TRUE, alpha = 0.05)

perm_importance(
fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100
)
# Permutation importance regarding mlogloss
# Petal.Length Petal.Width Sepal.Width Sepal.Length
# 2.61783760 1.00647382 0.08414687 0.01011645
# 2.624241332 1.011168660 0.082477177 0.009757393

partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(swap_dim = TRUE, alpha = 0.05)

# Interaction statistics, including three-way stats
(H <- hstats(fit, X = X_train, reshape = TRUE, threeway_m = 4)) # 0.3010446 0.4167927 0.1623982
plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)
plot(H, ncol = 1)
```

![](man/figures/lightgbm.svg)
Expand All @@ -327,22 +328,13 @@ plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)
Also here, mind the `reshape = TRUE` sent to the prediction function.

```r
library(hstats)
library(xgboost)

set.seed(1)

ix <- c(1:40, 51:90, 101:140)
X <- data.matrix(iris[, -5])
y <- as.integer(iris[, 5]) - 1
X_train <- X[ix, ]
X_valid <- X[-ix, ]
y_train <- y[ix]
y_valid <- y[-ix]

params <- list(objective = "multi:softprob", num_class = 3, learning_rate = 0.2)
dtrain <- xgb.DMatrix(X_train, label = y_train)
dvalid <- xgb.DMatrix(X_valid, label = y_valid)
dtrain <- xgb.DMatrix(X_train, label = as.integer(y_train) - 1)
dvalid <- xgb.DMatrix(X_valid, label = as.integer(y_valid) - 1)

fit <- xgb.train(
params = params,
Expand Down
Loading