Skip to content

Commit

Permalink
Minor: rename GetFieldAccessCharacteristic and add docs (apache#7220)
Browse files Browse the repository at this point in the history
* Minor: rename `GetFieldAccessCharacteristic` and add docs

* Update datafusion/expr/src/field_util.rs
  • Loading branch information
alamb authored Aug 8, 2023
1 parent 3d917a0 commit 627abd7
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 128 deletions.
68 changes: 26 additions & 42 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use crate::expr::{
GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort,
TryCast, WindowFunction,
};
use crate::field_util::{get_indexed_field, GetFieldAccessCharacteristic};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
use crate::{LogicalPlan, Projection, Subquery};
use arrow::compute::can_cast_types;
use arrow::datatypes::DataType;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
plan_err, Column, DFField, DFSchema, DataFusionError, ExprSchema, Result,
};
Expand Down Expand Up @@ -157,26 +157,7 @@ impl ExprSchemable for Expr {
Ok(DataType::Null)
}
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
let expr_dt = expr.get_type(schema)?;
let field_ch = match field {
GetFieldAccess::NamedStructField { name } => {
GetFieldAccessCharacteristic::NamedStructField {
name: name.clone(),
}
}
GetFieldAccess::ListIndex { key } => {
GetFieldAccessCharacteristic::ListIndex {
key_dt: key.get_type(schema)?,
}
}
GetFieldAccess::ListRange { start, stop } => {
GetFieldAccessCharacteristic::ListRange {
start_dt: start.get_type(schema)?,
stop_dt: stop.get_type(schema)?,
}
}
};
get_indexed_field(&expr_dt, &field_ch).map(|x| x.data_type().clone())
field_for_index(expr, field, schema).map(|x| x.data_type().clone())
}
}
}
Expand Down Expand Up @@ -285,26 +266,7 @@ impl ExprSchemable for Expr {
.to_owned(),
)),
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
let expr_dt = expr.get_type(input_schema)?;
let field_ch = match field {
GetFieldAccess::NamedStructField { name } => {
GetFieldAccessCharacteristic::NamedStructField {
name: name.clone(),
}
}
GetFieldAccess::ListIndex { key } => {
GetFieldAccessCharacteristic::ListIndex {
key_dt: key.get_type(input_schema)?,
}
}
GetFieldAccess::ListRange { start, stop } => {
GetFieldAccessCharacteristic::ListRange {
start_dt: start.get_type(input_schema)?,
stop_dt: stop.get_type(input_schema)?,
}
}
};
get_indexed_field(&expr_dt, &field_ch).map(|x| x.is_nullable())
field_for_index(expr, field, input_schema).map(|x| x.is_nullable())
}
Expr::GroupingSet(_) => {
// grouping sets do not really have the concept of nullable and do not appear
Expand Down Expand Up @@ -373,6 +335,28 @@ impl ExprSchemable for Expr {
}
}

/// return the schema [`Field`] for the type referenced by `get_indexed_field`
fn field_for_index<S: ExprSchema>(
expr: &Expr,
field: &GetFieldAccess,
schema: &S,
) -> Result<Field> {
let expr_dt = expr.get_type(schema)?;
match field {
GetFieldAccess::NamedStructField { name } => {
GetFieldAccessSchema::NamedStructField { name: name.clone() }
}
GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex {
key_dt: key.get_type(schema)?,
},
GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange {
start_dt: start.get_type(schema)?,
stop_dt: stop.get_type(schema)?,
},
}
.get_accessed_field(&expr_dt)
}

/// cast subquery in InSubquery/ScalarSubquery to a given type.
pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
if subquery.subquery.schema().field(0).data_type() == cast_to_type {
Expand Down
92 changes: 46 additions & 46 deletions datafusion/expr/src/field_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,64 +20,64 @@
use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue};

