Skip to content

Commit

Permalink
Merge pull request #53 from mayer79/prep_release
Browse files Browse the repository at this point in the history
Print, plot etc.
  • Loading branch information
mayer79 authored Sep 2, 2023
2 parents 455b1da + eec823e commit 80208d0
Show file tree
Hide file tree
Showing 11 changed files with 901 additions and 850 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ S3method(print,hstats)
S3method(print,ice)
S3method(print,partial_dep)
S3method(print,perm_importance)
S3method(print,summary_hstats)
S3method(summary,hstats)
export(average_loss)
export(h2)
Expand Down
6 changes: 4 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
- **average_loss()**: This new function calculates the average loss of a model for a given dataset, optionally grouped by a discrete vector. It supports the most important loss functions (squared error, Poisson deviance, Gamma deviance, Log loss, multivariate Log loss, absolute error, classification error), and allows for case weights. Custom losses can be passed as vector/matrix valued functions of signature `f(obs, pred)`.
Note that such a custom function needs to return per-row losses, not their average.

- **perm_importance()**: H-statistics are often calculated for important features only. To support this workflow, we have added permutation importance regarding the most important loss functions. Multivariate losses can be studied individually or collapsed over dimensions. The importance of *feature groups* can be studied as well. Note that the API is different from the experimental `pd_importance()`, which is calculated from a "hstats" object, while `perm_importance()` acts on the fitted model.
- **perm_importance()**: H-statistics are often calculated for important features only. To support this workflow, we have added permutation importance regarding the most important loss functions. Multivariate losses can be studied individually or collapsed over dimensions. The importance of *feature groups* can be studied as well. Note that the API of `perm_importance()` is different from the experimental `pd_importance()`, which is calculated from a "hstats" object.

## Minor improvements
## Minor changes

- `summary.hstats()` now returns an object of class "summary_hstats" with its own `print()` method. Like this, one can use `su <- summary()` without printing to the console.
- The output of `summary.hstats()` is printed sligthly more compact.
- `plot.hstats()` has recieved a `rotate_x = FALSE` argument for rotating x labels by 45 degrees.
- `plot.hstats()` and `summary.hstats()` have received explicit arguments `normalize`, `squared`, `sort`, `eps` instead of passing them via `...`.
- `plot.hstats()` now passes `...` to `geom_bar()`.
Expand Down
48 changes: 27 additions & 21 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -307,17 +307,18 @@ print.hstats <- function(x, ...) {
#'
#' @inheritParams h2_overall
#' @param ... Currently not used.
#' @returns A named list of statistics.
#' @returns
#' An object of class "summary_hstats" representing a named list with statistics.
#' @export
#' @seealso See [hstats()] for examples.
summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE,
top_m = 6L, eps = 1e-8, ...) {
top_m = Inf, eps = 1e-8, ...) {
args <- list(
object = object,
normalize = normalize,
squared = squared,
sort = sort,
top_m = Inf,
top_m = top_m,
eps = eps,
plot = FALSE
)
Expand All @@ -327,23 +328,35 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE
h2_pairwise = do.call(h2_pairwise, args),
h2_threeway = do.call(h2_threeway, args)
)
out <- out[sapply(out, Negate(is.null))]

addon <- "(only for features with strong overall interactions)"
class(out) <- "summary_hstats"
out
}

