Skip to content

Commit

Permalink
Fix quantile -> median, mean/cdf transformation (use intermediary sam…
Browse files Browse the repository at this point in the history
…ples)
  • Loading branch information
lshandross committed Oct 4, 2024
1 parent 054bffb commit cbe2acf
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 22 deletions.
74 changes: 52 additions & 22 deletions R/convert_output_types.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,36 +92,42 @@ convert_output_type <- function(model_out_tbl, new_output_type,
} else if (starting_output_type == "quantile") {
# if median output desired, and Q50 provided return exact value, otherwise
# estimate from samples
if (!("median" %in% new_output_type && 0.5 %in% starting_output_type_ids)) {
model_out_tbl <- get_samples_from_quantiles(model_out_tbl, task_id_cols, n_samples)
if (any(new_output_type != "median") || !(0.5 %in% starting_output_type_ids)) {
model_out_tbl <- model_out_tbl %>%
get_samples_from_quantiles(task_id_cols, n_samples) %>%
rbind(model_out_tbl)
}
}
# transform based on new_output_type
grouped_model_out_tbl <- model_out_tbl %>%
dplyr::group_by(.data[["model_id"]], dplyr::across(dplyr::all_of(task_id_cols)))
model_out_tbl_transform <- vector("list", length = length(new_output_type))
for (i in seq_along(new_output_type)) {
# first find new_output_type_id
new_output_type_id_tmp <- new_output_type_id
if (new_output_type[i] %in% c("mean", "median")) {
new_output_type_id_tmp <- NA
} else if (is.list(new_output_type_id)) {
new_output_type_id_tmp <- new_output_type_id[[new_output_type[i]]]
}
# if median output desired, and Q50 provided return exact value
if (new_output_type[i] == "median" && 0.5 %in% starting_output_type_ids) {
model_out_tbl_transform[[i]] <- model_out_tbl %>%
dplyr::filter(.data[["output_type_id"]] == 0.5) %>%
dplyr::filter(
.data[["output_type"]] != "sample",
.data[["output_type_id"]] == 0.5
) %>%
dplyr::mutate(
output_type = new_output_type[i],
output_type_id = NA
) %>%
as_model_out_tbl()
} else { # otherwise calculate new values
grouped_model_out_tbl <- model_out_tbl %>%
dplyr::filter(.data[["output_type"]] == "sample") %>%
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", task_id_cols))))
model_out_tbl_transform[[i]] <- grouped_model_out_tbl %>%
convert_from_sample(new_output_type[i], new_output_type_id_tmp) %>%
dplyr::ungroup()
}
# otherwise calculate new values
# first find new_output_type_id
new_output_type_id_tmp <- new_output_type_id
if (new_output_type[i] %in% c("mean", "median")) {
new_output_type_id_tmp <- NA
} else if (is.list(new_output_type_id)) {
new_output_type_id_tmp <- new_output_type_id[[new_output_type[i]]]
}
model_out_tbl_transform[[i]] <- convert_from_sample(
grouped_model_out_tbl, new_output_type[i], new_output_type_id_tmp
)
}
return(dplyr::bind_rows(model_out_tbl_transform))
}
Expand Down Expand Up @@ -210,15 +216,27 @@ get_samples_from_quantiles <- function(model_out_tbl, task_id_cols, n_samples, .
)
)
}

samples <- model_out_tbl %>%
dplyr::group_by(.data[["model_id"]], dplyr::across(dplyr::all_of(task_id_cols))) %>%
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", task_id_cols)))) %>%
dplyr::reframe(
value = distfromq::make_q_fn(
ps = as.numeric(.data[["output_type_id"]]),
qs = .data[["value"]], ...
)(stats::runif(n_samples, 0, 1))
)
return(samples)
) %>%
dplyr::ungroup()
split_samples <- split(samples, f = samples[[task_id_cols]])
formatted_samples <- split_samples %>%
purrr::map(.f = function(split_samples) {
dplyr::mutate(split_samples,
output_type = "sample",
output_type_id = as.numeric(dplyr::row_number()),
.before = "value")
}) %>%
purrr::list_rbind() %>%
as_model_out_tbl()
return(formatted_samples)
}

#' @noRd
Expand All @@ -231,15 +249,27 @@ get_samples_from_cdf <- function(model_out_tbl, task_id_cols, n_samples, ...) {
)
)
}

samples <- model_out_tbl %>%
dplyr::group_by(.data[["model_id"]], dplyr::across(dplyr::all_of(task_id_cols))) %>%
dplyr::group_by(dplyr::across(dplyr::all_of(c("model_id", task_id_cols)))) %>%
dplyr::reframe(
value = distfromq::make_q_fn(
ps = .data[["value"]],
qs = as.numeric(.data[["output_type_id"]]), ...
)(stats::runif(n_samples, 0, 1))
)
return(samples)
) %>%
dplyr::ungroup()
split_samples <- split(samples, f = samples[[task_id_cols]])
formatted_samples <- split_samples %>%
purrr::map(.f = function(split_samples) {
dplyr::mutate(split_samples,
output_type = "sample",
output_type_id = as.numeric(dplyr::row_number()),
.before = "value")
}) %>%
purrr::list_rbind() %>%
as_model_out_tbl()
return(formatted_samples)
}

#' @noRd
Expand Down
38 changes: 38 additions & 0 deletions tests/testthat/test-convert_output_types.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,44 @@ test_that("convert_output_type works (quantile >> cdf)", {
expect_equal(test, expected, tolerance = 1e-2)
})

test_that("convert_output_type works (quantile >> cdf, median)", {
ex_qs <- seq(0, 1, length.out = 500)[2:499]
model_out_tbl <- expand.grid(
grp1 = 1:2,
model_id = LETTERS[1:2],
output_type = "quantile",
output_type_id = ex_qs,
stringsAsFactors = FALSE
) %>%
dplyr::mutate(mean = grp1 * ifelse(model_id == "A", 1, 3),
value = qnorm(ex_qs, mean)) %>%
dplyr::select(-mean)
new_output_type <- c("cdf", "median")
new_output_type_id <- list(cdf = seq(-2, 2, 0.5), median = NA)
expected_median <- tibble::tibble(
grp1 = rep(1:2, 2), model_id = sort(rep(LETTERS[1:2], 2))
) %>%
dplyr::mutate(value = grp1 * ifelse(model_id == "A", 1, 3)) %>%
dplyr::mutate(output_type = new_output_type[2],
output_type_id = new_output_type_id[[2]]) %>%
as_model_out_tbl()
expected_cdf <- tibble::as_tibble(expand.grid(
grp1 = 1:2,
model_id = LETTERS[1:2],
output_type = new_output_type[1],
output_type_id = new_output_type_id[[1]],
KEEP.OUT.ATTRS = FALSE,
stringsAsFactors = FALSE
)) %>%
dplyr::mutate(value = pnorm(output_type_id, grp1 * ifelse(model_id == "A", 1, 3))) %>%
dplyr::arrange(model_id, grp1) %>%
as_model_out_tbl()
expected <- rbind(expected_cdf, expected_median)
set.seed(101)
test <- convert_output_type(model_out_tbl, new_output_type, new_output_type_id)
expect_equal(test, expected, tolerance = 1e-2)
})

test_that("convert_output_type works (cdf >> mean)", {
ex_ps <- seq(-2, 10, length.out = 500)[2:499]
model_out_tbl <- expand.grid(
Expand Down

0 comments on commit cbe2acf

Please sign in to comment.