Skip to content

Commit

Permalink
boostrap method to type
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 8, 2024
1 parent 36206c1 commit 5a56bff
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 15 deletions.
28 changes: 17 additions & 11 deletions R/bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ permute_clusters <- function(model, pcp_mle) {
#'
#' @param model An `nhmm` or `mnhmm` object.
#' @param B number of bootstrap samples.
#' @param method Either `"nonparametric"` or `"parametric"`, to define whether
#' @param type Either `"nonparametric"` or `"parametric"`, to define whether
#' nonparametric or parametric bootstrap should be used. The former samples
#' sequences with replacement, whereas the latter simulates new datasets based
#' on the model.
Expand All @@ -84,9 +84,9 @@ bootstrap_coefs <- function(model, ...) {
#' @rdname bootstrap
#' @export
bootstrap_coefs.nhmm <- function(model, B = 1000,
method = c("nonparametric", "parametric"),
type = c("nonparametric", "parametric"),
verbose = FALSE, ...) {
method <- match.arg(method)
type <- match.arg(type)
stopifnot_(
checkmate::test_int(x = B, lower = 0L),
"Argument {.arg B} must be a single positive integer."
Expand All @@ -96,13 +96,15 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
gamma_pi <- replicate(B, gammas_mle$pi, simplify = FALSE)
gamma_A <- replicate(B, gammas_mle$A, simplify = FALSE)
gamma_B <- replicate(B, gammas_mle$B, simplify = FALSE)
lambda <- model$estimation_results$lambda

if (verbose) pb <- utils::txtProgressBar(min = 0, max = B, style = 3)
if (method == "nonparametric") {
if (type == "nonparametric") {
out <- future.apply::future_lapply(
seq_len(B), function(i) {
mod <- bootstrap_model(model)
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, ...)
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
...)
if (verbose) utils::setTxtProgressBar(pb, i)
permute_states(fit$gammas, gammas_mle)
}
Expand All @@ -123,7 +125,8 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
mod <- simulate_nhmm(
N, T_, M, S, formula_pi, formula_A, formula_B,
data = d, time, id, init)$model
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, ...)
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
...)
if (verbose) utils::setTxtProgressBar(pb, i)
fit$gammas <- permute_states(fit$gammas, gammas_mle)
}
Expand All @@ -137,9 +140,9 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
#' @rdname bootstrap
#' @export
bootstrap_coefs.mnhmm <- function(model, B = 1000,
method = c("nonparametric", "parametric"),
type = c("nonparametric", "parametric"),
verbose = FALSE, ...) {
method <- match.arg(method)
type <- match.arg(type)
stopifnot_(
checkmate::test_int(x = B, lower = 0L),
"Argument {.arg B} must be a single positive integer."
Expand All @@ -151,12 +154,14 @@ bootstrap_coefs.mnhmm <- function(model, B = 1000,
gamma_A <- replicate(B, gammas_mle$A, simplify = FALSE)
gamma_B <- replicate(B, gammas_mle$B, simplify = FALSE)
gamma_omega <- replicate(B, gammas_mle$omega, simplify = FALSE)
lambda <- model$estimation_results$lambda
D <- model$n_clusters
if (verbose) pb <- utils::txtProgressBar(min = 0, max = B, style = 3)
if (method == "nonparametric") {
if (type == "nonparametric") {
for (i in seq_len(B)) {
mod <- bootstrap_model(model)
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, ...)
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
...)
fit <- permute_clusters(fit, pcp_mle)
for (j in seq_len(D)) {
out <- permute_states(
Expand Down Expand Up @@ -189,7 +194,8 @@ bootstrap_coefs.mnhmm <- function(model, B = 1000,
mod <- simulate_mnhmm(
N, T_, M, S, D, formula_pi, formula_A, formula_B, formula_omega,
data = d, time, id, init)$model
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, ...)
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
...)
fit <- permute_clusters(fit, pcp_mle)
for (j in seq_len(D)) {
out <- permute_states(
Expand Down
2 changes: 2 additions & 0 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
all_solutions = all_solutions,
time = end_time - start_time
)

model$estimation_results$lambda <- lambda
model
}
2 changes: 1 addition & 1 deletion R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,6 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
x_abs_change = out$absolute_x_change
)
}

model$estimation_results$lambda <- lambda
model
}
6 changes: 3 additions & 3 deletions man/bootstrap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5a56bff

Please sign in to comment.