Skip to content

Commit

Permalink
Merge pull request #66 from mayer79/different_pd
Browse files Browse the repository at this point in the history
WIP: Cleaner output API
  • Loading branch information
mayer79 authored Oct 3, 2023
2 parents fd42bac + ce1ddfa commit 9cc20fe
Show file tree
Hide file tree
Showing 45 changed files with 1,488 additions and 1,433 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hstats
Title: Interaction Statistics
Version: 1.0.0
Version: 0.4.0
Authors@R:
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"))
Description: Fast, model-agnostic implementation of different H-statistics
Expand Down
6 changes: 3 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ S3method(perm_importance,default)
S3method(perm_importance,explainer)
S3method(perm_importance,ranger)
S3method(plot,hstats)
S3method(plot,hstats_matrix)
S3method(plot,ice)
S3method(plot,partial_dep)
S3method(plot,perm_importance)
S3method(print,hstats)
S3method(print,hstats_matrix)
S3method(print,hstats_summary)
S3method(print,ice)
S3method(print,partial_dep)
S3method(print,perm_importance)
S3method(print,summary_hstats)
S3method(summary,hstats)
export(average_loss)
export(h2)
Expand Down
19 changes: 16 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
# hstats 1.0.0
# hstats 0.4.0

## Visible changes
This release comes with a cleaner output API. The numeric results are unchanged.

## Major changes

- `h2()`, `h2_overall()`, `h2_pairwise()`, `h2_threeway()`, `perm_importance()`, and `pd_importance()` now return an object of type "hstats_matrix" with a `print()` and `plot()` method. The values can be extracted via `$M`.
- Their argument `top_m` has been moved to `plot()`.
- `perm_importance()`: The `perms` argument has been renamed to `m_rep`. Since the output is now of class "hstats_matrix", the resulting importance values are stored as `$M`.
- All `print()`, `summary()`, and `plot()` methods have been revised.

## Minor changes

- plot.perm_importance() now represents importance values of *multi-output* models as stacked bars. Set `multi_output = "facets"` for the old behaviour.
- Plotting the result of `perm_importance()` on a multi-output model now produces a stacked barplot. Set `multi_output = "facets"` for the old behaviour.
- `H-squared`: The $H^2$ statistic stored in a "hstats" object is now a matrix with one row (it was a vector).
- `eps`: The clipping threshold of squared numerator statistics has been reduced from 1e-8 to 1e-10. It is now handled in `hstats()` instead of the statistic functions.
- `pd_importance()`: The "hstats" object now contains pre-calculated PD-based importance values in `$pd_importance`.
- `summary.hstats()` now returns an object of class "hstats_summary" instead of "summary_hstats".

# hstats 0.3.0

