generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adapts #5 to include only R code (or #9 maybe) --------- Co-authored-by: Fuhan Yang <[email protected]> Co-authored-by: Fuhan-Yang <[email protected]>
- Loading branch information
1 parent
4bc7638
commit 0580abd
Showing
4 changed files
with
356 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
library(dplyr) | ||
library(ggplot2) | ||
|
||
### functions from bottom to top: | ||
# inv_z_scale() | ||
# data_prepare() | ||
# build_projection_model2() | ||
# | ||
# IS IN | ||
# | ||
# real_time_projection() -> plot_projection() | ||
# | ||
# IS IN | ||
# | ||
# plot_real_time_projection() | ||
|
||
# real_time_projection() -> get_mspe() | ||
# | ||
# IS IN | ||
# | ||
# evaluate_real_time_projection() | ||
|
||
|
||
#### inverse Z score ##### | ||
#' @description Inversely transform z-score standardized data to raw data | ||
#' using: raw = z*sd + mean | ||
#' | ||
#' @param old_data: the previous data that is used for z-score standardization | ||
#' @param scaled_data: current scaled data | ||
#' @return inversely z-score transformed data | ||
inv_z_scale <- function(old_data, scaled_data) { | ||
sd <- sd(old_data) | ||
m <- mean(old_data) | ||
|
||
scaled_data * sd + m | ||
} | ||
|
||
|
||
#' @description Data format preparation for model building. | ||
#' It includes: | ||
#' 1. Add `previous` predictor. | ||
#' 2. Z-score standardization of `previous`,`elapsed`, and `daily` | ||
#' @param data: the data frame need to be prepared. | ||
#' Must include columns:`previous`,`elapsed`, and `daily` | ||
#' @return Data frame with scaled `previous`,`elapsed`, and `daily` and ready to build model | ||
data_prepare <- function(data) { | ||
data$previous <- c(NA, data$daily[-nrow(data)]) # nolint | ||
data <- data[-1, ] | ||
data$previous_std <- scale(data$previous, center = T, scale = T) | ||
data$elapsed_std <- scale(data$elapsed, center = T, scale = T) | ||
data$daily_std <- scale(data$daily, center = T, scale = T) | ||
|
||
return(data) | ||
} | ||
|
||
#' @description This function uses `train_data` to train the linear regression | ||
#' and project to the last date in the `test_data`. | ||
#' It is embedded in `plot_real_time_projection` and 'evaluate_real_time_projection' | ||
#' @param train_data: data frame used to train linear regression | ||
#' @param test_data: data frame used to compare with the prediction generated by linear regression | ||
#' @return a list containing the predicted uptake and the corresponding | ||
#' observed rate under the `test_data` time frame, | ||
#' in both uptake rate `rate` and cumulative uptake `cumulative` | ||
|
||
real_time_projection <- function( | ||
train_data = nis_usa_2022, | ||
test_data = nis_usa_2023) { | ||
# transform the train data and test data to the format of model building and model prediction # | ||
scaled_train <- data_prepare(train_data) | ||
scaled_test <- data_prepare(test_data) | ||
|
||
# The first row is also removed for raw test data, to be comparable with the predicted values # | ||
test_data <- test_data[-1, ] | ||
|
||
# model building # | ||
model <- build_projection_model2(data = scaled_train) | ||
|
||
## get out-of-sample prediction ## | ||
scaled_pred <- brms::posterior_predict(model, newdata = scaled_test) | ||
|
||
## recover the prediction to the original format for comparision ## | ||
pred <- inv_z_scale(test_data$daily, scaled_pred) | ||
|
||
# predicted daily vaccine uptake # | ||
pred_summary <- data.frame( | ||
date = test_data$date, | ||
obs = test_data$daily, | ||
lower = apply(pred, 2, quantile, 0.025), | ||
upper = apply(pred, 2, quantile, 0.975), | ||
mean = colMeans(pred) | ||
) | ||
|
||
# predicted cumulative vaccine uptake # | ||
cumu <- t(apply(pred, 1, cumsum)) | ||
cumu_summary <- data.frame( | ||
mean = colMeans(cumu), | ||
obs = test_data$cumulative, | ||
lower = apply(cumu, 2, quantile, 0.025), | ||
upper = apply(cumu, 2, quantile, 0.975), | ||
date = test_data$date | ||
) | ||
|
||
return(list(rate = pred_summary, cumulative = cumu_summary)) | ||
} | ||
|
||
#' @description This function plots prediction with observed data | ||
#' This function is embeded in `plot_real_time_projection` | ||
#' @param option: which uptake to plot: 'rate' or 'cumulative' | ||
#' @param predicted: out-of-sample prediction from linear regression | ||
#' @param test: the observed data corresponding to the prediction | ||
#' @return a ggplot object | ||
|
||
plot_projection <- function(option, predicted, test) { | ||
if (option == "rate") { | ||
data_y_name <- "daily" | ||
} else if (option == "cumulative") { | ||
data_y_name <- "cumulative" | ||
} else { | ||
stop("Response variable in test is not found.") | ||
} | ||
|
||
ggplot(predicted) + | ||
geom_ribbon(aes(x = date, ymin = lower, ymax = upper), | ||
fill = "black", alpha = 0.25 | ||
) + | ||
geom_line(aes(x = date, y = mean), # Bug warning! colnames may change later # | ||
linewidth = 1.5 | ||
) + | ||
geom_line(aes(x = date, y = daily), | ||
data = test, | ||
linewidth = 1.5, color = "dodgerblue" | ||
) + | ||
xlab("Time") + | ||
ylab("% of Population") + | ||
ggtitle(paste("Forecast starts at", min(test$date))) + | ||
theme_bw() + | ||
theme(text = ggplot2::element_text(size = 20)) + | ||
annotate("text", | ||
x = test$date[round(0.9 * nrow(test))], | ||
y = 1.7 * diff(range(predicted$upper)), | ||
label = "Input", col = "dodgerblue", size = 5 | ||
) + | ||
annotate("text", | ||
x = test$date[round(0.9 * nrow(test))], | ||
y = 1.6 * diff(range(predicted$upper)), | ||
label = "Model", col = "black", size = 5 | ||
) | ||
} | ||
|
||
|
||
#' @description This function fits the linear regression and plots the prediction | ||
#' with the observed data in a time-varying manner | ||
#' @param data_option: which uptake to plot: 'rate' or 'cumulative' | ||
#' @param all_data: Data used to iteratively train and test, | ||
#' with each iteration moving forward one data point | ||
#' @param min_data_size: the minimum data size required for model training, | ||
#' So far, it is the same for model testing. | ||
#' @return a X*Y grid ggplot | ||
|
||
plot_real_time_projection <- function( | ||
data_option = "rate", | ||
all_data, min_data_size = 8) { | ||
start_date <- sort(all_data$date)[1 + min_data_size - 1] | ||
end_date <- sort(all_data$date)[nrow(all_data) - min_data_size + 1] | ||
|
||
date_series <- all_data %>% | ||
filter(date >= start_date, date <= end_date) %>% | ||
select(date) | ||
|
||
# convert single-column data.frame to vector # | ||
date_series <- as.vector(date_series$date) + as.Date("1970-01-01") | ||
ps <- list() | ||
i <- 1 | ||
|
||
for (split_date in date_series) { | ||
train_data <- all_data %>% | ||
filter(date < split_date) | ||
|
||
test_data <- all_data %>% | ||
filter(date >= split_date) | ||
|
||
output <- real_time_projection( | ||
train_data = train_data, | ||
test_data = test_data | ||
) | ||
|
||
if (data_option == "rate") { | ||
data_to_plot <- output$rate | ||
} else if (data_option == "cumulative") { | ||
data_to_plot <- output$cumulative | ||
} else { | ||
stop("Data_option is not valid.") | ||
} | ||
|
||
## plot ## | ||
plot_projection( | ||
option = data_option, | ||
predicted = data_to_plot, | ||
test = test_data | ||
) -> p | ||
|
||
ps[[i]] <- p | ||
i <- i + 1 | ||
} | ||
|
||
cowplot::plot_grid(plotlist = ps, nrow = ceiling(sqrt(length(date_series)))) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
library(dplyr) | ||
|
||
# Retrospective evaluation of forecast | ||
# MSPE only in this version. | ||
# Update in the future: | ||
# 1. Public health oriented metrics: end-of-season totals, high demand period, etc. | ||
# 2. Probabilistic metrics: QS, WIS, etc. | ||
|
||
|
||
#' @description Get mean-squared prediction error | ||
#' @param data: observed data, can be a constant or a vector | ||
#' @param pred: the prediction output from models, can be a constant or a vector | ||
#' @return MSPE | ||
get_mspe <- function(data, pred) { | ||
mean((data - pred)^2) | ||
} | ||
|
||
#' @description Retrospectively evaluate the forecast given the test data | ||
#' @param data_option: which uptake data type to evaluate? 'rate' or 'cumulative' | ||
#' @param evaluate_option: which metrics to use? 'mspe' only so far. | ||
#' Adding this indicates changing the output from the embedded functions. | ||
#' @param all_data: Data used to iteratively train and test, | ||
#' with each iteration moving forward one data point | ||
#' @param min_data_size: the minimum data size required for model training, | ||
#' So far, this number is the same for model testing. | ||
#' @return a data frame with MSPE given different initializing date | ||
|
||
evaluate_real_time_projection <- function( | ||
data_option = "rate", | ||
evaluate_option = "mspe", | ||
all_data, min_data_size = 8) { | ||
start_date <- sort(all_data$date)[1 + min_data_size - 1] | ||
end_date <- sort(all_data$date)[nrow(all_data) - min_data_size + 1] | ||
|
||
date_series <- all_data %>% | ||
filter(date >= start_date, date <= end_date) %>% | ||
select(date) | ||
|
||
# convert single-column data.frame to vector # | ||
date_series <- as.vector(date_series$date) + as.Date("1970-01-01") | ||
metrics <- data.frame() | ||
|
||
for (split_date in date_series) { | ||
train_data <- all_data %>% | ||
filter(date < split_date) | ||
|
||
test_data <- all_data %>% | ||
filter(date >= split_date) | ||
|
||
output <- real_time_projection( | ||
train_data = train_data, | ||
test_data = test_data | ||
) | ||
|
||
if (data_option == "rate") { | ||
data_to_eval <- output$rate | ||
} else if (data_option == "cumulative") { | ||
data_to_eval <- output$cumulative | ||
} else { | ||
stop("Data_option is not valid.") | ||
} | ||
|
||
## evaluation ## | ||
|
||
# MSPE # | ||
# So far, only use MSPE. | ||
# Note: What can be the code structure if multiple metrics are evaluated simultaneously | ||
if (evaluate_option == "mspe") { | ||
metric <- get_mspe(data_to_eval$obs, data_to_eval$mean) | ||
} | ||
|
||
metrics <- rbind( | ||
metrics, | ||
data.frame( | ||
forecast_date = split_date, | ||
mspe = metric | ||
) | ||
) | ||
} | ||
|
||
return(metrics) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
rm(list = ls()) | ||
# read all the R files under R folder # | ||
R_path <- "R/" | ||
purrr::map(paste0(R_path, list.files("R/")), source) | ||
|
||
# load observed data # | ||
|
||
# Load 2022 NIS data for USA | ||
nis_usa_2022 <- get_uptake_data( | ||
"https://data.cdc.gov/api/views/akkj-j5ru/rows.csv?accessType=DOWNLOAD", | ||
state = "Geography", | ||
date = c("Time.Period", "Year"), | ||
cumulative = "Estimate....", | ||
state_key = "data/USA_Name_Key.csv", | ||
date_format = "%m/%d/%Y", | ||
start_date = "9/2/2022", | ||
filters = list(c( | ||
"Indicator.Category", | ||
"Received updated bivalent booster dose (among adults who completed primary series)" | ||
)) | ||
) | ||
|
||
# Load 2023 NIS data for USA | ||
nis_usa_2023 <- get_uptake_data( | ||
"data/NIS_2023-24.csv", | ||
state = "geography", | ||
date = "date", | ||
cumulative = "estimate", | ||
state_key = "data/USA_Name_Key.csv", | ||
date_format = "%m/%d/%Y", | ||
start_date = "9/12/2023", | ||
filters = list(c("time_type", "Weekly"), c("group_name", "Overall")) | ||
) | ||
|
||
|
||
# generate projections initiated at different time # | ||
all_data <- rbind(nis_usa_2022, nis_usa_2023) | ||
|
||
# make sure there at least 4 data points for model training: 30 min to run # | ||
plot_real_time_projection( | ||
data_option = "rate", | ||
all_data = all_data | ||
) | ||
|
||
# evaluate MSPE between projections and data initiated at different time: 30 min to run # | ||
evaluate_real_time_projection(evaluate_option = "mspe", data_option = "rate", all_data = all_data) -> mspe_df | ||
|
||
mspe_df %>% | ||
mutate(forecast_date = forecast_date + as.Date("1970-01-01")) %>% | ||
filter(mspe < 1) %>% | ||
ggplot() + | ||
geom_point(aes(x = forecast_date, y = mspe)) + | ||
xlab("Forecast date") + | ||
ylab("MSPE") + | ||
theme_bw() | ||
ggsave("mspe.jpeg", width = 4, height = 3, units = "in") | ||
# Note for future upgrade: seperate model fitting from plotting and evaluation. | ||
# Repeatedly fitting the model to plot and evaluate is not time-efficient. |