#' Print Method
#'
#' Print method for object of class "summary_hstats".
#'
#' @param x An object of class "summary_hstats".
#' @param ... Further arguments passed from other methods.
#' @returns Invisibly, the input is returned.
#' @export
#' @seealso See [hstats()] for examples.
print.summary_hstats <- function(x, ...) {
addon <- "(for features with strong overall interactions)"
txt <- c(
h2 = "Proportion of prediction variability unexplained by main effects of v",
h2_overall = "Strongest overall interactions",
h2_pairwise = paste0("Strongest relative pairwise interactions\n", addon),
h2_threeway = paste0("Strongest relative three-way interactions\n", addon)
)

for (nm in names(out)) {
for (nm in names(Filter(Negate(is.null), x))) {
cat(txt[[nm]])
cat("\n")
print(utils::head(out[[nm]], top_m))
print(utils::head(drop(x[[nm]])))
cat("\n")
}
invisible(out)
invisible(x)
}

#' Plot Method for "hstats" Object
Expand All @@ -365,19 +378,12 @@ summary.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE
plot.hstats <- function(x, which = 1:2, normalize = TRUE, squared = TRUE, sort = TRUE,
top_m = 15L, eps = 1e-8, fill = "#2b51a1",
facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) {
su <- summary(
x, normalize = normalize, squared = squared, sort = sort, top_m = top_m, eps = eps
)
nms <- c("h2_overall", "h2_pairwise", "h2_threeway")
ids <- c("Overall", "Pairwise", "Threeway")
funs <- c(h2_overall, h2_pairwise, h2_threeway)
dat <- list()
i <- 1L
for (f in funs) {
if (i %in% which)
dat[[i]] <- mat2df(
f(x, normalize = normalize, squared = squared,
sort = sort, top_m = top_m, eps = eps, plot = FALSE),
id = ids[i]
)
i <- i + 1L
}
dat <- lapply(which, FUN = function(j) mat2df(su[[nms[j]]], id = ids[j]))
dat <- do.call(rbind, dat)
p <- ggplot2::ggplot(dat, ggplot2::aes(x = value_, y = variable_)) +
ggplot2::ylab(ggplot2::element_blank()) +
Expand Down
30 changes: 23 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ plot(s, which = 1:3, normalize = F, squared = F, facet_scales = "free_y", ncol =

![](man/figures/hstats3.svg)


### Describe interaction

Let's study different plots to understand *how* the strong interaction between distance to the ocean and age looks like. We will check the following three visualizations.
Expand Down Expand Up @@ -200,7 +199,7 @@ pd_importance(s) +
set.seed(10)
imp <- perm_importance(fit, v = x, X = X_valid, y = y_valid)
plot(imp) +
ggtitle("Permutation importance with standard errors")
ggtitle("Permutation importance + standard errors")
```

![](man/figures/importance.svg)
Expand All @@ -221,7 +220,7 @@ library(hstats)
set.seed(1)

fit <- ranger(Sepal.Length ~ ., data = iris)
ex <- explain(fit, data = iris[-1], y = iris[, 1])
ex <- DALEX::explain(fit, data = iris[-1], y = iris[, 1])

s <- hstats(ex)
s # Non-additivity index 0.054
Expand Down Expand Up @@ -254,14 +253,15 @@ library(ranger)
library(ggplot2)
library(hstats)

fit <- ranger(Species ~ ., data = iris, probability = TRUE, seed = 1)
set.seed(1)
fit <- ranger(Species ~ ., data = iris, probability = TRUE)
average_loss(fit, X = iris, y = iris$Species, loss = "mlogloss") # 0.054

s <- hstats(fit, v = colnames(iris)[-5], X = iris)
s
# Proportion of prediction variability unexplained by main effects of v:
# setosa versicolor virginica
# 0.002705945 0.065629375 0.046742035
# 0.001547791 0.064550141 0.049758237

plot(s, normalize = FALSE, squared = FALSE) +
ggtitle("Unnormalized statistics") +
Expand All @@ -270,6 +270,14 @@ plot(s, normalize = FALSE, squared = FALSE) +
ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |>
plot(center = TRUE) +
ggtitle("Centered ICE plots")

# Permutation importance
perm_importance(
fit, v = colnames(iris)[-5], X = iris, y = iris$Species, loss = "mlogloss"
)

# Petal.Length Petal.Width Sepal.Length Sepal.Width
# 0.50941613 0.49187688 0.05669978 0.00950009
```

![](man/figures/multivariate.svg)
Expand Down Expand Up @@ -302,6 +310,13 @@ fit <- iris_wf %>%
s <- hstats(fit, v = colnames(iris[-1]), X = iris)
s # 0 -> no interactions
plot(partial_dep(fit, v = "Petal.Width", X = iris))

imp <- perm_importance(fit, v = colnames(iris[-1]), X = iris, y = iris$Sepal.Length)
imp
# Petal.Length Species Petal.Width Sepal.Width
# 4.44682039 0.34064367 0.10195946 0.09520902

plot(imp)
```

### caret
Expand All @@ -321,6 +336,7 @@ fit <- train(
h2(hstats(fit, v = colnames(iris[-1]), X = iris)) # 0

plot(ice(fit, v = "Petal.Width", X = iris), center = TRUE)
plot(perm_importance(fit, v = colnames(iris[-1]), X = iris, y = iris$Sepal.Length))
```

### mlr3
Expand All @@ -332,10 +348,10 @@ library(mlr3learners)

# Probabilistic classification
task_iris <- TaskClassif$new(id = "class", backend = iris, target = "Species")
fit_rf <- lrn("classif.ranger", predict_type = "prob", num.trees = 50)
fit_rf <- lrn("classif.ranger", predict_type = "prob")
fit_rf$train(task_iris)
v <- colnames(iris[-5])
s <- hstats(fit_rf, v = v, X = iris)
s <- hstats(fit_rf, v = v, X = iris, threeway_m = 0)
plot(s)

# Permutation importance
Expand Down
Loading

0 comments on commit 80208d0

Please sign in to comment.