Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcasale committed Apr 26, 2024
1 parent 41db451 commit 8f8685a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 14 deletions.
22 changes: 11 additions & 11 deletions vortex-array/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::ptype::{NativePType, PType};
use crate::scalar;
use crate::validity::{ArrayValidity, LogicalValidity, Validity, ValidityMetadata};
use crate::visitor::{AcceptArrayVisitor, ArrayVisitor};
use crate::{impl_encoding, ArrayDType, OwnedArray, ToStatic};
use crate::{impl_encoding, ArrayDType, OwnedArray};
use crate::{match_each_native_ptype, ArrayFlatten};

mod accessor;
Expand Down Expand Up @@ -195,23 +195,23 @@ impl<'a> Array<'a> {
impl EncodingCompression for PrimitiveEncoding {}

impl ScalarSubtractFn for PrimitiveArray<'_> {
fn scalar_subtract(&self, summand: Scalar) -> VortexResult<OwnedArray> {
if self.dtype() != summand.dtype() {
vortex_bail!("MismatchedTypes: {}, {}", self.dtype(), summand.dtype())
fn scalar_subtract(&self, to_subtract: Scalar) -> VortexResult<OwnedArray> {
if self.dtype() != to_subtract.dtype() {
vortex_bail!(MismatchedTypes: self.dtype(), to_subtract.dtype())
}
match summand.dtype() {
match to_subtract.dtype() {
DType::Int(..) => {}
DType::Decimal(..) => {}
DType::Float(..) => {}
DType::Utf8(_) => {}
_ => vortex_bail!(InvalidArgument: "summand must be a numeric type"),
_ => vortex_bail!(InvalidArgument: "Can only subtract numeric types"),
}

let summed = match_each_native_ptype!(self.ptype(), |$T| {
let summand = <scalar::Scalar as TryInto<$T>>::try_into(summand)?;
let sum_vec : Vec<$T> = self.typed_data::<$T>().iter().map(|&v| v - summand).collect_vec();
PrimitiveArray::from(sum_vec)
let result = match_each_native_ptype!(self.ptype(), |$T| {
let to_subtract = <scalar::Scalar as TryInto<$T>>::try_into(to_subtract)?;
let sub_vec : Vec<$T> = self.typed_data::<$T>().iter().map(|&v| v - to_subtract).collect_vec();
PrimitiveArray::from(sub_vec)
});
Ok(summed.to_array().to_static())
Ok(result.into_array())
}
}
73 changes: 70 additions & 3 deletions vortex-array/src/compute/scalar_subtract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use crate::scalar::Scalar;
use crate::{Array, OwnedArray};

pub trait ScalarSubtractFn {
fn scalar_subtract(&self, summand: Scalar) -> VortexResult<OwnedArray>;
fn scalar_subtract(&self, to_subtract: Scalar) -> VortexResult<OwnedArray>;
}

pub fn scalar_subtract(array: &Array, summand: Scalar) -> VortexResult<OwnedArray> {
pub fn scalar_subtract(array: &Array, to_subtract: Scalar) -> VortexResult<OwnedArray> {
array.with_dyn(|c| {
c.scalar_subtract()
.map(|t| t.scalar_subtract(summand.clone()))
.map(|t| t.scalar_subtract(to_subtract.clone()))
.unwrap_or_else(|| {
Err(vortex_err!(
NotImplemented: "scalar_subtract",
Expand All @@ -19,3 +19,70 @@ pub fn scalar_subtract(array: &Array, summand: Scalar) -> VortexResult<OwnedArra
})
})
}

#[cfg(test)]
mod test {
use super::*;
use crate::IntoArray;

#[test]
fn test_scalar_subtract_unsigned() {
let values = vec![1u16, 2, 3].into_array();
let results = scalar_subtract(&values, 1u16.into())
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<u16>()
.to_vec();
assert_eq!(results, &[0u16, 1, 2]);
}

#[test]
#[should_panic]
fn test_scalar_subtract_unsigned_wrapping() {
let values = vec![0u16, 2, 3].into_array();
let _results = scalar_subtract(&values, 1u16.into()).unwrap();
}

#[test]
#[should_panic]
fn test_scalar_subtract_signed_wrapping() {
let values = vec![i16::MIN, 2, 3].into_array();
let _results = scalar_subtract(&values, 1u16.into()).unwrap();
}

#[test]
fn test_scalar_subtract_signed() {
let values = vec![1i64, 2, 3].into_array();
let to_subtract = -1i64;
let results = scalar_subtract(&values, to_subtract.into())
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<i64>()
.to_vec();
assert_eq!(results, &[2i64, 3, 4]);
}

#[test]
fn test_scalar_subtract_float() {
let values = vec![1.0f64, 2.0, 3.0].into_array();
let to_subtract = -1f64;
let results = scalar_subtract(&values, to_subtract.into())
.unwrap()
.flatten_primitive()
.unwrap()
.typed_data::<f64>()
.to_vec();
assert_eq!(results, &[2.0f64, 3.0, 4.0]);
}

#[test]
fn test_scalar_subtract_type_mismatch_fails() {
let values = vec![1.0f64, 2.0, 3.0].into_array();
// Subtracting non-equivalent dtypes should fail
let to_subtract = 1u64;
let _results =
scalar_subtract(&values, to_subtract.into()).expect_err("Expected type mismatch error");
}
}

0 comments on commit 8f8685a

Please sign in to comment.