From b2fc8e75101eb3d9ac98fa58f9a5c980c5908320 Mon Sep 17 00:00:00 2001 From: "yukkit.zhang" Date: Sun, 12 Nov 2023 14:07:00 +0800 Subject: [PATCH] feat: try to support user-defined types --- datafusion-examples/examples/rewrite_expr.rs | 3 +- datafusion/common/src/column.rs | 3 +- datafusion/common/src/dfschema.rs | 195 +++++---- datafusion/common/src/lib.rs | 1 + datafusion/common/src/logical_type.rs | 411 ++++++++++++++++++ datafusion/common/src/scalar.rs | 206 ++++++++- datafusion/core/src/dataframe/mod.rs | 17 +- .../core/src/datasource/listing/helpers.rs | 2 +- .../core/src/datasource/listing/table.rs | 5 +- datafusion/core/src/datasource/memory.rs | 2 +- .../physical_plan/parquet/row_groups.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 81 +++- .../core/src/physical_optimizer/pruning.rs | 60 +-- datafusion/core/src/test/variable.rs | 7 +- datafusion/core/tests/dataframe/mod.rs | 3 +- datafusion/core/tests/simplification.rs | 5 +- .../expr/src/conditional_expressions.rs | 9 +- datafusion/expr/src/expr.rs | 24 +- datafusion/expr/src/expr_fn.rs | 7 +- datafusion/expr/src/expr_rewriter/mod.rs | 6 +- datafusion/expr/src/expr_rewriter/order_by.rs | 9 +- datafusion/expr/src/expr_schema.rs | 151 +++++-- datafusion/expr/src/field_util.rs | 86 +++- datafusion/expr/src/logical_plan/builder.rs | 45 +- datafusion/expr/src/logical_plan/ddl.rs | 18 + datafusion/expr/src/logical_plan/display.rs | 20 +- datafusion/expr/src/logical_plan/mod.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 63 +-- datafusion/expr/src/type_coercion/mod.rs | 16 +- datafusion/expr/src/type_coercion/other.rs | 15 +- datafusion/expr/src/utils.rs | 50 +-- .../optimizer/src/analyzer/type_coercion.rs | 111 +++-- .../optimizer/src/common_subexpr_eliminate.rs | 10 +- .../src/decorrelate_predicate_subquery.rs | 81 ++-- .../optimizer/src/eliminate_outer_join.rs | 5 +- .../optimizer/src/push_down_projection.rs | 20 +- .../optimizer/src/scalar_subquery_to_join.rs | 3 +- .../src/simplify_expressions/context.rs | 11 +- .../simplify_expressions/expr_simplifier.rs | 93 ++-- .../src/simplify_expressions/utils.rs | 26 +- .../src/unwrap_cast_in_comparison.rs | 125 +++--- .../optimizer/tests/optimizer_integration.rs | 3 +- .../src/expressions/get_indexed_field.rs | 33 +- datafusion/physical-expr/src/planner.rs | 5 +- datafusion/physical-expr/src/var_provider.rs | 5 +- datafusion/proto/src/logical_plan/mod.rs | 5 +- datafusion/sql/examples/sql.rs | 3 +- datafusion/sql/src/expr/arrow_cast.rs | 2 +- datafusion/sql/src/expr/mod.rs | 9 +- datafusion/sql/src/expr/value.rs | 5 +- datafusion/sql/src/planner.rs | 82 ++-- datafusion/sql/src/relation/join.rs | 10 +- datafusion/sql/src/statement.rs | 111 +++-- datafusion/sql/src/utils.rs | 9 +- datafusion/sql/tests/sql_integration.rs | 29 +- 55 files changed, 1656 insertions(+), 664 deletions(-) create mode 100644 datafusion/common/src/logical_type.rs diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 5e95562033e6..b70ba2745edc 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -17,6 +17,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; +use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ @@ -212,7 +213,7 @@ impl ContextProvider for MyContextProvider { None } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 2e729c128e73..8821966f0360 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -357,7 +357,6 @@ impl fmt::Display for Column { mod tests { use super::*; use crate::DFField; - use arrow::datatypes::DataType; use std::collections::HashMap; fn create_schema(names: &[(Option<&str>, &str)]) -> Result { @@ -367,7 +366,7 @@ mod tests { DFField::new( qualifier.to_owned().map(|s| s.to_string()), name, - DataType::Boolean, + LogicalType::Boolean, true, ) }) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index d8cd103a4777..303d386b4827 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -19,14 +19,14 @@ //! fields with optional relation names. use std::collections::{HashMap, HashSet}; -use std::convert::TryFrom; use std::fmt::{Display, Formatter}; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::error::{ unqualified_field_not_found, DataFusionError, Result, SchemaError, _plan_err, }; +use crate::logical_type::{ExtensionType, LogicalType}; use crate::{ field_not_found, Column, FunctionalDependencies, OwnedTableReference, TableReference, }; @@ -116,7 +116,7 @@ impl DFSchema { qualifier: impl Into>, schema: &Schema, ) -> Result { - let qualifier = qualifier.into(); + let qualifier: TableReference<'a> = qualifier.into(); Self::new_with_metadata( schema .fields() @@ -127,6 +127,22 @@ impl DFSchema { ) } + /// Create a `DFSchema` from an Arrow schema and a given qualifier + pub fn try_from_qualified_dfschema<'a>( + qualifier: impl Into>, + schema: &DFSchema, + ) -> Result { + let qualifier: TableReference<'a> = qualifier.into(); + Self::new_with_metadata( + schema + .fields() + .iter() + .map(|f| f.clone().with_qualifier(qualifier.to_owned_reference())) + .collect(), + schema.metadata().clone(), + ) + } + /// Assigns functional dependencies. pub fn with_functional_dependencies( mut self, @@ -427,7 +443,7 @@ impl DFSchema { self_fields.zip(other_fields).all(|(f1, f2)| { f1.qualifier() == f2.qualifier() && f1.name() == f2.name() - && Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type()) + && Self::datatype_is_logically_equal(f1.data_type(), f2.data_type()) }) } @@ -435,40 +451,8 @@ impl DFSchema { /// than datatype_is_semantically_equal in that a Dictionary type is logically /// equal to a plain V type, but not semantically equal. Dictionary is also /// logically equal to Dictionary. - fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool { - // check nested fields - match (dt1, dt2) { - (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { - v1.as_ref() == v2.as_ref() - } - (DataType::Dictionary(_, v1), othertype) => v1.as_ref() == othertype, - (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, - (DataType::List(f1), DataType::List(f2)) - | (DataType::LargeList(f1), DataType::LargeList(f2)) - | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) - | (DataType::Map(f1, _), DataType::Map(f2, _)) => { - Self::field_is_logically_equal(f1, f2) - } - (DataType::Struct(fields1), DataType::Struct(fields2)) => { - let iter1 = fields1.iter(); - let iter2 = fields2.iter(); - fields1.len() == fields2.len() && - // all fields have to be the same - iter1 - .zip(iter2) - .all(|(f1, f2)| Self::field_is_logically_equal(f1, f2)) - } - (DataType::Union(fields1, _), DataType::Union(fields2, _)) => { - let iter1 = fields1.iter(); - let iter2 = fields2.iter(); - fields1.len() == fields2.len() && - // all fields have to be the same - iter1 - .zip(iter2) - .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_logically_equal(f1, f2)) - } - _ => dt1 == dt2, - } + fn datatype_is_logically_equal(dt1: &LogicalType, dt2: &LogicalType) -> bool { + dt1 == dt2 } /// Returns true of two [`DataType`]s are semantically equal (same @@ -518,11 +502,6 @@ impl DFSchema { } } - fn field_is_logically_equal(f1: &Field, f2: &Field) -> bool { - f1.name() == f2.name() - && Self::datatype_is_logically_equal(f1.data_type(), f2.data_type()) - } - fn field_is_semantically_equal(f1: &Field, f2: &Field) -> bool { f1.name() == f2.name() && Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type()) @@ -547,7 +526,7 @@ impl DFSchema { fields: self .fields .into_iter() - .map(|f| DFField::from_qualified(qualifier.clone(), f.field)) + .map(|f| f.with_qualifier(qualifier.clone())) .collect(), ..self } @@ -575,7 +554,14 @@ impl DFSchema { impl From for Schema { /// Convert DFSchema into a Schema fn from(df_schema: DFSchema) -> Self { - let fields: Fields = df_schema.fields.into_iter().map(|f| f.field).collect(); + let fields: Fields = df_schema + .fields + .into_iter() + .map(|f| { + Field::new(f.name, f.data_type.physical_type(), f.nullable) + .with_metadata(f.metadata) + }) + .collect(); Schema::new_with_metadata(fields, df_schema.metadata) } } @@ -583,7 +569,14 @@ impl From for Schema { impl From<&DFSchema> for Schema { /// Convert DFSchema reference into a Schema fn from(df_schema: &DFSchema) -> Self { - let fields: Fields = df_schema.fields.iter().map(|f| f.field.clone()).collect(); + let fields: Fields = df_schema + .fields + .iter() + .map(|f| { + Field::new(f.name(), f.data_type().physical_type(), f.is_nullable()) + .with_metadata(f.metadata.clone()) + }) + .collect(); Schema::new_with_metadata(fields, df_schema.metadata.clone()) } } @@ -596,7 +589,14 @@ impl TryFrom for DFSchema { schema .fields() .iter() - .map(|f| DFField::from(f.clone())) + .map(|f| { + DFField::new_unqualified( + f.name(), + f.data_type().into(), + f.is_nullable(), + ) + .with_metadata(f.metadata().clone()) + }) .collect(), schema.metadata().clone(), ) @@ -679,7 +679,7 @@ pub trait ExprSchema: std::fmt::Debug { fn nullable(&self, col: &Column) -> Result; /// What is the datatype of this column? - fn data_type(&self, col: &Column) -> Result<&DataType>; + fn data_type(&self, col: &Column) -> Result<&LogicalType>; /// Returns the column's optional metadata. fn metadata(&self, col: &Column) -> Result<&HashMap>; @@ -691,7 +691,7 @@ impl + std::fmt::Debug> ExprSchema for P { self.as_ref().nullable(col) } - fn data_type(&self, col: &Column) -> Result<&DataType> { + fn data_type(&self, col: &Column) -> Result<&LogicalType> { self.as_ref().data_type(col) } @@ -705,7 +705,7 @@ impl ExprSchema for DFSchema { Ok(self.field_from_column(col)?.is_nullable()) } - fn data_type(&self, col: &Column) -> Result<&DataType> { + fn data_type(&self, col: &Column) -> Result<&LogicalType> { Ok(self.field_from_column(col)?.data_type()) } @@ -715,12 +715,34 @@ impl ExprSchema for DFSchema { } /// DFField wraps an Arrow field and adds an optional qualifier -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct DFField { /// Optional qualifier (usually a table or relation name) qualifier: Option, /// Arrow field definition - field: FieldRef, + // field: FieldRef, + name: String, + data_type: LogicalType, + nullable: bool, + /// A map of key-value pairs containing additional custom meta data. + metadata: HashMap, +} + +impl Hash for DFField { + fn hash(&self, state: &mut H) { + self.qualifier.hash(state); + self.name.hash(state); + self.data_type.hash(state); + self.nullable.hash(state); + + // ensure deterministic key order + let mut keys: Vec<&String> = self.metadata.keys().collect(); + keys.sort(); + for k in keys { + k.hash(state); + self.metadata.get(k).expect("key valid").hash(state); + } + } } impl DFField { @@ -728,20 +750,26 @@ impl DFField { pub fn new>( qualifier: Option, name: &str, - data_type: DataType, + data_type: LogicalType, nullable: bool, ) -> Self { DFField { qualifier: qualifier.map(|s| s.into()), - field: Arc::new(Field::new(name, data_type, nullable)), + name: name.to_string(), + data_type, + nullable, + metadata: Default::default(), } } /// Convenience method for creating new `DFField` without a qualifier - pub fn new_unqualified(name: &str, data_type: DataType, nullable: bool) -> Self { + pub fn new_unqualified(name: &str, data_type: LogicalType, nullable: bool) -> Self { DFField { qualifier: None, - field: Arc::new(Field::new(name, data_type, nullable)), + name: name.to_string(), + data_type, + nullable, + metadata: Default::default(), } } @@ -750,37 +778,41 @@ impl DFField { qualifier: impl Into>, field: impl Into, ) -> Self { - Self { - qualifier: Some(qualifier.into().to_owned_reference()), - field: field.into(), - } + let field: FieldRef = field.into(); + Self::new( + Some(qualifier.into().to_owned_reference()), + field.name(), + field.data_type().into(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()) } /// Returns an immutable reference to the `DFField`'s unqualified name pub fn name(&self) -> &String { - self.field.name() + &self.name } /// Returns an immutable reference to the `DFField`'s data-type - pub fn data_type(&self) -> &DataType { - self.field.data_type() + pub fn data_type(&self) -> &LogicalType { + &self.data_type } /// Indicates whether this `DFField` supports null values pub fn is_nullable(&self) -> bool { - self.field.is_nullable() + self.nullable } pub fn metadata(&self) -> &HashMap { - self.field.metadata() + &self.metadata } /// Returns a string to the `DFField`'s qualified name pub fn qualified_name(&self) -> String { if let Some(qualifier) = &self.qualifier { - format!("{}.{}", qualifier, self.field.name()) + format!("{}.{}", qualifier, self.name) } else { - self.field.name().to_owned() + self.name.clone() } } @@ -788,7 +820,7 @@ impl DFField { pub fn qualified_column(&self) -> Column { Column { relation: self.qualifier.clone(), - name: self.field.name().to_string(), + name: self.name.clone(), } } @@ -796,7 +828,7 @@ impl DFField { pub fn unqualified_column(&self) -> Column { Column { relation: None, - name: self.field.name().to_string(), + name: self.name.clone(), } } @@ -805,11 +837,6 @@ impl DFField { self.qualifier.as_ref() } - /// Get the arrow field - pub fn field(&self) -> &FieldRef { - &self.field - } - /// Return field with qualifier stripped pub fn strip_qualifier(mut self) -> Self { self.qualifier = None; @@ -818,25 +845,27 @@ impl DFField { /// Return field with nullable specified pub fn with_nullable(mut self, nullable: bool) -> Self { - let f = self.field().as_ref().clone().with_nullable(nullable); - self.field = f.into(); + self.nullable = nullable; self } /// Return field with new metadata pub fn with_metadata(mut self, metadata: HashMap) -> Self { - let f = self.field().as_ref().clone().with_metadata(metadata); - self.field = f.into(); + self.metadata = metadata; + self + } + + /// Return field with new qualifier + pub fn with_qualifier(mut self, qualifier: impl Into) -> Self { + self.qualifier = Some(qualifier.into()); self } } impl From for DFField { fn from(value: FieldRef) -> Self { - Self { - qualifier: None, - field: value, - } + Self::new_unqualified(value.name(), value.data_type().into(), value.is_nullable()) + .with_metadata(value.metadata().clone()) } } @@ -890,7 +919,7 @@ impl SchemaExt for Schema { .zip(other.fields().iter()) .all(|(f1, f2)| { f1.name() == f2.name() - && DFSchema::datatype_is_logically_equal( + && DFSchema::datatype_is_semantically_equal( f1.data_type(), f2.data_type(), ) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 53c3cfddff8d..dc40fae61007 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -33,6 +33,7 @@ pub mod display; pub mod file_options; pub mod format; pub mod hash_utils; +pub mod logical_type; pub mod parsers; pub mod scalar; pub mod stats; diff --git a/datafusion/common/src/logical_type.rs b/datafusion/common/src/logical_type.rs new file mode 100644 index 000000000000..1f1cf9c52006 --- /dev/null +++ b/datafusion/common/src/logical_type.rs @@ -0,0 +1,411 @@ +use std::{borrow::Cow, fmt::Display, sync::Arc}; + +use crate::error::Result; +use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; + +#[derive(Clone, Debug)] +pub enum LogicalType { + Null, + Boolean, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float16, + Float32, + Float64, + String, + LargeString, + Date32, + Date64, + Time32(TimeUnit), + Time64(TimeUnit), + Timestamp(TimeUnit, Option>), + Duration(TimeUnit), + Interval(IntervalUnit), + Binary, + FixedSizeBinary(i32), + LargeBinary, + Utf8, + LargeUtf8, + List(Box), + FixedSizeList(Box, i32), + LargeList(Box), + Struct(Fields), + Map(NamedLogicalTypeRef, bool), + // union + Decimal128(u8, i8), + Decimal256(u8, i8), + Extension(ExtensionTypeRef), +} + +impl PartialEq for LogicalType { + fn eq(&self, other: &Self) -> bool { + self.type_signature() == other.type_signature() + } +} + +impl Eq for LogicalType {} + +impl std::hash::Hash for LogicalType { + fn hash(&self, state: &mut H) { + self.type_signature().hash(state) + } +} + +impl Display for LogicalType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.display_name()) + } +} + +pub type Fields = Arc<[NamedLogicalTypeRef]>; +pub type NamedLogicalTypeRef = Arc; + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct NamedLogicalType { + name: String, + data_type: LogicalType, +} + +impl NamedLogicalType { + pub fn new(name: impl Into, data_type: LogicalType) -> Self { + Self { + name: name.into(), + data_type, + } + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn data_type(&self) -> &LogicalType { + &self.data_type + } +} + +pub type OwnedTypeSignature = TypeSignature<'static>; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TypeSignature<'a> { + // **func_name**(p1, p2) + name: Cow<'a, str>, + // func_name(**p1**, **p2**) + params: Vec>, +} + +impl<'a> TypeSignature<'a> { + pub fn new(name: impl Into>) -> Self { + Self::new_with_params(name, vec![]) + } + + pub fn new_with_params( + name: impl Into>, + params: Vec>, + ) -> Self { + Self { + name: name.into(), + params, + } + } + + pub fn to_owned_type_signature(&self) -> OwnedTypeSignature { + OwnedTypeSignature { + name: self.name.to_string().into(), + params: self.params.iter().map(|p| p.to_string().into()).collect(), + } + } +} + +pub type ExtensionTypeRef = Arc; + +pub trait ExtensionType: std::fmt::Debug { + fn display_name(&self) -> &str; + fn type_signature(&self) -> TypeSignature; + fn physical_type(&self) -> DataType; + + fn is_comparable(&self) -> bool; + fn is_orderable(&self) -> bool; + fn is_numeric(&self) -> bool; +} + +pub trait TypeManager { + fn register_data_type( + &mut self, + signature: impl Into>, + extension_type: ExtensionTypeRef, + ) -> Result<()>; + + fn data_type(&self, signature: &TypeSignature) -> Result>; +} + +impl ExtensionType for LogicalType { + fn display_name(&self) -> &str { + match self { + Self::Null => "NULL", + Self::Boolean => "BOOLEAN", + Self::Int8 => "INT8", + Self::Int16 => "INT16", + Self::Int32 => "INT32", + Self::Int64 => "INT64", + Self::UInt8 => "UINT8", + Self::UInt16 => "UINT16", + Self::UInt32 => "UINT32", + Self::UInt64 => "UINT64", + Self::Float16 => "FLOAT16", + Self::Float32 => "Float16", + Self::Float64 => "Float64", + Self::String => "String", + Self::LargeString => "LargeString", + Self::Date32 => "Date32", + Self::Date64 => "Date64", + Self::Time32(_) => "Time32", + Self::Time64(_) => "Time64", + Self::Timestamp(_, _) => "Timestamp", + Self::Duration(_) => "Duration", + Self::Interval(_) => "Interval", + Self::Binary => "Binary", + Self::FixedSizeBinary(_) => "FixedSizeBinary", + Self::LargeBinary => "LargeBinary", + Self::Utf8 => "Utf8", + Self::LargeUtf8 => "LargeUtf8", + Self::List(_) => "List", + Self::FixedSizeList(_, _) => "FixedSizeList", + Self::LargeList(_) => "LargeList", + Self::Struct(_) => "Struct", + Self::Map(_, _) => "Map", + Self::Decimal128(_, _) => "Decimal128", + Self::Decimal256(_, _) => "Decimal256", + Self::Extension(ext) => ext.display_name(), + } + } + + fn type_signature(&self) -> TypeSignature { + match self { + Self::Boolean => TypeSignature::new("boolean"), + Self::Int32 => TypeSignature::new("int32"), + Self::Int64 => TypeSignature::new("int64"), + Self::UInt64 => TypeSignature::new("uint64"), + Self::Float32 => TypeSignature::new("float32"), + Self::Float64 => TypeSignature::new("float64"), + Self::String => TypeSignature::new("string"), + Self::Timestamp(tu, zone) => { + let tu = match tu { + TimeUnit::Second => "second", + TimeUnit::Millisecond => "millisecond", + TimeUnit::Microsecond => "microsecond", + TimeUnit::Nanosecond => "nanosecond", + }; + + let params = if let Some(zone) = zone { + vec![tu.into(), zone.as_ref().into()] + } else { + vec![tu.into()] + }; + + TypeSignature::new_with_params("timestamp", params) + } + Self::Binary => TypeSignature::new("binary"), + Self::Utf8 => TypeSignature::new("string"), + Self::Extension(ext) => ext.type_signature(), + Self::Struct(fields) => { + let params = fields.iter().map(|f| f.name().into()).collect(); + TypeSignature::new_with_params("struct", params) + } + other => panic!("not implemented: {other:?}"), + } + } + + fn physical_type(&self) -> DataType { + match self { + Self::Boolean => DataType::Boolean, + Self::Int32 => DataType::Int32, + Self::Int64 => DataType::Int64, + Self::UInt64 => DataType::UInt64, + Self::Float32 => DataType::Float32, + Self::Float64 => DataType::Float64, + Self::String => DataType::Utf8, + Self::Timestamp(tu, zone) => DataType::Timestamp(tu.clone(), zone.clone()), + Self::Binary => DataType::Binary, + Self::Utf8 => DataType::Utf8, + Self::Extension(ext) => ext.physical_type(), + Self::Struct(fields) => { + let fields = fields + .iter() + .map(|f| { + let name = f.name(); + let data_type = f.physical_type(); + Arc::new(Field::new(name, data_type, true)) + }) + .collect::>(); + DataType::Struct(fields.into()) + } + other => panic!("not implemented {other:?}"), + } + } + + fn is_comparable(&self) -> bool { + match self { + Self::Null + | Self::Boolean + | Self::Int8 + | Self::Int16 + | Self::Int32 + | Self::Int64 + | Self::UInt8 + | Self::UInt16 + | Self::UInt32 + | Self::UInt64 + | Self::Float16 + | Self::Float32 + | Self::Float64 + | Self::String + | Self::LargeString + | Self::Date32 + | Self::Date64 + | Self::Time32(_) + | Self::Time64(_) + | Self::Timestamp(_, _) + | Self::Duration(_) + | Self::Interval(_) + | Self::Binary + | Self::FixedSizeBinary(_) + | Self::LargeBinary + | Self::Utf8 + | Self::LargeUtf8 + | Self::Decimal128(_, _) + | Self::Decimal256(_, _) => true, + Self::List(_) => false, + Self::FixedSizeList(_, _) => false, + Self::LargeList(_) => false, + Self::Struct(_) => false, + Self::Map(_, _) => false, + Self::Extension(ext) => ext.is_comparable(), + } + } + + fn is_orderable(&self) -> bool { + todo!() + } + + /// Returns true if this type is numeric: (UInt*, Int*, Float*, Decimal*). + #[inline] + fn is_numeric(&self) -> bool { + use LogicalType::*; + match self { + UInt8 + | UInt16 + | UInt32 + | UInt64 + | Int8 + | Int16 + | Int32 + | Int64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) => true, + Extension(t) => t.is_numeric(), + _ => false, + } + } +} + +impl From<&DataType> for LogicalType { + fn from(value: &DataType) -> Self { + // TODO + value.clone().into() + } +} + +impl From for LogicalType { + fn from(value: DataType) -> Self { + match value { + DataType::Null => LogicalType::Null, + DataType::Boolean => LogicalType::Boolean, + DataType::Int8 => LogicalType::Int8, + DataType::Int16 => LogicalType::Int16, + DataType::Int32 => LogicalType::Int32, + DataType::Int64 => LogicalType::Int64, + DataType::UInt8 => LogicalType::UInt8, + DataType::UInt16 => LogicalType::UInt16, + DataType::UInt32 => LogicalType::UInt32, + DataType::UInt64 => LogicalType::UInt64, + DataType::Float16 => LogicalType::Float16, + DataType::Float32 => LogicalType::Float32, + DataType::Float64 => LogicalType::Float64, + DataType::Timestamp(tu, z) => LogicalType::Timestamp(tu, z), + DataType::Date32 => LogicalType::Date32, + DataType::Date64 => LogicalType::Date64, + DataType::Time32(tu) => LogicalType::Time32(tu), + DataType::Time64(tu) => LogicalType::Time64(tu), + DataType::Duration(tu) => LogicalType::Duration(tu), + DataType::Interval(iu) => LogicalType::Interval(iu), + DataType::Binary => LogicalType::Binary, + DataType::FixedSizeBinary(len) => LogicalType::FixedSizeBinary(len), + DataType::LargeBinary => LogicalType::LargeBinary, + DataType::Utf8 => LogicalType::Utf8, + DataType::LargeUtf8 => LogicalType::LargeUtf8, + DataType::List(f) => LogicalType::List(Box::new(f.data_type().into())), + DataType::FixedSizeList(f, len) => { + LogicalType::FixedSizeList(Box::new(f.data_type().into()), len) + } + DataType::LargeList(f) => { + LogicalType::LargeList(Box::new(f.data_type().into())) + } + DataType::Struct(fields) => { + let fields = fields + .into_iter() + .map(|f| { + let name = f.name(); + let logical_type = f.data_type().into(); + Arc::new(NamedLogicalType::new(name, logical_type)) + }) + .collect::>(); + LogicalType::Struct(fields.into()) + } + DataType::Union(_, _) => unimplemented!(), + DataType::Dictionary(_, dt) => dt.as_ref().into(), + DataType::Decimal128(p, s) => LogicalType::Decimal128(p, s), + DataType::Decimal256(p, s) => LogicalType::Decimal256(p, s), + DataType::Map(data, sorted) => { + let field = + Arc::new(NamedLogicalType::new(data.name(), data.data_type().into())); + LogicalType::Map(field, sorted) + } + DataType::RunEndEncoded(_, f) => f.data_type().into(), + } + } +} + +impl ExtensionType for NamedLogicalType { + fn display_name(&self) -> &str { + &self.name + } + + fn type_signature(&self) -> TypeSignature { + TypeSignature::new(self.name()) + } + + fn physical_type(&self) -> DataType { + self.data_type.physical_type() + } + + fn is_comparable(&self) -> bool { + self.data_type.is_comparable() + } + + fn is_orderable(&self) -> bool { + self.data_type.is_orderable() + } + + fn is_numeric(&self) -> bool { + self.data_type.is_numeric() + } +} diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index cdcc9aa4fbc5..dee0aef26131 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -22,7 +22,8 @@ use std::cmp::Ordering; use std::collections::HashSet; use std::convert::{Infallible, TryInto}; use std::str::FromStr; -use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; +use std::sync::Arc; +use std::{convert::TryFrom, fmt, iter::repeat}; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, @@ -30,10 +31,11 @@ use crate::cast::{ }; use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; +use crate::logical_type::{ExtensionType, LogicalType, NamedLogicalType}; use crate::utils::array_into_list_array; use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute::kernels::numeric::*; -use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder}; +use arrow::datatypes::{i256, Fields, SchemaBuilder}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, @@ -48,6 +50,7 @@ use arrow::{ }; use arrow_array::cast::as_list_array; use arrow_array::{ArrowNativeTypeOp, Scalar}; +use arrow_schema::FieldRef; /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part to arrow's [`Array`]. @@ -844,6 +847,90 @@ impl ScalarValue { }) } + /// return the [`DataType`] of this `ScalarValue` + pub fn logical_type(&self) -> LogicalType { + match self { + ScalarValue::Boolean(_) => LogicalType::Boolean, + ScalarValue::UInt8(_) => LogicalType::UInt8, + ScalarValue::UInt16(_) => LogicalType::UInt16, + ScalarValue::UInt32(_) => LogicalType::UInt32, + ScalarValue::UInt64(_) => LogicalType::UInt64, + ScalarValue::Int8(_) => LogicalType::Int8, + ScalarValue::Int16(_) => LogicalType::Int16, + ScalarValue::Int32(_) => LogicalType::Int32, + ScalarValue::Int64(_) => LogicalType::Int64, + ScalarValue::Decimal128(_, precision, scale) => { + LogicalType::Decimal128(*precision, *scale) + } + ScalarValue::Decimal256(_, precision, scale) => { + LogicalType::Decimal256(*precision, *scale) + } + ScalarValue::TimestampSecond(_, tz_opt) => { + LogicalType::Timestamp(TimeUnit::Second, tz_opt.clone()) + } + ScalarValue::TimestampMillisecond(_, tz_opt) => { + LogicalType::Timestamp(TimeUnit::Millisecond, tz_opt.clone()) + } + ScalarValue::TimestampMicrosecond(_, tz_opt) => { + LogicalType::Timestamp(TimeUnit::Microsecond, tz_opt.clone()) + } + ScalarValue::TimestampNanosecond(_, tz_opt) => { + LogicalType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) + } + ScalarValue::Float32(_) => LogicalType::Float32, + ScalarValue::Float64(_) => LogicalType::Float64, + ScalarValue::Utf8(_) => LogicalType::Utf8, + ScalarValue::LargeUtf8(_) => LogicalType::LargeUtf8, + ScalarValue::Binary(_) => LogicalType::Binary, + ScalarValue::FixedSizeBinary(sz, _) => LogicalType::FixedSizeBinary(*sz), + ScalarValue::LargeBinary(_) => LogicalType::LargeBinary, + ScalarValue::Fixedsizelist(_, field, length) => { + LogicalType::FixedSizeList(Box::new(field.data_type().into()), *length) + } + ScalarValue::List(arr) => arr.data_type().into(), + ScalarValue::Date32(_) => LogicalType::Date32, + ScalarValue::Date64(_) => LogicalType::Date64, + ScalarValue::Time32Second(_) => LogicalType::Time32(TimeUnit::Second), + ScalarValue::Time32Millisecond(_) => { + LogicalType::Time32(TimeUnit::Millisecond) + } + ScalarValue::Time64Microsecond(_) => { + LogicalType::Time64(TimeUnit::Microsecond) + } + ScalarValue::Time64Nanosecond(_) => LogicalType::Time64(TimeUnit::Nanosecond), + ScalarValue::IntervalYearMonth(_) => { + LogicalType::Interval(IntervalUnit::YearMonth) + } + ScalarValue::IntervalDayTime(_) => { + LogicalType::Interval(IntervalUnit::DayTime) + } + ScalarValue::IntervalMonthDayNano(_) => { + LogicalType::Interval(IntervalUnit::MonthDayNano) + } + ScalarValue::DurationSecond(_) => LogicalType::Duration(TimeUnit::Second), + ScalarValue::DurationMillisecond(_) => { + LogicalType::Duration(TimeUnit::Millisecond) + } + ScalarValue::DurationMicrosecond(_) => { + LogicalType::Duration(TimeUnit::Microsecond) + } + ScalarValue::DurationNanosecond(_) => { + LogicalType::Duration(TimeUnit::Nanosecond) + } + ScalarValue::Struct(_, fields) => { + let fields = fields + .iter() + .map(|f| { + Arc::new(NamedLogicalType::new(f.name(), f.data_type().into())) + }) + .collect::>(); + LogicalType::Struct(fields.into()) + } + ScalarValue::Dictionary(_k, v) => v.data_type().into(), + ScalarValue::Null => LogicalType::Null, + } + } + /// return the [`DataType`] of this `ScalarValue` pub fn data_type(&self) -> DataType { match self { @@ -881,10 +968,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::Fixedsizelist(_, field, length) => DataType::FixedSizeList( - Arc::new(Field::new("item", field.data_type().clone(), true)), - *length, - ), + ScalarValue::Fixedsizelist(_, field, length) => { + DataType::FixedSizeList(field.clone(), *length) + } ScalarValue::List(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, @@ -2855,6 +2941,114 @@ impl TryFrom<&DataType> for ScalarValue { } } +impl TryFrom for ScalarValue { + type Error = DataFusionError; + + /// Create a Null instance of ScalarValue for this datatype + fn try_from(datatype: LogicalType) -> Result { + (&datatype).try_into() + } +} + +impl TryFrom<&LogicalType> for ScalarValue { + type Error = DataFusionError; + + /// Create a Null instance of ScalarValue for this LogicalType + fn try_from(data_type: &LogicalType) -> Result { + Ok(match data_type { + LogicalType::Boolean => ScalarValue::Boolean(None), + LogicalType::Float64 => ScalarValue::Float64(None), + LogicalType::Float32 => ScalarValue::Float32(None), + LogicalType::Int8 => ScalarValue::Int8(None), + LogicalType::Int16 => ScalarValue::Int16(None), + LogicalType::Int32 => ScalarValue::Int32(None), + LogicalType::Int64 => ScalarValue::Int64(None), + LogicalType::UInt8 => ScalarValue::UInt8(None), + LogicalType::UInt16 => ScalarValue::UInt16(None), + LogicalType::UInt32 => ScalarValue::UInt32(None), + LogicalType::UInt64 => ScalarValue::UInt64(None), + LogicalType::Decimal128(precision, scale) => { + ScalarValue::Decimal128(None, *precision, *scale) + } + LogicalType::Decimal256(precision, scale) => { + ScalarValue::Decimal256(None, *precision, *scale) + } + LogicalType::Utf8 => ScalarValue::Utf8(None), + LogicalType::LargeUtf8 => ScalarValue::LargeUtf8(None), + LogicalType::Binary => ScalarValue::Binary(None), + LogicalType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), + LogicalType::LargeBinary => ScalarValue::LargeBinary(None), + LogicalType::Date32 => ScalarValue::Date32(None), + LogicalType::Date64 => ScalarValue::Date64(None), + LogicalType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None), + LogicalType::Time32(TimeUnit::Millisecond) => { + ScalarValue::Time32Millisecond(None) + } + LogicalType::Time64(TimeUnit::Microsecond) => { + ScalarValue::Time64Microsecond(None) + } + LogicalType::Time64(TimeUnit::Nanosecond) => { + ScalarValue::Time64Nanosecond(None) + } + LogicalType::Timestamp(TimeUnit::Second, tz_opt) => { + ScalarValue::TimestampSecond(None, tz_opt.clone()) + } + LogicalType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + ScalarValue::TimestampMillisecond(None, tz_opt.clone()) + } + LogicalType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) + } + LogicalType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + ScalarValue::TimestampNanosecond(None, tz_opt.clone()) + } + LogicalType::Interval(IntervalUnit::YearMonth) => { + ScalarValue::IntervalYearMonth(None) + } + LogicalType::Interval(IntervalUnit::DayTime) => { + ScalarValue::IntervalDayTime(None) + } + LogicalType::Interval(IntervalUnit::MonthDayNano) => { + ScalarValue::IntervalMonthDayNano(None) + } + + LogicalType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None), + LogicalType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(None) + } + LogicalType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(None) + } + LogicalType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(None) + } + + LogicalType::List(_) => ScalarValue::List(new_null_array(&DataType::Null, 0)), + + LogicalType::Struct(fields) => { + let fields = fields + .iter() + .map(|e| { + Arc::new(Field::new( + e.name(), + e.data_type().physical_type(), + false, + )) + }) + .collect::>(); + + ScalarValue::Struct(None, fields.into()) + } + LogicalType::Null => ScalarValue::Null, + _ => { + return _not_impl_err!( + "Can't create a scalar from logical_type \"{data_type:?}\"" + ); + } + }) + } +} + macro_rules! format_option { ($F:expr, $EXPR:expr) => {{ match $EXPR { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 89e82fa952bb..50dc28ff9fc3 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -30,6 +30,7 @@ use arrow::datatypes::{DataType, Field}; use async_trait::async_trait; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::logical_type::{ExtensionType, LogicalType}; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ DataFusionError, FileType, FileTypeWriterOptions, SchemaError, UnnestOptions, @@ -464,7 +465,10 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| { - !matches!(f.data_type(), DataType::Binary | DataType::Boolean) + !matches!( + f.data_type(), + LogicalType::Binary | LogicalType::Boolean + ) }) .map(|f| min(col(f.name())).alias(f.name())) .collect::>(), @@ -475,7 +479,10 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| { - !matches!(f.data_type(), DataType::Binary | DataType::Boolean) + !matches!( + f.data_type(), + LogicalType::Binary | LogicalType::Boolean + ) }) .map(|f| max(col(f.name())).alias(f.name())) .collect::>(), @@ -1023,7 +1030,7 @@ impl DataFrame { table_name: &str, write_options: DataFrameWriteOptions, ) -> Result, DataFusionError> { - let arrow_schema = Schema::from(self.schema()); + let arrow_schema = self.schema().clone().into(); let plan = LogicalPlanBuilder::insert_into( self.plan, table_name.to_owned(), @@ -2252,7 +2259,7 @@ mod tests { .await? .select_columns(&["c2", "c3"])? .limit(0, Some(1))? - .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + .with_column("sum", cast(col("c2") + col("c3"), LogicalType::Int64))?; let df_results = df.clone().collect().await?; df.clone().show().await?; @@ -2343,7 +2350,7 @@ mod tests { .await? .select_columns(&["c2", "c3"])? .limit(0, Some(1))? - .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + .with_column("sum", cast(col("c2") + col("c3"), LogicalType::Int64))?; let cached_df = df.clone().cache().await?; diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 986e54ebbe85..e211e43ffefb 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -263,7 +263,7 @@ async fn prune_partitions( let df_schema = DFSchema::new_with_metadata( partition_cols .iter() - .map(|(n, d)| DFField::new_unqualified(n, d.clone(), true)) + .map(|(n, d)| DFField::new_unqualified(n, d.into(), true)) .collect(), Default::default(), )?; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index d26d417bd8b2..59b5d6670062 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -253,6 +253,7 @@ pub struct ListingOptions { pub format: Arc, /// The expected partition column names in the folder structure. /// See [Self::with_table_partition_cols] for details + /// TODO this maybe LogicalType pub table_partition_cols: Vec<(String, DataType)>, /// Set true to try to guess statistics from the files. /// This can add a lot of overhead as it will usually require files @@ -2012,7 +2013,7 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema.into(), false)?.build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() @@ -2227,7 +2228,7 @@ mod tests { // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema.into(), false)?.build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 6bcaa97a408f..0b191ef0aa8c 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -561,7 +561,7 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema.into(), false)?.build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index dc6ef50bc101..867bc29543a6 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -1348,7 +1348,7 @@ mod tests { None } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9c500ec07293..8a5b92ef4fca 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -37,13 +37,18 @@ use crate::{ }; use datafusion_common::{ alias::AliasGenerator, - exec_err, not_impl_err, plan_datafusion_err, plan_err, + exec_err, + logical_type::{ + ExtensionType, ExtensionTypeRef, LogicalType, NamedLogicalType, + OwnedTypeSignature, TypeManager, TypeSignature, + }, + not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - StringifiedPlan, UserDefinedLogicalNode, WindowUDF, + CreateType, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -57,7 +62,7 @@ use std::{ }; use std::{ops::ControlFlow, sync::Weak}; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use crate::catalog::{ @@ -483,6 +488,7 @@ impl SessionContext { self.create_catalog_schema(cmd).await } DdlStatement::CreateCatalog(cmd) => self.create_catalog(cmd).await, + DdlStatement::CreateType(cmd) => self.create_type(cmd).await, DdlStatement::DropTable(cmd) => self.drop_table(cmd).await, DdlStatement::DropView(cmd) => self.drop_view(cmd).await, DdlStatement::DropCatalogSchema(cmd) => self.drop_schema(cmd).await, @@ -661,6 +667,28 @@ impl SessionContext { } } + async fn create_type(&self, cmd: CreateType) -> Result { + let CreateType { + name, data_type, .. + } = cmd; + + let extension_type = Arc::new(NamedLogicalType::new(name, data_type)); + let name = extension_type.display_name(); + let type_signature = extension_type.type_signature(); + + let exists_type = self.data_type(&type_signature)?; + if exists_type.is_some() { + return exec_err!("DataType '{name}' already exists"); + } + + self.register_data_type( + type_signature.to_owned_type_signature(), + extension_type, + )?; + + self.return_empty_dataframe() + } + async fn drop_table(&self, cmd: DropTable) -> Result { let DropTable { name, if_exists, .. @@ -1007,6 +1035,25 @@ impl SessionContext { self.state.read().catalog_list.catalog(name) } + /// TODO + pub fn data_type( + &self, + signature: &TypeSignature, + ) -> Result> { + self.state.read().data_type(signature) + } + + /// TODO + pub fn register_data_type( + &self, + signature: impl Into>, + extension_type: ExtensionTypeRef, + ) -> Result<()> { + self.state + .write() + .register_data_type(signature, extension_type) + } + /// Registers a [`TableProvider`] as a table that can be /// referenced from SQL statements executed against this context. /// @@ -1230,6 +1277,7 @@ pub struct SessionState { aggregate_functions: HashMap>, /// Window functions registered in the context window_functions: HashMap>, + data_types: HashMap, /// Deserializer registry for extensions. serializer_registry: Arc, /// Session configuration @@ -1325,6 +1373,7 @@ impl SessionState { scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), + data_types: HashMap::new(), serializer_registry: Arc::new(EmptySerializerRegistry), config, execution_props: ExecutionProps::new(), @@ -1846,6 +1895,22 @@ impl SessionState { } } +impl TypeManager for SessionState { + fn register_data_type( + &mut self, + signature: impl Into, + extension_type: ExtensionTypeRef, + ) -> Result<()> { + self.data_types.insert(signature.into(), extension_type); + + Ok(()) + } + + fn data_type(&self, signature: &TypeSignature) -> Result> { + Ok(self.data_types.get(signature).cloned()) + } +} + struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, @@ -1872,7 +1937,7 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { self.state.window_functions().get(name).cloned() } - fn get_variable_type(&self, variable_names: &[String]) -> Option { + fn get_variable_type(&self, variable_names: &[String]) -> Option { if variable_names.is_empty() { return None; } @@ -1890,6 +1955,14 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names)) } + fn get_data_type(&self, signature: &TypeSignature) -> Option { + self.state + .data_type(signature) + .ok() + .flatten() + .map(LogicalType::Extension) + } + fn options(&self) -> &ConfigOptions { self.state.config_options() } diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index de508327fade..b001b10dbc9c 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -20,11 +20,10 @@ //! //! [`Expr`]: crate::prelude::Expr use std::collections::HashSet; -use std::convert::TryFrom; use std::sync::Arc; use crate::{ - common::{Column, DFSchema}, + common::Column, error::{DataFusionError, Result}, logical_expr::Operator, physical_plan::{ColumnarValue, PhysicalExpr}, @@ -487,13 +486,8 @@ impl<'a> PruningExpressionBuilder<'a> { } }; - let df_schema = DFSchema::try_from(schema.clone())?; - let (column_expr, correct_operator, scalar_expr) = rewrite_expr_to_prunable( - column_expr, - correct_operator, - scalar_expr, - df_schema, - )?; + let (column_expr, correct_operator, scalar_expr) = + rewrite_expr_to_prunable(column_expr, correct_operator, scalar_expr, schema)?; let column = columns.iter().next().unwrap().clone(); let field = match schema.column_with_name(column.name()) { Some((_, f)) => f, @@ -547,7 +541,7 @@ fn rewrite_expr_to_prunable( column_expr: &PhysicalExprRef, op: Operator, scalar_expr: &PhysicalExprRef, - schema: DFSchema, + arrow_schema: &Schema, ) -> Result<(PhysicalExprRef, Operator, PhysicalExprRef)> { if !is_compare_op(op) { return plan_err!("rewrite_expr_to_prunable only support compare expression"); @@ -563,11 +557,10 @@ fn rewrite_expr_to_prunable( Ok((column_expr.clone(), op, scalar_expr.clone())) } else if let Some(cast) = column_expr_any.downcast_ref::() { // `cast(col) op lit()` - let arrow_schema: SchemaRef = schema.clone().into(); - let from_type = cast.expr().data_type(&arrow_schema)?; + let from_type = cast.expr().data_type(arrow_schema)?; verify_support_type_for_prune(&from_type, cast.cast_type())?; let (left, op, right) = - rewrite_expr_to_prunable(cast.expr(), op, scalar_expr, schema)?; + rewrite_expr_to_prunable(cast.expr(), op, scalar_expr, arrow_schema)?; let left = Arc::new(phys_expr::CastExpr::new( left, cast.cast_type().clone(), @@ -578,11 +571,10 @@ fn rewrite_expr_to_prunable( column_expr_any.downcast_ref::() { // `try_cast(col) op lit()` - let arrow_schema: SchemaRef = schema.clone().into(); - let from_type = try_cast.expr().data_type(&arrow_schema)?; + let from_type = try_cast.expr().data_type(arrow_schema)?; verify_support_type_for_prune(&from_type, try_cast.cast_type())?; let (left, op, right) = - rewrite_expr_to_prunable(try_cast.expr(), op, scalar_expr, schema)?; + rewrite_expr_to_prunable(try_cast.expr(), op, scalar_expr, arrow_schema)?; let left = Arc::new(phys_expr::TryCastExpr::new( left, try_cast.cast_type().clone(), @@ -591,7 +583,7 @@ fn rewrite_expr_to_prunable( } else if let Some(neg) = column_expr_any.downcast_ref::() { // `-col > lit()` --> `col < -lit()` let (left, op, right) = - rewrite_expr_to_prunable(neg.arg(), op, scalar_expr, schema)?; + rewrite_expr_to_prunable(neg.arg(), op, scalar_expr, arrow_schema)?; let right = Arc::new(phys_expr::NegativeExpr::new(right)); Ok((left, reverse_operator(op)?, right)) } else if let Some(not) = column_expr_any.downcast_ref::() { @@ -2390,20 +2382,15 @@ mod tests { #[test] fn test_rewrite_expr_to_prunable() { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let df_schema = DFSchema::try_from(schema.clone()).unwrap(); // column op lit let left_input = col("a"); let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int32(Some(12))); let right_input = logical2physical(&right_input, &schema); - let (result_left, _, result_right) = rewrite_expr_to_prunable( - &left_input, - Operator::Eq, - &right_input, - df_schema.clone(), - ) - .unwrap(); + let (result_left, _, result_right) = + rewrite_expr_to_prunable(&left_input, Operator::Eq, &right_input, &schema) + .unwrap(); assert_eq!(result_left.to_string(), left_input.to_string()); assert_eq!(result_right.to_string(), right_input.to_string()); @@ -2412,13 +2399,9 @@ mod tests { let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Decimal128(Some(12), 20, 3)); let right_input = logical2physical(&right_input, &schema); - let (result_left, _, result_right) = rewrite_expr_to_prunable( - &left_input, - Operator::Gt, - &right_input, - df_schema.clone(), - ) - .unwrap(); + let (result_left, _, result_right) = + rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, &schema) + .unwrap(); assert_eq!(result_left.to_string(), left_input.to_string()); assert_eq!(result_right.to_string(), right_input.to_string()); @@ -2428,7 +2411,7 @@ mod tests { let right_input = lit(ScalarValue::Int64(Some(12))); let right_input = logical2physical(&right_input, &schema); let (result_left, _, result_right) = - rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, df_schema) + rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, &schema) .unwrap(); assert_eq!(result_left.to_string(), left_input.to_string()); assert_eq!(result_right.to_string(), right_input.to_string()); @@ -2441,17 +2424,12 @@ mod tests { // cast string value to numeric value // this cast is not supported let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let df_schema = DFSchema::try_from(schema.clone()).unwrap(); let left_input = cast(col("a"), DataType::Int64); let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int64(Some(12))); let right_input = logical2physical(&right_input, &schema); - let result = rewrite_expr_to_prunable( - &left_input, - Operator::Gt, - &right_input, - df_schema.clone(), - ); + let result = + rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, &schema); assert!(result.is_err()); // other expr @@ -2460,7 +2438,7 @@ mod tests { let right_input = lit(ScalarValue::Int64(Some(12))); let right_input = logical2physical(&right_input, &schema); let result = - rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, df_schema); + rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, &schema); assert!(result.is_err()); // TODO: add other negative test for other case and op } diff --git a/datafusion/core/src/test/variable.rs b/datafusion/core/src/test/variable.rs index a55513841561..54373f160b4c 100644 --- a/datafusion/core/src/test/variable.rs +++ b/datafusion/core/src/test/variable.rs @@ -21,6 +21,7 @@ use crate::error::Result; use crate::scalar::ScalarValue; use crate::variable::VarProvider; use arrow::datatypes::DataType; +use datafusion_common::logical_type::LogicalType; /// System variable #[derive(Default, Debug)] @@ -41,7 +42,7 @@ impl VarProvider for SystemVar { } fn get_type(&self, _: &[String]) -> Option { - Some(DataType::Utf8) + Some(LogicalType::Utf8) } } @@ -69,9 +70,9 @@ impl VarProvider for UserDefinedVar { fn get_type(&self, var_names: &[String]) -> Option { if var_names[0] != "@integer" { - Some(DataType::Utf8) + Some(LogicalType::Utf8) } else { - Some(DataType::Int32) + Some(LogicalType::Int32) } } } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 845d77581b59..77f6f8664ee1 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -29,6 +29,7 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_schema::ArrowError; +use datafusion_common::logical_type::LogicalType; use std::sync::Arc; use datafusion::dataframe::DataFrame; @@ -1551,7 +1552,7 @@ impl VarProvider for HardcodedIntProvider { } fn get_type(&self, _: &[String]) -> Option { - Some(DataType::Int64) + Some(LogicalType::Int64) } } diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index b6d856b2d9a0..a4f2508c23cd 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -20,6 +20,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::DFSchema; use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*}; +use datafusion_common::logical_type::LogicalType; use datafusion_expr::{Expr, ExprSchemable}; use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyInfo}; @@ -39,7 +40,7 @@ struct MyInfo { impl SimplifyInfo for MyInfo { fn is_boolean_type(&self, expr: &Expr) -> Result { - Ok(matches!(expr.get_type(&self.schema)?, DataType::Boolean)) + Ok(matches!(expr.get_type(&self.schema)?, LogicalType::Boolean)) } fn nullable(&self, expr: &Expr) -> Result { @@ -50,7 +51,7 @@ impl SimplifyInfo for MyInfo { &self.execution_props } - fn get_data_type(&self, expr: &Expr) -> Result { + fn get_data_type(&self, expr: &Expr) -> Result { expr.get_type(&self.schema) } } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index c31bd04eafa0..27bf6706c84b 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -19,6 +19,7 @@ use crate::expr::Case; use crate::{expr_schema::ExprSchemable, Expr}; use arrow::datatypes::DataType; +use datafusion_common::logical_type::LogicalType; use datafusion_common::{plan_err, DFSchema, DataFusionError, Result}; use std::collections::HashSet; @@ -89,18 +90,18 @@ impl CaseBuilder { then_expr.push(e.as_ref().to_owned()); } - let then_types: Vec = then_expr + let then_types: Vec = then_expr .iter() .map(|e| match e { Expr::Literal(_) => e.get_type(&DFSchema::empty()), - _ => Ok(DataType::Null), + _ => Ok(LogicalType::Null), }) .collect::>>()?; - if then_types.contains(&DataType::Null) { + if then_types.contains(&LogicalType::Null) { // cannot verify types until execution type } else { - let unique_types: HashSet<&DataType> = then_types.iter().collect(); + let unique_types: HashSet<&LogicalType> = then_types.iter().collect(); if unique_types.len() != 1 { return plan_err!( "CASE expression 'then' values had multiple data types: {unique_types:?}" diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8929b21f4412..6e37227039f0 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -26,7 +26,7 @@ use crate::window_frame; use crate::window_function; use crate::Operator; use crate::{aggregate_function, ExprSchemable}; -use arrow::datatypes::DataType; +use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; @@ -88,7 +88,7 @@ pub enum Expr { /// A named reference to a qualified filed in a schema. Column(Column), /// A named reference to a variable in a registry. - ScalarVariable(DataType, Vec), + ScalarVariable(LogicalType, Vec), /// A constant value. Literal(ScalarValue), /// A binary expression such as "age > 21" @@ -184,7 +184,7 @@ pub enum Expr { Placeholder(Placeholder), /// A place holder which hold a reference to a qualified field /// in the outer query, used for correlated sub queries. - OuterReferenceColumn(DataType, Column), + OuterReferenceColumn(LogicalType, Column), } /// Alias expression @@ -402,12 +402,12 @@ pub struct Cast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub data_type: LogicalType, } impl Cast { /// Create a new Cast expression - pub fn new(expr: Box, data_type: DataType) -> Self { + pub fn new(expr: Box, data_type: LogicalType) -> Self { Self { expr, data_type } } } @@ -418,12 +418,12 @@ pub struct TryCast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub data_type: LogicalType, } impl TryCast { /// Create a new TryCast expression - pub fn new(expr: Box, data_type: DataType) -> Self { + pub fn new(expr: Box, data_type: LogicalType) -> Self { Self { expr, data_type } } } @@ -615,12 +615,12 @@ pub struct Placeholder { /// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo"`) pub id: String, /// The type the parameter will be filled in with - pub data_type: Option, + pub data_type: Option, } impl Placeholder { /// Create a new Placeholder expression - pub fn new(id: String, data_type: Option) -> Self { + pub fn new(id: String, data_type: Option) -> Self { Self { id, data_type } } } @@ -1635,7 +1635,7 @@ mod test { use crate::expr::Cast; use crate::expr_fn::col; use crate::{case, lit, Expr}; - use arrow::datatypes::DataType; + use datafusion_common::logical_type::LogicalType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; @@ -1656,7 +1656,7 @@ mod test { fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), - data_type: DataType::Utf8, + data_type: LogicalType::Utf8, }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, expr.canonical_name()); @@ -1685,7 +1685,7 @@ mod test { fn test_collect_expr() -> Result<()> { // single column { - let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); + let expr = &Expr::Cast(Cast::new(Box::new(col("a")), LogicalType::Float64)); let columns = expr.to_columns()?; assert_eq!(1, columns.len()); assert!(columns.contains(&Column::from_name("a"))); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 98cacc039228..9900cc2a13d3 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -30,6 +30,7 @@ use crate::{ ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; use arrow::datatypes::DataType; +use datafusion_common::logical_type::LogicalType; use datafusion_common::{Column, Result}; use std::ops::Not; use std::sync::Arc; @@ -55,7 +56,7 @@ pub fn col(ident: impl Into) -> Expr { /// Create an out reference column which hold a reference that has been resolved to a field /// outside of the current plan. -pub fn out_ref_col(dt: DataType, ident: impl Into) -> Expr { +pub fn out_ref_col(dt: LogicalType, ident: impl Into) -> Expr { Expr::OuterReferenceColumn(dt, ident.into()) } @@ -435,12 +436,12 @@ pub fn rollup(exprs: Vec) -> Expr { } /// Create a cast expression -pub fn cast(expr: Expr, data_type: DataType) -> Expr { +pub fn cast(expr: Expr, data_type: LogicalType) -> Expr { Expr::Cast(Cast::new(Box::new(expr), data_type)) } /// Create a try cast expression -pub fn try_cast(expr: Expr, data_type: DataType) -> Expr { +pub fn try_cast(expr: Expr, data_type: LogicalType) -> Expr { Expr::TryCast(TryCast::new(Box::new(expr), data_type)) } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 1f04c80833f0..8d489c56853c 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -262,7 +262,7 @@ mod test { use super::*; use crate::expr::Sort; use crate::{col, lit, Cast}; - use arrow::datatypes::DataType; + use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; use std::ops::Add; @@ -393,7 +393,7 @@ mod test { } fn make_field(relation: &str, column: &str) -> DFField { - DFField::new(Some(relation.to_string()), column, DataType::Int8, false) + DFField::new(Some(relation.to_string()), column, LogicalType::Int8, false) } #[test] @@ -423,7 +423,7 @@ mod test { // cast data types test_rewrite( col("a"), - Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), + Expr::Cast(Cast::new(Box::new(col("a")), LogicalType::Int32)), ); // change literal type from i32 to i64 diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c87a724d5646..55966e16cf28 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -152,6 +152,7 @@ mod test { use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::logical_type::LogicalType; use crate::{ avg, cast, col, lit, logical_plan::builder::LogicalTableSource, min, try_cast, @@ -268,13 +269,13 @@ mod test { let cases = vec![ TestCase { desc: "Cast is preserved by rewrite_sort_cols_by_aggs", - input: sort(cast(col("c2"), DataType::Int64)), - expected: sort(cast(col("c2").alias("c2"), DataType::Int64)), + input: sort(cast(col("c2"), LogicalType::Int64)), + expected: sort(cast(col("c2").alias("c2"), LogicalType::Int64)), }, TestCase { desc: "TryCast is preserved by rewrite_sort_cols_by_aggs", - input: sort(try_cast(col("c2"), DataType::Int64)), - expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)), + input: sort(try_cast(col("c2"), LogicalType::Int64)), + expected: sort(try_cast(col("c2").alias("c2"), LogicalType::Int64)), }, ]; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 2889fac8c1ee..82aef5639c27 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -26,7 +26,7 @@ use crate::type_coercion::binary::get_result_type; use crate::type_coercion::functions::data_types; use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field}; +use datafusion_common::logical_type::{ExtensionType, LogicalType}; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DataFusionError, ExprSchema, Result, @@ -37,7 +37,7 @@ use std::sync::Arc; /// trait to allow expr to typable with respect to a schema pub trait ExprSchemable { /// given a schema, return the type of the expr - fn get_type(&self, schema: &S) -> Result; + fn get_type(&self, schema: &S) -> Result; /// given a schema, return the nullability of the expr fn nullable(&self, input_schema: &S) -> Result; @@ -49,7 +49,11 @@ pub trait ExprSchemable { fn to_field(&self, input_schema: &DFSchema) -> Result; /// cast to a type with respect to a schema - fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; + fn cast_to( + self, + cast_to_type: &LogicalType, + schema: &S, + ) -> Result; } impl ExprSchemable for Expr { @@ -65,7 +69,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). - fn get_type(&self, schema: &S) -> Result { + fn get_type(&self, schema: &S) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { @@ -78,7 +82,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), - Expr::Literal(l) => Ok(l.data_type()), + Expr::Literal(l) => Ok(l.logical_type()), Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), @@ -87,7 +91,14 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + + // TODO not convert to DataType + let data_types = data_types + .into_iter() + .map(|e| e.physical_type()) + .collect::>(); + + Ok((fun.return_type)(&data_types)?.as_ref().clone().into()) } Expr::ScalarFunction(ScalarFunction { fun, args }) => { let arg_data_types = args @@ -95,6 +106,12 @@ impl ExprSchemable for Expr { .map(|e| e.get_type(schema)) .collect::>>()?; + // TODO not convert to DataType + let arg_data_types = arg_data_types + .into_iter() + .map(|e| e.physical_type()) + .collect::>(); + // verify that input data types is consistent with function's `TypeSignature` data_types(&arg_data_types, &fun.signature()).map_err(|_| { plan_datafusion_err!( @@ -107,28 +124,49 @@ impl ExprSchemable for Expr { ) })?; - fun.return_type(&arg_data_types) + fun.return_type(&arg_data_types).map(Into::into) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - fun.return_type(&data_types) + + // TODO not convert to DataType + let data_types = data_types + .into_iter() + .map(|e| e.physical_type()) + .collect::>(); + + fun.return_type(&data_types).map(Into::into) } Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - fun.return_type(&data_types) + + // TODO not convert to DataType + let data_types = data_types + .into_iter() + .map(|e| e.physical_type()) + .collect::>(); + + fun.return_type(&data_types).map(Into::into) } Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + + // TODO not convert to DataType + let data_types = data_types + .into_iter() + .map(|e| e.physical_type()) + .collect::>(); + + Ok((fun.return_type)(&data_types)?.as_ref().clone().into()) } Expr::Not(_) | Expr::IsNull(_) @@ -142,7 +180,7 @@ impl ExprSchemable for Expr { | Expr::IsUnknown(_) | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) => Ok(DataType::Boolean), + | Expr::IsNotUnknown(_) => Ok(LogicalType::Boolean), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).data_type().clone()) } @@ -150,8 +188,16 @@ impl ExprSchemable for Expr { ref left, ref right, ref op, - }) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?), - Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), + }) => { + // TODO not convert to physical DataType + let physical_type = get_result_type( + &left.get_type(schema)?.physical_type(), + op, + &right.get_type(schema)?.physical_type(), + )?; + Ok(physical_type.into()) + } + Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(LogicalType::Boolean), Expr::Placeholder(Placeholder { data_type, .. }) => { data_type.clone().ok_or_else(|| { plan_datafusion_err!("Placeholder type could not be resolved") @@ -159,14 +205,14 @@ impl ExprSchemable for Expr { } Expr::Wildcard => { // Wildcard do not really have a type and do not appear in projections - Ok(DataType::Null) + Ok(LogicalType::Null) } Expr::QualifiedWildcard { .. } => internal_err!( "QualifiedWildcard expressions are not valid in a logical query plan" ), Expr::GroupingSet(_) => { // grouping sets do not really have a type and do not appear in projections - Ok(DataType::Null) + Ok(LogicalType::Null) } Expr::GetIndexedField(GetIndexedField { expr, field }) => { field_for_index(expr, field, schema).map(|x| x.data_type().clone()) @@ -323,7 +369,11 @@ impl ExprSchemable for Expr { /// /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. - fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result { + fn cast_to( + self, + cast_to_type: &LogicalType, + schema: &S, + ) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { return Ok(self); @@ -332,8 +382,8 @@ impl ExprSchemable for Expr { // TODO(kszucs): most of the operations do not validate the type correctness // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? - - if can_cast_types(&this_type, cast_to_type) { + // TODO The basis for whether cast can be executed should be the logical type + if can_cast_types(&this_type.physical_type(), &cast_to_type.physical_type()) { match self { Expr::ScalarSubquery(subquery) => { Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) @@ -346,30 +396,48 @@ impl ExprSchemable for Expr { } } -/// return the schema [`Field`] for the type referenced by `get_indexed_field` +/// return the schema [`DFField`] for the type referenced by `get_indexed_field` fn field_for_index( expr: &Expr, field: &GetFieldAccess, schema: &S, -) -> Result { +) -> Result { let expr_dt = expr.get_type(schema)?; match field { GetFieldAccess::NamedStructField { name } => { GetFieldAccessSchema::NamedStructField { name: name.clone() } } - GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex { - key_dt: key.get_type(schema)?, - }, - GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange { - start_dt: start.get_type(schema)?, - stop_dt: stop.get_type(schema)?, - }, + GetFieldAccess::ListIndex { key } => { + let data_type = key.get_type(schema)?; + match data_type { + LogicalType::Int64 => {} + _ => { + return plan_err!( + "Only ints are valid as an indexed field in a list" + ); + } + } + + GetFieldAccessSchema::ListIndex + } + GetFieldAccess::ListRange { start, stop } => { + match (start.get_type(schema)?, stop.get_type(schema)?) { + (LogicalType::Int64, LogicalType::Int64) => {} + _ => { + return plan_err!( + "Only ints are valid as an indexed field in a list" + ); + } + } + + GetFieldAccessSchema::ListRange + } } - .get_accessed_field(&expr_dt) + .get_accessed_df_field(&expr_dt) } /// cast subquery in InSubquery/ScalarSubquery to a given type. -pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { +pub fn cast_subquery(subquery: Subquery, cast_to_type: &LogicalType) -> Result { if subquery.subquery.schema().field(0).data_type() == cast_to_type { return Ok(subquery); } @@ -404,7 +472,6 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result, } @@ -553,7 +620,7 @@ mod tests { fn new() -> Self { Self { nullable: false, - data_type: DataType::Null, + data_type: LogicalType::Null, error_on_nullable: false, metadata: HashMap::new(), } @@ -564,7 +631,7 @@ mod tests { self } - fn with_data_type(mut self, data_type: DataType) -> Self { + fn with_data_type(mut self, data_type: LogicalType) -> Self { self.data_type = data_type; self } @@ -589,7 +656,7 @@ mod tests { } } - fn data_type(&self, _col: &Column) -> Result<&DataType> { + fn data_type(&self, _col: &Column) -> Result<&LogicalType> { Ok(&self.data_type) } diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index 3829a2086b26..88704d954992 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -19,7 +19,8 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::{ - plan_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, + logical_type::LogicalType, plan_datafusion_err, plan_err, DFField, DataFusionError, + Result, ScalarValue, }; /// Types of the field access expression of a nested type, such as `Field` or `List` @@ -27,12 +28,9 @@ pub enum GetFieldAccessSchema { /// Named field, For example `struct["name"]` NamedStructField { name: ScalarValue }, /// Single list index, for example: `list[i]` - ListIndex { key_dt: DataType }, + ListIndex, /// List range, for example `list[i:j]` - ListRange { - start_dt: DataType, - stop_dt: DataType, - }, + ListRange, } impl GetFieldAccessSchema { @@ -76,22 +74,72 @@ impl GetFieldAccessSchema { (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, or `Map` types, got {other}"), } } - Self::ListIndex{ key_dt } => { - match (data_type, key_dt) { - (DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)), - (DataType::List(_), _) => plan_err!( - "Only ints are valid as an indexed field in a list" - ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + Self::ListIndex => { + match data_type { + DataType::List(lt) => Ok(Field::new("list", lt.data_type().clone(), true)), + other => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } + } + Self::ListRange => { + match data_type { + DataType::List(_) => Ok(Field::new("list", data_type.clone(), true)), + other => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), } } - Self::ListRange{ start_dt, stop_dt } => { - match (data_type, start_dt, stop_dt) { - (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), - (DataType::List(_), _, _) => plan_err!( - "Only ints are valid as an indexed field in a list" + } + } + /// Returns the schema [`DFField`] from a [`LogicalType::List`] or + /// [`LogicalType::Struct`] indexed by this structure + /// + /// # Error + /// Errors if + /// * the `data_type` is not a Struct or a List, + /// * the `data_type` of the name/index/start-stop do not match a supported index type + pub fn get_accessed_df_field(&self, data_type: &LogicalType) -> Result { + match self { + Self::NamedStructField{ name } => { + match (data_type, name) { + (LogicalType::Map(fields, _), _) => { + match fields.data_type() { + LogicalType::Struct(fields) if fields.len() == 2 => { + // Arrow's MapArray is essentially a ListArray of structs with two columns. They are + // often named "key", and "value", but we don't require any specific naming here; + // instead, we assume that the second columnis the "value" column both here and in + // execution. + let value_field = fields.get(1).expect("fields should have exactly two members"); + Ok(DFField::new_unqualified("map", value_field.data_type().clone(), true)) + }, + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + } + } + (LogicalType::Struct(fields), ScalarValue::Utf8(Some(s))) => { + if s.is_empty() { + plan_err!( + "Struct based indexed access requires a non empty string" + ) + } else { + let field = fields.iter().find(|f| f.name() == s); + field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| { + DFField::new_unqualified(f.name(), f.data_type().clone(), true) + }) + } + } + (LogicalType::Struct(_), _) => plan_err!( + "Only utf8 strings are valid as an indexed field in a struct" ), - (other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, or `Map` types, got {other}"), + } + } + Self::ListIndex => { + match data_type { + LogicalType::List(lt) => Ok(DFField::new_unqualified("list", lt.as_ref().clone(), true)), + other => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } + } + Self::ListRange => { + match data_type { + LogicalType::List(_) => Ok(DFField::new_unqualified("list", data_type.clone(), true)), + other => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), } } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 162a6a959e59..87c12201f948 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -47,8 +47,9 @@ use crate::{ TableProviderFilterPushDown, TableSource, WriteOp, }; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; +use datafusion_common::logical_type::{ExtensionType, LogicalType}; use datafusion_common::{ plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, FileType, OwnedTableReference, Result, ScalarValue, TableReference, @@ -139,7 +140,7 @@ impl LogicalPlanBuilder { return plan_err!("Values list cannot be zero length"); } let empty_schema = DFSchema::empty(); - let mut field_types: Vec> = Vec::with_capacity(n_cols); + let mut field_types: Vec> = Vec::with_capacity(n_cols); for _ in 0..n_cols { field_types.push(None); } @@ -171,7 +172,7 @@ impl LogicalPlanBuilder { Ok(Some(data_type)) } }) - .collect::>>>()?; + .collect::>>>()?; } let fields = field_types .iter() @@ -181,7 +182,7 @@ impl LogicalPlanBuilder { let name = &format!("column{}", j + 1); DFField::new_unqualified( name, - data_type.clone().unwrap_or(DataType::Utf8), + data_type.clone().unwrap_or(LogicalType::Utf8), true, ) }) @@ -255,11 +256,10 @@ impl LogicalPlanBuilder { pub fn insert_into( input: LogicalPlan, table_name: impl Into, - table_schema: &Schema, + table_schema: &DFSchemaRef, overwrite: bool, ) -> Result { - let table_schema = table_schema.clone().to_dfschema_ref()?; - + let table_schema = table_schema.clone(); let op = if overwrite { WriteOp::InsertOverwrite } else { @@ -350,7 +350,7 @@ impl LogicalPlanBuilder { } /// Make a builder for a prepare logical plan from the builder's plan - pub fn prepare(self, name: String, data_types: Vec) -> Result { + pub fn prepare(self, name: String, data_types: Vec) -> Result { Ok(Self::from(LogicalPlan::Prepare(Prepare { name, data_types, @@ -905,8 +905,7 @@ impl LogicalPlanBuilder { /// /// if `verbose` is true, prints out additional details. pub fn explain(self, verbose: bool, analyze: bool) -> Result { - let schema = LogicalPlan::explain_schema(); - let schema = schema.to_dfschema_ref()?; + let schema = LogicalPlan::explain_schema()?; if analyze { Ok(Self::from(LogicalPlan::Analyze(Analyze { @@ -1223,17 +1222,21 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result DFField::new( + LogicalType::List(dt) + | LogicalType::FixedSizeList(dt, _) + | LogicalType::LargeList(dt) => DFField::new( unnest_field.qualifier().cloned(), unnest_field.name(), - field.data_type().clone(), + dt.as_ref().clone(), unnest_field.is_nullable(), ), _ => { @@ -1924,7 +1927,7 @@ mod tests { .schema() .field_with_name(Some(&TableReference::bare("test_table")), "strings") .unwrap(); - assert_eq!(&DataType::Utf8, field.data_type()); + assert_eq!(&LogicalType::Utf8, field.data_type()); // Unnesting multiple fields. let plan = nested_table_scan("test_table")? @@ -1943,7 +1946,7 @@ mod tests { .schema() .field_with_name(Some(&TableReference::bare("test_table")), "structs") .unwrap(); - assert!(matches!(field.data_type(), DataType::Struct(_))); + assert!(matches!(field.data_type(), LogicalType::Struct(_))); // Unnesting missing column should fail. let plan = nested_table_scan("test_table")?.unnest_column("missing"); diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 2c90a3aca754..645e539d66a6 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -24,6 +24,7 @@ use std::{ use crate::{Expr, LogicalPlan}; +use datafusion_common::logical_type::LogicalType; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ Constraints, DFSchemaRef, OwnedSchemaReference, OwnedTableReference, @@ -42,6 +43,8 @@ pub enum DdlStatement { CreateCatalogSchema(CreateCatalogSchema), /// Creates a new catalog (aka "Database"). CreateCatalog(CreateCatalog), + /// Creates a new user defined data type. + CreateType(CreateType), /// Drops a table. DropTable(DropTable), /// Drops a view. @@ -63,6 +66,7 @@ impl DdlStatement { schema } DdlStatement::CreateCatalog(CreateCatalog { schema, .. }) => schema, + DdlStatement::CreateType(CreateType { schema, .. }) => schema, DdlStatement::DropTable(DropTable { schema, .. }) => schema, DdlStatement::DropView(DropView { schema, .. }) => schema, DdlStatement::DropCatalogSchema(DropCatalogSchema { schema, .. }) => schema, @@ -78,6 +82,7 @@ impl DdlStatement { DdlStatement::CreateView(_) => "CreateView", DdlStatement::CreateCatalogSchema(_) => "CreateCatalogSchema", DdlStatement::CreateCatalog(_) => "CreateCatalog", + DdlStatement::CreateType(_) => "CreateType", DdlStatement::DropTable(_) => "DropTable", DdlStatement::DropView(_) => "DropView", DdlStatement::DropCatalogSchema(_) => "DropCatalogSchema", @@ -94,6 +99,7 @@ impl DdlStatement { vec![input] } DdlStatement::CreateView(CreateView { input, .. }) => vec![input], + DdlStatement::CreateType(_) => vec![], DdlStatement::DropTable(_) => vec![], DdlStatement::DropView(_) => vec![], DdlStatement::DropCatalogSchema(_) => vec![], @@ -138,6 +144,9 @@ impl DdlStatement { }) => { write!(f, "CreateCatalog: {catalog_name:?}") } + DdlStatement::CreateType(CreateType { name, .. }) => { + write!(f, "CreateType: {name}") + } DdlStatement::DropTable(DropTable { name, if_exists, .. }) => { @@ -265,6 +274,15 @@ pub struct CreateCatalogSchema { pub schema: DFSchemaRef, } +/// Creates a schema. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct CreateType { + pub name: String, + pub data_type: LogicalType, + /// Empty schema + pub schema: DFSchemaRef, +} + /// Drops a table. #[derive(Clone, PartialEq, Eq, Hash)] pub struct DropTable { diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 112dbf74dba1..543746f7f261 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -17,10 +17,9 @@ //! This module provides logic for displaying LogicalPlans in various styles use crate::LogicalPlan; -use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; -use datafusion_common::DataFusionError; +use datafusion_common::{DataFusionError, DFSchema}; use std::fmt; /// Formats plans with a single line per node. For example: @@ -64,7 +63,7 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { write!( self.f, " {}", - display_schema(&plan.schema().as_ref().to_owned().into()) + display_schema(plan.schema().as_ref()) )?; } @@ -99,8 +98,8 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { /// format!("{}", display_schema(&schema)) /// ); /// ``` -pub fn display_schema(schema: &Schema) -> impl fmt::Display + '_ { - struct Wrapper<'a>(&'a Schema); +pub fn display_schema(schema: &DFSchema) -> impl fmt::Display + '_ { + struct Wrapper<'a>(&'a DFSchema); impl<'a> fmt::Display for Wrapper<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -185,7 +184,7 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { format!( r"{}\nSchema: {}", plan.display(), - display_schema(&plan.schema().as_ref().to_owned().into()) + display_schema(plan.schema().as_ref()) ) } else { format!("{}", plan.display()) @@ -222,20 +221,21 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; + use datafusion_common::{DFField, logical_type::LogicalType}; use super::*; #[test] fn test_display_empty_schema() { - let schema = Schema::empty(); + let schema = DFSchema::empty(); assert_eq!("[]", format!("{}", display_schema(&schema))); } #[test] fn test_display_schema() { - let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("first_name", DataType::Utf8, true), + let schema = DFSchema::new(vec![ + DFField::new_unqualified("id", LogicalType::Int32, false), + DFField::new_unqualified("first_name", LogicalType::Utf8, true), ]); assert_eq!( diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 8316417138bd..ed9eaa5eefe9 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -29,7 +29,7 @@ pub use builder::{ }; pub use ddl::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, - CreateView, DdlStatement, DropCatalogSchema, DropTable, DropView, + CreateType, CreateView, DdlStatement, DropCatalogSchema, DropTable, DropView, }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d62ac8926328..55516b8325d5 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -41,6 +41,7 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{ RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, @@ -48,7 +49,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, - OwnedTableReference, Result, ScalarValue, UnnestOptions, + OwnedTableReference, Result, ScalarValue, TableReference, UnnestOptions, }; // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; @@ -254,20 +255,20 @@ impl LogicalPlan { } /// Returns the (fixed) output schema for explain plans - pub fn explain_schema() -> SchemaRef { - SchemaRef::new(Schema::new(vec![ - Field::new("plan_type", DataType::Utf8, false), - Field::new("plan", DataType::Utf8, false), - ])) + pub fn explain_schema() -> Result { + Ok(DFSchemaRef::new(DFSchema::new_with_metadata(vec![ + DFField::new_unqualified("plan_type", LogicalType::Utf8, false), + DFField::new_unqualified("plan", LogicalType::Utf8, false), + ], Default::default())?)) } /// Returns the (fixed) output schema for `DESCRIBE` plans - pub fn describe_schema() -> Schema { - Schema::new(vec![ - Field::new("column_name", DataType::Utf8, false), - Field::new("data_type", DataType::Utf8, false), - Field::new("is_nullable", DataType::Utf8, false), - ]) + pub fn describe_schema() -> Result { + Ok(DFSchemaRef::new(DFSchema::new_with_metadata(vec![ + DFField::new_unqualified("column_name", LogicalType::Utf8, false), + DFField::new_unqualified("data_type", LogicalType::Utf8, false), + DFField::new_unqualified("is_nullable", LogicalType::Utf8, false), + ], Default::default())?)) } /// returns all expressions (non-recursively) in the current @@ -969,11 +970,11 @@ impl LogicalPlan { // Verify if the types of the params matches the types of the values let iter = prepare_lp.data_types.iter().zip(param_values.iter()); for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { + if *param_type != value.logical_type() { return plan_err!( "Expected parameter of type {:?}, got {:?} at index {}", param_type, - value.data_type(), + value.logical_type(), i ); } @@ -1163,8 +1164,8 @@ impl LogicalPlan { /// Walk the logical plan, find any `PlaceHolder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, - ) -> Result>, DataFusionError> { - let mut param_types: HashMap> = HashMap::new(); + ) -> Result>, DataFusionError> { + let mut param_types: HashMap> = HashMap::new(); self.apply(&mut |plan| { plan.inspect_expressions(|expr| { @@ -1218,7 +1219,7 @@ impl LogicalPlan { )) })?; // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { + if Some(value.logical_type()) != *data_type { return internal_err!( "Placeholder value type mismatch: expected {:?}, got {:?}", data_type, @@ -1804,12 +1805,13 @@ impl SubqueryAlias { alias: impl Into, ) -> Result { let alias = alias.into(); - let schema: Schema = plan.schema().as_ref().clone().into(); + let schema = plan.schema().as_ref(); // Since schema is the same, other than qualifier, we can use existing // functional dependencies: let func_dependencies = plan.schema().functional_dependencies().clone(); + let schema = DFSchemaRef::new( - DFSchema::try_from_qualified_schema(&alias, &schema)? + DFSchema::try_from_qualified_dfschema(&alias, schema)? .with_functional_dependencies(func_dependencies), ); Ok(SubqueryAlias { @@ -1848,7 +1850,7 @@ impl Filter { // construction (such as with correlated subqueries) so we make a best effort here and // ignore errors resolving the expression against the schema. if let Ok(predicate_type) = predicate.get_type(input.schema()) { - if predicate_type != DataType::Boolean { + if predicate_type != LogicalType::Boolean { return plan_err!( "Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}" ); @@ -1952,7 +1954,7 @@ impl TableScan { filters: Vec, fetch: Option, ) -> Result { - let table_name = table_name.into(); + let table_name: TableReference = table_name.into(); if table_name.table().is_empty() { return plan_err!("table_name cannot be empty"); @@ -1970,10 +1972,14 @@ impl TableScan { DFSchema::new_with_metadata( p.iter() .map(|i| { - DFField::from_qualified( - table_name.clone(), - schema.field(*i).clone(), + let f = schema.field(*i); + DFField::new( + Some(table_name.clone()), + f.name(), + f.data_type().into(), + f.is_nullable(), ) + .with_metadata(f.metadata().clone()) }) .collect(), schema.metadata().clone(), @@ -1983,9 +1989,10 @@ impl TableScan { }) }) .unwrap_or_else(|| { - DFSchema::try_from_qualified_schema(table_name.clone(), &schema).map( - |df_schema| df_schema.with_functional_dependencies(func_dependencies), - ) + DFSchema::try_from_qualified_schema(table_name.clone(), schema.as_ref()) + .map(|df_schema| { + df_schema.with_functional_dependencies(func_dependencies) + }) })?; let projected_schema = Arc::new(projected_schema); Ok(Self { @@ -2035,7 +2042,7 @@ pub struct Prepare { /// The name of the statement pub name: String, /// Data types of the parameters ([`Expr::Placeholder`]) - pub data_types: Vec, + pub data_types: Vec, /// The logical plan of the statements pub input: Arc, } diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 86005da3dafa..08db7ab9613b 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -37,6 +37,7 @@ pub mod functions; pub mod other; use arrow::datatypes::DataType; +use datafusion_common::logical_type::LogicalType; /// Determine whether the given data type `dt` represents signed numeric values. pub fn is_signed_numeric(dt: &DataType) -> bool { matches!( @@ -69,19 +70,22 @@ pub fn is_interval(dt: &DataType) -> bool { } /// Determine whether the given data type `dt` is a `Date` or `Timestamp`. -pub fn is_datetime(dt: &DataType) -> bool { +pub fn is_datetime(dt: &LogicalType) -> bool { matches!( dt, - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) + LogicalType::Date32 | LogicalType::Date64 | LogicalType::Timestamp(_, _) ) } /// Determine whether the given data type `dt` is a `Utf8` or `LargeUtf8`. -pub fn is_utf8_or_large_utf8(dt: &DataType) -> bool { - matches!(dt, DataType::Utf8 | DataType::LargeUtf8) +pub fn is_utf8_or_large_utf8(dt: &LogicalType) -> bool { + matches!(dt, LogicalType::Utf8 | LogicalType::LargeUtf8) } /// Determine whether the given data type `dt` is a `Decimal`. -pub fn is_decimal(dt: &DataType) -> bool { - matches!(dt, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) +pub fn is_decimal(dt: &LogicalType) -> bool { + matches!( + dt, + LogicalType::Decimal128(_, _) | LogicalType::Decimal256(_, _) + ) } diff --git a/datafusion/expr/src/type_coercion/other.rs b/datafusion/expr/src/type_coercion/other.rs index 634558094ae7..1505f0019088 100644 --- a/datafusion/expr/src/type_coercion/other.rs +++ b/datafusion/expr/src/type_coercion/other.rs @@ -16,6 +16,7 @@ // under the License. use arrow::datatypes::DataType; +use datafusion_common::logical_type::{ExtensionType, LogicalType}; use super::binary::comparison_coercion; @@ -37,18 +38,22 @@ pub fn get_coerce_type_for_list( /// and the `case_or_else_type`, if specified. /// Returns the common data type for `when_or_then_types` and `case_or_else_type` pub fn get_coerce_type_for_case_expression( - when_or_then_types: &[DataType], - case_or_else_type: Option<&DataType>, -) -> Option { + when_or_then_types: &[LogicalType], + case_or_else_type: Option<&LogicalType>, +) -> Option { let case_or_else_type = match case_or_else_type { None => when_or_then_types[0].clone(), Some(data_type) => data_type.clone(), - }; + } + .physical_type(); + // FIXME comparison_coercion use LogicalType when_or_then_types .iter() + .map(|e| e.physical_type()) .try_fold(case_or_else_type, |left_type, right_type| { // TODO: now just use the `equal` coercion rule for case when. If find the issue, and // refactor again. - comparison_coercion(&left_type, right_type) + comparison_coercion(&left_type, &right_type) }) + .map(Into::into) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 5fc5b5b3f9c7..95f9d573907a 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -22,6 +22,7 @@ use crate::logical_plan::Aggregate; use crate::signature::{Signature, TypeSignature}; use crate::{Cast, Expr, ExprSchemable, GroupingSet, LogicalPlan, TryCast}; use arrow::datatypes::{DataType, TimeUnit}; +use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, @@ -886,39 +887,34 @@ pub(crate) fn find_column_indexes_referenced_by_expr( /// can this data type be used in hash join equal conditions?? /// data types here come from function 'equal_rows', if more data types are supported /// in equal_rows(hash join), add those data types here to generate join logical plan. -pub fn can_hash(data_type: &DataType) -> bool { +pub fn can_hash(data_type: &LogicalType) -> bool { match data_type { - DataType::Null => true, - DataType::Boolean => true, - DataType::Int8 => true, - DataType::Int16 => true, - DataType::Int32 => true, - DataType::Int64 => true, - DataType::UInt8 => true, - DataType::UInt16 => true, - DataType::UInt32 => true, - DataType::UInt64 => true, - DataType::Float32 => true, - DataType::Float64 => true, - DataType::Timestamp(time_unit, None) => match time_unit { + LogicalType::Null => true, + LogicalType::Boolean => true, + LogicalType::Int8 => true, + LogicalType::Int16 => true, + LogicalType::Int32 => true, + LogicalType::Int64 => true, + LogicalType::UInt8 => true, + LogicalType::UInt16 => true, + LogicalType::UInt32 => true, + LogicalType::UInt64 => true, + LogicalType::Float32 => true, + LogicalType::Float64 => true, + LogicalType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => true, TimeUnit::Millisecond => true, TimeUnit::Microsecond => true, TimeUnit::Nanosecond => true, }, - DataType::Utf8 => true, - DataType::LargeUtf8 => true, - DataType::Decimal128(_, _) => true, - DataType::Date32 => true, - DataType::Date64 => true, - DataType::FixedSizeBinary(_) => true, - DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - DataType::is_dictionary_key_type(key_type) - } - DataType::List(_) => true, - DataType::LargeList(_) => true, + LogicalType::Utf8 => true, + LogicalType::LargeUtf8 => true, + LogicalType::Decimal128(_, _) => true, + LogicalType::Date32 => true, + LogicalType::Date64 => true, + LogicalType::FixedSizeBinary(_) => true, + LogicalType::List(_) => true, + LogicalType::LargeList(_) => true, _ => false, } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index bfdbec390199..276e610b5801 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; +use datafusion_common::logical_type::{ExtensionType, LogicalType}; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, @@ -162,10 +163,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; let expr_type = expr.get_type(&self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); - let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( + // FIXME use logical type + let common_type = comparison_coercion(&expr_type.physical_type(), &subquery_type.physical_type()).ok_or(plan_datafusion_err!( "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" ), - )?; + )?.into(); let new_subquery = Subquery { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, @@ -209,8 +211,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(&self.schema)?; - let right_type = pattern.get_type(&self.schema)?; + // FIXME like_coercion use LogicalType + let left_type = expr.get_type(&self.schema)?.physical_type(); + let right_type = pattern.get_type(&self.schema)?.physical_type(); let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -220,7 +223,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { plan_datafusion_err!( "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) - })?; + })?.into(); let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); let expr = Expr::Like(Like::new( @@ -233,16 +236,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter { Ok(expr) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + // FIXME get_input_types use LogicalType let (left_type, right_type) = get_input_types( - &left.get_type(&self.schema)?, + &left.get_type(&self.schema)?.physical_type(), &op, - &right.get_type(&self.schema)?, + &right.get_type(&self.schema)?.physical_type(), )?; Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, &self.schema)?), + Box::new(left.cast_to(&left_type.into(), &self.schema)?), op, - Box::new(right.cast_to(&right_type, &self.schema)?), + Box::new(right.cast_to(&right_type.into(), &self.schema)?), ))) } Expr::Between(Between { @@ -251,8 +255,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { low, high, }) => { - let expr_type = expr.get_type(&self.schema)?; - let low_type = low.get_type(&self.schema)?; + // FIXME comparison_coercion use LogicalType + let expr_type = expr.get_type(&self.schema)?.physical_type(); + let low_type = low.get_type(&self.schema)?.physical_type(); let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { DataFusionError::Internal(format!( @@ -272,7 +277,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { DataFusionError::Internal(format!( "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) - })?; + })?.into(); let expr = Expr::Between(Between::new( Box::new(expr.cast_to(&coercion_type, &self.schema)?), negated, @@ -286,10 +291,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { list, negated, }) => { - let expr_data_type = expr.get_type(&self.schema)?; + // FIXME get_coerce_type_for_list use LogicalType + let expr_data_type = expr.get_type(&self.schema)?.physical_type(); let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(&self.schema)) + .map(|list_expr| { + list_expr.get_type(&self.schema).map(|e| e.physical_type()) + }) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -298,12 +306,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}" ), Some(coerced_type) => { + let logical_coerced_type = coerced_type.into(); // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; + let cast_expr = expr.cast_to(&logical_coerced_type, &self.schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, &self.schema) + list_expr.cast_to(&logical_coerced_type, &self.schema) }) .collect::>>()?; let expr = Expr::InList(InList ::new( @@ -436,15 +445,16 @@ fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result Result { - coerce_scalar(target_type, value).or_else(|err| { + let target_type = target_type.physical_type(); + coerce_scalar(&target_type, value).or_else(|err| { // If type coercion fails, check if the largest type in family works: - if let Some(largest_type) = get_widest_type_in_family(target_type) { + if let Some(largest_type) = get_widest_type_in_family(&target_type) { coerce_scalar(largest_type, value).map_or_else( |_| exec_err!("Cannot cast {value:?} to {target_type:?}"), - |_| ScalarValue::try_from(target_type), + |_| ScalarValue::try_from(&target_type), ) } else { Err(err) @@ -466,7 +476,7 @@ fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> { /// Coerces the given (window frame) `bound` to `target_type`. fn coerce_frame_bound( - target_type: &DataType, + target_type: &LogicalType, bound: &WindowFrameBound, ) -> Result { match bound { @@ -498,7 +508,7 @@ fn coerce_window_frame( if col_type.is_numeric() || is_utf8_or_large_utf8(col_type) { col_type } else if is_datetime(col_type) { - &DataType::Interval(IntervalUnit::MonthDayNano) + &LogicalType::Interval(IntervalUnit::MonthDayNano) } else { return internal_err!( "Cannot run range queries on datatype: {col_type:?}" @@ -508,7 +518,7 @@ fn coerce_window_frame( return internal_err!("ORDER BY column cannot be empty"); } } - WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, + WindowFrameUnits::Rows | WindowFrameUnits::Groups => &LogicalType::UInt64, }; window_frame.start_bound = coerce_frame_bound(target_type, &window_frame.start_bound)?; @@ -520,8 +530,13 @@ fn coerce_window_frame( // The above op will be rewrite to the binary op when creating the physical op. fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result { let left_type = expr.get_type(schema)?; - get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; - cast_expr(expr, &DataType::Boolean, schema) + // FIXME use logical type + get_input_types( + &left_type.physical_type(), + &Operator::IsDistinctFrom, + &DataType::Boolean, + )?; + cast_expr(expr, &LogicalType::Boolean, schema) } /// Returns `expressions` coerced to types compatible with @@ -539,10 +554,13 @@ fn coerce_arguments_for_signature( let current_types = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.get_type(schema).map(|t| t.physical_type())) .collect::>>()?; - - let new_types = data_types(¤t_types, signature)?; + // FIXME data_types use logical type + let new_types = data_types(¤t_types, signature)? + .into_iter() + .map(Into::into) + .collect::>(); expressions .iter() @@ -568,9 +586,9 @@ fn coerce_arguments_for_fun( .into_iter() .map(|expr| { let data_type = expr.get_type(schema).unwrap(); - if let DataType::FixedSizeList(field, _) = data_type { + if let LogicalType::FixedSizeList(field, _) = data_type { let field = field.as_ref().clone(); - let to_type = DataType::List(Arc::new(field)); + let to_type = LogicalType::List(Box::new(field)); expr.cast_to(&to_type, schema) } else { Ok(expr) @@ -589,9 +607,15 @@ fn coerce_arguments_for_fun( let new_type = current_types .iter() .skip(1) - .fold(current_types.first().unwrap().clone(), |acc, x| { - comparison_coercion(&acc, x).unwrap_or(acc) - }); + .map(|e| e.physical_type()) + .fold( + current_types.first().unwrap().clone().physical_type(), + |acc, x| { + // FIXME comparison_coercion use logical type + comparison_coercion(&acc, &x).unwrap_or(acc) + }, + ) + .into(); return expressions .iter() @@ -603,18 +627,18 @@ fn coerce_arguments_for_fun( } /// Cast `expr` to the specified type, if possible -fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result { +fn cast_expr(expr: &Expr, to_type: &LogicalType, schema: &DFSchema) -> Result { expr.clone().cast_to(to_type, schema) } /// Cast array `expr` to the specified type, if possible fn cast_array_expr( expr: &Expr, - from_type: &DataType, - to_type: &DataType, + from_type: &LogicalType, + to_type: &LogicalType, schema: &DFSchema, ) -> Result { - if from_type.equals_datatype(&DataType::Null) { + if from_type == &LogicalType::Null { Ok(expr.clone()) } else { cast_expr(expr, to_type, schema) @@ -635,11 +659,14 @@ fn coerce_agg_exprs_for_signature( } let current_types = input_exprs .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.get_type(schema).map(|e| e.physical_type())) .collect::>>()?; - + // FIXME coerce_types use logical type let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, ¤t_types, signature)?; + type_coercion::aggregates::coerce_types(agg_fun, ¤t_types, signature)? + .into_iter() + .map(Into::into) + .collect::>(); input_exprs .iter() @@ -736,7 +763,9 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { .when_then_expr .into_iter() .map(|(when, then)| { - let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); + let when_type = case_when_coerce_type + .as_ref() + .unwrap_or(&LogicalType::Boolean); let when = when.cast_to(when_type, &schema).map_err(|e| { DataFusionError::Context( format!( diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 68a6a5607a1d..4934b60530de 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::{utils, OptimizerConfig, OptimizerRule}; -use arrow::datatypes::DataType; +use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{ RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, }; @@ -39,7 +39,7 @@ use datafusion_expr::{col, Expr, ExprSchemable}; /// - the expression itself (cloned) /// - counter /// - DataType of this expression. -type ExprSet = HashMap; +type ExprSet = HashMap; /// Identifier type. Current implementation use describe of a expression (type String) as /// Identifier. @@ -794,8 +794,8 @@ mod test { let schema = Arc::new(DFSchema::new_with_metadata( vec![ - DFField::new_unqualified("a", DataType::Int64, false), - DFField::new_unqualified("c", DataType::Int64, false), + DFField::new_unqualified("a", LogicalType::Int64, false), + DFField::new_unqualified("c", LogicalType::Int64, false), ], Default::default(), )?); @@ -1287,7 +1287,7 @@ mod test { fn test_extract_expressions_from_col() -> Result<()> { let mut result = Vec::with_capacity(1); let schema = DFSchema::new_with_metadata( - vec![DFField::new_unqualified("a", DataType::Int32, false)], + vec![DFField::new_unqualified("a", LogicalType::Int32, false)], HashMap::default(), )?; extract_expressions(&col("a"), &schema, &mut result)?; diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 96b46663d8e4..0745831450ba 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -322,8 +322,7 @@ impl SubqueryInfo { mod tests { use super::*; use crate::test::*; - use arrow::datatypes::DataType; - use datafusion_common::Result; + use datafusion_common::{logical_type::LogicalType, Result}; use datafusion_expr::{ and, binary_expr, col, exists, in_subquery, lit, logical_plan::LogicalPlanBuilder, not_exists, not_in_subquery, or, out_ref_col, @@ -523,7 +522,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), )? .project(vec![col("orders.o_custkey")])? .build()?, @@ -564,7 +563,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("lineitem")) .filter( col("lineitem.l_orderkey") - .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")), + .eq(out_ref_col(LogicalType::Int64, "orders.o_orderkey")), )? .project(vec![col("lineitem.l_orderkey")])? .build()?, @@ -575,7 +574,7 @@ mod tests { .filter( in_subquery(col("orders.o_orderkey"), lineitem).and( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), ), )? .project(vec![col("orders.o_custkey")])? @@ -612,7 +611,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .and(col("o_orderkey").eq(lit(1))), )? @@ -647,8 +646,8 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + out_ref_col(LogicalType::Int64, "customer.c_custkey") + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), )? .project(vec![col("orders.o_custkey")])? .build()?, @@ -711,7 +710,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .not_eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -744,7 +743,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .lt(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -777,7 +776,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .or(col("o_orderkey").eq(lit(1))), )? @@ -835,7 +834,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -868,7 +867,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey").add(lit(1))])? @@ -901,7 +900,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])? @@ -930,7 +929,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -967,7 +966,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -1004,7 +1003,7 @@ mod tests { fn in_subquery_correlated() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) - .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").eq(col("sq.a")))? .project(vec![col("c")])? .build()?, ); @@ -1113,7 +1112,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( - out_ref_col(DataType::UInt32, "test.a") + out_ref_col(LogicalType::UInt32, "test.a") .eq(col("sq.a")) .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))), )? @@ -1148,8 +1147,8 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( - out_ref_col(DataType::UInt32, "test.a") - .add(out_ref_col(DataType::UInt32, "test.b")) + out_ref_col(LogicalType::UInt32, "test.a") + .add(out_ref_col(LogicalType::UInt32, "test.b")) .eq(col("sq.a").add(col("sq.b"))) .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))), )? @@ -1184,12 +1183,12 @@ mod tests { let subquery_scan2 = test_table_scan_with_name("sq2")?; let subquery1 = LogicalPlanBuilder::from(subquery_scan1) - .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq1.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").gt(col("sq1.a")))? .project(vec![col("c") * lit(2u32)])? .build()?; let subquery2 = LogicalPlanBuilder::from(subquery_scan2) - .filter(out_ref_col(DataType::UInt32, "test.a").gt(col("sq2.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").gt(col("sq2.a")))? .project(vec![col("c") * lit(2u32)])? .build()?; @@ -1261,7 +1260,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), )? .project(vec![col("orders.o_custkey")])? .build()?, @@ -1292,7 +1291,7 @@ mod tests { LogicalPlanBuilder::from(scan_tpch_table("lineitem")) .filter( col("lineitem.l_orderkey") - .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")), + .eq(out_ref_col(LogicalType::Int64, "orders.o_orderkey")), )? .project(vec![col("lineitem.l_orderkey")])? .build()?, @@ -1303,7 +1302,7 @@ mod tests { .filter( exists(lineitem).and( col("orders.o_custkey") - .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + .eq(out_ref_col(LogicalType::Int64, "customer.c_custkey")), ), )? .project(vec![col("orders.o_custkey")])? @@ -1334,7 +1333,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .and(col("o_orderkey").eq(lit(1))), )? @@ -1362,7 +1361,9 @@ mod tests { fn exists_subquery_no_cols() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter(out_ref_col(DataType::Int64, "customer.c_custkey").eq(lit(1u32)))? + .filter( + out_ref_col(LogicalType::Int64, "customer.c_custkey").eq(lit(1u32)), + )? .project(vec![col("orders.o_custkey")])? .build()?, ); @@ -1407,7 +1408,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .not_eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -1435,7 +1436,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .lt(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -1463,7 +1464,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")) .or(col("o_orderkey").eq(lit(1))), )? @@ -1492,7 +1493,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .build()?, @@ -1518,7 +1519,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey").add(lit(1))])? @@ -1546,7 +1547,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") + out_ref_col(LogicalType::Int64, "customer.c_custkey") .eq(col("orders.o_custkey")), )? .project(vec![col("orders.o_custkey")])? @@ -1600,7 +1601,7 @@ mod tests { fn exists_subquery_correlated() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) - .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").eq(col("sq.a")))? .project(vec![col("c")])? .build()?, ); @@ -1651,12 +1652,12 @@ mod tests { let subquery_scan2 = test_table_scan_with_name("sq2")?; let subquery1 = LogicalPlanBuilder::from(subquery_scan1) - .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq1.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").eq(col("sq1.a")))? .project(vec![col("c")])? .build()?; let subquery2 = LogicalPlanBuilder::from(subquery_scan2) - .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq2.a")))? + .filter(out_ref_col(LogicalType::UInt32, "test.a").eq(col("sq2.a")))? .project(vec![col("c")])? .build()?; @@ -1690,7 +1691,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( (lit(1u32) + col("sq.a")) - .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)), + .gt(out_ref_col(LogicalType::UInt32, "test.a") * lit(2u32)), )? .project(vec![lit(1u32)])? .build()?; @@ -1742,7 +1743,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( (lit(1u32) + col("sq.a")) - .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)), + .gt(out_ref_col(LogicalType::UInt32, "test.a") * lit(2u32)), )? .project(vec![col("sq.c")])? .distinct()? @@ -1770,7 +1771,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( (lit(1u32) + col("sq.a")) - .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)), + .gt(out_ref_col(LogicalType::UInt32, "test.a") * lit(2u32)), )? .project(vec![col("sq.b") + col("sq.c")])? .distinct()? @@ -1798,7 +1799,7 @@ mod tests { let subquery = LogicalPlanBuilder::from(subquery_scan) .filter( (lit(1u32) + col("sq.a")) - .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)), + .gt(out_ref_col(LogicalType::UInt32, "test.a") * lit(2u32)), )? .project(vec![lit(1u32), col("sq.c")])? .distinct()? diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index e4d57f0209a4..90a43bbdaf09 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -298,7 +298,6 @@ fn extract_non_nullable_columns( mod tests { use super::*; use crate::test::*; - use arrow::datatypes::DataType; use datafusion_expr::{ binary_expr, cast, col, lit, logical_plan::builder::LogicalPlanBuilder, @@ -424,9 +423,9 @@ mod tests { None, )? .filter(binary_expr( - cast(col("t1.b"), DataType::Int64).gt(lit(10u32)), + cast(col("t1.b"), LogicalType::Int64).gt(lit(10u32)), And, - try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)), + try_cast(col("t2.c"), LogicalType::Int64).lt(lit(20u32)), ))? .build()?; let expected = "\ diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index b05d811cb481..9b8e189c9bb2 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -704,9 +704,9 @@ mod tests { **optimized_join.schema(), DFSchema::new_with_metadata( vec![ - DFField::new(Some("test"), "a", DataType::UInt32, false), - DFField::new(Some("test"), "b", DataType::UInt32, false), - DFField::new(Some("test2"), "c1", DataType::UInt32, true), + DFField::new(Some("test"), "a", LogicalType::UInt32, false), + DFField::new(Some("test"), "b", LogicalType::UInt32, false), + DFField::new(Some("test2"), "c1", LogicalType::UInt32, true), ], HashMap::new(), )?, @@ -747,9 +747,9 @@ mod tests { **optimized_join.schema(), DFSchema::new_with_metadata( vec![ - DFField::new(Some("test"), "a", DataType::UInt32, false), - DFField::new(Some("test"), "b", DataType::UInt32, false), - DFField::new(Some("test2"), "c1", DataType::UInt32, true), + DFField::new(Some("test"), "a", LogicalType::UInt32, false), + DFField::new(Some("test"), "b", LogicalType::UInt32, false), + DFField::new(Some("test2"), "c1", LogicalType::UInt32, true), ], HashMap::new(), )?, @@ -788,9 +788,9 @@ mod tests { **optimized_join.schema(), DFSchema::new_with_metadata( vec![ - DFField::new(Some("test"), "a", DataType::UInt32, false), - DFField::new(Some("test"), "b", DataType::UInt32, false), - DFField::new(Some("test2"), "a", DataType::UInt32, true), + DFField::new(Some("test"), "a", LogicalType::UInt32, false), + DFField::new(Some("test"), "b", LogicalType::UInt32, false), + DFField::new(Some("test2"), "a", LogicalType::UInt32, true), ], HashMap::new(), )?, @@ -806,7 +806,7 @@ mod tests { let projection = LogicalPlanBuilder::from(table_scan) .project(vec![Expr::Cast(Cast::new( Box::new(col("c")), - DataType::Float64, + LogicalType::Float64, ))])? .build()?; diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 7ac0c25119c3..9fa8e922e0f2 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -372,8 +372,7 @@ fn build_join( mod tests { use super::*; use crate::test::*; - use arrow::datatypes::DataType; - use datafusion_common::Result; + use datafusion_common::{logical_type::LogicalType as DataType, Result}; use datafusion_expr::{ col, lit, logical_plan::LogicalPlanBuilder, max, min, out_ref_col, scalar_subquery, sum, Between, diff --git a/datafusion/optimizer/src/simplify_expressions/context.rs b/datafusion/optimizer/src/simplify_expressions/context.rs index 34f3908c7e42..29c50ec7f477 100644 --- a/datafusion/optimizer/src/simplify_expressions/context.rs +++ b/datafusion/optimizer/src/simplify_expressions/context.rs @@ -17,8 +17,9 @@ //! Structs and traits to provide the information needed for expression simplification. -use arrow::datatypes::DataType; -use datafusion_common::{DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{ + logical_type::LogicalType, DFSchemaRef, DataFusionError, Result, +}; use datafusion_expr::{Expr, ExprSchemable}; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -40,7 +41,7 @@ pub trait SimplifyInfo { fn execution_props(&self) -> &ExecutionProps; /// Returns data type of this expr needed for determining optimized int type of a value - fn get_data_type(&self, expr: &Expr) -> Result; + fn get_data_type(&self, expr: &Expr) -> Result; } /// Provides simplification information based on DFSchema and @@ -100,7 +101,7 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { /// returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result { for schema in &self.schema { - if let Ok(DataType::Boolean) = expr.get_type(schema) { + if let Ok(LogicalType::Boolean) = expr.get_type(schema) { return Ok(true); } } @@ -119,7 +120,7 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { } /// Returns data type of this expr needed for determining optimized int type of a value - fn get_data_type(&self, expr: &Expr) -> Result { + fn get_data_type(&self, expr: &Expr) -> Result { let schema = self.schema.as_ref().ok_or_else(|| { DataFusionError::Internal( "attempt to get data type without schema".to_string(), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 04fdcca0a994..e6638a1c9777 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,7 +32,9 @@ use arrow::{ }; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, + logical_type::{ExtensionType, LogicalType}, tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, + DFField, }; use datafusion_common::{ exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -297,8 +299,15 @@ impl<'a> ConstEvaluator<'a> { // The dummy column name is unused and doesn't matter as only // expressions without column references can be evaluated static DUMMY_COL_NAME: &str = "."; - let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]); - let input_schema = DFSchema::try_from(schema.clone())?; + let input_schema = DFSchema::new_with_metadata( + vec![DFField::new_unqualified( + DUMMY_COL_NAME, + LogicalType::Null, + true, + )], + Default::default(), + )?; + let schema = input_schema.clone().into(); // Need a single "input" row to produce a single output row let col = new_null_array(&DataType::Null, 1); let input_batch = RecordBatch::try_new(std::sync::Arc::new(schema), vec![col])?; @@ -753,7 +762,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Multiply, right, }) if !info.nullable(&left)? - && !info.get_data_type(&left)?.is_floating() + && !is_floating(&info.get_data_type(&left)?) && is_zero(&right) => { *right @@ -764,7 +773,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Multiply, right, }) if !info.nullable(&right)? - && !info.get_data_type(&right)?.is_floating() + && !is_floating(&info.get_data_type(&right)?) && is_zero(&left) => { *left @@ -799,7 +808,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Divide, right, }) if !info.nullable(&left)? - && !info.get_data_type(&left)?.is_floating() + && !is_floating(&info.get_data_type(&left)?) && is_zero(&right) => { return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); @@ -827,7 +836,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Modulo, right, }) if !info.nullable(&left)? - && !info.get_data_type(&left)?.is_floating() + && !is_floating(&info.get_data_type(&left)?) && is_one(&right) => { lit(0) @@ -840,8 +849,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if !info.nullable(&left)? && is_zero(&right) => { match info.get_data_type(&left)? { - DataType::Float32 => lit(f32::NAN), - DataType::Float64 => lit(f64::NAN), + LogicalType::Float32 => lit(f32::NAN), + LogicalType::Float64 => lit(f64::NAN), _ => { return Err(DataFusionError::ArrowError( ArrowError::DivideByZero, @@ -888,7 +897,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?.physical_type(), + )?) } // A & !A -> 0 (if A not nullable) @@ -897,7 +908,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?.physical_type(), + )?) } // (..A..) & A --> (..A..) @@ -970,7 +983,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?.physical_type(), + )?) } // A | !A -> -1 (if A not nullable) @@ -979,7 +994,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?.physical_type(), + )?) } // (..A..) | A --> (..A..) @@ -1052,7 +1069,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?.physical_type(), + )?) } // A ^ !A -> -1 (if A not nullable) @@ -1061,7 +1080,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?.physical_type(), + )?) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) @@ -1072,7 +1093,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); if expr == *right { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) + Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&right)?.physical_type(), + )?) } else { expr } @@ -1086,7 +1109,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); if expr == *left { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?.physical_type(), + )?) } else { expr } @@ -1293,6 +1318,13 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } +fn is_floating(dt: &LogicalType) -> bool { + matches!( + dt, + LogicalType::Float16 | LogicalType::Float32 | LogicalType::Float64 + ) +} + #[cfg(test)] mod tests { use std::{ @@ -1313,7 +1345,10 @@ mod tests { datatypes::{DataType, Field, Schema}, }; use chrono::{DateTime, TimeZone, Utc}; - use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; + use datafusion_common::{ + assert_contains, cast::as_int32_array, logical_type::LogicalType, DFField, + ToDFSchema, + }; use datafusion_expr::*; use datafusion_physical_expr::{ execution_props::ExecutionProps, @@ -2799,14 +2834,14 @@ mod tests { Arc::new( DFSchema::new_with_metadata( vec![ - DFField::new_unqualified("c1", DataType::Utf8, true), - DFField::new_unqualified("c2", DataType::Boolean, true), - DFField::new_unqualified("c3", DataType::Int64, true), - DFField::new_unqualified("c4", DataType::UInt32, true), - DFField::new_unqualified("c1_non_null", DataType::Utf8, false), - DFField::new_unqualified("c2_non_null", DataType::Boolean, false), - DFField::new_unqualified("c3_non_null", DataType::Int64, false), - DFField::new_unqualified("c4_non_null", DataType::UInt32, false), + DFField::new_unqualified("c1", LogicalType::Utf8, true), + DFField::new_unqualified("c2", LogicalType::Boolean, true), + DFField::new_unqualified("c3", LogicalType::Int64, true), + DFField::new_unqualified("c4", LogicalType::UInt32, true), + DFField::new_unqualified("c1_non_null", LogicalType::Utf8, false), + DFField::new_unqualified("c2_non_null", LogicalType::Boolean, false), + DFField::new_unqualified("c3_non_null", LogicalType::Int64, false), + DFField::new_unqualified("c4_non_null", LogicalType::UInt32, false), ], HashMap::new(), ) @@ -2893,7 +2928,7 @@ mod tests { #[test] fn simplify_expr_eq() { let schema = expr_test_schema(); - assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); + assert_eq!(col("c2").get_type(&schema).unwrap(), LogicalType::Boolean); // true = true -> true assert_eq!(simplify(lit(true).eq(lit(true))), lit(true)); @@ -2917,7 +2952,7 @@ mod tests { // expression to non-boolean. // // Make sure c1 column to be used in tests is not boolean type - assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); + assert_eq!(col("c1").get_type(&schema).unwrap(), LogicalType::Utf8); // don't fold c1 = foo assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),); @@ -2927,7 +2962,7 @@ mod tests { fn simplify_expr_not_eq() { let schema = expr_test_schema(); - assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); + assert_eq!(col("c2").get_type(&schema).unwrap(), LogicalType::Boolean); // c2 != true -> !c2 assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),); @@ -2948,7 +2983,7 @@ mod tests { // when one of the operand is not of boolean type, folding the // other boolean constant will change return type of // expression to non-boolean. - assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); + assert_eq!(col("c1").get_type(&schema).unwrap(), LogicalType::Utf8); assert_eq!( simplify(col("c1").not_eq(lit("foo"))), diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 17e5d97c3006..07541f0891e4 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -18,6 +18,7 @@ //! Utility functions for expression simplification use crate::simplify_expressions::SimplifyInfo; +use datafusion_common::logical_type::ExtensionType; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ @@ -350,7 +351,9 @@ pub fn distribute_negation(expr: Expr) -> Expr { /// 3. Log(a, Power(a, b)) ===> b pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result { let mut number = ¤t_args[0]; - let mut base = &Expr::Literal(ScalarValue::new_ten(&info.get_data_type(number)?)?); + let mut base = &Expr::Literal(ScalarValue::new_ten( + &info.get_data_type(number)?.physical_type(), + )?); if current_args.len() == 2 { base = ¤t_args[0]; number = ¤t_args[1]; @@ -358,10 +361,13 @@ pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result + if value + == &ScalarValue::new_one( + &info.get_data_type(number)?.physical_type(), + )? => { Ok(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(base)?, + &info.get_data_type(base)?.physical_type(), )?)) } Expr::ScalarFunction(ScalarFunction { @@ -371,7 +377,7 @@ pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result { if number == base { Ok(Expr::Literal(ScalarValue::new_one( - &info.get_data_type(number)?, + &info.get_data_type(number)?.physical_type(), )?)) } else { Ok(Expr::ScalarFunction(ScalarFunction::new( @@ -393,14 +399,20 @@ pub fn simpl_power(current_args: Vec, info: &dyn SimplifyInfo) -> Result + if value + == &ScalarValue::new_zero( + &info.get_data_type(exponent)?.physical_type(), + )? => { Ok(Expr::Literal(ScalarValue::new_one( - &info.get_data_type(base)?, + &info.get_data_type(base)?.physical_type(), )?)) } Expr::Literal(value) - if value == &ScalarValue::new_one(&info.get_data_type(exponent)?)? => + if value + == &ScalarValue::new_one( + &info.get_data_type(exponent)?.physical_type(), + )? => { Ok(base.clone()) } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 907c12b7afb1..7e2c3876f0aa 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -22,9 +22,10 @@ use crate::optimizer::ApplyOrder; use crate::utils::merge_schema; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ - DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, + TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; +use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -284,27 +285,27 @@ fn is_comparison_op(op: &Operator) -> bool { ) } -fn is_support_data_type(data_type: &DataType) -> bool { +fn is_support_data_type(data_type: &LogicalType) -> bool { matches!( data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - | DataType::Timestamp(_, _) + LogicalType::UInt8 + | LogicalType::UInt16 + | LogicalType::UInt32 + | LogicalType::UInt64 + | LogicalType::Int8 + | LogicalType::Int16 + | LogicalType::Int32 + | LogicalType::Int64 + | LogicalType::Decimal128(_, _) + | LogicalType::Timestamp(_, _) ) } fn try_cast_literal_to_type( lit_value: &ScalarValue, - target_type: &DataType, + target_type: &LogicalType, ) -> Result> { - let lit_data_type = lit_value.data_type(); + let lit_data_type = lit_value.logical_type(); // the rule just support the signed numeric data type now if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) { return Ok(None); @@ -314,31 +315,31 @@ fn try_cast_literal_to_type( return Ok(Some(ScalarValue::try_from(target_type)?)); } let mul = match target_type { - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => 1_i128, - DataType::Timestamp(_, _) => 1_i128, - DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), + LogicalType::UInt8 + | LogicalType::UInt16 + | LogicalType::UInt32 + | LogicalType::UInt64 + | LogicalType::Int8 + | LogicalType::Int16 + | LogicalType::Int32 + | LogicalType::Int64 => 1_i128, + LogicalType::Timestamp(_, _) => 1_i128, + LogicalType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), other_type => { return internal_err!("Error target data type {other_type:?}"); } }; let (target_min, target_max) = match target_type { - DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), - DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), - DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), - DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), - DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), - DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), - DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), - DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), - DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), - DataType::Decimal128(precision, _) => ( + LogicalType::UInt8 => (u8::MIN as i128, u8::MAX as i128), + LogicalType::UInt16 => (u16::MIN as i128, u16::MAX as i128), + LogicalType::UInt32 => (u32::MIN as i128, u32::MAX as i128), + LogicalType::UInt64 => (u64::MIN as i128, u64::MAX as i128), + LogicalType::Int8 => (i8::MIN as i128, i8::MAX as i128), + LogicalType::Int16 => (i16::MIN as i128, i16::MAX as i128), + LogicalType::Int32 => (i32::MIN as i128, i32::MAX as i128), + LogicalType::Int64 => (i64::MIN as i128, i64::MAX as i128), + LogicalType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), + LogicalType::Decimal128(precision, _) => ( // Different precision for decimal128 can store different range of value. // For example, the precision is 3, the max of value is `999` and the min // value is `-999` @@ -393,47 +394,47 @@ fn try_cast_literal_to_type( // the value casted from lit to the target type is in the range of target type. // return the target type of scalar value let result_scalar = match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value as i64)), - DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), - DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), - DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), - DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), - DataType::Timestamp(TimeUnit::Second, tz) => { + LogicalType::Int8 => ScalarValue::Int8(Some(value as i8)), + LogicalType::Int16 => ScalarValue::Int16(Some(value as i16)), + LogicalType::Int32 => ScalarValue::Int32(Some(value as i32)), + LogicalType::Int64 => ScalarValue::Int64(Some(value as i64)), + LogicalType::UInt8 => ScalarValue::UInt8(Some(value as u8)), + LogicalType::UInt16 => ScalarValue::UInt16(Some(value as u16)), + LogicalType::UInt32 => ScalarValue::UInt32(Some(value as u32)), + LogicalType::UInt64 => ScalarValue::UInt64(Some(value as u64)), + LogicalType::Timestamp(TimeUnit::Second, tz) => { let value = cast_between_timestamp( lit_data_type, - DataType::Timestamp(TimeUnit::Second, tz.clone()), + LogicalType::Timestamp(TimeUnit::Second, tz.clone()), value, ); ScalarValue::TimestampSecond(value, tz.clone()) } - DataType::Timestamp(TimeUnit::Millisecond, tz) => { + LogicalType::Timestamp(TimeUnit::Millisecond, tz) => { let value = cast_between_timestamp( lit_data_type, - DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + LogicalType::Timestamp(TimeUnit::Millisecond, tz.clone()), value, ); ScalarValue::TimestampMillisecond(value, tz.clone()) } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { + LogicalType::Timestamp(TimeUnit::Microsecond, tz) => { let value = cast_between_timestamp( lit_data_type, - DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + LogicalType::Timestamp(TimeUnit::Microsecond, tz.clone()), value, ); ScalarValue::TimestampMicrosecond(value, tz.clone()) } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + LogicalType::Timestamp(TimeUnit::Nanosecond, tz) => { let value = cast_between_timestamp( lit_data_type, - DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + LogicalType::Timestamp(TimeUnit::Nanosecond, tz.clone()), value, ); ScalarValue::TimestampNanosecond(value, tz.clone()) } - DataType::Decimal128(p, s) => { + LogicalType::Decimal128(p, s) => { ScalarValue::Decimal128(Some(value), *p, *s) } other_type => { @@ -449,21 +450,25 @@ fn try_cast_literal_to_type( } /// Cast a timestamp value from one unit to another -fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { +fn cast_between_timestamp( + from: LogicalType, + to: LogicalType, + value: i128, +) -> Option { let value = value as i64; let from_scale = match from { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + LogicalType::Timestamp(TimeUnit::Second, _) => 1, + LogicalType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + LogicalType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + LogicalType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, _ => return Some(value), }; let to_scale = match to { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + LogicalType::Timestamp(TimeUnit::Second, _) => 1, + LogicalType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + LogicalType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + LogicalType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, _ => return Some(value), }; diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 872071e52fa7..7676fdbd7299 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -18,6 +18,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion_common::config::ConfigOptions; +use datafusion_common::logical_type::LogicalType; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; use datafusion_optimizer::analyzer::Analyzer; @@ -404,7 +405,7 @@ impl ContextProvider for MyContextProvider { None } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 7d5f16c454d6..580de48c7f7a 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -18,7 +18,7 @@ //! get field of a `ListArray` use crate::PhysicalExpr; -use datafusion_common::exec_err; +use datafusion_common::{exec_err, plan_err}; use crate::array_expressions::{array_element, array_slice}; use crate::physical_expr::down_cast_any_ref; @@ -144,14 +144,33 @@ impl GetIndexedFieldExpr { GetFieldAccessExpr::NamedStructField { name } => { GetFieldAccessSchema::NamedStructField { name: name.clone() } } - GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex { - key_dt: key.data_type(input_schema)?, - }, + GetFieldAccessExpr::ListIndex { key } => { + let data_type = key.data_type(input_schema)?; + match data_type { + DataType::Int64 => {} + _ => { + return plan_err!( + "Only ints are valid as an indexed field in a list" + ); + } + } + + GetFieldAccessSchema::ListIndex + } GetFieldAccessExpr::ListRange { start, stop } => { - GetFieldAccessSchema::ListRange { - start_dt: start.data_type(input_schema)?, - stop_dt: stop.data_type(input_schema)?, + match ( + start.data_type(input_schema)?, + stop.data_type(input_schema)?, + ) { + (DataType::Int64, DataType::Int64) => {} + _ => { + return plan_err!( + "Only ints are valid as an indexed field in a list" + ); + } } + + GetFieldAccessSchema::ListRange } }) } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f318cd3b0f4d..939eec658392 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -25,6 +25,7 @@ use crate::{ PhysicalExpr, }; use arrow::datatypes::Schema; +use datafusion_common::logical_type::ExtensionType; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, @@ -278,12 +279,12 @@ pub fn create_physical_expr( Expr::Cast(Cast { expr, data_type }) => expressions::cast( create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, input_schema, - data_type.clone(), + data_type.physical_type(), ), Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, input_schema, - data_type.clone(), + data_type.physical_type(), ), Expr::Not(expr) => expressions::not(create_physical_expr( expr, diff --git a/datafusion/physical-expr/src/var_provider.rs b/datafusion/physical-expr/src/var_provider.rs index e00cf7407237..86624a351ced 100644 --- a/datafusion/physical-expr/src/var_provider.rs +++ b/datafusion/physical-expr/src/var_provider.rs @@ -17,8 +17,7 @@ //! Variable provider -use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{logical_type::LogicalType, Result, ScalarValue}; /// Variable type, system/user defined #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -35,7 +34,7 @@ pub trait VarProvider: std::fmt::Debug { fn get_value(&self, var_names: Vec) -> Result; /// Return the type of the given variable - fn get_type(&self, var_names: &[String]) -> Option; + fn get_type(&self, var_names: &[String]) -> Option; } pub fn is_system_variables(variable_names: &[String]) -> bool { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e426c598523e..d8a4c854e80f 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -43,6 +43,7 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; +use datafusion_common::logical_type::LogicalType; use datafusion_common::plan_datafusion_err; use datafusion_common::{ context, internal_err, not_impl_err, parsers::CompressionTypeVariant, @@ -771,10 +772,12 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Prepare(prepare) => { let input: LogicalPlan = into_logical_plan!(prepare.input, ctx, extension_codec)?; - let data_types: Vec = prepare + // FIXME Use LogicalType in proto + let data_types: Vec = prepare .data_types .iter() .map(DataType::try_from) + .map(Into::into) .collect::>()?; LogicalPlanBuilder::from(input) .prepare(prepare.name.clone(), data_types)? diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 9df65b99a748..083ab721ce73 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -17,6 +17,7 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; +use datafusion_common::logical_type::LogicalType; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::WindowUDF; use datafusion_expr::{ @@ -120,7 +121,7 @@ impl ContextProvider for MyContextProvider { None } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 8c0184b6d119..59b39a075bf1 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -69,7 +69,7 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result }; // do the actual lookup to the appropriate data type - let data_type = parse_data_type(&data_type_string)?; + let data_type = parse_data_type(&data_type_string)?.into(); arg0.cast_to(&data_type, schema) } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 1cf0fc133f04..4690035562e4 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -28,7 +28,7 @@ mod unary_op; mod value; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow_schema::DataType; +use datafusion_common::logical_type::LogicalType; use datafusion_common::{ internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result, ScalarValue, @@ -584,7 +584,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; - if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { + if pattern_type != LogicalType::Utf8 && pattern_type != LogicalType::Null { return plan_err!("Invalid pattern in LIKE expression"); } Ok(Expr::Like(Like::new( @@ -607,7 +607,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; - if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { + if pattern_type != LogicalType::Utf8 && pattern_type != LogicalType::Null { return plan_err!("Invalid pattern in SIMILAR TO expression"); } Ok(Expr::SimilarTo(Like::new( @@ -750,6 +750,7 @@ mod tests { use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::logical_type::LogicalType; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -799,7 +800,7 @@ mod tests { None } - fn get_variable_type(&self, _variable_names: &[String]) -> Option { + fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 3a06fdb158f7..19d9192ac986 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -20,6 +20,7 @@ use arrow::array::new_null_array; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; +use datafusion_common::logical_type::LogicalType; use datafusion_common::{ not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; @@ -35,7 +36,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_value( &self, value: Value, - param_data_types: &[DataType], + param_data_types: &[LogicalType], ) -> Result { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), @@ -96,7 +97,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on. fn create_placeholder_expr( param: String, - param_data_types: &[DataType], + param_data_types: &[LogicalType], ) -> Result { // Parse the placeholder as a number because it is the only support from sqlparser and postgres let index = param[1..].parse::(); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ca5e260aee05..b01ee3dfd915 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -23,6 +23,9 @@ use std::vec; use arrow_schema::*; use datafusion_common::field_not_found; use datafusion_common::internal_err; +use datafusion_common::logical_type::LogicalType; +use datafusion_common::logical_type::TypeSignature; +use datafusion_common::DFField; use datafusion_expr::WindowUDF; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; @@ -40,6 +43,7 @@ use datafusion_expr::utils::find_column_exprs; use datafusion_expr::TableSource; use datafusion_expr::{col, AggregateUDF, Expr, ScalarUDF}; +use crate::statement::object_name_to_string; use crate::utils::make_decimal_type; /// The ContextProvider trait allows the query planner to obtain meta-data about tables and @@ -58,7 +62,11 @@ pub trait ContextProvider { /// Getter for a UDWF fn get_window_meta(&self, name: &str) -> Option>; /// Getter for system/user-defined variable type - fn get_variable_type(&self, variable_names: &[String]) -> Option; + fn get_variable_type(&self, variable_names: &[String]) -> Option; + /// Getter for extension data type + fn get_data_type(&self, _name: &TypeSignature) -> Option { + None + } /// Get configuration options fn options(&self) -> &ConfigOptions; @@ -115,7 +123,7 @@ impl IdentNormalizer { pub struct PlannerContext { /// Data types for numbered parameters ($1, $2, etc), if supplied /// in `PREPARE` statement - prepare_param_data_types: Vec, + prepare_param_data_types: Vec, /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, @@ -142,7 +150,7 @@ impl PlannerContext { /// Update the PlannerContext with provided prepare_param_data_types pub fn with_prepare_param_data_types( mut self, - prepare_param_data_types: Vec, + prepare_param_data_types: Vec, ) -> Self { self.prepare_param_data_types = prepare_param_data_types; self @@ -164,7 +172,7 @@ impl PlannerContext { } /// Return the types of parameters (`$1`, `$2`, etc) if known - pub fn prepare_param_data_types(&self) -> &[DataType] { + pub fn prepare_param_data_types(&self) -> &[LogicalType] { &self.prepare_param_data_types } @@ -211,7 +219,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - pub fn build_schema(&self, columns: Vec) -> Result { + pub fn build_schema(&self, columns: Vec) -> Result { let mut fields = Vec::with_capacity(columns.len()); for column in columns { @@ -220,14 +228,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .options .iter() .any(|x| x.option == ColumnOption::NotNull); - fields.push(Field::new( - self.normalizer.normalize(column.name), + fields.push(DFField::new_unqualified( + &self.normalizer.normalize(column.name), data_type, !not_nullable, )); } - Ok(Schema::new(fields)) + DFSchema::new_with_metadata(fields, Default::default()) } /// Apply the given TableAlias to the input plan @@ -295,15 +303,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) } - pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { + pub(crate) fn convert_data_type( + &self, + sql_type: &SQLDataType, + ) -> Result { match sql_type { SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type)) => { let data_type = self.convert_simple_data_type(inner_sql_type)?; - Ok(DataType::List(Arc::new(Field::new( - "field", data_type, true, - )))) + Ok(LogicalType::List(Box::new(data_type))) } SQLDataType::Array(ArrayElemTypeDef::None) => { not_impl_err!("Arrays with unspecified type is not supported") @@ -312,26 +321,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn convert_simple_data_type(&self, sql_type: &SQLDataType) -> Result { + fn convert_simple_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { - SQLDataType::Boolean | SQLDataType::Bool => Ok(DataType::Boolean), - SQLDataType::TinyInt(_) => Ok(DataType::Int8), - SQLDataType::SmallInt(_) | SQLDataType::Int2(_) => Ok(DataType::Int16), - SQLDataType::Int(_) | SQLDataType::Integer(_) | SQLDataType::Int4(_) => Ok(DataType::Int32), - SQLDataType::BigInt(_) | SQLDataType::Int8(_) => Ok(DataType::Int64), - SQLDataType::UnsignedTinyInt(_) => Ok(DataType::UInt8), - SQLDataType::UnsignedSmallInt(_) | SQLDataType::UnsignedInt2(_) => Ok(DataType::UInt16), + SQLDataType::Boolean | SQLDataType::Bool => Ok(LogicalType::Boolean), + SQLDataType::TinyInt(_) => Ok(LogicalType::Int8), + SQLDataType::SmallInt(_) | SQLDataType::Int2(_) => Ok(LogicalType::Int16), + SQLDataType::Int(_) | SQLDataType::Integer(_) | SQLDataType::Int4(_) => Ok(LogicalType::Int32), + SQLDataType::BigInt(_) | SQLDataType::Int8(_) => Ok(LogicalType::Int64), + SQLDataType::UnsignedTinyInt(_) => Ok(LogicalType::UInt8), + SQLDataType::UnsignedSmallInt(_) | SQLDataType::UnsignedInt2(_) => Ok(LogicalType::UInt16), SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) | SQLDataType::UnsignedInt4(_) => { - Ok(DataType::UInt32) + Ok(LogicalType::UInt32) } - SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(DataType::UInt64), - SQLDataType::Float(_) => Ok(DataType::Float32), - SQLDataType::Real | SQLDataType::Float4 => Ok(DataType::Float32), - SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64), + SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(LogicalType::UInt64), + SQLDataType::Float(_) => Ok(LogicalType::Float32), + SQLDataType::Real | SQLDataType::Float4 => Ok(LogicalType::Float32), + SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(LogicalType::Float64), SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text - | SQLDataType::String(_) => Ok(DataType::Utf8), + | SQLDataType::String(_) => Ok(LogicalType::Utf8), SQLDataType::Timestamp(None, tz_info) => { let tz = if matches!(tz_info, TimezoneInfo::Tz) || matches!(tz_info, TimezoneInfo::WithTimeZone) @@ -344,14 +353,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Timestamp Without Time zone None }; - Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.map(Into::into))) + Ok(LogicalType::Timestamp(TimeUnit::Nanosecond, tz.map(Into::into))) } - SQLDataType::Date => Ok(DataType::Date32), + SQLDataType::Date => Ok(LogicalType::Date32), SQLDataType::Time(None, tz_info) => { if matches!(tz_info, TimezoneInfo::None) || matches!(tz_info, TimezoneInfo::WithoutTimeZone) { - Ok(DataType::Time64(TimeUnit::Nanosecond)) + Ok(LogicalType::Time64(TimeUnit::Nanosecond)) } else { // We dont support TIMETZ and TIME WITH TIME ZONE for now not_impl_err!( @@ -370,11 +379,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; make_decimal_type(precision, scale) } - SQLDataType::Bytea => Ok(DataType::Binary), - SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + SQLDataType::Bytea => Ok(LogicalType::Binary), + SQLDataType::Interval => Ok(LogicalType::Interval(IntervalUnit::MonthDayNano)), // Explicitly list all other types so that if sqlparser // adds/changes the `SQLDataType` the compiler will tell us on upgrade // and avoid bugs like https://github.com/apache/arrow-datafusion/issues/3059 + SQLDataType::Custom(name, params) => { + let name = object_name_to_string(name); + let params = params.iter().map(Into::into).collect(); + let type_signature = TypeSignature::new_with_params(name, params); + if let Some(data_type) = self.context_provider.get_data_type(&type_signature) { + return Ok(data_type); + } + + plan_err!("User-Defined SQL type {sql_type:?} not found") + } SQLDataType::Nvarchar(_) | SQLDataType::JSON | SQLDataType::Uuid @@ -383,7 +402,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Blob(_) | SQLDataType::Datetime(_) | SQLDataType::Regclass - | SQLDataType::Custom(_, _) | SQLDataType::Array(_) | SQLDataType::Enum(_) | SQLDataType::Set(_) diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index b119672eae5f..50fc7ac9afd0 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -145,17 +145,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } JoinConstraint::Natural => { - let left_cols: HashSet<&String> = left - .schema() - .fields() - .iter() - .map(|f| f.field().name()) - .collect(); + let left_cols: HashSet<&String> = + left.schema().fields().iter().map(|f| f.name()).collect(); let keys: Vec = right .schema() .fields() .iter() - .map(|f| f.field().name()) + .map(|f| f.name()) .filter(|f| left_cols.contains(f)) .map(Column::from_name) .collect(); diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index ecc77b044223..dec5837e9201 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -27,8 +27,8 @@ use crate::planner::{ }; use crate::utils::normalize_ident; -use arrow_schema::DataType; use datafusion_common::file_options::StatementOptions; +use datafusion_common::logical_type::{LogicalType, NamedLogicalType}; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, Column, @@ -42,14 +42,16 @@ use datafusion_expr::logical_plan::DdlStatement; use datafusion_expr::utils::expr_to_columns; use datafusion_expr::{ cast, col, Analyze, CreateCatalog, CreateCatalogSchema, - CreateExternalTable as PlanCreateExternalTable, CreateMemoryTable, CreateView, - DescribeTable, DmlStatement, DropCatalogSchema, DropTable, DropView, EmptyRelation, - Explain, ExprSchemable, Filter, LogicalPlan, LogicalPlanBuilder, PlanType, Prepare, - SetVariable, Statement as PlanStatement, ToStringifiedPlan, TransactionAccessMode, - TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, - WriteOp, + CreateExternalTable as PlanCreateExternalTable, CreateMemoryTable, CreateType, + CreateView, DescribeTable, DmlStatement, DropCatalogSchema, DropTable, DropView, + EmptyRelation, Explain, ExprSchemable, Filter, LogicalPlan, LogicalPlanBuilder, + PlanType, Prepare, SetVariable, Statement as PlanStatement, ToStringifiedPlan, + TransactionAccessMode, TransactionConclusion, TransactionEnd, + TransactionIsolationLevel, TransactionStart, WriteOp, +}; +use sqlparser::ast::{ + self, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeRepresentation, }; -use sqlparser::ast; use sqlparser::ast::{ Assignment, ColumnDef, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, SchemaName, SetExpr, ShowCreateObject, ShowStatementFilter, Statement, @@ -61,7 +63,7 @@ fn ident_to_string(ident: &Ident) -> String { normalize_ident(ident.to_owned()) } -fn object_name_to_string(object_name: &ObjectName) -> String { +pub fn object_name_to_string(object_name: &ObjectName) -> String { object_name .0 .iter() @@ -210,13 +212,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let input_schema = plan.schema(); let plan = if !columns.is_empty() { - let schema = self.build_schema(columns)?.to_dfschema_ref()?; + let schema = self.build_schema(columns)?; if schema.fields().len() != input_schema.fields().len() { return plan_err!( - "Mismatch: {} columns specified, but result has {} columns", - schema.fields().len(), - input_schema.fields().len() - ); + "Mismatch: {} columns specified, but result has {} columns", + schema.fields().len(), + input_schema.fields().len() + ); } let input_fields = input_schema.fields(); let project_exprs = schema @@ -255,7 +257,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => { - let schema = self.build_schema(columns)?.to_dfschema_ref()?; + let schema = Arc::new(self.build_schema(columns)?); let plan = EmptyRelation { produce_one_row: false, schema, @@ -323,6 +325,43 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: Arc::new(DFSchema::empty()), }, ))), + Statement::CreateType { + name, + representation, + } => { + let name = object_name_to_string(&name); + + let UserDefinedTypeRepresentation::Composite { attributes } = + representation; + let fields = attributes + .into_iter() + .map( + |UserDefinedTypeCompositeAttributeDef { + name, + data_type, + collation, + }| { + if collation.is_some() { + return not_impl_err!("Collation not supported"); + } + let field_name = ident_to_string(&name); + let field_data_type = self.convert_data_type(&data_type)?; + + Ok(Arc::new(NamedLogicalType::new( + field_name, + field_data_type, + ))) + }, + ) + .collect::>>()?; + let data_type = LogicalType::Struct(fields.into()); + + Ok(LogicalPlan::Ddl(DdlStatement::CreateType(CreateType { + name, + data_type, + schema: Arc::new(DFSchema::empty()), + }))) + } Statement::Drop { object_type, if_exists, @@ -383,10 +422,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { statement, } => { // Convert parser data types to DataFusion data types - let data_types: Vec = data_types + let data_types = data_types .into_iter() .map(|t| self.convert_data_type(&t)) - .collect::>()?; + .collect::>>()?; // Create planner context with parameters let mut planner_context = PlannerContext::new() @@ -629,11 +668,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let schema = table_source.schema(); - let output_schema = DFSchema::try_from(LogicalPlan::describe_schema()).unwrap(); + let output_schema = LogicalPlan::describe_schema()?; Ok(LogicalPlan::DescribeTable(DescribeTable { schema, - output_schema: Arc::new(output_schema), + output_schema, })) } @@ -750,19 +789,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; } - let schema = self.build_schema(columns)?; - let df_schema = schema.to_dfschema_ref()?; + let schema = Arc::new(self.build_schema(columns)?); let ordered_exprs = - self.build_order_by(order_exprs, &df_schema, &mut PlannerContext::new())?; + self.build_order_by(order_exprs, &schema, &mut PlannerContext::new())?; // External tables do not support schemas at the moment, so the name is just a table name let name = OwnedTableReference::bare(name); let constraints = - Constraints::new_from_table_constraints(&all_constraints, &df_schema)?; + Constraints::new_from_table_constraints(&all_constraints, &schema)?; Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( PlanCreateExternalTable { - schema: df_schema, + schema, name, location, file_type, @@ -794,27 +832,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if matches!(plan, LogicalPlan::Explain(_)) { return plan_err!("Nested EXPLAINs are not supported"); } - let plan = Arc::new(plan); - let schema = LogicalPlan::explain_schema(); - let schema = schema.to_dfschema_ref()?; - if analyze { - Ok(LogicalPlan::Analyze(Analyze { - verbose, - input: plan, - schema, - })) - } else { - let stringified_plans = - vec![plan.to_stringified(PlanType::InitialLogicalPlan)]; - Ok(LogicalPlan::Explain(Explain { - verbose, - plan, - stringified_plans, - schema, - logical_optimization_succeeded: false, - })) - } + LogicalPlanBuilder::from(plan) + .explain(verbose, analyze)? + .build() } fn show_variable_to_plan(&self, variable: &[Ident]) -> Result { @@ -1138,7 +1159,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { idx + 1 ) })?; - let dt = field.field().data_type().clone(); + let dt = field.data_type().clone(); let _ = prepare_param_data_types.insert(name, dt); } } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 616a2fc74932..6484d188a2e3 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -18,8 +18,9 @@ //! SQL Utility Functions use arrow_schema::{ - DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; +use datafusion_common::logical_type::LogicalType; use datafusion_common::tree_node::{Transformed, TreeNode}; use sqlparser::ast::Ident; @@ -212,7 +213,7 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr pub(crate) fn make_decimal_type( precision: Option, scale: Option, -) -> Result { +) -> Result { // postgres like behavior let (precision, scale) = match (precision, scale) { (Some(p), Some(s)) => (p as u8, s as i8), @@ -233,9 +234,9 @@ pub(crate) fn make_decimal_type( } else if precision > DECIMAL128_MAX_PRECISION && precision <= DECIMAL256_MAX_PRECISION { - Ok(DataType::Decimal256(precision, scale)) + Ok(LogicalType::Decimal256(precision, scale)) } else { - Ok(DataType::Decimal128(precision, scale)) + Ok(LogicalType::Decimal128(precision, scale)) } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index ff6dca7eef2a..beb129bd54cf 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -20,6 +20,7 @@ use std::collections::HashMap; use std::{sync::Arc, vec}; use arrow_schema::*; +use datafusion_common::logical_type::LogicalType; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use datafusion_common::plan_err; @@ -2813,7 +2814,7 @@ impl ContextProvider for MockContextProvider { self.udafs.get(name).map(Arc::clone) } - fn get_variable_type(&self, _: &[String]) -> Option { + fn get_variable_type(&self, _: &[String]) -> Option { unimplemented!() } @@ -3685,8 +3686,8 @@ fn test_prepare_statement_should_infer_types() { let plan = logical_plan(sql).unwrap(); let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int64)), + ("$1".to_string(), Some(LogicalType::Int32)), + ("$2".to_string(), Some(LogicalType::Int64)), ]); assert_eq!(actual_types, expected_types); } @@ -3699,7 +3700,7 @@ fn test_non_prepare_statement_should_infer_types() { let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ // constant 1 is inferred to be int64 - ("$1".to_string(), Some(DataType::Int64)), + ("$1".to_string(), Some(LogicalType::Int64)), ]); assert_eq!(actual_types, expected_types); } @@ -3874,7 +3875,7 @@ Projection: person.id, orders.order_id let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); + let expected_types = HashMap::from([("$1".to_string(), Some(LogicalType::Int32))]); assert_eq!(actual_types, expected_types); // replace params with values @@ -3906,7 +3907,7 @@ Projection: person.id, person.age let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); + let expected_types = HashMap::from([("$1".to_string(), Some(LogicalType::Int32))]); assert_eq!(actual_types, expected_types); // replace params with values @@ -3938,8 +3939,8 @@ Projection: person.id, person.age let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int32)), + ("$1".to_string(), Some(LogicalType::Int32)), + ("$2".to_string(), Some(LogicalType::Int32)), ]); assert_eq!(actual_types, expected_types); @@ -3976,7 +3977,7 @@ Projection: person.id, person.age let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); + let expected_types = HashMap::from([("$1".to_string(), Some(LogicalType::UInt32))]); assert_eq!(actual_types, expected_types); // replace params with values @@ -4014,8 +4015,8 @@ Dml: op=[Update] table=[person] let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::UInt32)), + ("$1".to_string(), Some(LogicalType::Int32)), + ("$2".to_string(), Some(LogicalType::UInt32)), ]); assert_eq!(actual_types, expected_types); @@ -4049,9 +4050,9 @@ Dml: op=[Insert Into] table=[person] let actual_types = plan.get_parameter_types().unwrap(); let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::UInt32)), - ("$2".to_string(), Some(DataType::Utf8)), - ("$3".to_string(), Some(DataType::Utf8)), + ("$1".to_string(), Some(LogicalType::UInt32)), + ("$2".to_string(), Some(LogicalType::Utf8)), + ("$3".to_string(), Some(LogicalType::Utf8)), ]); assert_eq!(actual_types, expected_types);