From 73a71e621bd15a440e93608c5b5ed55c76d67b97 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Thu, 29 Feb 2024 11:47:41 -0800 Subject: [PATCH] [FEAT] `any_value` groupby aggregation (#1941) This function is parameterized by `ignore_nulls`, which attempts to find a non-null value in each group when true. However, usage of this parameter in the aggregation function would require some changes to `DataFrame._agg()` that I am going to save for later, since these changes will probably not be needed anymore once global expressions can be passed into GroupBy operations Also in this PR: fixes to the `count` aggregation function --- daft/daft.pyi | 1 + daft/dataframe/dataframe.py | 12 ++++ daft/expressions/expressions.py | 4 ++ daft/logical/builder.py | 2 + src/daft-core/src/array/ops/count.rs | 10 ++- .../src/series/array_impl/logical_array.rs | 1 + src/daft-core/src/series/ops/agg.rs | 37 +++++++++++ src/daft-dsl/src/expr.rs | 18 ++++- src/daft-dsl/src/python.rs | 4 ++ src/daft-dsl/src/treenode.rs | 4 ++ src/daft-plan/src/logical_ops/project.rs | 6 ++ src/daft-plan/src/planner.rs | 17 +++++ src/daft-table/src/lib.rs | 3 + tests/table/test_table_aggs.py | 66 +++++++++++++++++++ 14 files changed, 181 insertions(+), 4 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index eb30a512c3..038072bcab 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -878,6 +878,7 @@ class PyExpr: def mean(self) -> PyExpr: ... def min(self) -> PyExpr: ... def max(self) -> PyExpr: ... + def any_value(self, ignore_nulls: bool) -> PyExpr: ... def agg_list(self) -> PyExpr: ... def agg_concat(self) -> PyExpr: ... def explode(self) -> PyExpr: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 32f8bb726e..69994a0a1f 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1546,6 +1546,18 @@ def max(self, *cols: ColumnInputType) -> "DataFrame": return self.df._agg([(c, "max") for c in cols], group_by=self.group_by) + def any_value(self, *cols: ColumnInputType) -> "DataFrame": + """Returns an arbitrary value on this GroupedDataFrame. + Values for each column are not guaranteed to be from the same row. + + Args: + *cols (Union[str, Expression]): columns to get + + Returns: + DataFrame: DataFrame with any values. + """ + return self.df._agg([(c, "any_value") for c in cols], group_by=self.group_by) + def count(self) -> "DataFrame": """Performs grouped count on this GroupedDataFrame. diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 449ea2bb23..76031caaf0 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -352,6 +352,10 @@ def _max(self) -> Expression: expr = self._expr.max() return Expression._from_pyexpr(expr) + def _any_value(self, ignore_nulls=False) -> Expression: + expr = self._expr.any_value(ignore_nulls) + return Expression._from_pyexpr(expr) + def _agg_list(self) -> Expression: expr = self._expr.agg_list() return Expression._from_pyexpr(expr) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 996c52ec64..53116b7e5f 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -150,6 +150,8 @@ def agg( exprs.append(expr._max()) elif op == "mean": exprs.append(expr._mean()) + elif op == "any_value": + exprs.append(expr._any_value()) elif op == "list": exprs.append(expr._agg_list()) elif op == "concat": diff --git a/src/daft-core/src/array/ops/count.rs b/src/daft-core/src/array/ops/count.rs index 84595cf710..04f9fdad1d 100644 --- a/src/daft-core/src/array/ops/count.rs +++ b/src/daft-core/src/array/ops/count.rs @@ -30,7 +30,11 @@ fn grouped_count_arrow_bitmap( None => repeat(0).take(groups.len()).collect(), // None of the values are Null Some(validity) => groups .iter() - .map(|g| g.iter().map(|i| validity.get_bit(*i as usize) as u64).sum()) + .map(|g| { + g.iter() + .map(|i| validity.get_bit(!*i as usize) as u64) + .sum() + }) .collect(), }, } @@ -46,11 +50,11 @@ fn count_arrow_bitmap( CountMode::All => arr_len as u64, CountMode::Valid => match arrow_bitmap { None => arr_len as u64, - Some(validity) => validity.into_iter().map(|b| b as u64).sum(), + Some(validity) => (validity.len() - validity.unset_bits()) as u64, }, CountMode::Null => match arrow_bitmap { None => 0, - Some(validity) => validity.into_iter().map(|b| !b as u64).sum(), + Some(validity) => validity.unset_bits() as u64, }, } } diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index 8f4e82aab9..f8dc7a2e94 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -157,6 +157,7 @@ macro_rules! impl_series_like_for_logical_array { }; Ok($da::new(self.0.field.clone(), data_array).into_series()) } + fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { use crate::array::{ops::DaftListAggable, ListArray}; let data_array = match groups { diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index c55295cb5e..f5a6b86fe8 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -2,6 +2,7 @@ use crate::array::ListArray; use crate::count_mode::CountMode; use crate::series::IntoSeries; use crate::{array::ops::GroupIndices, series::Series, with_match_physical_daft_types}; +use arrow2::array::PrimitiveArray; use common_error::{DaftError, DaftResult}; use crate::datatypes::*; @@ -97,6 +98,42 @@ impl Series { self.inner.max(groups) } + pub fn any_value( + &self, + groups: Option<&GroupIndices>, + ignore_nulls: bool, + ) -> DaftResult { + let indices = match groups { + Some(groups) => { + if self.data_type().is_null() { + Box::new(PrimitiveArray::new_null(arrow2::datatypes::DataType::UInt64, groups.len())) + } else if ignore_nulls && let Some(validity) = self.validity() { + Box::new(PrimitiveArray::from_trusted_len_iter(groups.iter().map(|g| { + g.iter().find(|i| validity.get_bit(**i as usize)).copied() + }))) + } else { + Box::new(PrimitiveArray::from_trusted_len_iter(groups.iter().map(|g| g.first().cloned()))) + } + }, + None => { + let idx = if self.data_type().is_null() || self.is_empty(){ + None + } else if ignore_nulls && let Some(validity) = self.validity() { + validity.iter().position(|v| v).map(|i| i as u64) + } else { + Some(0) + }; + + Box::new(PrimitiveArray::from([idx])) + } + }; + + self.take(&Series::from_arrow( + Field::new("", DataType::UInt64).into(), + indices, + )?) + } + pub fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { self.inner.agg_list(groups) } diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index a29a2b2871..d66ce7df55 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -58,6 +58,7 @@ pub enum AggExpr { Mean(ExprRef), Min(ExprRef), Max(ExprRef), + AnyValue(ExprRef, bool), List(ExprRef), Concat(ExprRef), MapGroups { @@ -87,6 +88,7 @@ impl AggExpr { | Mean(expr) | Min(expr) | Max(expr) + | AnyValue(expr, _) | List(expr) | Concat(expr) => expr.name(), MapGroups { func: _, inputs } => inputs.first().unwrap().name(), @@ -116,6 +118,12 @@ impl AggExpr { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_max()")) } + AnyValue(expr, ignore_nulls) => { + let child_id = expr.semantic_id(schema); + FieldID::new(format!( + "{child_id}.local_any_value(ignore_nulls={ignore_nulls})" + )) + } List(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_list()")) @@ -136,6 +144,7 @@ impl AggExpr { | Mean(expr) | Min(expr) | Max(expr) + | AnyValue(expr, _) | List(expr) | Concat(expr) => vec![expr.clone()], MapGroups { func: _, inputs } => inputs.iter().map(|e| e.clone().into()).collect(), @@ -196,7 +205,7 @@ impl AggExpr { }, )) } - Min(expr) | Max(expr) => { + Min(expr) | Max(expr) | AnyValue(expr, _) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), field.dtype)) } @@ -279,6 +288,10 @@ impl Expr { Expr::Agg(AggExpr::Max(self.clone().into())) } + pub fn any_value(&self, ignore_nulls: bool) -> Self { + Expr::Agg(AggExpr::AnyValue(self.clone().into(), ignore_nulls)) + } + pub fn agg_list(&self) -> Self { Expr::Agg(AggExpr::List(self.clone().into())) } @@ -604,6 +617,9 @@ impl Display for AggExpr { Mean(expr) => write!(f, "mean({expr})"), Min(expr) => write!(f, "min({expr})"), Max(expr) => write!(f, "max({expr})"), + AnyValue(expr, ignore_nulls) => { + write!(f, "any_value({expr}, ignore_nulls={ignore_nulls})") + } List(expr) => write!(f, "list({expr})"), Concat(expr) => write!(f, "list({expr})"), MapGroups { func, inputs } => function_display(f, func, inputs), diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 0628b0b47c..a79beb9aef 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -184,6 +184,10 @@ impl PyExpr { Ok(self.expr.max().into()) } + pub fn any_value(&self, ignore_nulls: bool) -> PyResult { + Ok(self.expr.any_value(ignore_nulls).into()) + } + pub fn agg_list(&self) -> PyResult { Ok(self.expr.agg_list().into()) } diff --git a/src/daft-dsl/src/treenode.rs b/src/daft-dsl/src/treenode.rs index 81452bcdae..3f2d05453b 100644 --- a/src/daft-dsl/src/treenode.rs +++ b/src/daft-dsl/src/treenode.rs @@ -21,6 +21,7 @@ impl TreeNode for Expr { | Mean(expr) | Min(expr) | Max(expr) + | AnyValue(expr, _) | List(expr) | Concat(expr) => vec![expr.as_ref()], MapGroups { func: _, inputs } => inputs.iter().collect::>(), @@ -65,6 +66,9 @@ impl TreeNode for Expr { Mean(expr) => transform(expr.as_ref().clone())?.mean(), Min(expr) => transform(expr.as_ref().clone())?.min(), Max(expr) => transform(expr.as_ref().clone())?.max(), + AnyValue(expr, ignore_nulls) => { + transform(expr.as_ref().clone())?.any_value(ignore_nulls) + } List(expr) => transform(expr.as_ref().clone())?.agg_list(), Concat(expr) => transform(expr.as_ref().clone())?.agg_concat(), MapGroups { func, inputs } => Expr::Agg(MapGroups { diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 50ffb2fe78..124ff20f82 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -370,6 +370,12 @@ fn replace_column_with_semantic_id_aggexpr( replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::Max, |_| e.clone()) } + AggExpr::AnyValue(ref child, ignore_nulls) => { + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( + |transformed_child| AggExpr::AnyValue(transformed_child, ignore_nulls), + |_| e.clone(), + ) + } AggExpr::List(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no(AggExpr::List, |_| e.clone()) diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 6a507a72b7..fd7f212229 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -397,6 +397,23 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe .into())); final_exprs.push(Column(max_of_max_id.clone()).alias(output_name)); } + AnyValue(e, ignore_nulls) => { + let any_id = agg_expr.semantic_id(&schema).id; + let any_of_any_id = + AnyValue(Column(any_id.clone()).into(), *ignore_nulls) + .semantic_id(&schema) + .id; + first_stage_aggs.entry(any_id.clone()).or_insert(AnyValue( + e.alias(any_id.clone()).clone().into(), + *ignore_nulls, + )); + second_stage_aggs + .entry(any_of_any_id.clone()) + .or_insert(AnyValue( + Column(any_id.clone()).alias(any_of_any_id.clone()).into(), + *ignore_nulls, + )); + } List(e) => { let list_id = agg_expr.semantic_id(&schema).id; let concat_of_list_id = Concat(Column(list_id.clone()).into()) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 0dcd16544c..e3ef21f3dd 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -325,6 +325,9 @@ impl Table { Mean(expr) => Series::mean(&self.eval_expression(expr)?, groups), Min(expr) => Series::min(&self.eval_expression(expr)?, groups), Max(expr) => Series::max(&self.eval_expression(expr)?, groups), + AnyValue(expr, ignore_nulls) => { + Series::any_value(&self.eval_expression(expr)?, groups, *ignore_nulls) + } List(expr) => Series::agg_list(&self.eval_expression(expr)?, groups), Concat(expr) => Series::agg_concat(&self.eval_expression(expr)?, groups), MapGroups { .. } => Err(DaftError::ValueError( diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index b10bd28534..5671531e5d 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -232,6 +232,72 @@ def test_table_sum_badtype() -> None: daft_table = daft_table.eval_expression_list([col("a")._sum()]) +test_micropartition_any_value_cases = [ + ( + MicroPartition.from_pydict({"a": [None, 1, None, None], "b": ["a", "a", "b", "b"]}), # 1 table + {"a": True, "b": True}, + {"a": False, "b": True}, + ), + ( + MicroPartition.concat( + [ + MicroPartition.from_pydict({"a": np.array([]).astype(np.int64), "b": pa.array([], type=pa.string())}), + MicroPartition.from_pydict({"a": [None, 3, None], "b": ["a", "b", "b"]}), + ] + ), # 2 tables + {"a": True, "b": True}, + {"a": True, "b": False}, + ), + ( + MicroPartition.concat( + [ + MicroPartition.from_pydict({"a": np.array([]).astype(np.int64), "b": pa.array([], type=pa.string())}), + MicroPartition.from_pydict({"a": [None, 3, None], "b": ["a", "b", "b"]}), + MicroPartition.from_pydict({"a": [1], "b": ["a"]}), + ] + ), # 3 tables + {"a": True, "b": True}, + {"a": False, "b": False}, + ), +] + + +@pytest.mark.parametrize("mp,expected_nulls,expected_no_nulls", test_micropartition_any_value_cases) +def test_micropartition_any_value(mp, expected_nulls, expected_no_nulls): + any_values = mp.agg([col("a")._any_value(False)], group_by=[col("b")]).to_pydict() + assert len(any_values["b"]) == len(expected_nulls) + for k, v in zip(any_values["b"], any_values["a"]): + assert expected_nulls[k] or v is not None + + any_values = mp.agg([col("a")._any_value(True)], group_by=[col("b")]).to_pydict() + assert len(any_values["b"]) == len(expected_no_nulls) + for k, v in zip(any_values["b"], any_values["a"]): + assert expected_no_nulls[k] or v is not None + + +test_table_any_value_cases = [ + ({"a": [1], "b": ["a"]}, {"a": False}, {"a": False}), + ({"a": [None], "b": ["a"]}, {"a": True}, {"a": True}), + ({"a": [None, 1], "b": ["a", "a"]}, {"a": True}, {"a": False}), + ({"a": [1, None, 2], "b": ["a", "b", "b"]}, {"a": False, "b": True}, {"a": False, "b": False}), +] + + +@pytest.mark.parametrize("case,expected_nulls,expected_no_nulls", test_table_any_value_cases) +def test_table_any_value(case, expected_nulls, expected_no_nulls): + daft_table = MicroPartition.from_pydict(case) + + any_values = daft_table.agg([col("a")._any_value(False)], group_by=[col("b")]).to_pydict() + assert len(any_values["b"]) == len(expected_nulls) + for k, v in zip(any_values["b"], any_values["a"]): + assert expected_nulls[k] or v is not None + + any_values = daft_table.agg([col("a")._any_value(True)], group_by=[col("b")]).to_pydict() + assert len(any_values["b"]) == len(expected_no_nulls) + for k, v in zip(any_values["b"], any_values["a"]): + assert expected_no_nulls[k] or v is not None + + test_table_agg_global_cases = [ ( [],