diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 0157caf8c296..0ef6532da477 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -42,14 +42,6 @@ runs: - name: Generate lockfile shell: bash run: cargo fetch - - name: Cache Rust dependencies - uses: actions/cache@v3 - with: - # these represent compiled steps of both dependencies and arrow - # and thus are specific for a particular OS, arch and rust version. - path: /github/home/target - key: ${{ runner.os }}-${{ runner.arch }}-target-cache3-${{ inputs.rust-version }}-${{ hashFiles('**/Cargo.lock') }} - restore-keys: ${{ runner.os }}-${{ runner.arch }}-target-cache3-${{ inputs.rust-version }}- - name: Install Build Dependencies shell: bash run: | diff --git a/.github/workflows/arrow_flight.yml b/.github/workflows/arrow_flight.yml new file mode 100644 index 000000000000..ddf3c21dab31 --- /dev/null +++ b/.github/workflows/arrow_flight.yml @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +name: "Arrow Flight" + +on: + pull_request: + +jobs: + # test the crate + linux-test: + name: Test + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v2 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Tests with default features + run: | + cargo test -p arrow-flight + - name: Tests with all features + run: | + cargo test -p arrow-flight --all-features diff --git a/.github/workflows/parquet_derive.yml b/.github/workflows/parquet_derive.yml new file mode 100644 index 000000000000..9f9c6743cb9e --- /dev/null +++ b/.github/workflows/parquet_derive.yml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +name: "Parquet Derive" + +on: + pull_request: + +jobs: + # test the crate + linux-test: + name: Test + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v2 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Test crate + run: | + cargo test -p parquet_derive diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 15e6741476b0..34b786719d20 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -73,9 +73,6 @@ jobs: cargo check -p arrow --all-targets cargo check -p arrow --no-default-features --all-targets cargo check -p arrow --no-default-features --all-targets --features test_utils - - name: Re-run tests on arrow-flight with all features - run: | - cargo test -p arrow-flight --all-features - name: Re-run tests on parquet crate with all features run: | cargo test -p parquet --all-features @@ -90,9 +87,6 @@ jobs: cargo check -p parquet --all-targets cargo check -p parquet --no-default-features --all-targets cargo check -p parquet --no-default-features --features arrow --all-targets - - name: Test compilation of parquet_derive macro with different feature combinations - run: | - cargo check -p parquet_derive # test the --features "simd" of the arrow crate. This requires nightly. linux-test-simd: diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 5344e160c09b..7733ce67a76e 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -24,8 +24,7 @@ //! use crate::array::*; -use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuffer}; -use crate::compute::binary_boolean_kernel; +use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer}; use crate::compute::util::combine_option_bitmap; use crate::datatypes::{ ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type, @@ -37,171 +36,74 @@ use crate::datatypes::{ use crate::error::{ArrowError, Result}; use crate::util::bit_util; use regex::{escape, Regex}; -use std::any::type_name; use std::collections::HashMap; -/// Helper function to perform boolean lambda function on values from two arrays, this +/// Helper function to perform boolean lambda function on values from two array accessors, this /// version does not attempt to use SIMD. -macro_rules! compare_op { - ($left: expr, $right:expr, $op:expr) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let null_bit_buffer = - combine_option_bitmap(&[$left.data_ref(), $right.data_ref()], $left.len())?; - - // Safety: - // `i < $left.len()` and $left.len() == $right.len() - let comparison = (0..$left.len()) - .map(|i| unsafe { $op($left.value_unchecked(i), $right.value_unchecked(i)) }); - // same size as $left.len() and $right.len() - let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - $left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(buffer)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) - }}; -} +fn compare_op(left: T, right: T, op: F) -> Result +where + F: Fn(T::Item, T::Item) -> bool, +{ + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } -macro_rules! compare_op_primitive { - ($left: expr, $right:expr, $op:expr) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } + let null_bit_buffer = + combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; - let null_bit_buffer = - combine_option_bitmap(&[$left.data_ref(), $right.data_ref()], $left.len())?; - - let mut values = MutableBuffer::from_len_zeroed(($left.len() + 7) / 8); - let lhs_chunks_iter = $left.values().chunks_exact(8); - let lhs_remainder = lhs_chunks_iter.remainder(); - let rhs_chunks_iter = $right.values().chunks_exact(8); - let rhs_remainder = rhs_chunks_iter.remainder(); - let chunks = $left.len() / 8; - - values[..chunks] - .iter_mut() - .zip(lhs_chunks_iter) - .zip(rhs_chunks_iter) - .for_each(|((byte, lhs), rhs)| { - lhs.iter() - .zip(rhs.iter()) - .enumerate() - .for_each(|(i, (&lhs, &rhs))| { - *byte |= if $op(lhs, rhs) { 1 << i } else { 0 }; - }); - }); + // Safety: + // `i < $left.len()` and $left.len() == $right.len() + let comparison = (0..left.len()) + .map(|i| unsafe { op(left.value_unchecked(i), right.value_unchecked(i)) }); + // same size as $left.len() and $right.len() + let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; - if !lhs_remainder.is_empty() { - let last = &mut values[chunks]; - lhs_remainder - .iter() - .zip(rhs_remainder.iter()) - .enumerate() - .for_each(|(i, (&lhs, &rhs))| { - *last |= if $op(lhs, rhs) { 1 << i } else { 0 }; - }); - }; - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - $left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(values)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) - }}; + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + left.len(), + None, + null_bit_buffer, + 0, + vec![Buffer::from(buffer)], + vec![], + ) + }; + Ok(BooleanArray::from(data)) } -macro_rules! compare_op_scalar { - ($left:expr, $op:expr) => {{ - let null_bit_buffer = $left - .data() - .null_buffer() - .map(|b| b.bit_slice($left.offset(), $left.len())); - - // Safety: - // `i < $left.len()` - let comparison = - (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i)) }); - // same as $left.len() - let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - $left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(buffer)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) - }}; -} +/// Helper function to perform boolean lambda function on values from array accessor, this +/// version does not attempt to use SIMD. +fn compare_op_scalar(left: T, op: F) -> Result +where + F: Fn(T::Item) -> bool, +{ + let null_bit_buffer = left + .data() + .null_buffer() + .map(|b| b.bit_slice(left.offset(), left.len())); -macro_rules! compare_op_scalar_primitive { - ($left: expr, $right:expr, $op:expr) => {{ - let null_bit_buffer = $left - .data() - .null_buffer() - .map(|b| b.bit_slice($left.offset(), $left.len())); - - let mut values = MutableBuffer::from_len_zeroed(($left.len() + 7) / 8); - let lhs_chunks_iter = $left.values().chunks_exact(8); - let lhs_remainder = lhs_chunks_iter.remainder(); - let chunks = $left.len() / 8; - - values[..chunks] - .iter_mut() - .zip(lhs_chunks_iter) - .for_each(|(byte, chunk)| { - chunk.iter().enumerate().for_each(|(i, &c_i)| { - *byte |= if $op(c_i, $right) { 1 << i } else { 0 }; - }); - }); - if !lhs_remainder.is_empty() { - let last = &mut values[chunks]; - lhs_remainder.iter().enumerate().for_each(|(i, &lhs)| { - *last |= if $op(lhs, $right) { 1 << i } else { 0 }; - }); - }; + // Safety: + // `i < $left.len()` + let comparison = (0..left.len()).map(|i| unsafe { op(left.value_unchecked(i)) }); + // same as $left.len() + let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - $left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(values)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) - }}; + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + left.len(), + None, + null_bit_buffer, + 0, + vec![Buffer::from(buffer)], + vec![], + ) + }; + Ok(BooleanArray::from(data)) } /// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified @@ -215,7 +117,7 @@ where T: ArrowNumericType, F: Fn(T::Native, T::Native) -> bool, { - compare_op_primitive!(left, right, op) + compare_op(left, right, op) } /// Evaluate `op(left, right)` for [`PrimitiveArray`] and scalar using @@ -229,7 +131,7 @@ where T: ArrowNumericType, F: Fn(T::Native, T::Native) -> bool, { - compare_op_scalar_primitive!(left, right, op) + compare_op_scalar(left, |l| op(l, right)) } fn is_like_pattern(c: char) -> bool { @@ -769,7 +671,7 @@ pub fn eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a == b) + compare_op(left, right, |a, b| a == b) } /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -777,66 +679,37 @@ pub fn eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a == right) -} - -#[inline] -fn binary_boolean_op( - left: &BooleanArray, - right: &BooleanArray, - op: F, -) -> Result -where - F: Copy + Fn(u64, u64) -> u64, -{ - binary_boolean_kernel( - left, - right, - |left: &Buffer, - left_offset_in_bits: usize, - right: &Buffer, - right_offset_in_bits: usize, - len_in_bits: usize| { - bitwise_bin_op_helper( - left, - left_offset_in_bits, - right, - right_offset_in_bits, - len_in_bits, - op, - ) - }, - ) + compare_op_scalar(left, |a| a == right) } /// Perform `left == right` operation on [`BooleanArray`] pub fn eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| !(a ^ b)) + compare_op(left, right, |a, b| !(a ^ b)) } /// Perform `left != right` operation on [`BooleanArray`] pub fn neq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| (a ^ b)) + compare_op(left, right, |a, b| (a ^ b)) } /// Perform `left < right` operation on [`BooleanArray`] pub fn lt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| ((!a) & b)) + compare_op(left, right, |a, b| ((!a) & b)) } /// Perform `left <= right` operation on [`BooleanArray`] pub fn lt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| !(a & (!b))) + compare_op(left, right, |a, b| !(a & (!b))) } /// Perform `left > right` operation on [`BooleanArray`] pub fn gt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| (a & (!b))) + compare_op(left, right, |a, b| (a & (!b))) } /// Perform `left >= right` operation on [`BooleanArray`] pub fn gt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| !((!a) & b)) + compare_op(left, right, |a, b| !((!a) & b)) } /// Perform `left == right` operation on [`BooleanArray`] and a scalar @@ -870,22 +743,22 @@ pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result /// Perform `left < right` operation on [`BooleanArray`] and a scalar pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, |a: bool| !a & right) + compare_op_scalar(left, |a: bool| !a & right) } /// Perform `left <= right` operation on [`BooleanArray`] and a scalar pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, |a| a <= right) + compare_op_scalar(left, |a| a <= right) } /// Perform `left > right` operation on [`BooleanArray`] and a scalar pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, |a: bool| a & !right) + compare_op_scalar(left, |a: bool| a & !right) } /// Perform `left >= right` operation on [`BooleanArray`] and a scalar pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, |a| a >= right) + compare_op_scalar(left, |a| a >= right) } /// Perform `left != right` operation on [`BooleanArray`] and a scalar @@ -898,7 +771,7 @@ pub fn eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a == b) + compare_op(left, right, |a, b| a == b) } /// Perform `left == right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar @@ -906,7 +779,7 @@ pub fn eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a == right) + compare_op_scalar(left, |a| a == right) } /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -914,7 +787,7 @@ pub fn neq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a != b) + compare_op(left, right, |a, b| a != b) } /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -922,7 +795,7 @@ pub fn neq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a != right) + compare_op_scalar(left, |a| a != right) } /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -930,7 +803,7 @@ pub fn lt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a < b) + compare_op(left, right, |a, b| a < b) } /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -938,7 +811,7 @@ pub fn lt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a < right) + compare_op_scalar(left, |a| a < right) } /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -946,7 +819,7 @@ pub fn lt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a <= b) + compare_op(left, right, |a, b| a <= b) } /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -954,7 +827,7 @@ pub fn lt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a <= right) + compare_op_scalar(left, |a| a <= right) } /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -962,7 +835,7 @@ pub fn gt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a > b) + compare_op(left, right, |a, b| a > b) } /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -970,7 +843,7 @@ pub fn gt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a > right) + compare_op_scalar(left, |a| a > right) } /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -978,7 +851,7 @@ pub fn gt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a >= b) + compare_op(left, right, |a, b| a >= b) } /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -986,7 +859,7 @@ pub fn gt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a >= right) + compare_op_scalar(left, |a| a >= right) } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -994,7 +867,7 @@ pub fn neq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a != b) + compare_op(left, right, |a, b| a != b) } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1002,7 +875,7 @@ pub fn neq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a != right) + compare_op_scalar(left, |a| a != right) } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1010,7 +883,7 @@ pub fn lt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a < b) + compare_op(left, right, |a, b| a < b) } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1018,7 +891,7 @@ pub fn lt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a < right) + compare_op_scalar(left, |a| a < right) } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1026,7 +899,7 @@ pub fn lt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a <= b) + compare_op(left, right, |a, b| a <= b) } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1034,7 +907,7 @@ pub fn lt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a <= right) + compare_op_scalar(left, |a| a <= right) } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1042,7 +915,7 @@ pub fn gt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a > b) + compare_op(left, right, |a, b| a > b) } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1050,7 +923,7 @@ pub fn gt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a > right) + compare_op_scalar(left, |a| a > right) } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1058,7 +931,7 @@ pub fn gt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a >= b) + compare_op(left, right, |a, b| a >= b) } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1066,7 +939,7 @@ pub fn gt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a >= right) + compare_op_scalar(left, |a| a >= right) } /// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message. @@ -1931,177 +1804,107 @@ where Ok(BooleanArray::from(data)) } -macro_rules! typed_cmp { - ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident) => {{ - let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| { - ArrowError::CastError(format!( - "Left array cannot be cast to {}", - type_name::<$T>() - )) - })?; - let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| { - ArrowError::CastError(format!( - "Right array cannot be cast to {}", - type_name::<$T>(), - )) - })?; - $OP(left, right) - }}; - ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{ - let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| { - ArrowError::CastError(format!( - "Left array cannot be cast to {}", - type_name::<$T>() - )) - })?; - let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| { - ArrowError::CastError(format!( - "Right array cannot be cast to {}", - type_name::<$T>(), - )) - })?; - $OP::<$TT>(left, right) - }}; +fn cmp_primitive_array( + left: &dyn Array, + right: &dyn Array, + op: F, +) -> Result +where + F: Fn(T::Native, T::Native) -> bool, +{ + let left_array = as_primitive_array::(left); + let right_array = as_primitive_array::(right); + compare_op(left_array, right_array, op) } macro_rules! typed_compares { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident, $OP_BINARY: ident) => {{ + ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { (DataType::Boolean, DataType::Boolean) => { - typed_cmp!($LEFT, $RIGHT, BooleanArray, $OP_BOOL) + compare_op(as_boolean_array($LEFT), as_boolean_array($RIGHT), $OP_BOOL) } (DataType::Int8, DataType::Int8) => { - typed_cmp!($LEFT, $RIGHT, Int8Array, $OP_PRIM, Int8Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Int16, DataType::Int16) => { - typed_cmp!($LEFT, $RIGHT, Int16Array, $OP_PRIM, Int16Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Int32, DataType::Int32) => { - typed_cmp!($LEFT, $RIGHT, Int32Array, $OP_PRIM, Int32Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Int64, DataType::Int64) => { - typed_cmp!($LEFT, $RIGHT, Int64Array, $OP_PRIM, Int64Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::UInt8, DataType::UInt8) => { - typed_cmp!($LEFT, $RIGHT, UInt8Array, $OP_PRIM, UInt8Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::UInt16, DataType::UInt16) => { - typed_cmp!($LEFT, $RIGHT, UInt16Array, $OP_PRIM, UInt16Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::UInt32, DataType::UInt32) => { - typed_cmp!($LEFT, $RIGHT, UInt32Array, $OP_PRIM, UInt32Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::UInt64, DataType::UInt64) => { - typed_cmp!($LEFT, $RIGHT, UInt64Array, $OP_PRIM, UInt64Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Float32, DataType::Float32) => { - typed_cmp!($LEFT, $RIGHT, Float32Array, $OP_PRIM, Float32Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Float64, DataType::Float64) => { - typed_cmp!($LEFT, $RIGHT, Float64Array, $OP_PRIM, Float64Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Utf8, DataType::Utf8) => { - typed_cmp!($LEFT, $RIGHT, StringArray, $OP_STR, i32) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64) - } - (DataType::Binary, DataType::Binary) => { - typed_cmp!($LEFT, $RIGHT, BinaryArray, $OP_BINARY, i32) - } - (DataType::LargeBinary, DataType::LargeBinary) => { - typed_cmp!($LEFT, $RIGHT, LargeBinaryArray, $OP_BINARY, i64) + compare_op(as_string_array($LEFT), as_string_array($RIGHT), $OP) } + (DataType::LargeUtf8, DataType::LargeUtf8) => compare_op( + as_largestring_array($LEFT), + as_largestring_array($RIGHT), + $OP, + ), + (DataType::Binary, DataType::Binary) => compare_op( + as_generic_binary_array::($LEFT), + as_generic_binary_array::($RIGHT), + $OP, + ), + (DataType::LargeBinary, DataType::LargeBinary) => compare_op( + as_generic_binary_array::($LEFT), + as_generic_binary_array::($RIGHT), + $OP, + ), ( DataType::Timestamp(TimeUnit::Nanosecond, _), DataType::Timestamp(TimeUnit::Nanosecond, _), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - TimestampNanosecondArray, - $OP_PRIM, - TimestampNanosecondType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Timestamp(TimeUnit::Microsecond, _), DataType::Timestamp(TimeUnit::Microsecond, _), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - TimestampMicrosecondArray, - $OP_PRIM, - TimestampMicrosecondType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Timestamp(TimeUnit::Millisecond, _), DataType::Timestamp(TimeUnit::Millisecond, _), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - TimestampMillisecondArray, - $OP_PRIM, - TimestampMillisecondType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Timestamp(TimeUnit::Second, _), DataType::Timestamp(TimeUnit::Second, _), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - TimestampSecondArray, - $OP_PRIM, - TimestampSecondType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), (DataType::Date32, DataType::Date32) => { - typed_cmp!($LEFT, $RIGHT, Date32Array, $OP_PRIM, Date32Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Date64, DataType::Date64) => { - typed_cmp!($LEFT, $RIGHT, Date64Array, $OP_PRIM, Date64Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } ( DataType::Interval(IntervalUnit::YearMonth), DataType::Interval(IntervalUnit::YearMonth), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - IntervalYearMonthArray, - $OP_PRIM, - IntervalYearMonthType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Interval(IntervalUnit::DayTime), DataType::Interval(IntervalUnit::DayTime), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - IntervalDayTimeArray, - $OP_PRIM, - IntervalDayTimeType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Interval(IntervalUnit::MonthDayNano), DataType::Interval(IntervalUnit::MonthDayNano), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - IntervalMonthDayNanoArray, - $OP_PRIM, - IntervalMonthDayNanoType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing arrays of type {} is not yet implemented", t1 @@ -2410,7 +2213,7 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) => { typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b) } - _ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary), + _ => typed_compares!(left, right, |a, b| !(a ^ b), |a, b| a == b), } } @@ -2435,7 +2238,7 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) => { typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b) } - _ => typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary), + _ => typed_compares!(left, right, |a, b| (a ^ b), |a, b| a != b), } } @@ -2460,7 +2263,7 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) => { typed_dict_compares!(left, right, |a, b| a < b, |a, b| a < b) } - _ => typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary), + _ => typed_compares!(left, right, |a, b| ((!a) & b), |a, b| a < b), } } @@ -2484,7 +2287,7 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) => { typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a <= b) } - _ => typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary), + _ => typed_compares!(left, right, |a, b| !(a & (!b)), |a, b| a <= b), } } @@ -2508,7 +2311,7 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) => { typed_dict_compares!(left, right, |a, b| a > b, |a, b| a > b) } - _ => typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary), + _ => typed_compares!(left, right, |a, b| (a & (!b)), |a, b| a > b), } } @@ -2531,7 +2334,7 @@ pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) => { typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a >= b) } - _ => typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary), + _ => typed_compares!(left, right, |a, b| !((!a) & b), |a, b| a >= b), } } @@ -2543,7 +2346,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::eq, |a, b| a == b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a == b); + return compare_op(left, right, |a, b| a == b); } /// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2554,7 +2357,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a == right); + return compare_op_scalar(left, |a| a == right); } /// Applies an unary and infallible comparison function to a primitive array. @@ -2563,7 +2366,7 @@ where T: ArrowNumericType, F: Fn(T::Native) -> bool, { - return compare_op_scalar!(left, op); + compare_op_scalar(left, op) } /// Perform `left != right` operation on two [`PrimitiveArray`]s. @@ -2574,7 +2377,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::ne, |a, b| a != b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a != b); + return compare_op(left, right, |a, b| a != b); } /// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2585,7 +2388,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a != right); + return compare_op_scalar(left, |a| a != right); } /// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2597,7 +2400,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::lt, |a, b| a < b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a < b); + return compare_op(left, right, |a, b| a < b); } /// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2609,7 +2412,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a < right); + return compare_op_scalar(left, |a| a < right); } /// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2624,7 +2427,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::le, |a, b| a <= b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a <= b); + return compare_op(left, right, |a, b| a <= b); } /// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2636,7 +2439,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a <= right); + return compare_op_scalar(left, |a| a <= right); } /// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2648,7 +2451,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::gt, |a, b| a > b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a > b); + return compare_op(left, right, |a, b| a > b); } /// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2660,7 +2463,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a > right); + return compare_op_scalar(left, |a| a > right); } /// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2675,7 +2478,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::ge, |a, b| a >= b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a >= b); + return compare_op(left, right, |a, b| a >= b); } /// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2687,7 +2490,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a >= right); + return compare_op_scalar(left, |a| a >= right); } /// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]