diff --git a/src/daft-core/src/array/growable/arrow_growable.rs b/src/daft-core/src/array/growable/arrow_growable.rs index b63cc399c0..db8b4667ef 100644 --- a/src/daft-core/src/array/growable/arrow_growable.rs +++ b/src/daft-core/src/array/growable/arrow_growable.rs @@ -9,9 +9,8 @@ use crate::{ }, datatypes::{ BinaryType, BooleanType, DaftArrowBackedType, DaftDataType, ExtensionArray, Field, - FixedSizeListType, Float32Type, Float64Type, Int128Type, Int16Type, Int32Type, Int64Type, - Int8Type, ListType, NullType, StructType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, - Utf8Type, + Float32Type, Float64Type, Int128Type, Int16Type, Int32Type, Int64Type, Int8Type, ListType, + NullType, StructType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Type, }, DataType, IntoSeries, Series, }; @@ -170,11 +169,6 @@ impl_arrow_backed_data_array_growable!( ListType, arrow2::array::growable::GrowableList<'a, i64> ); -impl_arrow_backed_data_array_growable!( - ArrowFixedSizeListGrowable, - FixedSizeListType, - arrow2::array::growable::GrowableFixedSizeList<'a> -); impl_arrow_backed_data_array_growable!( ArrowStructGrowable, StructType, diff --git a/src/daft-core/src/array/growable/mod.rs b/src/daft-core/src/array/growable/mod.rs index 28ffc7222a..e6ca4d8c5d 100644 --- a/src/daft-core/src/array/growable/mod.rs +++ b/src/daft-core/src/array/growable/mod.rs @@ -6,15 +6,17 @@ use crate::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, }, - BinaryArray, BooleanArray, ExtensionArray, FixedSizeListArray, Float32Array, Float64Array, - Int128Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, NullArray, - StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, + nested_arrays::FixedSizeListArray, + BinaryArray, BooleanArray, ExtensionArray, Float32Array, Float64Array, Int128Array, + Int16Array, Int32Array, Int64Array, Int8Array, ListArray, NullArray, StructArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }, DataType, Series, }; mod arrow_growable; mod logical_growable; +mod nested_growable; #[cfg(feature = "python")] mod python_growable; @@ -120,7 +122,7 @@ impl_growable_array!(Utf8Array, arrow_growable::ArrowUtf8Growable<'a>); impl_growable_array!(ListArray, arrow_growable::ArrowListGrowable<'a>); impl_growable_array!( FixedSizeListArray, - arrow_growable::ArrowFixedSizeListGrowable<'a> + nested_growable::FixedSizeListGrowable<'a> ); impl_growable_array!(StructArray, arrow_growable::ArrowStructGrowable<'a>); impl_growable_array!(ExtensionArray, arrow_growable::ArrowExtensionGrowable<'a>); diff --git a/src/daft-core/src/array/growable/nested_growable.rs b/src/daft-core/src/array/growable/nested_growable.rs new file mode 100644 index 0000000000..62a530b281 --- /dev/null +++ b/src/daft-core/src/array/growable/nested_growable.rs @@ -0,0 +1,120 @@ +use std::mem::swap; + +use common_error::DaftResult; + +use crate::{ + datatypes::{nested_arrays::FixedSizeListArray, Field}, + with_match_daft_types, DataType, IntoSeries, Series, +}; + +use super::{Growable, GrowableArray}; + +pub struct ArrowBitmapGrowable<'a> { + bitmap_refs: Vec>, + mutable_bitmap: arrow2::bitmap::MutableBitmap, +} + +impl<'a> ArrowBitmapGrowable<'a> { + pub fn new(bitmap_refs: Vec>, capacity: usize) -> Self { + Self { + bitmap_refs, + mutable_bitmap: arrow2::bitmap::MutableBitmap::with_capacity(capacity), + } + } + + pub fn extend(&mut self, index: usize, start: usize, len: usize) { + let bm = self.bitmap_refs.get(index).unwrap(); + match bm { + None => self.mutable_bitmap.extend_constant(len, true), + Some(bm) => { + let (bm_data, bm_start, _bm_len) = bm.as_slice(); + self.mutable_bitmap + .extend_from_slice(bm_data, bm_start + start, len) + } + } + } + + fn add_nulls(&mut self, additional: usize) { + self.mutable_bitmap.extend_constant(additional, false) + } + + fn build(self) -> arrow2::bitmap::Bitmap { + self.mutable_bitmap.clone().into() + } +} + +pub struct FixedSizeListGrowable<'a> { + name: String, + dtype: DataType, + element_fixed_len: usize, + child_growable: Box, + growable_validity: ArrowBitmapGrowable<'a>, +} + +impl<'a> FixedSizeListGrowable<'a> { + pub fn new( + name: String, + dtype: &DataType, + arrays: Vec<&'a FixedSizeListArray>, + use_validity: bool, + capacity: usize, + ) -> Self { + match dtype { + DataType::FixedSizeList(child_field, element_fixed_len) => { + with_match_daft_types!(&child_field.dtype, |$T| { + let child_growable = <<$T as DaftDataType>::ArrayType as GrowableArray>::make_growable( + name.clone(), + &child_field.dtype, + arrays.iter().map(|a| a.flat_child.downcast::<<$T as DaftDataType>::ArrayType>().unwrap()).collect::>(), + use_validity, + capacity * element_fixed_len, + ); + let growable_validity = ArrowBitmapGrowable::new( + arrays.iter().map(|a| a.validity.as_ref()).collect(), + capacity, + ); + Self { + name, + dtype: dtype.clone(), + element_fixed_len: *element_fixed_len, + child_growable: Box::new(child_growable), + growable_validity, + } + }) + } + _ => panic!("Cannot create FixedSizeListGrowable from dtype: {}", dtype), + } + } +} + +impl<'a> Growable for FixedSizeListGrowable<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + self.child_growable.extend( + index, + start * self.element_fixed_len, + len * self.element_fixed_len, + ); + self.growable_validity.extend(index, start, len); + } + + fn add_nulls(&mut self, additional: usize) { + self.child_growable + .add_nulls(additional * self.element_fixed_len); + self.growable_validity.add_nulls(additional); + } + + fn build(&mut self) -> DaftResult { + // Swap out self.growable_validity so we can use the values and move it + let mut grown_validity = ArrowBitmapGrowable::new(vec![], 0); + swap(&mut self.growable_validity, &mut grown_validity); + + let built_child = self.child_growable.build()?; + let built_validity = grown_validity.build(); + Ok(FixedSizeListArray::new( + Field::new(self.name.clone(), self.dtype.clone()), + built_child, + Some(built_validity), + ) + .into_series()) + } +} diff --git a/src/daft-core/src/array/mod.rs b/src/daft-core/src/array/mod.rs index 9200d83fe8..a699e2d5cd 100644 --- a/src/daft-core/src/array/mod.rs +++ b/src/daft-core/src/array/mod.rs @@ -57,6 +57,10 @@ where self.data().len() } + pub fn data_type(&self) -> &DataType { + &self.field.dtype + } + pub fn is_empty(&self) -> bool { self.len() == 0 } @@ -92,10 +96,6 @@ where self.data.as_ref() } - pub fn data_type(&self) -> &DataType { - &self.field.dtype - } - pub fn name(&self) -> &str { self.field.name.as_str() } diff --git a/src/daft-core/src/array/ops/as_arrow.rs b/src/daft-core/src/array/ops/as_arrow.rs index 24d5b77f82..f3a4c309da 100644 --- a/src/daft-core/src/array/ops/as_arrow.rs +++ b/src/daft-core/src/array/ops/as_arrow.rs @@ -5,11 +5,9 @@ use crate::{ array::DataArray, datatypes::{ logical::{ - DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, + DateArray, Decimal128Array, DurationArray, ImageArray, TensorArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray, ListArray, NullArray, - StructArray, Utf8Array, + BinaryArray, BooleanArray, DaftNumericType, ListArray, NullArray, StructArray, Utf8Array, }, }; @@ -64,7 +62,6 @@ impl_asarrow_dataarray!(Utf8Array, array::Utf8Array); impl_asarrow_dataarray!(BooleanArray, array::BooleanArray); impl_asarrow_dataarray!(BinaryArray, array::BinaryArray); impl_asarrow_dataarray!(ListArray, array::ListArray); -impl_asarrow_dataarray!(FixedSizeListArray, array::FixedSizeListArray); impl_asarrow_dataarray!(StructArray, array::StructArray); #[cfg(feature = "python")] @@ -74,8 +71,5 @@ impl_asarrow_logicalarray!(Decimal128Array, array::PrimitiveArray); impl_asarrow_logicalarray!(DateArray, array::PrimitiveArray); impl_asarrow_logicalarray!(DurationArray, array::PrimitiveArray); impl_asarrow_logicalarray!(TimestampArray, array::PrimitiveArray); -impl_asarrow_logicalarray!(EmbeddingArray, array::FixedSizeListArray); impl_asarrow_logicalarray!(ImageArray, array::StructArray); -impl_asarrow_logicalarray!(FixedShapeImageArray, array::FixedSizeListArray); impl_asarrow_logicalarray!(TensorArray, array::StructArray); -impl_asarrow_logicalarray!(FixedShapeTensorArray, array::FixedSizeListArray); diff --git a/src/daft-core/src/array/ops/broadcast.rs b/src/daft-core/src/array/ops/broadcast.rs index 711db5502e..9d74e20dd1 100644 --- a/src/daft-core/src/array/ops/broadcast.rs +++ b/src/daft-core/src/array/ops/broadcast.rs @@ -3,7 +3,7 @@ use crate::{ growable::{Growable, GrowableArray}, DataArray, }, - datatypes::{DaftArrayType, DaftPhysicalType, DataType}, + datatypes::{nested_arrays::FixedSizeListArray, DaftArrayType, DaftPhysicalType, DataType}, }; use common_error::{DaftError, DaftResult}; @@ -53,3 +53,24 @@ where } } } + +impl Broadcastable for FixedSizeListArray { + fn broadcast(&self, num: usize) -> DaftResult { + if self.len() != 1 { + return Err(DaftError::ValueError(format!( + "Attempting to broadcast non-unit length Array named: {}", + self.name() + ))); + } + + if self.is_valid(0) { + generic_growable_broadcast(self, num, self.name(), self.data_type()) + } else { + Ok(FixedSizeListArray::full_null( + self.name(), + self.data_type(), + num, + )) + } + } +} diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 676cbc6635..dbbae736c3 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -1,3 +1,5 @@ +use std::{iter::repeat, sync::Arc}; + use super::as_arrow::AsArrow; use crate::{ array::{ @@ -11,12 +13,12 @@ use crate::{ FixedShapeTensorArray, ImageArray, LogicalArray, LogicalArrayImpl, TensorArray, TimestampArray, }, - DaftArrowBackedType, DaftLogicalType, DataType, Field, FixedSizeListArray, ImageMode, - StructArray, TimeUnit, Utf8Array, + nested_arrays::FixedSizeListArray, + DaftArrowBackedType, DaftLogicalType, DataType, Field, ImageMode, ListArray, StructArray, + TimeUnit, Utf8Array, }, series::{IntoSeries, Series}, - with_match_arrow_daft_types, with_match_daft_logical_primitive_types, - with_match_daft_logical_types, + with_match_arrow_daft_types, with_match_daft_logical_primitive_types, with_match_daft_types, }; use common_error::{DaftError, DaftResult}; @@ -26,13 +28,13 @@ use arrow2::{ self, cast::{can_cast_types, cast, CastOptions}, }, + offset::Offsets, }; -use std::sync::Arc; #[cfg(feature = "python")] use { crate::array::pseudo_arrow::PseudoArrowArray, - crate::datatypes::{ListArray, PythonArray}, + crate::datatypes::PythonArray, crate::ffi, crate::with_match_numeric_daft_types, log, @@ -137,15 +139,10 @@ where } }; - let new_field = Arc::new(Field::new(to_cast.name(), dtype.clone())); + let new_field = Field::new(to_cast.name(), dtype.clone()); - if dtype.is_logical() { - with_match_daft_logical_types!(dtype, |$T| { - return Ok(LogicalArray::<$T>::from_arrow(new_field.as_ref(), result_arrow_physical_array)?.into_series()) - }) - } - with_match_arrow_daft_types!(dtype, |$T| { - Ok(DataArray::<$T>::from_arrow(new_field.as_ref(), result_arrow_physical_array)?.into_series()) + with_match_daft_types!(dtype, |$T| { + return Ok(<$T as DaftDataType>::ArrayType::from_arrow(&new_field, result_arrow_physical_array)?.into_series()); }) } @@ -228,15 +225,10 @@ where ))); }; - let new_field = Arc::new(Field::new(to_cast.name(), dtype.clone())); + let new_field = Field::new(to_cast.name(), dtype.clone()); - if dtype.is_logical() { - with_match_daft_logical_types!(dtype, |$T| { - return Ok(LogicalArray::<$T>::from_arrow(new_field.as_ref(), result_array)?.into_series()); - }) - } - with_match_arrow_daft_types!(dtype, |$T| { - return Ok(DataArray::<$T>::from_arrow(new_field.as_ref(), result_array)?.into_series()); + with_match_daft_types!(dtype, |$T| { + Ok(<$T as DaftDataType>::ArrayType::from_arrow(&new_field, result_array)?.into_series()) }) } @@ -749,9 +741,7 @@ fn extract_python_like_to_fixed_size_list< Box::new(arrow2::array::PrimitiveArray::from_vec(values_vec)); let inner_field = child_field.to_arrow()?; - let list_dtype = arrow2::datatypes::DataType::FixedSizeList(Box::new(inner_field), list_size); - let daft_type = (&list_dtype).into(); let list_array = arrow2::array::FixedSizeListArray::new( @@ -760,8 +750,8 @@ fn extract_python_like_to_fixed_size_list< python_objects.as_arrow().validity().cloned(), ); - FixedSizeListArray::new( - Field::new(python_objects.name(), daft_type).into(), + FixedSizeListArray::from_arrow( + &Field::new(python_objects.name(), daft_type), Box::new(list_array), ) } @@ -1091,23 +1081,23 @@ impl EmbeddingArray { match (dtype, self.data_type()) { #[cfg(feature = "python")] (DataType::Python, DataType::Embedding(_, size)) => Python::with_gil(|py| { + let physical_arrow = self.physical.flat_child.to_arrow(); let shape = (self.len(), *size); let pyarrow = py.import("pyarrow")?; // Only go through FFI layer once instead of for every embedding. // We create an ndarray view on the entire embeddings array // buffer sans the validity mask, and then create a subndarray view // for each embedding ndarray in the PythonArray. - let py_array = - ffi::to_py_array(self.as_arrow().values().with_validity(None), py, pyarrow)? - .call_method1(py, pyo3::intern!(py, "to_numpy"), (false,))? - .call_method1(py, pyo3::intern!(py, "reshape"), (shape,))?; + let py_array = ffi::to_py_array(physical_arrow.with_validity(None), py, pyarrow)? + .call_method1(py, pyo3::intern!(py, "to_numpy"), (false,))? + .call_method1(py, pyo3::intern!(py, "reshape"), (shape,))?; let ndarrays = py_array .as_ref(py) .iter()? .map(|a| a.unwrap().to_object(py)) .collect::>(); let values_array = - PseudoArrowArray::new(ndarrays.into(), self.as_arrow().validity().cloned()); + PseudoArrowArray::new(ndarrays.into(), self.physical.validity.clone()); Ok(PythonArray::new( Field::new(self.name(), dtype.clone()).into(), values_array.to_boxed(), @@ -1252,6 +1242,7 @@ impl FixedShapeImageArray { match (dtype, self.data_type()) { #[cfg(feature = "python")] (DataType::Python, DataType::FixedShapeImage(mode, height, width)) => { + let physical_arrow = self.physical.flat_child.to_arrow(); pyo3::Python::with_gil(|py| { let shape = ( self.len(), @@ -1264,24 +1255,17 @@ impl FixedShapeImageArray { // We create an (N, H, W, C) ndarray view on the entire image array // buffer sans the validity mask, and then create a subndarray view // for each image ndarray in the PythonArray. - let py_array = ffi::to_py_array( - self.as_arrow().values().with_validity(None), - py, - pyarrow, - )? - .call_method1(py, pyo3::intern!(py, "to_numpy"), (false,))? - .call_method1( - py, - pyo3::intern!(py, "reshape"), - (shape,), - )?; + let py_array = + ffi::to_py_array(physical_arrow.with_validity(None), py, pyarrow)? + .call_method1(py, pyo3::intern!(py, "to_numpy"), (false,))? + .call_method1(py, pyo3::intern!(py, "reshape"), (shape,))?; let ndarrays = py_array .as_ref(py) .iter()? .map(|a| a.unwrap().to_object(py)) .collect::>(); let values_array = - PseudoArrowArray::new(ndarrays.into(), self.as_arrow().validity().cloned()); + PseudoArrowArray::new(ndarrays.into(), self.physical.validity.clone()); Ok(PythonArray::new( Field::new(self.name(), dtype.clone()).into(), values_array.to_boxed(), @@ -1388,11 +1372,8 @@ impl TensorArray { Default::default(), )?; let inner_field = Box::new(Field::new("data", *inner_dtype.clone())); - let new_field = Arc::new(Field::new( - "data", - DataType::FixedSizeList(inner_field, size), - )); - let result = FixedSizeListArray::new(new_field, new_da)?; + let new_field = Field::new("data", DataType::FixedSizeList(inner_field, size)); + let result = FixedSizeListArray::from_arrow(&new_field, new_da)?; let tensor_array = FixedShapeTensorArray::new(Field::new(self.name(), dtype.clone()), result); Ok(tensor_array.into_series()) @@ -1527,6 +1508,7 @@ impl FixedShapeTensorArray { match (dtype, self.data_type()) { #[cfg(feature = "python")] (DataType::Python, DataType::FixedShapeTensor(_, shape)) => { + let physical_arrow = self.physical.flat_child.to_arrow(); pyo3::Python::with_gil(|py| { let pyarrow = py.import("pyarrow")?; let mut np_shape: Vec = vec![self.len() as u64]; @@ -1535,24 +1517,17 @@ impl FixedShapeTensorArray { // We create an (N, [shape..]) ndarray view on the entire tensor array buffer // sans the validity mask, and then create a subndarray view for each ndarray // element in the PythonArray. - let py_array = ffi::to_py_array( - self.as_arrow().values().with_validity(None), - py, - pyarrow, - )? - .call_method1(py, pyo3::intern!(py, "to_numpy"), (false,))? - .call_method1( - py, - pyo3::intern!(py, "reshape"), - (np_shape,), - )?; + let py_array = + ffi::to_py_array(physical_arrow.with_validity(None), py, pyarrow)? + .call_method1(py, pyo3::intern!(py, "to_numpy"), (false,))? + .call_method1(py, pyo3::intern!(py, "reshape"), (np_shape,))?; let ndarrays = py_array .as_ref(py) .iter()? .map(|a| a.unwrap().to_object(py)) .collect::>(); let values_array = - PseudoArrowArray::new(ndarrays.into(), self.as_arrow().validity().cloned()); + PseudoArrowArray::new(ndarrays.into(), self.physical.validity.clone()); Ok(PythonArray::new( Field::new(self.name(), dtype.clone()).into(), values_array.to_boxed(), @@ -1572,19 +1547,18 @@ impl FixedShapeTensorArray { .step_by(ndim) .map(|v| v as i64) .collect::>(); - let physical_arr = self.as_arrow(); - let list_dtype = arrow2::datatypes::DataType::LargeList(Box::new( - arrow2::datatypes::Field::new("data", inner_dtype.to_arrow()?, true), - )); - let list_arr = cast( - physical_arr, - &list_dtype, - CastOptions { - wrapped: true, - partial: false, - }, - )?; - let validity = self.as_arrow().validity(); + + let physical_arr = &self.physical; + let validity = self.physical.validity.as_ref(); + + // FixedSizeList -> List + let list_arr = physical_arr.cast(&DataType::List(Box::new(Field::new( + "data", + inner_dtype.as_ref().clone(), + ))))?; + let list_arr = list_arr.downcast::()?.data(); + + // List -> Struct let shapes_dtype = arrow2::datatypes::DataType::LargeList(Box::new( arrow2::datatypes::Field::new( "shape", @@ -1599,15 +1573,14 @@ impl FixedShapeTensorArray { Box::new(arrow2::array::PrimitiveArray::from_vec(shapes)), validity.cloned(), )); - - let values: Vec> = vec![list_arr, shapes_array]; + let values: Vec> = + vec![list_arr.to_boxed(), shapes_array]; let physical_type = dtype.to_physical(); let struct_array = Box::new(arrow2::array::StructArray::new( physical_type.to_arrow()?, values, validity.cloned(), )); - let daft_struct_array = StructArray::new(Field::new(self.name(), physical_type).into(), struct_array)?; Ok( @@ -1621,6 +1594,95 @@ impl FixedShapeTensorArray { } } +impl FixedSizeListArray { + pub fn cast(&self, dtype: &DataType) -> DaftResult { + match dtype { + DataType::FixedSizeList(child, size) => { + if size != &self.fixed_element_len() { + return Err(DaftError::ValueError(format!( + "Cannot cast from FixedSizeListSeries with size {} to size: {}", + self.fixed_element_len(), + size + ))); + } + let casted_child = self.flat_child.cast(&child.dtype)?; + Ok(FixedSizeListArray::new( + Field::new(self.name().to_string(), dtype.clone()), + casted_child, + self.validity.clone(), + ) + .into_series()) + } + DataType::List(child) => { + let element_size = self.fixed_element_len(); + // TODO: This will be refactored when List is no longer arrow backed + let casted_child: Box = with_match_arrow_daft_types!(child.dtype, |$T| { + let casted_child_series = self.flat_child.cast(&child.dtype)?; + let downcasted = casted_child_series.downcast::<<$T as DaftDataType>::ArrayType>()?; + downcasted.data().to_boxed() + }); + let offsets: Offsets = match &self.validity { + None => Offsets::try_from_iter(repeat(element_size).take(self.len()))?, + Some(validity) => Offsets::try_from_iter(validity.iter().map(|v| { + if v { + element_size + } else { + 0 + } + }))?, + }; + let list_arrow_array = arrow2::array::ListArray::new( + dtype.to_arrow()?, + offsets.into(), + casted_child, + self.validity.clone(), + ); + Ok(ListArray::new( + Arc::new(Field::new(self.name().to_string(), dtype.clone())), + Box::new(list_arrow_array), + )? + .into_series()) + } + DataType::FixedShapeTensor(child_datatype, shape) => { + if child_datatype.as_ref() != self.child_data_type() { + return Err(DaftError::TypeError(format!( + "Cannot cast {} to {}: mismatched child type", + self.data_type(), + dtype + ))); + } + if shape.iter().product::() != (self.fixed_element_len() as u64) { + return Err(DaftError::TypeError(format!( + "Cannot cast {} to {}: mismatch in element sizes", + self.data_type(), + dtype + ))); + } + Ok(FixedShapeTensorArray::new( + Field::new(self.name().to_string(), dtype.clone()), + self.clone(), + ) + .into_series()) + } + DataType::FixedShapeImage(mode, h, w) => { + if (h * w * mode.num_channels() as u32) as u64 != self.fixed_element_len() as u64 { + return Err(DaftError::TypeError(format!( + "Cannot cast {} to {}: mismatch in element sizes", + self.data_type(), + dtype + ))); + } + Ok(FixedShapeImageArray::new( + Field::new(self.name().to_string(), dtype.clone()), + self.clone(), + ) + .into_series()) + } + _ => unimplemented!("FixedSizeList casting not implemented for dtype: {}", dtype), + } + } +} + #[cfg(feature = "python")] fn cast_logical_to_python_array(array: &LogicalArray, dtype: &DataType) -> DaftResult where diff --git a/src/daft-core/src/array/ops/compare_agg.rs b/src/daft-core/src/array/ops/compare_agg.rs index ec1f829943..2a7e8131f6 100644 --- a/src/daft-core/src/array/ops/compare_agg.rs +++ b/src/daft-core/src/array/ops/compare_agg.rs @@ -1,5 +1,6 @@ +use super::full::FullNull; use super::{DaftCompareAggable, GroupIndices}; -use crate::{array::ops::full::FullNull, array::DataArray, datatypes::*}; +use crate::{array::DataArray, datatypes::nested_arrays::FixedSizeListArray, datatypes::*}; use arrow2::array::PrimitiveArray; use arrow2::{self, array::Array}; diff --git a/src/daft-core/src/array/ops/comparison.rs b/src/daft-core/src/array/ops/comparison.rs index 9bf61bed3e..97ecc0ec2e 100644 --- a/src/daft-core/src/array/ops/comparison.rs +++ b/src/daft-core/src/array/ops/comparison.rs @@ -1167,7 +1167,10 @@ impl DaftCompare<&str> for Utf8Array { #[cfg(test)] mod tests { - use crate::{array::ops::DaftCompare, datatypes::Int64Array}; + use crate::{ + array::ops::DaftCompare, + datatypes::{DaftArrayType, Int64Array}, + }; use common_error::DaftResult; #[test] diff --git a/src/daft-core/src/array/ops/count.rs b/src/daft-core/src/array/ops/count.rs index 5bcc2f8543..cbb601b6e1 100644 --- a/src/daft-core/src/array/ops/count.rs +++ b/src/daft-core/src/array/ops/count.rs @@ -1,12 +1,60 @@ -use std::sync::Arc; +use std::{iter::repeat, sync::Arc}; use arrow2; -use crate::{array::DataArray, count_mode::CountMode, datatypes::*}; +use crate::{ + array::DataArray, + count_mode::CountMode, + datatypes::{nested_arrays::FixedSizeListArray, *}, +}; use common_error::DaftResult; use super::{DaftCountAggable, GroupIndices}; +/// Helper to perform a grouped count on a validity map of type arrow2::bitmap::Bitmap +fn grouped_count_arrow_bitmap( + groups: &GroupIndices, + mode: &CountMode, + arrow_bitmap: Option<&arrow2::bitmap::Bitmap>, +) -> Vec { + match mode { + CountMode::All => groups.iter().map(|g| g.len() as u64).collect(), + CountMode::Valid => match arrow_bitmap { + None => groups.iter().map(|g| g.len() as u64).collect(), // Equivalent to CountMode::All + Some(validity) => groups + .iter() + .map(|g| g.iter().map(|i| validity.get_bit(*i as usize) as u64).sum()) + .collect(), + }, + CountMode::Null => match arrow_bitmap { + None => repeat(0).take(groups.len()).collect(), // None of the values are Null + Some(validity) => groups + .iter() + .map(|g| g.iter().map(|i| validity.get_bit(*i as usize) as u64).sum()) + .collect(), + }, + } +} + +/// Helper to perform a count on a validity map of type arrow2::bitmap::Bitmap +fn count_arrow_bitmap( + mode: &CountMode, + arrow_bitmap: Option<&arrow2::bitmap::Bitmap>, + arr_len: usize, +) -> u64 { + match mode { + CountMode::All => arr_len as u64, + CountMode::Valid => match arrow_bitmap { + None => arr_len as u64, + Some(validity) => validity.into_iter().map(|b| b as u64).sum(), + }, + CountMode::Null => match arrow_bitmap { + None => 0, + Some(validity) => validity.into_iter().map(|b| !b as u64).sum(), + }, + } +} + impl DaftCountAggable for &DataArray where T: DaftPhysicalType, @@ -14,48 +62,53 @@ where type Output = DaftResult>; fn count(&self, mode: CountMode) -> Self::Output { - let arrow_array = &self.data; - let count = match mode { - CountMode::All => arrow_array.len(), - CountMode::Valid => arrow_array.len() - arrow_array.null_count(), - CountMode::Null => arrow_array.null_count(), + let count = if self.data_type() == &DataType::Null { + match &mode { + CountMode::All => self.len() as u64, + CountMode::Valid => 0u64, + CountMode::Null => self.len() as u64, + } + } else { + count_arrow_bitmap(&mode, self.data().validity(), self.len()) }; - let result_arrow_array = - Box::new(arrow2::array::PrimitiveArray::from([Some(count as u64)])); + let result_arrow_array = Box::new(arrow2::array::PrimitiveArray::from([Some(count)])); DataArray::::new( Arc::new(Field::new(self.field.name.clone(), DataType::UInt64)), result_arrow_array, ) } fn grouped_count(&self, groups: &GroupIndices, mode: CountMode) -> Self::Output { - let arrow_array = self.data.as_ref(); - - let counts_per_group: Vec<_> = match mode { - CountMode::All => groups.iter().map(|g| g.len() as u64).collect(), - CountMode::Valid => { - if arrow_array.null_count() > 0 { - groups - .iter() - .map(|g| { - let null_count = g - .iter() - .fold(0u64, |acc, v| acc + arrow_array.is_null(*v as usize) as u64); - (g.len() as u64) - null_count - }) - .collect() - } else { - groups.iter().map(|g| g.len() as u64).collect() - } + let counts_per_group: Vec = if self.data_type() == &DataType::Null { + match &mode { + CountMode::All => groups.iter().map(|g| g.len() as u64).collect(), + CountMode::Valid => repeat(0).take(groups.len()).collect(), + CountMode::Null => groups.iter().map(|g| g.len() as u64).collect(), } - CountMode::Null => groups - .iter() - .map(|g| { - g.iter() - .fold(0u64, |acc, v| acc + arrow_array.is_null(*v as usize) as u64) - }) - .collect(), + } else { + grouped_count_arrow_bitmap(groups, &mode, self.data().validity()) }; + Ok(DataArray::::from(( + self.field.name.as_ref(), + counts_per_group, + ))) + } +} + +impl DaftCountAggable for &FixedSizeListArray { + type Output = DaftResult>; + fn count(&self, mode: CountMode) -> Self::Output { + let count = count_arrow_bitmap(&mode, self.validity.as_ref(), self.len()); + let result_arrow_array = Box::new(arrow2::array::PrimitiveArray::from([Some(count)])); + DataArray::::new( + Arc::new(Field::new(self.field.name.clone(), DataType::UInt64)), + result_arrow_array, + ) + } + + fn grouped_count(&self, groups: &GroupIndices, mode: CountMode) -> Self::Output { + let counts_per_group: Vec<_> = + grouped_count_arrow_bitmap(groups, &mode, self.validity.as_ref()); Ok(DataArray::::from(( self.field.name.as_ref(), counts_per_group, diff --git a/src/daft-core/src/array/ops/filter.rs b/src/daft-core/src/array/ops/filter.rs index 6d2b95be0a..b315f458e5 100644 --- a/src/daft-core/src/array/ops/filter.rs +++ b/src/daft-core/src/array/ops/filter.rs @@ -1,3 +1,5 @@ +use std::iter::repeat; + use crate::{ array::DataArray, datatypes::{BooleanArray, DaftArrowBackedType}, @@ -69,3 +71,29 @@ impl crate::datatypes::PythonArray { DataArray::::new(self.field().clone().into(), arrow_array) } } + +impl crate::datatypes::nested_arrays::FixedSizeListArray { + pub fn filter(&self, mask: &BooleanArray) -> DaftResult { + let size = self.fixed_element_len(); + let expanded_filter: Vec = mask + .into_iter() + .flat_map(|pred| repeat(pred.unwrap_or(false)).take(size)) + .collect(); + let expanded_filter = BooleanArray::from(("", expanded_filter.as_slice())); + let filtered_child = self.flat_child.filter(&expanded_filter)?; + let filtered_validity = self.validity.as_ref().map(|validity| { + arrow2::bitmap::Bitmap::from_iter(mask.into_iter().zip(validity.iter()).filter_map( + |(keep, valid)| match keep { + None => None, + Some(false) => None, + Some(true) => Some(valid), + }, + )) + }); + Ok(Self::new( + self.field.clone(), + filtered_child, + filtered_validity, + )) + } +} diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs index e4b3ce8ed7..a7673c956c 100644 --- a/src/daft-core/src/array/ops/from_arrow.rs +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -1,8 +1,15 @@ -use common_error::DaftResult; +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; use crate::{ array::DataArray, - datatypes::{logical::LogicalArray, DaftDataType, DaftLogicalType, DaftPhysicalType, Field}, + datatypes::{ + logical::LogicalArray, nested_arrays::FixedSizeListArray, DaftDataType, DaftLogicalType, + DaftPhysicalType, Field, + }, + series::IntoSeries, + with_match_daft_types, DataType, }; /// Arrays that implement [`FromArrow`] can be instantiated from a Box @@ -33,3 +40,27 @@ where Ok(LogicalArray::::new(field.clone(), physical)) } } + +impl FromArrow for FixedSizeListArray { + fn from_arrow(field: &Field, arrow_arr: Box) -> DaftResult { + match (&field.dtype, arrow_arr.data_type()) { + (DataType::FixedSizeList(daft_child_field, daft_size), arrow2::datatypes::DataType::FixedSizeList(_arrow_child_field, arrow_size)) => { + if daft_size != arrow_size { + return Err(DaftError::TypeError(format!("Attempting to create Daft FixedSizeListArray with element length {} from Arrow FixedSizeList array with element length {}", daft_size, arrow_size))); + } + + let arrow_arr = arrow_arr.as_ref().as_any().downcast_ref::().unwrap(); + let arrow_child_array = arrow_arr.values(); + let child_series = with_match_daft_types!(daft_child_field.dtype, |$T| { + <$T as DaftDataType>::ArrayType::from_arrow(daft_child_field.as_ref(), arrow_child_array.clone())?.into_series() + }); + Ok(FixedSizeListArray::new( + Arc::new(field.clone()), + child_series, + arrow_arr.validity().cloned(), + )) + } + (d, a) => Err(DaftError::TypeError(format!("Attempting to create Daft FixedSizeListArray with type {:?} from arrow array with type {}", a, d))) + } + } +} diff --git a/src/daft-core/src/array/ops/full.rs b/src/daft-core/src/array/ops/full.rs index 6e1244f1fe..ab2a326666 100644 --- a/src/daft-core/src/array/ops/full.rs +++ b/src/daft-core/src/array/ops/full.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{iter::repeat, sync::Arc}; #[cfg(feature = "python")] use pyo3::Python; @@ -6,8 +6,10 @@ use pyo3::Python; use crate::{ array::{pseudo_arrow::PseudoArrowArray, DataArray}, datatypes::{ - logical::LogicalArray, DaftDataType, DaftLogicalType, DaftPhysicalType, DataType, Field, + logical::LogicalArray, nested_arrays::FixedSizeListArray, DaftDataType, DaftLogicalType, + DaftPhysicalType, DataType, Field, }, + with_match_daft_types, IntoSeries, }; pub trait FullNull { @@ -83,3 +85,27 @@ where Self::new(field, physical) } } + +impl FullNull for FixedSizeListArray { + fn full_null(name: &str, dtype: &DataType, length: usize) -> Self { + let empty = Self::empty(name, dtype); + let validity = arrow2::bitmap::Bitmap::from_iter(repeat(false).take(length)); + Self::new(empty.field, empty.flat_child, Some(validity)) + } + + fn empty(name: &str, dtype: &DataType) -> Self { + match dtype { + DataType::FixedSizeList(child, _) => { + let field = Field::new(name, dtype.clone()); + let empty_child = with_match_daft_types!(&child.dtype, |$T| { + <$T as DaftDataType>::ArrayType::empty(name, &child.dtype).into_series() + }); + Self::new(field, empty_child, None) + } + _ => panic!( + "Cannot create empty FixedSizeListArray with dtype: {}", + dtype + ), + } + } +} diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index b72d13186a..f1fed5dfdf 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -2,12 +2,13 @@ use crate::{ array::DataArray, datatypes::{ logical::{ - DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - ImageArray, TimestampArray, + DateArray, Decimal128Array, DurationArray, ImageArray, LogicalArrayImpl, TimestampArray, }, - BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, FixedSizeListArray, ListArray, + nested_arrays::FixedSizeListArray, + BinaryArray, BooleanArray, DaftLogicalType, DaftNumericType, ExtensionArray, ListArray, NullArray, StructArray, Utf8Array, }, + Series, }; use super::as_arrow::AsArrow; @@ -34,7 +35,7 @@ where } // Default implementations of get ops for DataArray and LogicalArray. -macro_rules! impl_array_get { +macro_rules! impl_array_arrow_get { ($ArrayT:ty, $output:ty) => { impl $ArrayT { #[inline] @@ -56,17 +57,21 @@ macro_rules! impl_array_get { }; } -impl_array_get!(Utf8Array, &str); -impl_array_get!(BooleanArray, bool); -impl_array_get!(BinaryArray, &[u8]); -impl_array_get!(ListArray, Box); -impl_array_get!(FixedSizeListArray, Box); -impl_array_get!(Decimal128Array, i128); -impl_array_get!(DateArray, i32); -impl_array_get!(DurationArray, i64); -impl_array_get!(TimestampArray, i64); -impl_array_get!(EmbeddingArray, Box); -impl_array_get!(FixedShapeImageArray, Box); +impl LogicalArrayImpl { + #[inline] + pub fn get(&self, idx: usize) -> Option { + self.physical.get(idx) + } +} + +impl_array_arrow_get!(Utf8Array, &str); +impl_array_arrow_get!(BooleanArray, bool); +impl_array_arrow_get!(BinaryArray, &[u8]); +impl_array_arrow_get!(ListArray, Box); +impl_array_arrow_get!(Decimal128Array, i128); +impl_array_arrow_get!(DateArray, i32); +impl_array_arrow_get!(DurationArray, i64); +impl_array_arrow_get!(TimestampArray, i64); impl NullArray { #[inline] @@ -163,3 +168,110 @@ impl ImageArray { } } } + +impl FixedSizeListArray { + #[inline] + pub fn get(&self, idx: usize) -> Option { + if idx >= self.len() { + panic!("Out of bounds: {} vs len: {}", idx, self.len()) + } + let fixed_len = self.fixed_element_len(); + let valid = self.is_valid(idx); + if valid { + Some( + self.flat_child + .slice(idx * fixed_len, (idx + 1) * fixed_len) + .unwrap(), + ) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use common_error::DaftResult; + + use crate::{ + datatypes::{nested_arrays::FixedSizeListArray, BooleanArray, Field, Int32Array}, + DataType, IntoSeries, + }; + + #[test] + fn test_fixed_size_list_get_all_valid() -> DaftResult<()> { + let field = Field::new( + "foo", + DataType::FixedSizeList(Box::new(Field::new("foo", DataType::Int32)), 3), + ); + let flat_child = Int32Array::from(("foo", (0..9).collect::>())); + let validity = None; + let arr = FixedSizeListArray::new(field, flat_child.into_series(), validity); + assert_eq!(arr.len(), 3); + + for i in 0..3 { + let element = arr.get(i); + assert!(element.is_some()); + + let element = element.unwrap(); + assert_eq!(element.len(), 3); + assert_eq!(element.data_type(), &DataType::Int32); + + let element = element.i32()?; + let data = element + .into_iter() + .map(|x| x.map(|v| *v)) + .collect::>>(); + let expected = ((i * 3) as i32..((i + 1) * 3) as i32) + .map(|x| Some(x)) + .collect::>>(); + assert_eq!(data, expected); + } + + Ok(()) + } + + #[test] + fn test_fixed_size_list_get_some_valid() -> DaftResult<()> { + let field = Field::new( + "foo", + DataType::FixedSizeList(Box::new(Field::new("foo", DataType::Int32)), 3), + ); + let flat_child = Int32Array::from(("foo", (0..9).collect::>())); + let raw_validity = vec![true, false, true]; + let validity = Some(arrow2::bitmap::Bitmap::from(raw_validity.as_slice())); + let arr = FixedSizeListArray::new(field, flat_child.into_series(), validity); + assert_eq!(arr.len(), 3); + + let element = arr.get(0); + assert!(element.is_some()); + let element = element.unwrap(); + assert_eq!(element.len(), 3); + assert_eq!(element.data_type(), &DataType::Int32); + let element = element.i32()?; + let data = element + .into_iter() + .map(|x| x.map(|v| *v)) + .collect::>>(); + let expected = vec![Some(0), Some(1), Some(2)]; + assert_eq!(data, expected); + + let element = arr.get(1); + assert!(element.is_none()); + + let element = arr.get(2); + assert!(element.is_some()); + let element = element.unwrap(); + assert_eq!(element.len(), 3); + assert_eq!(element.data_type(), &DataType::Int32); + let element = element.i32()?; + let data = element + .into_iter() + .map(|x| x.map(|v| *v)) + .collect::>>(); + let expected = vec![Some(6), Some(7), Some(8)]; + assert_eq!(data, expected); + + Ok(()) + } +} diff --git a/src/daft-core/src/array/ops/if_else.rs b/src/daft-core/src/array/ops/if_else.rs index 22bf07e9dd..51173fb7d8 100644 --- a/src/daft-core/src/array/ops/if_else.rs +++ b/src/daft-core/src/array/ops/if_else.rs @@ -1,6 +1,7 @@ use crate::array::growable::{Growable, GrowableArray}; use crate::array::ops::full::FullNull; use crate::array::DataArray; +use crate::datatypes::nested_arrays::FixedSizeListArray; use crate::datatypes::{BooleanArray, DaftPhysicalType}; use crate::{DataType, IntoSeries, Series}; use arrow2::array::Array; @@ -131,3 +132,23 @@ where .map(|arr| arr.clone()) } } + +impl<'a> FixedSizeListArray { + pub fn if_else( + &'a self, + other: &'a FixedSizeListArray, + predicate: &BooleanArray, + ) -> DaftResult { + generic_if_else( + predicate, + self.name(), + self, + other, + self.data_type(), + self.len(), + other.len(), + )? + .downcast::() + .map(|arr| arr.clone()) + } +} diff --git a/src/daft-core/src/array/ops/image.rs b/src/daft-core/src/array/ops/image.rs index 70a0a14f5e..2d94026f05 100644 --- a/src/daft-core/src/array/ops/image.rs +++ b/src/daft-core/src/array/ops/image.rs @@ -4,9 +4,10 @@ use std::vec; use image::{ColorType, DynamicImage, ImageBuffer}; -use crate::datatypes::FixedSizeListArray; +use crate::datatypes::UInt8Array; use crate::datatypes::{ logical::{DaftImageryType, FixedShapeImageArray, ImageArray, LogicalArray}, + nested_arrays::FixedSizeListArray, BinaryArray, DataType, Field, ImageFormat, ImageMode, StructArray, }; use common_error::{DaftError, DaftResult}; @@ -21,7 +22,7 @@ use std::ops::Deref; pub struct BBox(u32, u32, u32, u32); impl BBox { - pub fn from_u32_arrow_array(arr: Box) -> Self { + pub fn from_u32_arrow_array(arr: &dyn arrow2::array::Array) -> Self { assert!(arr.len() == 4); let mut iter = arr .as_any() @@ -495,19 +496,15 @@ impl ImageArray { pub fn crop(&self, bboxes: &FixedSizeListArray) -> DaftResult { let mut bboxes_iterator: Box>> = if bboxes.len() == 1 { - Box::new(std::iter::repeat( - bboxes - .as_arrow() - .get(0) - .map(|bbox| BBox::from_u32_arrow_array(bbox)), - )) + Box::new(std::iter::repeat(bboxes.get(0).map(|bbox| { + BBox::from_u32_arrow_array(bbox.u32().unwrap().data()) + }))) } else { - Box::new( + Box::new((0..bboxes.len()).map(|i| { bboxes - .as_arrow() - .iter() - .map(|bbox| bbox.map(|bbox| BBox::from_u32_arrow_array(bbox))), - ) + .get(i) + .map(|bbox| BBox::from_u32_arrow_array(bbox.u32().unwrap().data())) + })) }; let result = crop_images(self, &mut bboxes_iterator); Self::from_daft_image_buffers(self.name(), result.as_slice(), self.image_mode()) @@ -714,19 +711,15 @@ impl FixedShapeImageArray { pub fn crop(&self, bboxes: &FixedSizeListArray) -> DaftResult { let mut bboxes_iterator: Box>> = if bboxes.len() == 1 { - Box::new(std::iter::repeat( - bboxes - .as_arrow() - .get(0) - .map(|bbox| BBox::from_u32_arrow_array(bbox)), - )) + Box::new(std::iter::repeat(bboxes.get(0).map(|bbox| { + BBox::from_u32_arrow_array(bbox.u32().unwrap().data()) + }))) } else { - Box::new( + Box::new((0..bboxes.len()).map(|i| { bboxes - .as_arrow() - .iter() - .map(|bbox| bbox.map(|bbox| BBox::from_u32_arrow_array(bbox))), - ) + .get(i) + .map(|bbox| BBox::from_u32_arrow_array(bbox.u32().unwrap().data())) + })) }; let result = crop_images(self, &mut bboxes_iterator); ImageArray::from_daft_image_buffers(self.name(), result.as_slice(), &Some(self.mode())) @@ -750,7 +743,7 @@ impl AsImageObj for FixedShapeImageArray { match self.data_type() { DataType::FixedShapeImage(mode, height, width) => { - let arrow_array = self.as_arrow().values().as_any().downcast_ref::().unwrap(); + let arrow_array = self.physical.flat_child.downcast::().unwrap().as_arrow(); let num_channels = mode.num_channels(); let size = height * width * num_channels as u32; let start = idx * size as usize; diff --git a/src/daft-core/src/array/ops/len.rs b/src/daft-core/src/array/ops/len.rs index 0017ac6414..414493d52e 100644 --- a/src/daft-core/src/array/ops/len.rs +++ b/src/daft-core/src/array/ops/len.rs @@ -1,4 +1,7 @@ -use crate::{array::DataArray, datatypes::DaftArrowBackedType}; +use crate::{ + array::DataArray, + datatypes::{nested_arrays::FixedSizeListArray, DaftArrowBackedType}, +}; use common_error::DaftResult; #[cfg(feature = "python")] @@ -35,3 +38,14 @@ impl PythonArray { }) } } + +/// From arrow2 private method (arrow2::compute::aggregate::validity_size) +fn validity_size(validity: Option<&arrow2::bitmap::Bitmap>) -> usize { + validity.as_ref().map(|b| b.as_slice().0.len()).unwrap_or(0) +} + +impl FixedSizeListArray { + pub fn size_bytes(&self) -> DaftResult { + Ok(self.flat_child.size_bytes()? + validity_size(self.validity.as_ref())) + } +} diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 9888d6c8d9..c10c0beea4 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -1,4 +1,9 @@ -use crate::datatypes::{FixedSizeListArray, ListArray, UInt64Array, Utf8Array}; +use std::iter::repeat; + +use crate::array::growable::{Growable, GrowableArray}; +use crate::datatypes::{nested_arrays::FixedSizeListArray, ListArray, UInt64Array, Utf8Array}; +use crate::datatypes::{DaftDataType, Utf8Type}; +use crate::{with_match_daft_types, DataType}; use crate::series::Series; @@ -10,7 +15,7 @@ use common_error::DaftResult; use super::as_arrow::AsArrow; fn join_arrow_list_of_utf8s( - list_element: Option>, + list_element: Option<&dyn arrow2::array::Array>, delimiter_str: &str, ) -> Option { list_element @@ -93,9 +98,9 @@ impl ListArray { if delimiter.len() == 1 { let delimiter_str = delimiter.get(0).unwrap(); - let result = list_array - .iter() - .map(|list_element| join_arrow_list_of_utf8s(list_element, delimiter_str)); + let result = list_array.iter().map(|list_element| { + join_arrow_list_of_utf8s(list_element.as_ref().map(|b| b.as_ref()), delimiter_str) + }); Ok(Utf8Array::from(( self.name(), Box::new(arrow2::array::Utf8Array::from_iter(result)), @@ -105,7 +110,10 @@ impl ListArray { let result = list_array.iter().zip(delimiter.as_arrow().iter()).map( |(list_element, delimiter_element)| { let delimiter_str = delimiter_element.unwrap_or(""); - join_arrow_list_of_utf8s(list_element, delimiter_str) + join_arrow_list_of_utf8s( + list_element.as_ref().map(|b| b.as_ref()), + delimiter_str, + ) }, ); Ok(Utf8Array::from(( @@ -118,72 +126,84 @@ impl ListArray { impl FixedSizeListArray { pub fn lengths(&self) -> DaftResult { - let list_array = self.as_arrow(); - let list_size = list_array.size(); - let lens = (0..self.len()) - .map(|_| list_size as u64) - .collect::>(); - let array = Box::new( - arrow2::array::PrimitiveArray::from_vec(lens) - .with_validity(list_array.validity().cloned()), - ); - Ok(UInt64Array::from((self.name(), array))) + let size = self.fixed_element_len(); + match &self.validity { + None => Ok(UInt64Array::from(( + self.name(), + repeat(size as u64) + .take(self.len()) + .collect::>() + .as_slice(), + ))), + Some(validity) => { + let arrow_arr = arrow2::array::UInt64Array::from_iter(validity.iter().map(|v| { + if v { + Some(size as u64) + } else { + None + } + })); + Ok(UInt64Array::from((self.name(), Box::new(arrow_arr)))) + } + } } pub fn explode(&self) -> DaftResult { - let list_array = self.as_arrow(); - let child_array = list_array.values().as_ref(); - - let list_size = list_array.size(); - - let mut total_capacity: i64 = - (list_size * (list_array.len() - list_array.null_count())) as i64; - - if list_size == 0 { - total_capacity = list_array.len() as i64; - } - - let mut growable = - arrow2::array::growable::make_growable(&[child_array], true, total_capacity as usize); + let list_size = self.fixed_element_len(); + let total_capacity = if list_size == 0 { + self.len() + } else { + let null_count = self.validity.as_ref().map(|v| v.unset_bits()).unwrap_or(0); + list_size * (self.len() - null_count) + }; + + let mut child_growable: Box = with_match_daft_types!(self.child_data_type(), |$T| { + Box::new(<<$T as DaftDataType>::ArrayType as GrowableArray>::make_growable( + self.name().to_string(), + self.child_data_type(), + vec![self.flat_child.downcast::<<$T as DaftDataType>::ArrayType>()?], + true, + total_capacity, + )) + }); - for i in 0..list_array.len() { - let is_valid = list_array.is_valid(i) && (list_size > 0); + for i in 0..self.len() { + let is_valid = self.is_valid(i) && (list_size > 0); match is_valid { - false => growable.extend_validity(1), - true => growable.extend(0, i * list_size, list_size), + false => child_growable.add_nulls(1), + true => child_growable.extend(0, i * list_size, list_size), } } - Series::try_from((self.field.name.as_ref(), growable.as_box())) + child_growable.build() } pub fn join(&self, delimiter: &Utf8Array) -> DaftResult { - let list_array = self.as_arrow(); - assert_eq!( - list_array.values().data_type(), - &arrow2::datatypes::DataType::LargeUtf8 - ); + assert_eq!(self.child_data_type(), &DataType::Utf8,); - if delimiter.len() == 1 { - let delimiter_str = delimiter.get(0).unwrap(); - let result = list_array - .iter() - .map(|list_element| join_arrow_list_of_utf8s(list_element, delimiter_str)); - Ok(Utf8Array::from(( - self.name(), - Box::new(arrow2::array::Utf8Array::from_iter(result)), - ))) + let delimiter_iter: Box>> = if delimiter.len() == 1 { + Box::new(repeat(delimiter.get(0)).take(self.len())) } else { assert_eq!(delimiter.len(), self.len()); - let result = list_array.iter().zip(delimiter.as_arrow().iter()).map( - |(list_element, delimiter_element)| { - let delimiter_str = delimiter_element.unwrap_or(""); - join_arrow_list_of_utf8s(list_element, delimiter_str) - }, - ); - Ok(Utf8Array::from(( - self.name(), - Box::new(arrow2::array::Utf8Array::from_iter(result)), - ))) - } + Box::new(delimiter.as_arrow().iter()) + }; + let self_iter = (0..self.len()).map(|i| self.get(i)); + + let result = self_iter + .zip(delimiter_iter) + .map(|(list_element, delimiter)| { + join_arrow_list_of_utf8s( + list_element.as_ref().map(|l| { + l.downcast::<::ArrayType>() + .unwrap() + .as_arrow() as &dyn arrow2::array::Array + }), + delimiter.unwrap_or(""), + ) + }); + + Ok(Utf8Array::from(( + self.name(), + Box::new(arrow2::array::Utf8Array::from_iter(result)), + ))) } } diff --git a/src/daft-core/src/array/ops/list_agg.rs b/src/daft-core/src/array/ops/list_agg.rs index 7fc097acac..ac430107c9 100644 --- a/src/daft-core/src/array/ops/list_agg.rs +++ b/src/daft-core/src/array/ops/list_agg.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use crate::{ array::DataArray, - datatypes::{DaftArrowBackedType, ListArray}, + datatypes::{nested_arrays::FixedSizeListArray, DaftArrowBackedType, ListArray}, }; use common_error::DaftResult; @@ -97,3 +97,17 @@ impl DaftListAggable for crate::datatypes::PythonArray { Self::new(self.field().clone().into(), Box::new(arrow_array)) } } + +impl DaftListAggable for FixedSizeListArray { + type Output = DaftResult; + + fn list(&self) -> Self::Output { + // TODO(FixedSizeList) + todo!("Requires new ListArrays for implementation") + } + + fn grouped_list(&self, _groups: &GroupIndices) -> Self::Output { + // TODO(FixedSizeList) + todo!("Requires new ListArrays for implementation") + } +} diff --git a/src/daft-core/src/array/ops/null.rs b/src/daft-core/src/array/ops/null.rs index 9481c5b9de..b78a3ca6dd 100644 --- a/src/daft-core/src/array/ops/null.rs +++ b/src/daft-core/src/array/ops/null.rs @@ -1,8 +1,11 @@ -use std::sync::Arc; +use std::{iter::repeat, sync::Arc}; use arrow2; -use crate::{array::DataArray, datatypes::*}; +use crate::{ + array::DataArray, + datatypes::{nested_arrays::FixedSizeListArray, *}, +}; use common_error::DaftResult; use super::DaftIsNull; @@ -35,6 +38,26 @@ where } } +impl DaftIsNull for FixedSizeListArray { + type Output = DaftResult>; + + fn is_null(&self) -> Self::Output { + match &self.validity { + None => Ok(BooleanArray::from(( + self.name(), + repeat(false) + .take(self.len()) + .collect::>() + .as_slice(), + ))), + Some(validity) => Ok(BooleanArray::from(( + self.name(), + validity.into_iter().collect::>().as_slice(), + ))), + } + } +} + impl DataArray where T: DaftPhysicalType, @@ -44,3 +67,13 @@ where self.data.is_valid(idx) } } + +impl FixedSizeListArray { + #[inline] + pub fn is_valid(&self, idx: usize) -> bool { + match &self.validity { + None => true, + Some(validity) => validity.get(idx).unwrap(), + } + } +} diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 53b1fbaae1..14d59c6676 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -5,11 +5,13 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - ImageArray, TimestampArray, + FixedShapeTensorArray, ImageArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, FixedSizeListArray, - ImageFormat, ListArray, NullArray, StructArray, Utf8Array, + nested_arrays::FixedSizeListArray, + BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, ImageFormat, ListArray, + NullArray, StructArray, Utf8Array, }, + with_match_daft_types, }; use common_error::DaftResult; @@ -32,13 +34,9 @@ macro_rules! impl_array_str_value { impl_array_str_value!(BooleanArray, "{}"); impl_array_str_value!(ListArray, "{:?}"); -impl_array_str_value!(FixedSizeListArray, "{:?}"); impl_array_str_value!(StructArray, "{:?}"); impl_array_str_value!(ExtensionArray, "{:?}"); impl_array_str_value!(DurationArray, "{}"); -impl_array_str_value!(EmbeddingArray, "{:?}"); -impl_array_str_value!(ImageArray, "{:?}"); -impl_array_str_value!(FixedShapeImageArray, "{:?}"); fn pretty_print_bytes(bytes: &[u8], max_len: usize) -> DaftResult { /// influenced by pythons bytes repr @@ -207,6 +205,65 @@ impl Decimal128Array { } } +impl FixedSizeListArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + let val = self.get(idx); + match val { + None => Ok("None".to_string()), + Some(v) => { + with_match_daft_types!(self.child_data_type(), |$T| { + let arr = v.downcast::<<$T as DaftDataType>::ArrayType>()?; + let mut s = String::new(); + s += "["; + s += (0..v.len()).map(|i| arr.str_value(i)).collect::>>()?.join(", ").as_str(); + s += "]"; + Ok(s) + }) + } + } + } +} + +impl EmbeddingArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + if self.physical.is_valid(idx) { + Ok("".to_string()) + } else { + Ok("None".to_string()) + } + } +} + +impl ImageArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + if self.physical.is_valid(idx) { + Ok("".to_string()) + } else { + Ok("None".to_string()) + } + } +} + +impl FixedShapeImageArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + if self.physical.is_valid(idx) { + Ok("".to_string()) + } else { + Ok("None".to_string()) + } + } +} + +impl FixedShapeTensorArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + if self.physical.is_valid(idx) { + Ok("".to_string()) + } else { + Ok("None".to_string()) + } + } +} + // Default implementation of html_value: html escape the str_value. macro_rules! impl_array_html_value { ($ArrayT:ty) => { @@ -324,3 +381,12 @@ impl FixedShapeImageArray { } } } + +impl FixedShapeTensorArray { + pub fn html_value(&self, idx: usize) -> String { + let str_value = self.str_value(idx).unwrap(); + html_escape::encode_text(&str_value) + .into_owned() + .replace('\n', "
") + } +} diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index 14ab313861..6d423e9335 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -5,9 +5,9 @@ use crate::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, - FixedSizeListArray, Float32Array, Float64Array, ListArray, NullArray, StructArray, - Utf8Array, + nested_arrays::FixedSizeListArray, + BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, Float32Array, + Float64Array, ListArray, NullArray, StructArray, Utf8Array, }, kernels::search_sorted::{build_compare_with_nulls, cmp_float}, series::Series, diff --git a/src/daft-core/src/array/ops/take.rs b/src/daft-core/src/array/ops/take.rs index 6e8e692fbd..ac6981f35f 100644 --- a/src/daft-core/src/array/ops/take.rs +++ b/src/daft-core/src/array/ops/take.rs @@ -1,3 +1,5 @@ +use std::iter::repeat; + use crate::{ array::DataArray, datatypes::{ @@ -5,11 +7,14 @@ use crate::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, - FixedSizeListArray, ListArray, NullArray, StructArray, Utf8Array, + nested_arrays::FixedSizeListArray, + BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, ListArray, + NullArray, StructArray, UInt64Array, Utf8Array, }, + DataType, IntoSeries, }; use common_error::DaftResult; +use num_traits::ToPrimitive; use super::as_arrow::AsArrow; @@ -62,7 +67,6 @@ impl_dataarray_take!(Utf8Array); impl_dataarray_take!(BooleanArray); impl_dataarray_take!(BinaryArray); impl_dataarray_take!(ListArray); -impl_dataarray_take!(FixedSizeListArray); impl_dataarray_take!(NullArray); impl_dataarray_take!(StructArray); impl_dataarray_take!(ExtensionArray); @@ -85,7 +89,6 @@ impl crate::datatypes::PythonArray { use crate::datatypes::PythonType; use arrow2::array::Array; - use arrow2::types::Index; use pyo3::prelude::*; let indices = idx.as_arrow(); @@ -99,7 +102,7 @@ impl crate::datatypes::PythonArray { indices .iter() .map(|maybe_idx| match maybe_idx { - Some(idx) => old_values[idx.to_usize()].clone(), + Some(idx) => old_values[arrow2::types::Index::to_usize(idx)].clone(), None => py_none.clone(), }) .collect() @@ -139,6 +142,46 @@ impl crate::datatypes::PythonArray { } } +impl FixedSizeListArray { + pub fn take(&self, idx: &DataArray) -> DaftResult + where + I: DaftIntegerType, + ::Native: arrow2::types::Index, + { + let size = self.fixed_element_len() as u64; + let idx_as_u64 = idx.cast(&DataType::UInt64)?; + let expanded_child_idx: Vec> = idx_as_u64 + .u64()? + .into_iter() + .flat_map(|i| { + let x: Box>> = match &i { + None => Box::new(repeat(None).take(size as usize)), + Some(i) => Box::new((*i * size..(*i + 1) * size).map(Some)), + }; + x + }) + .collect(); + let child_idx = UInt64Array::from(( + "", + Box::new(arrow2::array::UInt64Array::from_iter( + expanded_child_idx.iter(), + )), + )) + .into_series(); + let taken_validity = self.validity.as_ref().map(|v| { + arrow2::bitmap::Bitmap::from_iter(idx.into_iter().map(|i| match i { + None => false, + Some(i) => v.get_bit(i.to_usize().unwrap()), + })) + }); + Ok(Self::new( + self.field.clone(), + self.flat_child.take(&child_idx)?, + taken_validity, + )) + } +} + impl TensorArray { #[inline] pub fn get(&self, idx: usize) -> Option> { @@ -185,22 +228,6 @@ impl TensorArray { } impl FixedShapeTensorArray { - #[inline] - pub fn get(&self, idx: usize) -> Option> { - if idx >= self.len() { - panic!("Out of bounds: {} vs len: {}", idx, self.len()) - } - let arrow_array = self.as_arrow(); - let is_valid = arrow_array - .validity() - .map_or(true, |validity| validity.get_bit(idx)); - if is_valid { - Some(unsafe { arrow_array.value_unchecked(idx) }) - } else { - None - } - } - pub fn take(&self, idx: &DataArray) -> DaftResult where I: DaftIntegerType, @@ -209,19 +236,4 @@ impl FixedShapeTensorArray { let new_array = self.physical.take(idx)?; Ok(Self::new(self.field.clone(), new_array)) } - - pub fn str_value(&self, idx: usize) -> DaftResult { - let val = self.get(idx); - match val { - None => Ok("None".to_string()), - Some(v) => Ok(format!("{v:?}")), - } - } - - pub fn html_value(&self, idx: usize) -> String { - let str_value = self.str_value(idx).unwrap(); - html_escape::encode_text(&str_value) - .into_owned() - .replace('\n', "
") - } } diff --git a/src/daft-core/src/datatypes/dtype.rs b/src/daft-core/src/datatypes/dtype.rs index 4307a40775..24c650d0ca 100644 --- a/src/daft-core/src/datatypes/dtype.rs +++ b/src/daft-core/src/datatypes/dtype.rs @@ -171,7 +171,11 @@ impl DataType { ); logical_extension.to_arrow() } - _ => Err(DaftError::TypeError(format!( + #[cfg(feature = "python")] + DataType::Python => Err(DaftError::TypeError(format!( + "Can not convert {self:?} into arrow type" + ))), + DataType::Unknown => Err(DaftError::TypeError(format!( "Can not convert {self:?} into arrow type" ))), } diff --git a/src/daft-core/src/datatypes/logical.rs b/src/daft-core/src/datatypes/logical.rs index 48c85c3bec..8103d86df6 100644 --- a/src/daft-core/src/datatypes/logical.rs +++ b/src/daft-core/src/datatypes/logical.rs @@ -7,8 +7,9 @@ use crate::{ use common_error::DaftResult; use super::{ - DaftArrayType, DaftDataType, DataArray, DataType, Decimal128Type, DurationType, EmbeddingType, - FixedShapeImageType, FixedShapeTensorType, ImageType, TensorType, TimestampType, + nested_arrays::FixedSizeListArray, DaftArrayType, DaftDataType, DataArray, DataType, + Decimal128Type, DurationType, EmbeddingType, FixedShapeImageType, FixedShapeTensorType, + ImageType, TensorType, TimestampType, }; /// A LogicalArray is a wrapper on top of some underlying array, applying the semantic meaning of its @@ -118,6 +119,18 @@ impl LogicalArrayImpl> { } } +/// Implementation for a LogicalArray that wraps a FixedSizeListArray +impl LogicalArrayImpl { + impl_logical_type!(FixedSizeListArray); + + pub fn to_arrow(&self) -> Box { + let mut fixed_size_list_arrow_array = self.physical.to_arrow(); + let arrow_logical_type = self.data_type().to_arrow().unwrap(); + fixed_size_list_arrow_array.change_type(arrow_logical_type); + fixed_size_list_arrow_array + } +} + pub type LogicalArray = LogicalArrayImpl::PhysicalType as DaftDataType>::ArrayType>; pub type Decimal128Array = LogicalArray; diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index 09b71fabfc..05ba1ad724 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -13,7 +13,7 @@ macro_rules! with_match_daft_types {( Int16 => __with_ty__! { Int16Type }, Int32 => __with_ty__! { Int32Type }, Int64 => __with_ty__! { Int64Type }, - Int128(..) => __with_ty__! { Int128Type }, + Int128 => __with_ty__! { Int128Type }, UInt8 => __with_ty__! { UInt8Type }, UInt16 => __with_ty__! { UInt16Type }, UInt32 => __with_ty__! { UInt32Type }, @@ -22,7 +22,6 @@ macro_rules! with_match_daft_types {( Float64 => __with_ty__! { Float64Type }, Timestamp(_, _) => __with_ty__! { TimestampType }, Date => __with_ty__! { DateType }, - Time(_) => __with_ty__! { TimeType }, Duration(_) => __with_ty__! { DurationType }, Binary => __with_ty__! { BinaryType }, Utf8 => __with_ty__! { Utf8Type }, @@ -38,7 +37,8 @@ macro_rules! with_match_daft_types {( Tensor(..) => __with_ty__! { TensorType }, FixedShapeTensor(..) => __with_ty__! { FixedShapeTensorType }, Decimal128(..) => __with_ty__! { Decimal128Type }, - Float16 => unimplemented!("Array for Float16 DataType not implemented"), + Time(_) => unimplemented!("Array for Time DataType not implemented"), + // Float16 => unimplemented!("Array for Float16 DataType not implemented"), Unknown => unimplemented!("Array for Unknown DataType not implemented"), // NOTE: We should not implement a default for match here, because this is meant to be @@ -108,7 +108,6 @@ macro_rules! with_match_arrow_daft_types {( // Date => __with_ty__! { DateType }, // Timestamp(_, _) => __with_ty__! { TimestampType }, List(_) => __with_ty__! { ListType }, - FixedSizeList(..) => __with_ty__! { FixedSizeListType }, Struct(_) => __with_ty__! { StructType }, Extension(_, _, _) => __with_ty__! { ExtensionType }, Utf8 => __with_ty__! { Utf8Type }, @@ -144,31 +143,6 @@ macro_rules! with_match_comparable_daft_types {( } })} -#[macro_export] -macro_rules! with_match_numeric_and_utf_daft_types {( - $key_type:expr, | $_:tt $T:ident | $($body:tt)* -) => ({ - macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} - use $crate::datatypes::DataType::*; - use $crate::datatypes::*; - - match $key_type { - Int8 => __with_ty__! { Int8Type }, - Int16 => __with_ty__! { Int16Type }, - Int32 => __with_ty__! { Int32Type }, - Int64 => __with_ty__! { Int64Type }, - UInt8 => __with_ty__! { UInt8Type }, - UInt16 => __with_ty__! { UInt16Type }, - UInt32 => __with_ty__! { UInt32Type }, - UInt64 => __with_ty__! { UInt64Type }, - // Float16 => __with_ty__! { Float16Type }, - Float32 => __with_ty__! { Float32Type }, - Float64 => __with_ty__! { Float64Type }, - Utf8 => __with_ty__! { Utf8Type }, - _ => panic!("{:?} not implemented", $key_type) - } -})} - #[macro_export] macro_rules! with_match_numeric_daft_types {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* @@ -231,29 +205,6 @@ macro_rules! with_match_float_and_null_daft_types {( } })} -#[macro_export] -macro_rules! with_match_daft_logical_types {( - $key_type:expr, | $_:tt $T:ident | $($body:tt)* -) => ({ - macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} - use $crate::datatypes::DataType::*; - #[allow(unused_imports)] - use $crate::datatypes::*; - - match $key_type { - Decimal128(..) => __with_ty__! { Decimal128Type }, - Date => __with_ty__! { DateType }, - Duration(..) => __with_ty__! { DurationType }, - Timestamp(..) => __with_ty__! { TimestampType }, - Embedding(..) => __with_ty__! { EmbeddingType }, - Image(..) => __with_ty__! { ImageType }, - FixedShapeImage(..) => __with_ty__! { FixedShapeImageType }, - Tensor(..) => __with_ty__! { TensorType }, - FixedShapeTensor(..) => __with_ty__! { FixedShapeTensorType }, - _ => panic!("{:?} not implemented for with_match_daft_logical_types", $key_type) - } -})} - #[macro_export] macro_rules! with_match_daft_logical_primitive_types {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 6ec758ed6b..8af987eb61 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -21,7 +21,10 @@ pub use image_format::ImageFormat; pub use image_mode::ImageMode; use num_traits::{Bounded, Float, FromPrimitive, Num, NumCast, ToPrimitive, Zero}; pub use time_unit::TimeUnit; + +use self::nested_arrays::FixedSizeListArray; pub mod logical; +pub mod nested_arrays; /// Trait that is implemented by all Array types /// @@ -89,7 +92,7 @@ macro_rules! impl_daft_non_arrow_datatype { }; } -macro_rules! impl_daft_logical_datatype { +macro_rules! impl_daft_logical_data_array_datatype { ($ca:ident, $variant:ident, $physical_type:ident) => { #[derive(Clone)] pub struct $ca {} @@ -109,6 +112,26 @@ macro_rules! impl_daft_logical_datatype { }; } +macro_rules! impl_daft_logical_fixed_size_list_datatype { + ($ca:ident, $variant:ident) => { + #[derive(Clone)] + pub struct $ca {} + + impl DaftDataType for $ca { + #[inline] + fn get_dtype() -> DataType { + DataType::$variant + } + + type ArrayType = logical::LogicalArray<$ca>; + } + + impl DaftLogicalType for $ca { + type PhysicalType = FixedSizeListType; + } + }; +} + impl_daft_arrow_datatype!(NullType, Null); impl_daft_arrow_datatype!(BooleanType, Boolean); impl_daft_arrow_datatype!(Int8Type, Int8); @@ -125,25 +148,37 @@ impl_daft_arrow_datatype!(Float32Type, Float32); impl_daft_arrow_datatype!(Float64Type, Float64); impl_daft_arrow_datatype!(BinaryType, Binary); impl_daft_arrow_datatype!(Utf8Type, Utf8); -impl_daft_arrow_datatype!(FixedSizeListType, Unknown); impl_daft_arrow_datatype!(ListType, Unknown); impl_daft_arrow_datatype!(StructType, Unknown); impl_daft_arrow_datatype!(ExtensionType, Unknown); +#[derive(Clone)] +pub struct FixedSizeListType {} + +impl DaftDataType for FixedSizeListType { + #[inline] + fn get_dtype() -> DataType { + DataType::Unknown + } + + type ArrayType = FixedSizeListArray; +} +impl DaftPhysicalType for FixedSizeListType {} + +impl_daft_logical_data_array_datatype!(Decimal128Type, Unknown, Int128Type); +impl_daft_logical_data_array_datatype!(TimestampType, Unknown, Int64Type); +impl_daft_logical_data_array_datatype!(DateType, Date, Int32Type); +// impl_daft_logical_data_array_datatype!(TimeType, Unknown, Int64Type); +impl_daft_logical_data_array_datatype!(DurationType, Unknown, Int64Type); +impl_daft_logical_data_array_datatype!(ImageType, Unknown, StructType); +impl_daft_logical_data_array_datatype!(TensorType, Unknown, StructType); +impl_daft_logical_fixed_size_list_datatype!(EmbeddingType, Unknown); +impl_daft_logical_fixed_size_list_datatype!(FixedShapeImageType, Unknown); +impl_daft_logical_fixed_size_list_datatype!(FixedShapeTensorType, Unknown); + #[cfg(feature = "python")] impl_daft_non_arrow_datatype!(PythonType, Python); -impl_daft_logical_datatype!(Decimal128Type, Unknown, Int128Type); -impl_daft_logical_datatype!(TimestampType, Unknown, Int64Type); -impl_daft_logical_datatype!(DateType, Date, Int32Type); -impl_daft_logical_datatype!(TimeType, Unknown, Int64Type); -impl_daft_logical_datatype!(DurationType, Unknown, Int64Type); -impl_daft_logical_datatype!(EmbeddingType, Unknown, FixedSizeListType); -impl_daft_logical_datatype!(ImageType, Unknown, StructType); -impl_daft_logical_datatype!(FixedShapeImageType, Unknown, FixedSizeListType); -impl_daft_logical_datatype!(TensorType, Unknown, StructType); -impl_daft_logical_datatype!(FixedShapeTensorType, Unknown, FixedSizeListType); - pub trait NumericNative: PartialOrd + NativeType @@ -285,7 +320,6 @@ pub type Float32Array = DataArray; pub type Float64Array = DataArray; pub type BinaryArray = DataArray; pub type Utf8Array = DataArray; -pub type FixedSizeListArray = DataArray; pub type ListArray = DataArray; pub type StructArray = DataArray; pub type ExtensionArray = DataArray; diff --git a/src/daft-core/src/datatypes/nested_arrays.rs b/src/daft-core/src/datatypes/nested_arrays.rs new file mode 100644 index 0000000000..0a8c94210b --- /dev/null +++ b/src/daft-core/src/datatypes/nested_arrays.rs @@ -0,0 +1,195 @@ +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; + +use crate::array::growable::{Growable, GrowableArray}; +use crate::datatypes::{DaftArrayType, Field}; +use crate::series::Series; +use crate::DataType; + +#[derive(Clone)] +pub struct FixedSizeListArray { + pub field: Arc, + pub flat_child: Series, + pub validity: Option, +} + +impl DaftArrayType for FixedSizeListArray {} + +impl FixedSizeListArray { + pub fn new>>( + field: F, + flat_child: Series, + validity: Option, + ) -> Self { + let field: Arc = field.into(); + match &field.as_ref().dtype { + DataType::FixedSizeList(_, size) => { + if let Some(validity) = validity.as_ref() && (validity.len() * size) != flat_child.len() { + panic!( + "FixedSizeListArray::new received values with len {} but expected it to match len of validity * size: {}", + flat_child.len(), + (validity.len() * size), + ) + } + } + _ => panic!( + "FixedSizeListArray::new expected FixedSizeList datatype, but received field: {}", + field + ) + } + FixedSizeListArray { + field, + flat_child, + validity, + } + } + + pub fn concat(arrays: &[&Self]) -> DaftResult { + if arrays.is_empty() { + return Err(DaftError::ValueError( + "Need at least 1 FixedSizeListArray to concat".to_string(), + )); + } + + let first_array = arrays.get(0).unwrap(); + let mut growable = ::make_growable( + first_array.field.name.clone(), + &first_array.field.dtype, + arrays.to_vec(), + arrays + .iter() + .map(|a| a.validity.as_ref().map_or(0usize, |v| v.unset_bits())) + .sum::() + > 0, + arrays.iter().map(|a| a.len()).sum(), + ); + + for (i, arr) in arrays.iter().enumerate() { + growable.extend(i, 0, arr.len()); + } + + growable + .build() + .map(|s| s.downcast::().unwrap().clone()) + } + + pub fn len(&self) -> usize { + self.flat_child.len() / self.fixed_element_len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn name(&self) -> &str { + &self.field.name + } + + pub fn data_type(&self) -> &DataType { + &self.field.dtype + } + + pub fn child_data_type(&self) -> &DataType { + match &self.field.dtype { + DataType::FixedSizeList(child, _) => &child.dtype, + _ => unreachable!("FixedSizeListArray must have DataType::FixedSizeList(..)"), + } + } + + pub fn rename(&self, name: &str) -> Self { + Self::new( + Field::new(name, self.data_type().clone()), + self.flat_child.rename(name), + self.validity.clone(), + ) + } + + pub fn slice(&self, start: usize, end: usize) -> DaftResult { + if start > end { + return Err(DaftError::ValueError(format!( + "Trying to slice array with negative length, start: {start} vs end: {end}" + ))); + } + let size = self.fixed_element_len(); + Ok(Self::new( + self.field.clone(), + self.flat_child.slice(start * size, end * size)?, + self.validity + .as_ref() + .map(|v| v.clone().sliced(start, end - start)), + )) + } + + pub fn to_arrow(&self) -> Box { + let arrow_dtype = self.data_type().to_arrow().unwrap(); + Box::new(arrow2::array::FixedSizeListArray::new( + arrow_dtype, + self.flat_child.to_arrow(), + self.validity.clone(), + )) + } + + pub fn fixed_element_len(&self) -> usize { + let dtype = &self.field.as_ref().dtype; + match dtype { + DataType::FixedSizeList(_, s) => *s, + _ => unreachable!("FixedSizeListArray should always have FixedSizeList datatype"), + } + } +} + +#[cfg(test)] +mod tests { + use common_error::DaftResult; + + use crate::{ + datatypes::{Field, Int32Array}, + DataType, IntoSeries, + }; + + use super::FixedSizeListArray; + + /// Helper that returns a FixedSizeListArray, with each list element at len=3 + fn get_i32_fixed_size_list_array(validity: &[bool]) -> FixedSizeListArray { + let field = Field::new( + "foo", + DataType::FixedSizeList(Box::new(Field::new("foo", DataType::Int32)), 3), + ); + let flat_child = Int32Array::from(( + "foo", + (0i32..(validity.len() * 3) as i32).collect::>(), + )); + FixedSizeListArray::new( + field, + flat_child.into_series(), + Some(arrow2::bitmap::Bitmap::from(validity)), + ) + } + + #[test] + fn test_rename() -> DaftResult<()> { + let arr = get_i32_fixed_size_list_array(vec![true, true, false].as_slice()); + let renamed_arr = arr.rename("bar"); + + assert_eq!(renamed_arr.name(), "bar"); + assert_eq!(renamed_arr.flat_child.len(), arr.flat_child.len()); + assert_eq!( + renamed_arr + .flat_child + .i32()? + .into_iter() + .collect::>(), + arr.flat_child.i32()?.into_iter().collect::>() + ); + assert_eq!( + renamed_arr + .validity + .unwrap() + .into_iter() + .collect::>(), + arr.validity.unwrap().into_iter().collect::>() + ); + Ok(()) + } +} diff --git a/src/daft-core/src/series/array_impl/binary_ops.rs b/src/daft-core/src/series/array_impl/binary_ops.rs index 1f3cfb5183..8f869bb9fb 100644 --- a/src/daft-core/src/series/array_impl/binary_ops.rs +++ b/src/daft-core/src/series/array_impl/binary_ops.rs @@ -14,9 +14,9 @@ use crate::datatypes::logical::{ ImageArray, TensorArray, TimestampArray, }; use crate::datatypes::{ - BinaryArray, BooleanArray, ExtensionArray, FixedSizeListArray, Float32Array, Float64Array, - Int16Array, Int32Array, Int64Array, Int8Array, ListArray, NullArray, StructArray, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, Utf8Array, + nested_arrays::FixedSizeListArray, BinaryArray, BooleanArray, ExtensionArray, Float32Array, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, NullArray, StructArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }; use super::{ArrayWrapper, IntoSeries, Series}; diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index b0c26d9ea2..b6d3cab256 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -13,9 +13,9 @@ use crate::series::array_impl::binary_ops::SeriesBinaryOps; use crate::series::Field; use crate::{ datatypes::{ - BinaryArray, BooleanArray, ExtensionArray, FixedSizeListArray, Float32Array, Float64Array, - Int128Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, NullArray, - StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, + BinaryArray, BooleanArray, ExtensionArray, Float32Array, Float64Array, Int128Array, + Int16Array, Int32Array, Int64Array, Int8Array, ListArray, NullArray, StructArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }, series::series_like::SeriesLike, with_match_integer_daft_types, @@ -61,35 +61,6 @@ fn logical_to_arrow<'a>( } } } - DataType::FixedSizeList(child_field, _size) => { - let downcasted = arr - .as_ref() - .as_any() - .downcast_ref::() - .unwrap(); - let values = Cow::Borrowed(downcasted.values()); - let new_values = logical_to_arrow(values, child_field.as_ref()); - match new_values { - Cow::Borrowed(..) => arr, - Cow::Owned(new_arr) => { - let new_child_field = arrow2::datatypes::Field::new( - field.name.clone(), - new_arr.data_type().clone(), - true, - ); - let new_datatype = - arrow2::datatypes::DataType::LargeList(Box::new(new_child_field)); - Cow::Owned( - arrow2::array::FixedSizeListArray::new( - new_datatype, - new_arr, - arr.validity().cloned(), - ) - .boxed(), - ) - } - } - } DataType::Struct(fields) => { let downcasted = arr .as_ref() @@ -340,7 +311,6 @@ impl_series_like_for_data_array!(UInt64Array); impl_series_like_for_data_array!(Float32Array); impl_series_like_for_data_array!(Float64Array); impl_series_like_for_data_array!(Utf8Array); -impl_series_like_for_data_array!(FixedSizeListArray); impl_series_like_for_data_array!(ListArray); impl_series_like_for_data_array!(StructArray); impl_series_like_for_data_array!(ExtensionArray); diff --git a/src/daft-core/src/series/array_impl/mod.rs b/src/daft-core/src/series/array_impl/mod.rs index f06a669b97..c440c4126a 100644 --- a/src/daft-core/src/series/array_impl/mod.rs +++ b/src/daft-core/src/series/array_impl/mod.rs @@ -1,6 +1,7 @@ pub mod binary_ops; pub mod data_array; pub mod logical_array; +pub mod nested_array; use super::Series; diff --git a/src/daft-core/src/series/array_impl/nested_array.rs b/src/daft-core/src/series/array_impl/nested_array.rs new file mode 100644 index 0000000000..f6a6bab08a --- /dev/null +++ b/src/daft-core/src/series/array_impl/nested_array.rs @@ -0,0 +1,178 @@ +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; + +use crate::array::ops::broadcast::Broadcastable; +use crate::array::ops::{DaftIsNull, GroupIndices}; +use crate::datatypes::Field; +use crate::datatypes::{nested_arrays::FixedSizeListArray, BooleanArray}; +use crate::series::{array_impl::binary_ops::SeriesBinaryOps, IntoSeries, Series, SeriesLike}; +use crate::{with_match_integer_daft_types, DataType}; + +use super::ArrayWrapper; + +impl IntoSeries for FixedSizeListArray { + fn into_series(self) -> Series { + Series { + inner: Arc::new(ArrayWrapper(self)), + } + } +} + +impl SeriesLike for ArrayWrapper { + fn into_series(&self) -> Series { + self.0.clone().into_series() + } + + fn to_arrow(&self) -> Box { + self.0.to_arrow() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn min(&self, _groups: Option<&GroupIndices>) -> DaftResult { + Err(DaftError::ValueError( + "FixedSizeList does not support min".to_string(), + )) + } + + fn max(&self, _groups: Option<&GroupIndices>) -> DaftResult { + Err(DaftError::ValueError( + "FixedSizeList does not support max".to_string(), + )) + } + + fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { + use crate::array::ops::DaftListAggable; + + match groups { + Some(groups) => Ok(self.0.grouped_list(groups)?.into_series()), + None => Ok(self.0.list()?.into_series()), + } + } + + fn broadcast(&self, num: usize) -> DaftResult { + Ok(self.0.broadcast(num)?.into_series()) + } + + fn cast(&self, datatype: &DataType) -> DaftResult { + self.0.cast(datatype) + } + + fn filter(&self, mask: &BooleanArray) -> DaftResult { + Ok(self.0.filter(mask)?.into_series()) + } + + fn if_else(&self, other: &Series, predicate: &Series) -> DaftResult { + Ok(self + .0 + .if_else(other.downcast()?, predicate.bool()?)? + .into_series()) + } + + fn data_type(&self) -> &DataType { + self.0.data_type() + } + + fn field(&self) -> &Field { + &self.0.field + } + + fn len(&self) -> usize { + self.0.len() + } + + fn name(&self) -> &str { + self.0.name() + } + + fn rename(&self, name: &str) -> Series { + self.0.rename(name).into_series() + } + + fn size_bytes(&self) -> DaftResult { + self.0.size_bytes() + } + + fn is_null(&self) -> DaftResult { + Ok(self.0.is_null()?.into_series()) + } + + fn sort(&self, _descending: bool) -> DaftResult { + Err(DaftError::ValueError( + "Cannot sort a FixedSizeListArray".to_string(), + )) + } + + fn head(&self, num: usize) -> DaftResult { + self.slice(0, num) + } + + fn slice(&self, start: usize, end: usize) -> DaftResult { + Ok(self.0.slice(start, end)?.into_series()) + } + + fn take(&self, idx: &Series) -> DaftResult { + with_match_integer_daft_types!(idx.data_type(), |$S| { + Ok(self + .0 + .take(idx.downcast::<<$S as DaftDataType>::ArrayType>()?)? + .into_series()) + }) + } + + fn str_value(&self, idx: usize) -> DaftResult { + self.0.str_value(idx) + } + + fn html_value(&self, idx: usize) -> String { + self.0.html_value(idx) + } + + fn add(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::add(self, rhs) + } + fn sub(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::sub(self, rhs) + } + fn mul(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::mul(self, rhs) + } + fn div(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::div(self, rhs) + } + fn rem(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::rem(self, rhs) + } + + fn and(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::and(self, rhs) + } + fn or(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::or(self, rhs) + } + fn xor(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::xor(self, rhs) + } + + fn equal(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::equal(self, rhs) + } + fn not_equal(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::not_equal(self, rhs) + } + fn lt(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::lt(self, rhs) + } + fn lte(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::lte(self, rhs) + } + fn gt(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::gt(self, rhs) + } + fn gte(&self, rhs: &Series) -> DaftResult { + SeriesBinaryOps::gte(self, rhs) + } +} diff --git a/src/daft-core/src/series/from.rs b/src/daft-core/src/series/from.rs index 602d4a10c8..7b533cb233 100644 --- a/src/daft-core/src/series/from.rs +++ b/src/daft-core/src/series/from.rs @@ -1,9 +1,8 @@ use std::sync::Arc; use crate::{ - datatypes::{logical::LogicalArray, DataType, Field}, - with_match_daft_logical_primitive_types, with_match_daft_logical_types, - with_match_physical_daft_types, + datatypes::{DataType, Field}, + with_match_daft_types, }; use common_error::{DaftError, DaftResult}; @@ -21,45 +20,15 @@ impl TryFrom<(&str, Box)> for Series { let dtype: DataType = source_arrow_type.into(); let field = Arc::new(Field::new(name, dtype.clone())); + // TODO(Nested): Refactor this out with nested logical types in StructArray and ListArray + // Corner-case nested logical types that have not yet been migrated to new Array formats + // to hold only casted physical arrow arrays. let physical_type = dtype.to_physical(); - - if dtype.is_logical() { - let arrow_physical_type = physical_type.to_arrow()?; - - use DataType::*; - let physical_arrow_array = match dtype { - // Primitive wrapper types: change the arrow2 array's type field to primitive - Decimal128(..) | Date | Timestamp(..) | Duration(..) => { - with_match_daft_logical_primitive_types!(dtype, |$P| { - use arrow2::array::Array; - array - .as_any() - .downcast_ref::>() - .unwrap() - .clone() - .to(arrow_physical_type) - .to_boxed() - }) - } - // Otherwise, use an Arrow cast to drop Extension types. - _ => arrow2::compute::cast::cast( - array.as_ref(), - &arrow_physical_type, - arrow2::compute::cast::CastOptions { - wrapped: true, - partial: false, - }, - )?, - }; - - let res = with_match_daft_logical_types!(dtype, |$T| { - LogicalArray::<$T>::from_arrow(field.as_ref(), physical_arrow_array)?.into_series() - }); - return Ok(res); - } - - // is not logical but contains one - if physical_type != dtype { + if (matches!(dtype, DataType::List(..)) + || matches!(dtype, DataType::Struct(..)) + || dtype.is_extension()) + && physical_type != dtype + { let arrow_physical_type = physical_type.to_arrow()?; let casted_array = arrow2::compute::cast::cast( array.as_ref(), @@ -70,12 +39,12 @@ impl TryFrom<(&str, Box)> for Series { }, )?; return Ok( - with_match_physical_daft_types!(physical_type, |$T| DataArray::<$T>::from_arrow(field.as_ref(), casted_array)?.into_series()), + with_match_daft_types!(physical_type, |$T| <$T as DaftDataType>::ArrayType::from_arrow(field.as_ref(), casted_array)?.into_series()), ); } - Ok( - with_match_physical_daft_types!(dtype, |$T| DataArray::<$T>::from_arrow(field.as_ref(), array.into())?.into_series()), - ) + with_match_daft_types!(dtype, |$T| { + Ok(<$T as DaftDataType>::ArrayType::from_arrow(&field, array)?.into_series()) + }) } } diff --git a/src/daft-core/src/series/ops/concat.rs b/src/daft-core/src/series/ops/concat.rs index 5812342bdf..3341533099 100644 --- a/src/daft-core/src/series/ops/concat.rs +++ b/src/daft-core/src/series/ops/concat.rs @@ -1,8 +1,5 @@ -use crate::datatypes::logical::LogicalArray; -use crate::{ - series::{IntoSeries, Series}, - with_match_daft_logical_types, with_match_physical_daft_types, -}; +use crate::series::{IntoSeries, Series}; +use crate::with_match_daft_types; use common_error::{DaftError, DaftResult}; impl Series { @@ -27,16 +24,10 @@ impl Series { ))); } } - if first_dtype.is_logical() { - return Ok(with_match_daft_logical_types!(first_dtype, |$T| { - let downcasted = series.into_iter().map(|s| s.downcast::<<$T as DaftDataType>::ArrayType>()).collect::>>()?; - LogicalArray::<$T>::concat(downcasted.as_slice())?.into_series() - })); - } - with_match_physical_daft_types!(first_dtype, |$T| { + with_match_daft_types!(first_dtype, |$T| { let downcasted = series.into_iter().map(|s| s.downcast::<<$T as DaftDataType>::ArrayType>()).collect::>>()?; - Ok(DataArray::<$T>::concat(downcasted.as_slice())?.into_series()) + Ok(<$T as DaftDataType>::ArrayType::concat(downcasted.as_slice())?.into_series()) }) } } diff --git a/src/daft-core/src/series/ops/downcast.rs b/src/daft-core/src/series/ops/downcast.rs index 2082130861..7558fee98b 100644 --- a/src/daft-core/src/series/ops/downcast.rs +++ b/src/daft-core/src/series/ops/downcast.rs @@ -1,8 +1,7 @@ use std::marker::PhantomData; use crate::datatypes::*; - -use crate::datatypes::logical::{FixedShapeImageArray, ImageArray}; +use crate::datatypes::{logical::FixedShapeImageArray, nested_arrays::FixedSizeListArray}; use crate::series::array_impl::ArrayWrapper; use crate::series::Series; use common_error::DaftResult; @@ -74,30 +73,6 @@ impl Series { self.downcast() } - // pub fn timestamp(&self) -> DaftResult<&TimestampArray> { - // use crate::datatypes::DataType::*; - // match self.data_type() { - // Timestamp(..) => Ok(self.inner.as_any().downcast_ref().unwrap()), - // t => Err(DaftError::SchemaMismatch(format!("{t:?} not timestamp"))), - // } - // } - - // pub fn date(&self) -> DaftResult<&DateArray> { - // use crate::datatypes::DataType::*; - // match self.data_type() { - // Date => Ok(self.inner.as_any().downcast_ref().unwrap()), - // t => Err(DaftError::SchemaMismatch(format!("{t:?} not date"))), - // } - // } - - // pub fn time(&self) -> DaftResult<&TimeArray> { - // use crate::datatypes::DataType::*; - // match self.data_type() { - // Time(..) => Ok(self.inner.as_any().downcast_ref().unwrap()), - // t => Err(DaftError::SchemaMismatch(format!("{t:?} not time"))), - // } - // } - pub fn binary(&self) -> DaftResult<&BinaryArray> { self.downcast() } @@ -118,10 +93,6 @@ impl Series { self.downcast() } - pub fn image(&self) -> DaftResult<&ImageArray> { - self.downcast() - } - pub fn fixed_size_image(&self) -> DaftResult<&FixedShapeImageArray> { self.downcast() } diff --git a/src/daft-core/src/series/ops/image.rs b/src/daft-core/src/series/ops/image.rs index 8f8791a93e..098e873941 100644 --- a/src/daft-core/src/series/ops/image.rs +++ b/src/daft-core/src/series/ops/image.rs @@ -63,7 +63,10 @@ impl Series { let bbox = bbox.fixed_size_list()?; match &self.data_type() { - DataType::Image(_) => self.image()?.crop(bbox).map(|arr| arr.into_series()), + DataType::Image(_) => self + .downcast::()? + .crop(bbox) + .map(|arr| arr.into_series()), DataType::FixedShapeImage(..) => self .fixed_size_image()? .crop(bbox) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 752db9e070..06a6d8cfeb 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -4,16 +4,16 @@ use std::collections::HashSet; use std::fmt::{Display, Formatter, Result}; use daft_core::array::ops::full::FullNull; +use daft_core::with_match_daft_types; use num_traits::ToPrimitive; use daft_core::array::ops::GroupIndices; use common_error::{DaftError, DaftResult}; -use daft_core::datatypes::logical::LogicalArray; use daft_core::datatypes::{BooleanArray, DataType, Field, UInt64Array}; use daft_core::schema::{Schema, SchemaRef}; use daft_core::series::{IntoSeries, Series}; -use daft_core::{with_match_daft_logical_types, with_match_physical_daft_types}; + use daft_dsl::functions::FunctionEvaluator; use daft_dsl::{col, null_lit, AggExpr, Expr}; #[cfg(feature = "python")] @@ -73,15 +73,10 @@ impl Table { Some(schema) => { let mut columns: Vec = Vec::with_capacity(schema.names().len()); for (field_name, field) in schema.fields.iter() { - if field.dtype.is_logical() { - with_match_daft_logical_types!(field.dtype, |$T| { - columns.push(LogicalArray::<$T>::empty(field_name, &field.dtype).into_series()) - }) - } else { - with_match_physical_daft_types!(field.dtype, |$T| { - columns.push(DataArray::<$T>::empty(field_name, &field.dtype).into_series()) - }) - } + let series = with_match_daft_types!(&field.dtype, |$T| { + <$T as DaftDataType>::ArrayType::empty(field_name, &field.dtype).into_series() + }); + columns.push(series) } Ok(Table { schema, columns }) } diff --git a/tests/series/test_slice.py b/tests/series/test_slice.py index df63137362..63cec2e9df 100644 --- a/tests/series/test_slice.py +++ b/tests/series/test_slice.py @@ -28,7 +28,6 @@ def test_series_slice_list_array(fixed) -> None: data = pa.array([[10, 20], [33, None], [43, 45], None, [50, 52], None], type=dtype) s = Series.from_arrow(data) - result = s.slice(2, 4) assert result.datatype() == s.datatype() assert len(result) == 2 diff --git a/tests/table/test_broadcasts.py b/tests/table/test_broadcasts.py index 29133e3932..196cc64dd4 100644 --- a/tests/table/test_broadcasts.py +++ b/tests/table/test_broadcasts.py @@ -2,12 +2,22 @@ import pytest +import daft from daft.expressions import col, lit from daft.table import Table -@pytest.mark.parametrize("data", [1, "a", True, b"Y", 0.5, None, object()]) +@pytest.mark.parametrize("data", [1, "a", True, b"Y", 0.5, None, [1, 2, 3], object()]) def test_broadcast(data): table = Table.from_pydict({"x": [1, 2, 3]}) new_table = table.eval_expression_list([col("x"), lit(data)]) assert new_table.to_pydict() == {"x": [1, 2, 3], "literal": [data for _ in range(3)]} + + +def test_broadcast_fixed_size_list(): + data = [1, 2, 3] + table = Table.from_pydict({"x": [1, 2, 3]}) + new_table = table.eval_expression_list( + [col("x"), lit(data).cast(daft.DataType.fixed_size_list("foo", daft.DataType.int64(), 3))] + ) + assert new_table.to_pydict() == {"x": [1, 2, 3], "literal": [data for _ in range(3)]} diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index fef7fbb038..a621ca2d99 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -536,7 +536,6 @@ def test_grouped_concat_aggs(dtype) -> None: input = [[x] for x in input] + [None] groups = [1, 2, 3, 4, 5, 6, 7] - daft_table = Table.from_pydict({"groups": groups, "input": input}).eval_expression_list( [col("groups"), col("input").cast(DataType.list("item", dtype))] ) diff --git a/tests/table/test_take.py b/tests/table/test_take.py index 7ba637db6e..03ec2b1dcf 100644 --- a/tests/table/test_take.py +++ b/tests/table/test_take.py @@ -154,3 +154,43 @@ def test_table_take_pyobject() -> None: assert taken.column_names() == ["objs"] assert taken.to_pydict()["objs"] == [objects[3], objects[2], objects[2], objects[2], objects[3]] + + +@pytest.mark.parametrize("idx_dtype", daft_int_types) +def test_table_take_fixed_size_list(idx_dtype) -> None: + pa_table = pa.Table.from_pydict( + { + "a": pa.array([[1, 2], [3, None], None, [None, None]], type=pa.list_(pa.int64(), 2)), + "b": pa.array([[4, 5], [6, None], None, [None, None]], type=pa.list_(pa.int64(), 2)), + } + ) + daft_table = Table.from_arrow(pa_table) + assert len(daft_table) == 4 + assert daft_table.column_names() == ["a", "b"] + + indices = Series.from_pylist([0, 1]).cast(idx_dtype) + + taken = daft_table.take(indices) + assert len(taken) == 2 + assert taken.column_names() == ["a", "b"] + + assert taken.to_pydict() == {"a": [[1, 2], [3, None]], "b": [[4, 5], [6, None]]} + + indices = Series.from_pylist([3, 2]).cast(idx_dtype) + + taken = daft_table.take(indices) + assert len(taken) == 2 + assert taken.column_names() == ["a", "b"] + + assert taken.to_pydict() == {"a": [[None, None], None], "b": [[None, None], None]} + + indices = Series.from_pylist([3, 2, 2, 2, 3]).cast(idx_dtype) + + taken = daft_table.take(indices) + assert len(taken) == 5 + assert taken.column_names() == ["a", "b"] + + assert taken.to_pydict() == { + "a": [[None, None], None, None, None, [None, None]], + "b": [[None, None], None, None, None, [None, None]], + }