Skip to content

Commit

Permalink
Fix the potential bug of check_all_column_from_schema (#5287)
Browse files Browse the repository at this point in the history
* Fix the potential bug of check_all_column_from_schema

* rename contain_column to is_column_from_schema
  • Loading branch information
ygf11 authored Feb 17, 2023
1 parent fed4019 commit f154a9a
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 26 deletions.
70 changes: 56 additions & 14 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl DFSchema {
&self,
qualifier: Option<&str>,
name: &str,
) -> Result<usize> {
) -> Result<Option<usize>> {
let mut matches = self
.fields
.iter()
Expand Down Expand Up @@ -221,13 +221,9 @@ impl DFSchema {
})
.map(|(idx, _)| idx);
match matches.next() {
None => Err(field_not_found(
qualifier.map(|s| s.to_string()),
name,
self,
)),
None => Ok(None),
Some(idx) => match matches.next() {
None => Ok(idx),
None => Ok(Some(idx)),
// found more than one matches
Some(_) => Err(DataFusionError::Internal(format!(
"Ambiguous reference to qualified field named '{}.{}'",
Expand All @@ -240,7 +236,17 @@ impl DFSchema {

/// Find the index of the column with the given qualifier and name
pub fn index_of_column(&self, col: &Column) -> Result<usize> {
let qualifier = col.relation.as_deref();
self.index_of_column_by_name(col.relation.as_deref(), &col.name)?
.ok_or_else(|| {
field_not_found(qualifier.map(|s| s.to_string()), &col.name, self)
})
}

/// Check if the column is in the current schema
pub fn is_column_from_schema(&self, col: &Column) -> Result<bool> {
self.index_of_column_by_name(col.relation.as_deref(), &col.name)
.map(|idx| idx.is_some())
}

/// Find the field with the given name
Expand Down Expand Up @@ -293,7 +299,10 @@ impl DFSchema {
qualifier: &str,
name: &str,
) -> Result<&DFField> {
let idx = self.index_of_column_by_name(Some(qualifier), name)?;
let idx = self
.index_of_column_by_name(Some(qualifier), name)?
.ok_or_else(|| field_not_found(Some(qualifier.to_string()), name, self))?;

Ok(self.field(idx))
}

Expand Down Expand Up @@ -663,9 +672,10 @@ mod tests {

#[test]
fn qualifier_in_name() -> Result<()> {
let col = Column::from_name("t1.c0");
let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
// lookup with unqualified name "t1.c0"
let err = schema.index_of_column_by_name(None, "t1.c0").err().unwrap();
let err = schema.index_of_column(&col).err().unwrap();
assert_eq!(
"Schema error: No field named 't1.c0'. Valid fields are 't1'.'c0', 't1'.'c1'.",
&format!("{err}")
Expand Down Expand Up @@ -829,14 +839,13 @@ mod tests {
fn select_without_valid_fields() {
let schema = DFSchema::empty();

let err = schema
.index_of_column_by_name(Some("t1"), "c0")
.err()
.unwrap();
let col = Column::from_qualified_name("t1.c0");
let err = schema.index_of_column(&col).err().unwrap();
assert_eq!("Schema error: No field named 't1'.'c0'.", &format!("{err}"));

// the same check without qualifier
let err = schema.index_of_column_by_name(None, "c0").err().unwrap();
let col = Column::from_name("c0");
let err = schema.index_of_column(&col).err().unwrap();
assert_eq!("Schema error: No field named 'c0'.", &format!("{err}"));
}

Expand Down Expand Up @@ -1123,6 +1132,39 @@ mod tests {
assert_eq!(a_df.metadata(), a_arrow.metadata())
}

#[test]
fn test_contain_column() -> Result<()> {
// qualified exists
{
let col = Column::from_qualified_name("t1.c0");
let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
assert!(schema.is_column_from_schema(&col)?);
}

// qualified not exists
{
let col = Column::from_qualified_name("t1.c2");
let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
assert!(!schema.is_column_from_schema(&col)?);
}

// unqualified exists
{
let col = Column::from_name("c0");
let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
assert!(schema.is_column_from_schema(&col)?);
}

// unqualified not exists
{
let col = Column::from_name("c2");
let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
assert!(!schema.is_column_from_schema(&col)?);
}

Ok(())
}

fn test_schema_2() -> Schema {
Schema::new(vec![
Field::new("c100", DataType::Boolean, true),
Expand Down
31 changes: 21 additions & 10 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -970,13 +970,18 @@ pub fn can_hash(data_type: &DataType) -> bool {
}

/// Check whether all columns are from the schema.
pub fn check_all_column_from_schema(
pub fn check_all_columns_from_schema(
columns: &HashSet<Column>,
schema: DFSchemaRef,
) -> bool {
columns
.iter()
.all(|column| schema.index_of_column(column).is_ok())
) -> Result<bool> {
for col in columns.iter() {
let exist = schema.is_column_from_schema(col)?;
if !exist {
return Ok(false);
}
}

Ok(true)
}

/// Give two sides of the equijoin predicate, return a valid join key pair.
Expand All @@ -1003,18 +1008,24 @@ pub fn find_valid_equijoin_key_pair(
}

let l_is_left =
check_all_column_from_schema(&left_using_columns, left_schema.clone());
check_all_columns_from_schema(&left_using_columns, left_schema.clone())?;
let r_is_right =
check_all_column_from_schema(&right_using_columns, right_schema.clone());
check_all_columns_from_schema(&right_using_columns, right_schema.clone())?;

let r_is_left_and_l_is_right = || {
check_all_column_from_schema(&right_using_columns, left_schema.clone())
&& check_all_column_from_schema(&left_using_columns, right_schema.clone())
let result =
check_all_columns_from_schema(&right_using_columns, left_schema.clone())?
&& check_all_columns_from_schema(
&left_using_columns,
right_schema.clone(),
)?;

Result::<_, DataFusionError>::Ok(result)
};

let join_key_pair = match (l_is_left, r_is_right) {
(true, true) => Some((left_key.clone(), right_key.clone())),
(_, _) if r_is_left_and_l_is_right() => {
(_, _) if r_is_left_and_l_is_right()? => {
Some((right_key.clone(), left_key.clone()))
}
_ => None,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/decorrelate_where_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{context, Column, DataFusionError, Result};
use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col};
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
use datafusion_expr::utils::check_all_column_from_schema;
use datafusion_expr::utils::check_all_columns_from_schema;
use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder};
use log::debug;
use std::collections::{BTreeSet, HashMap};
Expand Down Expand Up @@ -229,7 +229,7 @@ fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec<Expr>, Logica
let mut subquery_filters: Vec<Expr> = vec![];
for expr in subquery_filter_exprs {
let cols = expr.to_columns()?;
if check_all_column_from_schema(&cols, input_schema.clone()) {
if check_all_columns_from_schema(&cols, input_schema.clone())? {
subquery_filters.push(expr.clone());
} else {
join_filters.push(expr.clone())
Expand Down

0 comments on commit f154a9a

Please sign in to comment.