diff --git a/src/common.rs b/src/common.rs index 7144801..e4b1cd9 100644 --- a/src/common.rs +++ b/src/common.rs @@ -3,7 +3,9 @@ use std::str::Utf8Error; use datafusion::arrow::array::{ Array, ArrayRef, AsArray, Int64Array, LargeStringArray, StringArray, StringViewArray, UInt64Array, }; +use datafusion::arrow::compute::cast; use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::error::ArrowError; use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::ColumnarValue; use jiter::{Jiter, JiterError, Peek}; @@ -20,12 +22,12 @@ pub fn return_type_check(args: &[DataType], fn_name: &str) -> DataFusionResult<( let Some(first) = args.first() else { return plan_err!("The '{fn_name}' function requires one or more arguments."); }; - if !(is_str(undict(first)) || is_json_union(first)) { + if !(is_str(unpack_dict_type(first)) || is_json_union(first)) { // if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) { return plan_err!("Unexpected argument type to '{fn_name}' at position 1, expected a string, got {first:?}."); } args.iter().skip(1).enumerate().try_for_each(|(index, arg)| { - let t = undict(arg); + let t = unpack_dict_type(arg); if is_str(t) || is_int(t) { Ok(()) } else { @@ -46,7 +48,16 @@ fn is_int(d: &DataType) -> bool { matches!(d, DataType::UInt64 | DataType::Int64) } -fn undict(d: &DataType) -> &DataType { +/// Convert a dict array to a non-dict array. +fn unpack_dict_array(array: ArrayRef) -> Result { + match array.data_type() { + DataType::Dictionary(_, value_type) => cast(array.as_ref(), value_type), + _ => Ok(array), + } +} + +// if the type is a dict, return the value type, otherwise return the type +fn unpack_dict_type(d: &DataType) -> &DataType { if let DataType::Dictionary(_, value) = d { value.as_ref() } else { @@ -129,7 +140,8 @@ fn invoke_array> + 'static, I>( jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, ) -> DataFusionResult { if let Some(d) = needle_array.as_any_dictionary_opt() { - invoke_array(json_array, d.values(), to_array, jiter_find) + let values = invoke_array(json_array, d.values(), to_array, jiter_find)?; + unpack_dict_array(d.with_values(values)).map_err(Into::into) } else if let Some(str_path_array) = needle_array.as_any().downcast_ref::() { let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); zip_apply(json_array, paths, to_array, jiter_find, true) @@ -160,7 +172,8 @@ fn zip_apply<'a, P: Iterator>>, C: FromIterator() { @@ -226,7 +239,8 @@ fn scalar_apply>, I>( ) -> DataFusionResult { if let Some(d) = json_array.as_any_dictionary_opt() { // as above, don't return a dict - return scalar_apply(d.values(), path, to_array, jiter_find); + let values = scalar_apply(d.values(), path, to_array, jiter_find)?; + return unpack_dict_array(d.with_values(values)).map_err(Into::into); } let c = if let Some(string_array) = json_array.as_any().downcast_ref::() { diff --git a/tests/main.rs b/tests/main.rs index 9dd49b7..14849b6 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1,10 +1,14 @@ -use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, RecordBatch}; +use datafusion::arrow::datatypes::{Field, Int8Type, Schema}; +use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType}; use datafusion::assert_batches_eq; use datafusion::common::ScalarValue; use datafusion::logical_expr::ColumnarValue; - +use datafusion::prelude::SessionContext; use datafusion_functions_json::udfs::json_get_str_udf; -use utils::{display_val, logical_plan, run_query, run_query_large, run_query_params}; +use utils::{create_context, display_val, logical_plan, run_query, run_query_large, run_query_params}; mod utils; @@ -197,11 +201,11 @@ async fn test_json_get_no_path() { let batches = run_query(r#"select json_get('"foo"')::string"#).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Utf8, "foo".to_string())); - let batches = run_query(r#"select json_get('123')::int"#).await.unwrap(); + let batches = run_query(r"select json_get('123')::int").await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Int64, "123".to_string())); - let batches = run_query(r#"select json_get('true')::int"#).await.unwrap(); - assert_eq!(display_val(batches).await, (DataType::Int64, "".to_string())); + let batches = run_query(r"select json_get('true')::int").await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, String::new())); } #[tokio::test] @@ -350,7 +354,7 @@ async fn test_json_length_object() { let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::UInt64, "3".to_string())); - let sql = r#"select json_length('{}')"#; + let sql = r"select json_length('{}')"; let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::UInt64, "0".to_string())); } @@ -359,7 +363,7 @@ async fn test_json_length_object() { async fn test_json_length_string() { let sql = r#"select json_length('"foobar"')"#; let batches = run_query(sql).await.unwrap(); - assert_eq!(display_val(batches).await, (DataType::UInt64, "".to_string())); + assert_eq!(display_val(batches).await, (DataType::UInt64, String::new())); } #[tokio::test] @@ -370,7 +374,7 @@ async fn test_json_length_object_nested() { let sql = r#"select json_length('{"a": 1, "b": 2, "c": []}', 'b')"#; let batches = run_query(sql).await.unwrap(); - assert_eq!(display_val(batches).await, (DataType::UInt64, "".to_string())); + assert_eq!(display_val(batches).await, (DataType::UInt64, String::new())); } #[tokio::test] @@ -455,7 +459,7 @@ async fn test_json_contains_large_both_params() { #[tokio::test] async fn test_json_length_vec() { - let sql = r#"select name, json_len(json_data) as len from test"#; + let sql = r"select name, json_len(json_data) as len from test"; let batches = run_query(sql).await.unwrap(); let expected = [ @@ -479,7 +483,7 @@ async fn test_json_length_vec() { #[tokio::test] async fn test_no_args() { - let err = run_query(r#"select json_len()"#).await.unwrap_err(); + let err = run_query(r"select json_len()").await.unwrap_err(); assert!(err .to_string() .contains("No function matches the given name and argument types 'json_length()'.")); @@ -562,10 +566,10 @@ async fn test_json_get_nested_collapsed() { #[tokio::test] async fn test_json_get_cte() { // avoid auto-un-nesting with a CTE - let sql = r#" + let sql = r" with t as (select name, json_get(json_data, 'foo') j from test) select name, json_get(j, 0) v from t - "#; + "; let expected = [ "+------------------+---------+", "| name | v |", @@ -587,11 +591,11 @@ async fn test_json_get_cte() { #[tokio::test] async fn test_plan_json_get_cte() { // avoid auto-unnesting with a CTE - let sql = r#" + let sql = r" explain with t as (select name, json_get(json_data, 'foo') j from test) select name, json_get(j, 0) v from t - "#; + "; let expected = [ "Projection: t.name, json_get(t.j, Int64(0)) AS v", " SubqueryAlias: t", @@ -751,7 +755,7 @@ async fn test_arrow() { #[tokio::test] async fn test_plan_arrow() { - let lines = logical_plan(r#"explain select json_data->'foo' from test"#).await; + let lines = logical_plan(r"explain select json_data->'foo' from test").await; let expected = [ "Projection: json_get(test.json_data, Utf8(\"foo\")) AS test.json_data -> Utf8(\"foo\")", @@ -783,7 +787,7 @@ async fn test_long_arrow() { #[tokio::test] async fn test_plan_long_arrow() { - let lines = logical_plan(r#"explain select json_data->>'foo' from test"#).await; + let lines = logical_plan(r"explain select json_data->>'foo' from test").await; let expected = [ "Projection: json_as_text(test.json_data, Utf8(\"foo\")) AS test.json_data ->> Utf8(\"foo\")", @@ -834,7 +838,7 @@ async fn test_arrow_cast_int() { #[tokio::test] async fn test_plan_arrow_cast_int() { - let lines = logical_plan(r#"explain select (json_data->'foo')::int from test"#).await; + let lines = logical_plan(r"explain select (json_data->'foo')::int from test").await; let expected = [ "Projection: json_get_int(test.json_data, Utf8(\"foo\")) AS test.json_data -> Utf8(\"foo\")", @@ -866,7 +870,7 @@ async fn test_arrow_double_nested() { #[tokio::test] async fn test_plan_arrow_double_nested() { - let lines = logical_plan(r#"explain select json_data->'foo'->0 from test"#).await; + let lines = logical_plan(r"explain select json_data->'foo'->0 from test").await; let expected = [ "Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> Utf8(\"foo\") -> Int64(0)", @@ -900,7 +904,7 @@ async fn test_arrow_double_nested_cast() { #[tokio::test] async fn test_plan_arrow_double_nested_cast() { - let lines = logical_plan(r#"explain select (json_data->'foo'->0)::int from test"#).await; + let lines = logical_plan(r"explain select (json_data->'foo'->0)::int from test").await; let expected = [ "Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> Utf8(\"foo\") -> Int64(0)", @@ -948,7 +952,7 @@ async fn test_arrow_nested_double_columns() { async fn test_lexical_precedence_wrong() { let sql = r#"select '{"a": "b"}'->>'a'='b' as v"#; let err = run_query(sql).await.unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_as_text' at position 2, expected string or int, got Boolean.") + assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_as_text' at position 2, expected string or int, got Boolean."); } #[tokio::test] @@ -1261,3 +1265,81 @@ async fn test_dict_get_int() { let batches = run_query(sql).await.unwrap(); assert_batches_eq!(expected, &batches); } + +async fn build_dict_schema() -> SessionContext { + let mut builder = StringDictionaryBuilder::::new(); + builder.append(r#"{"foo": "bar"}"#).unwrap(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append("nah").unwrap(); + builder.append(r#"{"baz": "abcd"}"#).unwrap(); + builder.append_null(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append_null(); + + let dict = builder.finish(); + let array = Arc::new(dict) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + true, + )])); + + let data = RecordBatch::try_new(schema.clone(), vec![array]).unwrap(); + + let ctx = create_context().await.unwrap(); + ctx.register_batch("data", data).unwrap(); + ctx +} + +#[tokio::test] +async fn test_dict_filter() { + let ctx = build_dict_schema().await; + + let sql = "select json_get(x, 'baz') v from data"; + let expected = [ + "+------------+", + "| v |", + "+------------+", + "| {null=} |", + "| {str=fizz} |", + "| {null=} |", + "| {str=abcd} |", + "| {null=} |", + "| {str=fizz} |", + "| {str=fizz} |", + "| {str=fizz} |", + "| {str=fizz} |", + "| {null=} |", + "+------------+", + ]; + + let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); + + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_filter_is_not_null() { + let ctx = build_dict_schema().await; + let sql = "select x from data where json_get(x, 'baz') is not null"; + let expected = [ + "+-----------------+", + "| x |", + "+-----------------+", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"abcd\"} |", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"fizz\"} |", + "+-----------------+", + ]; + + let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); + + assert_batches_eq!(expected, &batches); +} diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index cf3ff2e..46a871b 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -13,10 +13,15 @@ use datafusion::execution::context::SessionContext; use datafusion::prelude::SessionConfig; use datafusion_functions_json::register_all; -async fn create_test_table(large_utf8: bool) -> Result { +pub async fn create_context() -> Result { let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres"); let mut ctx = SessionContext::new_with_config(config); register_all(&mut ctx)?; + Ok(ctx) +} + +async fn create_test_table(large_utf8: bool) -> Result { + let ctx = create_context().await?; let test_data = [ ("object_foo", r#" {"foo": "abc"} "#), @@ -214,5 +219,5 @@ pub async fn logical_plan(sql: &str) -> Vec { let batches = run_query(sql).await.unwrap(); let plan_col = batches[0].column(1).as_any().downcast_ref::().unwrap(); let logical_plan = plan_col.value(0); - logical_plan.split('\n').map(|s| s.to_string()).collect() + logical_plan.split('\n').map(ToString::to_string).collect() }