Skip to content

Commit

Permalink
Add plot option d2_geom = "line" to PDP
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Oct 28, 2023
1 parent 9fcab31 commit 2d1d6a5
Show file tree
Hide file tree
Showing 10 changed files with 1,583 additions and 32 deletions.
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# hstats 1.0.1

## Enhancements

- The plot method of a two-dimensional PDP has recieved the option `d2_geom = "line"`. Instead of a heatmap of the two features, one of the features is moved to color grouping. This might give a better impression where the interaction happens. Combined with `swap_dim = TRUE`, you can swap the role of the two `v` variables without recalculating anything. The idea was proposed by [Roel Verbelen](https://github.com/RoelVerbelen) in [issue #91](https://github.com/mayer79/hstats/issues/91), see also [issue #94](https://github.com/mayer79/hstats/issues/94).

## Bug fixes

- Using `BY` and `w` via column names would fail for tibbles. This problem was described in [#92](https://github.com/mayer79/hstats/issues/92) by @RoelVerbelen. Thx!
- Using `BY` and `w` via column names would fail for tibbles. This problem was described in [#92](https://github.com/mayer79/hstats/issues/92) by [Roel Verbelen](https://github.com/RoelVerbelen). Thx!

## Other changes

Expand Down
62 changes: 42 additions & 20 deletions R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,25 @@
#' plot(pd)
#'
#' # Multivariable input
#' v <- c("Species", "Petal.Width")
#' v <- c("Species", "Petal.Length")
#' pd <- partial_dep(fit, v = v, X = iris, grid_size = 100L)
#' plot(pd, rotate_x = TRUE)
#' plot(pd, d2_geom = "line") # often better to read
#'
#' # With grouping
#' pd <- partial_dep(fit, v = v, X = iris, grid_size = 100L, BY = "Petal.Length")
#' pd <- partial_dep(fit, v = v, X = iris, grid_size = 100L, BY = "Petal.Width")
#' plot(pd, rotate_x = TRUE)
#' plot(pd, rotate_x = TRUE, d2_geom = "line")
#' plot(pd, rotate_x = TRUE, d2_geom = "line", swap_dim = TRUE)
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' pd <- partial_dep(fit, v = "Petal.Width", X = iris, BY = "Species")
#' plot(pd, show_points = FALSE)
#' plot(partial_dep(fit, v = c("Species", "Petal.Width"), X = iris), rotate_x = TRUE)
#' pd <- partial_dep(fit, v = c("Species", "Petal.Width"), X = iris)
#' plot(pd, rotate_x = TRUE)
#' plot(pd, d2_geom = "line", rotate_x = TRUE)
#' plot(pd, d2_geom = "line", rotate_x = TRUE, swap_dim = TRUE)
#'
#' # Multivariate, multivariable, and BY (no plot available)
#' pd <- partial_dep(
Expand Down Expand Up @@ -286,10 +292,14 @@ print.partial_dep <- function(x, n = 3L, ...) {
#' @param color Color of lines and points (in case there is no color/fill aesthetic).
#' The default equals the global option `hstats.color = "#3b528b"`.
#' To change the global option, use `options(stats.color = new value)`.
#' @param swap_dim Switches the role of grouping and facetting (default is `FALSE`).
#' Exception: For the 2D PDP with `d2_geom = "line"`, it swaps the role of the two
#' variables in `v`.
#' @param show_points Logical flag indicating whether to show points (default) or not.
#' No effect for 2D PDPs.
#' @param d2_geom The geometry used for 2D PDPs, by default "tile". The other option is
#' "point", which is useful, e.g., when the grid represents spatial points.
#' @param d2_geom The geometry used for 2D PDPs, by default "tile". Option "point"
#' is useful, e.g., when the grid represents spatial points. Option "line" produces
#' lines grouped by the second variable.
#' @param ... Arguments passed to geometries.
#' @inheritParams plot.hstats_matrix
#' @export
Expand All @@ -301,14 +311,15 @@ plot.partial_dep <- function(x,
viridis_args = getOption("hstats.viridis_args"),
facet_scales = "fixed",
rotate_x = FALSE, show_points = TRUE,
d2_geom = c("tile", "point"), ...) {
d2_geom = c("tile", "point", "line"), ...) {
d2_geom <- match.arg(d2_geom)
v <- x[["v"]]
by_name <- x[["by_name"]]
K <- x[["K"]]
if (length(v) > 2L) {
stop("Maximal two features can be plotted.")
}
if ((K > 1L) + (!is.null(by_name)) + length(v) > 3L) {
if (((K > 1L) + (!is.null(by_name)) + length(v)) > 3L) {
stop("No plot implemented for this case.")
}
if (is.null(viridis_args)) {
Expand All @@ -317,16 +328,31 @@ plot.partial_dep <- function(x,

data <- with(x, poor_man_stack(data, to_stack = pred_names))

wrp <- NULL
if (length(v) == 1L) {
if (length(v) == 2L && (K > 1L || !is.null(by_name))) { # Only one is possible
wrp <- if (K > 1L) "varying_" else by_name
} else {
wrp <- NULL
}
if (length(v) == 1L || d2_geom == "line") {
# Line plots
grp <- if (is.null(by_name) && K > 1L) "varying_" else by_name # can be NULL
wrp <- if (!is.null(by_name) && K > 1L) "varying_"
if (swap_dim) {
tmp <- grp
grp <- wrp
wrp <- tmp

# Determine the role of x axis, color axis and facetting
if (length(v) == 1L) {
grp <- if (is.null(by_name) && K > 1L) "varying_" else by_name # can be NULL
wrp <- if (!is.null(by_name) && K > 1L) "varying_"
if (swap_dim) {
tmp <- grp
grp <- wrp
wrp <- tmp
}
} else { # length(v) == 2
if (swap_dim) {
v <- rev(v)
}
grp <- v[2L]
v <- v[1L]
}

p <- ggplot2::ggplot(data, ggplot2::aes(x = .data[[v]], y = value_)) +
ggplot2::labs(x = v, y = "PD")

Expand All @@ -352,11 +378,7 @@ plot.partial_dep <- function(x,
}
}
} else if (length(v) == 2L) {
# Heat maps
d2_geom <- match.arg(d2_geom)
if (K > 1L || !is.null(by_name)) { # Only one is possible
wrp <- if (K > 1L) "varying_" else by_name
}
# Heat maps ("tile" or "point", "line" has been treated above)
p <- ggplot2::ggplot(data, ggplot2::aes(x = .data[[v[1L]]], y = .data[[v[2L]]]))
if (d2_geom == "tile") {
p <- p + ggplot2::geom_tile(ggplot2::aes(fill = value_), ...) +
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ Note: {hstats} can crunch **three-way** interaction statistics $H^2_{jkl}$ as we
Let's study different plots to understand *how* the strong interaction between distance to the ocean and age looks like. We will check the following three visualizations.

1. Stratified PDP
2. Two-dimensional PDP
2. Two-dimensional PDP (once as heatmap, once by representing the second variable on the color scale)
3. Centered ICE plot with colors

They all reveal a substantial interaction between the two variables in the sense that the age effect gets weaker the closer to the ocean. Note that numeric `BY` features are automatically binned into quartile groups.
Expand All @@ -160,9 +160,11 @@ plot(partial_dep(fit, v = "age", X = X_train, BY = "log_ocean"), show_points = F
```r
pd <- partial_dep(fit, v = c("age", "log_ocean"), X = X_train, grid_size = 1000)
plot(pd)
plot(pd, d2_geom = "line", show_points = FALSE)
```

![](man/figures/pdp_2d.png)
![](man/figures/pdp_2d.svg)
![](man/figures/pdp_2d_line.svg)

```r
ic <- ice(fit, v = "age", X = X_train, BY = "log_ocean")
Expand Down
Loading

0 comments on commit 2d1d6a5

Please sign in to comment.