Skip to content

Commit

Permalink
[FEAT] add list.value_counts() (#2902)
Browse files Browse the repository at this point in the history
This PR implements the `list.value_counts()` function and refactors Map types to use a more explicit key-value structure. It also includes various improvements and bug fixes across the codebase.

## Key changes

1. **Implemented list.value_counts() function**

2. **Refactored Map type representation**
   - Updated `DataType` enum in `src/daft-schema/src/dtype.rs` to use explicit key and value types

3. **Improved error handling and type checking**
   - Enhanced type checking in various parts of the codebase
   - Improved error messages for better debugging

4. **Performance optimizations**
   - Refactored some operations to be more efficient, especially in list and map operations

5. **Code cleanup and minor improvements**
   - Removed unnecessary clones and improved code readability
   - Updated comments and documentation
  • Loading branch information
andrewgazelka authored Oct 7, 2024
1 parent f5cf5af commit 98dbadb
Show file tree
Hide file tree
Showing 53 changed files with 1,302 additions and 420 deletions.
2 changes: 2 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[env]
PYO3_PYTHON = "./.venv/bin/python"
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,7 @@ def dt_truncate(expr: PyExpr, interval: str, relative_to: PyExpr) -> PyExpr: ...
# ---
def explode(expr: PyExpr) -> PyExpr: ...
def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ...
def list_value_counts(expr: PyExpr) -> PyExpr: ...
def list_join(expr: PyExpr, delimiter: PyExpr) -> PyExpr: ...
def list_count(expr: PyExpr, mode: CountMode) -> PyExpr: ...
def list_get(expr: PyExpr, idx: PyExpr, default: PyExpr) -> PyExpr: ...
Expand Down
64 changes: 49 additions & 15 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2930,6 +2930,40 @@ def join(self, delimiter: str | Expression) -> Expression:
delimiter_expr = Expression._to_expression(delimiter)
return Expression._from_pyexpr(native.list_join(self._expr, delimiter_expr._expr))

def value_counts(self) -> Expression:
"""Counts the occurrences of each unique value in the list.
Returns:
Expression: A Map<X, UInt64> expression where the keys are unique elements from the
original list of type X, and the values are UInt64 counts representing
the number of times each element appears in the list.
Note:
This function does not work for nested types. For example, it will not produce a map
with lists as keys.
Example:
>>> import daft
>>> df = daft.from_pydict({"letters": [["a", "b", "a"], ["b", "c", "b", "c"]]})
>>> df.with_column("value_counts", df["letters"].list.value_counts()).collect()
╭──────────────┬───────────────────╮
│ letters ┆ value_counts │
│ --- ┆ --- │
│ List[Utf8] ┆ Map[Utf8: UInt64] │
╞══════════════╪═══════════════════╡
│ [a, b, a] ┆ [{key: a, │
│ ┆ value: 2, │
│ ┆ }, {key: … │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [b, c, b, c] ┆ [{key: b, │
│ ┆ value: 2, │
│ ┆ }, {key: … │
╰──────────────┴───────────────────╯
<BLANKLINE>
(Showing first 2 of 2 rows)
"""
return Expression._from_pyexpr(native.list_value_counts(self._expr))

def count(self, mode: CountMode = CountMode.Valid) -> Expression:
"""Counts the number of elements in each list
Expand Down Expand Up @@ -3092,21 +3126,21 @@ def get(self, key: Expression) -> Expression:
>>> df = daft.from_arrow(pa.table({"map_col": pa_array}))
>>> df = df.with_column("a", df["map_col"].map.get("a"))
>>> df.show()
╭──────────────────────────────────────┬───────╮
│ map_col ┆ a │
│ --- ┆ --- │
│ Map[Struct[key: Utf8, value: Int64]] ┆ Int64 │
╞══════════════════════════════════════╪═══════╡
│ [{key: a, ┆ 1 │
│ value: 1, ┆ │
│ }] ┆ │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [] ┆ None │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [{key: b, ┆ None │
│ value: 2, ┆ │
│ }] ┆ │
╰──────────────────────────────────────┴───────╯
╭──────────────────┬───────╮
│ map_col ┆ a │
│ --- ┆ --- │
│ Map[Utf8: Int64] ┆ Int64 │
╞══════════════════╪═══════╡
│ [{key: a, ┆ 1 │
│ value: 1, ┆ │
│ }] ┆ │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [] ┆ None │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [{key: b, ┆ None │
│ value: 2, ┆ │
│ }] ┆ │
╰──────────────────┴───────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
Expand Down
10 changes: 8 additions & 2 deletions src/arrow2/src/array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,18 @@ impl<O: Offset> ListArray<O> {
if O::IS_LARGE {
match data_type.to_logical_type() {
DataType::LargeList(child) => Ok(child.as_ref()),
_ => Err(Error::oos("ListArray<i64> expects DataType::LargeList")),
got => {
let msg = format!("ListArray<i64> expects DataType::LargeList, but got {got:?}");
Err(Error::oos(msg))
},
}
} else {
match data_type.to_logical_type() {
DataType::List(child) => Ok(child.as_ref()),
_ => Err(Error::oos("ListArray<i32> expects DataType::List")),
got => {
let msg = format!("ListArray<i32> expects DataType::List, but got {got:?}");
Err(Error::oos(msg))
},
}
}
}
Expand Down
81 changes: 69 additions & 12 deletions src/arrow2/src/array/map/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use super::{new_empty_array, specification::try_check_offsets_bounds, Array, ListArray};
use crate::{
bitmap::Bitmap,
datatypes::{DataType, Field},
error::Error,
offset::OffsetsBuffer,
};

use super::{new_empty_array, specification::try_check_offsets_bounds, Array};

mod ffi;
pub(super) mod fmt;
mod iterator;
Expand Down Expand Up @@ -41,20 +40,27 @@ impl MapArray {
try_check_offsets_bounds(&offsets, field.len())?;

let inner_field = Self::try_get_field(&data_type)?;
if let DataType::Struct(inner) = inner_field.data_type() {
if inner.len() != 2 {
return Err(Error::InvalidArgumentError(
"MapArray's inner `Struct` must have 2 fields (keys and maps)".to_string(),
));
}
} else {

let inner_data_type = inner_field.data_type();
let DataType::Struct(inner) = inner_data_type else {
return Err(Error::InvalidArgumentError(
"MapArray expects `DataType::Struct` as its inner logical type".to_string(),
format!("MapArray expects `DataType::Struct` as its inner logical type, but found {inner_data_type:?}"),
));
};

let inner_len = inner.len();
if inner_len != 2 {
let msg = format!(
"MapArray's inner `Struct` must have 2 fields (keys and maps), but found {} fields",
inner_len
);
return Err(Error::InvalidArgumentError(msg));
}
if field.data_type() != inner_field.data_type() {

let field_data_type = field.data_type();
if field_data_type != inner_field.data_type() {
return Err(Error::InvalidArgumentError(
"MapArray expects `field.data_type` to match its inner DataType".to_string(),
format!("MapArray expects `field.data_type` to match its inner DataType, but found \n{field_data_type:?}\nvs\n\n\n{inner_field:?}"),
));
}

Expand Down Expand Up @@ -195,6 +201,57 @@ impl MapArray {
impl Array for MapArray {
impl_common_array!();

fn convert_logical_type(&self, target_data_type: DataType) -> Box<dyn Array> {
let is_target_map = matches!(target_data_type, DataType::Map { .. });

let DataType::Map(current_field, _) = self.data_type() else {
unreachable!(
"Expected MapArray to have Map data type, but found {:?}",
self.data_type()
);
};

if is_target_map {
// For Map-to-Map conversions, we can clone
// (same top level representation we are still a Map). and then change the subtype in
// place.
let mut converted_array = self.to_boxed();
converted_array.change_type(target_data_type);
return converted_array;
}

// Target type is a LargeList, so we need to convert to a ListArray before converting
let DataType::LargeList(target_field) = &target_data_type else {
panic!("MapArray can only be converted to Map or LargeList, but target type is {target_data_type:?}");
};


let current_physical_type = current_field.data_type.to_physical_type();
let target_physical_type = target_field.data_type.to_physical_type();

if current_physical_type != target_physical_type {
panic!(
"Inner physical types must be equal for conversion. Current: {:?}, Target: {:?}",
current_physical_type, target_physical_type
);
}

let mut converted_field = self.field.clone();
converted_field.change_type(target_field.data_type.clone());

let original_offsets = self.offsets().clone();
let converted_offsets = unsafe { original_offsets.map_unchecked(|offset| offset as i64) };

let converted_list = ListArray::new(
target_data_type,
converted_offsets,
converted_field,
self.validity.clone(),
);

Box::new(converted_list)
}

fn validity(&self) -> Option<&Bitmap> {
self.validity.as_ref()
}
Expand Down
Loading

0 comments on commit 98dbadb

Please sign in to comment.