Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ericward-noaa committed Sep 19, 2022
0 parents commit 7304ab7
Show file tree
Hide file tree
Showing 50 changed files with 5,058 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
^.*\.Rproj$
^\.Rproj\.user$
^\.github$
1 change: 1 addition & 0 deletions .github/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.html
28 changes: 28 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Workflow derived from https://github.com/r-lib/actions/tree/master/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

name: R-CMD-check

jobs:
R-CMD-check:
runs-on: ubuntu-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_KEEP_PKG_SOURCE: yes
steps:
- uses: actions/checkout@v2

- uses: r-lib/actions/setup-r@v1
with:
use-public-rspm: true

- uses: r-lib/actions/setup-r-dependencies@v1
with:
extra-packages: rcmdcheck

- uses: r-lib/actions/check-r-package@v1
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.Rproj.user
.Rhistory
.RData
.Ruserdata
src/*.o
src/*.so
src/*.dll
.DS_Store
project.Rproj
45 changes: 45 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
Package: mvdlm
Title: Multivariate Dynamic Linear Modelling With Stan
Version: 0.1.0
Authors@R:
c(person(given = "Eric J.",
family = "Ward",
role = c("aut", "cre"),
email = "[email protected]",
comment = c(ORCID = "0000-0002-4359-0296")))
Description: Fits multivariate dynamic linear models in a Bayesian framework using Stan.
License: GPL (>=3)
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.1
Biarch: true
URL: https://github.com/atsa-es/mvdlm
BugReports: https://github.com/atsa-es/mvdlm/issues
Depends:
R (>= 4.1.0)
Imports:
broom.mixed,
methods,
gtools,
compositions,
ggplot2,
MARSS,
Rcpp (>= 0.12.0),
RcppParallel (>= 5.0.1),
rstan (>= 2.18.1),
rstantools (>= 2.1.1)
Suggests:
testthat,
knitr,
rmarkdown,
parallel
LinkingTo:
BH (>= 1.66.0),
Rcpp (>= 0.12.0),
RcppEigen (>= 0.3.3.3.0),
RcppParallel (>= 5.0.1),
rstan (>= 2.18.1),
StanHeaders (>= 2.18.0)
SystemRequirements: GNU make
VignetteBuilder: knitr
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Generated by roxygen2: do not edit by hand

export(dlm_trends)
export(fit_dlm)
import(Rcpp)
import(ggplot2)
import(methods)
importFrom(broom.mixed,tidy)
importFrom(rstan,sampling)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,model.response)
useDynLib(mvdlm, .registration = TRUE)
57 changes: 57 additions & 0 deletions R/dlm_trends.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#' Summarize and plot time varying coefficients from the fitted model
#'
#' @param fitted_model A fitted model object
#' @export
#' @return A list containing the plot and data used
#' to fit the model. These include `plot` and `b_varying`

#' @importFrom broom.mixed tidy
#' @import ggplot2
#'
#' @examples
#' \donttest{
#' set.seed(123)
#' N = 20
#' data = data.frame("y" = runif(N),
#' "cov1" = rnorm(N),
#' "cov2" = rnorm(N),
#' "year" = 1:N,
#' "season" = sample(c("A","B"), size=N, replace=T))
#' b_1 = cumsum(rnorm(N))
#' b_2 = cumsum(rnorm(N))
#' data$y = data$cov1*b_1 + data$cov2*b_2
#' time_varying = y ~ cov1 + cov2
#' formula = NULL
#' fit <- fit_dlm(formula = formula,
#' time_varying = time_varying,
#' time = "year",
#' est_df = FALSE,
#' family = c("normal"),
#' data, chains = 1, iter = 20)
#' dlm_trends(fit)
#' }
#'
dlm_trends <- function(fitted_model) {

tidy_pars <- broom.mixed::tidy(fitted_model$fit)

indx <- grep("b_varying", tidy_pars$term)
if(length(indx) == 0) {
stop("Error: time varying parameters not found")
}

b_varying = tidy_pars[indx,] # subset
b_varying$par <- rep(fit$time_varying_pars, each = fit$stan_data$nT) # add names
b_varying$time <- rep(1:fit$stan_data$nT, length(fit$time_varying_pars))

cols <- "#440154FF" # viridis::viridis(1)
g <- ggplot(b_varying, aes(time, estimate)) +
geom_ribbon(aes(ymin=estimate-1.96*std.error, ymax=estimate+1.96*std.error), fill=cols, alpha=0.5) +
geom_line(col = cols) +
facet_wrap(~par, scales="free_y") +
ylab("Estimate") +
xlab("Time") +
theme_bw() +
theme(strip.background =element_rect(fill="white"))
return(list(plot = g, b_varying = b_varying))
}
222 changes: 222 additions & 0 deletions R/fitting.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#' Fit a Bayesian Dirichlet regression model, allowing for zero-and-one inflation, covariates, and overdispersion
#'
#' Fit a Bayesian Dirichlet regression model that optionally includes covariates to estimate
#' effects of factor or continuous variables on proportions.
#'
#' @param formula The model formula for the fixed effects; at least this formula or `time_varying` needs to have the response included
#' @param time_varying The model formula for the time-varying effects; at least this formula or `formula` needs to have the response included
#' @param time String describing the name of the variable corresponding to time, defaults to "year"
#' @param est_df Whether or not to estimate deviaitions of B as Student - t with estimated degrees of freedom, defaults to `FALSE`
#' @param family, The name of the family used for the response; can be one of "normal","binomial","possion","nbinom2","gamma","lognormal"
#' @param correlated_rw, Whether to estimate time-varying parameters as correlated random walk, defaults to TRUE
#' @param data The data frame including response and covariates for all model components
#' @param chains Number of mcmc chains, defaults to 3
#' @param iter Number of mcmc iterations, defaults to 2000
#' @param warmup Number iterations for mcmc warmup, defaults to 1/2 of the iterations
#' @param ... Any other arguments to pass to [rstan::sampling()].
#' @export
#' @return A list containing the fitted model and arguments and data used
#' to fit the model. These include `model` (the fitted model object of class `stanfit`),

#' @importFrom rstan sampling
#' @importFrom stats model.frame model.matrix model.response
#' @import Rcpp
#'
#' @examples
#' \donttest{
#' set.seed(123)
#' N = 20
#' data = data.frame("y" = runif(N),
#' "cov1" = rnorm(N),
#' "cov2" = rnorm(N),
#' "year" = 1:N,
#' "season" = sample(c("A","B"), size=N, replace=TRUE))
#' b_1 = cumsum(rnorm(N))
#' b_2 = cumsum(rnorm(N))
#' data$y = data$cov1*b_1 + data$cov2*b_2
#' time_varying = y ~ cov1 + cov2
#' formula = NULL
#'
#' # fit a model with a time varying component
#' fit <- fit_dlm(formula = formula,
#' time_varying = time_varying,
#' time = "year",
#' est_df = FALSE,
#' family = c("normal"),
#' data, chains = 1, iter = 20)
#'
#' # fit a model with a time varying and fixed component (here, fixed intercept)
#' fit <- fit_dlm(formula = y ~ 1,
#' time_varying = y ~ -1 + cov1 + cov2,
#' time = "year",
#' est_df = FALSE,
#' family = c("normal"),
#' data, chains = 1, iter = 20)
#'
#' #' # fit a model with deviations modeled with a multivariate Student-t
#' fit <- fit_dlm(formula = y ~ 1,
#' time_varying = y ~ -1 + cov1 + cov2,
#' time = "year",
#' est_df = TRUE,
#' family = c("normal"),
#' data, chains = 1, iter = 20)
#'
#' #' #' # fit a model with deviations modeled with a multivariate Student-t
#' fit <- fit_dlm(formula = y ~ 1,
#' time_varying = y ~ -1 + cov1 + cov2,
#' time = "year",
#' est_df = TRUE,
#' family = c("normal"),
#' data, chains = 1, iter = 20)
#' }
#'
fit_dlm <- function(formula = NULL,
time_varying = NULL,
time = "year",
est_df = FALSE,
family = c("normal", "binomial", "poisson", "nbinom2", "gamma", "lognormal"),
correlated_rw = TRUE,
data,
chains = 3,
iter = 2000,
warmup = floor(iter / 2),
...) {

# add intercept column to data
data$`(Intercept)` <- 1

recognized_families <- c("normal", "binomial", "poisson", "nbinom2", "gamma", "lognormal")
family <- family[1]
if (family %in% recognized_families == FALSE) {
stop("Error: family not recognized")
} else {
family <- match(family, recognized_families)
}

# parse formulas
est_fixed_coef <- FALSE
est_varying_coef <- FALSE
n_fixed <- 0
n_varying <- 0

y <- NULL
tv_pars <- NULL
fixed_pars <- NULL
if (!is.null(formula)) {
model_frame <- model.frame(formula, data, na.action=na.pass)
y <- model.response(model_frame)
model_matrix <- model.matrix(formula, model_frame)
fixed_pars <- colnames(model_matrix)
est_fixed_coef <- TRUE
fixed_dat <- cbind(model_matrix, c(data[, time]))
colnames(fixed_dat)[ncol(fixed_dat)] <- "time"
fixed_dat[,ncol(fixed_dat)] = fixed_dat[,ncol(fixed_dat)] - min(fixed_dat[,ncol(fixed_dat)]) + 1
n_fixed <- ncol(fixed_dat) - 1
fixed_time <- rep(fixed_dat[, "time"], ncol(fixed_dat) - 1)
fixed_var <- sort(rep(1:n_fixed, nrow(fixed_dat)))
fixed_x <- c(as.matrix(fixed_dat[, which(colnames(fixed_dat) != "time")]))
fixed_N <- length(fixed_time)
n_fixed_NAs <- length(which(is.na(fixed_x)))
fixed_NAs <- 0 # dummy
if (n_fixed_NAs > 0) {
fixed_NAs <- c(which(is.na(fixed_x)), 0, 0)
fixed_x[which(is.na(fixed_x))] = 0
} else {
fixed_NAs <- c(0, 0)
}
} else {
n_fixed <- 0
fixed_time <- c(0, 0)
fixed_var <- c(0, 0)
fixed_x <- c(0, 0)
fixed_N <- 2
n_fixed_NAs <- 0
fixed_NAs <- c(0, 0) # dummy
}
if (!is.null(time_varying)) {
model_frame <- model.frame(time_varying, data, na.action=na.pass)
if (is.null(y)) y <- model.response(model_frame)
model_matrix <- model.matrix(time_varying, model_frame)
tv_pars <- colnames(model_matrix)
est_varying_coef <- TRUE
varying_dat <- cbind(model_matrix, c(data[, time]))
colnames(varying_dat)[ncol(varying_dat)] <- "time"
varying_dat[,ncol(varying_dat)] <- varying_dat[,ncol(varying_dat)] - min(varying_dat[,ncol(varying_dat)]) + 1
n_varying <- ncol(varying_dat) - 1
varying_time <- rep(varying_dat[, "time"], ncol(varying_dat) - 1)
varying_var <- sort(rep(1:n_varying, nrow(varying_dat)))
varying_x <- c(as.matrix(varying_dat[, which(colnames(varying_dat) != "time")]))
varying_N <- length(varying_time)
n_varying_NAs <- length(which(is.na(varying_x)))
varying_NAs <- 0 # dummy
if (n_varying_NAs > 0) {
varying_NAs <- c(which(is.na(varying_x)), 0, 0)
varying_x[which(is.na(varying_x))] = 0
} else {
varying_NAs <- c(0, 0)
}
} else {
n_varying <- 0
varying_time <- 0
varying_var <- 0
varying_x <- 0
varying_N <- 1
n_varying_NAs <- 0
varying_NAs <- c(0, 0) # dummy
}

stan_data <- list(
y = y,
y_int = as.integer(y),
N = length(y),
nT = max(c(fixed_time, varying_time), na.rm = T),
est_fixed = as.numeric(est_fixed_coef),
est_varying = as.numeric(est_varying_coef),
n_fixed_covars = n_fixed,
fixed_N = fixed_N,
fixed_time_indx = fixed_time,
fixed_var_indx = fixed_var,
fixed_x_value = fixed_x,
n_varying_covars = n_varying,
varying_N = varying_N,
varying_time_indx = varying_time,
varying_var_indx = varying_var,
varying_x_value = varying_x,
est_df = as.numeric(est_df),
family = family,
n_fixed_NAs = n_fixed_NAs,
fixed_NAs = fixed_NAs,
n_varying_NAs = n_varying_NAs,
varying_NAs = varying_NAs,
correlated_rw = as.numeric(correlated_rw)
)

pars <- c("eta", "sigma", "log_lik", "lp__")
if(est_varying_coef == TRUE) pars <- c(pars, "b_varying")
if(est_fixed_coef == TRUE) pars <- c(pars, "b_fixed")
if(family %in% c("normal","negbin2","gamma","lognormal")) pars <- c(pars, "phi")
if(est_df == TRUE) pars <- c(pars, "nu")
if(correlated_rw == TRUE) pars <- c(pars, "R", "Sigma", "Lcorr")

sampling_args <- list(
object = stanmodels$dlm,
chains = chains,
iter = iter,
warmup = warmup,
pars = pars,
data = stan_data, ...
)
fit <- do.call(sampling, sampling_args)

return(list(
fit = fit,
"fixed_pars" = fixed_pars,
"time_varying_pars" = tv_pars,
fixed_formula = formula,
time_varying_formula = time_varying,
time = time,
est_df = est_df,
stan_data = stan_data,
raw_data = data
))
}
Loading

0 comments on commit 7304ab7

Please sign in to comment.