diff --git a/.github/workflows/rcmdcheck.yml b/.github/workflows/rcmdcheck.yml index c2fa22a..0d3cb3e 100644 --- a/.github/workflows/rcmdcheck.yml +++ b/.github/workflows/rcmdcheck.yml @@ -71,7 +71,7 @@ jobs: reticulate::install_miniconda() install.packages('keras') keras::install_keras(extra_packages = c('IPython', 'requests', 'certifi', 'urllib3', 'tensorflow-hub', 'tabnet==0.1.4.1')) - reticulate::py_install(c('torch', 'pycox'), pip = TRUE) + reticulate::py_install(c('torch', 'pycox', 'pandas==1.4.4'), pip = TRUE) shell: Rscript {0} - name: Session info diff --git a/DESCRIPTION b/DESCRIPTION index 4114c84..0ec2fbe 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,12 +1,16 @@ Package: survivalmodels Title: Models for Survival Analysis -Version: 0.1.18 -Authors@R: - person(given = "Raphael", +Version: 0.1.19 +Authors@R: + c(person(given = "Raphael", family = "Sonabend", - role = c("aut", "cre"), - email = "raphaelsonabend@gmail.com", - comment = c(ORCID = "0000-0001-9225-4654")) + role = c("aut"), + comment = c(ORCID = "0000-0001-9225-4654")), + person(given = "Yohann", + family = "Foucher", + role = c("cre"), + email = "yohann.foucher@univ-poitiers.fr", + comment = c(ORCID = "0000-0003-0330-7457"))) Description: Implementations of classical and machine learning models for survival analysis, including deep neural networks via 'keras' and 'tensorflow'. Each model includes a separated fit and predict interface with consistent prediction types for predicting risk, survival probabilities, or survival distributions with 'distr6' . Models are either implemented from 'Python' via 'reticulate' , from code in GitHub packages, or novel implementations using 'Rcpp' . Novel machine learning survival models wil be included in the package in near-future updates. Neural networks are implemented from the 'Python' package 'pycox' and are detailed by Kvamme et al. (2019) . The 'Akritas' estimator is defined in Akritas (1994) . 'DNNSurv' is defined in Zhao and Feng (2020) . License: MIT + file LICENSE URL: https://github.com/RaphaelS1/survivalmodels/ @@ -16,8 +20,10 @@ Imports: Suggests: distr6 (>= 1.6.6), keras (>= 2.11.0), + param6, pseudo, reticulate, + set6, survival, testthat LinkingTo: diff --git a/NAMESPACE b/NAMESPACE index 4a5a2ae..ef0b06a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ S3method(predict,akritas) S3method(predict,dnnsurv) +S3method(predict,parametric) S3method(predict,pycox) S3method(print,survivalmodel) S3method(summary,survivalmodel) @@ -22,6 +23,7 @@ export(install_keras) export(install_pycox) export(install_torch) export(loghaz) +export(parametric) export(pchazard) export(pycox_prepare_train_data) export(requireNamespaces) diff --git a/NEWS.md b/NEWS.md index 8650792..961006c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +# survivalmodels 0.1.19 + +* Add fully parametric survival models +* Fix broken CI + # survivalmodels 0.1.18 * Slight speed up in Akritas model by adding parameters to control number of unique time-points diff --git a/R/clean_data.R b/R/clean_data.R index 378682b..1fdfabe 100644 --- a/R/clean_data.R +++ b/R/clean_data.R @@ -48,21 +48,25 @@ clean_train_data <- function(formula = NULL, data = NULL, time_variable = NULL, } clean_test_data <- function(object, newdata) { + if (missing(newdata)) { - newdata <- object$x - } else { - newdata <- stats::model.matrix(~., newdata)[, -1, drop = FALSE] + newdata <- object$x[, !(colnames(object$x) %in% "(Intercept)")] + colnames(newdata) <- gsub("data$x", "", colnames(newdata), fixed = TRUE) + return(newdata) } - ord <- match(colnames(newdata), colnames(object$x), nomatch = NULL) - newdata <- newdata[, !is.na(ord), drop = FALSE] + newdata <- stats::model.matrix(~., newdata)[, -1, drop = FALSE] + old_features <- setdiff(colnames(object$x), "(Intercept)") + # fix for passing formula as data directly + old_features <- gsub("data$x", "", old_features, fixed = TRUE) + ord <- match(old_features, colnames(newdata), nomatch = NULL) newdata <- newdata[, ord[!is.na(ord)], drop = FALSE] - if (!all(suppressWarnings(colnames(newdata) == colnames(object$x)))) { + if (!all(suppressWarnings(colnames(newdata) == old_features))) { stop(sprintf( "Names in newdata should be identical to {%s}.", paste0(colnames(object$x), collapse = ", ") )) } - return(newdata) + newdata } diff --git a/R/coxtime.R b/R/coxtime.R index fbe257b..b77d9ce 100644 --- a/R/coxtime.R +++ b/R/coxtime.R @@ -28,7 +28,7 @@ #' #' #' @examples -#' \donttest{ +#' \dontrun{ #' if (requireNamespaces("reticulate")) { #' # all defaults #' coxtime(data = simsurvdata(50)) diff --git a/R/helpers_pycox.R b/R/helpers_pycox.R index 1b4f120..e9dd5ed 100644 --- a/R/helpers_pycox.R +++ b/R/helpers_pycox.R @@ -698,7 +698,7 @@ def init_weights(m): #' Currently ignored. #' #' @examples -#' \donttest{ +#' \dontrun{ #' if (requireNamespaces("reticulate")) { #' fit <- coxtime(data = simsurvdata(50)) #' diff --git a/R/methods.R b/R/methods.R index bfd13b4..e1e7da4 100644 --- a/R/methods.R +++ b/R/methods.R @@ -2,7 +2,12 @@ print.survivalmodel <- function(x, ...) { cat("\n", attr(x, "name"), "\n\n") cat("Call:\n ", deparse(x$call)) - cat("\n\nResponse:\n Surv(", paste0(colnames(x$y), collapse = ", "), ")\n", sep = "") + if (is.null(x[["y"]])) { + ynames <- x$ynames + } else { + ynames <- colnames(x$y) + } + cat("\n\nResponse:\n Surv(", paste0(ynames, collapse = ", "), ")\n", sep = "") cat("Features:\n ", setcollapse(x$xnames), "\n") } diff --git a/R/parametric.R b/R/parametric.R new file mode 100644 index 0000000..bc03fd5 --- /dev/null +++ b/R/parametric.R @@ -0,0 +1,362 @@ +#' @title Fully Parametric Survival Model +#' @name parametric +#' +#' @description +#' Fit/predict implementation of [survival::survreg()], which can return +#' absolutely continuous distribution predictions using \pkg{distr6}. +#' +#' @param eps `(numeric(1))` \cr +#' Used when the fitted `scale` parameter is too small. Default `1e-15`. +#' @param ... `ANY` \cr +#' Additional arguments passed to [survival::survreg()]. +#' +#' @template param_traindata +#' +#' @return An object inheriting from class `parametric`. +#' +#' @examples +#' if (requireNamespaces(c("distr6", "survival"))) { +#' library(survival) +#' parametric(Surv(time, status) ~ ., data = simsurvdata(10)) +#' } +#' @export +parametric <- function( + formula = NULL, data = NULL, reverse = FALSE, + time_variable = "time", status_variable = "status", + x = NULL, y = NULL, eps = 1e-15, ...) { + if (!requireNamespaces("distr6")) { + stop("Package 'distr6' required but not installed.") # nocov + } + + call <- match.call() + + data <- clean_train_data(formula, data, time_variable, status_variable, x, y, reverse) + + fit <- survival::survreg(survival::Surv(data$y) ~ data$x, x = TRUE, ...) + + location <- as.numeric(fit$coefficients[1]) + + if (is.na(location)) { + stop("Failed to fit survreg, coefficients all NA") + } + + scale <- fit$scale + + if (scale < eps) { + scale <- eps + } else if (scale > .Machine$double.xmax) { + scale <- .Machine$double.xmax + } + + if (location < -709 && + fit$dist %in% c("weibull", "exponential", "loglogistic")) { + location <- -709 + } + + basedist <- switch(fit$dist, + "weibull" = distr6::Weibull$new( + shape = 1 / scale, scale = exp(location), + decorators = "ExoticStatistics" + ), + "exponential" = distr6::Exponential$new( + scale = exp(location), + decorators = "ExoticStatistics" + ), + "gaussian" = distr6::Normal$new( + mean = location, sd = scale, + decorators = "ExoticStatistics" + ), + "lognormal" = distr6::Lognormal$new( + meanlog = location, sdlog = scale, + decorators = "ExoticStatistics" + ), + "loglogistic" = distr6::Loglogistic$new( + scale = exp(location), + shape = 1 / scale, + decorators = "ExoticStatistics" + ) + ) + + return(structure( + list( + model = fit, + basedist = basedist, + call = call, + xnames = colnames(fit$x), + ynames = unique(colnames(fit$y)) + ), + name = "Parametric survival model", + class = c("parametric", "survivalmodel") + )) +} + +#' @title Predict method for Parametric Model +#' +#' @description Predicted values from a fitted Parametric survival model. +#' +#' @details +#' The `form` parameter determines how the distribution is created. +#' Options are: +#' +#' - Accelerated failure time (`"aft"`) \deqn{h(t) = h_0(\frac{t}{exp(lp)})exp(-lp)} +#' - Proportional Hazards (`"ph"`) \deqn{h(t) = h_0(t)exp(lp)} +#' - Tobit (`"tobit"`) \deqn{h(t) = \Phi(\frac{t - lp}{scale})} +#' - Proportional odds (`"po"`) \deqn{h(t) = \frac{h_0(t)}{1 + (exp(lp)-1)S_0(t)}} +#' +#' where \eqn{h_0,S_0} are the estimated baseline hazard and survival functions +#' (in this case with a given parametric form), \eqn{lp} is the predicted linear +#' predictor calculated using the formula \eqn{lp = \hat{\beta} X_{new}} where +#' \eqn{X_{new}} are the variables in the test data set and \eqn{\hat{\beta}} +#' are the coefficients from the fitted parametric survival model (`object`). +#' \eqn{\Phi} is the cdf of a N(0, 1) distribution, and \eqn{scale} is the +#' fitted scale parameter. +#' +#' @param object (`parametric(1)`)\cr +#' Object of class inheriting from `"parametric"`. +#' @param newdata `(data.frame(1))`\cr +#' Testing data of `data.frame` like object, internally is coerced with [stats::model.matrix()]. +#' If missing then training data from fitted object is used. +#' @param form `(character(1))` \cr +#' The form of the predicted distribution, see `details` for options. +#' @param times `(numeric())`\cr +#' Times at which to evaluate the estimator. If `NULL` (default) then evaluated at all unique times +#' in the training set. +#' @param type (`character(1)`)\cr +#' Type of predicted value. Choices are survival probabilities over all time-points in training +#' data (`"survival"`) or a relative risk ranking (`"risk"`), which is the sum of the predicted +#' cumulative hazard function so higher rank implies higher risk of event, or both (`"all"`). +#' @param distr6 (`logical(1)`)\cr +#' If `FALSE` (default) and `type` is `"survival"` or `"all"` returns matrix of survival +#' probabilities, otherwise returns a [distr6::Distribution()]. +#' @param ntime `(numeric(1))`\cr +#' Number of unique time-points in the training set, default is 150. +#' @param round_time `(numeric(1))`\cr +#' Number of decimal places to round time-points to, default is 2, set to `FALSE` for no rounding. +#' @param ... `ANY` \cr +#' Currently ignored. +#' +#' @return A `numeric` if `type = "risk"`, a [distr6::Distribution()] +#' (if `distr6 = TRUE`) and `type = "survival"`; a `matrix` if +#' (`distr6 = FALSE`) and `type = "survival"` where entries are survival +#' probabilities with rows of observations and columns are time-points; +#' or a list combining above if `type = "all"`. +#' +#' @examples +#' if (requireNamespaces(c("distr6", "survival"))) { +#' library(survival) +#' +#' set.seed(42) +#' train <- simsurvdata(10) +#' test <- simsurvdata(5) +#' fit <- parametric(Surv(time, status) ~ ., data = train) +#' +#' # Return a discrete distribution survival matrix +#' predict_distr <- predict(fit, newdata = test) +#' predict_distr +#' +#' # Return a relative risk ranking with type = "risk" +#' predict(fit, newdata = test, type = "risk") +#' +#' # Or survival probabilities and a rank +#' predict(fit, newdata = test, type = "all", distr6 = TRUE) +#' } +#' @export +predict.parametric <- function(object, newdata, + form = c("aft", "ph", "tobit", "po"), times = NULL, + type = c("survival", "risk", "all"), distr6 = FALSE, + ntime = 150, round_time = 2, ...) { + + form <- match.arg(form) + type <- match.arg(type) + + unique_times <- sort(unique(object$model$y[, 1, drop = FALSE])) + if (!is.logical(round_time) || round_time) { + unique_times <- unique(round(unique_times, round_time)) + } + # using same method as in ranger + unique_times <- unique_times[ + unique(round(seq.int(1, length(unique_times), length.out = ntime))) + ] + + truth <- object$model$y + newdata <- clean_test_data(object$model, newdata) + + if (is.null(times)) { + predict_times <- unique_times + } else { + predict_times <- sort(unique(times)) + } + + basedist <- object$basedist + fit <- object$model + lp <- matrix(fit$coefficients[-1], nrow = 1) %*% t(newdata) + + if (type %in% c("survival", "all") && distr6) { + surv <- .predict_survreg_continuous(object, newdata, form, + basedist, fit, lp) + } else { + surv <- .predict_survreg_discrete(object, newdata, form, + predict_times, basedist, fit, lp) + } + + ret <- list() + + if (type %in% c("risk", "all")) { + ret$risk <- -surv$lp + } + + if (type %in% c("survival", "all")) { + ret$surv <- surv$distr + } + + if (length(ret) == 1) { + return(ret[[1]]) + } else { + return(ret) + } +} + +.predict_survreg_continuous <- function(object, newdata, form, basedist, + fit, lp) { + + dist <- toproper(fit$dist) + + if (form == "tobit") { + name = paste(dist, "Tobit Model") + short_name = paste0(dist, "Tobit") + description = paste(dist, "Tobit Model with negative log-likelihood", + -fit$loglik[2]) + } else if (form == "ph") { + name = paste(dist, "Proportional Hazards Model") + short_name = paste0(dist, "PH") + description = paste(dist, "Proportional Hazards Model with negative log-likelihood", + -fit$loglik[2]) + } else if (form == "aft") { + name = paste(dist, "Accelerated Failure Time Model") + short_name = paste0(dist, "AFT") + description = paste(dist, "Accelerated Failure Time Model with negative log-likelihood", + -fit$loglik[2]) + } else if (form == "po") { + name = paste(dist, "Proportional Odds Model") + short_name = paste0(dist, "PO") + description = paste(dist, "Proportional Odds Model with negative log-likelihood", + -fit$loglik[2]) + } + + params = list(list(name = name, + short_name = short_name, + type = set6::PosReals$new(), + support = set6::PosReals$new(), + valueSupport = "continuous", + variateForm = "univariate", + description = description, + .suppressChecks = TRUE, + pdf = function() { + }, + cdf = function() { + }, + parameters = param6::pset() + )) + + params = rep(params, length(lp)) + + pdf = function(x) {} # nolint + cdf = function(x) {} # nolint + quantile = function(p) {} # nolint + + if (form == "tobit") { + for (i in seq_along(lp)) { + body(pdf) = substitute(pnorm((x - y) / scale), list( + y = lp[i] + fit$coefficients[1], + scale = basedist$stdev() + )) + body(cdf) = substitute(pnorm((x - y) / scale), list( + y = lp[i] + fit$coefficients[1], + scale = basedist$stdev() + )) + body(quantile) = substitute(qnorm(p) * scale + y, list( + y = lp[i] + fit$coefficients[1], + scale = basedist$stdev() + )) + params[[i]]$pdf = pdf + params[[i]]$cdf = cdf + params[[i]]$quantile = quantile + } + } else if (form == "ph") { + for (i in seq_along(lp)) { + body(pdf) = substitute((exp(y) * basedist$hazard(x)) * (1 - self$cdf(x)), list(y = -lp[i])) + body(cdf) = substitute(1 - (basedist$survival(x)^exp(y)), list(y = -lp[i])) + body(quantile) = substitute( + basedist$quantile(1 - exp(exp(-y) * log(1 - p))), # nolint + list(y = -lp[i]) + ) + params[[i]]$pdf = pdf + params[[i]]$cdf = cdf + params[[i]]$quantile = quantile + } + } else if (form == "aft") { + for (i in seq_along(lp)) { + body(pdf) = substitute((exp(-y) * basedist$hazard(x / exp(y))) * (1 - self$cdf(x)), + list(y = lp[i])) + body(cdf) = substitute(1 - (basedist$survival(x / exp(y))), list(y = lp[i])) + body(quantile) = substitute(exp(y) * basedist$quantile(p), list(y = lp[i])) + params[[i]]$pdf = pdf + params[[i]]$cdf = cdf + params[[i]]$quantile = quantile + } + } else if (form == "po") { + for (i in seq_along(lp)) { + body(pdf) = substitute((basedist$hazard(x) * + (1 - (basedist$survival(x) / + (((exp(y) - 1)^-1) + basedist$survival(x))))) * + (1 - self$cdf(x)), list(y = lp[i])) + body(cdf) = substitute(1 - (basedist$survival(x) * + (exp(-y) + (1 - exp(-y)) * basedist$survival(x))^-1), # nolint + list(y = lp[i])) + body(quantile) = substitute(basedist$quantile(-p / ((exp(-y) * (p - 1)) - p)), # nolint + list(y = lp[i])) + params[[i]]$pdf = pdf + params[[i]]$cdf = cdf + params[[i]]$quantile = quantile + } + } + + distlist = lapply(params, function(.x) do.call(distr6::Distribution$new, .x)) + names(distlist) = paste0(short_name, seq_along(distlist)) + + distr = distr6::VectorDistribution$new(distlist, + decorators = c("CoreStatistics", "ExoticStatistics")) + + lp = lp + fit$coefficients[1] + + list(lp = as.numeric(lp), distr = distr) +} + +.predict_survreg_discrete <- function(object, newdata, form, predict_times, + basedist, fit, lp) { + + if (form == "tobit") { + fun = function(y) stats::pnorm((predict_times - y - fit$coefficients[1]) / basedist$stdev()) + } else if (form == "ph") { + fun = function(y) 1 - (basedist$survival(predict_times)^exp(-y)) + } else if (form == "aft") { + fun = function(y) 1 - (basedist$survival(predict_times / exp(y))) + } else if (form == "po") { + fun = function(y) { + surv = basedist$survival(predict_times) + 1 - (surv * (exp(-y) + (1 - exp(-y)) * surv)^-1) + } + } + + if (length(predict_times) == 1) { # edge case + mat <- as.matrix(vapply(lp, fun, numeric(1)), ncol = 1) + } else { + mat <- t(vapply(lp, fun, numeric(length(predict_times)))) + } + colnames(mat) <- predict_times + + list( + lp = as.numeric(lp + fit$coefficients[1]), + distr = 1 - mat + ) +} diff --git a/R/utils.R b/R/utils.R index 4221f09..68b4d64 100644 --- a/R/utils.R +++ b/R/utils.R @@ -26,4 +26,8 @@ fill_na <- function(x, along = 1) { } .x })) -} \ No newline at end of file +} + +toproper <- function(str) { + paste0(toupper(substr(str, 1, 1)), substr(str, 2, 100)) +} diff --git a/man/coxtime.Rd b/man/coxtime.Rd index 3ca0a01..94e07a9 100644 --- a/man/coxtime.Rd +++ b/man/coxtime.Rd @@ -129,7 +129,7 @@ Implemented from the \code{pycox} Python package via \CRANpkg{reticulate}. Calls \code{pycox.models.Coxtime}. } \examples{ -\donttest{ +\dontrun{ if (requireNamespaces("reticulate")) { # all defaults coxtime(data = simsurvdata(50)) diff --git a/man/parametric.Rd b/man/parametric.Rd new file mode 100644 index 0000000..be6e356 --- /dev/null +++ b/man/parametric.Rd @@ -0,0 +1,65 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/parametric.R +\name{parametric} +\alias{parametric} +\title{Fully Parametric Survival Model} +\usage{ +parametric( + formula = NULL, + data = NULL, + reverse = FALSE, + time_variable = "time", + status_variable = "status", + x = NULL, + y = NULL, + eps = 1e-15, + ... +) +} +\arguments{ +\item{formula}{\code{(formula(1))}\cr +Object specifying the model fit, left-hand-side of formula should describe a \code{\link[survival:Surv]{survival::Surv()}} +object.} + +\item{data}{\code{(data.frame(1))}\cr +Training data of \code{data.frame} like object, internally is coerced with \code{\link[stats:model.matrix]{stats::model.matrix()}}.} + +\item{reverse}{\code{(logical(1))}\cr +If \code{TRUE} fits estimator on censoring distribution, otherwise (default) survival distribution.} + +\item{time_variable}{\code{(character(1))}\cr +Alternative method to call the function. Name of the 'time' variable, required if \code{formula}. +or \code{x} and \code{Y} not given.} + +\item{status_variable}{\code{(character(1))}\cr +Alternative method to call the function. Name of the 'status' variable, required if \code{formula} +or \code{x} and \code{Y} not given.} + +\item{x}{\code{(data.frame(1))}\cr +Alternative method to call the function. Required if \verb{formula, time_variable} and +\code{status_variable} not given. Data frame like object of features which is internally +coerced with \code{model.matrix}.} + +\item{y}{\verb{([survival::Surv()])}\cr +Alternative method to call the function. Required if \verb{formula, time_variable} and +\code{status_variable} not given. Survival outcome of right-censored observations.} + +\item{eps}{\code{(numeric(1))} \cr +Used when the fitted \code{scale} parameter is too small. Default \code{1e-15}.} + +\item{...}{\code{ANY} \cr +Additional arguments passed to \code{\link[survival:survreg]{survival::survreg()}}.} +} +\value{ +An object inheriting from class \code{parametric}. +} +\description{ +Fit/predict implementation of \code{\link[survival:survreg]{survival::survreg()}}, which can return +absolutely continuous distribution predictions using \pkg{distr6}. +} +\examples{ +if (requireNamespaces(c("distr6", "survival"))) { + library(survival) + parametric(Surv(time, status) ~ ., data = simsurvdata(10)) +} +} diff --git a/man/predict.parametric.Rd b/man/predict.parametric.Rd new file mode 100644 index 0000000..2aa28e4 --- /dev/null +++ b/man/predict.parametric.Rd @@ -0,0 +1,99 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/parametric.R +\name{predict.parametric} +\alias{predict.parametric} +\title{Predict method for Parametric Model} +\usage{ +\method{predict}{parametric}( + object, + newdata, + form = c("aft", "ph", "tobit", "po"), + times = NULL, + type = c("survival", "risk", "all"), + distr6 = FALSE, + ntime = 150, + round_time = 2, + ... +) +} +\arguments{ +\item{object}{(\code{parametric(1)})\cr +Object of class inheriting from \code{"parametric"}.} + +\item{newdata}{\code{(data.frame(1))}\cr +Testing data of \code{data.frame} like object, internally is coerced with \code{\link[stats:model.matrix]{stats::model.matrix()}}. +If missing then training data from fitted object is used.} + +\item{form}{\code{(character(1))} \cr +The form of the predicted distribution, see \code{details} for options.} + +\item{times}{\code{(numeric())}\cr +Times at which to evaluate the estimator. If \code{NULL} (default) then evaluated at all unique times +in the training set.} + +\item{type}{(\code{character(1)})\cr +Type of predicted value. Choices are survival probabilities over all time-points in training +data (\code{"survival"}) or a relative risk ranking (\code{"risk"}), which is the sum of the predicted +cumulative hazard function so higher rank implies higher risk of event, or both (\code{"all"}).} + +\item{distr6}{(\code{logical(1)})\cr +If \code{FALSE} (default) and \code{type} is \code{"survival"} or \code{"all"} returns matrix of survival +probabilities, otherwise returns a \code{\link[distr6:Distribution]{distr6::Distribution()}}.} + +\item{ntime}{\code{(numeric(1))}\cr +Number of unique time-points in the training set, default is 150.} + +\item{round_time}{\code{(numeric(1))}\cr +Number of decimal places to round time-points to, default is 2, set to \code{FALSE} for no rounding.} + +\item{...}{\code{ANY} \cr +Currently ignored.} +} +\value{ +A \code{numeric} if \code{type = "risk"}, a \code{\link[distr6:Distribution]{distr6::Distribution()}} +(if \code{distr6 = TRUE}) and \code{type = "survival"}; a \code{matrix} if +(\code{distr6 = FALSE}) and \code{type = "survival"} where entries are survival +probabilities with rows of observations and columns are time-points; +or a list combining above if \code{type = "all"}. +} +\description{ +Predicted values from a fitted Parametric survival model. +} +\details{ +The \code{form} parameter determines how the distribution is created. +Options are: +\itemize{ +\item Accelerated failure time (\code{"aft"}) \deqn{h(t) = h_0(\frac{t}{exp(lp)})exp(-lp)} +\item Proportional Hazards (\code{"ph"}) \deqn{h(t) = h_0(t)exp(lp)} +\item Tobit (\code{"tobit"}) \deqn{h(t) = \Phi(\frac{t - lp}{scale})} +\item Proportional odds (\code{"po"}) \deqn{h(t) = \frac{h_0(t)}{1 + (exp(lp)-1)S_0(t)}} +} + +where \eqn{h_0,S_0} are the estimated baseline hazard and survival functions +(in this case with a given parametric form), \eqn{lp} is the predicted linear +predictor calculated using the formula \eqn{lp = \hat{\beta} X_{new}} where +\eqn{X_{new}} are the variables in the test data set and \eqn{\hat{\beta}} +are the coefficients from the fitted parametric survival model (\code{object}). +\eqn{\Phi} is the cdf of a N(0, 1) distribution, and \eqn{scale} is the +fitted scale parameter. +} +\examples{ +if (requireNamespaces(c("distr6", "survival"))) { + library(survival) + + set.seed(42) + train <- simsurvdata(10) + test <- simsurvdata(5) + fit <- parametric(Surv(time, status) ~ ., data = train) + + # Return a discrete distribution survival matrix + predict_distr <- predict(fit, newdata = test) + predict_distr + + # Return a relative risk ranking with type = "risk" + predict(fit, newdata = test, type = "risk") + + # Or survival probabilities and a rank + predict(fit, newdata = test, type = "all", distr6 = TRUE) +} +} diff --git a/man/predict.pycox.Rd b/man/predict.pycox.Rd index 0cde50c..86e1377 100644 --- a/man/predict.pycox.Rd +++ b/man/predict.pycox.Rd @@ -67,7 +67,7 @@ or a list combining above if \code{type = "all"}. Predicted values from a fitted pycox ANN. } \examples{ -\donttest{ +\dontrun{ if (requireNamespaces("reticulate")) { fit <- coxtime(data = simsurvdata(50)) diff --git a/man/survivalmodels-package.Rd b/man/survivalmodels-package.Rd index 674c61d..4581198 100644 --- a/man/survivalmodels-package.Rd +++ b/man/survivalmodels-package.Rd @@ -18,6 +18,11 @@ Useful links: } \author{ -\strong{Maintainer}: Raphael Sonabend \email{raphaelsonabend@gmail.com} (\href{https://orcid.org/0000-0001-9225-4654}{ORCID}) +\strong{Maintainer}: Yohann Foucher \email{yohann.foucher@univ-poitiers.fr} (\href{https://orcid.org/0000-0003-0330-7457}{ORCID}) + +Authors: +\itemize{ + \item Raphael Sonabend (\href{https://orcid.org/0000-0001-9225-4654}{ORCID}) +} } diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index 86a6293..6af38b7 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -26,12 +26,14 @@ sanity_check <- function(model, pars) { c(list(formula = Surv(time, status) ~ ., data = train), pars) ) - p <- predict(fit, newdata = test, type = "all", distr6 = TRUE) - - - expect_equal(length(p$risk), nrow(distr6::gprm(p$surv, "cdf"))) + p <- predict(fit, newdata = test, type = "all", distr6 = TRUE, + return_method = "discrete") + if (model != "parametric") { + expect_equal(length(p$risk), nrow(distr6::gprm(p$surv, "cdf"))) + } - p <- predict(fit, newdata = test, type = "all", distr6 = FALSE) + p <- predict(fit, newdata = test, type = "all", distr6 = FALSE, + return_method = "discrete") expect_equal(length(p$risk), nrow(p$surv)) } diff --git a/tests/testthat/test_parametric.R b/tests/testthat/test_parametric.R new file mode 100644 index 0000000..b1b03f1 --- /dev/null +++ b/tests/testthat/test_parametric.R @@ -0,0 +1,122 @@ +if (!requireNamespace("distr6", quietly = TRUE)) { + skip("distr6 not installed.") +} + +test_that("silent", { + expect_error(parametric(Surv(time, status) ~ .)) + expect_silent(parametric(Surv(time, status) ~ ., data = rats[1:10, ])) + fit <- parametric(Surv(time, status) ~ ., data = rats[1:10, ]) + expect_equal(predict(fit), predict(fit, rats[1:10, ])) + expect_error(parametric(x = "litter"), "Both 'x' and 'y'") + expect_error(parametric(time_variable = "time"), "'time_variable'") + expect_error(parametric( + x = rats[, c("rx", "litter")], + y = rats$time), "is not TRUE") + expect_error(parametric( + x = rats$rx, + y = Surv(rats$time, rats$status) + ), "data.frame") +}) + +test_that("auto sanity", { + sanity_check( + model = "parametric", + pars = list() + ) +}) + +form_opts <- c("aft", "ph", "po", "tobit") + +test_that("confirm lp and risk directions the same", { + + for (form in form_opts) { + fit <- parametric(Surv(time, status) ~ ., data = rats) + pred <- predict(fit, newdata = rats, type = "all", form = form) + expect_true(all.equal(order(surv_to_risk(pred$surv)), order(pred$risk))) + } +}) + +test_that("manualtest - aft", { + df = simsurvdata(50) + fit = parametric(Surv(time, status) ~ ., df, dist = "weibull") + p = predict(fit, df, type = "all", distr6 = TRUE) + + expect_equal(-p$risk, unname(predict(fit$model, type = "lp"))) + expect_equal(p$surv[1]$survival(predict( + fit$model, type = "quantile", p = c(0.2, 0.8) + )[1, ]), c(0.8, 0.2)) + expect_equal(p$surv[10]$cdf(predict( + fit$model, type = "quantile", p = seq.int(0, 1, 0.1) + )[10, ]), + seq.int(0, 1, 0.1)) + + fit = parametric(Surv(time, status) ~ ., df, dist = "lognormal") + p = predict(fit, df, type = "all", distr6 = TRUE) + + expect_equal(p$surv[15]$cdf(predict( + fit$model, type = "quantile", p = seq.int(0, 1, 0.1) + )[15, ]), seq.int(0, 1, 0.1)) +}) + +test_that("quantile type", { + df <- simsurvdata(50) + fit <- parametric(Surv(time, status) ~ ., df) + + p <- predict(fit, df, type = "all", form = "aft", distr6 = TRUE) + quantile <- p$surv$quantile(c(0.2, 0.8)) + expect_equal(matrix(t(quantile), ncol = 2), + predict(fit$model, type = "quantile", p = c(0.2, 0.8))) + + for (form in form_opts) { + p <- predict(fit, df, type = "all", form = form, distr6 = TRUE) + quantile <- p$surv$quantile(0.5) + expect_equal(unlist(p$surv$cdf(quantile), use.names = FALSE), rep(0.5, 50)) + } +}) + +dist_opts <- c("weibull", "exponential", "lognormal", "gaussian", "loglogistic") + +test_that("quantile dist", { + + df <- simsurvdata(50) + + for (dist in dist_opts) { + if (dist == "loglogistic") skip_if_not_installed("actuar") + fit <- parametric(Surv(time, status) ~ ., df, dist = dist) + form <- ifelse(dist == "gaussian", "tobit", "aft") + p <- predict(fit, df, form = form, distr6 = TRUE)$quantile(c(0.2, 0.8)) + expect_equal( + matrix(t(p), ncol = 2), + predict(fit$model, type = "quantile", p = c(0.2, 0.8), distr6 = TRUE) + ) + } +}) + +test_that("cdf dist", { + df <- simsurvdata(50) + + for (dist in dist_opts) { + if (dist == "loglogistic") skip_if_not_installed("actuar") + fit <- parametric(Surv(time, status) ~ ., df, dist = dist) + form <- ifelse(dist == "gaussian", "tobit", "aft") + p <- predict(fit, df, form = form, distr6 = TRUE) + cdf <- predict(fit$model, type = "quantile", p = c(0.2, 0.8)) + expect_equal(unname(as.matrix(p$cdf(data = t(cdf)))), + matrix(c(rep(0.2, 50), rep(0.8, 50)), byrow = TRUE, nrow = 2)) + } +}) + + +test_that("discrete = continuous when expected", { + fit <- parametric(Surv(time, status) ~ ., rats) + + for (form in form_opts) { + p_cont <- predict(fit, rats, form = form, type = "all", distr6 = TRUE) + p_disc <- predict(fit, rats, form = form, type = "all") + expect_equal(p_cont$risk, p_disc$risk) + utimes <- sort(unique(rats$time)) + s_cont <- as.matrix(p_cont$surv$survival(utimes)) + dimnames(s_cont) <- list(utimes, NULL) + expect_equal(s_cont, t(p_disc$surv)) + } +})