Skip to content

Commit

Permalink
Fix #1199, fix #1200
Browse files Browse the repository at this point in the history
  • Loading branch information
wlandau-lilly committed Mar 3, 2020
1 parent 597f032 commit a62401b
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 4 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Version 7.11.0.9000

## Bug fixes

* Restrict static transforms so they only use the upstream part of the plan (#1199, #1200, @bart1).

# Version 7.11.0

Expand Down
9 changes: 9 additions & 0 deletions R/igraph.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ downstream_nodes <- function(graph, from) {
)
}

upstream_nodes <- function(graph, from) {
nbhd_vertices(
graph = graph,
vertices = from,
mode = "in",
order = igraph::gorder(graph)
)
}

nbhd_graph <- function(graph, vertices, mode, order) {
vertices <- nbhd_vertices(
graph = graph,
Expand Down
16 changes: 12 additions & 4 deletions R/transform_plan.R
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,27 @@ transform_plan_ <- function(
plan$transform <- lapply(plan$transform, parse_transform)
graph <- dsl_graph(plan)
order <- igraph::topo_sort(graph)$name
subplans <- split(plan, f = plan$target)
for (target in order) {
index <- which(target == plan$target)
rows <- transform_row(index, plan, graph, max_expand)
plan <- sub_in_plan(plan, rows, index)
old_cols(plan) <- old_cols
upstream_plan <- dsl_upstream_plan(target, graph, subplans)
index <- which(target == upstream_plan$target)
old_cols(upstream_plan) <- old_cols
subplans[[target]] <- transform_row(index, upstream_plan, graph, max_expand)
}
plan <- drake_bind_rows(subplans)
old_cols(plan) <- old_cols
plan <- dsl_trace(plan = plan, trace = trace)
old_cols(plan) <- plan$transform <- NULL
plan <- dsl_tidy_eval(plan = plan, tidy_eval = tidy_eval, envir = envir)
plan <- dsl_sanitize(plan = plan, sanitize = sanitize, envir = envir)
plan
}

dsl_upstream_plan <- function(target, graph, subplans) {
upstream_targets <- upstream_nodes(graph, target)
drake_bind_rows(subplans[upstream_targets])
}

dsl_trace <- function(plan, trace) {
if (!trace) {
keep <- as.character(intersect(colnames(plan), old_cols(plan)))
Expand Down
124 changes: 124 additions & 0 deletions tests/testthat/test-7-dsl.R
Original file line number Diff line number Diff line change
Expand Up @@ -3078,3 +3078,127 @@ test_with_dir("NAs removed from old grouping vars grid (#1010)", {
)
equivalent_plans(out, exp)
})

test_with_dir("static transforms use only upstream part of plan (#1199)", {
skip_on_cran()
radars <- c("radar1", "radar2")
seasons <- c("season1", "season2")
months <- c(1, 2)
radar_seasons <- expand.grid(
radar = radars,
season = seasons,
stringsAsFactors = FALSE
)
out <- drake_plan(
data = target(
get_data(radar, month),
transform = cross(radar = !!radars, month = !!months)
),
to_cross = target(
list(data),
transform = combine(data, .by = radar)
),
problem = target(
list(to_cross, season),
transform = cross(to_cross, season = !!seasons)
),
separate = target(
list(radar, season),
transform = map(.data = !!radar_seasons)
),
trace = TRUE
)
exp <- drake_plan(
data_radar1_1 = target(
command = get_data("radar1", 1),
radar = "\"radar1\"",
month = "1",
data = "data_radar1_1"
),
data_radar2_1 = target(
command = get_data("radar2", 1),
radar = "\"radar2\"",
month = "1",
data = "data_radar2_1"
),
data_radar1_2 = target(
command = get_data("radar1", 2),
radar = "\"radar1\"",
month = "2",
data = "data_radar1_2"
),
data_radar2_2 = target(
command = get_data("radar2", 2),
radar = "\"radar2\"",
month = "2",
data = "data_radar2_2"
),
problem_season1_to_cross_radar1 = target(
command = list(to_cross_radar1, "season1"),
radar = "\"radar1\"",
season = "\"season1\"",
separate = "separate_radar1_season1",
to_cross = "to_cross_radar1",
problem = "problem_season1_to_cross_radar1"
),
problem_season2_to_cross_radar1 = target(
command = list(to_cross_radar1, "season2"),
radar = "\"radar1\"",
season = "\"season2\"",
separate = "separate_radar1_season2",
to_cross = "to_cross_radar1",
problem = "problem_season2_to_cross_radar1"
),
problem_season1_to_cross_radar2 = target(
command = list(to_cross_radar2, "season1"),
radar = "\"radar1\"",
season = "\"season1\"",
separate = "separate_radar1_season1",
to_cross = "to_cross_radar2",
problem = "problem_season1_to_cross_radar2"
),
problem_season2_to_cross_radar2 = target(
command = list(to_cross_radar2, "season2"),
radar = "\"radar1\"",
season = "\"season2\"",
separate = "separate_radar1_season2",
to_cross = "to_cross_radar2",
problem = "problem_season2_to_cross_radar2"
),
separate_radar1_season1 = target(
command = list("radar1", "season1"),
radar = "\"radar1\"",
season = "\"season1\"",
separate = "separate_radar1_season1"
),
separate_radar2_season1 = target(
command = list("radar2", "season1"),
radar = "\"radar2\"",
season = "\"season1\"",
separate = "separate_radar2_season1"
),
separate_radar1_season2 = target(
command = list("radar1", "season2"),
radar = "\"radar1\"",
season = "\"season2\"",
separate = "separate_radar1_season2"
),
separate_radar2_season2 = target(
command = list("radar2", "season2"),
radar = "\"radar2\"",
season = "\"season2\"",
separate = "separate_radar2_season2"
),
to_cross_radar1 = target(
command = list(data_radar1_1, data_radar1_2),
radar = "\"radar1\"",
to_cross = "to_cross_radar1"
),
to_cross_radar2 = target(
command = list(data_radar2_1, data_radar2_2),
radar = "\"radar2\"",
to_cross = "to_cross_radar2"
)
)
equivalent_plans(out, exp)
})

0 comments on commit a62401b

Please sign in to comment.