Expand Down
28 changes: 13 additions & 15 deletions R/H2.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#' Total Interaction Strength
#'
#' Proportion of prediction variability unexplained by main effects of `v`, see Details.
#' Proportion of prediction variability unexplained by main effects of `v`, see Details.
#' Use `plot()` to get a barplot.
#'
#' @details
#' If the model is additive in all features, then the (centered) prediction
#' function \eqn{F} equals the sum of the (centered) partial dependence
#' functions \eqn{F_j(x_j)}, i.e.,
Expand All @@ -25,8 +25,7 @@
#' A similar measure using accumulated local effects is discussed in Molnar (2020).
#'
#' @inheritParams h2_overall
#' @param ... Currently unused.
#' @returns Vector of total interaction strength (one value per prediction dimension).
#' @inherit h2_overall return
#' @export
#' @seealso [hstats()], [h2_overall()], [h2_pairwise()], [h2_threeway()]
#' @references
Expand Down Expand Up @@ -65,14 +64,14 @@ h2.default <- function(object, ...) {

#' @describeIn h2 Total interaction strength from "interact" object.
#' @export
h2.hstats <- function(object, normalize = TRUE, squared = TRUE, eps = 1e-8, ...) {
postprocess(
num = object$h2$num,
denom = object$h2$denom,
h2.hstats <- function(object, normalize = TRUE, squared = TRUE, ...) {
get_hstats_matrix(
statistic = "h2",
object = object,
normalize = normalize,
squared = squared,
squared = squared,
sort = FALSE,
eps = eps
zero = TRUE
)
}

Expand All @@ -83,12 +82,11 @@ h2.hstats <- function(object, normalize = TRUE, squared = TRUE, eps = 1e-8, ...)
#'
#' @noRd
#' @keywords internal
#' @param x A list containing the elements "f", "F_j", "w", and "mean_f2".
#' @param x A list containing the elements "f", "F_j", "w", "eps", and "mean_f2".
#' @returns A list with the numerator and denominator statistics.
h2_raw <- function(x) {
list(
num = with(x, wcolMeans((f - Reduce("+", F_j))^2, w = w)),
denom = x[["mean_f2"]]
)
num <- with(x, rbind(wcolMeans((f - Reduce("+", F_j))^2, w = w)))
num <- .zap_small(num, eps = x[["eps"]]) # Numeric precision
list(num = num, denom = x[["mean_f2"]])
}

49 changes: 24 additions & 25 deletions R/H2_overall.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#' Overall Interaction Strength
#'
#' Friedman and Popescu's statistic of overall interaction strength per
#' feature, see Details. Set `plot = TRUE` to plot the results as barplot.
#' feature, see Details. Use `plot()` to get a barplot.
#'
#' @details
#' The logic of Friedman and Popescu (2008) is as follows:
#' If there are no interactions involving feature \eqn{x_j}, we can decompose the
#' (centered) prediction function \eqn{F} into the sum of the (centered) partial
Expand Down Expand Up @@ -42,16 +41,19 @@
#' @param normalize Should statistics be normalized? Default is `TRUE`.
#' @param squared Should *squared* statistics be returned? Default is `TRUE`.
#' @param sort Should results be sorted? Default is `TRUE`.
#' (Multioutput is sorted by row means.)
#' @param top_m How many rows should be shown? (`Inf` to show all.)
#' (Multi-output is sorted by row means.)
#' @param zero Should rows with all 0 be shown? Default is `TRUE`.
#' @param eps Threshold below which numerator values are set to 0.
#' @param plot Should results be plotted as barplot? Default is `FALSE`.
#' @param fill Color of bar (only for univariate statistics).
#' @param ... Further parameters passed to `geom_bar()`.
#' @param ... Currently unused.
#' @returns
#' A matrix of statistics (one row per variable, one column per prediction dimension),
#' or a "ggplot" object (if `plot = TRUE`).
#' An object of class "hstats_matrix" containing these elements:
#' - `M`: Matrix of statistics (one column per prediction dimension), or `NULL`.
#' - `SE`: Matrix with standard errors of `M`, or `NULL`.
#' Multiply with `sqrt(m_rep)` to get *standard deviations* instead.
#' Currently, supported only for [perm_importance()].
#' - `m_rep`: The number of repetitions behind standard errors `SE`, or `NULL`.
#' Currently, supported only for [perm_importance()].
#' - `statistic`: Name of the function that generated the statistic.
#' - `description`: Description of the statistic.
#' @inherit hstats references
#' @seealso [hstats()], [h2()], [h2_pairwise()], [h2_threeway()]
#' @export
Expand All @@ -60,12 +62,12 @@
#' fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris)
#' s <- hstats(fit, X = iris[-1])
#' h2_overall(s)
#' h2_overall(s, plot = TRUE)
#' plot(h2_overall(s))
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5], verbose = FALSE)
#' h2_overall(s, plot = TRUE, zero = FALSE)
#' plot(h2_overall(s, zero = FALSE))
h2_overall <- function(object, ...) {
UseMethod("h2_overall")
}
Expand All @@ -78,21 +80,16 @@ h2_overall.default <- function(object, ...) {

#' @describeIn h2_overall Overall interaction strength from "hstats" object.
#' @export
h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE,
top_m = 15L, zero = TRUE, eps = 1e-8,
plot = FALSE, fill = "#2b51a1", ...) {
s <- object$h2_overall
out <- postprocess(
num = s$num,
denom = s$denom,
h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
get_hstats_matrix(
statistic = "h2_overall",
object = object,
normalize = normalize,
squared = squared,
sort = sort,
top_m = top_m,
zero = zero,
eps = eps
sort = sort,
zero = zero
)
if (plot) plot_stat(out, fill = fill, ...) else out
}

# Helper function
Expand All @@ -105,12 +102,14 @@ h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = T
#' @noRd
#' @keywords internal
#' @param x A list containing the elements "v", "K", "pred_names",
#' "f", "F_not_j", "F_j", "mean_f2", and "w".
#' "f", "F_not_j", "F_j", "mean_f2", "eps", and "w".
#' @returns A list with the numerator and denominator statistics.
h2_overall_raw <- function(x) {
num <- init_numerator(x, way = 1L)
for (z in x[["v"]]) {
num[z, ] <- with(x, wcolMeans((f - F_j[[z]] - F_not_j[[z]])^2, w = w))
}
num <- .zap_small(num, eps = x[["eps"]]) # Numeric precision

list(num = num, denom = x[["mean_f2"]])
}
45 changes: 19 additions & 26 deletions R/H2_pairwise.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#' Pairwise Interaction Strength
#'
#' Friedman and Popescu's statistic of pairwise interaction strength, see Details.
#' Set `plot = TRUE` to plot the results as barplot.
#' Use `plot()` to get a barplot.
#'
#' @details
#' Following Friedman and Popescu (2008), if there are no interaction effects between
#' features \eqn{x_j} and \eqn{x_k}, their two-dimensional (centered) partial dependence
#' function \eqn{F_{jk}} can be written as the sum of the (centered) univariate partial
Expand Down Expand Up @@ -46,10 +45,7 @@
#' rather for those features with *strongest overall interactions*.
#'
#' @inheritParams h2_overall
#' @returns
#' A matrix of statistics (one row per variable, one column per prediction dimension),
#' or a "ggplot" object (if `plot = TRUE`). If no pairwise
#' statistics have been calculated, the function returns `NULL`.
#' @inherit h2_overall return
#' @inherit hstats references
#' @export
#' @seealso [hstats()], [h2()], [h2_overall()], [h2_threeway()]
Expand All @@ -64,13 +60,15 @@
#' h2_pairwise(s, zero = FALSE) # Drop 0
#'
#' # Absolute measure as alternative
#' h2_pairwise(s, normalize = FALSE, squared = FALSE)
#' abs_h <- h2_pairwise(s, normalize = FALSE, squared = FALSE, zero = FALSE)
#' abs_h
#' abs_h$M
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5], verbose = FALSE)
#' h2_pairwise(s, plot = TRUE)
#' h2_pairwise(s, zero = FALSE, plot = TRUE)
#' h2_pairwise(s)
#' plot(h2_pairwise(s))
h2_pairwise <- function(object, ...) {
UseMethod("h2_pairwise")
}
Expand All @@ -83,24 +81,16 @@ h2_pairwise.default <- function(object, ...) {

#' @describeIn h2_pairwise Pairwise interaction strength from "hstats" object.
#' @export
h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE,
top_m = 15L, zero = TRUE, eps = 1e-8,
plot = FALSE, fill = "#2b51a1", ...) {
s <- object$h2_pairwise
if (is.null(s)) {
return(NULL)
}
out <- postprocess(
num = s$num,
denom = s$denom,
h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
get_hstats_matrix(
statistic = "h2_pairwise",
object = object,
normalize = normalize,
squared = squared,
sort = sort,
top_m = top_m,
zero = zero,
eps = eps
sort = sort,
zero = zero
)
if (plot) plot_stat(out, fill = fill, ...) else out
}

#' Raw H2 Pairwise
Expand All @@ -111,7 +101,7 @@ h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, sort =
#' @noRd
#' @keywords internal
#' @param x A list containing the elements "combs2", "v_pairwise_0", "K", "pred_names",
#' "F_jk", "F_j", and "w".
#' "F_jk", "F_j", "eps", and "w".
#' @returns A list with the numerator and denominator statistics.
h2_pairwise_raw <- function(x) {
num <- init_numerator(x, way = 2L)
Expand All @@ -122,10 +112,13 @@ h2_pairwise_raw <- function(x) {
if (!is.null(combs)) {
for (nm in names(combs)) {
z <- combs[[nm]]
num[nm, ] <- with(x, wcolMeans((F_jk[[nm]] - F_j[[z[1L]]] - F_j[[z[2L]]])^2, w = w))
num[nm, ] <- with(
x, wcolMeans((F_jk[[nm]] - F_j[[z[1L]]] - F_j[[z[2L]]])^2, w = w)
)
denom[nm, ] <- with(x, wcolMeans(F_jk[[nm]]^2, w = w))
}
}
num <- .zap_small(num, eps = x[["eps"]]) # Numeric precision

list(num = num, denom = denom)
}
38 changes: 13 additions & 25 deletions R/H2_threeway.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#' Three-way Interaction Strength
#'
#' Friedman and Popescu's statistic of three-way interaction strength, see Details.
#' Set `plot = TRUE` to plot the results as barplot.
#' Use `plot()` to get a barplot.
#'
#' @details
#' Friedman and Popescu (2008) describe a test statistic to measure three-way
#' interactions: in case there are no three-way interactions between features
#' \eqn{x_j}, \eqn{x_k} and \eqn{x_l}, their (centered) three-dimensional partial
Expand Down Expand Up @@ -36,10 +35,7 @@
#' Similar remarks as for [h2_pairwise()] apply.
#'
#' @inheritParams h2_overall
#' @returns
#' A matrix of statistics (one row per variable, one column per prediction dimension),
#' or a "ggplot" object (if `plot = TRUE`). If no three-way
#' statistics have been calculated, the function returns `NULL`.
#' @inherit h2_overall return
#' @inherit hstats references
#' @export
#' @seealso [hstats()], [h2()], [h2_overall()], [h2_pairwise()]
Expand All @@ -53,9 +49,8 @@
#' fit <- lm(cbind(up = uptake, up2 = 2 * uptake) ~ Type * Treatment * conc, data = CO2)
#' s <- hstats(fit, X = CO2[2:4], verbose = FALSE)
#' h2_threeway(s)
#'
#' # Unnormalized H
#' h2_threeway(s, normalize = FALSE, squared = FALSE)
#' h2_threeway(s, normalize = FALSE, squared = FALSE) # Unnormalized H
#' plot(h2_threeway(s))
h2_threeway <- function(object, ...) {
UseMethod("h2_threeway")
}
Expand All @@ -68,24 +63,16 @@ h2_threeway.default <- function(object, ...) {

#' @describeIn h2_threeway Pairwise interaction strength from "hstats" object.
#' @export
h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort = TRUE,
top_m = 15L, zero = TRUE, eps = 1e-8,
plot = FALSE, fill = "#2b51a1", ...) {
s <- object$h2_threeway
if (is.null(s)) {
return(NULL)
}
out <- postprocess(
num = s$num,
denom = s$denom,
h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
get_hstats_matrix(
statistic = "h2_threeway",
object = object,
normalize = normalize,
squared = squared,
sort = sort,
top_m = top_m,
zero = zero,
eps = eps
sort = sort,
zero = zero
)
if (plot) plot_stat(out, fill = fill, ...) else out
}

#' Raw H2 Threeway
Expand All @@ -96,7 +83,7 @@ h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, sort =
#' @noRd
#' @keywords internal
#' @param x A list containing the elements "combs3", "v_threeway_0", "K", "pred_names",
#' "F_jkl", "F_jk", "F_j", and "w".
#' "F_jkl", "F_jk", "F_j", "eps", and "w".
#' @returns A list with the numerator and denominator statistics.
h2_threeway_raw <- function(x) {
num <- init_numerator(x, way = 3L)
Expand All @@ -116,6 +103,7 @@ h2_threeway_raw <- function(x) {
denom[nm, ] <- with(x, wcolMeans(F_jkl[[nm]]^2, w = w))
}
}
num <- .zap_small(num, eps = x[["eps"]]) # Numeric precision

list(num = num, denom = denom)
}
Loading

0 comments on commit 9cc20fe

Please sign in to comment.