Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix actor pool project splitting when column is not renamed #2998

Merged
merged 5 commits into from
Oct 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl SplitActorPoolProjects {
impl OptimizerRule for SplitActorPoolProjects {
fn try_optimize(&self, plan: Arc<LogicalPlan>) -> DaftResult<Transformed<Arc<LogicalPlan>>> {
plan.transform_down(|node| match node.as_ref() {
LogicalPlan::Project(projection) => try_optimize_project(projection, node.clone(), 0),
LogicalPlan::Project(projection) => try_optimize_project(projection, node.clone()),
_ => Ok(Transformed::no(node)),
})
}
Expand Down Expand Up @@ -370,8 +370,34 @@ fn split_projection(
fn try_optimize_project(
projection: &Project,
plan: Arc<LogicalPlan>,
) -> DaftResult<Transformed<Arc<LogicalPlan>>> {
// Add aliases to the expressions in the projection to preserve original names when splitting stateful UDFs.
// This is needed because when we split stateful UDFs, we create new names for intermediates, but we would like
// to have the same expression names as the original projection.
let aliased_projection_exprs = projection
.projection
.iter()
.map(|e| {
if has_stateful_udf(e) && !matches!(e.as_ref(), Expr::Alias(..)) {
e.alias(e.name())
} else {
e.clone()
}
})
.collect();

let aliased_projection = Project::try_new(projection.input.clone(), aliased_projection_exprs)?;

recursive_optimize_project(&aliased_projection, plan, 0)
}

fn recursive_optimize_project(
projection: &Project,
plan: Arc<LogicalPlan>,
recursive_count: usize,
) -> DaftResult<Transformed<Arc<LogicalPlan>>> {
// TODO: eliminate the need for recursive calls by doing a post-order traversal of the plan tree.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent :)


// Base case: no stateful UDFs at all
let has_stateful_udfs = projection.projection.iter().any(has_stateful_udf);
if !has_stateful_udfs {
Expand Down Expand Up @@ -416,8 +442,11 @@ fn try_optimize_project(
// Recursively run the rule on the new child Project
let new_project = Project::try_new(projection.input.clone(), remaining)?;
let new_child_project = LogicalPlan::Project(new_project.clone()).arced();
let optimized_child_plan =
try_optimize_project(&new_project, new_child_project.clone(), recursive_count + 1)?;
let optimized_child_plan = recursive_optimize_project(
&new_project,
new_child_project.clone(),
recursive_count + 1,
)?;
optimized_child_plan.data.clone()
};

Expand Down Expand Up @@ -785,6 +814,67 @@ mod tests {
Ok(())
}

#[test]
fn test_multiple_with_column_serial_no_alias() -> DaftResult<()> {
let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Utf8)]);
let scan_plan = dummy_scan_node(scan_op);
let stacked_stateful_project_expr =
create_stateful_udf(vec![create_stateful_udf(vec![col("a")])]);

// Add a Projection with StatefulUDF and resource request
let project_plan = scan_plan
.select(vec![stacked_stateful_project_expr.clone()])?
.build();

let intermediate_name = "__TruncateRootStatefulUDF_0-0-0__";

let expected = scan_plan.select(vec![col("a")])?.build();
let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(
expected,
vec![
col("a"),
create_stateful_udf(vec![col("a")])
.clone()
.alias(intermediate_name),
],
)?)
.arced();
let expected =
LogicalPlan::Project(Project::try_new(expected, vec![col(intermediate_name)])?).arced();
let expected =
LogicalPlan::Project(Project::try_new(expected, vec![col(intermediate_name)])?).arced();
let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(
expected,
vec![
col(intermediate_name),
create_stateful_udf(vec![col(intermediate_name)])
.clone()
.alias("a"),
],
)?)
.arced();
let expected = LogicalPlan::Project(Project::try_new(expected, vec![col("a")])?).arced();
assert_optimized_plan_eq(project_plan.clone(), expected.clone())?;

let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(
scan_plan.build(),
vec![create_stateful_udf(vec![col("a")])
.clone()
.alias(intermediate_name)],
)?)
.arced();
let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(
expected,
vec![create_stateful_udf(vec![col(intermediate_name)])
.clone()
.alias("a")],
)?)
.arced();
assert_optimized_plan_eq_with_projection_pushdown(project_plan.clone(), expected.clone())?;

Ok(())
}

#[test]
fn test_multiple_with_column_serial_multiarg() -> DaftResult<()> {
let scan_op = dummy_scan_operator(vec![
Expand Down
Loading