Skip to content

Commit

Permalink
Merge pull request #100 from mayer79/normalization-check
Browse files Browse the repository at this point in the history
Revise postprocess()
  • Loading branch information
mayer79 authored Oct 30, 2023
2 parents df418d7 + 99b4ce5 commit a4bc1fb
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 39 deletions.
39 changes: 18 additions & 21 deletions R/utils_statistics.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,39 +70,36 @@ init_numerator <- function(x, way = 1L) {
#' @keywords internal
#'
#' @inheritParams H2_overall
#' @param num Matrix or vector of statistic.
#' @param denom Denominator of statistic (a matrix, number, or vector compatible with `num`).
#' @returns Matrix or vector of statistics. If length of output is 0, then `NULL`.
postprocess <- function(num, denom = 1, normalize = TRUE, squared = TRUE,
#' @param num Matrix with numerator statistics.
#' @param denom Vector or matrix with denominator statistics.
#' @returns Matrix of statistics, or `NULL`.
postprocess <- function(num, denom = rep(1, times = NCOL(num)),
normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE) {
out <- num
stopifnot(
is.matrix(num),
is.matrix(denom) || is.vector(denom), # already covered by the next condition
is.matrix(denom) && all(dim(num) == dim(denom)) || ## h2_pairwise/threeway
is.vector(denom) && length(denom) == ncol(num) ## all other stats
)
if (normalize) {
if (length(denom) == 1L || length(num) == length(denom)) {
out <- out / denom
} else if (length(denom) == ncol(num)) {
out <- sweep(out, MARGIN = 2L, STATS = denom, FUN = "/")
if (is.matrix(denom)) {
out <- num / denom
} else {
stop("Normalization error")
out <- sweep(num, MARGIN = 2L, STATS = denom, FUN = "/")
}
} else {
out <- num
}
if (!squared) {
out <- sqrt(out)
}
if (sort) {
if (is.matrix(out)) {
out <- out[order(-rowSums(out)), , drop = FALSE]
} else {
out <- sort(out, decreasing = TRUE)
}
out <- out[order(-rowSums(out)), , drop = FALSE]
}
if (!zero) {
if (is.matrix(out)) {
out <- out[rowSums(out) > 0, , drop = FALSE]
} else {
out <- out[out > 0]
}
out <- out[rowSums(out != 0) > 0, , drop = FALSE]
}
# out <- utils::head(out, n = top_m)
if (length(out) == 0L) NULL else out
}

Expand Down
4 changes: 3 additions & 1 deletion backlog/calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ hist2 <- function(x, breaks = 17L, trim = c(0.01, 0.99),
r <- stats::quantile(x, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
xx <- x[x >= r[1L] & x <= r[2L]]
}
h <- hist(xx, breaks = breaks, include.lowest = include.lowest, right = right)
h <- hist(
xx, breaks = breaks, include.lowest = include.lowest, right = right, plot = FALSE
)
b <- h$breaks
ix <- findInterval(
x, vec = b, left.open = right, rightmost.closed = include.lowest, all.inside = TRUE
Expand Down
1 change: 1 addition & 0 deletions backlog/modeltuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ perm_importance(mm, X = iris, y = "Sepal.Length", w = "Petal.Width") |>
plot(H)
h2_pairwise(H, normalize = FALSE, squared = FALSE) |>
plot(swap_dim = TRUE)

37 changes: 20 additions & 17 deletions tests/testthat/test_statistics.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

test_that("poor_man_stack() works (test could be improved", {
test_that("poor_man_stack() works (test could be improved)", {
y <- c("a", "b", "c")
z <- c("aa", "bb", "cc")
X <- data.frame(x = 1:3, y = y, z = z)
Expand Down Expand Up @@ -36,7 +36,23 @@ test_that("mat2df() works (test could be improved)", {
expect_error(mat2df(1:4))
})

test_that("postprocess() works for matrix input", {
test_that("postprocess() works for matrix num with 1 column", {
num <- cbind(1:3)
denom <- cbind(1:3)

expect_equal(postprocess(num = num, sort = FALSE), num)
expect_equal(postprocess(num = num, denom = denom, sort = FALSE), num / denom)
expect_equal(postprocess(num = num), num[3:1, , drop = FALSE])
expect_equal(postprocess(num = num, squared = FALSE), sqrt(num[3:1, , drop = FALSE]))

expect_equal(postprocess(num = num, denom = 2, sort = FALSE), num / 2)

expect_equal(postprocess(num = cbind(0:1), zero = FALSE), rbind(1))
expect_equal(postprocess(num = cbind(0:-1), zero = FALSE), rbind(-1))
expect_null(postprocess(num = cbind(0), zero = FALSE))
})

test_that("postprocess() works for matrix num with > 1 column", {
num <- cbind(a = 1:3, b = c(1, 1, 1))
denom <- cbind(a = 1:3, b = 1:3)

Expand All @@ -45,30 +61,17 @@ test_that("postprocess() works for matrix input", {
expect_equal(postprocess(num = num), num[3:1, ])
expect_equal(postprocess(num = num, squared = FALSE), sqrt(num[3:1, ]))

expect_equal(postprocess(num = num, denom = 2, sort = FALSE), num / 2)
expect_equal(postprocess(num = num, denom = c(2, 2), sort = FALSE), num / 2)
expect_equal(
postprocess(num = num, denom = 1:2, sort = FALSE),
num / cbind(c(1, 1, 1), c(2, 2, 2))
)

expect_equal(postprocess(num = cbind(0:1, 0:1), zero = FALSE), rbind(c(1, 1)))
expect_equal(postprocess(num = cbind(0:-1, 0:-1), zero = FALSE), rbind(c(-1, -1)))
expect_null(postprocess(num = cbind(0, 0), zero = FALSE))
})

test_that("postprocess() works for vector input", {
num <- 1:3
denom <- c(2, 4, 6)

expect_equal(postprocess(num = num), 3:1)
expect_equal(postprocess(num = num, denom = denom), num / denom)
expect_equal(postprocess(num = num, sort = FALSE), num)
expect_equal(postprocess(num = num, squared = FALSE), sqrt(num[3:1]))

expect_equal(postprocess(num = 0:1, denom = c(2, 2), zero = FALSE), 0.5)
expect_null(postprocess(num = 0, zero = FALSE))
})


test_that(".zap_small() works for vector input", {
expect_equal(.zap_small(1:3), 1:3)
expect_equal(.zap_small(c(1:3, NA)), c(1:3, 0))
Expand Down

0 comments on commit a4bc1fb

Please sign in to comment.