Skip to content

Commit

Permalink
Merge pull request #75 from mayer79/update_docu
Browse files Browse the repository at this point in the history
Update examples in README
  • Loading branch information
mayer79 authored Oct 12, 2023
2 parents 6fb1884 + a8d01ef commit 366bd46
Show file tree
Hide file tree
Showing 7 changed files with 962 additions and 1,092 deletions.
13 changes: 8 additions & 5 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,23 @@
#' - `h2`: List with numerator and denominator of \eqn{H^2}.
#' - `h2_overall`: List with numerator and denominator of \eqn{H^2_j}.
#' - `v_pairwise`: Subset of `v` with largest \eqn{H^2_j} used for pairwise
#' calculations.
#' calculations. Only if pairwise calculations have been done.
#' - `combs2`: Named list of variable pairs for which pairwise partial
#' dependence functions are available.
#' dependence functions are available. Only if pairwise calculations have been done.
#' - `F_jk`: List of matrices, each representing (centered) bivariate
#' partial dependence functions \eqn{F_{jk}}.
#' Only if pairwise calculations have been done.
#' - `h2_pairwise`: List with numerator and denominator of \eqn{H^2_{jk}}.
#' Only if pairwise calculations have been done.
#' - `v_threeway`: Subset of `v` with largest `h2_overall()` used for three-way
#' calculations.
#' calculations. Only if three-way calculations have been done.
#' - `combs3`: Named list of variable triples for which three-way partial
#' dependence functions are available.
#' dependence functions are available. Only if three-way calculations have been done.
#' - `F_jkl`: List of matrices, each representing (centered) three-way
#' partial dependence functions \eqn{F_{jkl}}.
#' - `h2_threeway`: List with numerator and denominator of \eqn{H^2_{jkl}}.
#' Only if three-way calculations have been done.
#' - `h2_threeway`: List with numerator and denominator of \eqn{H^2_{jkl}}.
#' Only if three-way calculations have been done.
#' @references
#' Friedman, Jerome H., and Bogdan E. Popescu. *"Predictive Learning via Rule Ensembles."*
#' The Annals of Applied Statistics 2, no. 3 (2008): 916-54.
Expand Down
90 changes: 41 additions & 49 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ In the spirit of [1], and related to [4], we can extract from the "hstats" objec
```r
plot(pd_importance(s))

# Compared with repeated permutation importance regarding MSE
# Compared with four times repeated permutation importance regarding MSE
set.seed(10)
plot(perm_importance(fit, X = X_valid, y = y_valid))
```
Expand Down Expand Up @@ -227,36 +227,46 @@ Strongest relative interaction shown as ICE plot.

## Multivariate responses

{hstats} works also with multivariate output such as probabilistic classification, see examples with {ranger}, LightGBM, and XGBoost.
{hstats} works also with multivariate output such as probabilistic classification, see examples with

### {ranger}
- ranger,
- LightGBM, and
- XGBoost.

### Common preparation

```r
library(hstats)

set.seed(1)

ix <- c(1:40, 51:90, 101:140)
train <- iris[ix, ]
valid <- iris[-ix, ]

X_train <- data.matrix(train[-5])
X_valid <- data.matrix(valid[-5])
y_train <- train[, 5]
y_valid <- valid[, 5]
```

### ranger

```r
library(ranger)
library(ggplot2)

set.seed(1)

fit <- ranger(Species ~ ., data = iris, probability = TRUE)
average_loss(fit, X = iris, y = "Species", loss = "mlogloss") # 0.0521
fit <- ranger(Species ~ ., data = train, probability = TRUE)
average_loss(fit, X = valid, y = "Species", loss = "mlogloss") # 0.02

s <- hstats(fit, X = iris[-5])
s
# H^2 (normalized)
# setosa versicolor virginica
# 0.001547791 0.064550141 0.049758237
perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")

(s <- hstats(fit, X = iris[-5]))
plot(s, normalize = FALSE, squared = FALSE)

ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |>
plot(center = TRUE) +
ggtitle("Centered ICE plots")

perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")
# Permutation importance
# Petal.Length Petal.Width Sepal.Length Sepal.Width
# 0.50941613 0.49187688 0.05669978 0.00950009
ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width") |>
plot(center = TRUE)
```

![](man/figures/multivariate.svg)
Expand All @@ -268,22 +278,13 @@ perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")
Note: Versions from 4.0.0 upwards to not anymore require passing `reshape = TRUE` to the prediction function.

```r
library(hstats)
library(lightgbm)

set.seed(1)

ix <- c(1:40, 51:90, 101:140)
X <- data.matrix(iris[, -5])
y <- as.integer(iris[, 5]) - 1
X_train <- X[ix, ]
X_valid <- X[-ix, ]
y_train <- y[ix]
y_valid <- y[-ix]

params <- list(objective = "multiclass", num_class = 3, learning_rate = 0.2)
dtrain <- lgb.Dataset(X_train, label = y_train)
dvalid <- lgb.Dataset(X_valid, label = y_valid)
dtrain <- lgb.Dataset(X_train, label = as.integer(y_train) - 1)
dvalid <- lgb.Dataset(X_valid, label = as.integer(y_valid) - 1)

fit <- lgb.train(
params = params,
Expand All @@ -302,22 +303,22 @@ predict(fit, head(X_train, 2), reshape = TRUE)
# mlogloss: 9.331699e-05
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE)

partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(swap_dim = TRUE, alpha = 0.05)

perm_importance(
fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100
)
# Permutation importance regarding mlogloss
# Petal.Length Petal.Width Sepal.Width Sepal.Length
# 2.61783760 1.00647382 0.08414687 0.01011645
# 2.624241332 1.011168660 0.082477177 0.009757393

partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(swap_dim = TRUE, alpha = 0.05)

# Interaction statistics, including three-way stats
(H <- hstats(fit, X = X_train, reshape = TRUE, threeway_m = 4)) # 0.3010446 0.4167927 0.1623982
plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)
plot(H, ncol = 1)
```

![](man/figures/lightgbm.svg)
Expand All @@ -327,22 +328,13 @@ plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)
Also here, mind the `reshape = TRUE` sent to the prediction function.

```r
library(hstats)
library(xgboost)

set.seed(1)

ix <- c(1:40, 51:90, 101:140)
X <- data.matrix(iris[, -5])
y <- as.integer(iris[, 5]) - 1
X_train <- X[ix, ]
X_valid <- X[-ix, ]
y_train <- y[ix]
y_valid <- y[-ix]

params <- list(objective = "multi:softprob", num_class = 3, learning_rate = 0.2)
dtrain <- xgb.DMatrix(X_train, label = y_train)
dvalid <- xgb.DMatrix(X_valid, label = y_valid)
dtrain <- xgb.DMatrix(X_train, label = as.integer(y_train) - 1)
dvalid <- xgb.DMatrix(X_valid, label = as.integer(y_valid) - 1)

fit <- xgb.train(
params = params,
Expand Down
Loading

0 comments on commit 366bd46

Please sign in to comment.