Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add basic list aggregations #2032

Merged
merged 11 commits into from
Apr 11, 2024
Merged
8 changes: 6 additions & 2 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,12 @@ class PyExpr:
def image_resize(self, w: int, h: int) -> PyExpr: ...
def image_crop(self, bbox: PyExpr) -> PyExpr: ...
def list_join(self, delimiter: PyExpr) -> PyExpr: ...
def list_lengths(self) -> PyExpr: ...
def list_count(self, mode: CountMode) -> PyExpr: ...
def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ...
def list_sum(self) -> PyExpr: ...
def list_mean(self) -> PyExpr: ...
def list_min(self) -> PyExpr: ...
def list_max(self) -> PyExpr: ...
def struct_get(self, name: str) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
Expand Down Expand Up @@ -1037,7 +1041,7 @@ class PySeries:
def partitioning_years(self) -> PySeries: ...
def partitioning_iceberg_bucket(self, n: int) -> PySeries: ...
def partitioning_iceberg_truncate(self, w: int) -> PySeries: ...
def list_lengths(self) -> PySeries: ...
def list_count(self, mode: CountMode) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def image_decode(self) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
Expand Down
45 changes: 44 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,13 +868,24 @@ def join(self, delimiter: str | Expression) -> Expression:
delimiter_expr = Expression._to_expression(delimiter)
return Expression._from_pyexpr(self._expr.list_join(delimiter_expr._expr))

def count(self, mode: CountMode = CountMode.Valid) -> Expression:
"""Counts the number of elements in each list

Args:
mode: The mode to use for counting. Defaults to CountMode.Valid

Returns:
Expression: a UInt64 expression which is the length of each list
"""
return Expression._from_pyexpr(self._expr.list_count(mode))

def lengths(self) -> Expression:
"""Gets the length of each list

Returns:
Expression: a UInt64 expression which is the length of each list
"""
return Expression._from_pyexpr(self._expr.list_lengths())
return Expression._from_pyexpr(self._expr.list_count(CountMode.All))

def get(self, idx: int | Expression, default: object = None) -> Expression:
"""Gets the element at an index in each list
Expand All @@ -890,6 +901,38 @@ def get(self, idx: int | Expression, default: object = None) -> Expression:
default_expr = lit(default)
return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr))

def sum(self) -> Expression:
"""Sums each list. Empty lists and lists with all nulls yield null.

Returns:
Expression: an expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_sum())

def mean(self) -> Expression:
"""Calculates the mean of each list. If no non-null values in a list, the result is null.

Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_mean())

def min(self) -> Expression:
"""Calculates the minimum of each list. If no non-null values in a list, the result is null.

Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_min())

