Skip to content

Commit

Permalink
recursively apply
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcasale committed Apr 29, 2024
1 parent a0dcb70 commit d89a25f
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 130 deletions.
1 change: 1 addition & 0 deletions vortex-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ walkdir = { workspace = true }
[dev-dependencies]
criterion = { workspace = true }


[[bench]]
name = "search_sorted"
harness = false
5 changes: 5 additions & 0 deletions vortex-array/src/array/chunked/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use vortex_error::VortexResult;
use crate::array::chunked::ChunkedArray;
use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn};
use crate::compute::scalar_at::{scalar_at, ScalarAtFn};
use crate::compute::scalar_subtract::ScalarSubtractFn;
use crate::compute::take::TakeFn;
use crate::compute::ArrayCompute;
use crate::scalar::Scalar;
Expand All @@ -22,6 +23,10 @@ impl ArrayCompute for ChunkedArray<'_> {
fn take(&self) -> Option<&dyn TakeFn> {
Some(self)
}

fn scalar_subtract(&self) -> Option<&dyn ScalarSubtractFn> {
Some(self)
}
}

impl AsContiguousFn for ChunkedArray<'_> {
Expand Down
39 changes: 38 additions & 1 deletion vortex-array/src/array/chunked/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use vortex_error::{vortex_bail, VortexResult};
use vortex_schema::{IntWidth, Nullability, Signedness};

use crate::array::primitive::PrimitiveArray;
use crate::compute::as_contiguous::as_contiguous;
use crate::compute::scalar_at::scalar_at;
use crate::compute::scalar_subtract::{scalar_subtract, ScalarSubtractFn};
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::validity::Validity::NonNullable;
use crate::validity::{ArrayValidity, LogicalValidity};
Expand Down Expand Up @@ -143,13 +145,25 @@ impl ArrayValidity for ChunkedArray<'_> {

impl EncodingCompression for ChunkedEncoding {}

impl ScalarSubtractFn for ChunkedArray<'_> {
fn scalar_subtract(&self, to_subtract: &Scalar) -> VortexResult<OwnedArray> {
as_contiguous(
&self
.chunks()
.map(|c| scalar_subtract(&c, to_subtract.clone()).unwrap())
.collect_vec(),
)
}
}

