Skip to content

Commit

Permalink
[FEAT] any_value groupby aggregation (#1941)
Browse files Browse the repository at this point in the history
This function is parameterized by `ignore_nulls`, which attempts to find
a non-null value in each group when true. However, usage of this
parameter in the aggregation function would require some changes to
`DataFrame._agg()` that I am going to save for later, since these
changes will probably not be needed anymore once global expressions can
be passed into GroupBy operations

Also in this PR: fixes to the `count` aggregation function
  • Loading branch information
kevinzwang authored Feb 29, 2024
1 parent dff4933 commit 73a71e6
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 4 deletions.
1 change: 1 addition & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ class PyExpr:
def mean(self) -> PyExpr: ...
def min(self) -> PyExpr: ...
def max(self) -> PyExpr: ...
def any_value(self, ignore_nulls: bool) -> PyExpr: ...
def agg_list(self) -> PyExpr: ...
def agg_concat(self) -> PyExpr: ...
def explode(self) -> PyExpr: ...
Expand Down
12 changes: 12 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,18 @@ def max(self, *cols: ColumnInputType) -> "DataFrame":

return self.df._agg([(c, "max") for c in cols], group_by=self.group_by)

def any_value(self, *cols: ColumnInputType) -> "DataFrame":
"""Returns an arbitrary value on this GroupedDataFrame.
Values for each column are not guaranteed to be from the same row.
Args:
*cols (Union[str, Expression]): columns to get
Returns:
DataFrame: DataFrame with any values.
"""
return self.df._agg([(c, "any_value") for c in cols], group_by=self.group_by)

def count(self) -> "DataFrame":
"""Performs grouped count on this GroupedDataFrame.
Expand Down
4 changes: 4 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ def _max(self) -> Expression:
expr = self._expr.max()
return Expression._from_pyexpr(expr)

def _any_value(self, ignore_nulls=False) -> Expression:
expr = self._expr.any_value(ignore_nulls)
return Expression._from_pyexpr(expr)

def _agg_list(self) -> Expression:
expr = self._expr.agg_list()
return Expression._from_pyexpr(expr)
Expand Down
2 changes: 2 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def agg(
exprs.append(expr._max())
elif op == "mean":
exprs.append(expr._mean())
elif op == "any_value":
exprs.append(expr._any_value())
elif op == "list":
exprs.append(expr._agg_list())
elif op == "concat":
Expand Down
10 changes: 7 additions & 3 deletions src/daft-core/src/array/ops/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ fn grouped_count_arrow_bitmap(
None => repeat(0).take(groups.len()).collect(), // None of the values are Null
Some(validity) => groups
.iter()
.map(|g| g.iter().map(|i| validity.get_bit(*i as usize) as u64).sum())
.map(|g| {
g.iter()
.map(|i| validity.get_bit(!*i as usize) as u64)
.sum()
})
.collect(),
},
}
Expand All @@ -46,11 +50,11 @@ fn count_arrow_bitmap(
CountMode::All => arr_len as u64,
CountMode::Valid => match arrow_bitmap {
None => arr_len as u64,
Some(validity) => validity.into_iter().map(|b| b as u64).sum(),
Some(validity) => (validity.len() - validity.unset_bits()) as u64,
},
CountMode::Null => match arrow_bitmap {
None => 0,
Some(validity) => validity.into_iter().map(|b| !b as u64).sum(),
Some(validity) => validity.unset_bits() as u64,
},
}
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/series/array_impl/logical_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ macro_rules! impl_series_like_for_logical_array {
};
Ok($da::new(self.0.field.clone(), data_array).into_series())
}

fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
use crate::array::{ops::DaftListAggable, ListArray};
let data_array = match groups {
Expand Down
37 changes: 37 additions & 0 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::array::ListArray;
use crate::count_mode::CountMode;
use crate::series::IntoSeries;
use crate::{array::ops::GroupIndices, series::Series, with_match_physical_daft_types};
use arrow2::array::PrimitiveArray;
use common_error::{DaftError, DaftResult};

use crate::datatypes::*;
Expand Down Expand Up @@ -97,6 +98,42 @@ impl Series {
self.inner.max(groups)
}

pub fn any_value(
&self,
groups: Option<&GroupIndices>,
ignore_nulls: bool,
) -> DaftResult<Series> {
let indices = match groups {
Some(groups) => {
if self.data_type().is_null() {
Box::new(PrimitiveArray::new_null(arrow2::datatypes::DataType::UInt64, groups.len()))
} else if ignore_nulls && let Some(validity) = self.validity() {
Box::new(PrimitiveArray::from_trusted_len_iter(groups.iter().map(|g| {
g.iter().find(|i| validity.get_bit(**i as usize)).copied()
})))
} else {
Box::new(PrimitiveArray::from_trusted_len_iter(groups.iter().map(|g| g.first().cloned())))
}
},
None => {
let idx = if self.data_type().is_null() || self.is_empty(){
None
} else if ignore_nulls && let Some(validity) = self.validity() {
validity.iter().position(|v| v).map(|i| i as u64)
} else {
Some(0)
};

Box::new(PrimitiveArray::from([idx]))
}
};

self.take(&Series::from_arrow(
Field::new("", DataType::UInt64).into(),
indices,
)?)
}

pub fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
self.inner.agg_list(groups)
}
Expand Down
18 changes: 17 additions & 1 deletion src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub enum AggExpr {
Mean(ExprRef),
Min(ExprRef),
Max(ExprRef),
AnyValue(ExprRef, bool),
List(ExprRef),
Concat(ExprRef),
MapGroups {
Expand Down Expand Up @@ -87,6 +88,7 @@ impl AggExpr {
| Mean(expr)
| Min(expr)
| Max(expr)
| AnyValue(expr, _)
| List(expr)
| Concat(expr) => expr.name(),
MapGroups { func: _, inputs } => inputs.first().unwrap().name(),
Expand Down Expand Up @@ -116,6 +118,12 @@ impl AggExpr {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_max()"))
}
AnyValue(expr, ignore_nulls) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!(
"{child_id}.local_any_value(ignore_nulls={ignore_nulls})"
))
}
List(expr) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_list()"))
Expand All @@ -136,6 +144,7 @@ impl AggExpr {
| Mean(expr)
| Min(expr)
| Max(expr)
| AnyValue(expr, _)
| List(expr)
| Concat(expr) => vec![expr.clone()],
MapGroups { func: _, inputs } => inputs.iter().map(|e| e.clone().into()).collect(),
Expand Down Expand Up @@ -196,7 +205,7 @@ impl AggExpr {
},
))
}
Min(expr) | Max(expr) => {
Min(expr) | Max(expr) | AnyValue(expr, _) => {
let field = expr.to_field(schema)?;
Ok(Field::new(field.name.as_str(), field.dtype))
}
Expand Down Expand Up @@ -279,6 +288,10 @@ impl Expr {
Expr::Agg(AggExpr::Max(self.clone().into()))
}

pub fn any_value(&self, ignore_nulls: bool) -> Self {
Expr::Agg(AggExpr::AnyValue(self.clone().into(), ignore_nulls))
}

pub fn agg_list(&self) -> Self {
Expr::Agg(AggExpr::List(self.clone().into()))
}
Expand Down Expand Up @@ -604,6 +617,9 @@ impl Display for AggExpr {
Mean(expr) => write!(f, "mean({expr})"),
Min(expr) => write!(f, "min({expr})"),
Max(expr) => write!(f, "max({expr})"),
AnyValue(expr, ignore_nulls) => {
write!(f, "any_value({expr}, ignore_nulls={ignore_nulls})")
}
List(expr) => write!(f, "list({expr})"),
Concat(expr) => write!(f, "list({expr})"),
MapGroups { func, inputs } => function_display(f, func, inputs),
Expand Down
4 changes: 4 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ impl PyExpr {
Ok(self.expr.max().into())
}

pub fn any_value(&self, ignore_nulls: bool) -> PyResult<Self> {
Ok(self.expr.any_value(ignore_nulls).into())
}

pub fn agg_list(&self) -> PyResult<Self> {
Ok(self.expr.agg_list().into())
}
Expand Down
4 changes: 4 additions & 0 deletions src/daft-dsl/src/treenode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl TreeNode for Expr {
| Mean(expr)
| Min(expr)
| Max(expr)
| AnyValue(expr, _)
| List(expr)
| Concat(expr) => vec![expr.as_ref()],
MapGroups { func: _, inputs } => inputs.iter().collect::<Vec<_>>(),
Expand Down Expand Up @@ -65,6 +66,9 @@ impl TreeNode for Expr {
Mean(expr) => transform(expr.as_ref().clone())?.mean(),
Min(expr) => transform(expr.as_ref().clone())?.min(),
Max(expr) => transform(expr.as_ref().clone())?.max(),
AnyValue(expr, ignore_nulls) => {
transform(expr.as_ref().clone())?.any_value(ignore_nulls)
}
List(expr) => transform(expr.as_ref().clone())?.agg_list(),
Concat(expr) => transform(expr.as_ref().clone())?.agg_concat(),
MapGroups { func, inputs } => Expr::Agg(MapGroups {
Expand Down
6 changes: 6 additions & 0 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,12 @@ fn replace_column_with_semantic_id_aggexpr(
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Max, |_| e.clone())
}
AggExpr::AnyValue(ref child, ignore_nulls) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no(
|transformed_child| AggExpr::AnyValue(transformed_child, ignore_nulls),
|_| e.clone(),
)
}
AggExpr::List(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::List, |_| e.clone())
Expand Down
17 changes: 17 additions & 0 deletions src/daft-plan/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,23 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc<DaftExecutionConfig>) -> DaftRe
.into()));
final_exprs.push(Column(max_of_max_id.clone()).alias(output_name));
}
AnyValue(e, ignore_nulls) => {
let any_id = agg_expr.semantic_id(&schema).id;
let any_of_any_id =
AnyValue(Column(any_id.clone()).into(), *ignore_nulls)
.semantic_id(&schema)
.id;
first_stage_aggs.entry(any_id.clone()).or_insert(AnyValue(
e.alias(any_id.clone()).clone().into(),
*ignore_nulls,
));
second_stage_aggs
.entry(any_of_any_id.clone())
.or_insert(AnyValue(
Column(any_id.clone()).alias(any_of_any_id.clone()).into(),
*ignore_nulls,
));
}
List(e) => {
let list_id = agg_expr.semantic_id(&schema).id;
let concat_of_list_id = Concat(Column(list_id.clone()).into())
Expand Down
3 changes: 3 additions & 0 deletions src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ impl Table {
Mean(expr) => Series::mean(&self.eval_expression(expr)?, groups),
Min(expr) => Series::min(&self.eval_expression(expr)?, groups),
Max(expr) => Series::max(&self.eval_expression(expr)?, groups),
AnyValue(expr, ignore_nulls) => {
Series::any_value(&self.eval_expression(expr)?, groups, *ignore_nulls)
}
List(expr) => Series::agg_list(&self.eval_expression(expr)?, groups),
Concat(expr) => Series::agg_concat(&self.eval_expression(expr)?, groups),
MapGroups { .. } => Err(DaftError::ValueError(
Expand Down
66 changes: 66 additions & 0 deletions tests/table/test_table_aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,72 @@ def test_table_sum_badtype() -> None:
daft_table = daft_table.eval_expression_list([col("a")._sum()])


test_micropartition_any_value_cases = [
(
MicroPartition.from_pydict({"a": [None, 1, None, None], "b": ["a", "a", "b", "b"]}), # 1 table
{"a": True, "b": True},
{"a": False, "b": True},
),
(
MicroPartition.concat(
[
MicroPartition.from_pydict({"a": np.array([]).astype(np.int64), "b": pa.array([], type=pa.string())}),
MicroPartition.from_pydict({"a": [None, 3, None], "b": ["a", "b", "b"]}),
]
), # 2 tables
{"a": True, "b": True},
{"a": True, "b": False},
),
(
MicroPartition.concat(
[
MicroPartition.from_pydict({"a": np.array([]).astype(np.int64), "b": pa.array([], type=pa.string())}),
MicroPartition.from_pydict({"a": [None, 3, None], "b": ["a", "b", "b"]}),
MicroPartition.from_pydict({"a": [1], "b": ["a"]}),
]
), # 3 tables
{"a": True, "b": True},
{"a": False, "b": False},
),
]


@pytest.mark.parametrize("mp,expected_nulls,expected_no_nulls", test_micropartition_any_value_cases)
def test_micropartition_any_value(mp, expected_nulls, expected_no_nulls):
any_values = mp.agg([col("a")._any_value(False)], group_by=[col("b")]).to_pydict()
assert len(any_values["b"]) == len(expected_nulls)
for k, v in zip(any_values["b"], any_values["a"]):
assert expected_nulls[k] or v is not None

any_values = mp.agg([col("a")._any_value(True)], group_by=[col("b")]).to_pydict()
assert len(any_values["b"]) == len(expected_no_nulls)
for k, v in zip(any_values["b"], any_values["a"]):
assert expected_no_nulls[k] or v is not None


test_table_any_value_cases = [
({"a": [1], "b": ["a"]}, {"a": False}, {"a": False}),
({"a": [None], "b": ["a"]}, {"a": True}, {"a": True}),
({"a": [None, 1], "b": ["a", "a"]}, {"a": True}, {"a": False}),
({"a": [1, None, 2], "b": ["a", "b", "b"]}, {"a": False, "b": True}, {"a": False, "b": False}),
]


@pytest.mark.parametrize("case,expected_nulls,expected_no_nulls", test_table_any_value_cases)
def test_table_any_value(case, expected_nulls, expected_no_nulls):
daft_table = MicroPartition.from_pydict(case)

any_values = daft_table.agg([col("a")._any_value(False)], group_by=[col("b")]).to_pydict()
assert len(any_values["b"]) == len(expected_nulls)
for k, v in zip(any_values["b"], any_values["a"]):
assert expected_nulls[k] or v is not None

any_values = daft_table.agg([col("a")._any_value(True)], group_by=[col("b")]).to_pydict()
assert len(any_values["b"]) == len(expected_no_nulls)
for k, v in zip(any_values["b"], any_values["a"]):
assert expected_no_nulls[k] or v is not None


test_table_agg_global_cases = [
(
[],
Expand Down

0 comments on commit 73a71e6

Please sign in to comment.