Skip to content

Commit

Permalink
fixed tests, tweaked M-steps, penalty to lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 7, 2024
1 parent 8cfee87 commit 99d6883
Show file tree
Hide file tree
Showing 20 changed files with 560 additions and 395 deletions.
8 changes: 4 additions & 4 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ log_objectivex <- function(transition, emission, init, obs, ANZ, BNZ, INZ, nSymb
.Call(`_seqHMM_log_objectivex`, transition, emission, init, obs, ANZ, BNZ, INZ, nSymbols, coef, X, numberOfStates, threads)
}

EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level, penalty) {
.Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level, penalty)
EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda) {
.Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda)
}

EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level, penalty) {
.Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level, penalty)
EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda) {
.Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda)
}

backward_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, iv_pi, iv_A, iv_B, tv_A, tv_B) {
Expand Down
8 changes: 6 additions & 2 deletions R/estimate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ estimate_mnhmm <- function(
transition_formula = ~1, emission_formula = ~1, cluster_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL,
channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2,
restarts = 0L, store_data = TRUE, ...) {
restarts = 0L, lambda = 0, method = "EM", store_data = TRUE, ...) {

call <- match.call()
model <- build_mnhmm(
Expand All @@ -56,10 +56,14 @@ estimate_mnhmm <- function(
checkmate::test_flag(x = store_data),
"Argument {.arg store_data} must be a single {.cls logical} value."
)
stopifnot_(
checkmate::check_number(lambda, lower = 0),
"Argument {.arg lambda} must be a single non-negative {.cls numeric} value."
)
if (store_data) {
model$data <- data
}
out <- fit_mnhmm(model, inits, init_sd, restarts, ...)
out <- fit_mnhmm(model, inits, init_sd, restarts, lambda, method, ...)

attr(out, "call") <- call
out
Expand Down
16 changes: 14 additions & 2 deletions R/estimate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@
#' of the regression coefficients to zero, use `init_sd = 0`.
#' @param restarts Number of times to run optimization using random starting
#' values (in addition to the final run). Default is 0.
#' @param lambda Penalization factor `lambda` for penalized log-likelihood, where the
#' penalization is `lambda * sum(parameters^2)/(2 * n_obs)`, where `n_obs` is
#' the number of non-missing observations.
#' @param method Optimization method used. Default is `"EM"` which uses EM
#' algorithm with L-BFGS in the M-step. Another option is `"DNM"` which uses
#' direct maximization of the log-likelihood using [nloptr::nloptr()].
#' @param store_data If `TRUE` (default), original data frame passed as `data`
#' is stored to the model object. For large datasets, this can be set to
#' `FALSE`, in which case you might need to pass the data separately to some
Expand Down Expand Up @@ -81,11 +87,13 @@ estimate_nhmm <- function(
observations, n_states, initial_formula = ~1,
transition_formula = ~1, emission_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL,
inits = "random", init_sd = 2, restarts = 0L,
inits = "random", init_sd = 2, restarts = 0L, lambda = 0, method = "EM",
store_data = TRUE, ...) {

call <- match.call()

method <- match.arg(method, c("DNM", "EM"))

model <- build_nhmm(
observations, n_states, initial_formula,
transition_formula, emission_formula, data, time, id, state_names,
Expand All @@ -95,10 +103,14 @@ estimate_nhmm <- function(
checkmate::test_flag(x = store_data),
"Argument {.arg store_data} must be a single {.cls logical} value."
)
stopifnot_(
checkmate::check_number(lambda, lower = 0),
"Argument {.arg lambda} must be a single non-negative {.cls numeric} value."
)
if (store_data) {
model$data <- data
}
out <- fit_nhmm(model, inits, init_sd, restarts, ...)
out <- fit_nhmm(model, inits, init_sd, restarts, lambda, method, ...)
attr(out, "call") <- call
out
}
69 changes: 28 additions & 41 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
#' Estimate a Mixture Non-homogeneous Hidden Markov Model
#'
#' @noRd
fit_mnhmm <- function(model, inits, init_sd, restarts,
save_all_solutions = FALSE, ...) {
fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
save_all_solutions = FALSE,
control_restart = list(), control_mstep = list(), ...) {
stopifnot_(
checkmate::test_int(x = restarts, lower = 0L),
"Argument {.arg restarts} must be a single integer."
)
control <- utils::modifyList(
list(
ftol_abs = 1e-8,
ftol_rel = 1e-8,
xtol_abs = 1e-4,
xtol_rel = 1e-4,
maxeval = 1e4,
print_level = 0,
algorithm = "NLOPT_LD_LBFGS"
),
list(...)
)
control_restart <- utils::modifyList(control, control_restart)
control_mstep <- utils::modifyList(control, control_mstep)

M <- model$n_symbols
S <- model$n_states
D <- model$n_clusters
Expand Down Expand Up @@ -47,9 +63,8 @@ fit_mnhmm <- function(model, inits, init_sd, restarts,
K_omega <- nrow(X_omega)
Ti <- model$sequence_lengths
n_obs <- nobs(model)
dots <- list(...)

if (isTRUE(dots$maxeval < 0)) {
if (isTRUE(control$maxeval < 0)) {
pars <- unlist(create_initial_values(
inits, S, M, init_sd, K_pi, K_A, K_B, K_omega, D
))
Expand Down Expand Up @@ -88,22 +103,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts,
)
return(model)
}
if (is.null(dots$algorithm))
dots$algorithm <- "NLOPT_LD_LBFGS"
need_grad <- grepl("NLOPT_LD_", dots$algorithm)
if (is.null(dots$maxeval))
dots$maxeval <- 10000L
if (is.null(dots$xtol_abs))
dots$xtol_abs <- rep(1e-8, attr(model, "df"))
if (is.null(dots$xtol_rel))
dots$xtol_rel <- 0
if (is.null(dots$ftol_abs))
dots$ftol_abs <- 1e-8
if (is.null(dots$ftol_rel))
dots$ftol_rel <- 1e-8
if (is.null(dots$check_derivatives))
dots$check_derivatives <- FALSE

need_grad <- grepl("NLOPT_LD_", control$algorithm)
if (C == 1L) {
if (need_grad) {
objectivef <- function(pars) {
Expand All @@ -120,8 +120,8 @@ fit_mnhmm <- function(model, inits, init_sd, restarts,
obs, iv_omega, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti
)
list(
objective = - out$loglik / n_obs,
gradient = - unlist(out[-1]) / n_obs
objective = - (out$loglik - 0.5 * lambda * sum(pars^2)) / n_obs,
gradient = - (unlist(out[-1]) - lambda * pars) / n_obs
)
}
} else {
Expand All @@ -139,7 +139,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts,
obs, iv_omega, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti
)

- sum(apply(out[, T_, ], 2, logSumExp)) / n_obs
- (sum(apply(out[, T_, ], 2, logSumExp)) - 0.5 * lambda * sum(pars^2)) / n_obs
}
}
} else {
Expand All @@ -164,8 +164,8 @@ fit_mnhmm <- function(model, inits, init_sd, restarts,
obs, iv_omega, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti
)
list(
objective = - out$loglik / n_obs,
gradient = - unlist(out[-1]) / n_obs
objective = - (out$loglik - 0.5 * lambda * sum(pars^2)) / n_obs,
gradient = - (unlist(out[-1]) - lambda * pars) / n_obs
)
}
} else {
Expand All @@ -190,33 +190,20 @@ fit_mnhmm <- function(model, inits, init_sd, restarts,
eta_omega, X_omega,
obs, M)

- sum(apply(out[, T_, ], 2, logSumExp)) / n_obs
- (sum(apply(out[, T_, ], 2, logSumExp)) - 0.5 * lambda * sum(pars^2)) / n_obs
}
}
}
all_solutions <- NULL
start_time <- proc.time()
if (restarts > 0L) {
dots$control_restart$algorithm <- dots$algorithm
if (is.null(dots$control_restart$maxeval))
dots$control_restart$maxeval <- dots$maxeval
if (is.null(dots$control_restart$print_level))
dots$control_restart$print_level <- 0
if (is.null(dots$control_restart$xtol_abs))
dots$control_restart$xtol_abs <-dots$xtol_abs
if (is.null(dots$control_restart$ftol_abs))
dots$control_restart$ftol_abs <- dots$ftol_abs
if (is.null(dots$control_restart$xtol_rel))
dots$control_restart$xtol_rel <- dots$xtol_rel
if (is.null(dots$control_restart$ftol_rel))
dots$control_restart$ftol_rel <- dots$ftol_rel
out <- future.apply::future_lapply(seq_len(restarts), function(i) {
init <- unlist(create_initial_values(
inits, S, M, init_sd, K_pi, K_A, K_B, K_omega, D
))
nloptr(
x0 = init, eval_f = objectivef,
opts = dots$control_restart
opts = control_restart
)
},
future.seed = TRUE)
Expand All @@ -236,7 +223,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts,
}
out <- nloptr(
x0 = init, eval_f = objectivef,
opts = dots
opts = control
)
end_time <- proc.time()
if (out$status < 0) {
Expand Down
Loading

0 comments on commit 99d6883

Please sign in to comment.