Skip to content

Commit

Permalink
Improve in-place primitive sorts by 13-67% (#4473)
Browse files Browse the repository at this point in the history
* Adding sort_primitives benchmark

* Adding sort_primitives improvements

* Fix lints

* Remove all unsafe code and handle offset cases

* Incorporate review comments

* Remove unneeded returns
  • Loading branch information
psvri authored Jul 4, 2023
1 parent 9ee36b2 commit aac3aa9
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 4 deletions.
72 changes: 68 additions & 4 deletions arrow-ord/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow_array::builder::BufferBuilder;
use arrow_array::cast::*;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::BooleanBufferBuilder;
use arrow_buffer::{ArrowNativeType, MutableBuffer, NullBuffer};
use arrow_data::ArrayData;
use arrow_data::ArrayDataBuilder;
Expand Down Expand Up @@ -57,11 +58,74 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
if let DataType::RunEndEncoded(_, _) = values.data_type() {
return sort_run(values, options, None);
downcast_primitive_array!(
values => sort_native_type(values, options),
DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
_ => {
let indices = sort_to_indices(values, options, None)?;
take(values, &indices, None)
}
)
}

fn sort_native_type<T>(
primitive_values: &PrimitiveArray<T>,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError>
where
T: ArrowPrimitiveType,
{
let sort_options = options.unwrap_or_default();

let mut mutable_buffer = vec![T::default_value(); primitive_values.len()];
let mutable_slice = &mut mutable_buffer;

let input_values = primitive_values.values().as_ref();

let nulls_count = primitive_values.null_count();
let valid_count = primitive_values.len() - nulls_count;

let null_bit_buffer = match nulls_count > 0 {
true => {
let mut validity_buffer = BooleanBufferBuilder::new(primitive_values.len());
if sort_options.nulls_first {
validity_buffer.append_n(nulls_count, false);
validity_buffer.append_n(valid_count, true);
} else {
validity_buffer.append_n(valid_count, true);
validity_buffer.append_n(nulls_count, false);
}
Some(validity_buffer.finish().into())
}
false => None,
};

if let Some(nulls) = primitive_values.nulls().filter(|n| n.null_count() > 0) {
let values_slice = match sort_options.nulls_first {
true => &mut mutable_slice[nulls_count..],
false => &mut mutable_slice[..valid_count],
};

for (write_index, index) in nulls.valid_indices().enumerate() {
values_slice[write_index] = primitive_values.value(index);
}

values_slice.sort_unstable_by(|a, b| a.compare(*b));
if sort_options.descending {
values_slice.reverse();
}
} else {
mutable_slice.copy_from_slice(input_values);
mutable_slice.sort_unstable_by(|a, b| a.compare(*b));
if sort_options.descending {
mutable_slice.reverse();
}
}
let indices = sort_to_indices(values, options, None)?;
take(values, &indices, None)

Ok(Arc::new(
PrimitiveArray::<T>::new(mutable_buffer.into(), null_bit_buffer)
.with_data_type(primitive_values.data_type().clone()),
))
}

/// Sort the `ArrayRef` partially.
Expand Down
5 changes: 5 additions & 0 deletions arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ name = "sort_kernel"
harness = false
required-features = ["test_utils"]

[[bench]]
name = "sort_kernel_primitives"
harness = false
required-features = ["test_utils"]

[[bench]]
name = "partition_kernels"
harness = false
Expand Down
59 changes: 59 additions & 0 deletions arrow/benches/sort_kernel_primitives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#[macro_use]
extern crate criterion;
use arrow_ord::sort::sort;
use criterion::Criterion;

use std::sync::Arc;

extern crate arrow;

use arrow::util::bench_util::*;
use arrow::{array::*, datatypes::Int64Type};

fn create_i64_array(size: usize, with_nulls: bool) -> ArrayRef {
let null_density = if with_nulls { 0.5 } else { 0.0 };
let array = create_primitive_array::<Int64Type>(size, null_density);
Arc::new(array)
}

fn bench_sort(array: &ArrayRef) {
criterion::black_box(sort(criterion::black_box(array), None).unwrap());
}

fn add_benchmark(c: &mut Criterion) {
let arr_a = create_i64_array(2u64.pow(10) as usize, false);

c.bench_function("sort 2^10", |b| b.iter(|| bench_sort(&arr_a)));

let arr_a = create_i64_array(2u64.pow(12) as usize, false);

c.bench_function("sort 2^12", |b| b.iter(|| bench_sort(&arr_a)));

let arr_a = create_i64_array(2u64.pow(10) as usize, true);

c.bench_function("sort nulls 2^10", |b| b.iter(|| bench_sort(&arr_a)));

let arr_a = create_i64_array(2u64.pow(12) as usize, true);

c.bench_function("sort nulls 2^12", |b| b.iter(|| bench_sort(&arr_a)));
}

criterion_group!(benches, add_benchmark);
criterion_main!(benches);

0 comments on commit aac3aa9

Please sign in to comment.