Skip to content

Commit

Permalink
[FEAT] Add basic list aggregations (#2032)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang authored Apr 11, 2024
1 parent 3e73d74 commit 93fc6ca
Show file tree
Hide file tree
Showing 23 changed files with 700 additions and 114 deletions.
8 changes: 6 additions & 2 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -954,8 +954,12 @@ class PyExpr:
def image_resize(self, w: int, h: int) -> PyExpr: ...
def image_crop(self, bbox: PyExpr) -> PyExpr: ...
def list_join(self, delimiter: PyExpr) -> PyExpr: ...
def list_lengths(self) -> PyExpr: ...
def list_count(self, mode: CountMode) -> PyExpr: ...
def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ...
def list_sum(self) -> PyExpr: ...
def list_mean(self) -> PyExpr: ...
def list_min(self) -> PyExpr: ...
def list_max(self) -> PyExpr: ...
def struct_get(self, name: str) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
Expand Down Expand Up @@ -1056,7 +1060,7 @@ class PySeries:
def partitioning_years(self) -> PySeries: ...
def partitioning_iceberg_bucket(self, n: int) -> PySeries: ...
def partitioning_iceberg_truncate(self, w: int) -> PySeries: ...
def list_lengths(self) -> PySeries: ...
def list_count(self, mode: CountMode) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def image_decode(self, raise_error_on_failure: bool) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
Expand Down
45 changes: 44 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,13 +1116,24 @@ def join(self, delimiter: str | Expression) -> Expression:
delimiter_expr = Expression._to_expression(delimiter)
return Expression._from_pyexpr(self._expr.list_join(delimiter_expr._expr))

def count(self, mode: CountMode = CountMode.Valid) -> Expression:
"""Counts the number of elements in each list
Args:
mode: The mode to use for counting. Defaults to CountMode.Valid
Returns:
Expression: a UInt64 expression which is the length of each list
"""
return Expression._from_pyexpr(self._expr.list_count(mode))

def lengths(self) -> Expression:
"""Gets the length of each list
Returns:
Expression: a UInt64 expression which is the length of each list
"""
return Expression._from_pyexpr(self._expr.list_lengths())
return Expression._from_pyexpr(self._expr.list_count(CountMode.All))

def get(self, idx: int | Expression, default: object = None) -> Expression:
"""Gets the element at an index in each list
Expand All @@ -1138,6 +1149,38 @@ def get(self, idx: int | Expression, default: object = None) -> Expression:
default_expr = lit(default)
return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr))

def sum(self) -> Expression:
"""Sums each list. Empty lists and lists with all nulls yield null.
Returns:
Expression: an expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_sum())

def mean(self) -> Expression:
"""Calculates the mean of each list. If no non-null values in a list, the result is null.
Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_mean())

def min(self) -> Expression:
"""Calculates the minimum of each list. If no non-null values in a list, the result is null.
Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_min())

