Skip to content

Commit

Permalink
tweak tolerances, permutation of states for bootstrap
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Oct 5, 2024
1 parent be42ffa commit 440f3d8
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 34 deletions.
8 changes: 6 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ backward_mnhmm_multichannel <- function(eta_A, X_s, eta_B, X_o, obs, M) {
.Call(`_seqHMM_backward_mnhmm_multichannel`, eta_A, X_s, eta_B, X_o, obs, M)
}

cost_matrix <- function(gamma_pi_est, gamma_pi_ref, gamma_A_est, gamma_A_ref, gamma_B_est, gamma_B_ref) {
.Call(`_seqHMM_cost_matrix`, gamma_pi_est, gamma_pi_ref, gamma_A_est, gamma_A_ref, gamma_B_est, gamma_B_ref)
cost_matrix_singlechannel <- function(gamma_pi_est, gamma_pi_ref, gamma_A_est, gamma_A_ref, gamma_B_est, gamma_B_ref) {
.Call(`_seqHMM_cost_matrix_singlechannel`, gamma_pi_est, gamma_pi_ref, gamma_A_est, gamma_A_ref, gamma_B_est, gamma_B_ref)
}

cost_matrix_multichannel <- function(gamma_pi_est, gamma_pi_ref, gamma_A_est, gamma_A_ref, gamma_B_est, gamma_B_ref) {
.Call(`_seqHMM_cost_matrix_multichannel`, gamma_pi_est, gamma_pi_ref, gamma_A_est, gamma_A_ref, gamma_B_est, gamma_B_ref)
}

create_Q <- function(n) {
Expand Down
60 changes: 44 additions & 16 deletions R/bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,59 @@ bootstrap_model <- function(model) {
}
model
}
permute_states <- function(gammas_boot, gammas_mle) {
C <- if(is.list(gammas_mle$B)) length(gammas_mle$B) else 1
if (C == 1) {
m <- cost_matrix_singlechannel(
gammas_boot$pi, gammas_mle$pi,
gammas_boot$A, gammas_mle$A,
gammas_boot$B, gammas_mle$B
)
} else {
m <- cost_matrix_multichannel(
gammas_boot$pi, gammas_mle$pi,
gammas_boot$A, gammas_mle$A,
gammas_boot$B, gammas_mle$B
)
}
perm <- RcppHungarian::HungarianSolver(m)$pairs[, 2]
gammas_boot$pi <- gammas_boot$pi[perm, , drop = FALSE]
gammas_boot$A <- gammas_boot$A[perm, , perm, drop = FALSE]
if (C == 1) {
gammas_boot$B <- gammas_boot$B[, , perm, drop = FALSE]
} else {
for (c in seq_len(C)) {
gammas_boot$B[[c]] <- gammas_boot$B[[c]][, , perm, drop = FALSE]
}
}
gammas_boot
}
#' @export
bootstrap_coefs.nhmm <- function(model, B = 1000,
method = c("nonparametric", "parametric"),
penalty) {
penalty, verbose = FALSE, ...) {
method <- match.arg(method)
stopifnot_(
checkmate::test_int(x = B, lower = 0L),
"Argument {.arg B} must be a single positive integer."
)
init <- model$coefficients
init <- model$etas
if (missing(penalty)) {
penalty <- model$estimation_results$penalty
} else {
penalty <- 4
}
mle_coefficients <- model$coefficients
gammas_mle <- model$gammas

coefs <- matrix(NA, length(unlist(gammas_mle)), B)
if (method == "nonparametric") {
coefs <- matrix(NA, length(unlist(init)), B)
for (i in seq_len(B)) {
mod <- bootstrap_model(model)
fit <- fit_nhmm(mod, init, 0, 0, 1, penalty, FALSE)
m <- cost_matrix(fit$coefficients$eta_pi, mle_coefficients$eta_pi,
fit$coefficients$eta_A, mle_coefficients$eta_A,
fit$coefficients$eta_B, mle_coefficients$eta_B,
fit$X_initial, fit$X_transition, fit$X_emission)
perm <- RcppHungarian:HungarianSolver(m)$pairs[, 2]
fit$coefficients$gamma_pi[perm]
coefs[, i] <- unlist(fit$coefficients)
fit <- fit_nhmm(mod, init, 0, 0, 1, penalty, ...)
coefs[, i] <- unlist(permute_states(fit$gammas, gammas_mle))
if(verbose) print(paste0("Bootstrap replication ", i, " complete."))
}
} else {
coefs <- matrix(NA, length(unlist(init)), B)
N <- model$n_sequences
T_ <- model$sequence_lengths
M <- model$n_symbols
Expand All @@ -58,15 +82,17 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
mod <- simulate_nhmm(
N, T_, M, S, formula_pi, formula_A, formula_B,
data = d, time, id, init)$model
fit <- fit_nhmm(mod, init, 0, 0, 1, penalty, FALSE)
coefs[, i] <- unlist(fit$coefficients)
fit <- fit_nhmm(mod, init, 0, 0, 1, penalty, ...)
coefs[, i] <- unlist(permute_states(fit$gammas, gammas_mle))
print(paste0("Bootstrap replication ", i, " complete."))
}
}
return(coefs)
}
#' @export
bootstrap_coefs.mnhmm <- function(model, B = 1000,
method = c("nonparametric", "parametric")) {
method = c("nonparametric", "parametric"),
verbose = FALSE) {
method <- match.arg(method)
stopifnot_(
checkmate::test_int(x = B, lower = 0L),
Expand All @@ -82,6 +108,7 @@ bootstrap_coefs.mnhmm <- function(model, B = 1000,
mod <- bootstrap_model(model)
fit <- fit_mnhmm(mod, init, 0, 0, 1, penalty, FALSE)
coefs[, i] <- unlist(fit$coefficients)
print(paste0("Bootstrap replication ", i, " complete."))
}
} else {
coefs <- matrix(NA, length(unlist(init)), B)
Expand All @@ -103,6 +130,7 @@ bootstrap_coefs.mnhmm <- function(model, B = 1000,
data = d, time, id, init)$model
fit <- fit_mnhmm(mod, init, 0, 0, 1, penalty, FALSE)
coefs[, i] <- unlist(fit$coefficients)
print(paste0("Bootstrap replication ", i, " complete."))
}
}
return(coefs)
Expand Down
18 changes: 12 additions & 6 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,20 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
}
}
}

