diff --git a/daft/daft.pyi b/daft/daft.pyi
index c74ffe8afa..4b8140e6bd 100644
--- a/daft/daft.pyi
+++ b/daft/daft.pyi
@@ -89,6 +89,8 @@ class JoinType(Enum):
Left: int
Right: int
Outer: int
+ Semi: int
+ Anti: int
@staticmethod
def from_join_type_str(join_type: str) -> JoinType:
diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py
index 8821f31f4f..824e2ea302 100644
--- a/daft/dataframe/dataframe.py
+++ b/daft/dataframe/dataframe.py
@@ -1164,8 +1164,10 @@ def join(
if join_strategy == JoinStrategy.SortMerge and join_type != JoinType.Inner:
raise ValueError("Sort merge join only supports inner joins")
- if join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Outer:
+ elif join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Outer:
raise ValueError("Broadcast join does not support outer joins")
+ elif join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Anti:
+ raise ValueError("Broadcast join does not support Anti joins")
left_exprs = self.__column_input_to_expression(tuple(left_on) if isinstance(left_on, list) else (left_on,))
right_exprs = self.__column_input_to_expression(tuple(right_on) if isinstance(right_on, list) else (right_on,))
diff --git a/daft/hudi/hudi_scan.py b/daft/hudi/hudi_scan.py
index fcef509bff..8c1c98f298 100644
--- a/daft/hudi/hudi_scan.py
+++ b/daft/hudi/hudi_scan.py
@@ -115,8 +115,8 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
file=path,
file_format=file_format_config,
schema=self._schema._schema,
- num_rows=record_count,
storage_config=self._storage_config,
+ num_rows=record_count,
size_bytes=size_bytes,
pushdowns=pushdowns,
partition_values=partition_values,
diff --git a/src/daft-core/src/join.rs b/src/daft-core/src/join.rs
index 3d1385a1be..018fba15b9 100644
--- a/src/daft-core/src/join.rs
+++ b/src/daft-core/src/join.rs
@@ -21,6 +21,8 @@ pub enum JoinType {
Left,
Right,
Outer,
+ Anti,
+ Semi,
}
#[cfg(feature = "python")]
@@ -46,7 +48,7 @@ impl JoinType {
pub fn iterator() -> std::slice::Iter<'static, JoinType> {
use JoinType::*;
- static JOIN_TYPES: [JoinType; 4] = [Inner, Left, Right, Outer];
+ static JOIN_TYPES: [JoinType; 6] = [Inner, Left, Right, Outer, Anti, Semi];
JOIN_TYPES.iter()
}
}
@@ -62,6 +64,8 @@ impl FromStr for JoinType {
"left" => Ok(Left),
"right" => Ok(Right),
"outer" => Ok(Outer),
+ "anti" => Ok(Anti),
+ "semi" => Ok(Semi),
_ => Err(DaftError::TypeError(format!(
"Join type {} is not supported; only the following types are supported: {:?}",
join_type,
diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs
index 8c0c26b4a3..8cc5c3d078 100644
--- a/src/daft-micropartition/src/ops/join.rs
+++ b/src/daft-micropartition/src/ops/join.rs
@@ -23,15 +23,14 @@ impl MicroPartition {
where
F: FnOnce(&Table, &Table, &[ExprRef], &[ExprRef], JoinType) -> DaftResult
,
{
- let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?;
-
+ let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on, how)?;
match (how, self.len(), right.len()) {
(JoinType::Inner, 0, _)
| (JoinType::Inner, _, 0)
| (JoinType::Left, 0, _)
| (JoinType::Right, _, 0)
| (JoinType::Outer, 0, 0) => {
- return Ok(Self::empty(Some(join_schema.into())));
+ return Ok(Self::empty(Some(join_schema)));
}
_ => {}
}
@@ -58,7 +57,7 @@ impl MicroPartition {
}
};
if let TruthValue::False = tv {
- return Ok(Self::empty(Some(join_schema.into())));
+ return Ok(Self::empty(Some(join_schema)));
}
}
@@ -67,11 +66,11 @@ impl MicroPartition {
let rt = right.concat_or_get(io_stats)?;
match (lt.as_slice(), rt.as_slice()) {
- ([], _) | (_, []) => Ok(Self::empty(Some(join_schema.into()))),
+ ([], _) | (_, []) => Ok(Self::empty(Some(join_schema))),
([lt], [rt]) => {
let joined_table = table_join(lt, rt, left_on, right_on, how)?;
Ok(MicroPartition::new_loaded(
- join_schema.into(),
+ join_schema,
vec![joined_table].into(),
None,
))
diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs
index 6d3c57b8a4..8fe2bad6fb 100644
--- a/src/daft-plan/src/logical_ops/join.rs
+++ b/src/daft-plan/src/logical_ops/join.rs
@@ -82,8 +82,9 @@ impl Join {
.map(|(_, field)| field)
.cloned()
.chain(right.schema().fields.iter().filter_map(|(rname, rfield)| {
- if left_join_keys.contains(rname.as_str())
- && right_join_keys.contains(rname.as_str())
+ if (left_join_keys.contains(rname.as_str())
+ && right_join_keys.contains(rname.as_str()))
+ || matches!(join_type, JoinType::Anti | JoinType::Semi)
{
right_input_mapping.insert(rname.clone(), rname.clone());
None
diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs
index 26b64d5551..db91617d15 100644
--- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs
+++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs
@@ -2,12 +2,12 @@ use std::{collections::HashMap, sync::Arc};
use common_error::DaftResult;
-use daft_core::schema::Schema;
+use daft_core::{schema::Schema, JoinType};
use daft_dsl::{col, optimization::replace_columns_with_expressions, Expr, ExprRef};
use indexmap::IndexSet;
use crate::{
- logical_ops::{Aggregate, Pivot, Project, Source},
+ logical_ops::{Aggregate, Join, Pivot, Project, Source},
source_info::SourceInfo,
LogicalPlan, ResourceRequest,
};
@@ -478,6 +478,52 @@ impl PushDownProjection {
}
}
+ fn try_optimize_join(
+ &self,
+ join: &Join,
+ plan: Arc,
+ ) -> DaftResult>> {
+ // If this join prunes columns from its upstream,
+ // then explicitly create a projection to do so.
+ // this is the case for semi and anti joins.
+
+ if matches!(join.join_type, JoinType::Anti | JoinType::Semi) {
+ let required_cols = plan.required_columns();
+ let right_required_cols = required_cols
+ .get(1)
+ .expect("we expect 2 set of required columns for join");
+ let right_schema = join.right.schema();
+
+ if right_required_cols.len() < right_schema.fields.len() {
+ let new_subprojection: LogicalPlan = {
+ let pushdown_column_exprs = right_required_cols
+ .iter()
+ .map(|s| col(s.as_str()))
+ .collect::>();
+
+ Project::try_new(
+ join.right.clone(),
+ pushdown_column_exprs,
+ Default::default(),
+ )?
+ .into()
+ };
+
+ let new_join = plan
+ .with_new_children(&[(join.left).clone(), new_subprojection.into()])
+ .arced();
+
+ Ok(self
+ .try_optimize(new_join.clone())?
+ .or(Transformed::Yes(new_join)))
+ } else {
+ Ok(Transformed::No(plan))
+ }
+ } else {
+ Ok(Transformed::No(plan))
+ }
+ }
+
fn try_optimize_pivot(
&self,
pivot: &Pivot,
@@ -524,6 +570,8 @@ impl OptimizerRule for PushDownProjection {
LogicalPlan::Aggregate(aggregation) => {
self.try_optimize_aggregation(aggregation, plan.clone())
}
+ // Joins also do column projection
+ LogicalPlan::Join(join) => self.try_optimize_join(join, plan.clone()),
// Pivots also do column projection
LogicalPlan::Pivot(pivot) => self.try_optimize_pivot(pivot, plan.clone()),
_ => Ok(Transformed::No(plan)),
diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs
index ec2511f867..072e308321 100644
--- a/src/daft-plan/src/physical_planner/translate.rs
+++ b/src/daft-plan/src/physical_planner/translate.rs
@@ -558,6 +558,16 @@ pub(super) fn translate_single_logical_node(
"Broadcast join does not support outer joins.".to_string(),
));
}
+ (JoinType::Anti, _) => {
+ return Err(common_error::DaftError::ValueError(
+ "Broadcast join does not support anti joins.".to_string(),
+ ));
+ }
+ (JoinType::Semi, _) => {
+ return Err(common_error::DaftError::ValueError(
+ "Broadcast join does not support semi joins.".to_string(),
+ ));
+ }
};
if is_swapped {
diff --git a/src/daft-table/src/ops/hash.rs b/src/daft-table/src/ops/hash.rs
index 735fa01199..d71ec9d3d8 100644
--- a/src/daft-table/src/ops/hash.rs
+++ b/src/daft-table/src/ops/hash.rs
@@ -103,4 +103,47 @@ impl Table {
}
Ok(probe_table)
}
+
+ pub fn to_probe_hash_map_without_idx(
+ &self,
+ ) -> DaftResult> {
+ let hashes = self.hash_rows()?;
+
+ const DEFAULT_SIZE: usize = 20;
+ let comparator = build_multi_array_is_equal(
+ self.columns.as_slice(),
+ self.columns.as_slice(),
+ true,
+ true,
+ )?;
+
+ let mut probe_table =
+ HashMap::::with_capacity_and_hasher(
+ DEFAULT_SIZE,
+ Default::default(),
+ );
+ // TODO(Sammy): Drop nulls using validity array if requested
+ for (i, h) in hashes.as_arrow().values_iter().enumerate() {
+ let entry = probe_table.raw_entry_mut().from_hash(*h, |other| {
+ (*h == other.hash) && {
+ let j = other.idx;
+ comparator(i, j as usize)
+ }
+ });
+ match entry {
+ RawEntryMut::Vacant(entry) => {
+ entry.insert_hashed_nocheck(
+ *h,
+ IndexHash {
+ idx: i as u64,
+ hash: *h,
+ },
+ (),
+ );
+ }
+ RawEntryMut::Occupied(_) => {}
+ }
+ }
+ Ok(probe_table)
+ }
}
diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs
index 6ae40a8f3a..1b8ad0bc62 100644
--- a/src/daft-table/src/ops/joins/hash_join.rs
+++ b/src/daft-table/src/ops/joins/hash_join.rs
@@ -4,7 +4,7 @@ use arrow2::{bitmap::MutableBitmap, types::IndexRange};
use daft_core::{
array::ops::{arrow2::comparison::build_multi_array_is_equal, full::FullNull},
datatypes::{BooleanArray, UInt64Array},
- DataType, IntoSeries,
+ DataType, IntoSeries, JoinType,
};
use daft_dsl::ExprRef;
@@ -21,7 +21,13 @@ pub(super) fn hash_inner_join(
left_on: &[ExprRef],
right_on: &[ExprRef],
) -> DaftResult {
- let join_schema = infer_join_schema(&left.schema, &right.schema, left_on, right_on)?;
+ let join_schema = infer_join_schema(
+ &left.schema,
+ &right.schema,
+ left_on,
+ right_on,
+ JoinType::Inner,
+ )?;
let lkeys = left.eval_expression_list(left_on)?;
let rkeys = right.eval_expression_list(right_on)?;
@@ -103,7 +109,13 @@ pub(super) fn hash_left_right_join(
right_on: &[ExprRef],
left_side: bool,
) -> DaftResult {
- let join_schema = infer_join_schema(&left.schema, &right.schema, left_on, right_on)?;
+ let join_schema = infer_join_schema(
+ &left.schema,
+ &right.schema,
+ left_on,
+ right_on,
+ JoinType::Right,
+ )?;
let lkeys = left.eval_expression_list(left_on)?;
let rkeys = right.eval_expression_list(right_on)?;
@@ -212,13 +224,80 @@ pub(super) fn hash_left_right_join(
Table::new(join_schema, join_series)
}
+pub(super) fn hash_semi_anti_join(
+ left: &Table,
+ right: &Table,
+ left_on: &[ExprRef],
+ right_on: &[ExprRef],
+ is_anti: bool,
+) -> DaftResult {
+ let lkeys = left.eval_expression_list(left_on)?;
+ let rkeys = right.eval_expression_list(right_on)?;
+
+ let (lkeys, rkeys) = match_types_for_tables(&lkeys, &rkeys)?;
+
+ let lidx = if lkeys.columns.iter().any(|s| s.data_type().is_null())
+ || rkeys.columns.iter().any(|s| s.data_type().is_null())
+ {
+ if is_anti {
+ // if we have a null column match, then all of the rows match for an anti join!
+ return Ok(left.clone());
+ } else {
+ UInt64Array::empty("left_indices", &DataType::UInt64).into_series()
+ }
+ } else {
+ let probe_table = rkeys.to_probe_hash_map_without_idx()?;
+
+ let l_hashes = lkeys.hash_rows()?;
+
+ let is_equal = build_multi_array_is_equal(
+ lkeys.columns.as_slice(),
+ rkeys.columns.as_slice(),
+ false,
+ false,
+ )?;
+ let rows = rkeys.len();
+
+ drop(lkeys);
+ drop(rkeys);
+
+ let mut left_idx = Vec::with_capacity(rows);
+ let is_semi = !is_anti;
+ for (l_idx, h) in l_hashes.as_arrow().values_iter().enumerate() {
+ let is_match = probe_table
+ .raw_entry()
+ .from_hash(*h, |other| {
+ *h == other.hash && {
+ let r_idx = other.idx as usize;
+ is_equal(l_idx, r_idx)
+ }
+ })
+ .is_some();
+ dbg!(l_idx);
+ if is_match == is_semi {
+ left_idx.push(l_idx as u64);
+ }
+ }
+
+ UInt64Array::from(("left_indices", left_idx)).into_series()
+ };
+
+ left.take(&lidx)
+}
+
pub(super) fn hash_outer_join(
left: &Table,
right: &Table,
left_on: &[ExprRef],
right_on: &[ExprRef],
) -> DaftResult {
- let join_schema = infer_join_schema(&left.schema, &right.schema, left_on, right_on)?;
+ let join_schema = infer_join_schema(
+ &left.schema,
+ &right.schema,
+ left_on,
+ right_on,
+ JoinType::Outer,
+ )?;
let lkeys = left.eval_expression_list(left_on)?;
let rkeys = right.eval_expression_list(right_on)?;
diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs
index d0805258f2..418b12a623 100644
--- a/src/daft-table/src/ops/joins/mod.rs
+++ b/src/daft-table/src/ops/joins/mod.rs
@@ -1,12 +1,18 @@
-use std::collections::{HashMap, HashSet};
+use std::{
+ collections::{HashMap, HashSet},
+ sync::Arc,
+};
use daft_core::{
- array::growable::make_growable, schema::Schema, utils::supertype::try_get_supertype, JoinType,
- Series,
+ array::growable::make_growable,
+ schema::{Schema, SchemaRef},
+ utils::supertype::try_get_supertype,
+ JoinType, Series,
};
use common_error::{DaftError, DaftResult};
use daft_dsl::ExprRef;
+use hash_join::hash_semi_anti_join;
use crate::Table;
@@ -36,11 +42,12 @@ fn match_types_for_tables(left: &Table, right: &Table) -> DaftResult<(Table, Tab
}
pub fn infer_join_schema(
- left: &Schema,
- right: &Schema,
+ left: &SchemaRef,
+ right: &SchemaRef,
left_on: &[ExprRef],
right_on: &[ExprRef],
-) -> DaftResult {
+ how: JoinType,
+) -> DaftResult {
if left_on.len() != right_on.len() {
return Err(DaftError::ValueError(format!(
"Length of left_on does not match length of right_on for Join {} vs {}",
@@ -48,6 +55,9 @@ pub fn infer_join_schema(
right_on.len()
)));
}
+ if matches!(how, JoinType::Anti | JoinType::Semi) {
+ return Ok(left.clone());
+ }
let lfields = left_on
.iter()
@@ -104,8 +114,8 @@ pub fn infer_join_schema(
join_fields.push(field.rename(curr_name.clone()));
names_so_far.insert(curr_name.clone());
}
-
- Schema::new(join_fields)
+ let schema = Schema::new(join_fields)?;
+ Ok(Arc::new(schema))
}
fn add_non_join_key_columns(
@@ -199,6 +209,8 @@ impl Table {
JoinType::Left => hash_left_right_join(self, right, left_on, right_on, true),
JoinType::Right => hash_left_right_join(self, right, left_on, right_on, false),
JoinType::Outer => hash_outer_join(self, right, left_on, right_on),
+ JoinType::Semi => hash_semi_anti_join(self, right, left_on, right_on, false),
+ JoinType::Anti => hash_semi_anti_join(self, right, left_on, right_on, true),
}
}
@@ -239,7 +251,13 @@ impl Table {
return left.sort_merge_join(&right, left_on, right_on, true);
}
- let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?;
+ let join_schema = infer_join_schema(
+ &self.schema,
+ &right.schema,
+ left_on,
+ right_on,
+ JoinType::Inner,
+ )?;
let ltable = self.eval_expression_list(left_on)?;
let rtable = right.eval_expression_list(right_on)?;
diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py
index 0c47513cd6..b416319637 100644
--- a/tests/dataframe/test_joins.py
+++ b/tests/dataframe/test_joins.py
@@ -13,6 +13,10 @@ def skip_invalid_join_strategies(join_strategy, join_type):
pytest.skip("Sort merge currently only supports inner joins")
elif join_strategy == "broadcast" and join_type == "outer":
pytest.skip("Broadcast join does not support outer joins")
+ elif join_strategy == "broadcast" and join_type == "anti":
+ pytest.skip("Broadcast join does not support anti joins")
+ elif join_strategy == "broadcast" and join_type == "semi":
+ pytest.skip("Broadcast join does not support semi joins")
def test_invalid_join_strategies(make_df):
@@ -720,3 +724,56 @@ def test_join_null_type_column(join_strategy, join_type, make_df):
with pytest.raises((ExpressionTypeError, ValueError)):
daft_df.join(daft_df2, on="id", how=join_type, strategy=join_strategy)
+
+
+@pytest.mark.parametrize("repartition_nparts", [1, 2, 4])
+@pytest.mark.parametrize(
+ "join_strategy",
+ [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"],
+ indirect=True,
+)
+@pytest.mark.parametrize(
+ "join_type,expected",
+ [
+ (
+ "semi",
+ {
+ "id": [2, 3],
+ "values_left": ["b1", "c1"],
+ },
+ ),
+ (
+ "anti",
+ {
+ "id": [1, None],
+ "values_left": ["a1", "d1"],
+ },
+ ),
+ ],
+)
+def test_join_semi_anti(join_strategy, join_type, expected, make_df, repartition_nparts):
+ skip_invalid_join_strategies(join_strategy, join_type)
+
+ daft_df1 = make_df(
+ {
+ "id": [1, 2, 3, None],
+ "values_left": ["a1", "b1", "c1", "d1"],
+ },
+ repartition=repartition_nparts,
+ )
+ daft_df2 = make_df(
+ {
+ "id": [2, 2, 3, 4],
+ "values_right": ["a2", "b2", "c2", "d2"],
+ },
+ repartition=repartition_nparts,
+ )
+ daft_df = (
+ daft_df1.with_column("id", daft_df1["id"].cast(DataType.int64()))
+ .join(daft_df2, on="id", how=join_type, strategy=join_strategy)
+ .sort(["id", "values_left"])
+ ).select("id", "values_left")
+
+ assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "id") == sort_arrow_table(
+ pa.Table.from_pydict(expected), "id"
+ )