Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Swordfish specific test fixtures #3164

Merged
merged 7 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/daft-local-execution/src/intermediate_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ pub mod explode;
pub mod filter;
pub mod inner_hash_join_probe;
pub mod intermediate_op;
pub mod pivot;
pub mod project;
pub mod sample;
pub mod unpivot;
57 changes: 0 additions & 57 deletions src/daft-local-execution/src/intermediate_ops/pivot.rs

This file was deleted.

13 changes: 8 additions & 5 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use crate::{
actor_pool_project::ActorPoolProjectOperator, aggregate::AggregateOperator,
anti_semi_hash_join_probe::AntiSemiProbeOperator, explode::ExplodeOperator,
filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator,
intermediate_op::IntermediateNode, pivot::PivotOperator, project::ProjectOperator,
sample::SampleOperator, unpivot::UnpivotOperator,
intermediate_op::IntermediateNode, project::ProjectOperator, sample::SampleOperator,
unpivot::UnpivotOperator,
},
sinks::{
aggregate::AggregateSink,
Expand All @@ -38,6 +38,7 @@ use crate::{
hash_join_build::HashJoinBuildSink,
limit::LimitSink,
outer_hash_join_probe::OuterHashJoinProbeSink,
pivot::PivotSink,
sort::SortSink,
streaming_sink::StreamingSinkNode,
write::{WriteFormat, WriteSink},
Expand Down Expand Up @@ -282,17 +283,19 @@ pub fn physical_plan_to_pipeline(
group_by,
pivot_column,
value_column,
aggregation,
names,
..
}) => {
let pivot_op = PivotOperator::new(
let child_node = physical_plan_to_pipeline(input, psets, cfg)?;
let pivot_sink = PivotSink::new(
group_by.clone(),
pivot_column.clone(),
value_column.clone(),
aggregation.clone(),
names.clone(),
);
let child_node = physical_plan_to_pipeline(input, psets, cfg)?;
IntermediateNode::new(Arc::new(pivot_op), vec![child_node]).boxed()
BlockingSinkNode::new(Arc::new(pivot_sink), child_node).boxed()
}
LocalPhysicalPlan::Sort(Sort {
input,
Expand Down
1 change: 1 addition & 0 deletions src/daft-local-execution/src/sinks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod concat;
pub mod hash_join_build;
pub mod limit;
pub mod outer_hash_join_probe;
pub mod pivot;
pub mod sort;
pub mod streaming_sink;
pub mod write;
126 changes: 126 additions & 0 deletions src/daft-local-execution/src/sinks/pivot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use std::sync::Arc;

use common_error::DaftResult;
use daft_dsl::{AggExpr, Expr, ExprRef};
use daft_micropartition::MicroPartition;
use tracing::instrument;

use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus};
use crate::{pipeline::PipelineResultType, NUM_CPUS};

enum PivotState {
Accumulating(Vec<Arc<MicroPartition>>),
Done,
}

impl PivotState {
fn push(&mut self, part: Arc<MicroPartition>) {
if let Self::Accumulating(ref mut parts) = self {
parts.push(part);
} else {
panic!("PivotSink should be in Accumulating state");

Check warning on line 21 in src/daft-local-execution/src/sinks/pivot.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/pivot.rs#L21

Added line #L21 was not covered by tests
}
}

fn finalize(&mut self) -> Vec<Arc<MicroPartition>> {
let res = if let Self::Accumulating(ref mut parts) = self {
std::mem::take(parts)
} else {
panic!("PivotSink should be in Accumulating state");

Check warning on line 29 in src/daft-local-execution/src/sinks/pivot.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/pivot.rs#L29

Added line #L29 was not covered by tests
};
*self = Self::Done;
res
}
}

impl BlockingSinkState for PivotState {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}

pub struct PivotSink {
pub group_by: Vec<ExprRef>,
pub pivot_column: ExprRef,
pub value_column: ExprRef,
pub aggregation: AggExpr,
pub names: Vec<String>,
}

impl PivotSink {
pub fn new(
group_by: Vec<ExprRef>,
pivot_column: ExprRef,
value_column: ExprRef,
aggregation: AggExpr,
names: Vec<String>,
) -> Self {
Self {
group_by,
pivot_column,
value_column,
aggregation,
names,
}
}
}

impl BlockingSink for PivotSink {
#[instrument(skip_all, name = "PivotSink::sink")]
fn sink(
&self,
input: &Arc<MicroPartition>,
mut state: Box<dyn BlockingSinkState>,
) -> DaftResult<BlockingSinkStatus> {
state
.as_any_mut()
.downcast_mut::<PivotState>()
.expect("PivotSink should have PivotState")
.push(input.clone());
Ok(BlockingSinkStatus::NeedMoreInput(state))
}

#[instrument(skip_all, name = "PivotSink::finalize")]
fn finalize(
&self,
states: Vec<Box<dyn BlockingSinkState>>,
) -> DaftResult<Option<PipelineResultType>> {
let all_parts = states.into_iter().flat_map(|mut state| {
state
.as_any_mut()
.downcast_mut::<PivotState>()
.expect("PivotSink should have PivotState")
.finalize()
});
let concated = MicroPartition::concat(all_parts)?;
colin-ho marked this conversation as resolved.
Show resolved Hide resolved
let group_by_with_pivot = self
.group_by
.iter()
.chain(std::iter::once(&self.pivot_column))
.cloned()
.collect::<Vec<_>>();
let agged = concated.agg(
&[Expr::Agg(self.aggregation.clone()).into()],
&group_by_with_pivot,
)?;
let pivoted = Arc::new(agged.pivot(
&self.group_by,
self.pivot_column.clone(),
self.value_column.clone(),
self.names.clone(),
)?);
Ok(Some(pivoted.into()))
}

fn name(&self) -> &'static str {
"PivotSink"
}

fn max_concurrency(&self) -> usize {
*NUM_CPUS
}

fn make_state(&self) -> DaftResult<Box<dyn BlockingSinkState>> {
Ok(Box::new(PivotState::Accumulating(vec![])))
}
}
3 changes: 3 additions & 0 deletions src/daft-physical-plan/src/local_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ impl LocalPhysicalPlan {
group_by: Vec<ExprRef>,
pivot_column: ExprRef,
value_column: ExprRef,
aggregation: AggExpr,
names: Vec<String>,
schema: SchemaRef,
) -> LocalPhysicalPlanRef {
Expand All @@ -213,6 +214,7 @@ impl LocalPhysicalPlan {
group_by,
pivot_column,
value_column,
aggregation,
names,
schema,
plan_stats: PlanStats {},
Expand Down Expand Up @@ -438,6 +440,7 @@ pub struct Pivot {
pub group_by: Vec<ExprRef>,
pub pivot_column: ExprRef,
pub value_column: ExprRef,
pub aggregation: AggExpr,
pub names: Vec<String>,
pub schema: SchemaRef,
pub plan_stats: PlanStats,
Expand Down
23 changes: 3 additions & 20 deletions src/daft-physical-plan/src/translate.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use common_error::{DaftError, DaftResult};
use daft_core::{join::JoinStrategy, prelude::Schema};
use daft_core::join::JoinStrategy;
use daft_dsl::ExprRef;
use daft_plan::{JoinType, LogicalPlan, LogicalPlanRef, SourceInfo};

Expand Down Expand Up @@ -91,29 +91,12 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult<LocalPhysicalPlanRef> {
}
LogicalPlan::Pivot(pivot) => {
let input = translate(&pivot.input)?;
let groupby_with_pivot = pivot
.group_by
.iter()
.chain(std::iter::once(&pivot.pivot_column))
.cloned()
.collect::<Vec<_>>();
let aggregate_fields = groupby_with_pivot
.iter()
.map(|expr| expr.to_field(input.schema()))
.chain(std::iter::once(pivot.aggregation.to_field(input.schema())))
.collect::<DaftResult<Vec<_>>>()?;
let aggregate_schema = Schema::new(aggregate_fields)?;
let aggregate = LocalPhysicalPlan::hash_aggregate(
input,
vec![pivot.aggregation.clone(); 1],
groupby_with_pivot,
aggregate_schema.into(),
);
Ok(LocalPhysicalPlan::pivot(
aggregate,
input,
pivot.group_by.clone(),
pivot.pivot_column.clone(),
pivot.value_column.clone(),
pivot.aggregation.clone(),
pivot.names.clone(),
pivot.output_schema.clone(),
))
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

import daft
import daft.context
from daft.table import MicroPartition

# import all conftest
Expand Down Expand Up @@ -170,3 +171,13 @@ def assert_df_equals(
except AssertionError:
print(f"Failed assertion for col: {col}")
raise


@pytest.fixture(
scope="function",
params=[1, None] if daft.context.get_context().daft_execution_config.enable_native_executor else [None],
)
def with_morsel_size(request):
morsel_size = request.param
with daft.context.execution_config_ctx(default_morsel_size=morsel_size):
yield morsel_size
5 changes: 4 additions & 1 deletion tests/cookbook/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def service_requests_csv_pd_df():
return pd.read_csv(COOKBOOK_DATA_CSV, keep_default_na=False)[COLUMNS]


@pytest.fixture(scope="module", params=[1, 2])
@pytest.fixture(
scope="module",
params=[1, 2] if daft.context.get_context().daft_execution_config.enable_native_executor is False else [1],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM I think, but in the future is should really just be if daft.context.get_context().runner == "ray" since partitioning only makes sense for the ray runner

def repartition_nparts(request):
"""Adds a `n_repartitions` parameter to test cases which provides the number of
partitions that the test case should repartition its dataset into for testing
Expand Down
Loading
Loading