Skip to content

Commit

Permalink
use IntervalCompound instead of interval-month-day-nano UDT
Browse files Browse the repository at this point in the history
  • Loading branch information
Blizzara committed Aug 20, 2024
1 parent 5347f30 commit a639feb
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 122 deletions.
59 changes: 50 additions & 9 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,14 @@ use crate::variation_const::{
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF,
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF,
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF,
};
#[allow(deprecated)]
use crate::variation_const::{
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF,
TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF,
TIMESTAMP_SECOND_TYPE_VARIATION_REF,
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME,
INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF,
TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF,
TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF,
};
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::common::scalar::ScalarStructBuilder;
Expand All @@ -69,10 +68,10 @@ use datafusion::{
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode;
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::literal::{
IntervalDayToSecond, IntervalYearToMonth, UserDefined,
interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth,
UserDefined,
};
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction};
Expand Down Expand Up @@ -1505,8 +1504,13 @@ fn from_substrait_type(
Ok(DataType::Interval(IntervalUnit::YearMonth))
}
r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)),
r#type::Kind::IntervalCompound(_) => {
Ok(DataType::Interval(IntervalUnit::MonthDayNano))
}
r#type::Kind::UserDefined(u) => {
// Kept for backwards compatibility, use IntervalCompound instead
if let Some(name) = extensions.types.get(&u.type_reference) {
#[allow(deprecated)]
match name.as_ref() {
INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
_ => not_impl_err!(
Expand All @@ -1516,7 +1520,7 @@ fn from_substrait_type(
),
}
} else {
// Kept for backwards compatibility, new plans should include the extension instead
// Kept for backwards compatibility, use IntervalCompound instead
#[allow(deprecated)]
match u.type_reference {
// Kept for backwards compatibility, use IntervalYear instead
Expand Down Expand Up @@ -1946,6 +1950,7 @@ fn from_substrait_literal(
subseconds,
precision_mode,
})) => {
use interval_day_to_second::PrecisionMode;
// DF only supports millisecond precision, so for any more granular type we lose precision
let milliseconds = match precision_mode {
Some(PrecisionMode::Microseconds(ms)) => ms / 1000,
Expand All @@ -1965,6 +1970,39 @@ fn from_substrait_literal(
Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => {
ScalarValue::new_interval_ym(*years, *months)
}
Some(LiteralType::IntervalCompound(IntervalCompound {
interval_year_to_month,
interval_day_to_second,
})) => match (interval_year_to_month, interval_day_to_second) {
(
Some(IntervalYearToMonth { years, months }),
Some(IntervalDayToSecond {
days,
seconds,
subseconds,
precision_mode:
Some(interval_day_to_second::PrecisionMode::Precision(p)),
}),
) => {
let nanos = match p {
0 => *subseconds * 1_000_000_000,
3 => *subseconds * 1_000_000,
6 => *subseconds * 1_000,
9 => *subseconds,
_ => {
return not_impl_err!(
"Unsupported Substrait interval day to second precision mode"
)
}
};
ScalarValue::new_interval_mdn(
*years * 12 + months,
*days,
*seconds as i64 * 1_000_000_000 + nanos,
)
}
_ => return not_impl_err!("Unsupported Substrait compound interval literal"),
},
Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())),
Some(LiteralType::UserDefined(user_defined)) => {
// Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed
Expand Down Expand Up @@ -1995,6 +2033,8 @@ fn from_substrait_literal(

if let Some(name) = extensions.types.get(&user_defined.type_reference) {
match name.as_ref() {
// Kept for backwards compatibility - new plans should use IntervalCompound instead
#[allow(deprecated)]
INTERVAL_MONTH_DAY_NANO_TYPE_NAME => {
interval_month_day_nano(user_defined)?
}
Expand Down Expand Up @@ -2045,6 +2085,7 @@ fn from_substrait_literal(
milliseconds,
}))
}
// Kept for backwards compatibility, use IntervalCompound instead
INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
interval_month_day_nano(user_defined)?
}
Expand Down
119 changes: 24 additions & 95 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use itertools::Itertools;
use std::ops::Deref;
use std::sync::Arc;

use arrow_buffer::ToByteSlice;
use datafusion::arrow::datatypes::IntervalUnit;
use datafusion::logical_expr::{
CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits,
Expand All @@ -37,8 +36,7 @@ use crate::variation_const::{
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF,
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF,
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF,
};
use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
use datafusion::common::{
Expand All @@ -56,8 +54,8 @@ use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode;
use substrait::proto::expression::literal::map::KeyValue;
use substrait::proto::expression::literal::{
user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map,
PrecisionTimestamp, Struct, UserDefined,
IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, Map,
PrecisionTimestamp, Struct,
};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
Expand Down Expand Up @@ -1429,16 +1427,14 @@ fn to_substrait_type(
})),
}),
IntervalUnit::MonthDayNano => {
// Substrait doesn't currently support this type, so we represent it as a UDT
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::UserDefined(r#type::UserDefined {
type_reference: extensions.register_type(
INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(),
),
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability,
type_parameters: vec![],
})),
kind: Some(r#type::Kind::IntervalCompound(
r#type::IntervalCompound {
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability,
precision: 9, // nanos
},
)),
})
}
}
Expand Down Expand Up @@ -1880,23 +1876,21 @@ fn to_substrait_literal(
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::IntervalMonthDayNano(Some(i)) => {
// IntervalMonthDayNano is internally represented as a 128-bit integer, containing
// months (32bit), days (32bit), and nanoseconds (64bit)
let bytes = i.to_byte_slice();
(
LiteralType::UserDefined(UserDefined {
type_reference: extensions
.register_type(INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()),
type_parameters: vec![],
val: Some(user_defined::Val::Value(ProtoAny {
type_url: INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(),
value: bytes.to_vec().into(),
})),
ScalarValue::IntervalMonthDayNano(Some(i)) => (
LiteralType::IntervalCompound(IntervalCompound {
interval_year_to_month: Some(IntervalYearToMonth {
years: 0,
months: i.months,
}),
DEFAULT_TYPE_VARIATION_REF,
)
}
interval_day_to_second: Some(IntervalDayToSecond {
days: i.days,
seconds: 0,
subseconds: i.nanoseconds,
precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds
}),
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::IntervalDayTime(Some(i)) => (
LiteralType::IntervalDayToSecond(IntervalDayToSecond {
days: i.days,
Expand Down Expand Up @@ -2161,7 +2155,6 @@ mod test {
};
use datafusion::arrow::datatypes::Field;
use datafusion::common::scalar::ScalarStructBuilder;
use std::collections::HashMap;

#[test]
fn round_trip_literals() -> Result<()> {
Expand Down Expand Up @@ -2292,39 +2285,6 @@ mod test {
Ok(())
}

#[test]
fn custom_type_literal_extensions() -> Result<()> {
let mut extensions = Extensions::default();
// IntervalMonthDayNano is represented as a custom type in Substrait
let scalar = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new(
17, 25, 1234567890,
)));
let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?;
let roundtrip_scalar =
from_substrait_literal_without_names(&substrait_literal, &extensions)?;
assert_eq!(scalar, roundtrip_scalar);

assert_eq!(
extensions,
Extensions {
functions: HashMap::new(),
types: HashMap::from([(
0,
INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()
)]),
type_variations: HashMap::new(),
}
);

// Check we fail if we don't propagate extensions
assert!(from_substrait_literal_without_names(
&substrait_literal,
&Extensions::default()
)
.is_err());
Ok(())
}

#[test]
fn round_trip_types() -> Result<()> {
round_trip_type(DataType::Boolean)?;
Expand Down Expand Up @@ -2403,35 +2363,4 @@ mod test {
assert_eq!(dt, roundtrip_dt);
Ok(())
}

#[test]
fn custom_type_extensions() -> Result<()> {
let mut extensions = Extensions::default();
// IntervalMonthDayNano is represented as a custom type in Substrait
let dt = DataType::Interval(IntervalUnit::MonthDayNano);

let substrait = to_substrait_type(&dt, true, &mut extensions)?;
let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?;
assert_eq!(dt, roundtrip_dt);

assert_eq!(
extensions,
Extensions {
functions: HashMap::new(),
types: HashMap::from([(
0,
INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()
)]),
type_variations: HashMap::new(),
}
);

// Check we fail if we don't propagate extensions
assert!(
from_substrait_type_without_names(&substrait, &Extensions::default())
.is_err()
);

Ok(())
}
}
6 changes: 5 additions & 1 deletion datafusion/substrait/src/variation_const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2;
/// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano
#[deprecated(
since = "41.0.0",
note = "Use Substrait `UserDefinedType` with name `INTERVAL_MONTH_DAY_NANO_TYPE_NAME` instead"
note = "Use Substrait `IntervalCompund` type instead"
)]
pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3;

/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
///
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
#[deprecated(
since = "42.0.0",
note = "Use Substrait `IntervalCompund` type instead"
)]
pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano";
17 changes: 0 additions & 17 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,23 +204,6 @@ async fn select_with_reused_functions() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_udt_extensions() -> Result<()> {
let ctx = create_context().await?;
let proto =
roundtrip_with_ctx("SELECT INTERVAL '1 YEAR 1 DAY 1 SECOND' FROM data", ctx)
.await?;
let expected_type = SimpleExtensionDeclaration {
mapping_type: Some(MappingType::ExtensionType(ExtensionType {
extension_uri_reference: u32::MAX,
type_anchor: 0,
name: "interval-month-day-nano".to_string(),
})),
};
assert_eq!(proto.extensions, vec![expected_type]);
Ok(())
}

#[tokio::test]
async fn select_with_filter_date() -> Result<()> {
roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS DATE)").await
Expand Down

0 comments on commit a639feb

Please sign in to comment.