#[cfg(test)]
mod test {
use vortex_schema::{DType, IntWidth, Nullability, Signedness};

use crate::array::chunked::{ChunkedArray, OwnedChunkedArray};
use crate::compute::scalar_subtract::scalar_subtract;
use crate::ptype::NativePType;
use crate::{Array, IntoArray};
use crate::{Array, IntoArray, ToArray, ToStatic};

#[allow(dead_code)]
fn chunked_array() -> OwnedChunkedArray {
Expand Down Expand Up @@ -179,6 +193,29 @@ mod test {
assert_eq!(values, slice);
}

#[test]
fn test_scalar_subtract() {
let chunk1 = vec![1.0f64, 2.0, 3.0].into_array();
let chunk2 = vec![4.0f64, 5.0, 6.0].into_array();
let to_subtract = -1f64;

let chunked = ChunkedArray::try_new(
vec![chunk1, chunk2],
DType::Float(64.into(), Nullability::NonNullable),
)
.unwrap()
.to_array()
.to_static();

let array = scalar_subtract(&chunked, to_subtract).unwrap();
let results = array
.flatten_primitive()
.unwrap()
.typed_data::<f64>()
.to_vec();
assert_eq!(results, &[2.0f64, 3.0, 4.0, 5.0, 6.0, 7.0]);
}

// FIXME(ngates): bring back when slicing is a compute function.
// #[test]
// pub fn slice_middle() {
Expand Down
135 changes: 124 additions & 11 deletions vortex-array/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use vortex_error::{vortex_bail, VortexResult};

use crate::buffer::Buffer;
use crate::compute::scalar_subtract::ScalarSubtractFn;
use crate::match_each_integer_ptype;
use crate::ptype::{NativePType, PType};
use crate::stats::ArrayStatistics;
use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata};
use crate::visitor::{AcceptArrayVisitor, ArrayVisitor};
use crate::{impl_encoding, ArrayDType, OwnedArray};
use crate::{match_each_integer_ptype, scalar};
use crate::{impl_encoding, match_each_float_ptype, ArrayDType, OwnedArray};
use crate::{match_each_native_ptype, ArrayFlatten};

mod accessor;
Expand Down Expand Up @@ -196,34 +196,36 @@ impl<'a> Array<'a> {
impl EncodingCompression for PrimitiveEncoding {}

impl ScalarSubtractFn for PrimitiveArray<'_> {
fn scalar_subtract(&self, to_subtract: Scalar) -> VortexResult<OwnedArray> {
fn scalar_subtract(&self, to_subtract: &Scalar) -> VortexResult<OwnedArray> {
if self.dtype() != to_subtract.dtype() {
vortex_bail!(MismatchedTypes: self.dtype(), to_subtract.dtype())
}

let result = match to_subtract.dtype() {
DType::Int(..) => {
match_each_integer_ptype!(self.ptype(), |$T| {
let to_subtract = <scalar::Scalar as TryInto<$T>>::try_into(to_subtract)?;
let maybe_min = self.statistics().compute_as_cast(Stat::Min);//.unwrap_or($T::MAX);
let to_subtract = $T::try_from(to_subtract)?;
let maybe_min = self.statistics().compute_as_cast(Stat::Min);

if maybe_min.is_some() {
let min: $T = maybe_min.unwrap();
let max: $T = self.statistics().compute_as_cast(Stat::Max).unwrap();
if let Some(min) = maybe_min {
let min: $T = min;
if let (min, true) = min.overflowing_sub(to_subtract) {
vortex_bail!("Integer subtraction over/underflow: {}, {}", min, to_subtract)
}
if let (max, true) = max.overflowing_sub(to_subtract) {
if let Some(max) = self.statistics().compute_as_cast(Stat::Max) {
let max: $T = max;
if let (max, true) = max.overflowing_sub(to_subtract) {
vortex_bail!("Integer subtraction over/underflow: {}, {}", max, to_subtract)
}
}
}
let sub_vec : Vec<$T> = self.typed_data::<$T>().iter().map(|&v| v - to_subtract).collect_vec();
PrimitiveArray::from(sub_vec)
})
}
DType::Decimal(..) | DType::Float(..) => {
match_each_native_ptype!(self.ptype(), |$T| {
let to_subtract = <scalar::Scalar as TryInto<$T>>::try_into(to_subtract)?;
match_each_float_ptype!(self.ptype(), |$T| {
let to_subtract = $T::try_from(to_subtract)?;
let sub_vec : Vec<$T> = self.typed_data::<$T>().iter().map(|&v| v - to_subtract).collect_vec();
PrimitiveArray::from(sub_vec)
})
Expand All @@ -234,3 +236,114 @@ impl ScalarSubtractFn for PrimitiveArray<'_> {
Ok(result.into_array())
}
}

#[cfg(test)]
mod test {
use crate::compute::scalar_subtract::scalar_subtract;
use crate::IntoArray;

#[test]
fn test_scalar_subtract_unsigned() {
let values = vec![1u16, 2, 3].into_array();
let results = scalar_subtract(&values, 1u16)
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<u16>()
.to_vec();
assert_eq!(results, &[0u16, 1, 2]);
}

#[test]
fn test_scalar_subtract_signed() {
let values = vec![1i64, 2, 3].into_array();
let results = scalar_subtract(&values, -1i64)
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<i64>()
.to_vec();
assert_eq!(results, &[2i64, 3, 4]);
}

#[test]
fn test_scalar_subtract_float() {
let values = vec![1.0f64, 2.0, 3.0].into_array();
let to_subtract = -1f64;
let results = scalar_subtract(&values, to_subtract)
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<f64>()
.to_vec();
assert_eq!(results, &[2.0f64, 3.0, 4.0]);
}

#[test]
fn test_scalar_subtract_int_from_float() {
let values = vec![3.0f64, 4.0, 5.0].into_array();
// Ints can be cast to floats, so there's no problem here
let results = scalar_subtract(&values, 1u64)
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<f64>()
.to_vec();
assert_eq!(results, &[2.0f64, 3.0, 4.0]);
}

#[test]
fn test_scalar_subtract_unsigned_underflow() {
let values = vec![u8::MIN, 2, 3].into_array();
let _results = scalar_subtract(&values, 1u8).expect_err("should fail with underflow");
let values = vec![u16::MIN, 2, 3].into_array();
let _results = scalar_subtract(&values, 1u16).expect_err("should fail with underflow");
let values = vec![u32::MIN, 2, 3].into_array();
let _results = scalar_subtract(&values, 1u32).expect_err("should fail with underflow");
let values = vec![u64::MIN, 2, 3].into_array();
let _results = scalar_subtract(&values, 1u64).expect_err("should fail with underflow");
}

#[test]
fn test_scalar_subtract_signed_overflow() {
let values = vec![i8::MAX, 2, 3].into_array();
let to_subtract = -1i8;
let _results =
scalar_subtract(&values, to_subtract).expect_err("should fail with overflow");
let values = vec![i16::MAX, 2, 3].into_array();
let _results =
scalar_subtract(&values, to_subtract).expect_err("should fail with overflow");
let values = vec![i32::MAX, 2, 3].into_array();
let _results =
scalar_subtract(&values, to_subtract).expect_err("should fail with overflow");
let values = vec![i64::MAX, 2, 3].into_array();
let _results =
scalar_subtract(&values, to_subtract).expect_err("should fail with overflow");
}

#[test]
fn test_scalar_subtract_signed_underflow() {
let values = vec![i8::MIN, 2, 3].into_array();
let _results = scalar_subtract(&values, 1i8).expect_err("should fail with underflow");
let values = vec![i16::MIN, 2, 3].into_array();
let _results = scalar_subtract(&values, 1i16).expect_err("should fail with underflow");
let values = vec![i32::MIN, 2, 3].into_array();
let _results = scalar_subtract(&values, 1i32).expect_err("should fail with underflow");
let values = vec![i64::MIN, 2, 3].into_array();
let _results = scalar_subtract(&values, 1i64).expect_err("should fail with underflow");
}

#[test]
fn test_scalar_subtract_float_underflow_is_ok() {
let values = vec![f32::MIN, 2.0, 3.0].into_array();
let _results = scalar_subtract(&values, 1.0f32).unwrap();
let _results = scalar_subtract(&values, f32::MAX).unwrap();
}

#[test]
fn test_scalar_subtract_type_mismatch_fails() {
let values = vec![1u64, 2, 3].into_array();
// Subtracting incompatible dtypes should fail
let _results = scalar_subtract(&values, 1.5f64).expect_err("Expected type mismatch error");
}
}
Loading

0 comments on commit d89a25f

Please sign in to comment.