Skip to content

Commit

Permalink
[Minor] Clean up DecimalArray API Usage (apache#1869)
Browse files Browse the repository at this point in the history
* Clean up decimal array creation

* Refactor a bit more

* Update the next
!

* cleanup

* update

* port over min/max

* Update sum

* Use max scale / precision from arrow

* fmt

* Fixup

* clippy
  • Loading branch information
alamb authored Mar 6, 2022
1 parent e8aff59 commit d082fa3
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 214 deletions.
4 changes: 1 addition & 3 deletions datafusion-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,4 @@ mod scalar;
pub use column::Column;
pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema};
pub use error::{DataFusionError, Result};
pub use scalar::{
ScalarType, ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128,
};
pub use scalar::{ScalarType, ScalarValue};
45 changes: 10 additions & 35 deletions datafusion-common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow::{
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
DECIMAL_MAX_PRECISION,
},
};
use ordered_float::OrderedFloat;
Expand All @@ -34,11 +35,6 @@ use std::convert::{Infallible, TryInto};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};

// TODO may need to be moved to arrow-rs
/// The max precision and scale for decimal128
pub const MAX_PRECISION_FOR_DECIMAL128: usize = 38;
pub const MAX_SCALE_FOR_DECIMAL128: usize = 38;

/// Represents a dynamically typed, nullable single value.
/// This is the single-valued counter-part of arrow’s `Array`.
#[derive(Clone)]
Expand Down Expand Up @@ -542,7 +538,7 @@ impl ScalarValue {
scale: usize,
) -> Result<Self> {
// make sure the precision and scale is valid
if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision {
if precision <= DECIMAL_MAX_PRECISION && scale <= precision {
return Ok(ScalarValue::Decimal128(Some(value), precision, scale));
}
return Err(DataFusionError::Internal(format!(
Expand Down Expand Up @@ -985,26 +981,15 @@ impl ScalarValue {
precision: &usize,
scale: &usize,
) -> Result<DecimalArray> {
// collect the value as Option<i128>
let array = scalars
.into_iter()
.map(|element: ScalarValue| match element {
ScalarValue::Decimal128(v1, _, _) => v1,
_ => unreachable!(),
})
.collect::<Vec<Option<i128>>>();

// build the decimal array using the Decimal Builder
let mut builder = DecimalBuilder::new(array.len(), *precision, *scale);
array.iter().for_each(|element| match element {
None => {
builder.append_null().unwrap();
}
Some(v) => {
builder.append_value(*v).unwrap();
}
});
Ok(builder.finish())
.collect::<DecimalArray>()
.with_precision_and_scale(*precision, *scale)?;
Ok(array)
}

fn iter_to_array_list(
Expand Down Expand Up @@ -1080,21 +1065,11 @@ impl ScalarValue {
scale: &usize,
size: usize,
) -> DecimalArray {
let mut builder = DecimalBuilder::new(size, *precision, *scale);
match value {
None => {
for _i in 0..size {
builder.append_null().unwrap();
}
}
Some(v) => {
let v = *v;
for _i in 0..size {
builder.append_value(v).unwrap();
}
}
};
builder.finish()
std::iter::repeat(value)
.take(size)
.collect::<DecimalArray>()
.with_precision_and_scale(*precision, *scale)
.unwrap()
}

/// Converts a scalar value into an array of `size` rows.
Expand Down
7 changes: 3 additions & 4 deletions datafusion-physical-expr/src/coercion_rule/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

//! Coercion rules for matching argument types for binary operators

use arrow::datatypes::DataType;
use arrow::datatypes::{DataType, DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE};
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128};
use datafusion_expr::Operator;

/// Coercion rules for all binary operators. Returns the output type
Expand Down Expand Up @@ -261,8 +260,8 @@ fn mathematics_numerical_coercion(

fn create_decimal_type(precision: usize, scale: usize) -> DataType {
DataType::Decimal(
MAX_PRECISION_FOR_DECIMAL128.min(precision),
MAX_SCALE_FOR_DECIMAL128.min(scale),
DECIMAL_MAX_PRECISION.min(precision),
DECIMAL_MAX_SCALE.min(scale),
)
}

Expand Down
47 changes: 22 additions & 25 deletions datafusion-physical-expr/src/expressions/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@ use std::sync::Arc;

use crate::{AggregateExpr, PhysicalExpr};
use arrow::compute;
use arrow::datatypes::DataType;
use arrow::datatypes::{DataType, DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE};
use arrow::{
array::{ArrayRef, UInt64Array},
datatypes::Field,
};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_common::{
ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128,
};
use datafusion_expr::Accumulator;

use super::{format_state_name, sum};
Expand All @@ -50,8 +48,8 @@ pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
DataType::Decimal(precision, scale) => {
// in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4);
let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4);
let new_precision = DECIMAL_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal(new_precision, new_scale))
}
DataType::Int8
Expand Down Expand Up @@ -237,11 +235,12 @@ mod tests {
#[test]
fn avg_decimal() -> Result<()> {
// test agg
let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
for i in 1..7 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
let array: ArrayRef = Arc::new(
(1..7)
.map(Some)
.collect::<DecimalArray>()
.with_precision_and_scale(10, 0)?,
);

generic_test_op!(
array,
Expand All @@ -254,15 +253,12 @@ mod tests {

#[test]
fn avg_decimal_with_nulls() -> Result<()> {
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for i in 1..6 {
if i == 2 {
decimal_builder.append_null()?;
} else {
decimal_builder.append_value(i)?;
}
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
let array: ArrayRef = Arc::new(
(1..6)
.map(|i| if i == 2 { None } else { Some(i) })
.collect::<DecimalArray>()
.with_precision_and_scale(10, 0)?,
);
generic_test_op!(
array,
DataType::Decimal(10, 0),
Expand All @@ -275,11 +271,12 @@ mod tests {
#[test]
fn avg_decimal_all_nulls() -> Result<()> {
// test agg
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for _i in 1..6 {
decimal_builder.append_null()?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
let array: ArrayRef = Arc::new(
std::iter::repeat(None)
.take(6)
.collect::<DecimalArray>()
.with_precision_and_scale(10, 0)?,
);
generic_test_op!(
array,
DataType::Decimal(10, 0),
Expand Down
80 changes: 54 additions & 26 deletions datafusion-physical-expr/src/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ mod tests {
use crate::expressions::col;
use arrow::{
array::{
Array, DecimalArray, DecimalBuilder, Float32Array, Float64Array, Int16Array,
Int32Array, Int64Array, Int8Array, StringArray, Time64NanosecondArray,
Array, DecimalArray, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, StringArray, Time64NanosecondArray,
TimestampNanosecondArray, UInt32Array,
},
datatypes::*,
Expand Down Expand Up @@ -268,23 +268,16 @@ mod tests {
}};
}

fn create_decimal_array(
array: &[i128],
precision: usize,
scale: usize,
) -> Result<DecimalArray> {
let mut decimal_builder = DecimalBuilder::new(array.len(), precision, scale);
for value in array {
decimal_builder.append_value(*value)?
}
decimal_builder.append_null()?;
Ok(decimal_builder.finish())
}

#[test]
fn test_cast_decimal_to_decimal() -> Result<()> {
let array: Vec<i128> = vec![1234, 2222, 3, 4000, 5000];
let decimal_array = create_decimal_array(&array, 10, 3)?;
let array = vec![1234, 2222, 3, 4000, 5000];

let decimal_array = array
.iter()
.map(|v| Some(*v))
.collect::<DecimalArray>()
.with_precision_and_scale(10, 3)?;

generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 3),
Expand All @@ -301,7 +294,12 @@ mod tests {
DEFAULT_DATAFUSION_CAST_OPTIONS
);

let decimal_array = create_decimal_array(&array, 10, 3)?;
let decimal_array = array
.iter()
.map(|v| Some(*v))
.collect::<DecimalArray>()
.with_precision_and_scale(10, 3)?;

generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 3),
Expand All @@ -323,9 +321,12 @@ mod tests {

#[test]
fn test_cast_decimal_to_numeric() -> Result<()> {
let array: Vec<i128> = vec![1, 2, 3, 4, 5];
let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None];
// decimal to i8
let decimal_array = create_decimal_array(&array, 10, 0)?;
let decimal_array = array
.iter()
.collect::<DecimalArray>()
.with_precision_and_scale(10, 0)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
Expand All @@ -341,8 +342,12 @@ mod tests {
],
DEFAULT_DATAFUSION_CAST_OPTIONS
);

// decimal to i16
let decimal_array = create_decimal_array(&array, 10, 0)?;
let decimal_array = array
.iter()
.collect::<DecimalArray>()
.with_precision_and_scale(10, 0)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
Expand All @@ -358,8 +363,12 @@ mod tests {
],
DEFAULT_DATAFUSION_CAST_OPTIONS
);

// decimal to i32
let decimal_array = create_decimal_array(&array, 10, 0)?;
let decimal_array = array
.iter()
.collect::<DecimalArray>()
.with_precision_and_scale(10, 0)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
Expand All @@ -375,8 +384,12 @@ mod tests {
],
DEFAULT_DATAFUSION_CAST_OPTIONS
);

// decimal to i64
let decimal_array = create_decimal_array(&array, 10, 0)?;
let decimal_array = array
.iter()
.collect::<DecimalArray>()
.with_precision_and_scale(10, 0)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
Expand All @@ -392,9 +405,20 @@ mod tests {
],
DEFAULT_DATAFUSION_CAST_OPTIONS
);

// decimal to float32
let array: Vec<i128> = vec![1234, 2222, 3, 4000, 5000];
let decimal_array = create_decimal_array(&array, 10, 3)?;
let array = vec![
Some(1234),
Some(2222),
Some(3),
Some(4000),
Some(5000),
None,
];
let decimal_array = array
.iter()
.collect::<DecimalArray>()
.with_precision_and_scale(10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 3),
Expand All @@ -410,8 +434,12 @@ mod tests {
],
DEFAULT_DATAFUSION_CAST_OPTIONS
);

// decimal to float64
let decimal_array = create_decimal_array(&array, 20, 6)?;
let decimal_array = array
.into_iter()
.collect::<DecimalArray>()
.with_precision_and_scale(20, 6)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(20, 6),
Expand Down
Loading

0 comments on commit d082fa3

Please sign in to comment.