From 64727def442d888aaed297cc7d1f0ac9311b9c85 Mon Sep 17 00:00:00 2001 From: tanruixiang Date: Wed, 28 Jun 2023 11:21:29 +0800 Subject: [PATCH] fix: missing dict attribute --- datafusion/common/src/dfschema.rs | 66 ++++++- datafusion/expr/src/expr_schema.rs | 308 +++++++++++++++++++++++++++-- 2 files changed, 360 insertions(+), 14 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 8105ada59f5c..754694838ab2 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -586,6 +586,12 @@ pub trait ExprSchema: std::fmt::Debug { /// What is the datatype of this column? fn data_type(&self, col: &Column) -> Result<&DataType>; + + /// Is this column reference dict_is_ordered? + fn dict_is_ordered(&self, col: &Column) -> Result; + + /// What is the dict_id of this column? + fn dict_id(&self, col: &Column) -> Result; } // Implement `ExprSchema` for `Arc` @@ -593,20 +599,40 @@ impl + std::fmt::Debug> ExprSchema for P { fn nullable(&self, col: &Column) -> Result { self.as_ref().nullable(col) } - fn data_type(&self, col: &Column) -> Result<&DataType> { self.as_ref().data_type(col) } + + fn dict_is_ordered(&self, col: &Column) -> Result { + self.as_ref().dict_is_ordered(col) + } + + fn dict_id(&self, col: &Column) -> Result { + self.as_ref().dict_id(col) + } } impl ExprSchema for DFSchema { fn nullable(&self, col: &Column) -> Result { Ok(self.field_from_column(col)?.is_nullable()) } - fn data_type(&self, col: &Column) -> Result<&DataType> { Ok(self.field_from_column(col)?.data_type()) } + + fn dict_is_ordered(&self, col: &Column) -> Result { + match self.field_from_column(col)?.field().dict_is_ordered() { + Some(dict_id_ordered) => Ok(dict_id_ordered), + _ => Ok(false), + } + } + + fn dict_id(&self, col: &Column) -> Result { + match self.field_from_column(col)?.field().dict_id() { + Some(dict_id_ordered) => Ok(dict_id_ordered), + _ => Ok(0), + } + } } /// DFField wraps an Arrow field and adds an optional qualifier @@ -635,6 +661,42 @@ impl DFField { /// Convenience method for creating new `DFField` without a qualifier pub fn new_unqualified(name: &str, data_type: DataType, nullable: bool) -> Self { DFField { + None, + field: Arc::new(Field::new(name, data_type, nullable)), + } + } + /// Creates a new `DFField` with dict + pub fn new_dict( + qualifier: Option<&str>, + name: &str, + data_type: DataType, + nullable: bool, + dict_id: i64, + dict_is_ordered: bool, + ) -> Self { + DFField { + qualifier: qualifier.map(|s| s.into()), + field: Field::new_dict(name, data_type, nullable, dict_id, dict_is_ordered), + } + } + + /// Convenience method for creating new `DFField` without a qualifier + pub fn new_unqualified_dict( + name: &str, + data_type: DataType, + nullable: bool, + dict_id: i64, + dict_is_ordered: bool, + ) -> Self { + DFField { + qualifier: None, + field: Field::new_dict(name, data_type, nullable, dict_id, dict_is_ordered), + } + } + + /// Create an unqualified field from an existing Arrow field + pub fn from(field: Field) -> Self { + Self { qualifier: None, field: Arc::new(Field::new(name, data_type, nullable)), } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index ba37cf6d45b8..f0cd07c31609 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -43,6 +43,12 @@ pub trait ExprSchemable { /// cast to a type with respect to a schema fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; + + /// given a schema, return the dict id of the expr + fn get_dict_id(&self, schema: &S) -> Result; + + /// given a schema, return the dict_is_ordered of the expr + fn dict_is_ordered(&self, input_schema: &S) -> Result; } impl ExprSchemable for Expr { @@ -262,23 +268,246 @@ impl ExprSchemable for Expr { } } + /// Returns the nullability of the expression based on [ExprSchema]. + /// + /// Note: [DFSchema] implements [ExprSchema]. + /// + /// # Errors + /// + /// This function errors when it is not possible to compute its + /// nullability. This happens when the expression refers to a + /// column that does not exist in the schema. + fn dict_is_ordered(&self, input_schema: &S) -> Result { + match self { + Expr::Column(c) => input_schema.dict_is_ordered(c), + _ => Ok(false), + } + // match self { + // Expr::Alias(expr, _) + // | Expr::Not(expr) + // | Expr::Negative(expr) + // | Expr::Sort(Sort { expr, .. }) + // | Expr::InList(InList { expr, .. }) => expr.dict_is_ordered(input_schema), + // Expr::Between(Between { expr, .. }) => expr.dict_is_ordered(input_schema), + // Expr::Column(c) => input_schema.dict_is_ordered(c), + // Expr::OuterReferenceColumn(_, _) => Ok(false), + // Expr::Literal(value) => Ok(false), + // Expr::Case(case) => { + // // this expression is nullable if any of the input expressions are nullable + // let then_nullable = case + // .when_then_expr + // .iter() + // .map(|(_, t)| t.dict_is_ordered(input_schema)) + // .collect::>>()?; + // if then_nullable.contains(&true) { + // Ok(true) + // } else if let Some(e) = &case.else_expr { + // e.dict_is_ordered(input_schema) + // } else { + // // CASE produces NULL if there is no `else` expr + // // (aka when none of the `when_then_exprs` match) + // Ok(true) + // } + // } + // Expr::Cast(Cast { expr, .. }) => expr.dict_is_ordered(input_schema), + // Expr::ScalarVariable(_, _) + // | Expr::TryCast { .. } + // | Expr::ScalarFunction(..) + // | Expr::ScalarUDF(..) + // | Expr::WindowFunction { .. } + // | Expr::AggregateFunction { .. } + // | Expr::AggregateUDF { .. } + // | Expr::Placeholder(_) => Ok(true), + // Expr::IsNull(_) + // | Expr::IsNotNull(_) + // | Expr::IsTrue(_) + // | Expr::IsFalse(_) + // | Expr::IsUnknown(_) + // | Expr::IsNotTrue(_) + // | Expr::IsNotFalse(_) + // | Expr::IsNotUnknown(_) + // | Expr::Exists { .. } => Ok(false), + // Expr::InSubquery(InSubquery { expr, .. }) => { + // expr.dict_is_ordered(input_schema) + // } + // Expr::ScalarSubquery(subquery) => Ok(subquery + // .subquery + // .schema() + // .field(0) + // .field() + // .dict_is_ordered() + // .unwrap_or(false)), + // Expr::BinaryExpr(BinaryExpr { + // ref left, + // ref right, + // .. + // }) => Ok(left.dict_is_ordered(input_schema)? + // || right.dict_is_ordered(input_schema)?), + // Expr::Like(Like { expr, .. }) => expr.dict_is_ordered(input_schema), + // Expr::ILike(Like { expr, .. }) => expr.dict_is_ordered(input_schema), + // Expr::SimilarTo(Like { expr, .. }) => expr.dict_is_ordered(input_schema), + // Expr::Wildcard => Err(DataFusionError::Internal( + // "Wildcard expressions are not valid in a logical query plan".to_owned(), + // )), + // Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( + // "QualifiedWildcard expressions are not valid in a logical query plan" + // .to_owned(), + // )), + // Expr::GetIndexedField(GetIndexedField { key, expr }) => { + // let data_type = expr.get_type(input_schema)?; + // get_indexed_field(&data_type, key) + // .map(|x| x.dict_is_ordered().unwrap_or(false)) + // } + // Expr::GroupingSet(_) => { + // // grouping sets do not really have the concept of nullable and do not appear + // // in projections + // Ok(true) + // } + // } + } + + fn get_dict_id(&self, schema: &S) -> Result { + match self { + Expr::Column(c) => schema.dict_id(c), + _ => Ok(0), + } + // match self { + // Expr::Alias(expr, _) + // | Expr::Not(expr) + // | Expr::Negative(expr) + // | Expr::Sort(Sort { expr, .. }) + // | Expr::InList(InList { expr, .. }) => expr.get_dict_id(input_schema), + // Expr::Between(Between { expr, .. }) => expr.get_dict_id(input_schema), + // Expr::Column(c) => input_schema.get_dict_id(c), + // Expr::OuterReferenceColumn(_, _) => Ok(false), + // Expr::Literal(value) => Ok(false), + // Expr::Case(case) => { + // // this expression is nullable if any of the input expressions are nullable + // let then_nullable = case + // .when_then_expr + // .iter() + // .map(|(_, t)| t.get_dict_id(input_schema)) + // .collect::>>()?; + // if then_nullable.contains(&true) { + // Ok(true) + // } else if let Some(e) = &case.else_expr { + // e.get_dict_id(input_schema) + // } else { + // // CASE produces NULL if there is no `else` expr + // // (aka when none of the `when_then_exprs` match) + // Ok(true) + // } + // } + // Expr::Cast(Cast { expr, .. }) => expr.get_dict_id(input_schema), + // Expr::ScalarVariable(_, _) + // | Expr::TryCast { .. } + // | Expr::ScalarFunction(..) + // | Expr::ScalarUDF(..) + // | Expr::WindowFunction { .. } + // | Expr::AggregateFunction { .. } + // | Expr::AggregateUDF { .. } + // | Expr::Placeholder(_) + // | Expr::IsNull(_) + // | Expr::IsNotNull(_) + // | Expr::IsTrue(_) + // | Expr::IsFalse(_) + // | Expr::IsUnknown(_) + // | Expr::IsNotTrue(_) + // | Expr::IsNotFalse(_) + // | Expr::IsNotUnknown(_) + // | Expr::Exists { .. } => Ok(0), + // Expr::InSubquery(InSubquery { expr, .. }) => expr.get_dict_id(input_schema), + // Expr::ScalarSubquery(subquery) => Ok(subquery + // .subquery + // .schema() + // .field(0) + // .field() + // .get_dict_id() + // .unwrap_or(0)), + // Expr::BinaryExpr(BinaryExpr { + // ref left, + // ref right, + // .. + // }) => Ok(left.dict_is_ordered(input_schema)? + // || right.dict_is_ordered(input_schema)?), + // Expr::Like(Like { expr, .. }) => expr.dict_is_ordered(input_schema), + // Expr::ILike(Like { expr, .. }) => expr.dict_is_ordered(input_schema), + // Expr::SimilarTo(Like { expr, .. }) => expr.dict_is_ordered(input_schema), + // Expr::Wildcard => Err(DataFusionError::Internal( + // "Wildcard expressions are not valid in a logical query plan".to_owned(), + // )), + // Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( + // "QualifiedWildcard expressions are not valid in a logical query plan" + // .to_owned(), + // )), + // Expr::GetIndexedField(GetIndexedField { key, expr }) => { + // let data_type = expr.get_type(input_schema)?; + // get_indexed_field(&data_type, key) + // .map(|x| x.dict_is_ordered().unwrap_or(false)) + // } + // Expr::GroupingSet(_) => { + // // grouping sets do not really have the concept of nullable and do not appear + // // in projections + // Ok(true) + // } + // } + } + /// Returns a [arrow::datatypes::Field] compatible with this expression. /// /// So for example, a projected expression `col(c1) + col(c2)` is /// placed in an output field **named** col("c1 + c2") + // fn to_field(&self, input_schema: &DFSchema) -> Result { + // match self { + // Expr::Column(c) => Ok(DFField::new( + // c.relation.as_deref(), + // &c.name, + // self.get_type(input_schema)?, + // self.nullable(input_schema)?, + // )), + // _ => Ok(DFField::new( + // None, + // &self.display_name()?, + // self.get_type(input_schema)?, + // self.nullable(input_schema)?, + // )), + // } + // } + fn to_field(&self, input_schema: &DFSchema) -> Result { match self { - Expr::Column(c) => Ok(DFField::new( - c.relation.clone(), - &c.name, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )), - _ => Ok(DFField::new_unqualified( - &self.display_name()?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )), + Expr::Column(c) => Ok(match self.get_type(input_schema)? { + DataType::Dictionary(_, _) => DFField::new_dict( + c.relation.as_deref(), + &c.name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + self.get_dict_id(input_schema)?, + self.dict_is_ordered(input_schema)?, + ), + _ => DFField::new( + c.relation.as_deref(), + &c.name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + ), + }), + _ => Ok(match self.get_type(input_schema)? { + DataType::Dictionary(_, _) => DFField::new_dict( + None, + &self.display_name()?, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + self.get_dict_id(input_schema)?, + self.dict_is_ordered(input_schema)?, + ), + _ => DFField::new( + None, + &self.display_name()?, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + ), + }), } } @@ -347,11 +576,58 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result::new(), + ) + .unwrap(); + let expr = vec![col("dictionary_column1"), col("dictionary_column2")]; + for i in 0..dffield.len() { + assert_eq!(expr[i].to_field(&dfschema).unwrap(), dffield[i]); + assert_eq!( + expr[i].to_field(&dfschema).unwrap().field().dict_id(), + dffield[i].field().dict_id() + ); + assert_eq!( + expr[i] + .to_field(&dfschema) + .unwrap() + .field() + .dict_is_ordered(), + dffield[i].field().dict_is_ordered() + ); + } + } + #[test] fn expr_schema_nullability() { let expr = col("foo").eq(lit(1)); @@ -404,5 +680,13 @@ mod tests { fn data_type(&self, _col: &Column) -> Result<&DataType> { Ok(&self.data_type) } + + fn dict_id(&self, _col: &Column) -> Result { + Ok(0) + } + + fn dict_is_ordered(&self, _col: &Column) -> Result { + Ok(false) + } } }