user_def_penalty <- penalty
if (penalty == 0) penalty <- 4
if (restarts > 0L) {
if (threads > 1L) {
future::plan(future::multisession, workers = threads)
} else {
future::plan(future::sequential)
}
if (is.null(dots$maxeval)) dots$maxeval <- 1000L
if (is.null(dots$print_level )) dots$print_level <- 0
if (is.null(dots$xtol_rel)) dots$xtol_rel <- 1e-2
if (is.null(dots$xtol_rel)) dots$ftol_rel <- 1e-4
if (is.null(dots$print_level)) dots$print_level <- 0
if (is.null(dots$xtol_abs)) dots$xtol_abs <- 1e-2
if (is.null(dots$ftol_abs)) dots$ftol_abs <- 1e-2
if (is.null(dots$xtol_rel)) dots$xtol_rel <- 1e-4
if (is.null(dots$xtol_rel)) dots$ftol_rel <- 1e-8
if (is.null(dots$check_derivatives)) dots$check_derivatives <- FALSE
out <- future.apply::future_lapply(seq_len(restarts), function(i) {
init <- unlist(create_initial_values(
Expand All @@ -227,9 +230,12 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
dots <- list(...)
if (is.null(dots$algorithm)) dots$algorithm <- "NLOPT_LD_LBFGS"
if (is.null(dots$maxeval)) dots$maxeval <- 10000L
if (is.null(dots$xtol_rel)) dots$xtol_rel <- 1e-6
if (is.null(dots$xtol_rel)) dots$ftol_rel <- 1e-12
if (is.null(dots$xtol_abs)) dots$xtol_abs <- 1e-4
if (is.null(dots$ftol_abs)) dots$ftol_abs <- 1e-4
if (is.null(dots$xtol_rel)) dots$xtol_rel <- 1e-4
if (is.null(dots$xtol_rel)) dots$ftol_rel <- 1e-8
if (is.null(dots$check_derivatives)) dots$check_derivatives <- FALSE
penalty <- user_def_penalty
out <- nloptr(
x0 = init, eval_f = objectivef,
opts = dots
Expand Down
17 changes: 12 additions & 5 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,20 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
}
}
}
user_def_penalty <- penalty
if (penalty == 0) penalty <- 4
if (restarts > 0L) {
if (threads > 1L) {
future::plan(future::multisession, workers = threads)
} else {
future::plan(future::sequential)
}
if (is.null(dots$maxeval)) dots$maxeval <- 1000L
if (is.null(dots$print_level )) dots$print_level <- 0
if (is.null(dots$xtol_rel)) dots$xtol_rel <- 1e-2
if (is.null(dots$xtol_rel)) dots$ftol_rel <- 1e-4
if (is.null(dots$print_level)) dots$print_level <- 0
if (is.null(dots$xtol_abs)) dots$xtol_abs <- 1e-2
if (is.null(dots$ftol_abs)) dots$ftol_abs <- 1e-2
if (is.null(dots$xtol_rel)) dots$xtol_rel <- 1e-4
if (is.null(dots$xtol_rel)) dots$ftol_rel <- 1e-8
if (is.null(dots$check_derivatives)) dots$check_derivatives <- FALSE
out <- future.apply::future_lapply(seq_len(restarts), function(i) {
init <- unlist(create_initial_values(
Expand All @@ -181,9 +185,12 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
dots <- list(...)
if (is.null(dots$algorithm)) dots$algorithm <- "NLOPT_LD_LBFGS"
if (is.null(dots$maxeval)) dots$maxeval <- 10000L
if (is.null(dots$xtol_rel)) dots$xtol_rel <- 1e-6
if (is.null(dots$xtol_rel)) dots$ftol_rel <- 1e-12
if (is.null(dots$xtol_abs)) dots$xtol_abs <- 1e-4
if (is.null(dots$ftol_abs)) dots$ftol_abs <- 1e-4
if (is.null(dots$xtol_rel)) dots$xtol_rel <- 1e-4
if (is.null(dots$xtol_rel)) dots$ftol_rel <- 1e-8
if (is.null(dots$check_derivatives)) dots$check_derivatives <- FALSE
penalty <- user_def_penalty
out <- nloptr(
x0 = init, eval_f = objectivef,
opts = dots
Expand Down
27 changes: 22 additions & 5 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// cost_matrix
arma::mat cost_matrix(const arma::mat& gamma_pi_est, const arma::mat& gamma_pi_ref, const arma::cube& gamma_A_est, const arma::cube& gamma_A_ref, const arma::cube& gamma_B_est, const arma::cube& gamma_B_ref);
RcppExport SEXP _seqHMM_cost_matrix(SEXP gamma_pi_estSEXP, SEXP gamma_pi_refSEXP, SEXP gamma_A_estSEXP, SEXP gamma_A_refSEXP, SEXP gamma_B_estSEXP, SEXP gamma_B_refSEXP) {
// cost_matrix_singlechannel
arma::mat cost_matrix_singlechannel(const arma::mat& gamma_pi_est, const arma::mat& gamma_pi_ref, const arma::cube& gamma_A_est, const arma::cube& gamma_A_ref, const arma::cube& gamma_B_est, const arma::cube& gamma_B_ref);
RcppExport SEXP _seqHMM_cost_matrix_singlechannel(SEXP gamma_pi_estSEXP, SEXP gamma_pi_refSEXP, SEXP gamma_A_estSEXP, SEXP gamma_A_refSEXP, SEXP gamma_B_estSEXP, SEXP gamma_B_refSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -126,7 +126,23 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const arma::cube& >::type gamma_A_ref(gamma_A_refSEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type gamma_B_est(gamma_B_estSEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type gamma_B_ref(gamma_B_refSEXP);
rcpp_result_gen = Rcpp::wrap(cost_matrix(gamma_pi_est, gamma_pi_ref, gamma_A_est, gamma_A_ref, gamma_B_est, gamma_B_ref));
rcpp_result_gen = Rcpp::wrap(cost_matrix_singlechannel(gamma_pi_est, gamma_pi_ref, gamma_A_est, gamma_A_ref, gamma_B_est, gamma_B_ref));
return rcpp_result_gen;
END_RCPP
}
// cost_matrix_multichannel
arma::mat cost_matrix_multichannel(const arma::mat& gamma_pi_est, const arma::mat& gamma_pi_ref, const arma::cube& gamma_A_est, const arma::cube& gamma_A_ref, const arma::field<arma::cube>& gamma_B_est, arma::field<arma::cube>& gamma_B_ref);
RcppExport SEXP _seqHMM_cost_matrix_multichannel(SEXP gamma_pi_estSEXP, SEXP gamma_pi_refSEXP, SEXP gamma_A_estSEXP, SEXP gamma_A_refSEXP, SEXP gamma_B_estSEXP, SEXP gamma_B_refSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const arma::mat& >::type gamma_pi_est(gamma_pi_estSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type gamma_pi_ref(gamma_pi_refSEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type gamma_A_est(gamma_A_estSEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type gamma_A_ref(gamma_A_refSEXP);
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type gamma_B_est(gamma_B_estSEXP);
Rcpp::traits::input_parameter< arma::field<arma::cube>& >::type gamma_B_ref(gamma_B_refSEXP);
rcpp_result_gen = Rcpp::wrap(cost_matrix_multichannel(gamma_pi_est, gamma_pi_ref, gamma_A_est, gamma_A_ref, gamma_B_est, gamma_B_ref));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -940,7 +956,8 @@ static const R_CallMethodDef CallEntries[] = {
{"_seqHMM_backward_nhmm_multichannel", (DL_FUNC) &_seqHMM_backward_nhmm_multichannel, 6},
{"_seqHMM_backward_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_backward_mnhmm_singlechannel, 5},
{"_seqHMM_backward_mnhmm_multichannel", (DL_FUNC) &_seqHMM_backward_mnhmm_multichannel, 6},
{"_seqHMM_cost_matrix", (DL_FUNC) &_seqHMM_cost_matrix, 6},
{"_seqHMM_cost_matrix_singlechannel", (DL_FUNC) &_seqHMM_cost_matrix_singlechannel, 6},
{"_seqHMM_cost_matrix_multichannel", (DL_FUNC) &_seqHMM_cost_matrix_multichannel, 6},
{"_seqHMM_create_Q", (DL_FUNC) &_seqHMM_create_Q, 1},
{"_seqHMM_eta_to_gamma_mat", (DL_FUNC) &_seqHMM_eta_to_gamma_mat, 1},
{"_seqHMM_eta_to_gamma_cube", (DL_FUNC) &_seqHMM_eta_to_gamma_cube, 1},
Expand Down
44 changes: 44 additions & 0 deletions src/cost_matrix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include<RcppArmadillo.h>

// [[Rcpp::export]]
arma::mat cost_matrix_singlechannel(
const arma::mat& gamma_pi_est, const arma::mat& gamma_pi_ref,
const arma::cube& gamma_A_est, const arma::cube& gamma_A_ref,
const arma::cube& gamma_B_est, const arma::cube& gamma_B_ref) {

unsigned int S = gamma_A_ref.n_slices;
arma::mat costs(S, S);

for (unsigned int j = 0; j < S; j++) {
for (unsigned int k = 0; k < S; k++) {
double cost_pi = arma::norm(gamma_pi_est.row(j) - gamma_pi_ref.row(k));
double cost_A = arma::norm(arma::vectorise(gamma_A_est.slice(j) - gamma_A_ref.slice(k)));
double cost_B = arma::norm(arma::vectorise(gamma_B_est.slice(j) - gamma_B_ref.slice(k)));
costs(k, j) = cost_pi + cost_A + cost_B;
}
}
return costs.t();
}
// [[Rcpp::export]]
arma::mat cost_matrix_multichannel(
const arma::mat& gamma_pi_est, const arma::mat& gamma_pi_ref,
const arma::cube& gamma_A_est, const arma::cube& gamma_A_ref,
const arma::field<arma::cube>& gamma_B_est, arma::field<arma::cube>& gamma_B_ref) {

unsigned int S = gamma_A_ref.n_slices;
unsigned int C = gamma_B_est.n_elem;
arma::mat costs(S, S);

for (unsigned int j = 0; j < S; j++) {
for (unsigned int k = 0; k < S; k++) {
double cost_pi = arma::norm(gamma_pi_est.row(j) - gamma_pi_ref.row(k));
double cost_A = arma::norm(arma::vectorise(gamma_A_est.slice(j) - gamma_A_ref.slice(k)));
double cost_B = 0;
for (unsigned int c = 0; c < C; c++){
cost_B += arma::norm(arma::vectorise(gamma_B_est(c).slice(j) - gamma_B_ref(c).slice(k)));
}
costs(k, j) = cost_pi + cost_A + cost_B;
}
}
return costs.t();
}

0 comments on commit 440f3d8

Please sign in to comment.