From 93fc6ca0700032a1ee4426f7a0c48c0c0bf5ee3e Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Wed, 10 Apr 2024 17:35:24 -0700 Subject: [PATCH] [FEAT] Add basic list aggregations (#2032) --- daft/daft.pyi | 8 +- daft/expressions/expressions.py | 45 +++++- daft/series.py | 2 +- .../src/array/fixed_size_list_array.rs | 41 ++++++ src/daft-core/src/array/list_array.rs | 39 ++++++ src/daft-core/src/array/ops/count.rs | 2 +- src/daft-core/src/array/ops/list.rs | 131 ++++++++++++++---- src/daft-core/src/datatypes/agg_ops.rs | 31 +++++ src/daft-core/src/datatypes/mod.rs | 2 + src/daft-core/src/python/series.rs | 4 +- src/daft-core/src/series/ops/list.rs | 55 +++++++- src/daft-dsl/src/expr.rs | 40 +----- src/daft-dsl/src/functions/list/count.rs | 65 +++++++++ src/daft-dsl/src/functions/list/max.rs | 45 ++++++ src/daft-dsl/src/functions/list/mean.rs | 44 ++++++ src/daft-dsl/src/functions/list/min.rs | 45 ++++++ src/daft-dsl/src/functions/list/mod.rs | 57 +++++++- .../src/functions/list/{lengths.rs => sum.rs} | 27 ++-- src/daft-dsl/src/python.rs | 26 +++- src/daft-table/src/ops/explode.rs | 8 +- tests/table/list/test_list_count_lengths.py | 52 +++++++ tests/table/list/test_list_lengths.py | 10 -- tests/table/list/test_list_numeric_aggs.py | 35 +++++ 23 files changed, 700 insertions(+), 114 deletions(-) create mode 100644 src/daft-core/src/datatypes/agg_ops.rs create mode 100644 src/daft-dsl/src/functions/list/count.rs create mode 100644 src/daft-dsl/src/functions/list/max.rs create mode 100644 src/daft-dsl/src/functions/list/mean.rs create mode 100644 src/daft-dsl/src/functions/list/min.rs rename src/daft-dsl/src/functions/list/{lengths.rs => sum.rs} (51%) create mode 100644 tests/table/list/test_list_count_lengths.py delete mode 100644 tests/table/list/test_list_lengths.py create mode 100644 tests/table/list/test_list_numeric_aggs.py diff --git a/daft/daft.pyi b/daft/daft.pyi index 4ff5dc008d..aa51d38dc7 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -954,8 +954,12 @@ class PyExpr: def image_resize(self, w: int, h: int) -> PyExpr: ... def image_crop(self, bbox: PyExpr) -> PyExpr: ... def list_join(self, delimiter: PyExpr) -> PyExpr: ... - def list_lengths(self) -> PyExpr: ... + def list_count(self, mode: CountMode) -> PyExpr: ... def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ... + def list_sum(self) -> PyExpr: ... + def list_mean(self) -> PyExpr: ... + def list_min(self) -> PyExpr: ... + def list_max(self) -> PyExpr: ... def struct_get(self, name: str) -> PyExpr: ... def url_download( self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig @@ -1056,7 +1060,7 @@ class PySeries: def partitioning_years(self) -> PySeries: ... def partitioning_iceberg_bucket(self, n: int) -> PySeries: ... def partitioning_iceberg_truncate(self, w: int) -> PySeries: ... - def list_lengths(self) -> PySeries: ... + def list_count(self, mode: CountMode) -> PySeries: ... def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ... def image_decode(self, raise_error_on_failure: bool) -> PySeries: ... def image_encode(self, image_format: ImageFormat) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 5309e01d4e..081fb453da 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1116,13 +1116,24 @@ def join(self, delimiter: str | Expression) -> Expression: delimiter_expr = Expression._to_expression(delimiter) return Expression._from_pyexpr(self._expr.list_join(delimiter_expr._expr)) + def count(self, mode: CountMode = CountMode.Valid) -> Expression: + """Counts the number of elements in each list + + Args: + mode: The mode to use for counting. Defaults to CountMode.Valid + + Returns: + Expression: a UInt64 expression which is the length of each list + """ + return Expression._from_pyexpr(self._expr.list_count(mode)) + def lengths(self) -> Expression: """Gets the length of each list Returns: Expression: a UInt64 expression which is the length of each list """ - return Expression._from_pyexpr(self._expr.list_lengths()) + return Expression._from_pyexpr(self._expr.list_count(CountMode.All)) def get(self, idx: int | Expression, default: object = None) -> Expression: """Gets the element at an index in each list @@ -1138,6 +1149,38 @@ def get(self, idx: int | Expression, default: object = None) -> Expression: default_expr = lit(default) return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr)) + def sum(self) -> Expression: + """Sums each list. Empty lists and lists with all nulls yield null. + + Returns: + Expression: an expression with the type of the list values + """ + return Expression._from_pyexpr(self._expr.list_sum()) + + def mean(self) -> Expression: + """Calculates the mean of each list. If no non-null values in a list, the result is null. + + Returns: + Expression: a Float64 expression with the type of the list values + """ + return Expression._from_pyexpr(self._expr.list_mean()) + + def min(self) -> Expression: + """Calculates the minimum of each list. If no non-null values in a list, the result is null. + + Returns: + Expression: a Float64 expression with the type of the list values + """ + return Expression._from_pyexpr(self._expr.list_min()) + + def max(self) -> Expression: + """Calculates the maximum of each list. If no non-null values in a list, the result is null. + + Returns: + Expression: a Float64 expression with the type of the list values + """ + return Expression._from_pyexpr(self._expr.list_max()) + class ExpressionStructNamespace(ExpressionNamespace): def get(self, name: str) -> Expression: diff --git a/daft/series.py b/daft/series.py index 7ab9fe1051..884980b6c1 100644 --- a/daft/series.py +++ b/daft/series.py @@ -716,7 +716,7 @@ def iceberg_truncate(self, w: int) -> Series: class SeriesListNamespace(SeriesNamespace): def lengths(self) -> Series: - return Series._from_pyseries(self._series.list_lengths()) + return Series._from_pyseries(self._series.list_count(CountMode.All)) def get(self, idx: Series, default: Series) -> Series: return Series._from_pyseries(self._series.list_get(idx._series, default._series)) diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index a8faba1a19..2784f5746c 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -166,6 +166,47 @@ impl FixedSizeListArray { } } +impl<'a> IntoIterator for &'a FixedSizeListArray { + type Item = Option; + + type IntoIter = FixedSizeListArrayIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + FixedSizeListArrayIter { + array: self, + idx: 0, + } + } +} + +pub struct FixedSizeListArrayIter<'a> { + array: &'a FixedSizeListArray, + idx: usize, +} + +impl Iterator for FixedSizeListArrayIter<'_> { + type Item = Option; + + fn next(&mut self) -> Option { + if self.idx < self.array.len() { + if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) { + self.idx += 1; + Some(None) + } else { + let step = self.array.fixed_element_len(); + + let start = self.idx * step; + let end = (self.idx + 1) * step; + + self.idx += 1; + Some(Some(self.array.flat_child.slice(start, end).unwrap())) + } + } else { + None + } + } +} + #[cfg(test)] mod tests { use common_error::DaftResult; diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index 766f6d5d6b..abe879b7bd 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -178,3 +178,42 @@ impl ListArray { )) } } + +impl<'a> IntoIterator for &'a ListArray { + type Item = Option; + + type IntoIter = ListArrayIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + ListArrayIter { + array: self, + idx: 0, + } + } +} + +pub struct ListArrayIter<'a> { + array: &'a ListArray, + idx: usize, +} + +impl Iterator for ListArrayIter<'_> { + type Item = Option; + + fn next(&mut self) -> Option { + if self.idx < self.array.len() { + if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) { + self.idx += 1; + Some(None) + } else { + let start = *self.array.offsets().get(self.idx).unwrap() as usize; + let end = *self.array.offsets().get(self.idx + 1).unwrap() as usize; + + self.idx += 1; + Some(Some(self.array.flat_child.slice(start, end).unwrap())) + } + } else { + None + } + } +} diff --git a/src/daft-core/src/array/ops/count.rs b/src/daft-core/src/array/ops/count.rs index 04f9fdad1d..e71ff09e71 100644 --- a/src/daft-core/src/array/ops/count.rs +++ b/src/daft-core/src/array/ops/count.rs @@ -32,7 +32,7 @@ fn grouped_count_arrow_bitmap( .iter() .map(|g| { g.iter() - .map(|i| validity.get_bit(!*i as usize) as u64) + .map(|i| !validity.get_bit(*i as usize) as u64) .sum() }) .collect(), diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 20cdd6cead..bc93cc642a 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -1,11 +1,14 @@ use std::iter::repeat; -use crate::array::{ - growable::{make_growable, Growable}, - FixedSizeListArray, ListArray, +use crate::datatypes::{Int64Array, Utf8Array}; +use crate::{ + array::{ + growable::{make_growable, Growable}, + FixedSizeListArray, ListArray, + }, + datatypes::UInt64Array, }; -use crate::datatypes::{Int64Array, UInt64Array, Utf8Array}; -use crate::DataType; +use crate::{CountMode, DataType}; use crate::series::Series; @@ -42,11 +45,34 @@ fn join_arrow_list_of_utf8s( } impl ListArray { - pub fn lengths(&self) -> DaftResult { - let lengths = self.offsets().lengths().map(|l| Some(l as u64)); + pub fn count(&self, mode: CountMode) -> DaftResult { + let counts = match (mode, self.flat_child.validity()) { + (CountMode::All, _) | (CountMode::Valid, None) => { + self.offsets().lengths().map(|l| l as u64).collect() + } + (CountMode::Valid, Some(validity)) => self + .offsets() + .windows(2) + .map(|w| { + (w[0]..w[1]) + .map(|i| validity.get_bit(i as usize) as u64) + .sum() + }) + .collect(), + (CountMode::Null, None) => repeat(0).take(self.offsets().len() - 1).collect(), + (CountMode::Null, Some(validity)) => self + .offsets() + .windows(2) + .map(|w| { + (w[0]..w[1]) + .map(|i| !validity.get_bit(i as usize) as u64) + .sum() + }) + .collect(), + }; + let array = Box::new( - arrow2::array::PrimitiveArray::from_iter(lengths) - .with_validity(self.validity().cloned()), + arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()), ); Ok(UInt64Array::from((self.name(), array))) } @@ -172,27 +198,33 @@ impl ListArray { } impl FixedSizeListArray { - pub fn lengths(&self) -> DaftResult { + pub fn count(&self, mode: CountMode) -> DaftResult { let size = self.fixed_element_len(); - match self.validity() { - None => Ok(UInt64Array::from(( - self.name(), - repeat(size as u64) - .take(self.len()) - .collect::>() - .as_slice(), - ))), - Some(validity) => { - let arrow_arr = arrow2::array::UInt64Array::from_iter(validity.iter().map(|v| { - if v { - Some(size as u64) - } else { - None - } - })); - Ok(UInt64Array::from((self.name(), Box::new(arrow_arr)))) + let counts = match (mode, self.flat_child.validity()) { + (CountMode::All, _) | (CountMode::Valid, None) => { + repeat(size as u64).take(self.len()).collect() } - } + (CountMode::Valid, Some(validity)) => (0..self.len()) + .map(|i| { + (0..size) + .map(|j| validity.get_bit(i * size + j) as u64) + .sum() + }) + .collect(), + (CountMode::Null, None) => repeat(0).take(self.len()).collect(), + (CountMode::Null, Some(validity)) => (0..self.len()) + .map(|i| { + (0..size) + .map(|j| !validity.get_bit(i * size + j) as u64) + .sum() + }) + .collect(), + }; + + let array = Box::new( + arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()), + ); + Ok(UInt64Array::from((self.name(), array))) } pub fn explode(&self) -> DaftResult { @@ -303,3 +335,46 @@ impl FixedSizeListArray { } } } + +macro_rules! impl_aggs_list_array { + ($la:ident) => { + impl $la { + fn agg_helper(&self, op: T) -> DaftResult + where + T: Fn(&Series) -> DaftResult, + { + // TODO(Kevin): Currently this requires full materialization of one Series for every list. We could avoid this by implementing either sorted aggregation or an array builder + + // Assumes `op`` returns a null Series given an empty Series + let aggs = self + .into_iter() + .map(|s| s.unwrap_or(Series::empty("", self.child_data_type()))) + .map(|s| op(&s)) + .collect::>>()?; + + let agg_refs: Vec<_> = aggs.iter().collect(); + + Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name())) + } + + pub fn sum(&self) -> DaftResult { + self.agg_helper(|s| s.sum(None)) + } + + pub fn mean(&self) -> DaftResult { + self.agg_helper(|s| s.mean(None)) + } + + pub fn min(&self) -> DaftResult { + self.agg_helper(|s| s.min(None)) + } + + pub fn max(&self) -> DaftResult { + self.agg_helper(|s| s.max(None)) + } + } + }; +} + +impl_aggs_list_array!(ListArray); +impl_aggs_list_array!(FixedSizeListArray); diff --git a/src/daft-core/src/datatypes/agg_ops.rs b/src/daft-core/src/datatypes/agg_ops.rs new file mode 100644 index 0000000000..48a89968b6 --- /dev/null +++ b/src/daft-core/src/datatypes/agg_ops.rs @@ -0,0 +1,31 @@ +use common_error::{DaftError, DaftResult}; + +use super::DataType; + +/// Get the data type that the sum of a column of the given data type should be casted to. +pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { + use DataType::*; + match dtype { + Int8 | Int16 | Int32 | Int64 => Ok(Int64), + UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64), + Float32 => Ok(Float32), + Float64 => Ok(Float64), + other => Err(DaftError::TypeError(format!( + "Invalid argument to sum supertype: {}", + other + ))), + } +} + +/// Get the data type that the mean of a column of the given data type should be casted to. +pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { + use DataType::*; + if dtype.is_numeric() { + Ok(Float64) + } else { + Err(DaftError::TypeError(format!( + "Invalid argument to mean supertype: {}", + dtype + ))) + } +} diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 68d9c55464..3b937fbbc5 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -1,3 +1,4 @@ +mod agg_ops; mod binary_ops; mod dtype; mod field; @@ -8,6 +9,7 @@ mod time_unit; pub use crate::array::{DataArray, FixedSizeListArray}; use crate::array::{ListArray, StructArray}; +pub use agg_ops::{try_mean_supertype, try_sum_supertype}; use arrow2::{ compute::comparison::Simd8, types::{simd::Simd, NativeType}, diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 2882049f74..c262f36612 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -392,8 +392,8 @@ impl PySeries { Ok(self.series.murmur3_32()?.into_series().into()) } - pub fn list_lengths(&self) -> PyResult { - Ok(self.series.list_lengths()?.into_series().into()) + pub fn list_count(&self, mode: CountMode) -> PyResult { + Ok(self.series.list_count(mode)?.into_series().into()) } pub fn list_get(&self, idx: &Self, default: &Self) -> PyResult { diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index c48905fc7f..961e160d6c 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -1,5 +1,6 @@ use crate::datatypes::{DataType, UInt64Array, Utf8Array}; use crate::series::Series; +use crate::CountMode; use common_error::DaftError; use common_error::DaftResult; @@ -17,13 +18,13 @@ impl Series { } } - pub fn list_lengths(&self) -> DaftResult { + pub fn list_count(&self, mode: CountMode) -> DaftResult { use DataType::*; match self.data_type() { - List(_) => self.list()?.lengths(), - FixedSizeList(..) => self.fixed_size_list()?.lengths(), - Embedding(..) | FixedShapeImage(..) => self.as_physical()?.list_lengths(), + List(_) => self.list()?.count(mode), + FixedSizeList(..) => self.fixed_size_list()?.count(mode), + Embedding(..) | FixedShapeImage(..) => self.as_physical()?.list_count(mode), Image(..) => { let struct_array = self.as_physical()?; let data_array = struct_array.struct_()?.children[0].list().unwrap(); @@ -37,7 +38,7 @@ impl Series { Ok(UInt64Array::from((self.name(), array))) } dt => Err(DaftError::TypeError(format!( - "lengths not implemented for {}", + "Count not implemented for {}", dt ))), } @@ -67,4 +68,48 @@ impl Series { ))), } } + + pub fn list_sum(&self) -> DaftResult { + match self.data_type() { + DataType::List(_) => self.list()?.sum(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.sum(), + dt => Err(DaftError::TypeError(format!( + "Sum not implemented for {}", + dt + ))), + } + } + + pub fn list_mean(&self) -> DaftResult { + match self.data_type() { + DataType::List(_) => self.list()?.mean(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.mean(), + dt => Err(DaftError::TypeError(format!( + "Mean not implemented for {}", + dt + ))), + } + } + + pub fn list_min(&self) -> DaftResult { + match self.data_type() { + DataType::List(_) => self.list()?.min(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.min(), + dt => Err(DaftError::TypeError(format!( + "Min not implemented for {}", + dt + ))), + } + } + + pub fn list_max(&self) -> DaftResult { + match self.data_type() { + DataType::List(_) => self.list()?.max(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.max(), + dt => Err(DaftError::TypeError(format!( + "Max not implemented for {}", + dt + ))), + } + } } diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index aaca8a6472..1201956dea 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -1,7 +1,6 @@ use daft_core::{ count_mode::CountMode, - datatypes::DataType, - datatypes::{Field, FieldID}, + datatypes::{try_mean_supertype, try_sum_supertype, DataType, Field, FieldID}, schema::Schema, utils::supertype::try_get_supertype, }; @@ -163,47 +162,14 @@ impl AggExpr { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), - match &field.dtype { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Int64 - } - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => DataType::UInt64, - DataType::Float32 => DataType::Float32, - DataType::Float64 => DataType::Float64, - other => { - return Err(DaftError::TypeError(format!( - "Expected input to sum() to be numeric but received dtype {} for column \"{}\"", - other, field.name, - ))) - } - }, + try_sum_supertype(&field.dtype)?, )) } Mean(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), - match &field.dtype { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => DataType::Float64, - other => { - return Err(DaftError::TypeError(format!( - "Numeric mean is not implemented for column \"{}\" of type {}", - field.name, other, - ))) - } - }, + try_mean_supertype(&field.dtype)?, )) } Min(expr) | Max(expr) | AnyValue(expr, _) => { diff --git a/src/daft-dsl/src/functions/list/count.rs b/src/daft-dsl/src/functions/list/count.rs new file mode 100644 index 0000000000..10c818bb24 --- /dev/null +++ b/src/daft-dsl/src/functions/list/count.rs @@ -0,0 +1,65 @@ +use crate::{functions::FunctionExpr, Expr}; +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, + IntoSeries, +}; + +use common_error::{DaftError, DaftResult}; + +use super::{super::FunctionEvaluator, ListExpr}; + +pub(super) struct CountEvaluator {} + +impl FunctionEvaluator for CountEvaluator { + fn fn_name(&self) -> &'static str { + "count" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, expr: &Expr) -> DaftResult { + match inputs { + [input] => { + let input_field = input.to_field(schema)?; + + match input_field.dtype { + DataType::List(_) | DataType::FixedSizeList(_, _) => match expr { + Expr::Function { + func: FunctionExpr::List(ListExpr::Count(_)), + inputs: _, + } => Ok(Field::new(input.name()?, DataType::UInt64)), + _ => panic!("Expected List Count Expr, got {expr}"), + }, + _ => Err(DaftError::TypeError(format!( + "Expected input to be a list type, received: {}", + input_field.dtype + ))), + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], expr: &Expr) -> DaftResult { + match inputs { + [input] => { + let mode = match expr { + Expr::Function { + func: FunctionExpr::List(ListExpr::Count(mode)), + inputs: _, + } => mode, + _ => panic!("Expected List Count Expr, got {expr}"), + }; + + Ok(input.list_count(*mode)?.into_series()) + } + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/list/max.rs b/src/daft-dsl/src/functions/list/max.rs new file mode 100644 index 0000000000..34d788f4aa --- /dev/null +++ b/src/daft-dsl/src/functions/list/max.rs @@ -0,0 +1,45 @@ +use crate::Expr; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct MaxEvaluator {} + +impl FunctionEvaluator for MaxEvaluator { + fn fn_name(&self) -> &'static str { + "max" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input] => { + let field = input.to_field(schema)?.to_exploded_field()?; + + if field.dtype.is_numeric() { + Ok(field) + } else { + Err(DaftError::TypeError(format!( + "Expected input to be numeric, got {}", + field.dtype + ))) + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input] => Ok(input.list_max()?), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/list/mean.rs b/src/daft-dsl/src/functions/list/mean.rs new file mode 100644 index 0000000000..c7409093b3 --- /dev/null +++ b/src/daft-dsl/src/functions/list/mean.rs @@ -0,0 +1,44 @@ +use crate::Expr; +use daft_core::{ + datatypes::{try_mean_supertype, Field}, + schema::Schema, + series::Series, +}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct MeanEvaluator {} + +impl FunctionEvaluator for MeanEvaluator { + fn fn_name(&self) -> &'static str { + "mean" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input] => { + let inner_field = input.to_field(schema)?.to_exploded_field()?; + Ok(Field::new( + inner_field.name.as_str(), + try_mean_supertype(&inner_field.dtype)?, + )) + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input] => Ok(input.list_mean()?), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/list/min.rs b/src/daft-dsl/src/functions/list/min.rs new file mode 100644 index 0000000000..a2e1988e3e --- /dev/null +++ b/src/daft-dsl/src/functions/list/min.rs @@ -0,0 +1,45 @@ +use crate::Expr; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct MinEvaluator {} + +impl FunctionEvaluator for MinEvaluator { + fn fn_name(&self) -> &'static str { + "min" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input] => { + let field = input.to_field(schema)?.to_exploded_field()?; + + if field.dtype.is_numeric() { + Ok(field) + } else { + Err(DaftError::TypeError(format!( + "Expected input to be numeric, got {}", + field.dtype + ))) + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input] => Ok(input.list_min()?), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/list/mod.rs b/src/daft-dsl/src/functions/list/mod.rs index 8279dbdb2e..f673406f9d 100644 --- a/src/daft-dsl/src/functions/list/mod.rs +++ b/src/daft-dsl/src/functions/list/mod.rs @@ -1,13 +1,22 @@ +mod count; mod explode; mod get; mod join; -mod lengths; +mod max; +mod mean; +mod min; +mod sum; +use count::CountEvaluator; +use daft_core::CountMode; use explode::ExplodeEvaluator; use get::GetEvaluator; use join::JoinEvaluator; -use lengths::LengthsEvaluator; +use max::MaxEvaluator; +use mean::MeanEvaluator; +use min::MinEvaluator; use serde::{Deserialize, Serialize}; +use sum::SumEvaluator; use crate::Expr; @@ -17,8 +26,12 @@ use super::FunctionEvaluator; pub enum ListExpr { Explode, Join, - Lengths, + Count(CountMode), Get, + Sum, + Mean, + Min, + Max, } impl ListExpr { @@ -28,8 +41,12 @@ impl ListExpr { match self { Explode => &ExplodeEvaluator {}, Join => &JoinEvaluator {}, - Lengths => &LengthsEvaluator {}, + Count(_) => &CountEvaluator {}, Get => &GetEvaluator {}, + Sum => &SumEvaluator {}, + Mean => &MeanEvaluator {}, + Min => &MinEvaluator {}, + Max => &MaxEvaluator {}, } } } @@ -48,9 +65,9 @@ pub fn join(input: &Expr, delimiter: &Expr) -> Expr { } } -pub fn lengths(input: &Expr) -> Expr { +pub fn count(input: &Expr, mode: CountMode) -> Expr { Expr::Function { - func: super::FunctionExpr::List(ListExpr::Lengths), + func: super::FunctionExpr::List(ListExpr::Count(mode)), inputs: vec![input.clone()], } } @@ -61,3 +78,31 @@ pub fn get(input: &Expr, idx: &Expr, default: &Expr) -> Expr { inputs: vec![input.clone(), idx.clone(), default.clone()], } } + +pub fn sum(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Sum), + inputs: vec![input.clone()], + } +} + +pub fn mean(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Mean), + inputs: vec![input.clone()], + } +} + +pub fn min(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Min), + inputs: vec![input.clone()], + } +} + +pub fn max(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Max), + inputs: vec![input.clone()], + } +} diff --git a/src/daft-dsl/src/functions/list/lengths.rs b/src/daft-dsl/src/functions/list/sum.rs similarity index 51% rename from src/daft-dsl/src/functions/list/lengths.rs rename to src/daft-dsl/src/functions/list/sum.rs index b7858999bc..88ec0a56cc 100644 --- a/src/daft-dsl/src/functions/list/lengths.rs +++ b/src/daft-dsl/src/functions/list/sum.rs @@ -1,35 +1,30 @@ use crate::Expr; use daft_core::{ - datatypes::{DataType, Field}, + datatypes::{try_sum_supertype, Field}, schema::Schema, - series::{IntoSeries, Series}, + series::Series, }; use common_error::{DaftError, DaftResult}; use super::super::FunctionEvaluator; -pub(super) struct LengthsEvaluator {} +pub(super) struct SumEvaluator {} -impl FunctionEvaluator for LengthsEvaluator { +impl FunctionEvaluator for SumEvaluator { fn fn_name(&self) -> &'static str { - "lengths" + "sum" } fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { match inputs { [input] => { - let input_field = input.to_field(schema)?; + let inner_field = input.to_field(schema)?.to_exploded_field()?; - match input_field.dtype { - DataType::List(_) | DataType::FixedSizeList(_, _) => { - Ok(Field::new(input.name()?, DataType::UInt64)) - } - _ => Err(DaftError::TypeError(format!( - "Expected input to be a list type, received: {}", - input_field.dtype - ))), - } + Ok(Field::new( + inner_field.name.as_str(), + try_sum_supertype(&inner_field.dtype)?, + )) } _ => Err(DaftError::SchemaMismatch(format!( "Expected 1 input arg, got {}", @@ -40,7 +35,7 @@ impl FunctionEvaluator for LengthsEvaluator { fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { match inputs { - [input] => Ok(input.list_lengths()?.into_series()), + [input] => Ok(input.list_sum()?), _ => Err(DaftError::ValueError(format!( "Expected 1 input arg, got {}", inputs.len() diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 4243eb2b60..bc68334b5b 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -517,9 +517,9 @@ impl PyExpr { Ok(join(&self.expr, &delimiter.expr).into()) } - pub fn list_lengths(&self) -> PyResult { - use crate::functions::list::lengths; - Ok(lengths(&self.expr).into()) + pub fn list_count(&self, mode: CountMode) -> PyResult { + use crate::functions::list::count; + Ok(count(&self.expr, mode).into()) } pub fn list_get(&self, idx: &Self, default: &Self) -> PyResult { @@ -527,6 +527,26 @@ impl PyExpr { Ok(get(&self.expr, &idx.expr, &default.expr).into()) } + pub fn list_sum(&self) -> PyResult { + use crate::functions::list::sum; + Ok(sum(&self.expr).into()) + } + + pub fn list_mean(&self) -> PyResult { + use crate::functions::list::mean; + Ok(mean(&self.expr).into()) + } + + pub fn list_min(&self) -> PyResult { + use crate::functions::list::min; + Ok(min(&self.expr).into()) + } + + pub fn list_max(&self) -> PyResult { + use crate::functions::list::max; + Ok(max(&self.expr).into()) + } + pub fn struct_get(&self, name: &str) -> PyResult { use crate::functions::struct_::get; Ok(get(&self.expr, name).into()) diff --git a/src/daft-table/src/ops/explode.rs b/src/daft-table/src/ops/explode.rs index 45855805ed..bec57b8dfb 100644 --- a/src/daft-table/src/ops/explode.rs +++ b/src/daft-table/src/ops/explode.rs @@ -1,5 +1,6 @@ use common_error::{DaftError, DaftResult}; use daft_core::series::IntoSeries; +use daft_core::CountMode; use daft_core::{ array::ops::as_arrow::AsArrow, datatypes::{DataType, UInt64Array}, @@ -60,11 +61,14 @@ impl Table { } } } - let first_len = evaluated_columns.first().unwrap().list_lengths()?; + let first_len = evaluated_columns + .first() + .unwrap() + .list_count(CountMode::All)?; if evaluated_columns .iter() .skip(1) - .any(|c| c.list_lengths().unwrap().ne(&first_len)) + .any(|c| c.list_count(CountMode::All).unwrap().ne(&first_len)) { return Err(DaftError::ValueError( "In multicolumn explode, list length did not match".to_string(), diff --git a/tests/table/list/test_list_count_lengths.py b/tests/table/list/test_list_count_lengths.py new file mode 100644 index 0000000000..321219d6da --- /dev/null +++ b/tests/table/list/test_list_count_lengths.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import pytest + +from daft.daft import CountMode +from daft.datatype import DataType +from daft.expressions import col +from daft.table import MicroPartition + + +@pytest.fixture +def table(): + return MicroPartition.from_pydict({"col": [None, [], ["a"], [None], ["a", "a"], ["a", None], ["a", None, "a"]]}) + + +@pytest.fixture +def fixed_table(): + table = MicroPartition.from_pydict({"col": [["a", "a"], ["a", "a"], ["a", None], [None, None], None]}) + fixed_dtype = DataType.fixed_size_list(DataType.string(), 2) + return table.eval_expression_list([col("col").cast(fixed_dtype)]) + + +def test_list_lengths(table): + result = table.eval_expression_list([col("col").list.lengths()]) + assert result.to_pydict() == {"col": [None, 0, 1, 1, 2, 2, 3]} + + +def test_fixed_list_lengths(fixed_table): + result = fixed_table.eval_expression_list([col("col").list.lengths()]) + assert result.to_pydict() == {"col": [2, 2, 2, 2, None]} + + +def test_list_count(table): + result = table.eval_expression_list([col("col").list.count(CountMode.All)]) + assert result.to_pydict() == {"col": [None, 0, 1, 1, 2, 2, 3]} + + result = table.eval_expression_list([col("col").list.count(CountMode.Valid)]) + assert result.to_pydict() == {"col": [None, 0, 1, 0, 2, 1, 2]} + + result = table.eval_expression_list([col("col").list.count(CountMode.Null)]) + assert result.to_pydict() == {"col": [None, 0, 0, 1, 0, 1, 1]} + + +def test_fixed_list_count(fixed_table): + result = fixed_table.eval_expression_list([col("col").list.count(CountMode.All)]) + assert result.to_pydict() == {"col": [2, 2, 2, 2, None]} + + result = fixed_table.eval_expression_list([col("col").list.count(CountMode.Valid)]) + assert result.to_pydict() == {"col": [2, 2, 1, 0, None]} + + result = fixed_table.eval_expression_list([col("col").list.count(CountMode.Null)]) + assert result.to_pydict() == {"col": [0, 0, 1, 2, None]} diff --git a/tests/table/list/test_list_lengths.py b/tests/table/list/test_list_lengths.py deleted file mode 100644 index a3520908dc..0000000000 --- a/tests/table/list/test_list_lengths.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from daft.expressions import col -from daft.table import MicroPartition - - -def test_list_lengths(): - table = MicroPartition.from_pydict({"col": [None, [], ["a"], [None], ["a", "a"], ["a", None], ["a", None, "a"]]}) - result = table.eval_expression_list([col("col").list.lengths()]) - assert result.to_pydict() == {"col": [None, 0, 1, 1, 2, 2, 3]} diff --git a/tests/table/list/test_list_numeric_aggs.py b/tests/table/list/test_list_numeric_aggs.py new file mode 100644 index 0000000000..c77d2fe3dc --- /dev/null +++ b/tests/table/list/test_list_numeric_aggs.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import pytest + +from daft.datatype import DataType +from daft.expressions import col +from daft.table import MicroPartition + +table = MicroPartition.from_pydict({"a": [[1, 2], [3, 4], [5, None], [None, None], None]}) +fixed_dtype = DataType.fixed_size_list(DataType.int64(), 2) +fixed_table = table.eval_expression_list([col("a").cast(fixed_dtype)]) + + +@pytest.mark.parametrize("table", [table, fixed_table]) +def test_list_sum(table): + result = table.eval_expression_list([col("a").list.sum()]) + assert result.to_pydict() == {"a": [3, 7, 5, None, None]} + + +@pytest.mark.parametrize("table", [table, fixed_table]) +def test_list_mean(table): + result = table.eval_expression_list([col("a").list.mean()]) + assert result.to_pydict() == {"a": [1.5, 3.5, 5, None, None]} + + +@pytest.mark.parametrize("table", [table, fixed_table]) +def test_list_min(table): + result = table.eval_expression_list([col("a").list.min()]) + assert result.to_pydict() == {"a": [1, 3, 5, None, None]} + + +@pytest.mark.parametrize("table", [table, fixed_table]) +def test_list_max(table): + result = table.eval_expression_list([col("a").list.max()]) + assert result.to_pydict() == {"a": [2, 4, 5, None, None]}