Skip to content

Commit

Permalink
[FEAT] Enable sample for swordfish (#3079)
Browse files Browse the repository at this point in the history
Adds sample as an intermediate operator. Unskips all the sample tests
(except one which depends on concat).

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Oct 21, 2024
1 parent 23d4a1f commit 16665f2
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/daft-local-execution/src/intermediate_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pub mod filter;
pub mod hash_join_probe;
pub mod intermediate_op;
pub mod project;
pub mod sample;
47 changes: 47 additions & 0 deletions src/daft-local-execution/src/intermediate_ops/sample.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use std::sync::Arc;

use common_error::DaftResult;
use tracing::instrument;

use super::intermediate_op::{
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState,
};
use crate::pipeline::PipelineResultType;

pub struct SampleOperator {
fraction: f64,
with_replacement: bool,
seed: Option<u64>,
}

impl SampleOperator {
pub fn new(fraction: f64, with_replacement: bool, seed: Option<u64>) -> Self {
Self {
fraction,
with_replacement,
seed,
}
}
}

impl IntermediateOperator for SampleOperator {
#[instrument(skip_all, name = "SampleOperator::execute")]
fn execute(
&self,
_idx: usize,
input: &PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult> {
let out =
input
.as_data()
.sample_by_fraction(self.fraction, self.with_replacement, self.seed)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
out,
))))
}

fn name(&self) -> &'static str {
"SampleOperator"
}
}
16 changes: 14 additions & 2 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use daft_dsl::{col, join::get_common_join_keys, Expr};
use daft_micropartition::MicroPartition;
use daft_physical_plan::{
EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, Project,
Sort, UnGroupedAggregate,
Sample, Sort, UnGroupedAggregate,
};
use daft_plan::{populate_aggregation_stages, JoinType};
use daft_table::{Probeable, Table};
Expand All @@ -23,7 +23,7 @@ use crate::{
intermediate_ops::{
aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator,
filter::FilterOperator, hash_join_probe::HashJoinProbeOperator,
intermediate_op::IntermediateNode, project::ProjectOperator,
intermediate_op::IntermediateNode, project::ProjectOperator, sample::SampleOperator,
},
sinks::{
aggregate::AggregateSink, blocking_sink::BlockingSinkNode,
Expand Down Expand Up @@ -125,6 +125,17 @@ pub fn physical_plan_to_pipeline(
let child_node = physical_plan_to_pipeline(input, psets)?;
IntermediateNode::new(Arc::new(proj_op), vec![child_node]).boxed()
}
LocalPhysicalPlan::Sample(Sample {
input,
fraction,
with_replacement,
seed,
..
}) => {
let sample_op = SampleOperator::new(*fraction, *with_replacement, *seed);
let child_node = physical_plan_to_pipeline(input, psets)?;
IntermediateNode::new(Arc::new(sample_op), vec![child_node]).boxed()
}
LocalPhysicalPlan::Filter(Filter {
input, predicate, ..
}) => {
Expand Down Expand Up @@ -233,6 +244,7 @@ pub fn physical_plan_to_pipeline(
let child_node = physical_plan_to_pipeline(input, psets)?;
BlockingSinkNode::new(sort_sink.boxed(), child_node).boxed()
}

LocalPhysicalPlan::HashJoin(HashJoin {
left,
right,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-physical-plan/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ mod translate;

pub use local_plan::{
Concat, EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan,
LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Project, Sort, UnGroupedAggregate,
LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Project, Sample, Sort, UnGroupedAggregate,
};
pub use translate::translate;
31 changes: 30 additions & 1 deletion src/daft-physical-plan/src/local_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub enum LocalPhysicalPlan {
// Unpivot(Unpivot),
Sort(Sort),
// Split(Split),
// Sample(Sample),
Sample(Sample),
// MonotonicallyIncreasingId(MonotonicallyIncreasingId),
// Coalesce(Coalesce),
// Flatten(Flatten),
Expand Down Expand Up @@ -167,6 +167,24 @@ impl LocalPhysicalPlan {
.arced()
}

pub(crate) fn sample(
input: LocalPhysicalPlanRef,
fraction: f64,
with_replacement: bool,
seed: Option<u64>,
) -> LocalPhysicalPlanRef {
let schema = input.schema().clone();
Self::Sample(Sample {
input,
fraction,
with_replacement,
seed,
schema,
plan_stats: PlanStats {},
})
.arced()
}

pub(crate) fn hash_join(
left: LocalPhysicalPlanRef,
right: LocalPhysicalPlanRef,
Expand Down Expand Up @@ -211,6 +229,7 @@ impl LocalPhysicalPlan {
| Self::UnGroupedAggregate(UnGroupedAggregate { schema, .. })
| Self::HashAggregate(HashAggregate { schema, .. })
| Self::Sort(Sort { schema, .. })
| Self::Sample(Sample { schema, .. })
| Self::HashJoin(HashJoin { schema, .. })
| Self::Concat(Concat { schema, .. }) => schema,
Self::InMemoryScan(InMemoryScan { info, .. }) => &info.source_schema,
Expand Down Expand Up @@ -271,6 +290,16 @@ pub struct Sort {
pub plan_stats: PlanStats,
}

#[derive(Debug)]
pub struct Sample {
pub input: LocalPhysicalPlanRef,
pub fraction: f64,
pub with_replacement: bool,
pub seed: Option<u64>,
pub schema: SchemaRef,
pub plan_stats: PlanStats,
}

#[derive(Debug)]
pub struct UnGroupedAggregate {
pub input: LocalPhysicalPlanRef,
Expand Down
9 changes: 9 additions & 0 deletions src/daft-physical-plan/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult<LocalPhysicalPlanRef> {
project.projected_schema.clone(),
))
}
LogicalPlan::Sample(sample) => {
let input = translate(&sample.input)?;
Ok(LocalPhysicalPlan::sample(
input,
sample.fraction,
sample.with_replacement,
sample.seed,
))
}
LogicalPlan::Aggregate(aggregate) => {
let input = translate(&aggregate.input)?;
if aggregate.groupby.is_empty() {
Expand Down
9 changes: 4 additions & 5 deletions tests/dataframe/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@

from daft import context

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)


def test_sample_fraction(make_df, valid_data: list[dict[str, float]]) -> None:
df = make_df(valid_data)
Expand Down Expand Up @@ -105,6 +100,10 @@ def test_sample_without_replacement(make_df, valid_data: list[dict[str, float]])
assert pylist[0] != pylist[1]


@pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for concat",
)
def test_sample_with_concat(make_df, valid_data: list[dict[str, float]]) -> None:
df1 = make_df(valid_data)
df2 = make_df(valid_data)
Expand Down

0 comments on commit 16665f2

Please sign in to comment.