diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 89e3217f730b..1d26485b4e03 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,11 +21,11 @@ use crate::expr::{ GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction, }; -use crate::field_util::{get_indexed_field, GetFieldAccessCharacteristic}; +use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; use crate::{LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{ plan_err, Column, DFField, DFSchema, DataFusionError, ExprSchema, Result, }; @@ -157,26 +157,7 @@ impl ExprSchemable for Expr { Ok(DataType::Null) } Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr_dt = expr.get_type(schema)?; - let field_ch = match field { - GetFieldAccess::NamedStructField { name } => { - GetFieldAccessCharacteristic::NamedStructField { - name: name.clone(), - } - } - GetFieldAccess::ListIndex { key } => { - GetFieldAccessCharacteristic::ListIndex { - key_dt: key.get_type(schema)?, - } - } - GetFieldAccess::ListRange { start, stop } => { - GetFieldAccessCharacteristic::ListRange { - start_dt: start.get_type(schema)?, - stop_dt: stop.get_type(schema)?, - } - } - }; - get_indexed_field(&expr_dt, &field_ch).map(|x| x.data_type().clone()) + field_for_index(expr, field, schema).map(|x| x.data_type().clone()) } } } @@ -285,26 +266,7 @@ impl ExprSchemable for Expr { .to_owned(), )), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr_dt = expr.get_type(input_schema)?; - let field_ch = match field { - GetFieldAccess::NamedStructField { name } => { - GetFieldAccessCharacteristic::NamedStructField { - name: name.clone(), - } - } - GetFieldAccess::ListIndex { key } => { - GetFieldAccessCharacteristic::ListIndex { - key_dt: key.get_type(input_schema)?, - } - } - GetFieldAccess::ListRange { start, stop } => { - GetFieldAccessCharacteristic::ListRange { - start_dt: start.get_type(input_schema)?, - stop_dt: stop.get_type(input_schema)?, - } - } - }; - get_indexed_field(&expr_dt, &field_ch).map(|x| x.is_nullable()) + field_for_index(expr, field, input_schema).map(|x| x.is_nullable()) } Expr::GroupingSet(_) => { // grouping sets do not really have the concept of nullable and do not appear @@ -373,6 +335,28 @@ impl ExprSchemable for Expr { } } +/// return the schema [`Field`] for the type referenced by `get_indexed_field` +fn field_for_index( + expr: &Expr, + field: &GetFieldAccess, + schema: &S, +) -> Result { + let expr_dt = expr.get_type(schema)?; + match field { + GetFieldAccess::NamedStructField { name } => { + GetFieldAccessSchema::NamedStructField { name: name.clone() } + } + GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex { + key_dt: key.get_type(schema)?, + }, + GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange { + start_dt: start.get_type(schema)?, + stop_dt: stop.get_type(schema)?, + }, + } + .get_accessed_field(&expr_dt) +} + /// cast subquery in InSubquery/ScalarSubquery to a given type. pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { if subquery.subquery.schema().field(0).data_type() == cast_to_type { diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index 5b08e9c7d999..d405deb93772 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -20,64 +20,65 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; -pub enum GetFieldAccessCharacteristic { - /// returns the field `struct[field]`. For example `struct["name"]` +/// Types of the field access expression of a nested type, such as `Field` or `List` +pub enum GetFieldAccessSchema { + /// Named field, For example `struct["name"]` NamedStructField { name: ScalarValue }, - /// single list index - // list[i] + /// Single list index, for example: `list[i]` ListIndex { key_dt: DataType }, - /// list range `list[i:j]` + /// List range, for example `list[i:j]` ListRange { start_dt: DataType, stop_dt: DataType, }, } -/// Returns the field access indexed by `key` and/or `extra_key` from a [`DataType::List`] or [`DataType::Struct`] -/// # Error -/// Errors if -/// * the `data_type` is not a Struct or a List, -/// * the `data_type` of extra key does not match with `data_type` of key -/// * there is no field key is not of the required index type -pub fn get_indexed_field( - data_type: &DataType, - field_characteristic: &GetFieldAccessCharacteristic, -) -> Result { - match field_characteristic { - GetFieldAccessCharacteristic::NamedStructField{ name } => { - match (data_type, name) { - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { - if s.is_empty() { - plan_err!( - "Struct based indexed access requires a non empty string" - ) - } else { - let field = fields.iter().find(|f| f.name() == s); - field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone()) +impl GetFieldAccessSchema { + /// Returns the schema [`Field`] from a [`DataType::List`] or + /// [`DataType::Struct`] indexed by this structure + /// + /// # Error + /// Errors if + /// * the `data_type` is not a Struct or a List, + /// * the `data_type` of extra key does not match with `data_type` of key + /// * there is no field key is not of the required index type + pub fn get_accessed_field(&self, data_type: &DataType) -> Result { + match self { + Self::NamedStructField{ name } => { + match (data_type, name) { + (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { + if s.is_empty() { + plan_err!( + "Struct based indexed access requires a non empty string" + ) + } else { + let field = fields.iter().find(|f| f.name() == s); + field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone()) + } } + (DataType::Struct(_), _) => plan_err!( + "Only utf8 strings are valid as an indexed field in a struct" + ), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), } - (DataType::Struct(_), _) => plan_err!( - "Only utf8 strings are valid as an indexed field in a struct" - ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), } - } - GetFieldAccessCharacteristic::ListIndex{ key_dt } => { - match (data_type, key_dt) { - (DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)), - (DataType::List(_), _) => plan_err!( - "Only ints are valid as an indexed field in a list" - ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + Self::ListIndex{ key_dt } => { + match (data_type, key_dt) { + (DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)), + (DataType::List(_), _) => plan_err!( + "Only ints are valid as an indexed field in a list" + ), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } } - } - GetFieldAccessCharacteristic::ListRange{ start_dt, stop_dt } => { - match (data_type, start_dt, stop_dt) { - (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), - (DataType::List(_), _, _) => plan_err!( - "Only ints are valid as an indexed field in a list" - ), - (other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + Self::ListRange{ start_dt, stop_dt } => { + match (data_type, start_dt, stop_dt) { + (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), + (DataType::List(_), _, _) => plan_err!( + "Only ints are valid as an indexed field in a list" + ), + (other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } } } } diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index e414e594f908..f9d6b1908db1 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -27,12 +27,7 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{cast::as_struct_array, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{ - field_util::{ - get_indexed_field as get_data_type_field, GetFieldAccessCharacteristic, - }, - ColumnarValue, -}; +use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue}; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; @@ -96,6 +91,23 @@ impl GetIndexedFieldExpr { pub fn arg(&self) -> &Arc { &self.arg } + + fn schema_access(&self, input_schema: &Schema) -> Result { + Ok(match &self.field { + GetFieldAccessExpr::NamedStructField { name } => { + GetFieldAccessSchema::NamedStructField { name: name.clone() } + } + GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex { + key_dt: key.data_type(input_schema)?, + }, + GetFieldAccessExpr::ListRange { start, stop } => { + GetFieldAccessSchema::ListRange { + start_dt: start.data_type(input_schema)?, + stop_dt: stop.data_type(input_schema)?, + } + } + }) + } } impl std::fmt::Display for GetIndexedFieldExpr { @@ -111,44 +123,16 @@ impl PhysicalExpr for GetIndexedFieldExpr { fn data_type(&self, input_schema: &Schema) -> Result { let arg_dt = self.arg.data_type(input_schema)?; - let field_ch = match &self.field { - GetFieldAccessExpr::NamedStructField { name } => { - GetFieldAccessCharacteristic::NamedStructField { name: name.clone() } - } - GetFieldAccessExpr::ListIndex { key } => { - GetFieldAccessCharacteristic::ListIndex { - key_dt: key.data_type(input_schema)?, - } - } - GetFieldAccessExpr::ListRange { start, stop } => { - GetFieldAccessCharacteristic::ListRange { - start_dt: start.data_type(input_schema)?, - stop_dt: stop.data_type(input_schema)?, - } - } - }; - get_data_type_field(&arg_dt, &field_ch).map(|f| f.data_type().clone()) + self.schema_access(input_schema)? + .get_accessed_field(&arg_dt) + .map(|f| f.data_type().clone()) } fn nullable(&self, input_schema: &Schema) -> Result { let arg_dt = self.arg.data_type(input_schema)?; - let field_ch = match &self.field { - GetFieldAccessExpr::NamedStructField { name } => { - GetFieldAccessCharacteristic::NamedStructField { name: name.clone() } - } - GetFieldAccessExpr::ListIndex { key } => { - GetFieldAccessCharacteristic::ListIndex { - key_dt: key.data_type(input_schema)?, - } - } - GetFieldAccessExpr::ListRange { start, stop } => { - GetFieldAccessCharacteristic::ListRange { - start_dt: start.data_type(input_schema)?, - stop_dt: stop.data_type(input_schema)?, - } - } - }; - get_data_type_field(&arg_dt, &field_ch).map(|f| f.is_nullable()) + self.schema_access(input_schema)? + .get_accessed_field(&arg_dt) + .map(|f| f.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result {