From f0ed52b4d1e550a25f2a91d9f63508861406b544 Mon Sep 17 00:00:00 2001 From: advancedxy <807537+advancedxy@users.noreply.github.com> Date: Mon, 28 Oct 2024 15:02:23 +0800 Subject: [PATCH] [FEAT] Support intersect as a DataFrame API --- daft/dataframe/dataframe.py | 30 +++++ daft/logical/builder.py | 4 + src/daft-plan/src/builder.rs | 11 ++ src/daft-plan/src/logical_ops/mod.rs | 2 + .../src/logical_ops/set_operations.rs | 112 ++++++++++++++++++ tests/dataframe/test_intersect.py | 47 ++++++++ 6 files changed, 206 insertions(+) create mode 100644 src/daft-plan/src/logical_ops/set_operations.rs create mode 100644 tests/dataframe/test_intersect.py diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 8dbb33111e..6caa5df319 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -2457,6 +2457,36 @@ def pivot( builder = self._builder.pivot(group_by_expr, pivot_col_expr, value_col_expr, agg_expr, names) return DataFrame(builder) + @DataframePublicAPI + def intersect(self, other: "DataFrame") -> "DataFrame": + """Returns the intersection of two DataFrames. + + Example: + >>> import daft + >>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> df2 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 8, 6]}) + >>> df1.intersect(df2).collect() + ╭───────┬───────╮ + │ a ┆ b │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 1 ┆ 4 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 3 ┆ 6 │ + ╰───────┴───────╯ + + (Showing first 2 of 2 rows) + + Args: + other (DataFrame): DataFrame to intersect with + + Returns: + DataFrame: DataFrame with the intersection of the two DataFrames + """ + builder = self._builder.intersect(other._builder) + return DataFrame(builder) + def _materialize_results(self) -> None: """Materializes the results of for this DataFrame and hold a pointer to the results.""" context = get_context() diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 36e6ead37c..95f25ff83b 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -273,6 +273,10 @@ def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: # type: igno builder = self._builder.concat(other._builder) return LogicalPlanBuilder(builder) + def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: + builder = self._builder.intersect(other._builder, False) + return LogicalPlanBuilder(builder) + def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: builder = self._builder.add_monotonically_increasing_id(column_name) return LogicalPlanBuilder(builder) diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 64f4faf5c5..f694c7371c 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -474,6 +474,13 @@ impl LogicalPlanBuilder { Ok(self.with_new_plan(logical_plan)) } + pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult { + let logical_plan: LogicalPlan = + logical_ops::Intersect::try_new(self.plan.clone(), other.plan.clone(), is_all)? + .to_optimized_join()?; + Ok(self.with_new_plan(logical_plan)) + } + pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into(); @@ -902,6 +909,10 @@ impl PyLogicalPlanBuilder { Ok(self.builder.concat(&other.builder)?.into()) } + pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult { + Ok(self.builder.intersect(&other.builder, is_all)?.into()) + } + pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult { Ok(self .builder diff --git a/src/daft-plan/src/logical_ops/mod.rs b/src/daft-plan/src/logical_ops/mod.rs index 339589deea..ec0e47e0e7 100644 --- a/src/daft-plan/src/logical_ops/mod.rs +++ b/src/daft-plan/src/logical_ops/mod.rs @@ -11,6 +11,7 @@ mod pivot; mod project; mod repartition; mod sample; +mod set_operations; mod sink; mod sort; mod source; @@ -29,6 +30,7 @@ pub use pivot::Pivot; pub use project::Project; pub use repartition::Repartition; pub use sample::Sample; +pub use set_operations::Intersect; pub use sink::Sink; pub use sort::Sort; pub use source::Source; diff --git a/src/daft-plan/src/logical_ops/set_operations.rs b/src/daft-plan/src/logical_ops/set_operations.rs new file mode 100644 index 0000000000..31fd22c806 --- /dev/null +++ b/src/daft-plan/src/logical_ops/set_operations.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use daft_core::join::JoinType; +use daft_dsl::{col, zero_lit, ExprRef}; +use daft_schema::schema::Schema; +use itertools::Itertools; +use snafu::ResultExt; + +use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan}; + +// null safe equal (a <> b) is equivalent to: +// nvl(a, zero_lit_of_type(a)) = nvl(b, zero_lit_of_type(b)) and is_null(a) = is_null(b) +fn to_null_safe_equal_join_keys(schema: &Schema) -> DaftResult> { + schema + .names() + .iter() + .map(|k| { + let field = schema.get_field(k).unwrap(); + // TODO: expr name should reflect the expression itself, a.k.a is_null(a)'s field name + // should be `is_null(a)`. + let col_or_zero = col(field.name.clone()) + .fill_null(zero_lit(&field.dtype)?) + .alias(field.name.clone() + "__or_zero__"); + let is_null = col(field.name.clone()) + .is_null() + .alias(field.name.clone() + "__is_null__"); + Ok(vec![col_or_zero, is_null]) + }) + .flatten_ok() + .collect() +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Intersect { + // Upstream nodes. + pub lhs: Arc, + pub rhs: Arc, + pub is_all: bool, +} + +impl Intersect { + pub(crate) fn try_new( + lhs: Arc, + rhs: Arc, + is_all: bool, + ) -> logical_plan::Result { + let lhs_schema = lhs.schema(); + let rhs_schema = rhs.schema(); + if lhs_schema.len() != rhs_schema.len() { + return Err(DaftError::SchemaMismatch(format!( + "Both plans must have the same num of fields to intersect, \ + but got[lhs: {} v.s rhs: {}], lhs schema: {}, rhs schema: {}", + lhs_schema.len(), + rhs_schema.len(), + lhs_schema, + rhs_schema + ))) + .context(CreationSnafu); + } + // lhs and rhs should have the same type for each field to intersect + if lhs_schema + .fields + .values() + .zip(rhs_schema.fields.values()) + .any(|(l, r)| l.dtype != r.dtype) + { + return Err(DaftError::SchemaMismatch(format!( + "Both plans' schemas should have the same type for each field to intersect, \ + but got lhs schema: {}, rhs schema: {}", + lhs_schema, rhs_schema + ))) + .context(CreationSnafu); + } + Ok(Self { lhs, rhs, is_all }) + } + + /// intersect distinct could be represented as a semi join + distinct + /// the following intersect operator: + /// ```sql + /// select a1, a2 from t1 intersect select b1, b2 from t2 + /// ``` + /// is the same as: + /// ```sql + /// select distinct a1, a2 from t1 left semi join t2 + /// on t1.a1 <> t2.b1 and t1.a2 <> t2.b2 + /// ``` + /// TODO: Move this logical to logical optimization rules + pub(crate) fn to_optimized_join(&self) -> logical_plan::Result { + if self.is_all { + Err(logical_plan::Error::CreationError { + source: DaftError::InternalError("intersect all is not supported yet".to_string()), + }) + } else { + let left_on = to_null_safe_equal_join_keys(&self.lhs.schema()) + .map_err(|e| logical_plan::Error::CreationError { source: e })?; + let right_on = to_null_safe_equal_join_keys(&self.rhs.schema()) + .map_err(|e| logical_plan::Error::CreationError { source: e })?; + let join = logical_plan::Join::try_new( + self.lhs.clone(), + self.rhs.clone(), + left_on, + right_on, + JoinType::Semi, + None, + None, + None, + ); + join.map(|j| logical_plan::Distinct::new(j.into()).into()) + } + } +} diff --git a/tests/dataframe/test_intersect.py b/tests/dataframe/test_intersect.py new file mode 100644 index 0000000000..24330ec84e --- /dev/null +++ b/tests/dataframe/test_intersect.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import daft +from daft import col + + +def test_simple_intersect(make_df): + df1 = make_df({"foo": [1, 2, 3]}) + df2 = make_df({"bar": [2, 3, 4]}) + result = df1.intersect(df2) + assert result.to_pydict() == {"foo": [2, 3]} + + +def test_intersect_with_duplicate(make_df): + df1 = make_df({"foo": [1, 2, 2, 3]}) + df2 = make_df({"bar": [2, 3, 3]}) + result = df1.intersect(df2) + assert result.to_pydict() == {"foo": [2, 3]} + + +def test_self_intersect(make_df): + df = make_df({"foo": [1, 2, 3]}) + result = df.intersect(df).sort(by="foo") + assert result.to_pydict() == {"foo": [1, 2, 3]} + + +def test_intersect_empty(make_df): + df1 = make_df({"foo": [1, 2, 3]}) + df2 = make_df({"bar": []}).select(col("bar").cast(daft.DataType.int64())) + result = df1.intersect(df2) + assert result.to_pydict() == {"foo": []} + + +def test_intersect_with_nulls(make_df): + df1 = make_df({"foo": [1, 2, None]}) + df1_without_mull = make_df({"foo": [1, 2]}) + df2 = make_df({"bar": [2, 3, None]}) + df2_without_null = make_df({"bar": [2, 3]}) + + result = df1.intersect(df2) + assert result.to_pydict() == {"foo": [2, None]} + + result = df1_without_mull.intersect(df2) + assert result.to_pydict() == {"foo": [2]} + + result = df1.intersect(df2_without_null) + assert result.to_pydict() == {"foo": [2]}