def max(self) -> Expression:
"""Calculates the maximum of each list. If no non-null values in a list, the result is null.
Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_max())


class ExpressionStructNamespace(ExpressionNamespace):
def get(self, name: str) -> Expression:
Expand Down
2 changes: 1 addition & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def iceberg_truncate(self, w: int) -> Series:

class SeriesListNamespace(SeriesNamespace):
def lengths(self) -> Series:
return Series._from_pyseries(self._series.list_lengths())
return Series._from_pyseries(self._series.list_count(CountMode.All))

def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series, default._series))
Expand Down
41 changes: 41 additions & 0 deletions src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,47 @@ impl FixedSizeListArray {
}
}

impl<'a> IntoIterator for &'a FixedSizeListArray {
type Item = Option<Series>;

type IntoIter = FixedSizeListArrayIter<'a>;

fn into_iter(self) -> Self::IntoIter {
FixedSizeListArrayIter {
array: self,
idx: 0,
}
}
}

pub struct FixedSizeListArrayIter<'a> {
array: &'a FixedSizeListArray,
idx: usize,
}

impl Iterator for FixedSizeListArrayIter<'_> {
type Item = Option<Series>;

fn next(&mut self) -> Option<Self::Item> {
if self.idx < self.array.len() {
if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) {
self.idx += 1;
Some(None)
} else {
let step = self.array.fixed_element_len();

let start = self.idx * step;
let end = (self.idx + 1) * step;

self.idx += 1;
Some(Some(self.array.flat_child.slice(start, end).unwrap()))
}
} else {
None
}
}
}

#[cfg(test)]
mod tests {
use common_error::DaftResult;
Expand Down
39 changes: 39 additions & 0 deletions src/daft-core/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,42 @@ impl ListArray {
))
}
}

impl<'a> IntoIterator for &'a ListArray {
type Item = Option<Series>;

type IntoIter = ListArrayIter<'a>;

fn into_iter(self) -> Self::IntoIter {
ListArrayIter {
array: self,
idx: 0,
}
}
}

pub struct ListArrayIter<'a> {
array: &'a ListArray,
idx: usize,
}

impl Iterator for ListArrayIter<'_> {
type Item = Option<Series>;

fn next(&mut self) -> Option<Self::Item> {
if self.idx < self.array.len() {
if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) {
self.idx += 1;
Some(None)
} else {
let start = *self.array.offsets().get(self.idx).unwrap() as usize;
let end = *self.array.offsets().get(self.idx + 1).unwrap() as usize;

self.idx += 1;
Some(Some(self.array.flat_child.slice(start, end).unwrap()))
}
} else {
None
}
}
}
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn grouped_count_arrow_bitmap(
.iter()
.map(|g| {
g.iter()
.map(|i| validity.get_bit(!*i as usize) as u64)
.map(|i| !validity.get_bit(*i as usize) as u64)
.sum()
})
.collect(),
Expand Down
131 changes: 103 additions & 28 deletions src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::iter::repeat;

use crate::array::{
growable::{make_growable, Growable},
FixedSizeListArray, ListArray,
use crate::datatypes::{Int64Array, Utf8Array};
use crate::{
array::{
growable::{make_growable, Growable},
FixedSizeListArray, ListArray,
},
datatypes::UInt64Array,
};
use crate::datatypes::{Int64Array, UInt64Array, Utf8Array};
use crate::DataType;
use crate::{CountMode, DataType};

use crate::series::Series;

Expand Down Expand Up @@ -42,11 +45,34 @@ fn join_arrow_list_of_utf8s(
}

impl ListArray {
pub fn lengths(&self) -> DaftResult<UInt64Array> {
let lengths = self.offsets().lengths().map(|l| Some(l as u64));
pub fn count(&self, mode: CountMode) -> DaftResult<UInt64Array> {
let counts = match (mode, self.flat_child.validity()) {
(CountMode::All, _) | (CountMode::Valid, None) => {
self.offsets().lengths().map(|l| l as u64).collect()
}
(CountMode::Valid, Some(validity)) => self
.offsets()
.windows(2)
.map(|w| {
(w[0]..w[1])
.map(|i| validity.get_bit(i as usize) as u64)
.sum()
})
.collect(),
(CountMode::Null, None) => repeat(0).take(self.offsets().len() - 1).collect(),
(CountMode::Null, Some(validity)) => self
.offsets()
.windows(2)
.map(|w| {
(w[0]..w[1])
.map(|i| !validity.get_bit(i as usize) as u64)
.sum()
})
.collect(),
};

let array = Box::new(
arrow2::array::PrimitiveArray::from_iter(lengths)
.with_validity(self.validity().cloned()),
arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()),
);
Ok(UInt64Array::from((self.name(), array)))
}
Expand Down Expand Up @@ -172,27 +198,33 @@ impl ListArray {
}

impl FixedSizeListArray {
pub fn lengths(&self) -> DaftResult<UInt64Array> {
pub fn count(&self, mode: CountMode) -> DaftResult<UInt64Array> {
let size = self.fixed_element_len();
match self.validity() {
None => Ok(UInt64Array::from((
self.name(),
repeat(size as u64)
.take(self.len())
.collect::<Vec<_>>()
.as_slice(),
))),
Some(validity) => {
let arrow_arr = arrow2::array::UInt64Array::from_iter(validity.iter().map(|v| {
if v {
Some(size as u64)
} else {
None
}
}));
Ok(UInt64Array::from((self.name(), Box::new(arrow_arr))))
let counts = match (mode, self.flat_child.validity()) {
(CountMode::All, _) | (CountMode::Valid, None) => {
repeat(size as u64).take(self.len()).collect()
}
}
(CountMode::Valid, Some(validity)) => (0..self.len())
.map(|i| {
(0..size)
.map(|j| validity.get_bit(i * size + j) as u64)
.sum()
})
.collect(),
(CountMode::Null, None) => repeat(0).take(self.len()).collect(),
(CountMode::Null, Some(validity)) => (0..self.len())
.map(|i| {
(0..size)
.map(|j| !validity.get_bit(i * size + j) as u64)
.sum()
})
.collect(),
};

let array = Box::new(
arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()),
);
Ok(UInt64Array::from((self.name(), array)))
}

pub fn explode(&self) -> DaftResult<Series> {
Expand Down Expand Up @@ -303,3 +335,46 @@ impl FixedSizeListArray {
}
}
}

macro_rules! impl_aggs_list_array {
($la:ident) => {
impl $la {
fn agg_helper<T>(&self, op: T) -> DaftResult<Series>
where
T: Fn(&Series) -> DaftResult<Series>,
{
// TODO(Kevin): Currently this requires full materialization of one Series for every list. We could avoid this by implementing either sorted aggregation or an array builder

// Assumes `op`` returns a null Series given an empty Series
let aggs = self
.into_iter()
.map(|s| s.unwrap_or(Series::empty("", self.child_data_type())))
.map(|s| op(&s))
.collect::<DaftResult<Vec<_>>>()?;

let agg_refs: Vec<_> = aggs.iter().collect();

Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name()))
}

pub fn sum(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.sum(None))
}

pub fn mean(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.mean(None))
}

pub fn min(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.min(None))
}

pub fn max(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.max(None))
}
}
};
}

impl_aggs_list_array!(ListArray);
impl_aggs_list_array!(FixedSizeListArray);
31 changes: 31 additions & 0 deletions src/daft-core/src/datatypes/agg_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use common_error::{DaftError, DaftResult};

use super::DataType;

/// Get the data type that the sum of a column of the given data type should be casted to.
pub fn try_sum_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
match dtype {
Int8 | Int16 | Int32 | Int64 => Ok(Int64),
UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64),
Float32 => Ok(Float32),
Float64 => Ok(Float64),
other => Err(DaftError::TypeError(format!(
"Invalid argument to sum supertype: {}",
other
))),
}
}

/// Get the data type that the mean of a column of the given data type should be casted to.
pub fn try_mean_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
if dtype.is_numeric() {
Ok(Float64)
} else {
Err(DaftError::TypeError(format!(
"Invalid argument to mean supertype: {}",
dtype
)))
}
}
Loading

0 comments on commit 93fc6ca

Please sign in to comment.