From 2fbf05f28499e2faa56492d6adf6182b857faaaf Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 17 Jul 2024 09:37:48 -0700 Subject: [PATCH 1/3] fix --- src/daft-plan/src/logical_plan.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index cb04907447..5f6d2ed4f4 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -199,6 +199,7 @@ impl LogicalPlan { Self::Sink(Sink { sink_info, .. }) => Self::Sink(Sink::try_new(input.clone(), sink_info.clone()).unwrap()), Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::new(input.clone(), Some(column_name))), Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }), + Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)), _ => panic!("Logical op {} has two inputs, but got one", self), }, [input1, input2] => match self { From e6934afe20b498212aab71f90f13598e0b0988e0 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 17 Jul 2024 09:44:28 -0700 Subject: [PATCH 2/3] remove catch all --- src/daft-plan/src/logical_plan.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 5f6d2ed4f4..8a37b50c2e 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -200,7 +200,8 @@ impl LogicalPlan { Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::new(input.clone(), Some(column_name))), Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }), Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)), - _ => panic!("Logical op {} has two inputs, but got one", self), + Self::Concat(_) => panic!("Concat ops should never have only one input, but got one"), + Self::Join(_) => panic!("Join ops should never have only one input, but got one"), }, [input1, input2] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), From 22be5515f56bee39afbd82b05f65958aac71c52f Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 17 Jul 2024 12:49:09 -0700 Subject: [PATCH 3/3] add test --- tests/dataframe/test_sample.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/dataframe/test_sample.py b/tests/dataframe/test_sample.py index aec7af62f3..f59b6a4172 100644 --- a/tests/dataframe/test_sample.py +++ b/tests/dataframe/test_sample.py @@ -81,3 +81,19 @@ def test_sample_with_replacement(make_df, valid_data: list[dict[str, float]]) -> assert df.column_names == list(valid_data[0].keys()) # Check that the two rows are the same, which should be for this seed. assert all(col[0] == col[1] for col in df.to_pydict().values()) + + +def test_sample_with_concat(make_df, valid_data: list[dict[str, float]]) -> None: + df1 = make_df(valid_data) + df2 = make_df(valid_data) + + df1 = df1.sample(fraction=0.5, seed=42) + df2 = df2.sample(fraction=0.5, seed=42) + + df = df1.concat(df2) + df.collect() + + assert len(df) == 4 + assert df.column_names == list(valid_data[0].keys()) + # Check that the two rows are the same, which should be for this seed. + assert all(col[:2] == col[2:] for col in df.to_pydict().values())