diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index ecdb03e97ee3..07185b4d6527 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -415,6 +415,18 @@ impl PartialEq for InListExpr { } } +/// Checks if two types are logically equal, dictionary types are compared by their value types. +fn is_logically_eq(lhs: &DataType, rhs: &DataType) -> bool { + match (lhs, rhs) { + (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { + v1.as_ref().eq(v2.as_ref()) + } + (DataType::Dictionary(_, l), _) => l.as_ref().eq(rhs), + (_, DataType::Dictionary(_, r)) => lhs.eq(r.as_ref()), + _ => lhs.eq(rhs), + } +} + /// Creates a unary expression InList pub fn in_list( expr: Arc, @@ -426,7 +438,7 @@ pub fn in_list( let expr_data_type = expr.data_type(schema)?; for list_expr in list.iter() { let list_expr_data_type = list_expr.data_type(schema)?; - if !expr_data_type.eq(&list_expr_data_type) { + if !is_logically_eq(&expr_data_type, &list_expr_data_type) { return internal_err!( "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" ); @@ -499,7 +511,21 @@ mod tests { macro_rules! in_list { ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; - let expr = in_list(cast_expr, cast_list_exprs, $NEGATED, $SCHEMA).unwrap(); + in_list_raw!( + $BATCH, + cast_list_exprs, + $NEGATED, + $EXPECTED, + cast_expr, + $SCHEMA + ); + }}; + } + + // applies the in_list expr to an input batch and list without cast + macro_rules! in_list_raw { + ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ + let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap(); let result = expr .evaluate(&$BATCH)? .into_array($BATCH.num_rows()) @@ -540,7 +566,7 @@ mod tests { &schema ); - // expression: "a not in ("a", "b")" + // expression: "a in ("a", "b", null)" let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; in_list!( batch, @@ -551,7 +577,7 @@ mod tests { &schema ); - // expression: "a not in ("a", "b")" + // expression: "a not in ("a", "b", null)" let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; in_list!( batch, @@ -1314,4 +1340,96 @@ mod tests { Ok(()) } + + #[test] + fn in_list_utf8_with_dict_types() -> Result<()> { + fn dict_lit(key_type: DataType, value: &str) -> Arc { + lit(ScalarValue::Dictionary( + Box::new(key_type), + Box::new(ScalarValue::new_utf8(value.to_string())), + )) + } + + fn null_dict_lit(key_type: DataType) -> Arc { + lit(ScalarValue::Dictionary( + Box::new(key_type), + Box::new(ScalarValue::Utf8(None)), + )) + } + + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + )]); + let a: UInt16DictionaryArray = + vec![Some("a"), Some("d"), None].into_iter().collect(); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ("a", "b")" + let lists = [ + vec![lit("a"), lit("b")], + vec![ + dict_lit(DataType::Int8, "a"), + dict_lit(DataType::UInt16, "b"), + ], + ]; + for list in lists.iter() { + in_list_raw!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false), None], + col_a.clone(), + &schema + ); + } + + // expression: "a not in ("a", "b")" + for list in lists.iter() { + in_list_raw!( + batch, + list.clone(), + &true, + vec![Some(false), Some(true), None], + col_a.clone(), + &schema + ); + } + + // expression: "a in ("a", "b", null)" + let lists = [ + vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], + vec![ + dict_lit(DataType::Int8, "a"), + dict_lit(DataType::UInt16, "b"), + null_dict_lit(DataType::UInt16), + ], + ]; + for list in lists.iter() { + in_list_raw!( + batch, + list.clone(), + &false, + vec![Some(true), None, None], + col_a.clone(), + &schema + ); + } + + // expression: "a not in ("a", "b", null)" + for list in lists.iter() { + in_list_raw!( + batch, + list.clone(), + &true, + vec![Some(false), None, None], + col_a.clone(), + &schema + ); + } + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index af7bf5cb16e8..891a09fbc177 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -87,6 +87,22 @@ f3 Utf8 YES f4 Float64 YES time Timestamp(Nanosecond, None) YES +# in list with dictionary input +query BBB +SELECT + tag_id in ('1000'), '1000' in (tag_id, null), arrow_cast('999','Dictionary(Int32, Utf8)') in (tag_id, null) +FROM m1 +---- +true true NULL +true true NULL +true true NULL +true true NULL +true true NULL +true true NULL +true true NULL +true true NULL +true true NULL +true true NULL # Table m2 with a tag columns `tag_id` and `type`, a field column `f5`, and `time` statement ok @@ -165,6 +181,29 @@ order by date_bin('30 minutes', time) DESC 3 400 600 500 2023-12-04T00:30:00 3 100 300 200 2023-12-04T00:00:00 +# query with in list +query BBBBBBBB +SELECT + type in ('active', 'passive') + , 'active' in (type) + , 'active' in (type, null) + , arrow_cast('passive','Dictionary(Int8, Utf8)') in (type, null) + , tag_id in ('1000', '2000') + , tag_id in ('999') + , '1000' in (tag_id, null) + , arrow_cast('999','Dictionary(Int16, Utf8)') in (tag_id, null) +FROM m2 +---- +true true true NULL true false true NULL +true true true NULL true false true NULL +true true true NULL true false true NULL +true true true NULL true false true NULL +true true true NULL true false true NULL +true true true NULL true false true NULL +true false NULL true true false true NULL +true false NULL true true false true NULL +true false NULL true true false true NULL +true false NULL true true false true NULL # Reproducer for https://github.com/apache/arrow-datafusion/issues/8738