Skip to content

Commit

Permalink
[FEAT] Support intersect as a DataFrame API
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Oct 29, 2024
1 parent ced8c4b commit f0ed52b
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 0 deletions.
30 changes: 30 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,36 @@ def pivot(
builder = self._builder.pivot(group_by_expr, pivot_col_expr, value_col_expr, agg_expr, names)
return DataFrame(builder)

@DataframePublicAPI
def intersect(self, other: "DataFrame") -> "DataFrame":
"""Returns the intersection of two DataFrames.
Example:
>>> import daft
>>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> df2 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 8, 6]})
>>> df1.intersect(df2).collect()
╭───────┬───────╮
│ a ┆ b │
│ --- ┆ --- │
│ Int64 ┆ Int64 │
╞═══════╪═══════╡
│ 1 ┆ 4 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 6 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 2 of 2 rows)
Args:
other (DataFrame): DataFrame to intersect with
Returns:
DataFrame: DataFrame with the intersection of the two DataFrames
"""
builder = self._builder.intersect(other._builder)
return DataFrame(builder)

def _materialize_results(self) -> None:
"""Materializes the results of for this DataFrame and hold a pointer to the results."""
context = get_context()
Expand Down
4 changes: 4 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: # type: igno
builder = self._builder.concat(other._builder)
return LogicalPlanBuilder(builder)

def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.intersect(other._builder, False)
return LogicalPlanBuilder(builder)

def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder:
builder = self._builder.add_monotonically_increasing_id(column_name)
return LogicalPlanBuilder(builder)
Expand Down
11 changes: 11 additions & 0 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,13 @@ impl LogicalPlanBuilder {
Ok(self.with_new_plan(logical_plan))
}

pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
logical_ops::Intersect::try_new(self.plan.clone(), other.plan.clone(), is_all)?
.to_optimized_join()?;
Ok(self.with_new_plan(logical_plan))
}

pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
logical_ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into();
Expand Down Expand Up @@ -902,6 +909,10 @@ impl PyLogicalPlanBuilder {
Ok(self.builder.concat(&other.builder)?.into())
}

pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
Ok(self.builder.intersect(&other.builder, is_all)?.into())
}

pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult<Self> {
Ok(self
.builder
Expand Down
2 changes: 2 additions & 0 deletions src/daft-plan/src/logical_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod pivot;
mod project;
mod repartition;
mod sample;
mod set_operations;
mod sink;
mod sort;
mod source;
Expand All @@ -29,6 +30,7 @@ pub use pivot::Pivot;
pub use project::Project;
pub use repartition::Repartition;
pub use sample::Sample;
pub use set_operations::Intersect;
pub use sink::Sink;
pub use sort::Sort;
pub use source::Source;
Expand Down
112 changes: 112 additions & 0 deletions src/daft-plan/src/logical_ops/set_operations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use std::sync::Arc;

use common_error::{DaftError, DaftResult};
use daft_core::join::JoinType;
use daft_dsl::{col, zero_lit, ExprRef};
use daft_schema::schema::Schema;
use itertools::Itertools;
use snafu::ResultExt;

use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan};

// null safe equal (a <> b) is equivalent to:
// nvl(a, zero_lit_of_type(a)) = nvl(b, zero_lit_of_type(b)) and is_null(a) = is_null(b)
fn to_null_safe_equal_join_keys(schema: &Schema) -> DaftResult<Vec<ExprRef>> {
schema
.names()
.iter()
.map(|k| {
let field = schema.get_field(k).unwrap();
// TODO: expr name should reflect the expression itself, a.k.a is_null(a)'s field name
// should be `is_null(a)`.
let col_or_zero = col(field.name.clone())
.fill_null(zero_lit(&field.dtype)?)
.alias(field.name.clone() + "__or_zero__");
let is_null = col(field.name.clone())
.is_null()
.alias(field.name.clone() + "__is_null__");
Ok(vec![col_or_zero, is_null])
})
.flatten_ok()
.collect()
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Intersect {
// Upstream nodes.
pub lhs: Arc<LogicalPlan>,
pub rhs: Arc<LogicalPlan>,
pub is_all: bool,
}

impl Intersect {
pub(crate) fn try_new(
lhs: Arc<LogicalPlan>,
rhs: Arc<LogicalPlan>,
is_all: bool,
) -> logical_plan::Result<Self> {
let lhs_schema = lhs.schema();
let rhs_schema = rhs.schema();
if lhs_schema.len() != rhs_schema.len() {
return Err(DaftError::SchemaMismatch(format!(
"Both plans must have the same num of fields to intersect, \
but got[lhs: {} v.s rhs: {}], lhs schema: {}, rhs schema: {}",
lhs_schema.len(),
rhs_schema.len(),
lhs_schema,
rhs_schema
)))
.context(CreationSnafu);
}
// lhs and rhs should have the same type for each field to intersect
if lhs_schema
.fields
.values()
.zip(rhs_schema.fields.values())
.any(|(l, r)| l.dtype != r.dtype)
{
return Err(DaftError::SchemaMismatch(format!(
"Both plans' schemas should have the same type for each field to intersect, \
but got lhs schema: {}, rhs schema: {}",
lhs_schema, rhs_schema
)))
.context(CreationSnafu);
}
Ok(Self { lhs, rhs, is_all })
}

/// intersect distinct could be represented as a semi join + distinct
/// the following intersect operator:
/// ```sql
/// select a1, a2 from t1 intersect select b1, b2 from t2
/// ```
/// is the same as:
/// ```sql
/// select distinct a1, a2 from t1 left semi join t2
/// on t1.a1 <> t2.b1 and t1.a2 <> t2.b2
/// ```
/// TODO: Move this logical to logical optimization rules
pub(crate) fn to_optimized_join(&self) -> logical_plan::Result<LogicalPlan> {
if self.is_all {
Err(logical_plan::Error::CreationError {
source: DaftError::InternalError("intersect all is not supported yet".to_string()),
})
} else {
let left_on = to_null_safe_equal_join_keys(&self.lhs.schema())
.map_err(|e| logical_plan::Error::CreationError { source: e })?;
let right_on = to_null_safe_equal_join_keys(&self.rhs.schema())
.map_err(|e| logical_plan::Error::CreationError { source: e })?;
let join = logical_plan::Join::try_new(
self.lhs.clone(),
self.rhs.clone(),
left_on,
right_on,
JoinType::Semi,
None,
None,
None,
);
join.map(|j| logical_plan::Distinct::new(j.into()).into())
}
}
}
47 changes: 47 additions & 0 deletions tests/dataframe/test_intersect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import daft
from daft import col


def test_simple_intersect(make_df):
df1 = make_df({"foo": [1, 2, 3]})
df2 = make_df({"bar": [2, 3, 4]})
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, 3]}


def test_intersect_with_duplicate(make_df):
df1 = make_df({"foo": [1, 2, 2, 3]})
df2 = make_df({"bar": [2, 3, 3]})
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, 3]}


def test_self_intersect(make_df):
df = make_df({"foo": [1, 2, 3]})
result = df.intersect(df).sort(by="foo")
assert result.to_pydict() == {"foo": [1, 2, 3]}


def test_intersect_empty(make_df):
df1 = make_df({"foo": [1, 2, 3]})
df2 = make_df({"bar": []}).select(col("bar").cast(daft.DataType.int64()))
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": []}


def test_intersect_with_nulls(make_df):
df1 = make_df({"foo": [1, 2, None]})
df1_without_mull = make_df({"foo": [1, 2]})
df2 = make_df({"bar": [2, 3, None]})
df2_without_null = make_df({"bar": [2, 3]})

result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, None]}

result = df1_without_mull.intersect(df2)
assert result.to_pydict() == {"foo": [2]}

result = df1.intersect(df2_without_null)
assert result.to_pydict() == {"foo": [2]}

0 comments on commit f0ed52b

Please sign in to comment.