diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ec099d4..08b141b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,8 +12,8 @@ repos: - repo: local hooks: - - id: format-check - name: Format Check + - id: format + name: Format entry: cargo fmt types: [rust] language: system diff --git a/README.md b/README.md index 19fcb8c..569d609 100644 --- a/README.md +++ b/README.md @@ -17,14 +17,15 @@ To register the below JSON functions in your `SessionContext`. ## Done -* [x] `json_contains(json: str, *keys: str | int) -> bool` - true if a JSON object has a specific key -* [x] `json_get(json: str, *keys: str | int) -> JsonUnion` - Get a value from a JSON object by its "path" -* [x] `json_get_str(json: str, *keys: str | int) -> str` - Get a string value from a JSON object by its "path" -* [x] `json_get_int(json: str, *keys: str | int) -> int` - Get an integer value from a JSON object by its "path" -* [x] `json_get_float(json: str, *keys: str | int) -> float` - Get a float value from a JSON object by its "path" -* [x] `json_get_bool(json: str, *keys: str | int) -> bool` - Get a boolean value from a JSON object by its "path" -* [x] `json_get_json(json: str, *keys: str | int) -> str` - Get any value from a JSON object by its "path", represented as a string -* [x] `json_length(json: str, *keys: str | int) -> int` - get the length of a JSON object or array +* [x] `json_contains(json: str, *keys: str | int) -> bool` - true if a JSON string has a specific key (used for the `?` operator) +* [x] `json_get(json: str, *keys: str | int) -> JsonUnion` - Get a value from a JSON string by its "path" +* [x] `json_get_str(json: str, *keys: str | int) -> str` - Get a string value from a JSON string by its "path" +* [x] `json_get_int(json: str, *keys: str | int) -> int` - Get an integer value from a JSON string by its "path" +* [x] `json_get_float(json: str, *keys: str | int) -> float` - Get a float value from a JSON string by its "path" +* [x] `json_get_bool(json: str, *keys: str | int) -> bool` - Get a boolean value from a JSON string by its "path" +* [x] `json_get_json(json: str, *keys: str | int) -> str` - Get a nested raw JSON string from a JSON string by its "path" +* [x] `json_as_text(json: str, *keys: str | int) -> str` - Get any value from a JSON string by its "path", represented as a string (used for the `->>` operator) +* [x] `json_length(json: str, *keys: str | int) -> int` - get the length of a JSON string or array Cast expressions with `json_get` are rewritten to the appropriate method, e.g. @@ -38,7 +39,7 @@ select * from foo where json_get_str(attributes, 'bar')='ham' ## TODO (maybe, if they're actually useful) -* [ ] `json_keys(json: str, *keys: str | int) -> list[str]` - get the keys of a JSON object +* [ ] `json_keys(json: str, *keys: str | int) -> list[str]` - get the keys of a JSON string * [ ] `json_is_obj(json: str, *keys: str | int) -> bool` - true if the JSON is an object * [ ] `json_is_array(json: str, *keys: str | int) -> bool` - true if the JSON is an array * [ ] `json_valid(json: str) -> bool` - true if the JSON is valid diff --git a/src/json_as_text.rs b/src/json_as_text.rs new file mode 100644 index 0000000..9f02470 --- /dev/null +++ b/src/json_as_text.rs @@ -0,0 +1,84 @@ +use std::any::Any; +use std::sync::Arc; + +use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common_macros::make_udf_function; +use arrow::array::{ArrayRef, StringArray}; +use arrow_schema::DataType; +use datafusion_common::{Result as DataFusionResult, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::Peek; + +make_udf_function!( + JsonAsText, + json_as_text, + json_data path, + r#"Get any value from a JSON string by its "path", represented as a string"# +); + +#[derive(Debug)] +pub(super) struct JsonAsText { + signature: Signature, + aliases: [String; 1], +} + +impl Default for JsonAsText { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: ["json_as_text".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonAsText { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + check_args(arg_types, self.name()).map(|()| DataType::Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + invoke::( + args, + jiter_json_as_text, + |c| Ok(Arc::new(c) as ArrayRef), + ScalarValue::Utf8, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_as_text(opt_json: Option<&str>, path: &[JsonPath]) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { + match peek { + Peek::Null => { + jiter.known_null()?; + get_err!() + } + Peek::String => Ok(jiter.known_str()?.to_owned()), + _ => { + let start = jiter.current_index(); + jiter.known_skip(peek)?; + let object_slice = jiter.slice_to_current(start); + let object_string = std::str::from_utf8(object_slice)?; + Ok(object_string.to_owned()) + } + } + } else { + get_err!() + } +} diff --git a/src/json_get.rs b/src/json_get.rs index b212aa3..b1e4810 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -16,7 +16,7 @@ make_udf_function!( JsonGet, json_get, json_data path, - r#"Get a value from a JSON object by its "path""# + r#"Get a value from a JSON string by its "path""# ); // build_typed_get!(JsonGet, "json_get", Union, Float64Array, jiter_json_get_float); diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 4f9d2e1..92a4ac9 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -14,7 +14,7 @@ make_udf_function!( JsonGetBool, json_get_bool, json_data path, - r#"Get an boolean value from a JSON object by its "path""# + r#"Get an boolean value from a JSON string by its "path""# ); #[derive(Debug)] diff --git a/src/json_get_float.rs b/src/json_get_float.rs index 4e0cec2..bed8c67 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -14,7 +14,7 @@ make_udf_function!( JsonGetFloat, json_get_float, json_data path, - r#"Get a float value from a JSON object by its "path""# + r#"Get a float value from a JSON string by its "path""# ); #[derive(Debug)] diff --git a/src/json_get_int.rs b/src/json_get_int.rs index 6f49380..4f80256 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -14,7 +14,7 @@ make_udf_function!( JsonGetInt, json_get_int, json_data path, - r#"Get an integer value from a JSON object by its "path""# + r#"Get an integer value from a JSON string by its "path""# ); #[derive(Debug)] diff --git a/src/json_get_json.rs b/src/json_get_json.rs index 02e3422..002702b 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -13,7 +13,7 @@ make_udf_function!( JsonGetJson, json_get_json, json_data path, - r#"Get any value from a JSON object by its "path", represented as a string"# + r#"Get a nested raw JSON string from a JSON string by its "path""# ); #[derive(Debug)] diff --git a/src/json_get_str.rs b/src/json_get_str.rs index a45f0ac..a6f4ad5 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -14,7 +14,7 @@ make_udf_function!( JsonGetStr, json_get_str, json_data path, - r#"Get a string value from a JSON object by its "path""# + r#"Get a string value from a JSON string by its "path""# ); #[derive(Debug)] diff --git a/src/lib.rs b/src/lib.rs index 94484aa..c576794 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ use std::sync::Arc; mod common; mod common_macros; mod common_union; +mod json_as_text; mod json_contains; mod json_get; mod json_get_bool; @@ -18,6 +19,7 @@ mod json_length; mod rewrite; pub mod functions { + pub use crate::json_as_text::json_as_text; pub use crate::json_contains::json_contains; pub use crate::json_get::json_get; pub use crate::json_get_bool::json_get_bool; @@ -29,6 +31,7 @@ pub mod functions { } pub mod udfs { + pub use crate::json_as_text::json_as_text_udf; pub use crate::json_contains::json_contains_udf; pub use crate::json_get::json_get_udf; pub use crate::json_get_bool::json_get_bool_udf; @@ -55,6 +58,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { json_get_float::json_get_float_udf(), json_get_int::json_get_int_udf(), json_get_json::json_get_json_udf(), + json_as_text::json_as_text_udf(), json_get_str::json_get_str_udf(), json_contains::json_contains_udf(), json_length::json_length_udf(), diff --git a/src/rewrite.rs b/src/rewrite.rs index 6fed579..0e600c4 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -89,7 +89,7 @@ impl ExprPlanner for JsonExprPlanner { fn plan_binary_op(&self, expr: RawBinaryExpr, _schema: &DFSchema) -> Result> { let (func, op_display) = match &expr.op { BinaryOperator::Arrow => (crate::json_get::json_get_udf(), "->"), - BinaryOperator::LongArrow => (crate::json_get_str::json_get_str_udf(), "->>"), + BinaryOperator::LongArrow => (crate::json_as_text::json_as_text_udf(), "->>"), BinaryOperator::Question => (crate::json_contains::json_contains_udf(), "?"), _ => return Ok(PlannerResult::Original(expr)), }; diff --git a/tests/main.rs b/tests/main.rs index efa916e..cb9733a 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -755,8 +755,8 @@ async fn test_long_arrow() { "| name | json_data ->> Utf8(\"foo\") |", "+------------------+---------------------------+", "| object_foo | abc |", - "| object_foo_array | |", - "| object_foo_obj | |", + "| object_foo_array | [1] |", + "| object_foo_obj | {} |", "| object_foo_null | |", "| object_bar | |", "| list_foo | |", @@ -771,7 +771,7 @@ async fn test_plan_long_arrow() { let lines = logical_plan(r#"explain select json_data->>'foo' from test"#).await; let expected = [ - "Projection: json_get_str(test.json_data, Utf8(\"foo\")) AS json_data ->> Utf8(\"foo\")", + "Projection: json_as_text(test.json_data, Utf8(\"foo\")) AS json_data ->> Utf8(\"foo\")", " TableScan: test projection=[json_data]", ]; @@ -789,8 +789,8 @@ async fn test_long_arrow_eq_str() { "| name | json_data ->> Utf8(\"foo\") = Utf8(\"abc\") |", "+------------------+-----------------------------------------+", "| object_foo | true |", - "| object_foo_array | |", - "| object_foo_obj | |", + "| object_foo_array | false |", + "| object_foo_obj | false |", "| object_foo_null | |", "| object_bar | |", "| list_foo | |", @@ -933,7 +933,7 @@ async fn test_arrow_nested_double_columns() { async fn test_lexical_precedence_wrong() { let sql = r#"select '{"a": "b"}'->>'a'='b' as v"#; let err = run_query(sql).await.unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_get_str' at position 2, expected string or int, got Boolean.") + assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_as_text' at position 2, expected string or int, got Boolean.") } #[tokio::test] @@ -1099,3 +1099,20 @@ async fn test_arrow_scalar_union_is_null() { ]; assert_batches_eq!(expected, &batches); } + +#[tokio::test] +async fn test_arrow_cast() { + let batches = run_query("select (json_data->>'foo')::int from other").await.unwrap(); + + let expected = [ + "+---------------------------+", + "| json_data ->> Utf8(\"foo\") |", + "+---------------------------+", + "| 42 |", + "| 42 |", + "| |", + "| |", + "+---------------------------+", + ]; + assert_batches_eq!(expected, &batches); +}