-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from kapsner/master
Added ranger_surv.unify for random survival forests with {ranger}
- Loading branch information
Showing
6 changed files
with
274 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#' Unify ranger survival model | ||
#' | ||
#' Convert your ranger model into a standarised representation. | ||
#' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function. | ||
#' | ||
#' @param rf_model An object of \code{ranger} class. At the moment, models built on data with categorical features | ||
#' are not supported - please encode them before training. | ||
#' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model. | ||
#' | ||
#' @return a unified model representation - a \code{\link{model_unified.object}} object | ||
#' | ||
#' @import data.table | ||
#' | ||
#' @export | ||
#' | ||
#' @seealso | ||
#' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}} | ||
#' | ||
#' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}} | ||
#' | ||
#' \code{\link{catboost.unify}} for \code{\link[catboost:catboost.train]{Catboost models}} | ||
#' | ||
#' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}} | ||
#' | ||
#' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}} | ||
#' | ||
#' @examples | ||
#' | ||
#' library(ranger) | ||
#' data_colon <- data.table::data.table(survival::colon) | ||
#' data_colon <- na.omit(data_colon[get("etype") == 2, ]) | ||
#' surv_cols <- c("status", "time", "rx") | ||
#' | ||
#' feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)] | ||
#' | ||
#' train_x <- model.matrix( | ||
#' ~ -1 + ., | ||
#' data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])] | ||
#' ) | ||
#' train_y <- survival::Surv( | ||
#' event = (data_colon[, get("status")] |> | ||
#' as.character() |> | ||
#' as.integer()), | ||
#' time = data_colon[, get("time")], | ||
#' type = "right" | ||
#' ) | ||
#' | ||
#' rf <- ranger::ranger( | ||
#' x = train_x, | ||
#' y = train_y, | ||
#' data = data_colon, | ||
#' max.depth = 10, | ||
#' num.trees = 10 | ||
#' ) | ||
#' unified_model <- ranger_surv.unify(rf, train_x) | ||
#' shaps <- treeshap(unified_model, train_x[1:2,]) | ||
#' | ||
ranger_surv.unify <- function(rf_model, data) { | ||
if (!"ranger" %in% class(rf_model)) { | ||
stop("Object rf_model was not of class \"ranger\"") | ||
} | ||
if (!"survival" %in% names(rf_model)) { | ||
stop("Object rf_model is not a survival random forest.") | ||
} | ||
n <- rf_model$num.trees | ||
x <- lapply(1:n, function(tree) { | ||
tree_data <- data.table::as.data.table(ranger::treeInfo(rf_model, | ||
tree = tree)) | ||
|
||
# first get number of columns | ||
chf_node <- rf_model$forest$chf[[tree]] | ||
nodes_chf_n <- ncol(do.call(rbind, chf_node)) | ||
nodes_prepare_chf_list <- lapply( | ||
X = chf_node, | ||
FUN = function(node) { | ||
if (identical(node, numeric(0L))) { | ||
rep(NA, nodes_chf_n) | ||
} else { | ||
node | ||
} | ||
} | ||
) | ||
nodes_chf <- do.call(rbind, nodes_prepare_chf_list) | ||
tree_data$prediction <- rowSums(nodes_chf) | ||
tree_data[, c("nodeID", "leftChild", "rightChild", "splitvarName", | ||
"splitval", "prediction")] | ||
}) | ||
return(ranger_unify.common(x = x, n = n, data = data)) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
library(treeshap) | ||
|
||
data_colon <- data.table::data.table(survival::colon) | ||
data_colon <- na.omit(data_colon[get("etype") == 2, ]) | ||
surv_cols <- c("status", "time", "rx") | ||
|
||
feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)] | ||
|
||
x <- model.matrix( | ||
~ -1 + ., | ||
data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])] | ||
) | ||
y <- survival::Surv( | ||
event = (data_colon[, get("status")] |> | ||
as.character() |> | ||
as.integer()), | ||
time = data_colon[, get("time")], | ||
type = "right" | ||
) | ||
|
||
ranger_num_model <- ranger::ranger( | ||
x = x, | ||
y = y, | ||
data = data_colon, | ||
max.depth = 10, | ||
num.trees = 10 | ||
) | ||
|
||
|
||
test_that('ranger_surv.unify creates an object of the appropriate class', { | ||
expect_true(is.model_unified(ranger_surv.unify(ranger_num_model, x))) | ||
}) | ||
|
||
test_that('ranger_surv.unify returns an object with correct attributes', { | ||
unified_model <- ranger_surv.unify(ranger_num_model, x) | ||
|
||
expect_equal(attr(unified_model, "missing_support"), FALSE) | ||
expect_equal(attr(unified_model, "model"), "ranger") | ||
}) | ||
|
||
test_that('the ranger_surv.unify function returns data frame with columns of appropriate column', { | ||
unifier <- ranger_surv.unify(ranger_num_model, x)$model | ||
expect_true(is.integer(unifier$Tree)) | ||
expect_true(is.integer(unifier$Node)) | ||
expect_true(is.character(unifier$Feature)) | ||
expect_true(is.factor(unifier$Decision.type)) | ||
expect_true(is.numeric(unifier$Split)) | ||
expect_true(is.integer(unifier$Yes)) | ||
expect_true(is.integer(unifier$No)) | ||
expect_true(all(is.na(unifier$Missing))) | ||
expect_true(is.numeric(unifier$Prediction)) | ||
expect_true(is.numeric(unifier$Cover)) | ||
}) | ||
|
||
test_that("ranger_surv: shap calculates without an error", { | ||
unifier <- ranger_surv.unify(ranger_num_model, x) | ||
expect_error(treeshap(unifier, x[1:3,], verbose = FALSE), NA) | ||
}) | ||
|
||
test_that("ranger_surv: predictions from unified == original predictions", { | ||
unifier <- ranger_surv.unify(ranger_num_model, x) | ||
obs <- x[1:800, ] | ||
surv_preds <- stats::predict(ranger_num_model, obs) | ||
original <- rowSums(surv_preds$chf) | ||
from_unified <- predict(unifier, obs) | ||
expect_true(all(abs((from_unified - original) / original) < 10**(-14))) | ||
}) | ||
|
||
test_that("ranger_surv: mean prediction calculated using predict == using covers", { | ||
unifier <- ranger_surv.unify(ranger_num_model, x) | ||
|
||
intercept_predict <- mean(predict(unifier, x)) | ||
|
||
ntrees <- sum(unifier$model$Node == 0) | ||
leaves <- unifier$model[is.na(unifier$model$Feature), ] | ||
intercept_covers <- sum(leaves$Prediction * leaves$Cover) / sum(leaves$Cover) * ntrees | ||
|
||
#expect_true(all(abs((intercept_predict - intercept_covers) / intercept_predict) < 10**(-14))) | ||
expect_equal(intercept_predict, intercept_covers) | ||
}) | ||
|
||
test_that("ranger_surv: covers correctness", { | ||
unifier <- ranger_surv.unify(ranger_num_model, x) | ||
|
||
roots <- unifier$model[unifier$model$Node == 0, ] | ||
expect_true(all(roots$Cover == nrow(x))) | ||
|
||
internals <- unifier$model[!is.na(unifier$model$Feature), ] | ||
yes_child_cover <- unifier$model[internals$Yes, ]$Cover | ||
no_child_cover <- unifier$model[internals$No, ]$Cover | ||
if (all(is.na(internals$Missing))) { | ||
children_cover <- yes_child_cover + no_child_cover | ||
} else { | ||
missing_child_cover <- unifier$model[internals$Missing, ]$Cover | ||
missing_child_cover[is.na(missing_child_cover)] <- 0 | ||
missing_child_cover[internals$Missing == internals$Yes | internals$Missing == internals$No] <- 0 | ||
children_cover <- yes_child_cover + no_child_cover + missing_child_cover | ||
} | ||
expect_true(all(internals$Cover == children_cover)) | ||
}) |