Skip to content

Commit

Permalink
[BUG] With_new_children not implemented for sample (#2528)
Browse files Browse the repository at this point in the history
Addresses: #2510

Fixed:
```
import daft
x1 = daft.from_pydict({"foo": [1, 2], "bar": [1, 2]}) 
x2 = daft.from_pydict({"foo": [3, 4], "bar": [3, 4]}) 
x1 = x1.sample(0.5) 
x2 = x2.sample(0.5) 
xx = x1.concat(x2) 
xx = xx.select('foo') 
print(xx.collect())
```

---------

Co-authored-by: Colin Ho <[email protected]>
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
3 people authored Jul 17, 2024
1 parent 346868e commit a226d00
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/daft-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ impl LogicalPlan {
Self::Sink(Sink { sink_info, .. }) => Self::Sink(Sink::try_new(input.clone(), sink_info.clone()).unwrap()),
Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::new(input.clone(), Some(column_name))),
Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }),
_ => panic!("Logical op {} has two inputs, but got one", self),
Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)),
Self::Concat(_) => panic!("Concat ops should never have only one input, but got one"),
Self::Join(_) => panic!("Join ops should never have only one input, but got one"),
},
[input1, input2] => match self {
Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"),
Expand Down
16 changes: 16 additions & 0 deletions tests/dataframe/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,19 @@ def test_sample_with_replacement(make_df, valid_data: list[dict[str, float]]) ->
assert df.column_names == list(valid_data[0].keys())
# Check that the two rows are the same, which should be for this seed.
assert all(col[0] == col[1] for col in df.to_pydict().values())


def test_sample_with_concat(make_df, valid_data: list[dict[str, float]]) -> None:
df1 = make_df(valid_data)
df2 = make_df(valid_data)

df1 = df1.sample(fraction=0.5, seed=42)
df2 = df2.sample(fraction=0.5, seed=42)

df = df1.concat(df2)
df.collect()

assert len(df) == 4
assert df.column_names == list(valid_data[0].keys())
# Check that the two rows are the same, which should be for this seed.
assert all(col[:2] == col[2:] for col in df.to_pydict().values())

0 comments on commit a226d00

Please sign in to comment.