Skip to content

Commit

Permalink
pseudocounts
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 9, 2024
1 parent 27dd4f6 commit 039f5f8
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 44 deletions.
8 changes: 4 additions & 4 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ logSumExp <- function(x) {
.Call(`_seqHMM_logSumExp`, x)
}

EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda) {
.Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda)
EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount) {
.Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount)
}

EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda) {
.Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda)
EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount) {
.Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount)
}

backward_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B) {
Expand Down
6 changes: 4 additions & 2 deletions R/estimate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ estimate_mnhmm <- function(
transition_formula = ~1, emission_formula = ~1, cluster_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL,
channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2,
restarts = 0L, lambda = 0, method = "EM", store_data = TRUE, ...) {
restarts = 0L, lambda = 0, method = "EM", pseudocount = 0,
store_data = TRUE, ...) {

call <- match.call()
model <- build_mnhmm(
Expand All @@ -63,7 +64,8 @@ estimate_mnhmm <- function(
if (store_data) {
model$data <- data
}
out <- fit_mnhmm(model, inits, init_sd, restarts, lambda, method, ...)
out <- fit_mnhmm(model, inits, init_sd, restarts, lambda, method,
pseudocount, ...)

attr(out, "call") <- call
out
Expand Down
9 changes: 7 additions & 2 deletions R/estimate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@
#' @param method Optimization method used. Default is `"EM"` which uses EM
#' algorithm with L-BFGS in the M-step. Another option is `"DNM"` which uses
#' direct maximization of the log-likelihood using [nloptr::nloptr()].
#' @param pseudocount. A positive scalar to be added for the expected counts of
#' E-step. Only used in EM algorithm. Default is 0. Larger values can be used
#' to avoid zero probabilities in initial, transition, and emission
#' probabilities, i.e. these have similar role as `lambda`.
#' @param store_data If `TRUE` (default), original data frame passed as `data`
#' is stored to the model object. For large datasets, this can be set to
#' `FALSE`, in which case you might need to pass the data separately to some
Expand Down Expand Up @@ -88,7 +92,7 @@ estimate_nhmm <- function(
transition_formula = ~1, emission_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL,
inits = "random", init_sd = 2, restarts = 0L, lambda = 0, method = "EM",
store_data = TRUE, ...) {
pseudocount = 0, store_data = TRUE, ...) {

call <- match.call()

Expand All @@ -110,7 +114,8 @@ estimate_nhmm <- function(
if (store_data) {
model$data <- data
}
out <- fit_nhmm(model, inits, init_sd, restarts, lambda, method, ...)
out <- fit_nhmm(model, inits, init_sd, restarts, lambda, method, pseudocount,
...)
attr(out, "call") <- call
out
}
11 changes: 4 additions & 7 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#' Estimate a Mixture Non-homogeneous Hidden Markov Model
#'
#' @noRd
fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
save_all_solutions = FALSE, control_restart = list(),
control_mstep = list(), ...) {
fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
pseudocount = 0, save_all_solutions = FALSE,
control_restart = list(), control_mstep = list(), ...) {
stopifnot_(
checkmate::test_int(x = restarts, lower = 0L),
"Argument {.arg restarts} must be a single integer."
Expand All @@ -21,10 +21,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
list(...)
)
control_restart <- utils::modifyList(control, control_restart)
control_mstep <- utils::modifyList(
c(control, list(pseudocount = 0)),
control_mstep
)
control_mstep <- utils::modifyList(control, control_mstep)

M <- model$n_symbols
S <- model$n_states
Expand Down
15 changes: 6 additions & 9 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Estimate a Non-homogeneous Hidden Markov Model
#'
#' @noRd
fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocount = 0,
save_all_solutions = FALSE, control_restart = list(),
control_mstep = list(), ...) {

Expand All @@ -22,10 +22,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
list(...)
)
control_restart <- utils::modifyList(control, control_restart)
control_mstep <- utils::modifyList(
c(control, list(pseudocount = 0)),
control_mstep
)
control_mstep <- utils::modifyList(control, control_mstep)

M <- model$n_symbols
S <- model$n_states
Expand Down Expand Up @@ -235,7 +232,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
control_restart$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda)
control_mstep$print_level, lambda, pseudocount)
} else {
EM_LBFGS_nhmm_multichannel(
init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs,
Expand All @@ -246,7 +243,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
control_restart$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda)
control_mstep$print_level, lambda, pseudocount)
}
},
future.seed = TRUE)
Expand Down Expand Up @@ -274,7 +271,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
control$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, control_mstep$pseudocount)
control_mstep$print_level, lambda, pseudocount)
} else {
out <- EM_LBFGS_nhmm_multichannel(
init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs,
Expand All @@ -285,7 +282,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
control$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, control_mstep$pseudocount)
control_mstep$print_level, lambda, pseudocount)
}
end_time <- proc.time()
# if (out$status < 0) {
Expand Down
1 change: 1 addition & 0 deletions man/estimate_mnhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/estimate_nhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 10 additions & 8 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,8 @@ BEGIN_RCPP
END_RCPP
}
// EM_LBFGS_nhmm_singlechannel
Rcpp::List EM_LBFGS_nhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda);
RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP) {
Rcpp::List EM_LBFGS_nhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount);
RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand Down Expand Up @@ -432,13 +432,14 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const double >::type xtol_rel_m(xtol_rel_mSEXP);
Rcpp::traits::input_parameter< const arma::uword >::type print_level_m(print_level_mSEXP);
Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP);
rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda));
Rcpp::traits::input_parameter< const double >::type pseudocount(pseudocountSEXP);
rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount));
return rcpp_result_gen;
END_RCPP
}
// EM_LBFGS_nhmm_multichannel
Rcpp::List EM_LBFGS_nhmm_multichannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::field<arma::cube>& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda);
RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_multichannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP) {
Rcpp::List EM_LBFGS_nhmm_multichannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::field<arma::cube>& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount);
RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_multichannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand Down Expand Up @@ -471,7 +472,8 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const double >::type xtol_rel_m(xtol_rel_mSEXP);
Rcpp::traits::input_parameter< const arma::uword >::type print_level_m(print_level_mSEXP);
Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP);
rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_multichannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda));
Rcpp::traits::input_parameter< const double >::type pseudocount(pseudocountSEXP);
rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_multichannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -1355,8 +1357,8 @@ static const R_CallMethodDef CallEntries[] = {
{"_seqHMM_get_B_ame", (DL_FUNC) &_seqHMM_get_B_ame, 5},
{"_seqHMM_get_omega_ame", (DL_FUNC) &_seqHMM_get_omega_ame, 4},
{"_seqHMM_logSumExp", (DL_FUNC) &_seqHMM_logSumExp, 1},
{"_seqHMM_EM_LBFGS_nhmm_singlechannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_singlechannel, 29},
{"_seqHMM_EM_LBFGS_nhmm_multichannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_multichannel, 29},
{"_seqHMM_EM_LBFGS_nhmm_singlechannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_singlechannel, 30},
{"_seqHMM_EM_LBFGS_nhmm_multichannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_multichannel, 30},
{"_seqHMM_backward_nhmm_singlechannel", (DL_FUNC) &_seqHMM_backward_nhmm_singlechannel, 15},
{"_seqHMM_backward_nhmm_multichannel", (DL_FUNC) &_seqHMM_backward_nhmm_multichannel, 15},
{"_seqHMM_backward_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_backward_mnhmm_singlechannel, 18},
Expand Down
Loading

0 comments on commit 039f5f8

Please sign in to comment.