Skip to content

Commit

Permalink
Merge pull request #7 from nwfsc-cb/randomeffects
Browse files Browse the repository at this point in the history
add random effects (intercepts), bump version
  • Loading branch information
ericward-noaa authored Nov 16, 2023
2 parents c9a53a4 + f82134c commit ca0e3da
Show file tree
Hide file tree
Showing 37 changed files with 2,051 additions and 818 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: zoid
Title: Bayesian Zero-and-One Inflated Dirichlet Regression Modelling
Version: 1.2.0
Version: 1.3.0
Authors@R:
c(person(given = "Eric J.",
family = "Ward",
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ importFrom(compositions,fitDirichlet)
importFrom(gtools,rdirichlet)
importFrom(rstan,extract)
importFrom(rstan,sampling)
importFrom(stats,as.formula)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
Expand Down
89 changes: 86 additions & 3 deletions R/fitting.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@
#' fit <- fit_zoid(data_matrix = y)
#'
#' # fit a model with 1 factor
#' design <- data.frame("y" = c(1, 1, 1), "fac" = c("spring", "spring", "fall"))
#' design <- data.frame("fac" = c("spring", "spring", "fall"))
#' fit <- fit_zoid(formula = ~fac, design_matrix = design, data_matrix = y)
#' }
#' # try a model with random effects
#' set.seed(123)
#' y <- matrix(runif(99,1,4), ncol=3)
#' design <- data.frame("fac" = sample(letters[1:5], size=nrow(y), replace=TRUE))
#' design$fac <- as.factor(design$fac)
#' fit <- fit_zoid(formula = ~(1|fac), design_matrix = design, data_matrix = y)
#'
fit_zoid <- function(formula = NULL,
design_matrix,
Expand All @@ -55,9 +61,23 @@ fit_zoid <- function(formula = NULL,
data_matrix <- matrix(data_matrix, nrow = 1)
}

# fill with dummy values
parsed_res <- list(design_matrix = matrix(0, nrow(data_matrix),ncol=1),
var_indx = 1,
n_re_by_group = 1,
tot_re = 1,
n_groups = 1)
est_re <- FALSE
if (!is.null(formula)) {
model_frame <- model.frame(formula, design_matrix)
model_matrix <- model.matrix(formula, model_frame)
# extract the random effects
res <- parse_re_formula(formula, design_matrix)
if(length(res$var_indx) > 0) {
parsed_res <- res # only update if REs are in formula
est_re <- TRUE
model_matrix <- res$fixed_design_matrix
}
} else {
model_matrix <- matrix(1, nrow = nrow(data_matrix))
colnames(model_matrix) <- "(Intercept)"
Expand All @@ -84,10 +104,17 @@ fit_zoid <- function(formula = NULL,
overdisp = ifelse(overdispersion == TRUE, 1, 0),
overdispersion_sd = overdispersion_sd,
postpred = ifelse(posterior_predict == TRUE, 1, 0),
prior_sd = sd_prior
prior_sd = sd_prior,
design_Z = parsed_res$design_matrix, # design matrix for Z (random int)
re_var_indx = c(parsed_res$var_indx, 1), # index of the group for each re
n_re_by_group = c(parsed_res$n_re_by_group, 1), # number of random ints per group
tot_re = parsed_res$tot_re, # total number of random ints, across all groups
n_groups = parsed_res$n_group,
est_re = as.numeric(est_re)
)

pars <- c("beta", "log_lik", "mu")
if(est_re == TRUE) pars <- c(pars, "zeta", "zeta_sds")
if (overdispersion == TRUE) pars <- c(pars, "phi")
if (posterior_predict == TRUE) pars <- c(pars, "ynew")
if (moment_match == TRUE) pars <- c(pars, "phi_inv", "beta_raw", "p_zero", "p_one")
Expand All @@ -111,6 +138,62 @@ fit_zoid <- function(formula = NULL,
data_matrix = data_matrix,
overdispersion = overdispersion,
overdispersion_prior = prior,
posterior_predict = posterior_predict
posterior_predict = posterior_predict,
stan_data = stan_data
))
}


#' Fit a trinomial mixture model that optionally includes covariates to estimate
#' effects of factor or continuous variables on proportions.
#'
#' @param formula The model formula for the design matrix.
#' @param data The data matrix used to construct RE design matrix
#' @importFrom stats model.matrix as.formula
parse_re_formula <- function(formula, data) {
# Convert the formula to a character string
formula_str <- as.character(formula)
# Split the formula into parts based on '+' and '-' symbols
formula_parts <- unlist(strsplit(formula_str, split = "[-+]", perl = TRUE))
# Trim whitespace from each part
formula_parts <- trimws(formula_parts)
# Identify parts containing a bar '|'
random_effects <- grep("\\|", formula_parts, value = TRUE)
fixed_effects <- setdiff(formula_parts, random_effects)

# Create design matrix for fixed effects. Catch the cases where no fixed
# effects are included, or intercept-only models used
if (length(fixed_effects) > 1 || (length(fixed_effects) == 1 && fixed_effects != "~")) {
fixed_formula_str <- paste("~", paste(fixed_effects, collapse = "+"))
} else {
fixed_formula_str <- "~ 1" # Only intercept
}
fixed_design_matrix <- model.matrix(as.formula(fixed_formula_str), data)

random_effect_group_names <- sapply(random_effects, function(part) {
# Extract the part after the '|'
split_part <- strsplit(part, "\\|", perl = TRUE)[[1]]
# Remove the closing parenthesis and trim
group_name <- gsub("\\)", "", split_part[2])
trimws(group_name)
})

# create design matrices by group
for(i in 1:length(random_effects)) {
new_formula <- as.formula(paste("~", random_effect_group_names[i], "-1"))
if(i ==1) {
design_matrix <- model.matrix(new_formula, data)
var_indx <- rep(1, ncol(design_matrix))
n_re <- length(var_indx)
} else {
design_matrix <- cbind(design_matrix, model.matrix(new_formula, data))
var_indx <- c(var_indx, rep(i, ncol(design_matrix)))
n_re <- c(n_re, length(ncol(design_matrix)))
}
}
n_groups <- 0
if(length(var_indx) > 0) n_groups <- max(var_indx)
return(list(design_matrix = design_matrix, var_indx = var_indx, n_re_by_group = n_re,
tot_re = sum(n_re), n_groups = n_groups,
fixed_design_matrix = fixed_design_matrix))
}
23 changes: 23 additions & 0 deletions R/get_pars.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,28 @@ get_pars <- function(fitted_model, conf_int = 0.05) {
par_list$phi <- phi
}

# include zetas (random group intercepts)
if (fitted_model$stan_data$est_re == 1) {
n_group <- dim(pars$zeta)[2]
n_cov <- dim(pars$zeta)[3]
zetas <- expand.grid(
"group" = seq(1, n_group),
"cov" = seq(1, n_cov),
"par" = NA,
"mean" = NA,
"median" = NA,
"lo" = NA,
"hi" = NA
)
for (i in 1:nrow(zetas)) {
zetas$mean[i] <- mean(pars$zeta[, zetas$group[i], zetas$cov[i]])
zetas$median[i] <- median(pars$zeta[, zetas$group[i], zetas$cov[i]])
zetas$lo[i] <- quantile(pars$zeta[, zetas$group[i], zetas$cov[i]], conf_int / 2.0)
zetas$hi[i] <- quantile(pars$zeta[, zetas$group[i], zetas$cov[i]], 1 - conf_int / 2.0)
zetas$par[i] <- fitted_model$par_names[zetas$cov[i]]
}
par_list$zetas <- zetas
}

return(par_list)
}
9 changes: 6 additions & 3 deletions docs/404.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ca0e3da

Please sign in to comment.