Skip to content

Commit

Permalink
Merge pull request #22 from kapsner/master
Browse files Browse the repository at this point in the history
Added ranger_surv.unify for random survival forests with {ranger}
  • Loading branch information
pbiecek authored Jan 15, 2023
2 parents 8049a53 + 35eb9c7 commit 7d28afe
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 12 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ License: GPL-3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.2
RoxygenNote: 7.2.2
LinkingTo:
Rcpp
Imports:
Expand All @@ -40,6 +40,7 @@ Suggests:
catboost (>= 0.22),
jsonlite,
testthat,
scales
scales,
survival
URL: https://github.com/ModelOriented/treeshap
BugReports: https://github.com/ModelOriented/treeshap/issues
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export(plot_feature_importance)
export(plot_interaction)
export(randomForest.unify)
export(ranger.unify)
export(ranger_surv.unify)
export(set_reference_dataset)
export(theme_drwhy)
export(theme_drwhy_vertical)
Expand Down
27 changes: 17 additions & 10 deletions R/unify_ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,35 +42,42 @@ ranger.unify <- function(rf_model, data) {
}
n <- rf_model$num.trees
x <- lapply(1:n, function(tree) {
tree_data <- as.data.table(ranger::treeInfo(rf_model, tree = tree))
tree_data <- data.table::as.data.table(ranger::treeInfo(rf_model, tree = tree))
tree_data[, c("nodeID", "leftChild", "rightChild", "splitvarName", "splitval", "prediction")]
})
return(ranger_unify.common(x = x, n = n, data = data))
}


ranger_unify.common <- function(x, n, data) {
times_vec <- sapply(x, nrow)
y <- rbindlist(x)
y[, Tree := rep(0:(n - 1), times = times_vec)]
setnames(y, c("Node", "Yes", "No", "Feature", "Split", "Prediction", "Tree"))
y[, Feature := as.character(Feature)]
y <- data.table::rbindlist(x)
y[, ("Tree") := rep(0:(n - 1), times = times_vec)]
data.table::setnames(y, c("Node", "Yes", "No", "Feature", "Split", "Prediction", "Tree"))
y[, ("Feature") := as.character(get("Feature"))]
y[y$Yes < 0, "Yes"] <- NA
y[y$No < 0, "No"] <- NA
y[, Missing := NA]
y[, ("Missing") := NA]
y$Cover <- 0
y$Decision.type <- factor(x = rep("<=", times = nrow(y)), levels = c("<=", "<"))
y[is.na(Feature), Decision.type := NA]
y[is.na(get("Feature")), ("Decision.type") := NA]

ID <- paste0(y$Node, "-", y$Tree)
y$Yes <- match(paste0(y$Yes, "-", y$Tree), ID)
y$No <- match(paste0(y$No, "-", y$Tree), ID)

# Here we lose "Quality" information
y[!is.na(Feature), Prediction := NA]
y[!is.na(get("Feature")), ("Prediction") := NA]

# treeSHAP assumes, that [prediction = sum of predictions of the trees]
# in random forest [prediction = mean of predictions of the trees]
# so here we correct it by adjusting leaf prediction values
y[is.na(Feature), Prediction := Prediction / n]
y[is.na(get("Feature")), ("Prediction") := I(get("Prediction") / n)]


setcolorder(y, c("Tree", "Node", "Feature", "Decision.type", "Split", "Yes", "No", "Missing", "Prediction", "Cover"))
data.table::setcolorder(
y, c("Tree", "Node", "Feature", "Decision.type", "Split",
"Yes", "No", "Missing", "Prediction", "Cover"))

ret <- list(model = as.data.frame(y), data = as.data.frame(data))
class(ret) <- "model_unified"
Expand Down
89 changes: 89 additions & 0 deletions R/unify_ranger_surv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#' Unify ranger survival model
#'
#' Convert your ranger model into a standarised representation.
#' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function.
#'
#' @param rf_model An object of \code{ranger} class. At the moment, models built on data with categorical features
#' are not supported - please encode them before training.
#' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.
#'
#' @return a unified model representation - a \code{\link{model_unified.object}} object
#'
#' @import data.table
#'
#' @export
#'
#' @seealso
#' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}}
#'
#' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}}
#'
#' \code{\link{catboost.unify}} for \code{\link[catboost:catboost.train]{Catboost models}}
#'
#' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}}
#'
#' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}}
#'
#' @examples
#'
#' library(ranger)
#' data_colon <- data.table::data.table(survival::colon)
#' data_colon <- na.omit(data_colon[get("etype") == 2, ])
#' surv_cols <- c("status", "time", "rx")
#'
#' feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)]
#'
#' train_x <- model.matrix(
#' ~ -1 + .,
#' data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])]
#' )
#' train_y <- survival::Surv(
#' event = (data_colon[, get("status")] |>
#' as.character() |>
#' as.integer()),
#' time = data_colon[, get("time")],
#' type = "right"
#' )
#'
#' rf <- ranger::ranger(
#' x = train_x,
#' y = train_y,
#' data = data_colon,
#' max.depth = 10,
#' num.trees = 10
#' )
#' unified_model <- ranger_surv.unify(rf, train_x)
#' shaps <- treeshap(unified_model, train_x[1:2,])
#'
ranger_surv.unify <- function(rf_model, data) {
if (!"ranger" %in% class(rf_model)) {
stop("Object rf_model was not of class \"ranger\"")
}
if (!"survival" %in% names(rf_model)) {
stop("Object rf_model is not a survival random forest.")
}
n <- rf_model$num.trees
x <- lapply(1:n, function(tree) {
tree_data <- data.table::as.data.table(ranger::treeInfo(rf_model,
tree = tree))

# first get number of columns
chf_node <- rf_model$forest$chf[[tree]]
nodes_chf_n <- ncol(do.call(rbind, chf_node))
nodes_prepare_chf_list <- lapply(
X = chf_node,
FUN = function(node) {
if (identical(node, numeric(0L))) {
rep(NA, nodes_chf_n)
} else {
node
}
}
)
nodes_chf <- do.call(rbind, nodes_prepare_chf_list)
tree_data$prediction <- rowSums(nodes_chf)
tree_data[, c("nodeID", "leftChild", "rightChild", "splitvarName",
"splitval", "prediction")]
})
return(ranger_unify.common(x = x, n = n, data = data))
}
64 changes: 64 additions & 0 deletions man/ranger_surv.unify.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

100 changes: 100 additions & 0 deletions tests/testthat/test_ranger_surv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
library(treeshap)

data_colon <- data.table::data.table(survival::colon)
data_colon <- na.omit(data_colon[get("etype") == 2, ])
surv_cols <- c("status", "time", "rx")

feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)]