def max(self) -> Expression:
"""Calculates the maximum of each list. If no non-null values in a list, the result is null.

Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_max())


class ExpressionStructNamespace(ExpressionNamespace):
def get(self, name: str) -> Expression:
Expand Down
2 changes: 1 addition & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def iceberg_truncate(self, w: int) -> Series:

class SeriesListNamespace(SeriesNamespace):
def lengths(self) -> Series:
return Series._from_pyseries(self._series.list_lengths())
return Series._from_pyseries(self._series.list_count(CountMode.All))

def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series, default._series))
Expand Down
41 changes: 41 additions & 0 deletions src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,47 @@ impl FixedSizeListArray {
}
}

impl<'a> IntoIterator for &'a FixedSizeListArray {
type Item = Option<Series>;

type IntoIter = FixedSizeListArrayIter<'a>;

fn into_iter(self) -> Self::IntoIter {
FixedSizeListArrayIter {
array: self,
idx: 0,
}
}
}

pub struct FixedSizeListArrayIter<'a> {
array: &'a FixedSizeListArray,
idx: usize,
}

impl Iterator for FixedSizeListArrayIter<'_> {
type Item = Option<Series>;

fn next(&mut self) -> Option<Self::Item> {
if self.idx < self.array.len() {
if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) {
self.idx += 1;
Some(None)
} else {
let step = self.array.fixed_element_len();

let start = self.idx * step;
let end = (self.idx + 1) * step;

self.idx += 1;
Some(Some(self.array.flat_child.slice(start, end).unwrap()))
}
} else {
None
}
}
}

#[cfg(test)]
mod tests {
use common_error::DaftResult;
Expand Down
39 changes: 39 additions & 0 deletions src/daft-core/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,42 @@ impl ListArray {
))
}
}

impl<'a> IntoIterator for &'a ListArray {
type Item = Option<Series>;

type IntoIter = ListArrayIter<'a>;

fn into_iter(self) -> Self::IntoIter {
ListArrayIter {
array: self,
idx: 0,
}
}
}

pub struct ListArrayIter<'a> {
array: &'a ListArray,
idx: usize,
}

impl Iterator for ListArrayIter<'_> {
type Item = Option<Series>;

fn next(&mut self) -> Option<Self::Item> {
if self.idx < self.array.len() {
if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) {
self.idx += 1;
Some(None)
} else {
let start = *self.array.offsets().get(self.idx).unwrap() as usize;
let end = *self.array.offsets().get(self.idx + 1).unwrap() as usize;

self.idx += 1;
Some(Some(self.array.flat_child.slice(start, end).unwrap()))
}
} else {
None
}
}
}
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn grouped_count_arrow_bitmap(
.iter()
.map(|g| {
g.iter()
.map(|i| validity.get_bit(!*i as usize) as u64)
.map(|i| !validity.get_bit(*i as usize) as u64)
.sum()
})
.collect(),
Expand Down
131 changes: 103 additions & 28 deletions src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::iter::repeat;

use crate::array::{
growable::{make_growable, Growable},
FixedSizeListArray, ListArray,
use crate::datatypes::{Int64Array, Utf8Array};
use crate::{
array::{
growable::{make_growable, Growable},
FixedSizeListArray, ListArray,
},
datatypes::UInt64Array,
};
use crate::datatypes::{Int64Array, UInt64Array, Utf8Array};
use crate::DataType;
use crate::{CountMode, DataType};

use crate::series::Series;

Expand Down Expand Up @@ -42,11 +45,34 @@ fn join_arrow_list_of_utf8s(
}

impl ListArray {
pub fn lengths(&self) -> DaftResult<UInt64Array> {
let lengths = self.offsets().lengths().map(|l| Some(l as u64));
pub fn count(&self, mode: CountMode) -> DaftResult<UInt64Array> {
let counts = match (mode, self.flat_child.validity()) {
(CountMode::All, _) | (CountMode::Valid, None) => {
self.offsets().lengths().map(|l| l as u64).collect()
}
(CountMode::Valid, Some(validity)) => self
.offsets()
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
.windows(2)
.map(|w| {
(w[0]..w[1])
.map(|i| validity.get_bit(i as usize) as u64)
.sum()
})
.collect(),
(CountMode::Null, None) => repeat(0).take(self.offsets().len() - 1).collect(),
(CountMode::Null, Some(validity)) => self
.offsets()
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
.windows(2)
.map(|w| {
(w[0]..w[1])
.map(|i| !validity.get_bit(i as usize) as u64)
.sum()
})
.collect(),
};

let array = Box::new(
arrow2::array::PrimitiveArray::from_iter(lengths)
.with_validity(self.validity().cloned()),
arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()),
);
Ok(UInt64Array::from((self.name(), array)))
}
Expand Down Expand Up @@ -172,27 +198,33 @@ impl ListArray {
}

impl FixedSizeListArray {
pub fn lengths(&self) -> DaftResult<UInt64Array> {
pub fn count(&self, mode: CountMode) -> DaftResult<UInt64Array> {
let size = self.fixed_element_len();
match self.validity() {
None => Ok(UInt64Array::from((
self.name(),
repeat(size as u64)
.take(self.len())
.collect::<Vec<_>>()
.as_slice(),
))),
Some(validity) => {
let arrow_arr = arrow2::array::UInt64Array::from_iter(validity.iter().map(|v| {
if v {
Some(size as u64)
} else {
None
}
}));
Ok(UInt64Array::from((self.name(), Box::new(arrow_arr))))
let counts = match (mode, self.flat_child.validity()) {
(CountMode::All, _) | (CountMode::Valid, None) => {
repeat(size as u64).take(self.len()).collect()
}
}
(CountMode::Valid, Some(validity)) => (0..self.len())
.map(|i| {
(0..size)
.map(|j| validity.get_bit(i * size + j) as u64)
.sum()
})
.collect(),
(CountMode::Null, None) => repeat(0).take(self.len()).collect(),
(CountMode::Null, Some(validity)) => (0..self.len())
.map(|i| {
(0..size)
.map(|j| !validity.get_bit(i * size + j) as u64)
.sum()
})
.collect(),
};

let array = Box::new(
arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()),
);
Ok(UInt64Array::from((self.name(), array)))
}

pub fn explode(&self) -> DaftResult<Series> {
Expand Down Expand Up @@ -303,3 +335,46 @@ impl FixedSizeListArray {
}
}
}

macro_rules! impl_aggs_list_array {
($la:ident) => {
impl $la {
fn agg_helper<T>(&self, op: T) -> DaftResult<Series>
where
T: Fn(&Series) -> DaftResult<Series>,
{
// TODO(Kevin): Currently this requires full materialization of one Series for every list. We could avoid this by implementing either sorted aggregation or an array builder

// Assumes `op`` returns a null Series given an empty Series
let aggs = self
.into_iter()
.map(|s| s.unwrap_or(Series::empty("", self.child_data_type())))
.map(|s| op(&s))
.collect::<DaftResult<Vec<_>>>()?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably want to have a concat that takes in an iterator of Series for this. materializing all the Series then concating it is going to be slow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tabled until we create an ArrayBuilder so that we don't have to fully materialize the Series objects to concat


let agg_refs: Vec<_> = aggs.iter().collect();

Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name()))
}

pub fn sum(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.sum(None))
}

pub fn mean(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.mean(None))
}

pub fn min(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.min(None))
}

pub fn max(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.max(None))
}
}
};
}

impl_aggs_list_array!(ListArray);
impl_aggs_list_array!(FixedSizeListArray);
31 changes: 31 additions & 0 deletions src/daft-core/src/datatypes/agg_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use common_error::{DaftError, DaftResult};

use super::DataType;

/// Get the data type that the sum of a column of the given data type should be casted to.
pub fn try_sum_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
match dtype {
Int8 | Int16 | Int32 | Int64 => Ok(Int64),
UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64),
Float32 => Ok(Float32),
Float64 => Ok(Float64),
other => Err(DaftError::TypeError(format!(
"Invalid argument to sum supertype: {}",
other
))),
}
}

/// Get the data type that the mean of a column of the given data type should be casted to.
pub fn try_mean_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
if dtype.is_numeric() {
Ok(Float64)
} else {
Err(DaftError::TypeError(format!(
"Invalid argument to mean supertype: {}",
dtype
)))
}
}
Loading
Loading