From d052a11ae8756715a1a46eee8bce4a83451a26fb Mon Sep 17 00:00:00 2001 From: wlandau Date: Mon, 4 Nov 2024 15:57:19 -0500 Subject: [PATCH] Fix unlisting in tar_map_rep() --- R/tar_map_rep_raw.R | 6 +- tests/testthat/test-tar_map_rep.R | 177 +++++++++++++++--------------- 2 files changed, 95 insertions(+), 88 deletions(-) diff --git a/R/tar_map_rep_raw.R b/R/tar_map_rep_raw.R index 4988cb7..4b4e5e4 100644 --- a/R/tar_map_rep_raw.R +++ b/R/tar_map_rep_raw.R @@ -121,8 +121,11 @@ tar_map_rep_raw <- function( delimiter = delimiter, unlist = unlist ) - )[[1L]] + ) ) + if (!unlist) { + target_static <- target_static[[1L]] + } target_combine <- if_any( is.null(values) || !combine, NULL, @@ -151,6 +154,7 @@ tar_map_rep_raw <- function( ) if (unlist) { out <- unlist(out, recursive = TRUE) + names(out) <- map_chr(out, ~.x$settings$name) } out } diff --git a/tests/testthat/test-tar_map_rep.R b/tests/testthat/test-tar_map_rep.R index bca8306..b1b778a 100644 --- a/tests/testthat/test-tar_map_rep.R +++ b/tests/testthat/test-tar_map_rep.R @@ -1,96 +1,99 @@ targets::tar_test("tar_map_rep(): combine, columns, static branches", { skip_if_not_installed("dplyr") - targets::tar_script({ - f <- function(sigma1, sigma2) { - tibble::tibble( - out = sigma1 + 1000 * sigma2, - length1 = length(sigma1), - length2 = length(sigma2), - random = sample.int(1e6, size = 1) + for (unlist in c(TRUE, FALSE)) { + targets::tar_script({ + f <- function(sigma1, sigma2) { + tibble::tibble( + out = sigma1 + 1000 * sigma2, + length1 = length(sigma1), + length2 = length(sigma2), + random = sample.int(1e6, size = 1) + ) + } + hyperparameters <- tibble::tibble( + scenario = c("tight", "medium", "diffuse"), + sigma1 = c(10, 50, 50), + sigma2 = c(10, 5, 10) + ) + tarchetypes::tar_map_rep( + x, + command = f(sigma1, sigma2), + values = hyperparameters, + names = tidyselect::any_of("scenario"), + descriptions = tidyselect::any_of(c("scenario", "sigma2")), + batches = 2, + reps = 3, + unlist = !!unlist + ) + }) + # manifest + out <- targets::tar_manifest(callr_function = NULL) + out <- out[order(out$name), ] + expect_equal( + sort(out$name), + sort( + paste0("x", c("_batch", "_tight", "_medium", "_diffuse", "")) ) - } - hyperparameters <- tibble::tibble( - scenario = c("tight", "medium", "diffuse"), - sigma1 = c(10, 50, 50), - sigma2 = c(10, 5, 10) ) - tarchetypes::tar_map_rep( - x, - command = f(sigma1, sigma2), - values = hyperparameters, - names = tidyselect::any_of("scenario"), - descriptions = tidyselect::any_of(c("scenario", "sigma2")), - batches = 2, - reps = 3 + expect_equal(out$command[out$name == "x_batch"], "seq_len(2)") + expect_equal( + grepl("diffuse|medium|tight", out$name), + grepl("tar_rep_run", out$command) ) - }) - # manifest - out <- targets::tar_manifest(callr_function = NULL) - out <- out[order(out$name), ] - expect_equal( - sort(out$name), - sort( - paste0("x", c("_batch", "_tight", "_medium", "_diffuse", "")) + expect_equal( + grepl("diffuse|medium|tight", out$name), + !is.na(out$pattern) ) - ) - expect_equal(out$command[out$name == "x_batch"], "seq_len(2)") - expect_equal( - grepl("diffuse|medium|tight", out$name), - grepl("tar_rep_run", out$command) - ) - expect_equal( - grepl("diffuse|medium|tight", out$name), - !is.na(out$pattern) - ) - expect_equal( - out$name == "x", - grepl("bind_rows", out$command) - ) - expect_equal( - out$description[out$name == "x_diffuse"], - "diffuse 10" - ) - expect_equal( - out$description[out$name == "x_medium"], - "medium 5" - ) - expect_equal( - out$description[out$name == "x_tight"], - "tight 10" - ) - # network - out <- targets::tar_network(callr_function = NULL)$edges - out <- dplyr::arrange(out, from, to) - exp <- tibble::tribble( - ~from, ~to, - "f", "x_diffuse", - "f", "x_medium", - "f", "x_tight", - "x_batch", "x_diffuse", - "x_batch", "x_medium", - "x_batch", "x_tight", - "x_diffuse", "x", - "x_medium", "x", - "x_tight", "x" - ) - exp <- dplyr::arrange(exp, from, to) - expect_equal(out, exp) - # output - targets::tar_make(callr_function = NULL) - out <- dplyr::arrange(targets::tar_read(x), tar_batch, tar_rep, scenario) - d <- dplyr::distinct(out, tar_group, tar_batch, tar_rep) - expect_equal(nrow(out), nrow(d)) - expect_equal(out$out, rep(c(10050, 5050, 10010), times = 6)) - expect_equal(out$sigma1, rep(c(50, 50, 10), times = 6)) - expect_equal(out$sigma2, rep(c(10, 5, 10), times = 6)) - scenarios <- sort(unique(out$scenario)) - expect_equal(out$scenario, rep(scenarios, times = 6)) - expect_true(all(out$length1 == 1L)) - expect_true(all(out$length2 == 1L)) - expect_equal(length(unique(out$random)), nrow(out)) - # metadata - meta <- targets::tar_meta(x_diffuse) - expect_equal(length(unlist(meta$children)), 2L) + expect_equal( + out$name == "x", + grepl("bind_rows", out$command) + ) + expect_equal( + out$description[out$name == "x_diffuse"], + "diffuse 10" + ) + expect_equal( + out$description[out$name == "x_medium"], + "medium 5" + ) + expect_equal( + out$description[out$name == "x_tight"], + "tight 10" + ) + # network + out <- targets::tar_network(callr_function = NULL)$edges + out <- dplyr::arrange(out, from, to) + exp <- tibble::tribble( + ~from, ~to, + "f", "x_diffuse", + "f", "x_medium", + "f", "x_tight", + "x_batch", "x_diffuse", + "x_batch", "x_medium", + "x_batch", "x_tight", + "x_diffuse", "x", + "x_medium", "x", + "x_tight", "x" + ) + exp <- dplyr::arrange(exp, from, to) + expect_equal(out, exp) + # output + targets::tar_make(callr_function = NULL) + out <- dplyr::arrange(targets::tar_read(x), tar_batch, tar_rep, scenario) + d <- dplyr::distinct(out, tar_group, tar_batch, tar_rep) + expect_equal(nrow(out), nrow(d)) + expect_equal(out$out, rep(c(10050, 5050, 10010), times = 6)) + expect_equal(out$sigma1, rep(c(50, 50, 10), times = 6)) + expect_equal(out$sigma2, rep(c(10, 5, 10), times = 6)) + scenarios <- sort(unique(out$scenario)) + expect_equal(out$scenario, rep(scenarios, times = 6)) + expect_true(all(out$length1 == 1L)) + expect_true(all(out$length2 == 1L)) + expect_equal(length(unique(out$random)), nrow(out)) + # metadata + meta <- targets::tar_meta(x_diffuse) + expect_equal(length(unlist(meta$children)), 2L) + } }) targets::tar_test("tar_map_rep(): no combine, 1 col, static branches", {