x <- model.matrix(
~ -1 + .,
data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])]
)
y <- survival::Surv(
event = (data_colon[, get("status")] |>
as.character() |>
as.integer()),
time = data_colon[, get("time")],
type = "right"
)

ranger_num_model <- ranger::ranger(
x = x,
y = y,
data = data_colon,
max.depth = 10,
num.trees = 10
)


test_that('ranger_surv.unify creates an object of the appropriate class', {
expect_true(is.model_unified(ranger_surv.unify(ranger_num_model, x)))
})

test_that('ranger_surv.unify returns an object with correct attributes', {
unified_model <- ranger_surv.unify(ranger_num_model, x)

expect_equal(attr(unified_model, "missing_support"), FALSE)
expect_equal(attr(unified_model, "model"), "ranger")
})

test_that('the ranger_surv.unify function returns data frame with columns of appropriate column', {
unifier <- ranger_surv.unify(ranger_num_model, x)$model
expect_true(is.integer(unifier$Tree))
expect_true(is.integer(unifier$Node))
expect_true(is.character(unifier$Feature))
expect_true(is.factor(unifier$Decision.type))
expect_true(is.numeric(unifier$Split))
expect_true(is.integer(unifier$Yes))
expect_true(is.integer(unifier$No))
expect_true(all(is.na(unifier$Missing)))
expect_true(is.numeric(unifier$Prediction))
expect_true(is.numeric(unifier$Cover))
})

test_that("ranger_surv: shap calculates without an error", {
unifier <- ranger_surv.unify(ranger_num_model, x)
expect_error(treeshap(unifier, x[1:3,], verbose = FALSE), NA)
})

test_that("ranger_surv: predictions from unified == original predictions", {
unifier <- ranger_surv.unify(ranger_num_model, x)
obs <- x[1:800, ]
surv_preds <- stats::predict(ranger_num_model, obs)
original <- rowSums(surv_preds$chf)
from_unified <- predict(unifier, obs)
expect_true(all(abs((from_unified - original) / original) < 10**(-14)))
})

test_that("ranger_surv: mean prediction calculated using predict == using covers", {
unifier <- ranger_surv.unify(ranger_num_model, x)

intercept_predict <- mean(predict(unifier, x))

ntrees <- sum(unifier$model$Node == 0)
leaves <- unifier$model[is.na(unifier$model$Feature), ]
intercept_covers <- sum(leaves$Prediction * leaves$Cover) / sum(leaves$Cover) * ntrees

#expect_true(all(abs((intercept_predict - intercept_covers) / intercept_predict) < 10**(-14)))
expect_equal(intercept_predict, intercept_covers)
})

test_that("ranger_surv: covers correctness", {
unifier <- ranger_surv.unify(ranger_num_model, x)

roots <- unifier$model[unifier$model$Node == 0, ]
expect_true(all(roots$Cover == nrow(x)))

internals <- unifier$model[!is.na(unifier$model$Feature), ]
yes_child_cover <- unifier$model[internals$Yes, ]$Cover
no_child_cover <- unifier$model[internals$No, ]$Cover
if (all(is.na(internals$Missing))) {
children_cover <- yes_child_cover + no_child_cover
} else {
missing_child_cover <- unifier$model[internals$Missing, ]$Cover
missing_child_cover[is.na(missing_child_cover)] <- 0
missing_child_cover[internals$Missing == internals$Yes | internals$Missing == internals$No] <- 0
children_cover <- yes_child_cover + no_child_cover + missing_child_cover
}
expect_true(all(internals$Cover == children_cover))
})

0 comments on commit 7d28afe

Please sign in to comment.