Skip to content

Commit

Permalink
gganno: only reshape matrix not data.frame
Browse files Browse the repository at this point in the history
  • Loading branch information
Yunuuuu committed Jul 11, 2024
1 parent 5a017b6 commit da2bb73
Show file tree
Hide file tree
Showing 24 changed files with 779 additions and 474 deletions.
126 changes: 68 additions & 58 deletions R/eanno.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,27 @@
#' Manual](https://jokergoo.github.io/ComplexHeatmap-reference/book/heatmap-annotations.html#implement-new-annotation-functions)
#' for details.
#'
#' The function must have at least Four arguments: `index`, `k`, `n`, and
#' `matrix` (the names of the arguments can be arbitrary) where `k` and `n` are
#' optional. `index` corresponds to the indices of rows or columns of the
#' heatmap. The value of `index` is not necessarily to be the whole row indices
#' or column indices in the heatmap. It can also be a subset of the indices if
#' the annotation is split into slices according to the split of the heatmap.
#' `index` is reordered according to the reordering of heatmap rows or columns
#' (e.g. by clustering). So, `index` actually contains a list of row or column
#' indices for the current slice after row or column reordering. `matrix` will
#' contain the data passed into the argument `matrix`.
#' The function must have at least Four arguments: `index`, `k`, `n` (the names
#' of the arguments can be arbitrary) where `k` and `n` are optional. `index`
#' corresponds to the indices of rows or columns of the heatmap. The value of
#' `index` is not necessarily to be the whole row indices or column indices in
#' the heatmap. It can also be a subset of the indices if the annotation is
#' split into slices according to the split of the heatmap. `index` is
#' reordered according to the reordering of heatmap rows or columns (e.g. by
#' clustering). So, `index` actually contains a list of row or column indices
#' for the current slice after row or column reordering.
#'
#' k corresponds to the current slice and n corresponds to the total number of
#' slices.
#' `k` corresponds to the current slice and `n` corresponds to the total number
#' of slices.
#'
#' You can always use `self` to indicates the matrix attached in this
#' You can always use `self` to indicates the `data` attached in this
#' annotation.
#'
#' @param ... Additional arguments passed on to `draw_fn`. Only named arguments
#' can be subsettable.
#' @param matrix A matrix, if it is a simple vector, it will be converted to a
#' one-column matrix. Data.frame will also be coerced into matrix. If `NULL`,
#' the matrix from heatmap will be used. You can also provide a function to
#' transform the matrix.
#' @param data A `matrix` or `data.frame`, if it is a simple vector, it will be
#' converted to a one-column matrix. If `NULL`, the matrix from the heatmap will
#' be used. You can also provide a function to transform the matrix.
#' @inheritParams ComplexHeatmap::AnnotationFunction
#' @param subset_rule A list of function to subset variables in `...`.
#' @param fun_name Name of the annotation function, only used for message.
Expand Down Expand Up @@ -70,15 +68,15 @@
#' if (k == 1) grid.yaxis()
#' popViewport()
#' },
#' matrix = rnorm(10L), subset_rule = TRUE,
#' data = rnorm(10L), subset_rule = TRUE,
#' height = unit(2, "cm")
#' )
#' draw(anno)
#' draw(anno[1:2])
#' @seealso [AnnotationFunction][ComplexHeatmap::AnnotationFunction]
#' @return A `ExtendedAnnotation` object.
#' @export
eanno <- function(draw_fn, ..., matrix = NULL, which = NULL, subset_rule = NULL,
eanno <- function(draw_fn, ..., data = NULL, which = NULL, subset_rule = NULL,
width = NULL, height = NULL, show_name = TRUE,
legends_margin = NULL, legends_panel = NULL,
fun_name = NULL) {
Expand All @@ -91,14 +89,14 @@ eanno <- function(draw_fn, ..., matrix = NULL, which = NULL, subset_rule = NULL,
# package namespace can be used directly
draw_fn <- allow_lambda(draw_fn)
assert_(draw_fn, is.function, "a function")
matrix <- allow_lambda(matrix)
if (is.null(matrix)) {
data <- allow_lambda(data)
if (is.null(data)) {
n <- NA
} else if (is.function(matrix)) {
} else if (is.function(data)) {
n <- NA
} else {
matrix <- build_matrix(matrix)
n <- nrow(matrix)
data <- build_anno_data(data)
n <- nrow(data)
}
which <- eheat_which(which)

Expand All @@ -113,7 +111,7 @@ eanno <- function(draw_fn, ..., matrix = NULL, which = NULL, subset_rule = NULL,
if (!is_scalar(subset_rule)) {
cli::cli_abort("{.arg subset_rule} must be a single boolean value")
} else if (is.na(subset_rule)) {
cli::cli_abort("{.arg subset_rule} cannot be missing value")
cli::cli_abort("{.arg subset_rule} cannot be `NA`")
}

if (subsettable <- subset_rule) {
Expand Down Expand Up @@ -146,7 +144,7 @@ eanno <- function(draw_fn, ..., matrix = NULL, which = NULL, subset_rule = NULL,
# contruct ExtendedAnnotation -----------------------------
anno <- methods::new("ExtendedAnnotation")
anno@dots <- dots
anno@matrix <- matrix
anno@data <- data
anno@which <- which
anno@fun <- draw_fn
anno@fun_name <- fun_name %||% "eanno"
Expand Down Expand Up @@ -184,6 +182,8 @@ eanno <- function(draw_fn, ..., matrix = NULL, which = NULL, subset_rule = NULL,
if (!x@subsettable) {
cli::cli_abort("{.arg x} is not subsettable.")
}

# subset dots ---------------------------------------
rules <- x@subset_rule
x@dots[rlang::have_name(x@dots)] <- imap(
x@dots[rlang::have_name(x@dots)], function(var, nm) {
Expand All @@ -210,7 +210,15 @@ eanno <- function(draw_fn, ..., matrix = NULL, which = NULL, subset_rule = NULL,
}
}
)
if (is.matrix(x@matrix)) x@matrix <- x@matrix[i, , drop = FALSE]

# subset the annotation data ---------------------
if (inherits(x@data, c("tbl_df", "data.table"))) {
# For tibble and data.table, no `drop` argument
x@data <- x@data[i, ]
} else if (is.matrix(x@data) || is.data.frame(x@data)) {
# For matrix and data.frame
x@data <- x@data[i, , drop = FALSE]
}
if (is_scalar(x@n) && is.na(x@n)) return(x) # styler: off
if (is.logical(i)) {
x@n <- sum(i)
Expand All @@ -227,14 +235,14 @@ eanno <- function(draw_fn, ..., matrix = NULL, which = NULL, subset_rule = NULL,
methods::setClass(
"ExtendedAnnotation",
slots = list(
matrix = "ANY",
data = "ANY",
dots = "list",
legends_margin = "list",
legends_panel = "list",
initialized = "logical"
),
prototype = list(
matrix = NULL,
data = NULL,
dots = list(),
legends_margin = list(),
legends_panel = list(),
Expand All @@ -244,16 +252,20 @@ methods::setClass(
)

methods::setValidity("ExtendedAnnotation", function(object) {
matrix <- object@matrix
if (!is.null(matrix) && !is.function(matrix) && !is.matrix(matrix)) {
cli::cli_abort("{.code @matrix} must be a matrix or a function or NULL")
data <- object@data
if (!is.null(data) && !is.function(data) &&
!(is.matrix(data) || inherits(data, "data.frame"))) {
cli::cli_abort(paste(
"{.code @data} must be a",
"matrix or data.frame or a function or `NULL`"
))
}
TRUE
})

wrap_anno_fn <- function(object) {
# prepare annotation function --------------------------
matrix <- object@matrix
data <- object@data
dots <- object@dots
fn <- object@fun
args <- formals(fn)
Expand All @@ -262,7 +274,7 @@ wrap_anno_fn <- function(object) {
# also catches the case where there's a `self = NULL` argument.
if (!is.null(.subset2(args, "self")) || "self" %in% names(args)) {
function(index, k, n) {
rlang::inject(fn(index, k, n, !!!dots, self = matrix))
rlang::inject(fn(index, k, n, !!!dots, self = data))
}
} else {
function(index, k, n) {
Expand Down Expand Up @@ -290,48 +302,48 @@ methods::setMethod(
id <- sprintf("%s (%s)", object@fun_name, name)
}
# prepare ExtendedAnnotation matrix data ---------------------------
mat <- object@matrix
anno_data <- object@data
if (is.null(heatmap)) {
heat_matrix <- NULL
} else {
heat_matrix <- heatmap@matrix
}
if (is.null(heat_matrix) && (is.null(mat) || is.function(mat))) {
if (is.null(heat_matrix) &&
(is.null(anno_data) || is.function(anno_data))) {
cli::cli_abort(paste(
"You must provide a matrix in", id,
"You must provide data (matrix or data.frame) in", id,
"in order to draw {.cls {fclass(object)}} directly"
))
}
if (is.null(mat)) {
mat <- switch(which,
if (is.null(anno_data)) {
anno_data <- switch(which,
row = heat_matrix,
column = t(heat_matrix)
)
object@n <- nrow(mat)
} else if (is.function(mat)) {
data <- switch(which,
} else if (is.function(anno_data)) {
mat <- switch(which,
row = heat_matrix,
column = t(heat_matrix)
)
mat <- tryCatch(
build_matrix(mat(data)),
function(cnd) {
anno_data <- tryCatch(
build_anno_data(anno_data(mat)),
invalid_class = function(cnd) {
cli::cli_abort(paste(
"{.fn @matrix} of {id} must return a {.cls matrix},",
"{.fn @data} of {id} must return a {.cls matrix},",
"a simple vector, or a {.cls data.frame}."
))
}
)
if (nrow(mat) != nrow(data)) {
if (nrow(anno_data) != nrow(mat)) {
cli::cli_abort(paste(
"{.fn @matrix} of {id} must a {.cls matrix}",
"with {nrow(mat)} observation{?s}, but the heatmap",
"contain {nrow(data)} for {which} annotation."
"{.fn @data} of {id} return",
"{nrow(anno_data)} observation{?s}, but the heatmap",
"contain {nrow(mat)} for {which} annotation."
))
}
object@n <- nrow(mat)
}
object@matrix <- mat
object@n <- nrow(anno_data)
object@data <- anno_data

# call `eheat_prepare` to modify object after make_layout ----------
# for `eheat_prepare`, the actual geom matrix has been added
Expand Down Expand Up @@ -413,12 +425,10 @@ methods::setMethod(
}
if (missing(index)) {
if (is.na(object@n)) {
cli::cli_abort(
paste(
"You must provide {.arg index} to draw",
"{.cls {fclass(object)}} directly"
)
)
cli::cli_abort(paste(
"You must provide {.arg index} to draw",
"{.cls {fclass(object)}} directly"
))
}
index <- seq_len(object@n)
}
Expand Down
2 changes: 1 addition & 1 deletion R/eheat.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
#' @name eheat
eheat <- function(matrix, ...,
legends_margin = list(), legends_panel = list()) {
matrix <- build_matrix(matrix)
matrix <- build_heatmap_matrix(matrix)
out <- ComplexHeatmap::Heatmap(matrix = matrix, ...)
out <- methods::as(out, "ExtendedHeatmap")
out@legends_margin <- legends_margin
Expand Down
73 changes: 43 additions & 30 deletions R/gganno.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,40 @@
#' @section ggfn:
#'
#' `ggfn` accept a ggplot2 object with a default data and mapping created by
#' `ggplot(data, aes(.data$x))` / `ggplot(data, ggplot2::aes(y = .data$y))`.
#' The original matrix will be converted into a long-data.frame (`gganno` always
#' regard row as the observations) with following columns:
#' `ggplot(data, aes(.data$x))` / `ggplot(data, ggplot2::aes(y = .data$y))`.
#'
#' If the original data is a matrix, it'll be reshaped into a long-format
#' data frame in the `ggplot2` plot data. The final ggplot2 plot data will
#' contain following columns:
#' - `.slice`: the slice row (which = `"row"`) or column (which = `"column"`)
#' number.
#' - `.row_names` and `.column_names`: the row and column names of the original
#' matrix (only applicable when names exist).
#' - `.row_index` and `.column_index`: the row and column index of the original
#' matrix.
#' - `.row_names` and `.row_index`: the row names (only applicable when names
#' exist) and index of the original data.
#' - `.column_names` and `.column_index`: the column names (only applicable when
#' names exist) and index of the original data (`only applicable when
#' the original data is a matrix`).
#' - `x` / `y`: indicating the x-axis (or y-axis) coordinates. Don't use
#' [coord_flip][ggplot2::coord_flip] to flip coordinates as it may disrupt
#' internal operations.
#' - `value`: the actual matrix value of the annotation matrix.
#' - `value`: the actual matrix value of the annotation matrix (`only applicable
#' when the original data is a matrix`).
#'
#' @inherit ggheat
#' @seealso [eanno]
#' @examples
#' draw(gganno(function(p) {
#' p + geom_point(aes(y = value))
#' }, matrix = rnorm(10L), height = unit(10, "cm"), width = unit(0.7, "npc")))
#' }, data = rnorm(10L), height = unit(10, "cm"), width = unit(0.7, "npc")))
#' @return A `ggAnno` object.
#' @export
#' @name gganno
gganno <- function(ggfn, ..., matrix = NULL,
gganno <- function(ggfn, ..., data = NULL,
which = NULL, width = NULL, height = NULL) {
out <- eanno(
draw_fn = ggfn, ..., matrix = matrix,
draw_fn = ggfn, ..., data = data, subset_rule = NULL,
which = which, width = width, height = height,
show_name = FALSE, fun_name = "gganno"
show_name = FALSE, fun_name = "gganno",
legends_margin = NULL, legends_panel = NULL
)
out <- methods::as(out, "ggAnno")
out
Expand All @@ -65,9 +70,9 @@ eheat_prepare.ggAnno <- function(object, ..., viewport, heatmap, name) {
}
which <- object@which
# we always regard matrix row as the observations
matrix <- object@matrix
data <- object@data
if (is.null(heatmap)) {
order_list <- list(seq_len(nrow(matrix)))
order_list <- list(seq_len(nrow(data)))
} else {
order_list <- switch(which,
row = heatmap@row_order_list,
Expand All @@ -79,19 +84,25 @@ eheat_prepare.ggAnno <- function(object, ..., viewport, heatmap, name) {
} else {
with_slice <- FALSE
}
row_nms <- rownames(matrix)
col_nms <- colnames(matrix)
data <- as_tibble0(matrix, rownames = NULL) # nolint
colnames(data) <- seq_len(ncol(data))
data$.row_index <- seq_len(nrow(data))
data <- tidyr::pivot_longer(data,
cols = !".row_index",
names_to = ".column_index",
values_to = "value"
)
data$.column_index <- as.integer(data$.column_index)
if (!is.null(row_nms)) data$.row_names <- row_nms[data$.row_index]
if (!is.null(col_nms)) data$.column_names <- col_nms[data$.column_index]
if (is.matrix(data)) {
row_nms <- rownames(data)
col_nms <- colnames(data)
data <- as_tibble0(data, rownames = NULL) # nolint
colnames(data) <- seq_len(ncol(data))
data$.row_index <- seq_len(nrow(data))
data <- tidyr::pivot_longer(data,
cols = !".row_index",
names_to = ".column_index",
values_to = "value"
)
data$.column_index <- as.integer(data$.column_index)
if (!is.null(row_nms)) data$.row_names <- row_nms[data$.row_index]
if (!is.null(col_nms)) data$.column_names <- col_nms[data$.column_index]
} else {
row_nms <- rownames(data)
data <- as_tibble0(data, rownames = ".row_names")
data$.row_index <- seq_len(nrow(data))
}

coords <- data_frame0(
.slice = rep(
Expand All @@ -104,7 +115,7 @@ eheat_prepare.ggAnno <- function(object, ..., viewport, heatmap, name) {
data <- merge(coords, data, by = ".row_index", all = FALSE)
nms <- c(
".slice", ".row_names", ".column_names",
".row_index", ".column_index", "x", "y", "value"
".row_index", ".column_index", "x", "y"
)
if (which == "row") {
data <- rename(data, c(x = "y"))
Expand All @@ -118,9 +129,11 @@ eheat_prepare.ggAnno <- function(object, ..., viewport, heatmap, name) {
} else {
data$y <- reverse_trans(data$y)
}
p <- ggplot(data[intersect(nms, names(data))], aes(y = .data$y))
data <- data[union(intersect(nms, names(data)), names(data))]
p <- ggplot(data, aes(y = .data$y))
} else {
p <- ggplot(data[intersect(nms, names(data))], aes(x = .data$x))
data <- data[union(intersect(nms, names(data)), names(data))]
p <- ggplot(data, aes(x = .data$x))
}
p <- rlang::inject(object@fun(p, !!!object@dots))
object@dots <- list() # remove dots
Expand Down
Loading

0 comments on commit da2bb73

Please sign in to comment.