pub enum GetFieldAccessCharacteristic {
/// returns the field `struct[field]`. For example `struct["name"]`
/// Types of the field access expression of a nested type, such as `Field` or `List`
pub enum GetFieldAccessSchema {
/// Named field, For example `struct["name"]`
NamedStructField { name: ScalarValue },
/// single list index
// list[i]
/// Single list index, for example: `list[i]`
ListIndex { key_dt: DataType },
/// list range `list[i:j]`
/// List range, for example `list[i:j]`
ListRange {
start_dt: DataType,
stop_dt: DataType,
},
}

/// Returns the field access indexed by `key` and/or `extra_key` from a [`DataType::List`] or [`DataType::Struct`]
/// # Error
/// Errors if
/// * the `data_type` is not a Struct or a List,
/// * the `data_type` of extra key does not match with `data_type` of key
/// * there is no field key is not of the required index type
pub fn get_indexed_field(
data_type: &DataType,
field_characteristic: &GetFieldAccessCharacteristic,
) -> Result<Field> {
match field_characteristic {
GetFieldAccessCharacteristic::NamedStructField{ name } => {
match (data_type, name) {
(DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
if s.is_empty() {
plan_err!(
"Struct based indexed access requires a non empty string"
)
} else {
let field = fields.iter().find(|f| f.name() == s);
field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone())
impl GetFieldAccessSchema {
/// Returns the schema [`Field`] from a [`DataType::List`] or
/// [`DataType::Struct`] indexed by this structure
///
/// # Error
/// Errors if
/// * the `data_type` is not a Struct or a List,
/// * the `data_type` of the name/index/start-stop do not match a supported index type
pub fn get_accessed_field(&self, data_type: &DataType) -> Result<Field> {
match self {
Self::NamedStructField{ name } => {
match (data_type, name) {
(DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
if s.is_empty() {
plan_err!(
"Struct based indexed access requires a non empty string"
)
} else {
let field = fields.iter().find(|f| f.name() == s);
field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone())
}
}
(DataType::Struct(_), _) => plan_err!(
"Only utf8 strings are valid as an indexed field in a struct"
),
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
}
(DataType::Struct(_), _) => plan_err!(
"Only utf8 strings are valid as an indexed field in a struct"
),
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
}
}
GetFieldAccessCharacteristic::ListIndex{ key_dt } => {
match (data_type, key_dt) {
(DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)),
(DataType::List(_), _) => plan_err!(
"Only ints are valid as an indexed field in a list"
),
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
Self::ListIndex{ key_dt } => {
match (data_type, key_dt) {
(DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)),
(DataType::List(_), _) => plan_err!(
"Only ints are valid as an indexed field in a list"
),
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
}
}
}
GetFieldAccessCharacteristic::ListRange{ start_dt, stop_dt } => {
match (data_type, start_dt, stop_dt) {
(DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
(DataType::List(_), _, _) => plan_err!(
"Only ints are valid as an indexed field in a list"
),
(other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
Self::ListRange{ start_dt, stop_dt } => {
match (data_type, start_dt, stop_dt) {
(DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
(DataType::List(_), _, _) => plan_err!(
"Only ints are valid as an indexed field in a list"
),
(other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
}
}
}
}
Expand Down
64 changes: 24 additions & 40 deletions datafusion/physical-expr/src/expressions/get_indexed_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@ use arrow::{
record_batch::RecordBatch,
};
use datafusion_common::{cast::as_struct_array, DataFusionError, Result, ScalarValue};
use datafusion_expr::{
field_util::{
get_indexed_field as get_data_type_field, GetFieldAccessCharacteristic,
},
ColumnarValue,
};
use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue};
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -120,6 +115,23 @@ impl GetIndexedFieldExpr {
pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
&self.arg
}

fn schema_access(&self, input_schema: &Schema) -> Result<GetFieldAccessSchema> {
Ok(match &self.field {
GetFieldAccessExpr::NamedStructField { name } => {
GetFieldAccessSchema::NamedStructField { name: name.clone() }
}
GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex {
key_dt: key.data_type(input_schema)?,
},
GetFieldAccessExpr::ListRange { start, stop } => {
GetFieldAccessSchema::ListRange {
start_dt: start.data_type(input_schema)?,
stop_dt: stop.data_type(input_schema)?,
}
}
})
}
}

impl std::fmt::Display for GetIndexedFieldExpr {
Expand All @@ -135,44 +147,16 @@ impl PhysicalExpr for GetIndexedFieldExpr {

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
let arg_dt = self.arg.data_type(input_schema)?;
let field_ch = match &self.field {
GetFieldAccessExpr::NamedStructField { name } => {
GetFieldAccessCharacteristic::NamedStructField { name: name.clone() }
}
GetFieldAccessExpr::ListIndex { key } => {
GetFieldAccessCharacteristic::ListIndex {
key_dt: key.data_type(input_schema)?,
}
}
GetFieldAccessExpr::ListRange { start, stop } => {
GetFieldAccessCharacteristic::ListRange {
start_dt: start.data_type(input_schema)?,
stop_dt: stop.data_type(input_schema)?,
}
}
};
get_data_type_field(&arg_dt, &field_ch).map(|f| f.data_type().clone())
self.schema_access(input_schema)?
.get_accessed_field(&arg_dt)
.map(|f| f.data_type().clone())
}

fn nullable(&self, input_schema: &Schema) -> Result<bool> {
let arg_dt = self.arg.data_type(input_schema)?;
let field_ch = match &self.field {
GetFieldAccessExpr::NamedStructField { name } => {
GetFieldAccessCharacteristic::NamedStructField { name: name.clone() }
}
GetFieldAccessExpr::ListIndex { key } => {
GetFieldAccessCharacteristic::ListIndex {
key_dt: key.data_type(input_schema)?,
}
}
GetFieldAccessExpr::ListRange { start, stop } => {
GetFieldAccessCharacteristic::ListRange {
start_dt: start.data_type(input_schema)?,
stop_dt: stop.data_type(input_schema)?,
}
}
};
get_data_type_field(&arg_dt, &field_ch).map(|f| f.is_nullable())
self.schema_access(input_schema)?
.get_accessed_field(&arg_dt)
.map(|f| f.is_nullable())
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
Expand Down

0 comments on commit 627abd7

Please sign in to comment.