Skip to content

Commit

Permalink
[FEAT] agg_concat doesn't work on strings (#2847)
Browse files Browse the repository at this point in the history
Solves #2768

---------

Co-authored-by: Colin Ho <[email protected]>
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2024
1 parent 195dd00 commit d57433a
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 5 deletions.
77 changes: 73 additions & 4 deletions src/daft-core/src/array/ops/concat_agg.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
use arrow2::{bitmap::utils::SlicesIterator, offset::OffsetsBuffer, types::Index};
use arrow2::{
array::{Array, Utf8Array},
bitmap::utils::SlicesIterator,
offset::OffsetsBuffer,
types::Index,
};
use common_error::DaftResult;

use super::{as_arrow::AsArrow, DaftConcatAggable};
use crate::array::{
growable::{make_growable, Growable},
ListArray,
use crate::{
array::{
growable::{make_growable, Growable},
DataArray, ListArray,
},
prelude::Utf8Type,
};

#[cfg(feature = "python")]
Expand Down Expand Up @@ -146,6 +154,67 @@ impl DaftConcatAggable for ListArray {
}
}

impl DaftConcatAggable for DataArray<Utf8Type> {
type Output = DaftResult<Self>;

fn concat(&self) -> Self::Output {
let new_validity = match self.validity() {
Some(validity) if validity.unset_bits() == self.len() => {
Some(arrow2::bitmap::Bitmap::from(vec![false]))
}
_ => None,
};

let arrow_array = self.as_arrow();
let new_offsets = OffsetsBuffer::<i64>::try_from(vec![0, *arrow_array.offsets().last()])?;
let output = Utf8Array::new(
arrow_array.data_type().clone(),
new_offsets,
arrow_array.values().clone(),
new_validity,
);

let result_box = Box::new(output);
DataArray::new(self.field().clone().into(), result_box)
}

fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output {
let arrow_array = self.as_arrow();
let concat_per_group = if arrow_array.null_count() > 0 {
Box::new(Utf8Array::from_trusted_len_iter(groups.iter().map(|g| {
let to_concat = g
.iter()
.filter_map(|index| {
let idx = *index as usize;
arrow_array.get(idx)
})
.collect::<Vec<&str>>();
if to_concat.is_empty() {
None
} else {
Some(to_concat.concat())
}
})))
} else {
Box::new(Utf8Array::from_trusted_len_values_iter(groups.iter().map(
|g| {
g.iter()
.map(|index| {
let idx = *index as usize;
arrow_array.value(idx)
})
.collect::<String>()
},
)))
};

Ok(DataArray::from((
self.field.name.as_ref(),
concat_per_group,
)))
}
}

#[cfg(test)]
mod test {
use std::iter::repeat;
Expand Down
11 changes: 10 additions & 1 deletion src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,17 @@ impl Series {
None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()),
}
}
DataType::Utf8 => {
let downcasted = self.downcast::<Utf8Array>()?;
match groups {
Some(groups) => {
Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series())
}
None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()),
}
}
_ => Err(DaftError::TypeError(format!(
"concat aggregation is only valid for List or Python types, got {}",
"concat aggregation is only valid for List, Python types, or Utf8, got {}",
self.data_type()
))),
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ impl AggExpr {
let field = expr.to_field(schema)?;
match field.dtype {
DataType::List(..) => Ok(field),
DataType::Utf8 => Ok(field),
#[cfg(feature = "python")]
DataType::Python => Ok(field),
_ => Err(DaftError::TypeError(format!(
Expand Down
50 changes: 50 additions & 0 deletions tests/table/test_table_aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,3 +874,53 @@ def test_groupby_struct(dtype) -> None:
expected = [[0, 1, 4], [2, 6], [3, 5]]
for lt in expected:
assert lt in res["b"]


def test_agg_concat_on_string() -> None:
df3 = from_pydict({"a": ["the", " quick", " brown", " fox"]})
res = df3.agg(col("a").agg_concat()).to_pydict()
assert res["a"] == ["the quick brown fox"]


def test_agg_concat_on_string_groupby() -> None:
df3 = from_pydict({"a": ["the", " quick", " brown", " fox"], "b": [1, 2, 1, 2]})
res = df3.groupby("b").agg_concat("a").to_pydict()
expected = ["the brown", " quick fox"]
for txt in expected:
assert txt in res["a"]


def test_agg_concat_on_string_null() -> None:
df3 = from_pydict({"a": ["the", " quick", None, " fox"]})
res = df3.agg(col("a").agg_concat()).to_pydict()
expected = ["the quick fox"]
assert res["a"] == expected


def test_agg_concat_on_string_groupby_null() -> None:
df3 = from_pydict({"a": ["the", " quick", None, " fox"], "b": [1, 2, 1, 2]})
res = df3.groupby("b").agg_concat("a").to_pydict()
expected = ["the", " quick fox"]
for txt in expected:
assert txt in res["a"]


def test_agg_concat_on_string_null_list() -> None:
df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column(
"a", col("a").cast(DataType.string())
)
res = df3.agg(col("a").agg_concat()).to_pydict()
print(res)
expected = [None]
assert res["a"] == expected
assert len(res["a"]) == 1


def test_agg_concat_on_string_groupby_null_list() -> None:
df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column(
"a", col("a").cast(DataType.string())
)
res = df3.groupby("b").agg_concat("a").to_pydict()
expected = [None, None]
assert res["a"] == expected
assert len(res["a"]) == len(expected)

0 comments on commit d57433a

Please sign in to comment.