Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cast between Durations + between Durations all numeric types #6452

Merged
merged 7 commits into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 153 additions & 21 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
| Time64(Microsecond)
| Time64(Nanosecond),
) => true,
(Int64, Duration(_)) => true,
(Duration(_), Int64) => true,
(_, Duration(_)) if from_type.is_numeric() => true,
(Duration(_), _) if to_type.is_numeric() => true,
(Duration(_), Duration(_)) => true,
(Interval(from_type), Int64) => {
match from_type {
YearMonth => true,
Expand Down Expand Up @@ -518,6 +519,15 @@ fn make_timestamp_array(
}
}

fn make_duration_array(array: &PrimitiveArray<Int64Type>, unit: TimeUnit) -> ArrayRef {
match unit {
TimeUnit::Second => Arc::new(array.reinterpret_cast::<DurationSecondType>()),
TimeUnit::Millisecond => Arc::new(array.reinterpret_cast::<DurationMillisecondType>()),
TimeUnit::Microsecond => Arc::new(array.reinterpret_cast::<DurationMicrosecondType>()),
TimeUnit::Nanosecond => Arc::new(array.reinterpret_cast::<DurationNanosecondType>()),
}
}

fn as_time_res_with_timezone<T: ArrowPrimitiveType>(
v: i64,
tz: Option<Tz>,
Expand Down Expand Up @@ -2074,31 +2084,53 @@ pub fn cast_with_options(
.as_primitive::<Date32Type>()
.unary::<_, TimestampNanosecondType>(|x| (x as i64) * NANOSECONDS_IN_DAY),
)),
(Int64, Duration(TimeUnit::Second)) => {
cast_reinterpret_arrays::<Int64Type, DurationSecondType>(array)
}
(Int64, Duration(TimeUnit::Millisecond)) => {
cast_reinterpret_arrays::<Int64Type, DurationMillisecondType>(array)
}
(Int64, Duration(TimeUnit::Microsecond)) => {
cast_reinterpret_arrays::<Int64Type, DurationMicrosecondType>(array)

(_, Duration(unit)) if from_type.is_numeric() => {
let array = cast_with_options(array, &Int64, cast_options)?;
Ok(make_duration_array(array.as_primitive(), *unit))
}
(Int64, Duration(TimeUnit::Nanosecond)) => {
cast_reinterpret_arrays::<Int64Type, DurationNanosecondType>(array)
(Duration(TimeUnit::Second), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<DurationSecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}

(Duration(TimeUnit::Second), Int64) => {
cast_reinterpret_arrays::<DurationSecondType, Int64Type>(array)
(Duration(TimeUnit::Millisecond), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<DurationMillisecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}
(Duration(TimeUnit::Millisecond), Int64) => {
cast_reinterpret_arrays::<DurationMillisecondType, Int64Type>(array)
(Duration(TimeUnit::Microsecond), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<DurationMicrosecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}
(Duration(TimeUnit::Microsecond), Int64) => {
cast_reinterpret_arrays::<DurationMicrosecondType, Int64Type>(array)
(Duration(TimeUnit::Nanosecond), _) if to_type.is_numeric() => {
let array = cast_reinterpret_arrays::<DurationNanosecondType, Int64Type>(array)?;
cast_with_options(&array, to_type, cast_options)
}
(Duration(TimeUnit::Nanosecond), Int64) => {
cast_reinterpret_arrays::<DurationNanosecondType, Int64Type>(array)

(Duration(from_unit), Duration(to_unit)) => {
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = array.as_primitive::<Int64Type>();
let from_size = time_unit_multiple(from_unit);
let to_size = time_unit_multiple(to_unit);
// we either divide or multiply, depending on size of each unit
// units are never the same when the types are the same
let converted = match from_size.cmp(&to_size) {
Ordering::Greater => {
let divisor = from_size / to_size;
time_array.unary::<_, Int64Type>(|o| o / divisor)
}
Ordering::Equal => time_array.clone(),
Ordering::Less => {
let mul = to_size / from_size;
if cast_options.safe {
time_array.unary_opt::<_, Int64Type>(|o| o.checked_mul(mul))
} else {
time_array.try_unary::<_, Int64Type, _>(|o| o.mul_checked(mul))?
}
}
};
Ok(make_duration_array(&converted, *to_unit))
}

(Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationSecondType>(array, cast_options)
}
Expand Down Expand Up @@ -5254,6 +5286,106 @@ mod tests {
}
}

#[test]
fn test_cast_between_durations_and_numerics() {
fn test_cast_between_durations<FromType, ToType>()
where
FromType: ArrowPrimitiveType<Native = i64>,
ToType: ArrowPrimitiveType<Native = i64>,
PrimitiveArray<FromType>: From<Vec<Option<i64>>>,
{
let from_unit = match FromType::DATA_TYPE {
DataType::Duration(unit) => unit,
_ => panic!("Expected a duration type"),
};
let to_unit = match ToType::DATA_TYPE {
DataType::Duration(unit) => unit,
_ => panic!("Expected a duration type"),
};
let from_size = time_unit_multiple(&from_unit);
let to_size = time_unit_multiple(&to_unit);

let (v1_before, v2_before) = (8640003005, 1696002001);
let (v1_after, v2_after) = if from_size >= to_size {
(
v1_before / (from_size / to_size),
v2_before / (from_size / to_size),
)
} else {
(
v1_before * (to_size / from_size),
v2_before * (to_size / from_size),
)
};

let array =
PrimitiveArray::<FromType>::from(vec![Some(v1_before), Some(v2_before), None]);
let b = cast(&array, &ToType::DATA_TYPE).unwrap();
let c = b.as_primitive::<ToType>();
assert_eq!(v1_after, c.value(0));
assert_eq!(v2_after, c.value(1));
assert!(c.is_null(2));
}

// between each individual duration type
test_cast_between_durations::<DurationSecondType, DurationMillisecondType>();
test_cast_between_durations::<DurationSecondType, DurationMicrosecondType>();
test_cast_between_durations::<DurationSecondType, DurationNanosecondType>();
test_cast_between_durations::<DurationMillisecondType, DurationSecondType>();
test_cast_between_durations::<DurationMillisecondType, DurationMicrosecondType>();
test_cast_between_durations::<DurationMillisecondType, DurationNanosecondType>();
test_cast_between_durations::<DurationMicrosecondType, DurationSecondType>();
test_cast_between_durations::<DurationMicrosecondType, DurationMillisecondType>();
test_cast_between_durations::<DurationMicrosecondType, DurationNanosecondType>();
test_cast_between_durations::<DurationNanosecondType, DurationSecondType>();
test_cast_between_durations::<DurationNanosecondType, DurationMillisecondType>();
test_cast_between_durations::<DurationNanosecondType, DurationMicrosecondType>();

// cast failed
let array = DurationSecondArray::from(vec![
Some(i64::MAX),
Some(8640203410378005),
Some(10241096),
None,
]);
let b = cast(&array, &DataType::Duration(TimeUnit::Nanosecond)).unwrap();
let c = b.as_primitive::<DurationNanosecondType>();
assert!(c.is_null(0));
assert!(c.is_null(1));
assert_eq!(10241096000000000, c.value(2));
assert!(c.is_null(3));

// durations to numerics
let array = DurationSecondArray::from(vec![
Some(i64::MAX),
Some(8640203410378005),
Some(10241096),
None,
]);
let b = cast(&array, &DataType::Int64).unwrap();
let c = b.as_primitive::<Int64Type>();
assert_eq!(i64::MAX, c.value(0));
assert_eq!(8640203410378005, c.value(1));
assert_eq!(10241096, c.value(2));
assert!(c.is_null(3));

let b = cast(&array, &DataType::Int32).unwrap();
let c = b.as_primitive::<Int32Type>();
assert_eq!(0, c.value(0));
assert_eq!(0, c.value(1));
assert_eq!(10241096, c.value(2));
assert!(c.is_null(3));

// numerics to durations
let array = Int32Array::from(vec![Some(i32::MAX), Some(802034103), Some(10241096), None]);
let b = cast(&array, &DataType::Duration(TimeUnit::Second)).unwrap();
let c = b.as_any().downcast_ref::<DurationSecondArray>().unwrap();
assert_eq!(i32::MAX as i64, c.value(0));
assert_eq!(802034103, c.value(1));
assert_eq!(10241096, c.value(2));
assert!(c.is_null(3));
}

#[test]
fn test_cast_to_strings() {
let a = Int32Array::from(vec![1, 2, 3]);
Expand Down
Loading