From f154a9aa38e376de2cf7ac0df3618c28729f2f20 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 17 Feb 2023 21:46:53 +0800 Subject: [PATCH] Fix the potential bug of check_all_column_from_schema (#5287) * Fix the potential bug of check_all_column_from_schema * rename contain_column to is_column_from_schema --- datafusion/common/src/dfschema.rs | 70 +++++++++++++++---- datafusion/expr/src/utils.rs | 31 +++++--- .../optimizer/src/decorrelate_where_in.rs | 4 +- 3 files changed, 79 insertions(+), 26 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index e4a591ad37c1..982459ac658b 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -191,7 +191,7 @@ impl DFSchema { &self, qualifier: Option<&str>, name: &str, - ) -> Result { + ) -> Result> { let mut matches = self .fields .iter() @@ -221,13 +221,9 @@ impl DFSchema { }) .map(|(idx, _)| idx); match matches.next() { - None => Err(field_not_found( - qualifier.map(|s| s.to_string()), - name, - self, - )), + None => Ok(None), Some(idx) => match matches.next() { - None => Ok(idx), + None => Ok(Some(idx)), // found more than one matches Some(_) => Err(DataFusionError::Internal(format!( "Ambiguous reference to qualified field named '{}.{}'", @@ -240,7 +236,17 @@ impl DFSchema { /// Find the index of the column with the given qualifier and name pub fn index_of_column(&self, col: &Column) -> Result { + let qualifier = col.relation.as_deref(); + self.index_of_column_by_name(col.relation.as_deref(), &col.name)? + .ok_or_else(|| { + field_not_found(qualifier.map(|s| s.to_string()), &col.name, self) + }) + } + + /// Check if the column is in the current schema + pub fn is_column_from_schema(&self, col: &Column) -> Result { self.index_of_column_by_name(col.relation.as_deref(), &col.name) + .map(|idx| idx.is_some()) } /// Find the field with the given name @@ -293,7 +299,10 @@ impl DFSchema { qualifier: &str, name: &str, ) -> Result<&DFField> { - let idx = self.index_of_column_by_name(Some(qualifier), name)?; + let idx = self + .index_of_column_by_name(Some(qualifier), name)? + .ok_or_else(|| field_not_found(Some(qualifier.to_string()), name, self))?; + Ok(self.field(idx)) } @@ -663,9 +672,10 @@ mod tests { #[test] fn qualifier_in_name() -> Result<()> { + let col = Column::from_name("t1.c0"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; // lookup with unqualified name "t1.c0" - let err = schema.index_of_column_by_name(None, "t1.c0").err().unwrap(); + let err = schema.index_of_column(&col).err().unwrap(); assert_eq!( "Schema error: No field named 't1.c0'. Valid fields are 't1'.'c0', 't1'.'c1'.", &format!("{err}") @@ -829,14 +839,13 @@ mod tests { fn select_without_valid_fields() { let schema = DFSchema::empty(); - let err = schema - .index_of_column_by_name(Some("t1"), "c0") - .err() - .unwrap(); + let col = Column::from_qualified_name("t1.c0"); + let err = schema.index_of_column(&col).err().unwrap(); assert_eq!("Schema error: No field named 't1'.'c0'.", &format!("{err}")); // the same check without qualifier - let err = schema.index_of_column_by_name(None, "c0").err().unwrap(); + let col = Column::from_name("c0"); + let err = schema.index_of_column(&col).err().unwrap(); assert_eq!("Schema error: No field named 'c0'.", &format!("{err}")); } @@ -1123,6 +1132,39 @@ mod tests { assert_eq!(a_df.metadata(), a_arrow.metadata()) } + #[test] + fn test_contain_column() -> Result<()> { + // qualified exists + { + let col = Column::from_qualified_name("t1.c0"); + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + assert!(schema.is_column_from_schema(&col)?); + } + + // qualified not exists + { + let col = Column::from_qualified_name("t1.c2"); + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + assert!(!schema.is_column_from_schema(&col)?); + } + + // unqualified exists + { + let col = Column::from_name("c0"); + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + assert!(schema.is_column_from_schema(&col)?); + } + + // unqualified not exists + { + let col = Column::from_name("c2"); + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + assert!(!schema.is_column_from_schema(&col)?); + } + + Ok(()) + } + fn test_schema_2() -> Schema { Schema::new(vec![ Field::new("c100", DataType::Boolean, true), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8ce959e793f0..5706ef304d5a 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -970,13 +970,18 @@ pub fn can_hash(data_type: &DataType) -> bool { } /// Check whether all columns are from the schema. -pub fn check_all_column_from_schema( +pub fn check_all_columns_from_schema( columns: &HashSet, schema: DFSchemaRef, -) -> bool { - columns - .iter() - .all(|column| schema.index_of_column(column).is_ok()) +) -> Result { + for col in columns.iter() { + let exist = schema.is_column_from_schema(col)?; + if !exist { + return Ok(false); + } + } + + Ok(true) } /// Give two sides of the equijoin predicate, return a valid join key pair. @@ -1003,18 +1008,24 @@ pub fn find_valid_equijoin_key_pair( } let l_is_left = - check_all_column_from_schema(&left_using_columns, left_schema.clone()); + check_all_columns_from_schema(&left_using_columns, left_schema.clone())?; let r_is_right = - check_all_column_from_schema(&right_using_columns, right_schema.clone()); + check_all_columns_from_schema(&right_using_columns, right_schema.clone())?; let r_is_left_and_l_is_right = || { - check_all_column_from_schema(&right_using_columns, left_schema.clone()) - && check_all_column_from_schema(&left_using_columns, right_schema.clone()) + let result = + check_all_columns_from_schema(&right_using_columns, left_schema.clone())? + && check_all_columns_from_schema( + &left_using_columns, + right_schema.clone(), + )?; + + Result::<_, DataFusionError>::Ok(result) }; let join_key_pair = match (l_is_left, r_is_right) { (true, true) => Some((left_key.clone(), right_key.clone())), - (_, _) if r_is_left_and_l_is_right() => { + (_, _) if r_is_left_and_l_is_right()? => { Some((right_key.clone(), left_key.clone())) } _ => None, diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 35164dcadcd6..7a9a75ff45bb 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -22,7 +22,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{context, Column, DataFusionError, Result}; use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col}; use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; -use datafusion_expr::utils::check_all_column_from_schema; +use datafusion_expr::utils::check_all_columns_from_schema; use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder}; use log::debug; use std::collections::{BTreeSet, HashMap}; @@ -229,7 +229,7 @@ fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec, Logica let mut subquery_filters: Vec = vec![]; for expr in subquery_filter_exprs { let cols = expr.to_columns()?; - if check_all_column_from_schema(&cols, input_schema.clone()) { + if check_all_columns_from_schema(&cols, input_schema.clone())? { subquery_filters.push(expr.clone()); } else { join_filters.push(expr.clone())