From 9081ba8ffbda2a99bff826b9166c3fd9fbdccd67 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 6 Nov 2024 13:48:28 -0800 Subject: [PATCH 1/6] rebase & add daft-connect --- Cargo.lock | 250 ++- Cargo.toml | 37 +- src/daft-connect/Cargo.toml | 38 + src/daft-connect/src/command.rs | 128 ++ src/daft-connect/src/config.rs | 212 +++ src/daft-connect/src/convert.rs | 45 + .../src/convert/data_conversion.rs | 59 + .../src/convert/data_conversion/range.rs | 42 + .../convert/data_conversion/show_string.rs | 59 + src/daft-connect/src/convert/expression.rs | 120 ++ src/daft-connect/src/convert/formatting.rs | 69 + .../src/convert/plan_conversion.rs | 134 ++ .../src/convert/schema_conversion.rs | 56 + src/daft-connect/src/lib.rs | 449 +++++ src/daft-connect/src/main.rs | 32 + src/daft-connect/src/session.rs | 48 + src/daft-connect/src/util.rs | 109 ++ src/daft-local-execution/src/lib.rs | 2 +- src/daft-table/src/lib.rs | 10 + tests/connect/__init__.py | 0 tests/connect/conf.py | 150 ++ tests/connect/test_client.py | 434 +++++ tests/connect/test_conf.py | 113 ++ tests/connect/test_config.py | 153 ++ tests/connect/test_connect.py | 127 ++ tests/connect/test_connect_basic.py | 1477 +++++++++++++++++ tests/connect/test_session.py | 265 +++ 27 files changed, 4584 insertions(+), 34 deletions(-) create mode 100644 src/daft-connect/Cargo.toml create mode 100644 src/daft-connect/src/command.rs create mode 100644 src/daft-connect/src/config.rs create mode 100644 src/daft-connect/src/convert.rs create mode 100644 src/daft-connect/src/convert/data_conversion.rs create mode 100644 src/daft-connect/src/convert/data_conversion/range.rs create mode 100644 src/daft-connect/src/convert/data_conversion/show_string.rs create mode 100644 src/daft-connect/src/convert/expression.rs create mode 100644 src/daft-connect/src/convert/formatting.rs create mode 100644 src/daft-connect/src/convert/plan_conversion.rs create mode 100644 src/daft-connect/src/convert/schema_conversion.rs create mode 100644 src/daft-connect/src/lib.rs create mode 100644 src/daft-connect/src/main.rs create mode 100644 src/daft-connect/src/session.rs create mode 100644 src/daft-connect/src/util.rs create mode 100644 tests/connect/__init__.py create mode 100644 tests/connect/conf.py create mode 100644 tests/connect/test_client.py create mode 100644 tests/connect/test_conf.py create mode 100644 tests/connect/test_config.py create mode 100644 tests/connect/test_connect.py create mode 100755 tests/connect/test_connect_basic.py create mode 100644 tests/connect/test_session.py diff --git a/Cargo.lock b/Cargo.lock index ec65979338..645373c596 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1054,6 +1054,9 @@ name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +dependencies = [ + "serde", +] [[package]] name = "block-buffer" @@ -1843,6 +1846,7 @@ dependencies = [ "common-tracing", "common-version", "daft-compression", + "daft-connect", "daft-core", "daft-csv", "daft-dsl", @@ -1883,6 +1887,35 @@ dependencies = [ "url", ] +[[package]] +name = "daft-connect" +version = "0.3.0-dev0" +dependencies = [ + "arrow2", + "common-daft-config", + "daft-core", + "daft-dsl", + "daft-local-execution", + "daft-physical-plan", + "daft-plan", + "daft-schema", + "daft-table", + "dashmap", + "eyre", + "futures", + "pyo3", + "ron", + "spark-connect", + "tempfile", + "tokio", + "tokio-stream", + "tonic", + "tracing", + "tracing-subscriber", + "tracing-tracy", + "uuid 1.10.0", +] + [[package]] name = "daft-core" version = "0.3.0-dev0" @@ -2477,6 +2510,20 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deflate64" version = "0.1.9" @@ -2721,6 +2768,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fallible-streaming-iterator" version = "0.1.9" @@ -2825,9 +2882,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -2840,9 +2897,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -2850,15 +2907,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -2867,9 +2924,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" @@ -2888,9 +2945,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -2899,15 +2956,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -2917,9 +2974,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -2933,6 +2990,19 @@ dependencies = [ "slab", ] +[[package]] +name = "generator" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb949699c3e4df3a183b1d2142cb24277057055ed23c68ed58894f76c517223" +dependencies = [ + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.58.0", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -3457,7 +3527,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core", + "windows-core 0.52.0", ] [[package]] @@ -3513,6 +3583,12 @@ dependencies = [ "quick-error 2.0.1", ] +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -3866,6 +3942,19 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lz4" version = "1.26.0" @@ -5138,6 +5227,19 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" +[[package]] +name = "ron" +version = "0.9.0-alpha.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c0bd893640cac34097a74f0c2389ddd54c62d6a3c635fa93cafe6b6bc19be6a" +dependencies = [ + "base64 0.21.7", + "bitflags 2.6.0", + "serde", + "serde_derive", + "unicode-ident", +] + [[package]] name = "rstest" version = "0.18.2" @@ -5298,6 +5400,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -5832,7 +5940,7 @@ dependencies = [ "ntapi", "once_cell", "rayon", - "windows", + "windows 0.52.0", ] [[package]] @@ -6091,9 +6199,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.2" +version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" dependencies = [ "backtrace", "bytes", @@ -6288,6 +6396,38 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tracing-tracy" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc775fdaf33c3dfd19dc354729e65e87914bc67dcdc390ca1210807b8bee5902" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracy-client", +] + +[[package]] +name = "tracy-client" +version = "0.17.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "746b078c6a09ebfd5594609049e07116735c304671eaab06ce749854d23435bc" +dependencies = [ + "loom", + "once_cell", + "tracy-client-sys", +] + +[[package]] +name = "tracy-client-sys" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3637e734239e12ab152cd269302500bd063f37624ee210cd04b4936ed671f3b1" +dependencies = [ + "cc", + "windows-targets 0.52.6", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -6702,7 +6842,17 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ - "windows-core", + "windows-core 0.52.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core 0.58.0", "windows-targets 0.52.6", ] @@ -6715,6 +6865,60 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index 5204ac2f81..68313903c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,10 +9,11 @@ common-system-info = {path = "src/common/system-info", default-features = false} common-tracing = {path = "src/common/tracing", default-features = false} common-version = {path = "src/common/version", default-features = false} daft-compression = {path = "src/daft-compression", default-features = false} +daft-connect = {path = "src/daft-connect", optional = true} daft-core = {path = "src/daft-core", default-features = false} daft-csv = {path = "src/daft-csv", default-features = false} daft-dsl = {path = "src/daft-dsl", default-features = false} -daft-functions = {path = "src/daft-functions", default-features = false} +daft-functions = {path = "src/daft-functions"} daft-functions-json = {path = "src/daft-functions-json", default-features = false} daft-hash = {path = "src/daft-hash", default-features = false} daft-image = {path = "src/daft-image", default-features = false} @@ -44,6 +45,11 @@ python = [ "common-display/python", "common-resource-request/python", "common-system-info/python", + "common-daft-config/python", + "common-display/python", + "common-resource-request/python", + "common-system-info/python", + "daft-connect/python", "daft-core/python", "daft-csv/python", "daft-dsl/python", @@ -61,14 +67,10 @@ python = [ "daft-scheduler/python", "daft-sql/python", "daft-stats/python", - "daft-table/python", - "daft-functions/python", - "daft-functions-json/python", + "daft-stats/python", "daft-writers/python", - "common-daft-config/python", - "common-system-info/python", - "common-display/python", - "common-resource-request/python", + "daft-table/python", + "dep:daft-connect", "dep:pyo3", "dep:pyo3-log" ] @@ -154,6 +156,7 @@ members = [ "src/daft-table", "src/daft-writers", "src/hyperloglog", + "src/daft-connect", "src/parquet2", # "src/spark-connect-script", "src/generated/spark-connect" @@ -161,6 +164,7 @@ members = [ [workspace.dependencies] ahash = "0.8.11" +anyhow = "1.0.89" approx = "0.5.1" async-compat = "0.2.3" async-compression = {version = "0.4.12", features = [ @@ -174,8 +178,20 @@ bytes = "1.6.0" chrono = "0.4.38" chrono-tz = "0.8.4" comfy-table = "7.1.1" +common-daft-config = {path = "src/common/daft-config"} +common-display = {path = "src/common/display"} common-error = {path = "src/common/error", default-features = false} +daft-connect = {path = "src/daft-connect", default-features = false} +daft-core = {path = "src/daft-core"} +daft-dsl = {path = "src/daft-dsl"} daft-hash = {path = "src/daft-hash"} +daft-local-execution = {path = "src/daft-local-execution"} +daft-micropartition = {path = "src/daft-micropartition"} +daft-physical-plan = {path = "src/daft-physical-plan"} +daft-plan = {path = "src/daft-plan"} +daft-schema = {path = "src/daft-schema"} +daft-sql = {path = "src/daft-sql"} +daft-table = {path = "src/daft-table"} derivative = "2.2.0" derive_builder = "0.20.2" divan = "0.1.14" @@ -204,6 +220,7 @@ serde_json = "1.0.116" sha1 = "0.11.0-pre.4" sketches-ddsketch = {version = "0.2.2", features = ["use_serde"]} snafu = {version = "0.7.4", features = ["futures"]} +spark-connect = {path = "src/spark-connect", default-features = false} sqlparser = "0.51.0" sysinfo = "0.30.12" tango-bench = "0.6.0" @@ -233,7 +250,7 @@ path = "src/arrow2" version = "1.3.3" [workspace.dependencies.derive_more] -features = ["display"] +features = ["display", "from", "constructor"] version = "1.0.0" [workspace.dependencies.lazy_static] @@ -321,7 +338,7 @@ uninlined_format_args = "allow" unnecessary_wraps = "allow" unnested_or_patterns = "allow" unreadable_literal = "allow" -# todo: remove? +# todo: remove this at some point unsafe_derive_deserialize = "allow" unused_async = "allow" # used_underscore_items = "allow" # REMOVE diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml new file mode 100644 index 0000000000..e82be71bbc --- /dev/null +++ b/src/daft-connect/Cargo.toml @@ -0,0 +1,38 @@ +[dependencies] +dashmap = "6.1.0" +# papaya = "0.1.3" +eyre = "0.6.12" +futures = "0.3.31" +pyo3 = {workspace = true, optional = true} +ron = "0.9.0-alpha.0" +tokio = {version = "1.40.0", features = ["full"]} +tokio-stream = "0.1.16" +tonic = "0.12.3" +tracing-subscriber = {version = "0.3.18", features = ["env-filter"]} +tracing-tracy = "0.11.3" +uuid = {version = "1.10.0", features = ["v4"]} +arrow2.workspace = true +common-daft-config.workspace = true +daft-core.workspace = true +daft-dsl.workspace = true +daft-local-execution.workspace = true +daft-physical-plan.workspace = true +daft-plan.workspace = true +daft-schema.workspace = true +daft-table.workspace = true +spark-connect.workspace = true +tracing.workspace = true + +[dev-dependencies] +tempfile = "3.4.0" + +[features] +python = ["dep:pyo3"] + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "daft-connect" +version = {workspace = true} diff --git a/src/daft-connect/src/command.rs b/src/daft-connect/src/command.rs new file mode 100644 index 0000000000..c999e2bac0 --- /dev/null +++ b/src/daft-connect/src/command.rs @@ -0,0 +1,128 @@ +// Stream of Result + +use std::thread; + +use arrow2::io::ipc::write::StreamWriter; +use daft_table::Table; +use eyre::Context; +use futures::TryStreamExt; +use spark_connect::{ + execute_plan_response::{ArrowBatch, ResponseType, ResultComplete}, + spark_connect_service_server::SparkConnectService, + ExecutePlanResponse, Relation, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::Status; +use uuid::Uuid; + +use crate::{convert::convert_data, DaftSparkConnectService, Session}; + +type DaftStream = ::ExecutePlanStream; + +struct ExecutablePlanChannel { + session_id: String, + server_side_session_id: String, + operation_id: String, + tx: tokio::sync::mpsc::UnboundedSender>, +} + +pub trait ConcreteDataChannel { + fn send_table(&mut self, table: &Table) -> eyre::Result<()>; +} + +impl ConcreteDataChannel for ExecutablePlanChannel { + fn send_table(&mut self, table: &Table) -> eyre::Result<()> { + let mut data = Vec::new(); + + let mut writer = StreamWriter::new( + &mut data, + arrow2::io::ipc::write::WriteOptions { compression: None }, + ); + + let row_count = table.num_rows(); + + let schema = table + .schema + .to_arrow() + .wrap_err("Failed to convert Daft schema to Arrow schema")?; + + writer + .start(&schema, None) + .wrap_err("Failed to start Arrow stream writer with schema")?; + + let arrays = table.get_inner_arrow_arrays().collect(); + let chunk = arrow2::chunk::Chunk::new(arrays); + + writer + .write(&chunk, None) + .wrap_err("Failed to write Arrow chunk to stream writer")?; + + let response = ExecutePlanResponse { + session_id: self.session_id.to_string(), + server_side_session_id: self.server_side_session_id.to_string(), + operation_id: self.operation_id.to_string(), + response_id: Uuid::new_v4().to_string(), // todo: implement this + metrics: None, // todo: implement this + observed_metrics: vec![], + schema: None, + response_type: Some(ResponseType::ArrowBatch(ArrowBatch { + row_count: row_count as i64, + data, + start_offset: None, + })), + }; + + self.tx + .send(Ok(response)) + .wrap_err("Error sending response to client")?; + + Ok(()) + } +} + +impl Session { + pub async fn handle_root_command( + &self, + command: Relation, + operation_id: String, + ) -> Result { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + let mut channel = ExecutablePlanChannel { + session_id: self.id().to_string(), + server_side_session_id: self.server_side_session_id().to_string(), + operation_id: operation_id.clone(), + tx: tx.clone(), + }; + + thread::spawn({ + let session_id = self.id().to_string(); + let server_side_session_id = self.server_side_session_id().to_string(); + move || { + let result = convert_data(command, &mut channel); + + if let Err(e) = result { + tx.send(Err(e)).unwrap(); + } else { + let finished = ExecutePlanResponse { + session_id, + server_side_session_id, + operation_id: operation_id.to_string(), + response_id: Uuid::new_v4().to_string(), + metrics: None, + observed_metrics: vec![], + schema: None, + response_type: Some(ResponseType::ResultComplete(ResultComplete {})), + }; + + tx.send(Ok(finished)).unwrap(); + } + } + }); + + let recv_stream = + UnboundedReceiverStream::new(rx).map_err(|e| Status::internal(e.to_string())); + + Ok(Box::pin(recv_stream)) + } +} diff --git a/src/daft-connect/src/config.rs b/src/daft-connect/src/config.rs new file mode 100644 index 0000000000..863a800b2a --- /dev/null +++ b/src/daft-connect/src/config.rs @@ -0,0 +1,212 @@ +use std::collections::BTreeMap; + +use spark_connect::{ + config_request::{Get, GetAll, GetOption, GetWithDefault, IsModifiable, Set, Unset}, + ConfigResponse, KeyValue, +}; +use tonic::Status; + +use crate::Session; + +impl Session { + fn config_response(&self) -> ConfigResponse { + ConfigResponse { + session_id: self.id().to_string(), + server_side_session_id: self.server_side_session_id().to_string(), + pairs: vec![], + warnings: vec![], + } + } + + pub fn set(&mut self, operation: Set) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("set", session_id = %self.id(), ?operation); + let _enter = span.enter(); + + for KeyValue { key, value } in operation.pairs { + let Some(value) = value else { + let msg = format!("Missing value for key {key}. If you want to unset a value use the Unset operation"); + response.warnings.push(msg); + continue; + }; + + let previous = self.config_values_mut().insert(key.clone(), value.clone()); + if previous.is_some() { + tracing::info!("Updated existing configuration value"); + } else { + tracing::info!("Set new configuration value"); + } + } + + Ok(response) + } + + pub fn get(&self, operation: Get) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get", session_id = %self.id()); + let _enter = span.enter(); + + for key in operation.keys { + let value = self.config_values().get(&key).cloned(); + response.pairs.push(KeyValue { key, value }); + } + + Ok(response) + } + + pub fn get_with_default(&self, operation: GetWithDefault) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get_with_default", session_id = %self.id()); + let _enter = span.enter(); + + for KeyValue { + key, + value: default_value, + } in operation.pairs + { + let value = self.config_values().get(&key).cloned().or(default_value); + response.pairs.push(KeyValue { key, value }); + } + + Ok(response) + } + + /// Needs to be fixed so it has different behavior than [`Session::get`]. Not entirely + /// sure how it should work yet. + pub fn get_option(&self, operation: GetOption) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get_option", session_id = %self.id()); + let _enter = span.enter(); + + for key in operation.keys { + let value = self.config_values().get(&key).cloned(); + response.pairs.push(KeyValue { key, value }); + } + + Ok(response) + } + + pub fn get_all(&self, operation: GetAll) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get_all", session_id = %self.id()); + let _enter = span.enter(); + + let Some(prefix) = operation.prefix else { + for (key, value) in self.config_values() { + response.pairs.push(KeyValue { + key: key.clone(), + value: Some(value.clone()), + }); + } + return Ok(response); + }; + + for (k, v) in prefix_search(self.config_values(), &prefix) { + response.pairs.push(KeyValue { + key: k.clone(), + value: Some(v.clone()), + }); + } + + Ok(response) + } + + pub fn unset(&mut self, operation: Unset) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("unset", session_id = %self.id()); + let _enter = span.enter(); + + for key in operation.keys { + if self.config_values_mut().remove(&key).is_none() { + let msg = format!("Key {key} not found"); + response.warnings.push(msg); + } else { + tracing::info!("Unset configuration value"); + } + } + + Ok(response) + } + + pub fn is_modifiable(&self, _operation: IsModifiable) -> Result { + let response = self.config_response(); + + let span = tracing::info_span!("is_modifiable", session_id = %self.id()); + let _enter = span.enter(); + + tracing::warn!(session_id = %self.id(), "is_modifiable operation not yet implemented"); + // todo: need to implement this + Ok(response) + } +} + +fn prefix_search<'a, V>( + map: &'a BTreeMap, + prefix: &'a str, +) -> impl Iterator { + let start = map.range(prefix.to_string()..); + start.take_while(move |(k, _)| k.starts_with(prefix)) +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use super::*; + + #[test] + fn test_prefix_search() { + let mut map = BTreeMap::new(); + map.insert("apple".to_string(), 1); + map.insert("application".to_string(), 2); + map.insert("banana".to_string(), 3); + map.insert("app".to_string(), 4); + map.insert("apricot".to_string(), 5); + + // Test with prefix "app" + let result: Vec<_> = prefix_search(&map, "app").collect(); + assert_eq!( + result, + vec![ + (&"app".to_string(), &4), + (&"apple".to_string(), &1), + (&"application".to_string(), &2), + ] + ); + + // Test with prefix "b" + let result: Vec<_> = prefix_search(&map, "b").collect(); + assert_eq!(result, vec![(&"banana".to_string(), &3),]); + + // Test with prefix that doesn't match any keys + let result: Vec<_> = prefix_search(&map, "z").collect(); + assert_eq!(result, vec![]); + + // Test with empty prefix (should return all items) + let result: Vec<_> = prefix_search(&map, "").collect(); + assert_eq!( + result, + vec![ + (&"app".to_string(), &4), + (&"apple".to_string(), &1), + (&"application".to_string(), &2), + (&"apricot".to_string(), &5), + (&"banana".to_string(), &3), + ] + ); + + // Test with prefix that matches a complete key + let result: Vec<_> = prefix_search(&map, "apple").collect(); + assert_eq!(result, vec![(&"apple".to_string(), &1),]); + + // Test with case sensitivity + let result: Vec<_> = prefix_search(&map, "App").collect(); + assert_eq!(result, vec![]); + } +} diff --git a/src/daft-connect/src/convert.rs b/src/daft-connect/src/convert.rs new file mode 100644 index 0000000000..98cfa9e4d8 --- /dev/null +++ b/src/daft-connect/src/convert.rs @@ -0,0 +1,45 @@ +mod data_conversion; +mod expression; +mod formatting; +mod plan_conversion; +mod schema_conversion; + +use std::{ + collections::HashMap, + ops::{ControlFlow, Try}, + sync::Arc, +}; + +use common_daft_config::DaftExecutionConfig; +use daft_plan::LogicalPlanRef; +use daft_table::Table; +pub use data_conversion::convert_data; +use eyre::Context; +pub use plan_conversion::to_logical_plan; +pub use schema_conversion::connect_schema; + +pub fn map_to_tables( + logical_plan: &LogicalPlanRef, + mut f: impl FnMut(&Table) -> T, + default: impl FnOnce() -> T, +) -> eyre::Result { + let physical_plan = daft_physical_plan::translate(logical_plan)?; + let cfg = Arc::new(DaftExecutionConfig::default()); + let psets = HashMap::new(); + + let stream = daft_local_execution::run_local(&physical_plan, psets, cfg, None) + .wrap_err("running local execution")?; + + for elem in stream { + let elem = elem?; + let tables = elem.get_tables()?; + + for table in tables.as_slice() { + if let ControlFlow::Break(x) = f(table).branch() { + return Ok(T::from_residual(x)); + } + } + } + + Ok(default()) +} diff --git a/src/daft-connect/src/convert/data_conversion.rs b/src/daft-connect/src/convert/data_conversion.rs new file mode 100644 index 0000000000..de6b0ed71e --- /dev/null +++ b/src/daft-connect/src/convert/data_conversion.rs @@ -0,0 +1,59 @@ +//! Relation handling for Spark Connect protocol. +//! +//! A Relation represents a structured dataset or transformation in Spark Connect. +//! It can be either a base relation (direct data source) or derived relation +//! (result of operations on other relations). +//! +//! The protocol represents relations as trees of operations where: +//! - Each node is a Relation with metadata and an operation type +//! - Operations can reference other relations, forming a DAG +//! - The tree describes how to derive the final result +//! +//! Example flow for: SELECT age, COUNT(*) FROM employees WHERE dept='Eng' GROUP BY age +//! +//! ```text +//! Aggregate (grouping by age) +//! ↳ Filter (department = 'Engineering') +//! ↳ Read (employees table) +//! ``` +//! +//! Relations abstract away: +//! - Physical storage details +//! - Distributed computation +//! - Query optimization +//! - Data source specifics +//! +//! This allows Spark to optimize and execute queries efficiently across a cluster +//! while providing a consistent API regardless of the underlying data source. +//! ```mermaid +//! +//! ``` + +use eyre::{eyre, Context}; +use spark_connect::{relation::RelType, Relation}; +use tracing::trace; + +use crate::{command::ConcreteDataChannel, convert::formatting::RelTypeExt}; + +mod show_string; +use show_string::show_string; + +mod range; +use range::range; + +pub fn convert_data(plan: Relation, encoder: &mut impl ConcreteDataChannel) -> eyre::Result<()> { + // First check common fields if needed + if let Some(common) = &plan.common { + // contains metadata shared across all relation types + // Log or handle common fields if necessary + trace!("Processing relation with plan_id: {:?}", common.plan_id); + } + + let rel_type = plan.rel_type.ok_or_else(|| eyre!("rel_type is None"))?; + + match rel_type { + RelType::ShowString(input) => show_string(*input, encoder).wrap_err("parsing ShowString"), + RelType::Range(input) => range(input, encoder).wrap_err("parsing Range"), + other => Err(eyre!("Unsupported top-level relation: {}", other.name())), + } +} diff --git a/src/daft-connect/src/convert/data_conversion/range.rs b/src/daft-connect/src/convert/data_conversion/range.rs new file mode 100644 index 0000000000..1afb710a81 --- /dev/null +++ b/src/daft-connect/src/convert/data_conversion/range.rs @@ -0,0 +1,42 @@ +use daft_core::prelude::Series; +use daft_schema::prelude::Schema; +use daft_table::Table; +use eyre::{ensure, Context}; +use spark_connect::Range; + +use crate::command::ConcreteDataChannel; + +pub fn range(range: Range, channel: &mut impl ConcreteDataChannel) -> eyre::Result<()> { + let Range { + start, + end, + step, + num_partitions, + } = range; + + let start = start.unwrap_or(0); + + ensure!(num_partitions.is_none(), "num_partitions is not supported"); + + let step = usize::try_from(step).wrap_err("step must be a positive integer")?; + ensure!(step > 0, "step must be greater than 0"); + + let arrow_array: arrow2::array::Int64Array = (start..end).step_by(step).map(Some).collect(); + let len = arrow_array.len(); + + let singleton_series = Series::try_from(( + "range", + Box::new(arrow_array) as Box, + )) + .wrap_err("creating singleton series")?; + + let singleton_table = Table::new_with_size( + Schema::new(vec![singleton_series.field().clone()])?, + vec![singleton_series], + len, + )?; + + channel.send_table(&singleton_table)?; + + Ok(()) +} diff --git a/src/daft-connect/src/convert/data_conversion/show_string.rs b/src/daft-connect/src/convert/data_conversion/show_string.rs new file mode 100644 index 0000000000..35c5e3d602 --- /dev/null +++ b/src/daft-connect/src/convert/data_conversion/show_string.rs @@ -0,0 +1,59 @@ +use daft_core::prelude::Series; +use daft_schema::prelude::Schema; +use daft_table::Table; +use eyre::{ensure, eyre, Context}; +use spark_connect::ShowString; + +use crate::{ + command::ConcreteDataChannel, + convert::{map_to_tables, plan_conversion::to_logical_plan}, +}; + +pub fn show_string( + show_string: ShowString, + channel: &mut impl ConcreteDataChannel, +) -> eyre::Result<()> { + let ShowString { + input, + num_rows, + truncate, + vertical, + } = show_string; + + ensure!(num_rows > 0, "num_rows must be positive, got {num_rows}"); + ensure!(truncate > 0, "truncate must be positive, got {truncate}"); + ensure!(!vertical, "vertical is not yet supported"); + + let input = *input.ok_or_else(|| eyre!("input is None"))?; + + let logical_plan = to_logical_plan(input)?.build(); + + map_to_tables( + &logical_plan, + |table| -> eyre::Result<()> { + let display = format!("{table}"); + + let arrow_array: arrow2::array::Utf8Array = + std::iter::once(display.as_str()).map(Some).collect(); + + let singleton_series = Series::try_from(( + "show_string", + Box::new(arrow_array) as Box, + )) + .wrap_err("creating singleton series")?; + + let singleton_table = Table::new_with_size( + Schema::new(vec![singleton_series.field().clone()])?, + vec![singleton_series], + 1, + )?; + + channel.send_table(&singleton_table)?; + + Ok(()) + }, + || Ok(()), + )??; + + Ok(()) +} diff --git a/src/daft-connect/src/convert/expression.rs b/src/daft-connect/src/convert/expression.rs new file mode 100644 index 0000000000..f79a7bf5a8 --- /dev/null +++ b/src/daft-connect/src/convert/expression.rs @@ -0,0 +1,120 @@ +use daft_dsl::{Expr as DaftExpr, Operator}; +use eyre::{bail, ensure, eyre, Result}; +use spark_connect::{expression, expression::literal::LiteralType, Expression}; + +pub fn convert_expression(expr: Expression) -> Result { + match expr.expr_type { + Some(expression::ExprType::Literal(lit)) => Ok(DaftExpr::Literal(convert_literal(lit)?)), + + Some(expression::ExprType::UnresolvedAttribute(attr)) => { + Ok(DaftExpr::Column(attr.unparsed_identifier.into())) + } + + Some(expression::ExprType::Alias(alias)) => { + let expression::Alias { + expr, + name, + metadata, + } = *alias; + let expr = *expr.ok_or_else(|| eyre!("expr is None"))?; + + // Convert alias + let expr = convert_expression(expr)?; + + if let Some(metadata) = metadata + && !metadata.is_empty() + { + bail!("Metadata is not yet supported"); + } + + // ignore metadata for now + + let [name] = name.as_slice() else { + bail!("Alias name must have exactly one element"); + }; + + Ok(DaftExpr::Alias(expr.into(), name.as_str().into())) + } + + Some(expression::ExprType::UnresolvedFunction(expression::UnresolvedFunction { + function_name, + arguments, + is_distinct, + is_user_defined_function, + })) => { + ensure!(!is_distinct, "Distinct is not yet supported"); + ensure!( + !is_user_defined_function, + "User-defined functions are not yet supported" + ); + + let op = function_name.as_str(); + match op { + ">" | "<" | "<=" | ">=" | "+" | "-" | "*" | "/" => { + let arr: [Expression; 2] = arguments + .try_into() + .map_err(|_| eyre!("Expected 2 arguments"))?; + let [left, right] = arr; + + let left = convert_expression(left)?; + let right = convert_expression(right)?; + + let op = match op { + ">" => Operator::Gt, + "<" => Operator::Lt, + "<=" => Operator::LtEq, + ">=" => Operator::GtEq, + "+" => Operator::Plus, + "-" => Operator::Minus, + "*" => Operator::Multiply, + "/" => Operator::FloorDivide, // todo is this what we want? + _ => unreachable!(), + }; + + Ok(DaftExpr::BinaryOp { + left: left.into(), + op, + right: right.into(), + }) + } + other => bail!("Unsupported function name: {other}"), + } + } + + // Handle other expression types... + _ => Err(eyre!("Unsupported expression type")), + } +} + +// Helper functions to convert literals, function names, operators etc. + +fn convert_literal(lit: expression::Literal) -> Result { + let literal_type = lit + .literal_type + .ok_or_else(|| eyre!("literal_type is None"))?; + + let result = match literal_type { + LiteralType::Null(..) => daft_dsl::LiteralValue::Null, + LiteralType::Binary(input) => daft_dsl::LiteralValue::Binary(input), + LiteralType::Boolean(input) => daft_dsl::LiteralValue::Boolean(input), + LiteralType::Byte(input) => daft_dsl::LiteralValue::Int32(input), + LiteralType::Short(input) => daft_dsl::LiteralValue::Int32(input), + LiteralType::Integer(input) => daft_dsl::LiteralValue::Int32(input), + LiteralType::Long(input) => daft_dsl::LiteralValue::Int64(input), + LiteralType::Float(input) => daft_dsl::LiteralValue::Float64(f64::from(input)), + LiteralType::Double(input) => daft_dsl::LiteralValue::Float64(input), + LiteralType::String(input) => daft_dsl::LiteralValue::Utf8(input), + LiteralType::Date(input) => daft_dsl::LiteralValue::Date(input), + LiteralType::Decimal(_) + | LiteralType::Timestamp(_) + | LiteralType::TimestampNtz(_) + | LiteralType::CalendarInterval(_) + | LiteralType::YearMonthInterval(_) + | LiteralType::DayTimeInterval(_) + | LiteralType::Array(_) + | LiteralType::Map(_) + | LiteralType::Struct(_) => bail!("unimplemented"), + }; + + Ok(result) +} diff --git a/src/daft-connect/src/convert/formatting.rs b/src/daft-connect/src/convert/formatting.rs new file mode 100644 index 0000000000..3310a918fb --- /dev/null +++ b/src/daft-connect/src/convert/formatting.rs @@ -0,0 +1,69 @@ +use spark_connect::relation::RelType; + +/// Extension trait for RelType to add a `name` method. +pub trait RelTypeExt { + /// Returns the name of the RelType as a string. + fn name(&self) -> &'static str; +} + +impl RelTypeExt for RelType { + fn name(&self) -> &'static str { + match self { + Self::Read(_) => "Read", + Self::Project(_) => "Project", + Self::Filter(_) => "Filter", + Self::Join(_) => "Join", + Self::SetOp(_) => "SetOp", + Self::Sort(_) => "Sort", + Self::Limit(_) => "Limit", + Self::Aggregate(_) => "Aggregate", + Self::Sql(_) => "Sql", + Self::LocalRelation(_) => "LocalRelation", + Self::Sample(_) => "Sample", + Self::Offset(_) => "Offset", + Self::Deduplicate(_) => "Deduplicate", + Self::Range(_) => "Range", + Self::SubqueryAlias(_) => "SubqueryAlias", + Self::Repartition(_) => "Repartition", + Self::ToDf(_) => "ToDf", + Self::WithColumnsRenamed(_) => "WithColumnsRenamed", + Self::ShowString(_) => "ShowString", + Self::Drop(_) => "Drop", + Self::Tail(_) => "Tail", + Self::WithColumns(_) => "WithColumns", + Self::Hint(_) => "Hint", + Self::Unpivot(_) => "Unpivot", + Self::ToSchema(_) => "ToSchema", + Self::RepartitionByExpression(_) => "RepartitionByExpression", + Self::MapPartitions(_) => "MapPartitions", + Self::CollectMetrics(_) => "CollectMetrics", + Self::Parse(_) => "Parse", + Self::GroupMap(_) => "GroupMap", + Self::CoGroupMap(_) => "CoGroupMap", + Self::WithWatermark(_) => "WithWatermark", + Self::ApplyInPandasWithState(_) => "ApplyInPandasWithState", + Self::HtmlString(_) => "HtmlString", + Self::CachedLocalRelation(_) => "CachedLocalRelation", + Self::CachedRemoteRelation(_) => "CachedRemoteRelation", + Self::CommonInlineUserDefinedTableFunction(_) => "CommonInlineUserDefinedTableFunction", + Self::AsOfJoin(_) => "AsOfJoin", + Self::CommonInlineUserDefinedDataSource(_) => "CommonInlineUserDefinedDataSource", + Self::WithRelations(_) => "WithRelations", + Self::Transpose(_) => "Transpose", + Self::FillNa(_) => "FillNa", + Self::DropNa(_) => "DropNa", + Self::Replace(_) => "Replace", + Self::Summary(_) => "Summary", + Self::Crosstab(_) => "Crosstab", + Self::Describe(_) => "Describe", + Self::Cov(_) => "Cov", + Self::Corr(_) => "Corr", + Self::ApproxQuantile(_) => "ApproxQuantile", + Self::FreqItems(_) => "FreqItems", + Self::SampleBy(_) => "SampleBy", + Self::Catalog(_) => "Catalog", + Self::Extension(_) => "Extension", + Self::Unknown(_) => "Unknown", + } + } +} diff --git a/src/daft-connect/src/convert/plan_conversion.rs b/src/daft-connect/src/convert/plan_conversion.rs new file mode 100644 index 0000000000..8fa9b8a7b3 --- /dev/null +++ b/src/daft-connect/src/convert/plan_conversion.rs @@ -0,0 +1,134 @@ +use std::{collections::HashSet, sync::Arc}; + +use daft_plan::{LogicalPlanBuilder, ParquetScanBuilder}; +use eyre::{bail, eyre, Result, WrapErr}; +use spark_connect::{ + expression::Alias, + read::{DataSource, ReadType}, + relation::RelType, + Filter, Read, Relation, WithColumns, +}; +use tracing::warn; + +use crate::convert::expression; + +pub fn to_logical_plan(plan: Relation) -> Result { + let scope = std::thread::spawn(|| { + let rel_type = plan.rel_type.ok_or_else(|| eyre!("rel_type is None"))?; + + match rel_type { + RelType::ShowString(..) => { + bail!("ShowString is only supported as a top-level relation") + } + RelType::Filter(filter) => parse_filter(*filter).wrap_err("parsing Filter"), + RelType::WithColumns(with_columns) => { + parse_with_columns(*with_columns).wrap_err("parsing WithColumns") + } + RelType::Read(read) => parse_read(read), + _ => bail!("Unsupported relation type: {rel_type:?}"), + } + }); + + scope.join().unwrap() +} + +fn parse_filter(filter: Filter) -> Result { + let Filter { input, condition } = filter; + let input = *input.ok_or_else(|| eyre!("input is None"))?; + let input_plan = to_logical_plan(input).wrap_err("parsing input")?; + + let condition = condition.ok_or_else(|| eyre!("condition is None"))?; + let condition = + expression::convert_expression(condition).wrap_err("converting to daft expression")?; + let condition = Arc::new(condition); + + input_plan.filter(condition).wrap_err("applying filter") +} + +fn parse_with_columns(with_columns: WithColumns) -> Result { + let WithColumns { input, aliases } = with_columns; + let input = *input.ok_or_else(|| eyre!("input is None"))?; + let input_plan = to_logical_plan(input).wrap_err("parsing input")?; + + let mut new_exprs = Vec::new(); + let mut existing_columns: HashSet<_> = input_plan.schema().names().into_iter().collect(); + + for alias in aliases { + let Alias { + expr, + name, + metadata, + } = alias; + + if name.len() != 1 { + bail!("Alias name must have exactly one element"); + } + let name = name[0].as_str(); + + if metadata.is_some() { + bail!("Metadata is not yet supported"); + } + + let expr = expr.ok_or_else(|| eyre!("expression is None"))?; + let expr = + expression::convert_expression(*expr).wrap_err("converting to daft expression")?; + let expr = Arc::new(expr); + + new_exprs.push(expr.alias(name)); + + if existing_columns.contains(name) { + existing_columns.remove(name); + } + } + + // Add remaining existing columns + for col_name in existing_columns { + new_exprs.push(daft_dsl::col(col_name)); + } + + input_plan + .select(new_exprs) + .wrap_err("selecting new expressions") +} + +fn parse_read(read: Read) -> Result { + let Read { + is_streaming, + read_type, + } = read; + + warn!("Ignoring is_streaming: {is_streaming}"); + + let read_type = read_type.ok_or_else(|| eyre!("type is None"))?; + + match read_type { + ReadType::NamedTable(_) => bail!("Named tables are not yet supported"), + ReadType::DataSource(data_source) => parse_data_source(data_source), + } +} + +fn parse_data_source(data_source: DataSource) -> Result { + let DataSource { + format, + options, + paths, + predicates, + .. + } = data_source; + + let format = format.ok_or_else(|| eyre!("format is None"))?; + if format != "parquet" { + bail!("Only parquet is supported; got {format}"); + } + + if !options.is_empty() { + bail!("Options are not yet supported"); + } + if !predicates.is_empty() { + bail!("Predicates are not yet supported"); + } + + ParquetScanBuilder::new(paths) + .finish() + .wrap_err("creating ParquetScanBuilder") +} diff --git a/src/daft-connect/src/convert/schema_conversion.rs b/src/daft-connect/src/convert/schema_conversion.rs new file mode 100644 index 0000000000..dcce376b94 --- /dev/null +++ b/src/daft-connect/src/convert/schema_conversion.rs @@ -0,0 +1,56 @@ +use spark_connect::{ + data_type::{Kind, Long, Struct, StructField}, + relation::RelType, + DataType, Relation, +}; + +#[tracing::instrument(skip_all)] +pub fn connect_schema(input: Relation) -> Result { + if input.common.is_some() { + tracing::warn!("We do not currently look at common fields"); + } + + let result = match input + .rel_type + .ok_or_else(|| tonic::Status::internal("rel_type is None"))? + { + RelType::Range(spark_connect::Range { num_partitions, .. }) => { + if num_partitions.is_some() { + return Err(tonic::Status::unimplemented( + "num_partitions is not supported", + )); + } + + let long = Long { + type_variation_reference: 0, + }; + + let id_field = StructField { + name: "id".to_string(), + data_type: Some(DataType { + kind: Some(Kind::Long(long)), + }), + nullable: false, + metadata: None, + }; + + let fields = vec![id_field]; + + let strct = Struct { + fields, + type_variation_reference: 0, + }; + + DataType { + kind: Some(Kind::Struct(strct)), + } + } + other => { + return Err(tonic::Status::unimplemented(format!( + "Unsupported relation type: {other:?}" + ))) + } + }; + + Ok(result) +} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs new file mode 100644 index 0000000000..8e021cd66d --- /dev/null +++ b/src/daft-connect/src/lib.rs @@ -0,0 +1,449 @@ +#![feature(iterator_try_collect)] +#![feature(let_chains)] +#![feature(try_trait_v2)] +#![feature(coroutines)] +#![feature(iter_from_coroutine)] +#![feature(stmt_expr_attributes)] +#![feature(try_trait_v2_residual)] + +use std::ops::ControlFlow; + +use dashmap::DashMap; +use eyre::Context; +#[cfg(feature = "python")] +use pyo3::types::PyModuleMethods; +use ron::extensions::Extensions; +use spark_connect::{ + analyze_plan_response, + command::CommandType, + plan::OpType, + spark_connect_service_server::{SparkConnectService, SparkConnectServiceServer}, + AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, + ArtifactStatusesRequest, ArtifactStatusesResponse, ConfigRequest, ConfigResponse, + ExecutePlanRequest, ExecutePlanResponse, FetchErrorDetailsRequest, FetchErrorDetailsResponse, + InterruptRequest, InterruptResponse, Plan, ReattachExecuteRequest, ReleaseExecuteRequest, + ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, +}; +use tonic::{transport::Server, Request, Response, Status}; +use tracing::{info, warn}; +use uuid::Uuid; + +use crate::{convert::map_to_tables, session::Session}; + +mod command; +mod config; +mod convert; +mod session; +pub mod util; + +pub fn start(addr: &str) -> eyre::Result<()> { + info!("Daft-Connect server listening on {addr}"); + let addr = util::parse_spark_connect_address(addr)?; + + let service = DaftSparkConnectService::default(); + + info!("Daft-Connect server listening on {addr}"); + + std::thread::spawn(move || { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let result = runtime + .block_on(async { + Server::builder() + .add_service(SparkConnectServiceServer::new(service)) + .serve(addr) + .await + }) + .wrap_err_with(|| format!("Failed to start server on {addr}")); + + if let Err(e) = result { + eprintln!("Daft-Connect server error: {e:?}"); + } + + println!("done with runtime"); + + eyre::Result::<_>::Ok(()) + }); + + Ok(()) +} + +#[derive(Default)] +pub struct DaftSparkConnectService { + client_to_session: DashMap, // To track session data +} + +impl DaftSparkConnectService { + fn get_session( + &self, + session_id: &str, + ) -> Result, Status> { + let Ok(uuid) = Uuid::parse_str(session_id) else { + return Err(Status::invalid_argument( + "Invalid session_id format, must be a UUID", + )); + }; + + let res = self + .client_to_session + .entry(uuid) + .or_insert_with(|| Session::new(session_id.to_string())); + + Ok(res) + } +} + +fn pretty_config() -> ron::ser::PrettyConfig { + ron::ser::PrettyConfig::default() + .extensions( + Extensions::IMPLICIT_SOME + | Extensions::UNWRAP_NEWTYPES + | Extensions::UNWRAP_VARIANT_NEWTYPES, + ) + .indentor(" ".to_string()) +} + +#[tonic::async_trait] +impl SparkConnectService for DaftSparkConnectService { + type ExecutePlanStream = std::pin::Pin< + Box< + dyn futures::Stream> + Send + Sync + 'static, + >, + >; + type ReattachExecuteStream = std::pin::Pin< + Box< + dyn futures::Stream> + Send + Sync + 'static, + >, + >; + + #[tracing::instrument(skip_all)] + async fn execute_plan( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let session = self.get_session(&request.session_id)?; + + let operation = request + .operation_id + .ok_or_else(|| invalid_argument!("Operation ID is required"))?; + + // Proceed with executing the plan... + let plan = request + .plan + .ok_or_else(|| invalid_argument!("Plan is required"))?; + let plan = plan + .op_type + .ok_or_else(|| invalid_argument!("Plan operation is required"))?; + + use spark_connect::plan::OpType; + + match plan { + OpType::Root(relation) => { + let result = session.handle_root_command(relation, operation).await?; + return Ok(Response::new(result)); + } + OpType::Command(command) => { + let command = command + .command_type + .ok_or_else(|| invalid_argument!("Command type is required"))?; + + match command { + CommandType::RegisterFunction(_) => { + Err(unimplemented!("RegisterFunction not implemented")) + } + CommandType::WriteOperation(_) => { + Err(unimplemented!("WriteOperation not implemented")) + } + CommandType::CreateDataframeView(_) => { + Err(unimplemented!("CreateDataframeView not implemented")) + } + CommandType::WriteOperationV2(_) => { + Err(unimplemented!("WriteOperationV2 not implemented")) + } + CommandType::SqlCommand(..) => { + Err(unimplemented!("SQL execution not yet implemented")) + } + CommandType::WriteStreamOperationStart(_) => { + Err(unimplemented!("WriteStreamOperationStart not implemented")) + } + CommandType::StreamingQueryCommand(_) => { + Err(unimplemented!("StreamingQueryCommand not implemented")) + } + CommandType::GetResourcesCommand(_) => { + Err(unimplemented!("GetResourcesCommand not implemented")) + } + CommandType::StreamingQueryManagerCommand(_) => Err(unimplemented!( + "StreamingQueryManagerCommand not implemented" + )), + CommandType::RegisterTableFunction(_) => { + Err(unimplemented!("RegisterTableFunction not implemented")) + } + CommandType::StreamingQueryListenerBusCommand(_) => Err(unimplemented!( + "StreamingQueryListenerBusCommand not implemented" + )), + CommandType::RegisterDataSource(_) => { + Err(unimplemented!("RegisterDataSource not implemented")) + } + CommandType::CreateResourceProfileCommand(_) => Err(unimplemented!( + "CreateResourceProfileCommand not implemented" + )), + CommandType::CheckpointCommand(_) => { + Err(unimplemented!("CheckpointCommand not implemented")) + } + CommandType::RemoveCachedRemoteRelationCommand(_) => Err(unimplemented!( + "RemoveCachedRemoteRelationCommand not implemented" + )), + CommandType::MergeIntoTableCommand(_) => { + Err(unimplemented!("MergeIntoTableCommand not implemented")) + } + CommandType::Extension(_) => Err(unimplemented!("Extension not implemented")), + } + } + }?; + + Err(unimplemented!("Unsupported plan type")) + } + + #[tracing::instrument(skip_all)] + async fn config( + &self, + request: Request, + ) -> Result, Status> { + println!("got config"); + let request = request.into_inner(); + + let mut session = self.get_session(&request.session_id)?; + + let Some(operation) = request.operation.and_then(|op| op.op_type) else { + return Err(Status::invalid_argument("Missing operation")); + }; + + use spark_connect::config_request::operation::OpType; + + let response = match operation { + OpType::Set(op) => session.set(op), + OpType::Get(op) => session.get(op), + OpType::GetWithDefault(op) => session.get_with_default(op), + OpType::GetOption(op) => session.get_option(op), + OpType::GetAll(op) => session.get_all(op), + OpType::Unset(op) => session.unset(op), + OpType::IsModifiable(op) => session.is_modifiable(op), + }?; + + info!("Response: {response:?}"); + + Ok(Response::new(response)) + } + + #[tracing::instrument(skip_all)] + async fn add_artifacts( + &self, + _request: Request>, + ) -> Result, Status> { + Err(unimplemented!( + "add_artifacts operation is not yet implemented" + )) + } + + #[tracing::instrument(skip_all)] + async fn analyze_plan( + &self, + request: Request, + ) -> Result, Status> { + use spark_connect::analyze_plan_request::*; + let request = request.into_inner(); + + let AnalyzePlanRequest { + session_id, + analyze, + .. + } = request; + + let Some(analyze) = analyze else { + return Err(Status::invalid_argument("analyze is required")); + }; + + match analyze { + Analyze::Schema(Schema { plan }) => { + let Some(Plan { op_type }) = plan else { + return Err(Status::invalid_argument("plan is required")); + }; + + let Some(OpType::Root(relation)) = op_type else { + return Err(Status::invalid_argument("op_type is required to be root")); + }; + + let result = convert::connect_schema(relation)?; + + let schema = analyze_plan_response::DdlParse { + parsed: Some(result), + }; + + let response = AnalyzePlanResponse { + session_id, + server_side_session_id: String::new(), + result: Some(analyze_plan_response::Result::DdlParse(schema)), + }; + + println!("response: {response:#?}"); + + Ok(Response::new(response)) + } + Analyze::TreeString(tree_string) => { + if let Some(level) = tree_string.level { + warn!("Ignoring level {level} in TreeString"); + } + + let Some(plan) = tree_string.plan else { + return Err(invalid_argument!("TreeString must have a plan")); + }; + + let Some(op_type) = plan.op_type else { + return Err(invalid_argument!("plan must have an op_type")); + }; + + println!("op_type: {op_type:?}"); + + let OpType::Root(plan) = op_type else { + return Err(invalid_argument!("Only op_type Root is supported")); + }; + + let logical_plan = match convert::to_logical_plan(plan) { + Ok(lp) => lp, + Err(e) => { + return Err(invalid_argument!( + "Failed to convert to logical plan: {e:?}" + )); + } + }; + + let logical_plan = logical_plan.build(); + + let res = std::thread::spawn(move || { + let result = map_to_tables( + &logical_plan, + |table| { + let table = format!("{table}"); + ControlFlow::Break(table) + }, + || ControlFlow::Continue(()), + ) + .unwrap(); + + let result = match result { + ControlFlow::Break(x) => Some(x), + ControlFlow::Continue(()) => None, + } + .unwrap(); + + AnalyzePlanResponse { + session_id, + server_side_session_id: String::new(), + result: Some(analyze_plan_response::Result::TreeString( + analyze_plan_response::TreeString { + tree_string: result, + }, + )), + } + }); + + let res = res.join().unwrap(); + + let response = Response::new(res); + Ok(response) + } + _ => Err(unimplemented!( + "Analyze plan operation is not yet implemented" + )), + } + } + + #[tracing::instrument(skip_all)] + async fn artifact_status( + &self, + _request: Request, + ) -> Result, Status> { + println!("got artifact status"); + Err(unimplemented!( + "artifact_status operation is not yet implemented" + )) + } + + #[tracing::instrument(skip_all)] + async fn interrupt( + &self, + _request: Request, + ) -> Result, Status> { + println!("got interrupt"); + Err(unimplemented!("interrupt operation is not yet implemented")) + } + + #[tracing::instrument(skip_all)] + async fn reattach_execute( + &self, + request: Request, + ) -> Result, Status> { + warn!("reattach_execute operation is not yet implemented"); + + let singleton_stream = futures::stream::once(async { + Err(Status::unimplemented( + "reattach_execute operation is not yet implemented", + )) + }); + + Ok(Response::new(Box::pin(singleton_stream))) + } + + #[tracing::instrument(skip_all)] + async fn release_execute( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let session = self.get_session(&request.session_id)?; + + let response = ReleaseExecuteResponse { + session_id: session.id().to_string(), + server_side_session_id: session.server_side_session_id().to_string(), + operation_id: Some(request.operation_id), // todo: impl properly + }; + + Ok(Response::new(response)) + } + + #[tracing::instrument(skip_all)] + async fn release_session( + &self, + _request: Request, + ) -> Result, Status> { + println!("got release session"); + Err(unimplemented!( + "release_session operation is not yet implemented" + )) + } + + #[tracing::instrument(skip_all)] + async fn fetch_error_details( + &self, + _request: Request, + ) -> Result, Status> { + println!("got fetch error details"); + Err(unimplemented!( + "fetch_error_details operation is not yet implemented" + )) + } +} +#[cfg(feature = "python")] +#[pyo3::pyfunction] +#[pyo3(name = "connect_start")] +pub fn py_connect_start(addr: &str) -> pyo3::PyResult<()> { + start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}"))) +} + +#[cfg(feature = "python")] +pub fn register_modules(parent: &pyo3::Bound) -> pyo3::PyResult<()> { + parent.add_function(pyo3::wrap_pyfunction_bound!(py_connect_start, parent)?)?; + Ok(()) +} diff --git a/src/daft-connect/src/main.rs b/src/daft-connect/src/main.rs new file mode 100644 index 0000000000..6dbe5dac6c --- /dev/null +++ b/src/daft-connect/src/main.rs @@ -0,0 +1,32 @@ +use daft_connect::DaftSparkConnectService; +use spark_connect::spark_connect_service_server::SparkConnectServiceServer; +use tonic::transport::Server; +use tracing_subscriber::layer::SubscriberExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing::subscriber::set_global_default( + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::new("info")) + .with(tracing_tracy::TracyLayer::default()), + ) + .expect("setup tracy layer"); + + let addr = "[::1]:50051".parse()?; + let service = DaftSparkConnectService::default(); + + println!("Daft-Connect server listening on {}", addr); + + tokio::select! { + result = Server::builder() + .add_service(SparkConnectServiceServer::new(service)) + .serve(addr) => { + result?; + } + _ = tokio::signal::ctrl_c() => { + println!("\nReceived Ctrl-C, gracefully shutting down server"); + } + } + + Ok(()) +} diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs new file mode 100644 index 0000000000..1be05b3948 --- /dev/null +++ b/src/daft-connect/src/session.rs @@ -0,0 +1,48 @@ +use std::collections::{BTreeMap, HashMap}; + +use uuid::Uuid; + +pub struct Session { + /// so order is preserved, and so we can efficiently do a prefix search + /// + /// Also, + config_values: BTreeMap, + + #[expect( + unused, + reason = "this will be used in the future especially to pass spark connect tests" + )] + tables_by_name: HashMap, + + id: String, + server_side_session_id: String, +} + +impl Session { + pub fn config_values(&self) -> &BTreeMap { + &self.config_values + } + + pub fn config_values_mut(&mut self) -> &mut BTreeMap { + &mut self.config_values + } + + pub fn new(id: String) -> Self { + let server_side_session_id = Uuid::new_v4(); + let server_side_session_id = server_side_session_id.to_string(); + Self { + config_values: Default::default(), + tables_by_name: Default::default(), + id, + server_side_session_id, + } + } + + pub fn id(&self) -> &str { + &self.id + } + + pub fn server_side_session_id(&self) -> &str { + &self.server_side_session_id + } +} diff --git a/src/daft-connect/src/util.rs b/src/daft-connect/src/util.rs new file mode 100644 index 0000000000..29a593f342 --- /dev/null +++ b/src/daft-connect/src/util.rs @@ -0,0 +1,109 @@ +use std::net::ToSocketAddrs; + +#[macro_export] +macro_rules! invalid_argument { + ($arg: tt) => {{ + let msg = format!($arg); + ::tonic::Status::invalid_argument(msg) + }}; +} + +#[macro_export] +macro_rules! unimplemented { + ($arg: tt) => {{ + let msg = format!($arg); + ::tonic::Status::unimplemented(msg) + }}; +} + +pub fn parse_spark_connect_address(addr: &str) -> eyre::Result { + // Check if address starts with "sc://" + if !addr.starts_with("sc://") { + return Err(eyre::eyre!("Address must start with 'sc://'")); + } + + // Remove the "sc://" prefix + let addr = addr.trim_start_matches("sc://"); + + // Resolve the hostname using tokio's DNS resolver + let addrs = addr.to_socket_addrs()?; + + // Take the first resolved address + addrs + .into_iter() + .next() + .ok_or_else(|| eyre::eyre!("No addresses found for hostname")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_spark_connect_address_valid() { + let addr = "sc://localhost:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_spark_connect_address_missing_prefix() { + let addr = "localhost:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("must start with 'sc://'")); + } + + #[test] + fn test_parse_spark_connect_address_invalid_port() { + let addr = "sc://localhost:invalid"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + } + + #[test] + fn test_parse_spark_connect_address_missing_port() { + let addr = "sc://localhost"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + } + + #[test] + fn test_parse_spark_connect_address_empty() { + let addr = ""; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("must start with 'sc://'")); + } + + #[test] + fn test_parse_spark_connect_address_only_prefix() { + let addr = "sc://"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Invalid address format")); + } + + #[test] + fn test_parse_spark_connect_address_ipv4() { + let addr = "sc://127.0.0.1:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_spark_connect_address_ipv6() { + let addr = "sc://[::1]:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_ok()); + } +} diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 553ad18b40..719da409c4 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -11,7 +11,7 @@ mod sources; use common_error::{DaftError, DaftResult}; use lazy_static::lazy_static; -pub use run::NativeExecutor; +pub use run::{run_local, NativeExecutor}; use snafu::{futures::TryFutureExt, Snafu}; lazy_static! { diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 0a450ace70..c68546f318 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -99,6 +99,12 @@ impl Table { Ok(Self::new_unchecked(schema, columns?, num_rows)) } + pub fn get_inner_arrow_arrays( + &self, + ) -> impl Iterator> + '_ { + self.columns.iter().map(|s| s.to_arrow()) + } + /// Create a new [`Table`] and validate against `num_rows` /// /// Note that this function is slow. You might instead be looking for [`Table::new_unchecked`] if you've already performed your own validation logic. @@ -194,6 +200,10 @@ impl Table { self.num_rows } + pub fn num_rows(&self) -> usize { + self.num_rows + } + pub fn is_empty(&self) -> bool { self.len() == 0 } diff --git a/tests/connect/__init__.py b/tests/connect/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/connect/conf.py b/tests/connect/conf.py new file mode 100644 index 0000000000..2ebd63ede6 --- /dev/null +++ b/tests/connect/conf.py @@ -0,0 +1,150 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.errors import PySparkTypeError, PySparkValueError +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + +import warnings +from typing import Any, Dict, Optional, Union, cast + +from pyspark import _NoValue +from pyspark._globals import _NoValueType +from pyspark.sql.conf import RuntimeConfig as PySparkRuntimeConfig +from pyspark.sql.connect import proto +from pyspark.sql.connect.client import SparkConnectClient + + +class RuntimeConf: + def __init__(self, client: SparkConnectClient) -> None: + """Create a new RuntimeConfig.""" + self._client = client + + __init__.__doc__ = PySparkRuntimeConfig.__init__.__doc__ + + def set(self, key: str, value: Union[str, int, bool]) -> None: + if isinstance(value, bool): + value = "true" if value else "false" + elif isinstance(value, int): + value = str(value) + op_set = proto.ConfigRequest.Set(pairs=[proto.KeyValue(key=key, value=value)]) + operation = proto.ConfigRequest.Operation(set=op_set) + result = self._client.config(operation) + for warn in result.warnings: + warnings.warn(warn) + + set.__doc__ = PySparkRuntimeConfig.set.__doc__ + + def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> Optional[str]: + self._checkType(key, "key") + if default is _NoValue: + op_get = proto.ConfigRequest.Get(keys=[key]) + operation = proto.ConfigRequest.Operation(get=op_get) + else: + if default is not None: + self._checkType(default, "default") + op_get_with_default = proto.ConfigRequest.GetWithDefault( + pairs=[proto.KeyValue(key=key, value=cast(Optional[str], default))] + ) + operation = proto.ConfigRequest.Operation(get_with_default=op_get_with_default) + result = self._client.config(operation) + return result.pairs[0][1] + + get.__doc__ = PySparkRuntimeConfig.get.__doc__ + + @property + def getAll(self) -> Dict[str, str]: + op_get_all = proto.ConfigRequest.GetAll() + operation = proto.ConfigRequest.Operation(get_all=op_get_all) + result = self._client.config(operation) + confs: Dict[str, str] = dict() + for key, value in result.pairs: + assert value is not None + confs[key] = value + return confs + + getAll.__doc__ = PySparkRuntimeConfig.getAll.__doc__ + + def unset(self, key: str) -> None: + op_unset = proto.ConfigRequest.Unset(keys=[key]) + operation = proto.ConfigRequest.Operation(unset=op_unset) + result = self._client.config(operation) + for warn in result.warnings: + warnings.warn(warn) + + unset.__doc__ = PySparkRuntimeConfig.unset.__doc__ + + def isModifiable(self, key: str) -> bool: + op_is_modifiable = proto.ConfigRequest.IsModifiable(keys=[key]) + operation = proto.ConfigRequest.Operation(is_modifiable=op_is_modifiable) + result = self._client.config(operation).pairs[0][1] + if result == "true": + return True + elif result == "false": + return False + else: + raise PySparkValueError( + errorClass="VALUE_NOT_ALLOWED", + messageParameters={"arg_name": "result", "allowed_values": "'true' or 'false'"}, + ) + + isModifiable.__doc__ = PySparkRuntimeConfig.isModifiable.__doc__ + + def _checkType(self, obj: Any, identifier: str) -> None: + """Assert that an object is of type str.""" + if not isinstance(obj, str): + raise PySparkTypeError( + errorClass="NOT_STR", + messageParameters={ + "arg_name": identifier, + "arg_type": type(obj).__name__, + }, + ) + + +RuntimeConf.__doc__ = PySparkRuntimeConfig.__doc__ + + +def _test() -> None: + import doctest + import sys + + import pyspark.sql.connect.conf + from pyspark.sql import SparkSession as PySparkSession + + globs = pyspark.sql.connect.conf.__dict__.copy() + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.conf tests") + # .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .remote("127.0.0.1:50051") + .getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.connect.conf, + globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.IGNORE_EXCEPTION_DETAIL, + ) + + globs["spark"].stop() + + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/tests/connect/test_client.py b/tests/connect/test_client.py new file mode 100644 index 0000000000..f0528b466d --- /dev/null +++ b/tests/connect/test_client.py @@ -0,0 +1,434 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import uuid +from collections.abc import Generator +from typing import Any, Optional, Union + +from pyspark.testing.connectutils import connect_requirement_message, should_test_connect +from pyspark.testing.utils import eventually + +if should_test_connect: + import grpc + import pandas as pd + import pyarrow as pa + import pyspark.sql.connect.proto as proto + from google.rpc import status_pb2 + from pyspark.errors import PySparkRuntimeError, RetriesExceeded + from pyspark.sql.connect.client import DefaultChannelBuilder, SparkConnectClient + from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator + from pyspark.sql.connect.client.retries import ( + DefaultPolicy, + Retrying, + ) + + class TestPolicy(DefaultPolicy): + def __init__(self): + super().__init__( + max_retries=3, + backoff_multiplier=4.0, + initial_backoff=10, + max_backoff=10, + jitter=10, + min_jitter_threshold=10, + ) + + class TestException(grpc.RpcError, grpc.Call): + """Exception mock to test retryable exceptions.""" + + def __init__( + self, + msg, + code=grpc.StatusCode.INTERNAL, + trailing_status: Union[status_pb2.Status, None] = None, + ): + self.msg = msg + self._code = code + self._trailer: dict[str, Any] = {} + if trailing_status is not None: + self._trailer["grpc-status-details-bin"] = trailing_status.SerializeToString() + + def code(self): + return self._code + + def __str__(self): + return self.msg + + def details(self): + return self.msg + + def trailing_metadata(self): + return None if not self._trailer else self._trailer.items() + + class ResponseGenerator(Generator): + """This class is used to generate values that are returned by the streaming + iterator of the GRPC stub.""" + + def __init__(self, funs): + self._funs = funs + self._iterator = iter(self._funs) + + def send(self, value: Any) -> proto.ExecutePlanResponse: + val = next(self._iterator) + if callable(val): + return val() + else: + return val + + def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any: + super().throw(type, value, traceback) + + def close(self) -> None: + return super().close() + + class MockSparkConnectStub: + """Simple mock class for the GRPC stub used by the re-attachable execution.""" + + def __init__(self, execute_ops=None, attach_ops=None): + self._execute_ops = execute_ops + self._attach_ops = attach_ops + # Call counters + self.execute_calls = 0 + self.release_calls = 0 + self.release_until_calls = 0 + self.attach_calls = 0 + + def ExecutePlan(self, *args, **kwargs): + self.execute_calls += 1 + return self._execute_ops + + def ReattachExecute(self, *args, **kwargs): + self.attach_calls += 1 + return self._attach_ops + + def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs): + if req.HasField("release_all"): + self.release_calls += 1 + elif req.HasField("release_until"): + print("increment") + self.release_until_calls += 1 + + class MockService: + # Simplest mock of the SparkConnectService. + # If this needs more complex logic, it needs to be replaced with Python mocking. + + req: Optional[proto.ExecutePlanRequest] + + def __init__(self, session_id: str): + self._session_id = session_id + self.req = None + + def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): + self.req = req + resp = proto.ExecutePlanResponse() + resp.session_id = self._session_id + + pdf = pd.DataFrame(data={"col1": [1, 2]}) + schema = pa.Schema.from_pandas(pdf) + table = pa.Table.from_pandas(pdf) + sink = pa.BufferOutputStream() + + writer = pa.ipc.new_stream(sink, schema=schema) + writer.write(table) + writer.close() + + buf = sink.getvalue() + resp.arrow_batch.data = buf.to_pybytes() + resp.arrow_batch.row_count = 2 + return [resp] + + def Interrupt(self, req: proto.InterruptRequest, metadata): + self.req = req + resp = proto.InterruptResponse() + resp.session_id = self._session_id + return resp + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class SparkConnectClientTestCase(unittest.TestCase): + def test_user_agent_passthrough(self): + client = SparkConnectClient("sc://foo/;user_agent=bar", use_reattachable_execute=False) + mock = MockService(client._session_id) + client._stub = mock + + command = proto.Command() + client.execute_command(command) + + self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") + self.assertRegex(mock.req.client_type, r"^bar spark/[^ ]+ os/[^ ]+ python/[^ ]+$") + + def test_user_agent_default(self): + client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) + mock = MockService(client._session_id) + client._stub = mock + + command = proto.Command() + client.execute_command(command) + + self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") + self.assertRegex(mock.req.client_type, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$") + + def test_properties(self): + client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) + self.assertEqual(client.token, "bar") + self.assertEqual(client.host, "foo") + + client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) + self.assertIsNone(client.token) + + def test_channel_builder(self): + class CustomChannelBuilder(DefaultChannelBuilder): + @property + def userId(self) -> Optional[str]: + return "abc" + + client = SparkConnectClient(CustomChannelBuilder("sc://foo/"), use_reattachable_execute=False) + + self.assertEqual(client._user_id, "abc") + + def test_interrupt_all(self): + client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) + mock = MockService(client._session_id) + client._stub = mock + + client.interrupt_all() + self.assertIsNotNone(mock.req, "Interrupt API was not called when expected") + + def test_is_closed(self): + client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) + + self.assertFalse(client.is_closed) + client.close() + self.assertTrue(client.is_closed) + + def test_retry(self): + client = SparkConnectClient("sc://foo/;token=bar") + + total_sleep = 0 + + def sleep(t): + nonlocal total_sleep + total_sleep += t + + try: + for attempt in Retrying(client._retry_policies, sleep=sleep): + with attempt: + raise TestException("Retryable error", grpc.StatusCode.UNAVAILABLE) + except RetriesExceeded: + pass + + # tolerated at least 10 mins of fails + self.assertGreaterEqual(total_sleep, 600) + + def test_retry_client_unit(self): + client = SparkConnectClient("sc://foo/;token=bar") + + policyA = TestPolicy() + policyB = DefaultPolicy() + + client.set_retry_policies([policyA, policyB]) + + self.assertEqual(client.get_retry_policies(), [policyA, policyB]) + + def test_channel_builder_with_session(self): + dummy = str(uuid.uuid4()) + chan = DefaultChannelBuilder(f"sc://foo/;session_id={dummy}") + client = SparkConnectClient(chan) + self.assertEqual(client._session_id, chan.session_id) + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class SparkConnectClientReattachTestCase(unittest.TestCase): + def setUp(self) -> None: + self.request = proto.ExecutePlanRequest() + self.retrying = lambda: Retrying(TestPolicy()) + self.response = proto.ExecutePlanResponse( + response_id="1", + ) + self.finished = proto.ExecutePlanResponse( + result_complete=proto.ExecutePlanResponse.ResultComplete(), + response_id="2", + ) + + def _stub_with(self, execute=None, attach=None): + return MockSparkConnectStub( + execute_ops=ResponseGenerator(execute) if execute is not None else None, + attach_ops=ResponseGenerator(attach) if attach is not None else None, + ) + + def test_basic_flow(self): + stub = self._stub_with([self.response, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for b in ite: + pass + + def check_all(): + self.assertEqual(0, stub.attach_calls) + self.assertEqual(1, stub.release_until_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + eventually(timeout=1, catch_assertions=True)(check_all)() + + def test_fail_during_execute(self): + def fatal(): + raise TestException("Fatal") + + stub = self._stub_with([self.response, fatal]) + with self.assertRaises(TestException): + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for b in ite: + pass + + def check(): + self.assertEqual(0, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.release_until_calls) + self.assertEqual(1, stub.execute_calls) + + eventually(timeout=1, catch_assertions=True)(check)() + + def test_fail_and_retry_during_execute(self): + def non_fatal(): + raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) + + stub = self._stub_with([self.response, non_fatal], [self.response, self.response, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for b in ite: + pass + + def check(): + self.assertEqual(1, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(3, stub.release_until_calls) + self.assertEqual(1, stub.execute_calls) + + eventually(timeout=1, catch_assertions=True)(check)() + + def test_fail_and_retry_during_reattach(self): + count = 0 + + def non_fatal(): + nonlocal count + if count < 2: + count += 1 + raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) + else: + return proto.ExecutePlanResponse() + + stub = self._stub_with([self.response, non_fatal], [self.response, non_fatal, self.response, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for b in ite: + pass + + def check(): + self.assertEqual(2, stub.attach_calls) + self.assertEqual(3, stub.release_until_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + eventually(timeout=1, catch_assertions=True)(check)() + + def test_not_found_recovers(self): + """SPARK-48056: Assert that the client recovers from session or operation not + found error if no partial responses were previously received. + """ + + def not_found_recovers(error_code: str): + def not_found(): + raise TestException( + error_code, + grpc.StatusCode.UNAVAILABLE, + trailing_status=status_pb2.Status(code=14, message=error_code, details=""), + ) + + stub = self._stub_with([not_found, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + + for _ in ite: + pass + + def checks(): + self.assertEqual(2, stub.execute_calls) + self.assertEqual(0, stub.attach_calls) + self.assertEqual(0, stub.release_calls) + self.assertEqual(0, stub.release_until_calls) + + eventually(timeout=1, catch_assertions=True)(checks)() + + parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"] + for b in parameters: + not_found_recovers(b) + + def test_not_found_fails(self): + """SPARK-48056: Assert that the client fails from session or operation not found error + if a partial response was previously received. + """ + + def not_found_fails(error_code: str): + def not_found(): + raise TestException( + error_code, + grpc.StatusCode.UNAVAILABLE, + trailing_status=status_pb2.Status(code=14, message=error_code, details=""), + ) + + stub = self._stub_with([self.response], [not_found]) + + with self.assertRaises(PySparkRuntimeError) as e: + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for _ in ite: + pass + + self.assertTrue("RESPONSE_ALREADY_RECEIVED" in e.exception.getMessage()) + + def checks(): + self.assertEqual(1, stub.execute_calls) + self.assertEqual(1, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.release_until_calls) + + eventually(timeout=1, catch_assertions=True)(checks)() + + parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"] + for b in parameters: + not_found_fails(b) + + def test_observed_session_id(self): + stub = self._stub_with([self.response, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + session_id = "test-session-id" + + reattach = ite._create_reattach_execute_request() + self.assertEqual(reattach.client_observed_server_side_session_id, "") + + self.request.client_observed_server_side_session_id = session_id + reattach = ite._create_reattach_execute_request() + self.assertEqual(reattach.client_observed_server_side_session_id, session_id) + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.client.test_client import * + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/tests/connect/test_conf.py b/tests/connect/test_conf.py new file mode 100644 index 0000000000..214e1833a1 --- /dev/null +++ b/tests/connect/test_conf.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from decimal import Decimal + +from pyspark.errors import IllegalArgumentException, PySparkTypeError +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class ConfTestsMixin: + def test_conf(self): + spark = self.spark + spark.conf.set("bogo", "sipeo") + self.assertEqual(spark.conf.get("bogo"), "sipeo") + spark.conf.set("bogo", "ta") + self.assertEqual(spark.conf.get("bogo"), "ta") + self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") + self.assertEqual(spark.conf.get("not.set", "ta"), "ta") + self.assertRaisesRegex(Exception, "not.set", lambda: spark.conf.get("not.set")) + spark.conf.unset("bogo") + self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + + self.assertEqual(spark.conf.get("hyukjin", None), None) + + # This returns 'STATIC' because it's the default value of + # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in + # `spark.conf.get` is unset. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") + + # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but + # `defaultValue` in `spark.conf.get` is set to None. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + + self.assertTrue(spark.conf.isModifiable("spark.sql.execution.arrow.maxRecordsPerBatch")) + self.assertFalse(spark.conf.isModifiable("spark.sql.warehouse.dir")) + + def test_conf_with_python_objects(self): + spark = self.spark + + try: + for value, expected in [(True, "true"), (False, "false")]: + spark.conf.set("foo", value) + self.assertEqual(spark.conf.get("foo"), expected) + + spark.conf.set("foo", 1) + self.assertEqual(spark.conf.get("foo"), "1") + + with self.assertRaises(IllegalArgumentException): + spark.conf.set("foo", None) + + with self.assertRaises(Exception): + spark.conf.set("foo", Decimal(1)) + + with self.assertRaises(PySparkTypeError) as pe: + spark.conf.get(123) + + self.check_error( + exception=pe.exception, + errorClass="NOT_STR", + messageParameters={ + "arg_name": "key", + "arg_type": "int", + }, + ) + finally: + spark.conf.unset("foo") + + def test_get_all(self): + spark = self.spark + all_confs = spark.conf.getAll + + self.assertTrue(len(all_confs) > 0) + self.assertNotIn("foo", all_confs) + + try: + spark.conf.set("foo", "bar") + updated = spark.conf.getAll + + self.assertEqual(len(updated), len(all_confs) + 1) + self.assertIn("foo", updated) + finally: + spark.conf.unset("foo") + + +class ConfTests(ConfTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + import unittest + + from pyspark.sql.tests.test_conf import * + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/tests/connect/test_config.py b/tests/connect/test_config.py new file mode 100644 index 0000000000..c1fa2f1f45 --- /dev/null +++ b/tests/connect/test_config.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import pandas as ps +from pyspark.pandas import config +from pyspark.pandas.config import DictWrapper, Option +from pyspark.testing.pandasutils import PandasOnSparkTestCase + + +class ConfigTestsMixin: + def setUp(self): + config._options_dict["test.config"] = Option(key="test.config", doc="", default="default") + + config._options_dict["test.config.list"] = Option(key="test.config.list", doc="", default=[], types=list) + config._options_dict["test.config.float"] = Option(key="test.config.float", doc="", default=1.2, types=float) + + config._options_dict["test.config.int"] = Option( + key="test.config.int", + doc="", + default=1, + types=int, + check_func=(lambda v: v > 0, "bigger then 0"), + ) + config._options_dict["test.config.int.none"] = Option( + key="test.config.int", doc="", default=None, types=(int, type(None)) + ) + + def tearDown(self): + ps.reset_option("test.config") + del config._options_dict["test.config"] + del config._options_dict["test.config.list"] + del config._options_dict["test.config.float"] + del config._options_dict["test.config.int"] + del config._options_dict["test.config.int.none"] + + def test_get_set_reset_option(self): + self.assertEqual(ps.get_option("test.config"), "default") + + ps.set_option("test.config", "value") + self.assertEqual(ps.get_option("test.config"), "value") + + ps.reset_option("test.config") + self.assertEqual(ps.get_option("test.config"), "default") + + def test_get_set_reset_option_different_types(self): + ps.set_option("test.config.list", [1, 2, 3, 4]) + self.assertEqual(ps.get_option("test.config.list"), [1, 2, 3, 4]) + + ps.set_option("test.config.float", 5.0) + self.assertEqual(ps.get_option("test.config.float"), 5.0) + + ps.set_option("test.config.int", 123) + self.assertEqual(ps.get_option("test.config.int"), 123) + + self.assertEqual(ps.get_option("test.config.int.none"), None) # default None + ps.set_option("test.config.int.none", 123) + self.assertEqual(ps.get_option("test.config.int.none"), 123) + ps.set_option("test.config.int.none", None) + self.assertEqual(ps.get_option("test.config.int.none"), None) + + def test_different_types(self): + with self.assertRaisesRegex(TypeError, "was "): + ps.set_option("test.config.list", 1) + + with self.assertRaisesRegex(TypeError, "however, expected types are"): + ps.set_option("test.config.float", "abc") + + with self.assertRaisesRegex(TypeError, "[]"): + ps.set_option("test.config.int", "abc") + + with self.assertRaisesRegex(TypeError, "(, )"): + ps.set_option("test.config.int.none", "abc") + + def test_check_func(self): + with self.assertRaisesRegex(ValueError, "bigger then 0"): + ps.set_option("test.config.int", -1) + + def test_unknown_option(self): + with self.assertRaisesRegex(config.OptionError, "No such option"): + ps.get_option("unknown") + + with self.assertRaisesRegex(config.OptionError, "Available options"): + ps.set_option("unknown", "value") + + with self.assertRaisesRegex(config.OptionError, "test.config"): + ps.reset_option("unknown") + + def test_namespace_access(self): + try: + self.assertEqual(ps.options.compute.max_rows, ps.get_option("compute.max_rows")) + ps.options.compute.max_rows = 0 + self.assertEqual(ps.options.compute.max_rows, 0) + self.assertTrue(isinstance(ps.options.compute, DictWrapper)) + + wrapper = ps.options.compute + self.assertEqual(wrapper.max_rows, ps.get_option("compute.max_rows")) + wrapper.max_rows = 1000 + self.assertEqual(ps.options.compute.max_rows, 1000) + + self.assertRaisesRegex(config.OptionError, "No such option", lambda: ps.options.compu) + self.assertRaisesRegex(config.OptionError, "No such option", lambda: ps.options.compute.max) + self.assertRaisesRegex(config.OptionError, "No such option", lambda: ps.options.max_rows1) + + with self.assertRaisesRegex(config.OptionError, "No such option"): + ps.options.compute.max = 0 + with self.assertRaisesRegex(config.OptionError, "No such option"): + ps.options.compute = 0 + with self.assertRaisesRegex(config.OptionError, "No such option"): + ps.options.com = 0 + finally: + ps.reset_option("compute.max_rows") + + def test_dir_options(self): + self.assertTrue("compute.default_index_type" in dir(ps.options)) + self.assertTrue("plotting.sample_ratio" in dir(ps.options)) + + self.assertTrue("default_index_type" in dir(ps.options.compute)) + self.assertTrue("sample_ratio" not in dir(ps.options.compute)) + + self.assertTrue("default_index_type" not in dir(ps.options.plotting)) + self.assertTrue("sample_ratio" in dir(ps.options.plotting)) + + +class ConfigTests(ConfigTestsMixin, PandasOnSparkTestCase): + pass + + +if __name__ == "__main__": + import unittest + + from pyspark.pandas.tests.test_config import * + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/tests/connect/test_connect.py b/tests/connect/test_connect.py new file mode 100644 index 0000000000..f6afa28876 --- /dev/null +++ b/tests/connect/test_connect.py @@ -0,0 +1,127 @@ +# def test_apply_lambda +# def test_apply_module_func +# def test_apply_inline_func +# def test_apply_lambda_pyobj + +from __future__ import annotations + +import time + +from pyspark.sql import SparkSession +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import col + + +def test_simple(): + print("Starting Daft-Connect server") + # connect_start("sc://localhost:50052") + + print("Created spark connect server") + + # Create a Spark session using Spark Connect + spark: SparkSession = ( + SparkSession.builder.appName("SparkConnectExample").remote("sc://localhost:50051").getOrCreate() + ) + + print("Spark session created") + + # Read the Parquet file back into a DataFrame + df: DataFrame = spark.read.parquet("/Users/andrewgazelka/Projects/simple-spark-connect/increasing_id_data.parquet") + print("DataFrame read from Parquet file") + # The DataFrame remains unchanged: + # +---+ + # | id| + # +---+ + # | 0| + # | 1| + # | 2| + # | 3| + # | 4| + # +---+ + + print("DataFrame schema:") + df.printSchema() + # root + # |-- id: long (nullable = false) + + print("\nDataFrame content:") + df.show() + + print("done showing") + + # Perform operations on the DataFrame + # 1. filter(col("id") > 2): Select only rows where 'id' is greater than 2 + # 2. withColumn("id2", col("id") + 2): Add a new column 'id2' that is 'id' plus 2 + result: DataFrame = df.filter(col("id") > 2).withColumn("id2", col("id") + 2) + + print("\nFiltered and transformed DataFrame:") + result.show() + + # result_pandas = result.toPandas() + # The resulting DataFrame looks like this: + # +---+---+ + # | id|id2| + # +---+---+ + # | 3| 5| + # | 4| 6| + # +---+---+ + # Explanation: + # 1. Only rows with id > 2 are kept (3 and 4) + # 2. A new column 'id2' is added with values id + 2 + + # Stop the Spark session + # spark.sql("select * from increasing_id_data").show() + + spark.stop() + print("Spark session stopped") + + # Waiting for 10 seconds + time.sleep(2) + + print("End of main function") + + +# from daft. +# +# +# def add_1(x): +# return x + 1 +# +# +# def test_apply_module_func(): +# df = daft.from_pydict({"a": [1, 2, 3]}) +# df = df.with_column("a_plus_1", df["a"].apply(add_1, return_dtype=DataType.int32())) +# assert df.to_pydict() == {"a": [1, 2, 3], "a_plus_1": [2, 3, 4]} +# +# +# def test_apply_lambda(): +# df = daft.from_pydict({"a": [1, 2, 3]}) +# df = df.with_column("a_plus_1", df["a"].apply(lambda x: x + 1, return_dtype=DataType.int32())) +# assert df.to_pydict() == {"a": [1, 2, 3], "a_plus_1": [2, 3, 4]} +# +# +# def test_apply_inline_func(): +# def inline_add_1(x): +# return x + 1 +# +# df = daft.from_pydict({"a": [1, 2, 3]}) +# df = df.with_column("a_plus_1", df["a"].apply(inline_add_1, return_dtype=DataType.int32())) +# assert df.to_pydict() == {"a": [1, 2, 3], "a_plus_1": [2, 3, 4]} +# +# +# @dataclasses.dataclass +# class MyObj: +# x: int +# +# +# def test_apply_obj(): +# df = daft.from_pydict({"obj": [MyObj(x=0), MyObj(x=0), MyObj(x=0)]}) +# +# def inline_mutate_obj(obj): +# obj.x = 1 +# return obj +# +# df = df.with_column("mut_obj", df["obj"].apply(inline_mutate_obj, return_dtype=DataType.python())) +# result = df.to_pydict() +# for mut_obj in result["mut_obj"]: +# assert mut_obj.x == 1 diff --git a/tests/connect/test_connect_basic.py b/tests/connect/test_connect_basic.py new file mode 100755 index 0000000000..0f55d2e78f --- /dev/null +++ b/tests/connect/test_connect_basic.py @@ -0,0 +1,1477 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import gc +import io +import os +import shutil +import tempfile +import unittest +from contextlib import redirect_stdout + +# from pyspark.util import is_remote_only +from pyspark.errors import PySparkTypeError, PySparkValueError +from pyspark.errors.exceptions.connect import ( + AnalysisException, + SparkConnectException, +) +from pyspark.sql import Row +from pyspark.sql import SparkSession as PySparkSession +from pyspark.sql.types import ( + ArrayType, + IntegerType, + LongType, + MapType, + Row, + StringType, + StructField, + StructType, +) +from pyspark.testing.connectutils import ( + ReusedConnectTestCase, + should_test_connect, +) +from pyspark.testing.pandasutils import PandasOnSparkTestUtils +from pyspark.testing.sqlutils import SQLTestUtils +from pyspark.testing.utils import eventually + +if should_test_connect: + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + from pyspark.sql.connect.column import Column + from pyspark.sql.connect.dataframe import DataFrame as CDataFrame + from pyspark.sql.connect.proto import Expression as ProtoExpression + from pyspark.sql.dataframe import DataFrame + + +def is_remote_only(): + return False + + +@unittest.skipIf(is_remote_only(), "Requires JVM access") +class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils): + """Parent test fixture class for all Spark Connect related + test cases.""" + + @classmethod + def setUpClass(cls): + super(SparkConnectSQLTestCase, cls).setUpClass() + # Disable the shared namespace so pyspark.sql.functions, etc point the regular + # PySpark libraries. + os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1" + + cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session. + cls.spark = PySparkSession._instantiatedSession + assert cls.spark is not None + + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.testDataStr = [Row(key=str(i)) for i in range(100)] + cls.df = cls.spark.sparkContext.parallelize(cls.testData).toDF() + cls.df_text = cls.spark.sparkContext.parallelize(cls.testDataStr).toDF() + + cls.tbl_name = "test_connect_basic_table_1" + cls.tbl_name2 = "test_connect_basic_table_2" + cls.tbl_name3 = "test_connect_basic_table_3" + cls.tbl_name4 = "test_connect_basic_table_4" + cls.tbl_name_empty = "test_connect_basic_table_empty" + + # Cleanup test data + cls.spark_connect_clean_up_test_data() + # Load test data + cls.spark_connect_load_test_data() + + @classmethod + def tearDownClass(cls): + try: + cls.spark_connect_clean_up_test_data() + # Stopping Spark Connect closes the session in JVM at the server. + cls.spark = cls.connect + del os.environ["PYSPARK_NO_NAMESPACE_SHARE"] + finally: + super(SparkConnectSQLTestCase, cls).tearDownClass() + + @classmethod + def spark_connect_load_test_data(cls): + df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"]) + # Since we might create multiple Spark sessions, we need to create global temporary view + # that is specifically maintained in the "global_temp" schema. + df.write.saveAsTable(cls.tbl_name) + df2 = cls.spark.createDataFrame([(x, f"{x}", 2 * x) for x in range(100)], ["col1", "col2", "col3"]) + df2.write.saveAsTable(cls.tbl_name2) + df3 = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "test\n_column"]) + df3.write.saveAsTable(cls.tbl_name3) + df4 = cls.spark.createDataFrame( + [(x, {"a": x}, [x, x * 2]) for x in range(100)], ["id", "map_column", "array_column"] + ) + df4.write.saveAsTable(cls.tbl_name4) + empty_table_schema = StructType( + [ + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), + ] + ) + emptyRDD = cls.spark.sparkContext.emptyRDD() + empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema) + empty_df.write.saveAsTable(cls.tbl_name_empty) + + @classmethod + def spark_connect_clean_up_test_data(cls): + cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name}") + cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name2}") + cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name3}") + cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name4}") + cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name_empty}") + + +class SparkConnectBasicTests(SparkConnectSQLTestCase): + def test_serialization(self): + from pyspark.cloudpickle import dumps, loads + + cdf = self.connect.range(10) + data = dumps(cdf) + cdf2 = loads(data) + self.assertEqual(cdf.collect(), cdf2.collect()) + + def test_df_getattr_behavior(self): + cdf = self.connect.range(10) + sdf = self.spark.range(10) + + sdf._simple_extension = 10 + cdf._simple_extension = 10 + + self.assertEqual(sdf._simple_extension, cdf._simple_extension) + self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension)) + + self.assertTrue(hasattr(cdf, "_simple_extension")) + self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit")) + + def test_df_get_item(self): + # SPARK-41779: test __getitem__ + + query = """ + SELECT * FROM VALUES + (true, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) + AS tab(a, b, c) + """ + + # +-----+----+----+ + # | a| b| c| + # +-----+----+----+ + # | true| 1|NULL| + # |false|NULL| 2.0| + # | NULL| 3| 3.0| + # +-----+----+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # filter + self.assert_eq( + cdf[cdf.a].toPandas(), + sdf[sdf.a].toPandas(), + ) + self.assert_eq( + cdf[cdf.b.isin(2, 3)].toPandas(), + sdf[sdf.b.isin(2, 3)].toPandas(), + ) + self.assert_eq( + cdf[cdf.c > 1.5].toPandas(), + sdf[sdf.c > 1.5].toPandas(), + ) + + # select + self.assert_eq( + cdf[[cdf.a, "b", cdf.c]].toPandas(), + sdf[[sdf.a, "b", sdf.c]].toPandas(), + ) + self.assert_eq( + cdf[(cdf.a, "b", cdf.c)].toPandas(), + sdf[(sdf.a, "b", sdf.c)].toPandas(), + ) + + # select by index + self.assertTrue(isinstance(cdf[0], Column)) + self.assertTrue(isinstance(cdf[1], Column)) + self.assertTrue(isinstance(cdf[2], Column)) + + self.assert_eq( + cdf[[cdf[0], cdf[1], cdf[2]]].toPandas(), + sdf[[sdf[0], sdf[1], sdf[2]]].toPandas(), + ) + + # check error + with self.assertRaises(PySparkTypeError) as pe: + cdf[1.5] + + self.check_error( + exception=pe.exception, + errorClass="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE", + messageParameters={ + "arg_name": "item", + "arg_type": "float", + }, + ) + + with self.assertRaises(PySparkTypeError) as pe: + cdf[None] + + self.check_error( + exception=pe.exception, + errorClass="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE", + messageParameters={ + "arg_name": "item", + "arg_type": "NoneType", + }, + ) + + with self.assertRaises(PySparkTypeError) as pe: + cdf[cdf] + + self.check_error( + exception=pe.exception, + errorClass="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE", + messageParameters={ + "arg_name": "item", + "arg_type": "DataFrame", + }, + ) + + def test_join_condition_column_list_columns(self): + left_connect_df = self.connect.read.table(self.tbl_name) + right_connect_df = self.connect.read.table(self.tbl_name2) + left_spark_df = self.spark.read.table(self.tbl_name) + right_spark_df = self.spark.read.table(self.tbl_name2) + joined_plan = left_connect_df.join( + other=right_connect_df, on=left_connect_df.id == right_connect_df.col1, how="inner" + ) + joined_plan2 = left_spark_df.join(other=right_spark_df, on=left_spark_df.id == right_spark_df.col1, how="inner") + self.assert_eq(joined_plan.toPandas(), joined_plan2.toPandas()) + + joined_plan3 = left_connect_df.join( + other=right_connect_df, + on=[ + left_connect_df.id == right_connect_df.col1, + left_connect_df.name == right_connect_df.col2, + ], + how="inner", + ) + joined_plan4 = left_spark_df.join( + other=right_spark_df, + on=[left_spark_df.id == right_spark_df.col1, left_spark_df.name == right_spark_df.col2], + how="inner", + ) + self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas()) + + def test_join_ambiguous_cols(self): + # SPARK-41812: test join with ambiguous columns + data1 = [Row(id=1, value="foo"), Row(id=2, value=None)] + cdf1 = self.connect.createDataFrame(data1) + sdf1 = self.spark.createDataFrame(data1) + + data2 = [Row(value="bar"), Row(value=None), Row(value="foo")] + cdf2 = self.connect.createDataFrame(data2) + sdf2 = self.spark.createDataFrame(data2) + + cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]) + sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]) + + self.assertEqual(cdf3.schema, sdf3.schema) + self.assertEqual(cdf3.collect(), sdf3.collect()) + + cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"])) + sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"])) + + self.assertEqual(cdf4.schema, sdf4.schema) + self.assertEqual(cdf4.collect(), sdf4.collect()) + + cdf5 = cdf1.join(cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"]))) + sdf5 = sdf1.join(sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"]))) + + self.assertEqual(cdf5.schema, sdf5.schema) + self.assertEqual(cdf5.collect(), sdf5.collect()) + + cdf6 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf1.value) + sdf6 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf1.value) + + self.assertEqual(cdf6.schema, sdf6.schema) + self.assertEqual(cdf6.collect(), sdf6.collect()) + + cdf7 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf2.value) + sdf7 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf2.value) + + self.assertEqual(cdf7.schema, sdf7.schema) + self.assertEqual(cdf7.collect(), sdf7.collect()) + + def test_join_with_cte(self): + cte_query = "with dt as (select 1 as ida) select ida as id from dt" + + sdf1 = self.spark.range(10) + sdf2 = self.spark.sql(cte_query) + sdf3 = sdf1.join(sdf2, sdf1.id == sdf2.id) + + cdf1 = self.connect.range(10) + cdf2 = self.connect.sql(cte_query) + cdf3 = cdf1.join(cdf2, cdf1.id == cdf2.id) + + self.assertEqual(sdf3.schema, cdf3.schema) + self.assertEqual(sdf3.collect(), cdf3.collect()) + + def test_with_columns_renamed(self): + # SPARK-41312: test DataFrame.withColumnsRenamed() + self.assertEqual( + self.connect.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema, + self.spark.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema, + ) + self.assertEqual( + self.connect.read.table(self.tbl_name).withColumnsRenamed({"id": "id_new", "name": "name_new"}).schema, + self.spark.read.table(self.tbl_name).withColumnsRenamed({"id": "id_new", "name": "name_new"}).schema, + ) + + def test_simple_explain_string(self): + df = self.connect.read.table(self.tbl_name).limit(10) + result = df._explain_string() + self.assertGreater(len(result), 0) + + def _check_print_schema(self, query: str): + with io.StringIO() as buf, redirect_stdout(buf): + self.spark.sql(query).printSchema() + print1 = buf.getvalue() + with io.StringIO() as buf, redirect_stdout(buf): + self.connect.sql(query).printSchema() + print2 = buf.getvalue() + self.assertEqual(print1, print2, query) + + for level in [-1, 0, 1, 2, 3, 4]: + with io.StringIO() as buf, redirect_stdout(buf): + self.spark.sql(query).printSchema(level) + print1 = buf.getvalue() + with io.StringIO() as buf, redirect_stdout(buf): + self.connect.sql(query).printSchema(level) + print2 = buf.getvalue() + self.assertEqual(print1, print2, query) + + def test_schema(self): + schema = self.connect.read.table(self.tbl_name).schema + self.assertEqual( + StructType([StructField("id", LongType(), True), StructField("name", StringType(), True)]), + schema, + ) + + # test FloatType, DoubleType, DecimalType, StringType, BooleanType, NullType + query = """ + SELECT * FROM VALUES + (float(1.0), double(1.0), 1.0, "1", true, NULL), + (float(2.0), double(2.0), 2.0, "2", false, NULL), + (float(3.0), double(3.0), NULL, "3", false, NULL) + AS tab(a, b, c, d, e, f) + """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + self._check_print_schema(query) + + # test TimestampType, DateType + query = """ + SELECT * FROM VALUES + (TIMESTAMP('2019-04-12 15:50:00'), DATE('2022-02-22')), + (TIMESTAMP('2019-04-12 15:50:00'), NULL), + (NULL, DATE('2022-02-22')) + AS tab(a, b) + """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + self._check_print_schema(query) + + # test DayTimeIntervalType + query = """ SELECT INTERVAL '100 10:30' DAY TO MINUTE AS interval """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + self._check_print_schema(query) + + # test MapType + query = """ + SELECT * FROM VALUES + (MAP('a', 'ab'), MAP('a', 'ab'), MAP(1, 2, 3, 4)), + (MAP('x', 'yz'), MAP('x', NULL), NULL), + (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4)) + AS tab(a, b, c) + """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + self._check_print_schema(query) + + # test ArrayType + query = """ + SELECT * FROM VALUES + (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3)), + (ARRAY('x', NULL), NULL, ARRAY(1, 3)), + (NULL, ARRAY(-1, -2, -3), Array()) + AS tab(a, b, c) + """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + self._check_print_schema(query) + + # test StructType + query = """ + SELECT STRUCT(a, b, c, d), STRUCT(e, f, g), STRUCT(STRUCT(a, b), STRUCT(h)) FROM VALUES + (float(1.0), double(1.0), 1.0, "1", true, NULL, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)), + (float(2.0), double(2.0), 2.0, "2", false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)), + (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL) + AS tab(a, b, c, d, e, f, g, h) + """ + self.assertEqual( + self.spark.sql(query).schema, + self.connect.sql(query).schema, + ) + self._check_print_schema(query) + + def test_to(self): + # SPARK-41464: test DataFrame.to() + + cdf = self.connect.read.table(self.tbl_name) + df = self.spark.read.table(self.tbl_name) + + def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType): + cdf_to = cdf.to(schema) + df_to = df.to(schema) + self.assertEqual(cdf_to.schema, df_to.schema) + self.assert_eq(cdf_to.toPandas(), df_to.toPandas()) + + # The schema has not changed + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField("name", StringType(), True), + ] + ) + + assert_eq_schema(cdf, df, schema) + + # Change schema with struct + schema2 = StructType([StructField("struct", schema, False)]) + + cdf_to = cdf.select(CF.struct("id", "name").alias("struct")).to(schema2) + df_to = df.select(SF.struct("id", "name").alias("struct")).to(schema2) + + self.assertEqual(cdf_to.schema, df_to.schema) + + # Change the column name + schema = StructType( + [ + StructField("col1", IntegerType(), True), + StructField("col2", StringType(), True), + ] + ) + + assert_eq_schema(cdf, df, schema) + + # Change the column data type + schema = StructType( + [ + StructField("id", StringType(), True), + StructField("name", StringType(), True), + ] + ) + + assert_eq_schema(cdf, df, schema) + + # Reduce the column quantity and change data type + schema = StructType( + [ + StructField("id", LongType(), True), + ] + ) + + assert_eq_schema(cdf, df, schema) + + # incompatible field nullability + schema = StructType([StructField("id", LongType(), False)]) + self.assertRaisesRegex( + AnalysisException, + "NULLABLE_COLUMN_OR_FIELD", + lambda: cdf.to(schema).toPandas(), + ) + + # field cannot upcast + schema = StructType([StructField("name", LongType())]) + self.assertRaisesRegex( + AnalysisException, + "INVALID_COLUMN_OR_FIELD_DATA_TYPE", + lambda: cdf.to(schema).toPandas(), + ) + + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField("name", IntegerType(), True), + ] + ) + self.assertRaisesRegex( + AnalysisException, + "INVALID_COLUMN_OR_FIELD_DATA_TYPE", + lambda: cdf.to(schema).toPandas(), + ) + + # Test map type and array type + schema = StructType( + [ + StructField("id", StringType(), True), + StructField("my_map", MapType(StringType(), IntegerType(), False), True), + StructField("my_array", ArrayType(IntegerType(), False), True), + ] + ) + cdf = self.connect.read.table(self.tbl_name4) + df = self.spark.read.table(self.tbl_name4) + + assert_eq_schema(cdf, df, schema) + + def test_toDF(self): + # SPARK-41310: test DataFrame.toDF() + self.assertEqual( + self.connect.read.table(self.tbl_name).toDF("col1", "col2").schema, + self.spark.read.table(self.tbl_name).toDF("col1", "col2").schema, + ) + + def test_print_schema(self): + # SPARK-41216: Test print schema + tree_str = self.connect.sql("SELECT 1 AS X, 2 AS Y").schema.treeString() + # root + # |-- X: integer (nullable = false) + # |-- Y: integer (nullable = false) + expected = "root\n |-- X: integer (nullable = false)\n |-- Y: integer (nullable = false)\n" + self.assertEqual(tree_str, expected) + + def test_is_local(self): + # SPARK-41216: Test is local + self.assertTrue(self.connect.sql("SHOW DATABASES").isLocal()) + self.assertFalse(self.connect.read.table(self.tbl_name).isLocal()) + + def test_is_streaming(self): + # SPARK-41216: Test is streaming + self.assertFalse(self.connect.read.table(self.tbl_name).isStreaming) + self.assertFalse(self.connect.sql("SELECT 1 AS X LIMIT 0").isStreaming) + + def test_input_files(self): + # SPARK-41216: Test input files + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + try: + self.df_text.write.text(tmpPath) + + input_files_list1 = self.spark.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles() + input_files_list2 = self.connect.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles() + + self.assertTrue(len(input_files_list1) > 0) + self.assertEqual(len(input_files_list1), len(input_files_list2)) + for file_path in input_files_list2: + self.assertTrue(file_path in input_files_list1) + finally: + shutil.rmtree(tmpPath) + + def test_limit_offset(self): + df = self.connect.read.table(self.tbl_name) + pd = df.limit(10).offset(1).toPandas() + self.assertEqual(9, len(pd.index)) + pd2 = df.offset(98).limit(10).toPandas() + self.assertEqual(2, len(pd2.index)) + + def test_tail(self): + df = self.connect.read.table(self.tbl_name) + df2 = self.spark.read.table(self.tbl_name) + self.assertEqual(df.tail(10), df2.tail(10)) + + def test_sql(self): + pdf = self.connect.sql("SELECT 1").toPandas() + self.assertEqual(1, len(pdf.index)) + + def test_sql_with_named_args(self): + sqlText = "SELECT *, element_at(:m, 'a') FROM range(10) WHERE id > :minId" + df = self.connect.sql(sqlText, args={"minId": 7, "m": CF.create_map(CF.lit("a"), CF.lit(1))}) + df2 = self.spark.sql(sqlText, args={"minId": 7, "m": SF.create_map(SF.lit("a"), SF.lit(1))}) + self.assert_eq(df.toPandas(), df2.toPandas()) + + def test_namedargs_with_global_limit(self): + sqlText = """SELECT * FROM VALUES (TIMESTAMP('2022-12-25 10:30:00'), 1) as tab(date, val) + where val = :val""" + df = self.connect.sql(sqlText, args={"val": 1}) + df2 = self.spark.sql(sqlText, args={"val": 1}) + self.assert_eq(df.toPandas(), df2.toPandas()) + + self.assert_eq(df.first()[0], datetime.datetime(2022, 12, 25, 10, 30)) + self.assert_eq(df.first().date, datetime.datetime(2022, 12, 25, 10, 30)) + self.assert_eq(df.first()[1], 1) + self.assert_eq(df.first().val, 1) + + def test_sql_with_pos_args(self): + sqlText = "SELECT *, element_at(?, 1) FROM range(10) WHERE id > ?" + df = self.connect.sql(sqlText, args=[CF.array(CF.lit(1)), 7]) + df2 = self.spark.sql(sqlText, args=[SF.array(SF.lit(1)), 7]) + self.assert_eq(df.toPandas(), df2.toPandas()) + + def test_sql_with_invalid_args(self): + sqlText = "SELECT ?, ?, ?" + for session in [self.connect, self.spark]: + with self.assertRaises(PySparkTypeError) as pe: + session.sql(sqlText, args={1, 2, 3}) + + self.check_error( + exception=pe.exception, + errorClass="INVALID_TYPE", + messageParameters={"arg_name": "args", "arg_type": "set"}, + ) + + def test_deduplicate(self): + # SPARK-41326: test distinct and dropDuplicates. + df = self.connect.read.table(self.tbl_name) + df2 = self.spark.read.table(self.tbl_name) + self.assert_eq(df.distinct().toPandas(), df2.distinct().toPandas()) + self.assert_eq(df.dropDuplicates().toPandas(), df2.dropDuplicates().toPandas()) + self.assert_eq(df.dropDuplicates(["name"]).toPandas(), df2.dropDuplicates(["name"]).toPandas()) + + def test_drop(self): + # SPARK-41169: test drop + query = """ + SELECT * FROM VALUES + (false, 1, NULL), (false, NULL, 2), (NULL, 3, 3) + AS tab(a, b, c) + """ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + self.assert_eq( + cdf.drop("a").toPandas(), + sdf.drop("a").toPandas(), + ) + self.assert_eq( + cdf.drop("a", "b").toPandas(), + sdf.drop("a", "b").toPandas(), + ) + self.assert_eq( + cdf.drop("a", "x").toPandas(), + sdf.drop("a", "x").toPandas(), + ) + self.assert_eq( + cdf.drop(cdf.a, "x").toPandas(), + sdf.drop(sdf.a, "x").toPandas(), + ) + + def test_subquery_alias(self) -> None: + # SPARK-40938: test subquery alias. + plan_text = self.connect.read.table(self.tbl_name).alias("special_alias")._explain_string(extended=True) + self.assertTrue("special_alias" in plan_text) + + def test_sort(self): + # SPARK-41332: test sort + query = """ + SELECT * FROM VALUES + (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) + AS tab(a, b, c) + """ + # +-----+----+----+ + # | a| b| c| + # +-----+----+----+ + # |false| 1|NULL| + # |false|NULL| 2.0| + # | NULL| 3| 3.0| + # +-----+----+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + self.assert_eq( + cdf.sort("a").toPandas(), + sdf.sort("a").toPandas(), + ) + self.assert_eq( + cdf.sort("c").toPandas(), + sdf.sort("c").toPandas(), + ) + self.assert_eq( + cdf.sort("b").toPandas(), + sdf.sort("b").toPandas(), + ) + self.assert_eq( + cdf.sort(cdf.c, "b").toPandas(), + sdf.sort(sdf.c, "b").toPandas(), + ) + self.assert_eq( + cdf.sort(cdf.c.desc(), "b").toPandas(), + sdf.sort(sdf.c.desc(), "b").toPandas(), + ) + self.assert_eq( + cdf.sort(cdf.c.desc(), cdf.a.asc()).toPandas(), + sdf.sort(sdf.c.desc(), sdf.a.asc()).toPandas(), + ) + + def test_range(self): + self.assert_eq( + self.connect.range(start=0, end=10).toPandas(), + self.spark.range(start=0, end=10).toPandas(), + ) + self.assert_eq( + self.connect.range(start=0, end=10, step=3).toPandas(), + self.spark.range(start=0, end=10, step=3).toPandas(), + ) + self.assert_eq( + self.connect.range(start=0, end=10, step=3, numPartitions=2).toPandas(), + self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas(), + ) + # SPARK-41301 + self.assert_eq(self.connect.range(10).toPandas(), self.connect.range(start=0, end=10).toPandas()) + + def test_create_global_temp_view(self): + # SPARK-41127: test global temp view creation. + with self.tempView("view_1"): + self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") + self.connect.sql("SELECT 2 AS X LIMIT 1").createOrReplaceGlobalTempView("view_1") + self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) + + # Test when creating a view which is already exists but + self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) + with self.assertRaises(AnalysisException): + self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") + + def test_create_session_local_temp_view(self): + # SPARK-41372: test session local temp view creation. + with self.tempView("view_local_temp"): + self.connect.sql("SELECT 1 AS X").createTempView("view_local_temp") + self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 1) + self.connect.sql("SELECT 1 AS X LIMIT 0").createOrReplaceTempView("view_local_temp") + self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0) + + # Test when creating a view which is already exists but + with self.assertRaises(AnalysisException): + self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp") + + def test_select_expr(self): + # SPARK-41201: test selectExpr API. + self.assert_eq( + self.connect.read.table(self.tbl_name).selectExpr("id * 2").toPandas(), + self.spark.read.table(self.tbl_name).selectExpr("id * 2").toPandas(), + ) + self.assert_eq( + self.connect.read.table(self.tbl_name).selectExpr(["id * 2", "cast(name as long) as name"]).toPandas(), + self.spark.read.table(self.tbl_name).selectExpr(["id * 2", "cast(name as long) as name"]).toPandas(), + ) + + self.assert_eq( + self.connect.read.table(self.tbl_name).selectExpr("id * 2", "cast(name as long) as name").toPandas(), + self.spark.read.table(self.tbl_name).selectExpr("id * 2", "cast(name as long) as name").toPandas(), + ) + + def test_select_star(self): + data = [Row(a=1, b=Row(c=2, d=Row(e=3)))] + + # +---+--------+ + # | a| b| + # +---+--------+ + # | 1|{2, {3}}| + # +---+--------+ + + cdf = self.connect.createDataFrame(data=data) + sdf = self.spark.createDataFrame(data=data) + + self.assertEqual( + cdf.select("*").collect(), + sdf.select("*").collect(), + ) + self.assertEqual( + cdf.select("a", "*").collect(), + sdf.select("a", "*").collect(), + ) + self.assertEqual( + cdf.select("a", "b").collect(), + sdf.select("a", "b").collect(), + ) + self.assertEqual( + cdf.select("a", "b.*").collect(), + sdf.select("a", "b.*").collect(), + ) + + def test_union_by_name(self): + # SPARK-41832: Test unionByName + data1 = [(1, 2, 3)] + data2 = [(6, 2, 5)] + df1_connect = self.connect.createDataFrame(data1, ["a", "b", "c"]) + df2_connect = self.connect.createDataFrame(data2, ["a", "b", "c"]) + union_df_connect = df1_connect.unionByName(df2_connect) + + df1_spark = self.spark.createDataFrame(data1, ["a", "b", "c"]) + df2_spark = self.spark.createDataFrame(data2, ["a", "b", "c"]) + union_df_spark = df1_spark.unionByName(df2_spark) + + self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas()) + + df2_connect = self.connect.createDataFrame(data2, ["a", "B", "C"]) + union_df_connect = df1_connect.unionByName(df2_connect, allowMissingColumns=True) + + df2_spark = self.spark.createDataFrame(data2, ["a", "B", "C"]) + union_df_spark = df1_spark.unionByName(df2_spark, allowMissingColumns=True) + + self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas()) + + def test_observe(self): + # SPARK-41527: test DataFrame.observe() + observation_name = "my_metric" + + self.assert_eq( + self.connect.read.table(self.tbl_name) + .filter("id > 3") + .observe(observation_name, CF.min("id"), CF.max("id"), CF.sum("id")) + .toPandas(), + self.spark.read.table(self.tbl_name) + .filter("id > 3") + .observe(observation_name, SF.min("id"), SF.max("id"), SF.sum("id")) + .toPandas(), + ) + + from pyspark.sql.connect.observation import Observation as ConnectObservation + from pyspark.sql.observation import Observation + + cobservation = ConnectObservation(observation_name) + observation = Observation(observation_name) + + cdf = ( + self.connect.read.table(self.tbl_name) + .filter("id > 3") + .observe(cobservation, CF.min("id"), CF.max("id"), CF.sum("id")) + .toPandas() + ) + df = ( + self.spark.read.table(self.tbl_name) + .filter("id > 3") + .observe(observation, SF.min("id"), SF.max("id"), SF.sum("id")) + .toPandas() + ) + + self.assert_eq(cdf, df) + + self.assertEqual(cobservation.get, observation.get) + + observed_metrics = cdf.attrs["observed_metrics"] + self.assert_eq(len(observed_metrics), 1) + self.assert_eq(observed_metrics[0].name, observation_name) + self.assert_eq(len(observed_metrics[0].metrics), 3) + for metric in observed_metrics[0].metrics: + self.assertIsInstance(metric, ProtoExpression.Literal) + values = list(map(lambda metric: metric.long, observed_metrics[0].metrics)) + self.assert_eq(values, [4, 99, 4944]) + + with self.assertRaises(PySparkValueError) as pe: + self.connect.read.table(self.tbl_name).observe(observation_name) + + self.check_error( + exception=pe.exception, + errorClass="CANNOT_BE_EMPTY", + messageParameters={"item": "exprs"}, + ) + + with self.assertRaises(PySparkTypeError) as pe: + self.connect.read.table(self.tbl_name).observe(observation_name, CF.lit(1), "id") + + self.check_error( + exception=pe.exception, + errorClass="NOT_LIST_OF_COLUMN", + messageParameters={"arg_name": "exprs"}, + ) + + def test_with_columns(self): + # SPARK-41256: test withColumn(s). + self.assert_eq( + self.connect.read.table(self.tbl_name).withColumn("id", CF.lit(False)).toPandas(), + self.spark.read.table(self.tbl_name).withColumn("id", SF.lit(False)).toPandas(), + ) + + self.assert_eq( + self.connect.read.table(self.tbl_name) + .withColumns({"id": CF.lit(False), "col_not_exist": CF.lit(False)}) + .toPandas(), + self.spark.read.table(self.tbl_name) + .withColumns( + { + "id": SF.lit(False), + "col_not_exist": SF.lit(False), + } + ) + .toPandas(), + ) + + def test_hint(self): + # SPARK-41349: Test hint + self.assert_eq( + self.connect.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(), + self.spark.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(), + ) + + # Hint with unsupported name will be ignored + self.assert_eq( + self.connect.read.table(self.tbl_name).hint("illegal").toPandas(), + self.spark.read.table(self.tbl_name).hint("illegal").toPandas(), + ) + + # Hint with all supported parameter values + such_a_nice_list = ["itworks1", "itworks2", "itworks3"] + self.assert_eq( + self.connect.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(), + self.spark.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(), + ) + + # Hint with unsupported parameter values + with self.assertRaises(AnalysisException): + self.connect.read.table(self.tbl_name).hint("REPARTITION", "id+1").toPandas() + + # Hint with unsupported parameter types + with self.assertRaises(TypeError): + self.connect.read.table(self.tbl_name).hint("REPARTITION", range(5)).toPandas() + + # Hint with unsupported parameter types + with self.assertRaises(TypeError): + self.connect.read.table(self.tbl_name).hint( + "my awesome hint", 1.2345, 2, such_a_nice_list, range(6) + ).toPandas() + + # Hint with wrong combination + with self.assertRaises(AnalysisException): + self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas() + + def test_join_hint(self): + cdf1 = self.connect.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) + cdf2 = self.connect.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) + + self.assertTrue("BroadcastHashJoin" in cdf1.join(cdf2.hint("BROADCAST"), "name")._explain_string()) + self.assertTrue("SortMergeJoin" in cdf1.join(cdf2.hint("MERGE"), "name")._explain_string()) + self.assertTrue("ShuffledHashJoin" in cdf1.join(cdf2.hint("SHUFFLE_HASH"), "name")._explain_string()) + + def test_extended_hint_types(self): + cdf = self.connect.range(100).toDF("id") + + cdf.hint( + "my awesome hint", + 1.2345, + "what", + ["itworks1", "itworks2", "itworks3"], + ).show() + + with self.assertRaises(PySparkTypeError) as pe: + cdf.hint( + "my awesome hint", + 1.2345, + "what", + {"itworks1": "itworks2"}, + ).show() + + self.check_error( + exception=pe.exception, + errorClass="INVALID_ITEM_FOR_CONTAINER", + messageParameters={ + "arg_name": "parameters", + "allowed_types": "str, float, int, Column, list[str], list[float], list[int]", + "item_type": "dict", + }, + ) + + def test_empty_dataset(self): + # SPARK-41005: Test arrow based collection with empty dataset. + self.assertTrue( + self.connect.sql("SELECT 1 AS X LIMIT 0") + .toPandas() + .equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas()) + ) + pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas() + self.assertEqual(0, len(pdf)) # empty dataset + self.assertEqual(1, len(pdf.columns)) # one column + self.assertEqual("X", pdf.columns[0]) + + def test_is_empty(self): + # SPARK-41212: Test is empty + self.assertFalse(self.connect.sql("SELECT 1 AS X").isEmpty()) + self.assertTrue(self.connect.sql("SELECT 1 AS X LIMIT 0").isEmpty()) + + def test_is_empty_with_unsupported_types(self): + df = self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval") + self.assertEqual(df.count(), 1) + self.assertFalse(df.isEmpty()) + + def test_session(self): + self.assertEqual(self.connect, self.connect.sql("SELECT 1").sparkSession) + + def test_show(self): + # SPARK-41111: Test the show method + show_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._show_string() + # +---+---+ + # | X| Y| + # +---+---+ + # | 1| 2| + # +---+---+ + expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n" + self.assertEqual(show_str, expected) + + def test_repr(self): + # SPARK-41213: Test the __repr__ method + query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)""" + self.assertEqual( + self.connect.sql(query).__repr__(), + self.spark.sql(query).__repr__(), + ) + + def test_explain_string(self): + # SPARK-41122: test explain API. + plan_str = self.connect.sql("SELECT 1")._explain_string(extended=True) + self.assertTrue("Parsed Logical Plan" in plan_str) + self.assertTrue("Analyzed Logical Plan" in plan_str) + self.assertTrue("Optimized Logical Plan" in plan_str) + self.assertTrue("Physical Plan" in plan_str) + + with self.assertRaises(PySparkValueError) as pe: + self.connect.sql("SELECT 1")._explain_string(mode="unknown") + self.check_error( + exception=pe.exception, + errorClass="UNKNOWN_EXPLAIN_MODE", + messageParameters={"explain_mode": "unknown"}, + ) + + def test_count(self) -> None: + # SPARK-41308: test count() API. + self.assertEqual( + self.connect.read.table(self.tbl_name).count(), + self.spark.read.table(self.tbl_name).count(), + ) + + def test_simple_transform(self) -> None: + """SPARK-41203: Support DF.transform""" + + def transform_df(input_df: CDataFrame) -> CDataFrame: + return input_df.select((CF.col("id") + CF.lit(10)).alias("id")) + + df = self.connect.range(1, 100) + result_left = df.transform(transform_df).collect() + result_right = self.connect.range(11, 110).collect() + self.assertEqual(result_right, result_left) + + # Check assertion. + with self.assertRaises(AssertionError): + df.transform(lambda x: 2) # type: ignore + + def test_alias(self) -> None: + """Testing supported and unsupported alias""" + col0 = self.connect.range(1, 10).select(CF.col("id").alias("name", metadata={"max": 99})).schema.names[0] + self.assertEqual("name", col0) + + with self.assertRaises(SparkConnectException) as exc: + self.connect.range(1, 10).select(CF.col("id").alias("this", "is", "not")).collect() + self.assertIn("(this, is, not)", str(exc.exception)) + + def test_column_regexp(self) -> None: + # SPARK-41438: test dataframe.colRegex() + ndf = self.connect.read.table(self.tbl_name3) + df = self.spark.read.table(self.tbl_name3) + + self.assert_eq( + ndf.select(ndf.colRegex("`tes.*\n.*mn`")).toPandas(), + df.select(df.colRegex("`tes.*\n.*mn`")).toPandas(), + ) + + def test_repartition(self) -> None: + # SPARK-41354: test dataframe.repartition(numPartitions) + self.assert_eq( + self.connect.read.table(self.tbl_name).repartition(10).toPandas(), + self.spark.read.table(self.tbl_name).repartition(10).toPandas(), + ) + + self.assert_eq( + self.connect.read.table(self.tbl_name).coalesce(10).toPandas(), + self.spark.read.table(self.tbl_name).coalesce(10).toPandas(), + ) + + def test_repartition_by_expression(self) -> None: + # SPARK-41354: test dataframe.repartition(expressions) + self.assert_eq( + self.connect.read.table(self.tbl_name).repartition(10, "id").toPandas(), + self.spark.read.table(self.tbl_name).repartition(10, "id").toPandas(), + ) + + self.assert_eq( + self.connect.read.table(self.tbl_name).repartition("id").toPandas(), + self.spark.read.table(self.tbl_name).repartition("id").toPandas(), + ) + + # repartition with unsupported parameter values + with self.assertRaises(AnalysisException): + self.connect.read.table(self.tbl_name).repartition("id+1").toPandas() + + def test_repartition_by_range(self) -> None: + # SPARK-41354: test dataframe.repartitionByRange(expressions) + cdf = self.connect.read.table(self.tbl_name) + sdf = self.spark.read.table(self.tbl_name) + + self.assert_eq( + cdf.repartitionByRange(10, "id").toPandas(), + sdf.repartitionByRange(10, "id").toPandas(), + ) + + self.assert_eq( + cdf.repartitionByRange("id").toPandas(), + sdf.repartitionByRange("id").toPandas(), + ) + + self.assert_eq( + cdf.repartitionByRange(cdf.id.desc()).toPandas(), + sdf.repartitionByRange(sdf.id.desc()).toPandas(), + ) + + # repartitionByRange with unsupported parameter values + with self.assertRaises(AnalysisException): + self.connect.read.table(self.tbl_name).repartitionByRange("id+1").toPandas() + + def test_crossjoin(self): + # SPARK-41227: Test CrossJoin + connect_df = self.connect.read.table(self.tbl_name) + spark_df = self.spark.read.table(self.tbl_name) + self.assert_eq( + set(connect_df.select("id").join(other=connect_df.select("name"), how="cross").toPandas()), + set(spark_df.select("id").join(other=spark_df.select("name"), how="cross").toPandas()), + ) + self.assert_eq( + set(connect_df.select("id").crossJoin(other=connect_df.select("name")).toPandas()), + set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()), + ) + + def test_self_join(self): + # SPARK-47713: this query fails in classic spark + df1 = self.connect.createDataFrame([(1, "a")], schema=["i", "j"]) + df1_filter = df1.filter(df1.i > 0) + df2 = df1.join(df1_filter, df1.i == 1) + self.assertEqual(df2.count(), 1) + self.assertEqual(df2.columns, ["i", "j", "i", "j"]) + self.assertEqual(list(df2.first()), [1, "a", 1, "a"]) + + def test_with_metadata(self): + cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], schema=["age", "name"]) + self.assertEqual(cdf.schema["age"].metadata, {}) + self.assertEqual(cdf.schema["name"].metadata, {}) + + cdf1 = cdf.withMetadata(columnName="age", metadata={"max_age": 5}) + self.assertEqual(cdf1.schema["age"].metadata, {"max_age": 5}) + + cdf2 = cdf.withMetadata(columnName="name", metadata={"names": ["Alice", "Bob"]}) + self.assertEqual(cdf2.schema["name"].metadata, {"names": ["Alice", "Bob"]}) + + with self.assertRaises(PySparkTypeError) as pe: + cdf.withMetadata(columnName="name", metadata=["magic"]) + + self.check_error( + exception=pe.exception, + errorClass="NOT_DICT", + messageParameters={ + "arg_name": "metadata", + "arg_type": "list", + }, + ) + + def test_version(self): + self.assertEqual( + self.connect.version, + self.spark.version, + ) + + def test_same_semantics(self): + plan = self.connect.sql("SELECT 1") + other = self.connect.sql("SELECT 1") + self.assertTrue(plan.sameSemantics(other)) + + def test_semantic_hash(self): + plan = self.connect.sql("SELECT 1") + other = self.connect.sql("SELECT 1") + self.assertEqual( + plan.semanticHash(), + other.semanticHash(), + ) + + def test_sql_with_command(self): + # SPARK-42705: spark.sql should return values from the command. + self.assertEqual(self.connect.sql("show functions").collect(), self.spark.sql("show functions").collect()) + + def test_df_caache(self): + df = self.connect.range(10) + df.cache() + self.assert_eq(10, df.count()) + self.assertTrue(df.is_cached) + + def test_parse_col_name(self): + from pyspark.sql.connect.types import parse_attr_name + + self.assert_eq(parse_attr_name(""), [""]) + + self.assert_eq(parse_attr_name("a"), ["a"]) + self.assert_eq(parse_attr_name("`a`"), ["a"]) + self.assert_eq(parse_attr_name("`a"), None) + self.assert_eq(parse_attr_name("a`"), None) + + self.assert_eq(parse_attr_name("`a`.b"), ["a", "b"]) + self.assert_eq(parse_attr_name("`a`.`b`"), ["a", "b"]) + self.assert_eq(parse_attr_name("`a```.b"), ["a`", "b"]) + self.assert_eq(parse_attr_name("`a``.b"), None) + + self.assert_eq(parse_attr_name("a.b.c"), ["a", "b", "c"]) + self.assert_eq(parse_attr_name("`a`.`b`.`c`"), ["a", "b", "c"]) + self.assert_eq(parse_attr_name("a.`b`.c"), ["a", "b", "c"]) + + self.assert_eq(parse_attr_name("`a.b.c`"), ["a.b.c"]) + self.assert_eq(parse_attr_name("a.`b.c`"), ["a", "b.c"]) + self.assert_eq(parse_attr_name("`a.b`.c"), ["a.b", "c"]) + self.assert_eq(parse_attr_name("`a.b.c"), None) + self.assert_eq(parse_attr_name("a.b.c`"), None) + self.assert_eq(parse_attr_name("`a.`b.`c"), None) + self.assert_eq(parse_attr_name("a`.b`.c`"), None) + + self.assert_eq(parse_attr_name("`ab..c`e.f"), None) + + def test_verify_col_name(self): + from pyspark.sql.connect.types import verify_col_name + + cdf = ( + self.connect.range(10) + .withColumn("v", CF.lit(123)) + .withColumn("s", CF.struct("id", "v")) + .withColumn("m", CF.struct("s", "v")) + .withColumn("a", CF.array("s")) + ) + + # root + # |-- id: long (nullable = false) + # |-- v: integer (nullable = false) + # |-- s: struct (nullable = false) + # | |-- id: long (nullable = false) + # | |-- v: integer (nullable = false) + # |-- m: struct (nullable = false) + # | |-- s: struct (nullable = false) + # | | |-- id: long (nullable = false) + # | | |-- v: integer (nullable = false) + # | |-- v: integer (nullable = false) + # |-- a: array (nullable = false) + # | |-- element: struct (containsNull = false) + # | | |-- id: long (nullable = false) + # | | |-- v: integer (nullable = false) + + self.assertTrue(verify_col_name("id", cdf.schema)) + self.assertTrue(verify_col_name("`id`", cdf.schema)) + + self.assertTrue(verify_col_name("v", cdf.schema)) + self.assertTrue(verify_col_name("`v`", cdf.schema)) + + self.assertFalse(verify_col_name("x", cdf.schema)) + self.assertFalse(verify_col_name("`x`", cdf.schema)) + + self.assertTrue(verify_col_name("s", cdf.schema)) + self.assertTrue(verify_col_name("`s`", cdf.schema)) + self.assertTrue(verify_col_name("s.id", cdf.schema)) + self.assertTrue(verify_col_name("s.`id`", cdf.schema)) + self.assertTrue(verify_col_name("`s`.id", cdf.schema)) + self.assertTrue(verify_col_name("`s`.`id`", cdf.schema)) + self.assertFalse(verify_col_name("`s.id`", cdf.schema)) + + self.assertTrue(verify_col_name("m", cdf.schema)) + self.assertTrue(verify_col_name("`m`", cdf.schema)) + self.assertTrue(verify_col_name("m.s.id", cdf.schema)) + self.assertTrue(verify_col_name("m.s.`id`", cdf.schema)) + self.assertTrue(verify_col_name("m.`s`.id", cdf.schema)) + self.assertTrue(verify_col_name("`m`.`s`.`id`", cdf.schema)) + self.assertFalse(verify_col_name("m.`s.id`", cdf.schema)) + + self.assertTrue(verify_col_name("a", cdf.schema)) + self.assertTrue(verify_col_name("`a`", cdf.schema)) + self.assertTrue(verify_col_name("a.`v`", cdf.schema)) + self.assertTrue(verify_col_name("a.`v`", cdf.schema)) + self.assertTrue(verify_col_name("`a`.v", cdf.schema)) + self.assertTrue(verify_col_name("`a`.`v`", cdf.schema)) + self.assertFalse(verify_col_name("`a`.`x`", cdf.schema)) + + cdf = ( + self.connect.range(10) + .withColumn("v", CF.lit(123)) + .withColumn("s.s", CF.struct("id", "v")) + .withColumn("m`", CF.struct("`s.s`", "v")) + ) + + # root + # |-- id: long (nullable = false) + # |-- v: string (nullable = false) + # |-- s.s: struct (nullable = false) + # | |-- id: long (nullable = false) + # | |-- v: string (nullable = false) + # |-- m`: struct (nullable = false) + # | |-- s.s: struct (nullable = false) + # | | |-- id: long (nullable = false) + # | | |-- v: string (nullable = false) + # | |-- v: string (nullable = false) + + self.assertFalse(verify_col_name("s", cdf.schema)) + self.assertFalse(verify_col_name("`s`", cdf.schema)) + self.assertFalse(verify_col_name("s.s", cdf.schema)) + self.assertFalse(verify_col_name("s.`s`", cdf.schema)) + self.assertFalse(verify_col_name("`s`.s", cdf.schema)) + self.assertTrue(verify_col_name("`s.s`", cdf.schema)) + + self.assertFalse(verify_col_name("m", cdf.schema)) + self.assertFalse(verify_col_name("`m`", cdf.schema)) + self.assertTrue(verify_col_name("`m```", cdf.schema)) + + self.assertFalse(verify_col_name("`m```.s", cdf.schema)) + self.assertFalse(verify_col_name("`m```.`s`", cdf.schema)) + self.assertFalse(verify_col_name("`m```.s.s", cdf.schema)) + self.assertFalse(verify_col_name("`m```.s.`s`", cdf.schema)) + self.assertTrue(verify_col_name("`m```.`s.s`", cdf.schema)) + + self.assertFalse(verify_col_name("`m```.s.s.v", cdf.schema)) + self.assertFalse(verify_col_name("`m```.s.`s`.v", cdf.schema)) + self.assertTrue(verify_col_name("`m```.`s.s`.v", cdf.schema)) + self.assertTrue(verify_col_name("`m```.`s.s`.`v`", cdf.schema)) + + def test_truncate_message(self): + cdf1 = self.connect.createDataFrame( + [ + ("a B c"), + ("X y Z"), + ], + ["a" * 4096], + ) + plan1 = cdf1._plan.to_proto(self.connect._client) + + proto_string_1 = self.connect._client._proto_to_string(plan1, False) + self.assertTrue(len(proto_string_1) > 10000, len(proto_string_1)) + proto_string_truncated_1 = self.connect._client._proto_to_string(plan1, True) + self.assertTrue(len(proto_string_truncated_1) < 4000, len(proto_string_truncated_1)) + + cdf2 = cdf1.select("a" * 4096, "a" * 4096, "a" * 4096) + plan2 = cdf2._plan.to_proto(self.connect._client) + + proto_string_2 = self.connect._client._proto_to_string(plan2, False) + self.assertTrue(len(proto_string_2) > 20000, len(proto_string_2)) + proto_string_truncated_2 = self.connect._client._proto_to_string(plan2, True) + self.assertTrue(len(proto_string_truncated_2) < 8000, len(proto_string_truncated_2)) + + cdf3 = cdf1.select("a" * 4096) + for _ in range(64): + cdf3 = cdf3.select("a" * 4096) + plan3 = cdf3._plan.to_proto(self.connect._client) + + proto_string_3 = self.connect._client._proto_to_string(plan3, False) + self.assertTrue(len(proto_string_3) > 128000, len(proto_string_3)) + proto_string_truncated_3 = self.connect._client._proto_to_string(plan3, True) + self.assertTrue(len(proto_string_truncated_3) < 64000, len(proto_string_truncated_3)) + + +class SparkConnectGCTests(SparkConnectSQLTestCase): + @classmethod + def setUpClass(cls): + cls.origin = os.getenv("USER", None) + os.environ["USER"] = "SparkConnectGCTests" + super(SparkConnectGCTests, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + super(SparkConnectGCTests, cls).tearDownClass() + if cls.origin is not None: + os.environ["USER"] = cls.origin + else: + del os.environ["USER"] + + def test_garbage_collection_checkpoint(self): + # SPARK-48258: Make sure garbage-collecting DataFrame remove the paired state + # in Spark Connect server + df = self.connect.range(10).localCheckpoint() + self.assertIsNotNone(df._plan._relation_id) + cached_remote_relation_id = df._plan._relation_id + + jvm = self.spark._jvm + session_holder = getattr( + getattr( + jvm.org.apache.spark.sql.connect.service, + "SparkConnectService$", + ), + "MODULE$", + ).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id) + + # Check the state exists. + self.assertIsNotNone(session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)) + + del df + gc.collect() + + def condition(): + # Check the state was removed up on garbage-collection. + self.assertIsNone(session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)) + + eventually(catch_assertions=True)(condition)() + + def test_garbage_collection_derived_checkpoint(self): + # SPARK-48258: Should keep the cached remote relation when derived DataFrames exist + df = self.connect.range(10).localCheckpoint() + self.assertIsNotNone(df._plan._relation_id) + derived = df.repartition(10) + cached_remote_relation_id = df._plan._relation_id + + jvm = self.spark._jvm + session_holder = getattr( + getattr( + jvm.org.apache.spark.sql.connect.service, + "SparkConnectService$", + ), + "MODULE$", + ).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id) + + # Check the state exists. + self.assertIsNotNone(session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)) + + del df + gc.collect() + + def condition(): + self.assertIsNone(session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)) + + # Should not remove the cache + with self.assertRaises(AssertionError): + eventually(catch_assertions=True, timeout=5)(condition)() + + del derived + gc.collect() + + eventually(catch_assertions=True)(condition)() + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_connect_basic import * + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/tests/connect/test_session.py b/tests/connect/test_session.py new file mode 100644 index 0000000000..6fdea1e7b4 --- /dev/null +++ b/tests/connect/test_session.py @@ -0,0 +1,265 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import threading +import time +import unittest +from typing import Optional + +from pyspark import InheritableThread, inheritable_thread_target +from pyspark.sql.connect.client import DefaultChannelBuilder +from pyspark.sql.connect.session import SparkSession as RemoteSparkSession +from pyspark.testing.connectutils import should_test_connect + +if should_test_connect: + from pyspark.testing.connectutils import ReusedConnectTestCase + + +class CustomChannelBuilder(DefaultChannelBuilder): + @property + def userId(self) -> Optional[str]: + return "abc" + + +class SparkSessionTestCase(unittest.TestCase): + def test_fails_to_create_session_without_remote_and_channel_builder(self): + with self.assertRaises(ValueError): + RemoteSparkSession.builder.getOrCreate() + + def test_fails_to_create_when_both_remote_and_channel_builder_are_specified(self): + with self.assertRaises(ValueError): + ( + RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder("sc://localhost")) + .remote("sc://localhost") + .getOrCreate() + ) + + def test_creates_session_with_channel_builder(self): + test_session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder("sc://other")).getOrCreate() + host = test_session.client.host + test_session.stop() + + self.assertEqual("other", host) + + def test_creates_session_with_remote(self): + test_session = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + host = test_session.client.host + test_session.stop() + + self.assertEqual("other", host) + + def test_session_stop(self): + session = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + + self.assertFalse(session.is_stopped) + session.stop() + self.assertTrue(session.is_stopped) + + def test_session_create_sets_active_session(self): + session = RemoteSparkSession.builder.remote("sc://abc").create() + session2 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + + self.assertIs(session, session2) + session.stop() + + def test_active_session_expires_when_client_closes(self): + s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + s2 = RemoteSparkSession.getActiveSession() + + self.assertIs(s1, s2) + + # We don't call close() to avoid executing ExecutePlanResponseReattachableIterator + s1._client._closed = True + + self.assertIsNone(RemoteSparkSession.getActiveSession()) + s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + + self.assertIsNot(s1, s3) + + def test_default_session_expires_when_client_closes(self): + s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + s2 = RemoteSparkSession.getDefaultSession() + + self.assertIs(s1, s2) + + # We don't call close() to avoid executing ExecutePlanResponseReattachableIterator + s1._client._closed = True + + self.assertIsNone(RemoteSparkSession.getDefaultSession()) + s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + + self.assertIsNot(s1, s3) + + +class JobCancellationTests(ReusedConnectTestCase): + def test_tags(self): + self.spark.clearTags() + self.spark.addTag("a") + self.assertEqual(self.spark.getTags(), {"a"}) + self.spark.addTag("b") + self.spark.removeTag("a") + self.assertEqual(self.spark.getTags(), {"b"}) + self.spark.addTag("c") + self.spark.clearTags() + self.assertEqual(self.spark.getTags(), set()) + self.spark.clearTags() + + def test_tags_multithread(self): + output1 = None + output2 = None + + def tag1(): + nonlocal output1 + + self.spark.addTag("tag1") + output1 = self.spark.getTags() + + def tag2(): + nonlocal output2 + + self.spark.addTag("tag2") + output2 = self.spark.getTags() + + t1 = threading.Thread(target=tag1) + t1.start() + t1.join() + t2 = threading.Thread(target=tag2) + t2.start() + t2.join() + + self.assertIsNotNone(output1) + self.assertEqual(output1, {"tag1"}) + self.assertIsNotNone(output2) + self.assertEqual(output2, {"tag2"}) + + def test_interrupt_tag(self): + thread_ids = range(4) + self.check_job_cancellation( + lambda job_group: self.spark.addTag(job_group), + lambda job_group: self.spark.interruptTag(job_group), + thread_ids, + [i for i in thread_ids if i % 2 == 0], + [i for i in thread_ids if i % 2 != 0], + ) + self.spark.clearTags() + + def test_interrupt_all(self): + thread_ids = range(4) + self.check_job_cancellation( + lambda job_group: None, + lambda job_group: self.spark.interruptAll(), + thread_ids, + thread_ids, + [], + ) + self.spark.clearTags() + + def check_job_cancellation(self, setter, canceller, thread_ids, thread_ids_to_cancel, thread_ids_to_run): + job_id_a = "job_ids_to_cancel" + job_id_b = "job_ids_to_run" + threads = [] + + # A list which records whether job is cancelled. + # The index of the array is the thread index which job run in. + is_job_cancelled = [False for _ in thread_ids] + + def run_job(job_id, index): + """ + Executes a job with the group ``job_group``. Each job waits for 3 seconds + and then exits. + """ + try: + setter(job_id) + + def func(itr): + for pdf in itr: + time.sleep(pdf._1.iloc[0]) + yield pdf + + self.spark.createDataFrame([[20]]).repartition(1).mapInPandas(func, schema="_1 LONG").collect() + is_job_cancelled[index] = False + except Exception: + # Assume that exception means job cancellation. + is_job_cancelled[index] = True + + # Test if job succeeded when not cancelled. + run_job(job_id_a, 0) + self.assertFalse(is_job_cancelled[0]) + self.spark.clearTags() + + # Run jobs + for i in thread_ids_to_cancel: + t = threading.Thread(target=run_job, args=(job_id_a, i)) + t.start() + threads.append(t) + + for i in thread_ids_to_run: + t = threading.Thread(target=run_job, args=(job_id_b, i)) + t.start() + threads.append(t) + + # Wait to make sure all jobs are executed. + time.sleep(10) + # And then, cancel one job group. + canceller(job_id_a) + + # Wait until all threads launching jobs are finished. + for t in threads: + t.join() + + for i in thread_ids_to_cancel: + self.assertTrue(is_job_cancelled[i], f"Thread {i}: Job in group A was not cancelled.") + + for i in thread_ids_to_run: + self.assertFalse(is_job_cancelled[i], f"Thread {i}: Job in group B did not succeeded.") + + def test_inheritable_tags(self): + self.check_inheritable_tags(create_thread=lambda target, session: InheritableThread(target, session=session)) + self.check_inheritable_tags( + create_thread=lambda target, session: threading.Thread(target=inheritable_thread_target(session)(target)) + ) + + # Test decorator usage + @inheritable_thread_target(self.spark) + def func(target): + return target() + + self.check_inheritable_tags(create_thread=lambda target, session: threading.Thread(target=func, args=(target,))) + + def check_inheritable_tags(self, create_thread): + spark = self.spark + spark.addTag("a") + first = set() + second = set() + + def get_inner_local_prop(): + spark.addTag("c") + second.update(spark.getTags()) + + def get_outer_local_prop(): + spark.addTag("b") + first.update(spark.getTags()) + t2 = create_thread(target=get_inner_local_prop, session=spark) + t2.start() + t2.join() + + t1 = create_thread(target=get_outer_local_prop, session=spark) + t1.start() + t1.join() + + self.assertEqual(spark.getTags(), {"a"}) + self.assertEqual(first, {"a", "b"}) + self.assertEqual(second, {"a", "b", "c"}) From e708a594635c4e1ab03e050b0bd93d0014a1441e Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 6 Nov 2024 13:54:17 -0800 Subject: [PATCH 2/6] remove a bunch of uneeded tests --- tests/connect/conf.py | 150 -- tests/connect/test_client.py | 434 ----- tests/connect/test_conf.py | 113 -- tests/connect/test_config.py | 153 -- tests/connect/test_connect_basic.py | 1477 ----------------- ...test_connect.py => test_parquet_simple.py} | 2 +- tests/connect/test_session.py | 265 --- 7 files changed, 1 insertion(+), 2593 deletions(-) delete mode 100644 tests/connect/conf.py delete mode 100644 tests/connect/test_client.py delete mode 100644 tests/connect/test_conf.py delete mode 100644 tests/connect/test_config.py delete mode 100755 tests/connect/test_connect_basic.py rename tests/connect/{test_connect.py => test_parquet_simple.py} (98%) delete mode 100644 tests/connect/test_session.py diff --git a/tests/connect/conf.py b/tests/connect/conf.py deleted file mode 100644 index 2ebd63ede6..0000000000 --- a/tests/connect/conf.py +++ /dev/null @@ -1,150 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from pyspark.errors import PySparkTypeError, PySparkValueError -from pyspark.sql.connect.utils import check_dependencies - -check_dependencies(__name__) - -import warnings -from typing import Any, Dict, Optional, Union, cast - -from pyspark import _NoValue -from pyspark._globals import _NoValueType -from pyspark.sql.conf import RuntimeConfig as PySparkRuntimeConfig -from pyspark.sql.connect import proto -from pyspark.sql.connect.client import SparkConnectClient - - -class RuntimeConf: - def __init__(self, client: SparkConnectClient) -> None: - """Create a new RuntimeConfig.""" - self._client = client - - __init__.__doc__ = PySparkRuntimeConfig.__init__.__doc__ - - def set(self, key: str, value: Union[str, int, bool]) -> None: - if isinstance(value, bool): - value = "true" if value else "false" - elif isinstance(value, int): - value = str(value) - op_set = proto.ConfigRequest.Set(pairs=[proto.KeyValue(key=key, value=value)]) - operation = proto.ConfigRequest.Operation(set=op_set) - result = self._client.config(operation) - for warn in result.warnings: - warnings.warn(warn) - - set.__doc__ = PySparkRuntimeConfig.set.__doc__ - - def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> Optional[str]: - self._checkType(key, "key") - if default is _NoValue: - op_get = proto.ConfigRequest.Get(keys=[key]) - operation = proto.ConfigRequest.Operation(get=op_get) - else: - if default is not None: - self._checkType(default, "default") - op_get_with_default = proto.ConfigRequest.GetWithDefault( - pairs=[proto.KeyValue(key=key, value=cast(Optional[str], default))] - ) - operation = proto.ConfigRequest.Operation(get_with_default=op_get_with_default) - result = self._client.config(operation) - return result.pairs[0][1] - - get.__doc__ = PySparkRuntimeConfig.get.__doc__ - - @property - def getAll(self) -> Dict[str, str]: - op_get_all = proto.ConfigRequest.GetAll() - operation = proto.ConfigRequest.Operation(get_all=op_get_all) - result = self._client.config(operation) - confs: Dict[str, str] = dict() - for key, value in result.pairs: - assert value is not None - confs[key] = value - return confs - - getAll.__doc__ = PySparkRuntimeConfig.getAll.__doc__ - - def unset(self, key: str) -> None: - op_unset = proto.ConfigRequest.Unset(keys=[key]) - operation = proto.ConfigRequest.Operation(unset=op_unset) - result = self._client.config(operation) - for warn in result.warnings: - warnings.warn(warn) - - unset.__doc__ = PySparkRuntimeConfig.unset.__doc__ - - def isModifiable(self, key: str) -> bool: - op_is_modifiable = proto.ConfigRequest.IsModifiable(keys=[key]) - operation = proto.ConfigRequest.Operation(is_modifiable=op_is_modifiable) - result = self._client.config(operation).pairs[0][1] - if result == "true": - return True - elif result == "false": - return False - else: - raise PySparkValueError( - errorClass="VALUE_NOT_ALLOWED", - messageParameters={"arg_name": "result", "allowed_values": "'true' or 'false'"}, - ) - - isModifiable.__doc__ = PySparkRuntimeConfig.isModifiable.__doc__ - - def _checkType(self, obj: Any, identifier: str) -> None: - """Assert that an object is of type str.""" - if not isinstance(obj, str): - raise PySparkTypeError( - errorClass="NOT_STR", - messageParameters={ - "arg_name": identifier, - "arg_type": type(obj).__name__, - }, - ) - - -RuntimeConf.__doc__ = PySparkRuntimeConfig.__doc__ - - -def _test() -> None: - import doctest - import sys - - import pyspark.sql.connect.conf - from pyspark.sql import SparkSession as PySparkSession - - globs = pyspark.sql.connect.conf.__dict__.copy() - globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.conf tests") - # .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) - .remote("127.0.0.1:50051") - .getOrCreate() - ) - - (failure_count, test_count) = doctest.testmod( - pyspark.sql.connect.conf, - globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.IGNORE_EXCEPTION_DETAIL, - ) - - globs["spark"].stop() - - if failure_count: - sys.exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/tests/connect/test_client.py b/tests/connect/test_client.py deleted file mode 100644 index f0528b466d..0000000000 --- a/tests/connect/test_client.py +++ /dev/null @@ -1,434 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest -import uuid -from collections.abc import Generator -from typing import Any, Optional, Union - -from pyspark.testing.connectutils import connect_requirement_message, should_test_connect -from pyspark.testing.utils import eventually - -if should_test_connect: - import grpc - import pandas as pd - import pyarrow as pa - import pyspark.sql.connect.proto as proto - from google.rpc import status_pb2 - from pyspark.errors import PySparkRuntimeError, RetriesExceeded - from pyspark.sql.connect.client import DefaultChannelBuilder, SparkConnectClient - from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator - from pyspark.sql.connect.client.retries import ( - DefaultPolicy, - Retrying, - ) - - class TestPolicy(DefaultPolicy): - def __init__(self): - super().__init__( - max_retries=3, - backoff_multiplier=4.0, - initial_backoff=10, - max_backoff=10, - jitter=10, - min_jitter_threshold=10, - ) - - class TestException(grpc.RpcError, grpc.Call): - """Exception mock to test retryable exceptions.""" - - def __init__( - self, - msg, - code=grpc.StatusCode.INTERNAL, - trailing_status: Union[status_pb2.Status, None] = None, - ): - self.msg = msg - self._code = code - self._trailer: dict[str, Any] = {} - if trailing_status is not None: - self._trailer["grpc-status-details-bin"] = trailing_status.SerializeToString() - - def code(self): - return self._code - - def __str__(self): - return self.msg - - def details(self): - return self.msg - - def trailing_metadata(self): - return None if not self._trailer else self._trailer.items() - - class ResponseGenerator(Generator): - """This class is used to generate values that are returned by the streaming - iterator of the GRPC stub.""" - - def __init__(self, funs): - self._funs = funs - self._iterator = iter(self._funs) - - def send(self, value: Any) -> proto.ExecutePlanResponse: - val = next(self._iterator) - if callable(val): - return val() - else: - return val - - def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any: - super().throw(type, value, traceback) - - def close(self) -> None: - return super().close() - - class MockSparkConnectStub: - """Simple mock class for the GRPC stub used by the re-attachable execution.""" - - def __init__(self, execute_ops=None, attach_ops=None): - self._execute_ops = execute_ops - self._attach_ops = attach_ops - # Call counters - self.execute_calls = 0 - self.release_calls = 0 - self.release_until_calls = 0 - self.attach_calls = 0 - - def ExecutePlan(self, *args, **kwargs): - self.execute_calls += 1 - return self._execute_ops - - def ReattachExecute(self, *args, **kwargs): - self.attach_calls += 1 - return self._attach_ops - - def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs): - if req.HasField("release_all"): - self.release_calls += 1 - elif req.HasField("release_until"): - print("increment") - self.release_until_calls += 1 - - class MockService: - # Simplest mock of the SparkConnectService. - # If this needs more complex logic, it needs to be replaced with Python mocking. - - req: Optional[proto.ExecutePlanRequest] - - def __init__(self, session_id: str): - self._session_id = session_id - self.req = None - - def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): - self.req = req - resp = proto.ExecutePlanResponse() - resp.session_id = self._session_id - - pdf = pd.DataFrame(data={"col1": [1, 2]}) - schema = pa.Schema.from_pandas(pdf) - table = pa.Table.from_pandas(pdf) - sink = pa.BufferOutputStream() - - writer = pa.ipc.new_stream(sink, schema=schema) - writer.write(table) - writer.close() - - buf = sink.getvalue() - resp.arrow_batch.data = buf.to_pybytes() - resp.arrow_batch.row_count = 2 - return [resp] - - def Interrupt(self, req: proto.InterruptRequest, metadata): - self.req = req - resp = proto.InterruptResponse() - resp.session_id = self._session_id - return resp - - -@unittest.skipIf(not should_test_connect, connect_requirement_message) -class SparkConnectClientTestCase(unittest.TestCase): - def test_user_agent_passthrough(self): - client = SparkConnectClient("sc://foo/;user_agent=bar", use_reattachable_execute=False) - mock = MockService(client._session_id) - client._stub = mock - - command = proto.Command() - client.execute_command(command) - - self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") - self.assertRegex(mock.req.client_type, r"^bar spark/[^ ]+ os/[^ ]+ python/[^ ]+$") - - def test_user_agent_default(self): - client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) - mock = MockService(client._session_id) - client._stub = mock - - command = proto.Command() - client.execute_command(command) - - self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") - self.assertRegex(mock.req.client_type, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$") - - def test_properties(self): - client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) - self.assertEqual(client.token, "bar") - self.assertEqual(client.host, "foo") - - client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) - self.assertIsNone(client.token) - - def test_channel_builder(self): - class CustomChannelBuilder(DefaultChannelBuilder): - @property - def userId(self) -> Optional[str]: - return "abc" - - client = SparkConnectClient(CustomChannelBuilder("sc://foo/"), use_reattachable_execute=False) - - self.assertEqual(client._user_id, "abc") - - def test_interrupt_all(self): - client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) - mock = MockService(client._session_id) - client._stub = mock - - client.interrupt_all() - self.assertIsNotNone(mock.req, "Interrupt API was not called when expected") - - def test_is_closed(self): - client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) - - self.assertFalse(client.is_closed) - client.close() - self.assertTrue(client.is_closed) - - def test_retry(self): - client = SparkConnectClient("sc://foo/;token=bar") - - total_sleep = 0 - - def sleep(t): - nonlocal total_sleep - total_sleep += t - - try: - for attempt in Retrying(client._retry_policies, sleep=sleep): - with attempt: - raise TestException("Retryable error", grpc.StatusCode.UNAVAILABLE) - except RetriesExceeded: - pass - - # tolerated at least 10 mins of fails - self.assertGreaterEqual(total_sleep, 600) - - def test_retry_client_unit(self): - client = SparkConnectClient("sc://foo/;token=bar") - - policyA = TestPolicy() - policyB = DefaultPolicy() - - client.set_retry_policies([policyA, policyB]) - - self.assertEqual(client.get_retry_policies(), [policyA, policyB]) - - def test_channel_builder_with_session(self): - dummy = str(uuid.uuid4()) - chan = DefaultChannelBuilder(f"sc://foo/;session_id={dummy}") - client = SparkConnectClient(chan) - self.assertEqual(client._session_id, chan.session_id) - - -@unittest.skipIf(not should_test_connect, connect_requirement_message) -class SparkConnectClientReattachTestCase(unittest.TestCase): - def setUp(self) -> None: - self.request = proto.ExecutePlanRequest() - self.retrying = lambda: Retrying(TestPolicy()) - self.response = proto.ExecutePlanResponse( - response_id="1", - ) - self.finished = proto.ExecutePlanResponse( - result_complete=proto.ExecutePlanResponse.ResultComplete(), - response_id="2", - ) - - def _stub_with(self, execute=None, attach=None): - return MockSparkConnectStub( - execute_ops=ResponseGenerator(execute) if execute is not None else None, - attach_ops=ResponseGenerator(attach) if attach is not None else None, - ) - - def test_basic_flow(self): - stub = self._stub_with([self.response, self.finished]) - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) - for b in ite: - pass - - def check_all(): - self.assertEqual(0, stub.attach_calls) - self.assertEqual(1, stub.release_until_calls) - self.assertEqual(1, stub.release_calls) - self.assertEqual(1, stub.execute_calls) - - eventually(timeout=1, catch_assertions=True)(check_all)() - - def test_fail_during_execute(self): - def fatal(): - raise TestException("Fatal") - - stub = self._stub_with([self.response, fatal]) - with self.assertRaises(TestException): - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) - for b in ite: - pass - - def check(): - self.assertEqual(0, stub.attach_calls) - self.assertEqual(1, stub.release_calls) - self.assertEqual(1, stub.release_until_calls) - self.assertEqual(1, stub.execute_calls) - - eventually(timeout=1, catch_assertions=True)(check)() - - def test_fail_and_retry_during_execute(self): - def non_fatal(): - raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) - - stub = self._stub_with([self.response, non_fatal], [self.response, self.response, self.finished]) - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) - for b in ite: - pass - - def check(): - self.assertEqual(1, stub.attach_calls) - self.assertEqual(1, stub.release_calls) - self.assertEqual(3, stub.release_until_calls) - self.assertEqual(1, stub.execute_calls) - - eventually(timeout=1, catch_assertions=True)(check)() - - def test_fail_and_retry_during_reattach(self): - count = 0 - - def non_fatal(): - nonlocal count - if count < 2: - count += 1 - raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) - else: - return proto.ExecutePlanResponse() - - stub = self._stub_with([self.response, non_fatal], [self.response, non_fatal, self.response, self.finished]) - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) - for b in ite: - pass - - def check(): - self.assertEqual(2, stub.attach_calls) - self.assertEqual(3, stub.release_until_calls) - self.assertEqual(1, stub.release_calls) - self.assertEqual(1, stub.execute_calls) - - eventually(timeout=1, catch_assertions=True)(check)() - - def test_not_found_recovers(self): - """SPARK-48056: Assert that the client recovers from session or operation not - found error if no partial responses were previously received. - """ - - def not_found_recovers(error_code: str): - def not_found(): - raise TestException( - error_code, - grpc.StatusCode.UNAVAILABLE, - trailing_status=status_pb2.Status(code=14, message=error_code, details=""), - ) - - stub = self._stub_with([not_found, self.finished]) - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) - - for _ in ite: - pass - - def checks(): - self.assertEqual(2, stub.execute_calls) - self.assertEqual(0, stub.attach_calls) - self.assertEqual(0, stub.release_calls) - self.assertEqual(0, stub.release_until_calls) - - eventually(timeout=1, catch_assertions=True)(checks)() - - parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"] - for b in parameters: - not_found_recovers(b) - - def test_not_found_fails(self): - """SPARK-48056: Assert that the client fails from session or operation not found error - if a partial response was previously received. - """ - - def not_found_fails(error_code: str): - def not_found(): - raise TestException( - error_code, - grpc.StatusCode.UNAVAILABLE, - trailing_status=status_pb2.Status(code=14, message=error_code, details=""), - ) - - stub = self._stub_with([self.response], [not_found]) - - with self.assertRaises(PySparkRuntimeError) as e: - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) - for _ in ite: - pass - - self.assertTrue("RESPONSE_ALREADY_RECEIVED" in e.exception.getMessage()) - - def checks(): - self.assertEqual(1, stub.execute_calls) - self.assertEqual(1, stub.attach_calls) - self.assertEqual(1, stub.release_calls) - self.assertEqual(1, stub.release_until_calls) - - eventually(timeout=1, catch_assertions=True)(checks)() - - parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"] - for b in parameters: - not_found_fails(b) - - def test_observed_session_id(self): - stub = self._stub_with([self.response, self.finished]) - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) - session_id = "test-session-id" - - reattach = ite._create_reattach_execute_request() - self.assertEqual(reattach.client_observed_server_side_session_id, "") - - self.request.client_observed_server_side_session_id = session_id - reattach = ite._create_reattach_execute_request() - self.assertEqual(reattach.client_observed_server_side_session_id, session_id) - - -if __name__ == "__main__": - from pyspark.sql.tests.connect.client.test_client import * - - try: - import xmlrunner # type: ignore - - testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) - except ImportError: - testRunner = None - unittest.main(testRunner=testRunner, verbosity=2) diff --git a/tests/connect/test_conf.py b/tests/connect/test_conf.py deleted file mode 100644 index 214e1833a1..0000000000 --- a/tests/connect/test_conf.py +++ /dev/null @@ -1,113 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from decimal import Decimal - -from pyspark.errors import IllegalArgumentException, PySparkTypeError -from pyspark.testing.sqlutils import ReusedSQLTestCase - - -class ConfTestsMixin: - def test_conf(self): - spark = self.spark - spark.conf.set("bogo", "sipeo") - self.assertEqual(spark.conf.get("bogo"), "sipeo") - spark.conf.set("bogo", "ta") - self.assertEqual(spark.conf.get("bogo"), "ta") - self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") - self.assertEqual(spark.conf.get("not.set", "ta"), "ta") - self.assertRaisesRegex(Exception, "not.set", lambda: spark.conf.get("not.set")) - spark.conf.unset("bogo") - self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") - - self.assertEqual(spark.conf.get("hyukjin", None), None) - - # This returns 'STATIC' because it's the default value of - # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in - # `spark.conf.get` is unset. - self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") - - # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but - # `defaultValue` in `spark.conf.get` is set to None. - self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) - - self.assertTrue(spark.conf.isModifiable("spark.sql.execution.arrow.maxRecordsPerBatch")) - self.assertFalse(spark.conf.isModifiable("spark.sql.warehouse.dir")) - - def test_conf_with_python_objects(self): - spark = self.spark - - try: - for value, expected in [(True, "true"), (False, "false")]: - spark.conf.set("foo", value) - self.assertEqual(spark.conf.get("foo"), expected) - - spark.conf.set("foo", 1) - self.assertEqual(spark.conf.get("foo"), "1") - - with self.assertRaises(IllegalArgumentException): - spark.conf.set("foo", None) - - with self.assertRaises(Exception): - spark.conf.set("foo", Decimal(1)) - - with self.assertRaises(PySparkTypeError) as pe: - spark.conf.get(123) - - self.check_error( - exception=pe.exception, - errorClass="NOT_STR", - messageParameters={ - "arg_name": "key", - "arg_type": "int", - }, - ) - finally: - spark.conf.unset("foo") - - def test_get_all(self): - spark = self.spark - all_confs = spark.conf.getAll - - self.assertTrue(len(all_confs) > 0) - self.assertNotIn("foo", all_confs) - - try: - spark.conf.set("foo", "bar") - updated = spark.conf.getAll - - self.assertEqual(len(updated), len(all_confs) + 1) - self.assertIn("foo", updated) - finally: - spark.conf.unset("foo") - - -class ConfTests(ConfTestsMixin, ReusedSQLTestCase): - pass - - -if __name__ == "__main__": - import unittest - - from pyspark.sql.tests.test_conf import * - - try: - import xmlrunner - - testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) - except ImportError: - testRunner = None - unittest.main(testRunner=testRunner, verbosity=2) diff --git a/tests/connect/test_config.py b/tests/connect/test_config.py deleted file mode 100644 index c1fa2f1f45..0000000000 --- a/tests/connect/test_config.py +++ /dev/null @@ -1,153 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from pyspark import pandas as ps -from pyspark.pandas import config -from pyspark.pandas.config import DictWrapper, Option -from pyspark.testing.pandasutils import PandasOnSparkTestCase - - -class ConfigTestsMixin: - def setUp(self): - config._options_dict["test.config"] = Option(key="test.config", doc="", default="default") - - config._options_dict["test.config.list"] = Option(key="test.config.list", doc="", default=[], types=list) - config._options_dict["test.config.float"] = Option(key="test.config.float", doc="", default=1.2, types=float) - - config._options_dict["test.config.int"] = Option( - key="test.config.int", - doc="", - default=1, - types=int, - check_func=(lambda v: v > 0, "bigger then 0"), - ) - config._options_dict["test.config.int.none"] = Option( - key="test.config.int", doc="", default=None, types=(int, type(None)) - ) - - def tearDown(self): - ps.reset_option("test.config") - del config._options_dict["test.config"] - del config._options_dict["test.config.list"] - del config._options_dict["test.config.float"] - del config._options_dict["test.config.int"] - del config._options_dict["test.config.int.none"] - - def test_get_set_reset_option(self): - self.assertEqual(ps.get_option("test.config"), "default") - - ps.set_option("test.config", "value") - self.assertEqual(ps.get_option("test.config"), "value") - - ps.reset_option("test.config") - self.assertEqual(ps.get_option("test.config"), "default") - - def test_get_set_reset_option_different_types(self): - ps.set_option("test.config.list", [1, 2, 3, 4]) - self.assertEqual(ps.get_option("test.config.list"), [1, 2, 3, 4]) - - ps.set_option("test.config.float", 5.0) - self.assertEqual(ps.get_option("test.config.float"), 5.0) - - ps.set_option("test.config.int", 123) - self.assertEqual(ps.get_option("test.config.int"), 123) - - self.assertEqual(ps.get_option("test.config.int.none"), None) # default None - ps.set_option("test.config.int.none", 123) - self.assertEqual(ps.get_option("test.config.int.none"), 123) - ps.set_option("test.config.int.none", None) - self.assertEqual(ps.get_option("test.config.int.none"), None) - - def test_different_types(self): - with self.assertRaisesRegex(TypeError, "was "): - ps.set_option("test.config.list", 1) - - with self.assertRaisesRegex(TypeError, "however, expected types are"): - ps.set_option("test.config.float", "abc") - - with self.assertRaisesRegex(TypeError, "[]"): - ps.set_option("test.config.int", "abc") - - with self.assertRaisesRegex(TypeError, "(, )"): - ps.set_option("test.config.int.none", "abc") - - def test_check_func(self): - with self.assertRaisesRegex(ValueError, "bigger then 0"): - ps.set_option("test.config.int", -1) - - def test_unknown_option(self): - with self.assertRaisesRegex(config.OptionError, "No such option"): - ps.get_option("unknown") - - with self.assertRaisesRegex(config.OptionError, "Available options"): - ps.set_option("unknown", "value") - - with self.assertRaisesRegex(config.OptionError, "test.config"): - ps.reset_option("unknown") - - def test_namespace_access(self): - try: - self.assertEqual(ps.options.compute.max_rows, ps.get_option("compute.max_rows")) - ps.options.compute.max_rows = 0 - self.assertEqual(ps.options.compute.max_rows, 0) - self.assertTrue(isinstance(ps.options.compute, DictWrapper)) - - wrapper = ps.options.compute - self.assertEqual(wrapper.max_rows, ps.get_option("compute.max_rows")) - wrapper.max_rows = 1000 - self.assertEqual(ps.options.compute.max_rows, 1000) - - self.assertRaisesRegex(config.OptionError, "No such option", lambda: ps.options.compu) - self.assertRaisesRegex(config.OptionError, "No such option", lambda: ps.options.compute.max) - self.assertRaisesRegex(config.OptionError, "No such option", lambda: ps.options.max_rows1) - - with self.assertRaisesRegex(config.OptionError, "No such option"): - ps.options.compute.max = 0 - with self.assertRaisesRegex(config.OptionError, "No such option"): - ps.options.compute = 0 - with self.assertRaisesRegex(config.OptionError, "No such option"): - ps.options.com = 0 - finally: - ps.reset_option("compute.max_rows") - - def test_dir_options(self): - self.assertTrue("compute.default_index_type" in dir(ps.options)) - self.assertTrue("plotting.sample_ratio" in dir(ps.options)) - - self.assertTrue("default_index_type" in dir(ps.options.compute)) - self.assertTrue("sample_ratio" not in dir(ps.options.compute)) - - self.assertTrue("default_index_type" not in dir(ps.options.plotting)) - self.assertTrue("sample_ratio" in dir(ps.options.plotting)) - - -class ConfigTests(ConfigTestsMixin, PandasOnSparkTestCase): - pass - - -if __name__ == "__main__": - import unittest - - from pyspark.pandas.tests.test_config import * - - try: - import xmlrunner - - testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) - except ImportError: - testRunner = None - unittest.main(testRunner=testRunner, verbosity=2) diff --git a/tests/connect/test_connect_basic.py b/tests/connect/test_connect_basic.py deleted file mode 100755 index 0f55d2e78f..0000000000 --- a/tests/connect/test_connect_basic.py +++ /dev/null @@ -1,1477 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import datetime -import gc -import io -import os -import shutil -import tempfile -import unittest -from contextlib import redirect_stdout - -# from pyspark.util import is_remote_only -from pyspark.errors import PySparkTypeError, PySparkValueError -from pyspark.errors.exceptions.connect import ( - AnalysisException, - SparkConnectException, -) -from pyspark.sql import Row -from pyspark.sql import SparkSession as PySparkSession -from pyspark.sql.types import ( - ArrayType, - IntegerType, - LongType, - MapType, - Row, - StringType, - StructField, - StructType, -) -from pyspark.testing.connectutils import ( - ReusedConnectTestCase, - should_test_connect, -) -from pyspark.testing.pandasutils import PandasOnSparkTestUtils -from pyspark.testing.sqlutils import SQLTestUtils -from pyspark.testing.utils import eventually - -if should_test_connect: - from pyspark.sql import functions as SF - from pyspark.sql.connect import functions as CF - from pyspark.sql.connect.column import Column - from pyspark.sql.connect.dataframe import DataFrame as CDataFrame - from pyspark.sql.connect.proto import Expression as ProtoExpression - from pyspark.sql.dataframe import DataFrame - - -def is_remote_only(): - return False - - -@unittest.skipIf(is_remote_only(), "Requires JVM access") -class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils): - """Parent test fixture class for all Spark Connect related - test cases.""" - - @classmethod - def setUpClass(cls): - super(SparkConnectSQLTestCase, cls).setUpClass() - # Disable the shared namespace so pyspark.sql.functions, etc point the regular - # PySpark libraries. - os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1" - - cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session. - cls.spark = PySparkSession._instantiatedSession - assert cls.spark is not None - - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.testDataStr = [Row(key=str(i)) for i in range(100)] - cls.df = cls.spark.sparkContext.parallelize(cls.testData).toDF() - cls.df_text = cls.spark.sparkContext.parallelize(cls.testDataStr).toDF() - - cls.tbl_name = "test_connect_basic_table_1" - cls.tbl_name2 = "test_connect_basic_table_2" - cls.tbl_name3 = "test_connect_basic_table_3" - cls.tbl_name4 = "test_connect_basic_table_4" - cls.tbl_name_empty = "test_connect_basic_table_empty" - - # Cleanup test data - cls.spark_connect_clean_up_test_data() - # Load test data - cls.spark_connect_load_test_data() - - @classmethod - def tearDownClass(cls): - try: - cls.spark_connect_clean_up_test_data() - # Stopping Spark Connect closes the session in JVM at the server. - cls.spark = cls.connect - del os.environ["PYSPARK_NO_NAMESPACE_SHARE"] - finally: - super(SparkConnectSQLTestCase, cls).tearDownClass() - - @classmethod - def spark_connect_load_test_data(cls): - df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"]) - # Since we might create multiple Spark sessions, we need to create global temporary view - # that is specifically maintained in the "global_temp" schema. - df.write.saveAsTable(cls.tbl_name) - df2 = cls.spark.createDataFrame([(x, f"{x}", 2 * x) for x in range(100)], ["col1", "col2", "col3"]) - df2.write.saveAsTable(cls.tbl_name2) - df3 = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "test\n_column"]) - df3.write.saveAsTable(cls.tbl_name3) - df4 = cls.spark.createDataFrame( - [(x, {"a": x}, [x, x * 2]) for x in range(100)], ["id", "map_column", "array_column"] - ) - df4.write.saveAsTable(cls.tbl_name4) - empty_table_schema = StructType( - [ - StructField("firstname", StringType(), True), - StructField("middlename", StringType(), True), - StructField("lastname", StringType(), True), - ] - ) - emptyRDD = cls.spark.sparkContext.emptyRDD() - empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema) - empty_df.write.saveAsTable(cls.tbl_name_empty) - - @classmethod - def spark_connect_clean_up_test_data(cls): - cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name}") - cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name2}") - cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name3}") - cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name4}") - cls.spark.sql(f"DROP TABLE IF EXISTS {cls.tbl_name_empty}") - - -class SparkConnectBasicTests(SparkConnectSQLTestCase): - def test_serialization(self): - from pyspark.cloudpickle import dumps, loads - - cdf = self.connect.range(10) - data = dumps(cdf) - cdf2 = loads(data) - self.assertEqual(cdf.collect(), cdf2.collect()) - - def test_df_getattr_behavior(self): - cdf = self.connect.range(10) - sdf = self.spark.range(10) - - sdf._simple_extension = 10 - cdf._simple_extension = 10 - - self.assertEqual(sdf._simple_extension, cdf._simple_extension) - self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension)) - - self.assertTrue(hasattr(cdf, "_simple_extension")) - self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit")) - - def test_df_get_item(self): - # SPARK-41779: test __getitem__ - - query = """ - SELECT * FROM VALUES - (true, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) - AS tab(a, b, c) - """ - - # +-----+----+----+ - # | a| b| c| - # +-----+----+----+ - # | true| 1|NULL| - # |false|NULL| 2.0| - # | NULL| 3| 3.0| - # +-----+----+----+ - - cdf = self.connect.sql(query) - sdf = self.spark.sql(query) - - # filter - self.assert_eq( - cdf[cdf.a].toPandas(), - sdf[sdf.a].toPandas(), - ) - self.assert_eq( - cdf[cdf.b.isin(2, 3)].toPandas(), - sdf[sdf.b.isin(2, 3)].toPandas(), - ) - self.assert_eq( - cdf[cdf.c > 1.5].toPandas(), - sdf[sdf.c > 1.5].toPandas(), - ) - - # select - self.assert_eq( - cdf[[cdf.a, "b", cdf.c]].toPandas(), - sdf[[sdf.a, "b", sdf.c]].toPandas(), - ) - self.assert_eq( - cdf[(cdf.a, "b", cdf.c)].toPandas(), - sdf[(sdf.a, "b", sdf.c)].toPandas(), - ) - - # select by index - self.assertTrue(isinstance(cdf[0], Column)) - self.assertTrue(isinstance(cdf[1], Column)) - self.assertTrue(isinstance(cdf[2], Column)) - - self.assert_eq( - cdf[[cdf[0], cdf[1], cdf[2]]].toPandas(), - sdf[[sdf[0], sdf[1], sdf[2]]].toPandas(), - ) - - # check error - with self.assertRaises(PySparkTypeError) as pe: - cdf[1.5] - - self.check_error( - exception=pe.exception, - errorClass="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE", - messageParameters={ - "arg_name": "item", - "arg_type": "float", - }, - ) - - with self.assertRaises(PySparkTypeError) as pe: - cdf[None] - - self.check_error( - exception=pe.exception, - errorClass="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE", - messageParameters={ - "arg_name": "item", - "arg_type": "NoneType", - }, - ) - - with self.assertRaises(PySparkTypeError) as pe: - cdf[cdf] - - self.check_error( - exception=pe.exception, - errorClass="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE", - messageParameters={ - "arg_name": "item", - "arg_type": "DataFrame", - }, - ) - - def test_join_condition_column_list_columns(self): - left_connect_df = self.connect.read.table(self.tbl_name) - right_connect_df = self.connect.read.table(self.tbl_name2) - left_spark_df = self.spark.read.table(self.tbl_name) - right_spark_df = self.spark.read.table(self.tbl_name2) - joined_plan = left_connect_df.join( - other=right_connect_df, on=left_connect_df.id == right_connect_df.col1, how="inner" - ) - joined_plan2 = left_spark_df.join(other=right_spark_df, on=left_spark_df.id == right_spark_df.col1, how="inner") - self.assert_eq(joined_plan.toPandas(), joined_plan2.toPandas()) - - joined_plan3 = left_connect_df.join( - other=right_connect_df, - on=[ - left_connect_df.id == right_connect_df.col1, - left_connect_df.name == right_connect_df.col2, - ], - how="inner", - ) - joined_plan4 = left_spark_df.join( - other=right_spark_df, - on=[left_spark_df.id == right_spark_df.col1, left_spark_df.name == right_spark_df.col2], - how="inner", - ) - self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas()) - - def test_join_ambiguous_cols(self): - # SPARK-41812: test join with ambiguous columns - data1 = [Row(id=1, value="foo"), Row(id=2, value=None)] - cdf1 = self.connect.createDataFrame(data1) - sdf1 = self.spark.createDataFrame(data1) - - data2 = [Row(value="bar"), Row(value=None), Row(value="foo")] - cdf2 = self.connect.createDataFrame(data2) - sdf2 = self.spark.createDataFrame(data2) - - cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]) - sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]) - - self.assertEqual(cdf3.schema, sdf3.schema) - self.assertEqual(cdf3.collect(), sdf3.collect()) - - cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"])) - sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"])) - - self.assertEqual(cdf4.schema, sdf4.schema) - self.assertEqual(cdf4.collect(), sdf4.collect()) - - cdf5 = cdf1.join(cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"]))) - sdf5 = sdf1.join(sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"]))) - - self.assertEqual(cdf5.schema, sdf5.schema) - self.assertEqual(cdf5.collect(), sdf5.collect()) - - cdf6 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf1.value) - sdf6 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf1.value) - - self.assertEqual(cdf6.schema, sdf6.schema) - self.assertEqual(cdf6.collect(), sdf6.collect()) - - cdf7 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf2.value) - sdf7 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf2.value) - - self.assertEqual(cdf7.schema, sdf7.schema) - self.assertEqual(cdf7.collect(), sdf7.collect()) - - def test_join_with_cte(self): - cte_query = "with dt as (select 1 as ida) select ida as id from dt" - - sdf1 = self.spark.range(10) - sdf2 = self.spark.sql(cte_query) - sdf3 = sdf1.join(sdf2, sdf1.id == sdf2.id) - - cdf1 = self.connect.range(10) - cdf2 = self.connect.sql(cte_query) - cdf3 = cdf1.join(cdf2, cdf1.id == cdf2.id) - - self.assertEqual(sdf3.schema, cdf3.schema) - self.assertEqual(sdf3.collect(), cdf3.collect()) - - def test_with_columns_renamed(self): - # SPARK-41312: test DataFrame.withColumnsRenamed() - self.assertEqual( - self.connect.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema, - self.spark.read.table(self.tbl_name).withColumnRenamed("id", "id_new").schema, - ) - self.assertEqual( - self.connect.read.table(self.tbl_name).withColumnsRenamed({"id": "id_new", "name": "name_new"}).schema, - self.spark.read.table(self.tbl_name).withColumnsRenamed({"id": "id_new", "name": "name_new"}).schema, - ) - - def test_simple_explain_string(self): - df = self.connect.read.table(self.tbl_name).limit(10) - result = df._explain_string() - self.assertGreater(len(result), 0) - - def _check_print_schema(self, query: str): - with io.StringIO() as buf, redirect_stdout(buf): - self.spark.sql(query).printSchema() - print1 = buf.getvalue() - with io.StringIO() as buf, redirect_stdout(buf): - self.connect.sql(query).printSchema() - print2 = buf.getvalue() - self.assertEqual(print1, print2, query) - - for level in [-1, 0, 1, 2, 3, 4]: - with io.StringIO() as buf, redirect_stdout(buf): - self.spark.sql(query).printSchema(level) - print1 = buf.getvalue() - with io.StringIO() as buf, redirect_stdout(buf): - self.connect.sql(query).printSchema(level) - print2 = buf.getvalue() - self.assertEqual(print1, print2, query) - - def test_schema(self): - schema = self.connect.read.table(self.tbl_name).schema - self.assertEqual( - StructType([StructField("id", LongType(), True), StructField("name", StringType(), True)]), - schema, - ) - - # test FloatType, DoubleType, DecimalType, StringType, BooleanType, NullType - query = """ - SELECT * FROM VALUES - (float(1.0), double(1.0), 1.0, "1", true, NULL), - (float(2.0), double(2.0), 2.0, "2", false, NULL), - (float(3.0), double(3.0), NULL, "3", false, NULL) - AS tab(a, b, c, d, e, f) - """ - self.assertEqual( - self.spark.sql(query).schema, - self.connect.sql(query).schema, - ) - self._check_print_schema(query) - - # test TimestampType, DateType - query = """ - SELECT * FROM VALUES - (TIMESTAMP('2019-04-12 15:50:00'), DATE('2022-02-22')), - (TIMESTAMP('2019-04-12 15:50:00'), NULL), - (NULL, DATE('2022-02-22')) - AS tab(a, b) - """ - self.assertEqual( - self.spark.sql(query).schema, - self.connect.sql(query).schema, - ) - self._check_print_schema(query) - - # test DayTimeIntervalType - query = """ SELECT INTERVAL '100 10:30' DAY TO MINUTE AS interval """ - self.assertEqual( - self.spark.sql(query).schema, - self.connect.sql(query).schema, - ) - self._check_print_schema(query) - - # test MapType - query = """ - SELECT * FROM VALUES - (MAP('a', 'ab'), MAP('a', 'ab'), MAP(1, 2, 3, 4)), - (MAP('x', 'yz'), MAP('x', NULL), NULL), - (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4)) - AS tab(a, b, c) - """ - self.assertEqual( - self.spark.sql(query).schema, - self.connect.sql(query).schema, - ) - self._check_print_schema(query) - - # test ArrayType - query = """ - SELECT * FROM VALUES - (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3)), - (ARRAY('x', NULL), NULL, ARRAY(1, 3)), - (NULL, ARRAY(-1, -2, -3), Array()) - AS tab(a, b, c) - """ - self.assertEqual( - self.spark.sql(query).schema, - self.connect.sql(query).schema, - ) - self._check_print_schema(query) - - # test StructType - query = """ - SELECT STRUCT(a, b, c, d), STRUCT(e, f, g), STRUCT(STRUCT(a, b), STRUCT(h)) FROM VALUES - (float(1.0), double(1.0), 1.0, "1", true, NULL, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)), - (float(2.0), double(2.0), 2.0, "2", false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)), - (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL) - AS tab(a, b, c, d, e, f, g, h) - """ - self.assertEqual( - self.spark.sql(query).schema, - self.connect.sql(query).schema, - ) - self._check_print_schema(query) - - def test_to(self): - # SPARK-41464: test DataFrame.to() - - cdf = self.connect.read.table(self.tbl_name) - df = self.spark.read.table(self.tbl_name) - - def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType): - cdf_to = cdf.to(schema) - df_to = df.to(schema) - self.assertEqual(cdf_to.schema, df_to.schema) - self.assert_eq(cdf_to.toPandas(), df_to.toPandas()) - - # The schema has not changed - schema = StructType( - [ - StructField("id", IntegerType(), True), - StructField("name", StringType(), True), - ] - ) - - assert_eq_schema(cdf, df, schema) - - # Change schema with struct - schema2 = StructType([StructField("struct", schema, False)]) - - cdf_to = cdf.select(CF.struct("id", "name").alias("struct")).to(schema2) - df_to = df.select(SF.struct("id", "name").alias("struct")).to(schema2) - - self.assertEqual(cdf_to.schema, df_to.schema) - - # Change the column name - schema = StructType( - [ - StructField("col1", IntegerType(), True), - StructField("col2", StringType(), True), - ] - ) - - assert_eq_schema(cdf, df, schema) - - # Change the column data type - schema = StructType( - [ - StructField("id", StringType(), True), - StructField("name", StringType(), True), - ] - ) - - assert_eq_schema(cdf, df, schema) - - # Reduce the column quantity and change data type - schema = StructType( - [ - StructField("id", LongType(), True), - ] - ) - - assert_eq_schema(cdf, df, schema) - - # incompatible field nullability - schema = StructType([StructField("id", LongType(), False)]) - self.assertRaisesRegex( - AnalysisException, - "NULLABLE_COLUMN_OR_FIELD", - lambda: cdf.to(schema).toPandas(), - ) - - # field cannot upcast - schema = StructType([StructField("name", LongType())]) - self.assertRaisesRegex( - AnalysisException, - "INVALID_COLUMN_OR_FIELD_DATA_TYPE", - lambda: cdf.to(schema).toPandas(), - ) - - schema = StructType( - [ - StructField("id", IntegerType(), True), - StructField("name", IntegerType(), True), - ] - ) - self.assertRaisesRegex( - AnalysisException, - "INVALID_COLUMN_OR_FIELD_DATA_TYPE", - lambda: cdf.to(schema).toPandas(), - ) - - # Test map type and array type - schema = StructType( - [ - StructField("id", StringType(), True), - StructField("my_map", MapType(StringType(), IntegerType(), False), True), - StructField("my_array", ArrayType(IntegerType(), False), True), - ] - ) - cdf = self.connect.read.table(self.tbl_name4) - df = self.spark.read.table(self.tbl_name4) - - assert_eq_schema(cdf, df, schema) - - def test_toDF(self): - # SPARK-41310: test DataFrame.toDF() - self.assertEqual( - self.connect.read.table(self.tbl_name).toDF("col1", "col2").schema, - self.spark.read.table(self.tbl_name).toDF("col1", "col2").schema, - ) - - def test_print_schema(self): - # SPARK-41216: Test print schema - tree_str = self.connect.sql("SELECT 1 AS X, 2 AS Y").schema.treeString() - # root - # |-- X: integer (nullable = false) - # |-- Y: integer (nullable = false) - expected = "root\n |-- X: integer (nullable = false)\n |-- Y: integer (nullable = false)\n" - self.assertEqual(tree_str, expected) - - def test_is_local(self): - # SPARK-41216: Test is local - self.assertTrue(self.connect.sql("SHOW DATABASES").isLocal()) - self.assertFalse(self.connect.read.table(self.tbl_name).isLocal()) - - def test_is_streaming(self): - # SPARK-41216: Test is streaming - self.assertFalse(self.connect.read.table(self.tbl_name).isStreaming) - self.assertFalse(self.connect.sql("SELECT 1 AS X LIMIT 0").isStreaming) - - def test_input_files(self): - # SPARK-41216: Test input files - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - try: - self.df_text.write.text(tmpPath) - - input_files_list1 = self.spark.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles() - input_files_list2 = self.connect.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles() - - self.assertTrue(len(input_files_list1) > 0) - self.assertEqual(len(input_files_list1), len(input_files_list2)) - for file_path in input_files_list2: - self.assertTrue(file_path in input_files_list1) - finally: - shutil.rmtree(tmpPath) - - def test_limit_offset(self): - df = self.connect.read.table(self.tbl_name) - pd = df.limit(10).offset(1).toPandas() - self.assertEqual(9, len(pd.index)) - pd2 = df.offset(98).limit(10).toPandas() - self.assertEqual(2, len(pd2.index)) - - def test_tail(self): - df = self.connect.read.table(self.tbl_name) - df2 = self.spark.read.table(self.tbl_name) - self.assertEqual(df.tail(10), df2.tail(10)) - - def test_sql(self): - pdf = self.connect.sql("SELECT 1").toPandas() - self.assertEqual(1, len(pdf.index)) - - def test_sql_with_named_args(self): - sqlText = "SELECT *, element_at(:m, 'a') FROM range(10) WHERE id > :minId" - df = self.connect.sql(sqlText, args={"minId": 7, "m": CF.create_map(CF.lit("a"), CF.lit(1))}) - df2 = self.spark.sql(sqlText, args={"minId": 7, "m": SF.create_map(SF.lit("a"), SF.lit(1))}) - self.assert_eq(df.toPandas(), df2.toPandas()) - - def test_namedargs_with_global_limit(self): - sqlText = """SELECT * FROM VALUES (TIMESTAMP('2022-12-25 10:30:00'), 1) as tab(date, val) - where val = :val""" - df = self.connect.sql(sqlText, args={"val": 1}) - df2 = self.spark.sql(sqlText, args={"val": 1}) - self.assert_eq(df.toPandas(), df2.toPandas()) - - self.assert_eq(df.first()[0], datetime.datetime(2022, 12, 25, 10, 30)) - self.assert_eq(df.first().date, datetime.datetime(2022, 12, 25, 10, 30)) - self.assert_eq(df.first()[1], 1) - self.assert_eq(df.first().val, 1) - - def test_sql_with_pos_args(self): - sqlText = "SELECT *, element_at(?, 1) FROM range(10) WHERE id > ?" - df = self.connect.sql(sqlText, args=[CF.array(CF.lit(1)), 7]) - df2 = self.spark.sql(sqlText, args=[SF.array(SF.lit(1)), 7]) - self.assert_eq(df.toPandas(), df2.toPandas()) - - def test_sql_with_invalid_args(self): - sqlText = "SELECT ?, ?, ?" - for session in [self.connect, self.spark]: - with self.assertRaises(PySparkTypeError) as pe: - session.sql(sqlText, args={1, 2, 3}) - - self.check_error( - exception=pe.exception, - errorClass="INVALID_TYPE", - messageParameters={"arg_name": "args", "arg_type": "set"}, - ) - - def test_deduplicate(self): - # SPARK-41326: test distinct and dropDuplicates. - df = self.connect.read.table(self.tbl_name) - df2 = self.spark.read.table(self.tbl_name) - self.assert_eq(df.distinct().toPandas(), df2.distinct().toPandas()) - self.assert_eq(df.dropDuplicates().toPandas(), df2.dropDuplicates().toPandas()) - self.assert_eq(df.dropDuplicates(["name"]).toPandas(), df2.dropDuplicates(["name"]).toPandas()) - - def test_drop(self): - # SPARK-41169: test drop - query = """ - SELECT * FROM VALUES - (false, 1, NULL), (false, NULL, 2), (NULL, 3, 3) - AS tab(a, b, c) - """ - - cdf = self.connect.sql(query) - sdf = self.spark.sql(query) - self.assert_eq( - cdf.drop("a").toPandas(), - sdf.drop("a").toPandas(), - ) - self.assert_eq( - cdf.drop("a", "b").toPandas(), - sdf.drop("a", "b").toPandas(), - ) - self.assert_eq( - cdf.drop("a", "x").toPandas(), - sdf.drop("a", "x").toPandas(), - ) - self.assert_eq( - cdf.drop(cdf.a, "x").toPandas(), - sdf.drop(sdf.a, "x").toPandas(), - ) - - def test_subquery_alias(self) -> None: - # SPARK-40938: test subquery alias. - plan_text = self.connect.read.table(self.tbl_name).alias("special_alias")._explain_string(extended=True) - self.assertTrue("special_alias" in plan_text) - - def test_sort(self): - # SPARK-41332: test sort - query = """ - SELECT * FROM VALUES - (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) - AS tab(a, b, c) - """ - # +-----+----+----+ - # | a| b| c| - # +-----+----+----+ - # |false| 1|NULL| - # |false|NULL| 2.0| - # | NULL| 3| 3.0| - # +-----+----+----+ - - cdf = self.connect.sql(query) - sdf = self.spark.sql(query) - self.assert_eq( - cdf.sort("a").toPandas(), - sdf.sort("a").toPandas(), - ) - self.assert_eq( - cdf.sort("c").toPandas(), - sdf.sort("c").toPandas(), - ) - self.assert_eq( - cdf.sort("b").toPandas(), - sdf.sort("b").toPandas(), - ) - self.assert_eq( - cdf.sort(cdf.c, "b").toPandas(), - sdf.sort(sdf.c, "b").toPandas(), - ) - self.assert_eq( - cdf.sort(cdf.c.desc(), "b").toPandas(), - sdf.sort(sdf.c.desc(), "b").toPandas(), - ) - self.assert_eq( - cdf.sort(cdf.c.desc(), cdf.a.asc()).toPandas(), - sdf.sort(sdf.c.desc(), sdf.a.asc()).toPandas(), - ) - - def test_range(self): - self.assert_eq( - self.connect.range(start=0, end=10).toPandas(), - self.spark.range(start=0, end=10).toPandas(), - ) - self.assert_eq( - self.connect.range(start=0, end=10, step=3).toPandas(), - self.spark.range(start=0, end=10, step=3).toPandas(), - ) - self.assert_eq( - self.connect.range(start=0, end=10, step=3, numPartitions=2).toPandas(), - self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas(), - ) - # SPARK-41301 - self.assert_eq(self.connect.range(10).toPandas(), self.connect.range(start=0, end=10).toPandas()) - - def test_create_global_temp_view(self): - # SPARK-41127: test global temp view creation. - with self.tempView("view_1"): - self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") - self.connect.sql("SELECT 2 AS X LIMIT 1").createOrReplaceGlobalTempView("view_1") - self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) - - # Test when creating a view which is already exists but - self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) - with self.assertRaises(AnalysisException): - self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") - - def test_create_session_local_temp_view(self): - # SPARK-41372: test session local temp view creation. - with self.tempView("view_local_temp"): - self.connect.sql("SELECT 1 AS X").createTempView("view_local_temp") - self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 1) - self.connect.sql("SELECT 1 AS X LIMIT 0").createOrReplaceTempView("view_local_temp") - self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0) - - # Test when creating a view which is already exists but - with self.assertRaises(AnalysisException): - self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp") - - def test_select_expr(self): - # SPARK-41201: test selectExpr API. - self.assert_eq( - self.connect.read.table(self.tbl_name).selectExpr("id * 2").toPandas(), - self.spark.read.table(self.tbl_name).selectExpr("id * 2").toPandas(), - ) - self.assert_eq( - self.connect.read.table(self.tbl_name).selectExpr(["id * 2", "cast(name as long) as name"]).toPandas(), - self.spark.read.table(self.tbl_name).selectExpr(["id * 2", "cast(name as long) as name"]).toPandas(), - ) - - self.assert_eq( - self.connect.read.table(self.tbl_name).selectExpr("id * 2", "cast(name as long) as name").toPandas(), - self.spark.read.table(self.tbl_name).selectExpr("id * 2", "cast(name as long) as name").toPandas(), - ) - - def test_select_star(self): - data = [Row(a=1, b=Row(c=2, d=Row(e=3)))] - - # +---+--------+ - # | a| b| - # +---+--------+ - # | 1|{2, {3}}| - # +---+--------+ - - cdf = self.connect.createDataFrame(data=data) - sdf = self.spark.createDataFrame(data=data) - - self.assertEqual( - cdf.select("*").collect(), - sdf.select("*").collect(), - ) - self.assertEqual( - cdf.select("a", "*").collect(), - sdf.select("a", "*").collect(), - ) - self.assertEqual( - cdf.select("a", "b").collect(), - sdf.select("a", "b").collect(), - ) - self.assertEqual( - cdf.select("a", "b.*").collect(), - sdf.select("a", "b.*").collect(), - ) - - def test_union_by_name(self): - # SPARK-41832: Test unionByName - data1 = [(1, 2, 3)] - data2 = [(6, 2, 5)] - df1_connect = self.connect.createDataFrame(data1, ["a", "b", "c"]) - df2_connect = self.connect.createDataFrame(data2, ["a", "b", "c"]) - union_df_connect = df1_connect.unionByName(df2_connect) - - df1_spark = self.spark.createDataFrame(data1, ["a", "b", "c"]) - df2_spark = self.spark.createDataFrame(data2, ["a", "b", "c"]) - union_df_spark = df1_spark.unionByName(df2_spark) - - self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas()) - - df2_connect = self.connect.createDataFrame(data2, ["a", "B", "C"]) - union_df_connect = df1_connect.unionByName(df2_connect, allowMissingColumns=True) - - df2_spark = self.spark.createDataFrame(data2, ["a", "B", "C"]) - union_df_spark = df1_spark.unionByName(df2_spark, allowMissingColumns=True) - - self.assert_eq(union_df_connect.toPandas(), union_df_spark.toPandas()) - - def test_observe(self): - # SPARK-41527: test DataFrame.observe() - observation_name = "my_metric" - - self.assert_eq( - self.connect.read.table(self.tbl_name) - .filter("id > 3") - .observe(observation_name, CF.min("id"), CF.max("id"), CF.sum("id")) - .toPandas(), - self.spark.read.table(self.tbl_name) - .filter("id > 3") - .observe(observation_name, SF.min("id"), SF.max("id"), SF.sum("id")) - .toPandas(), - ) - - from pyspark.sql.connect.observation import Observation as ConnectObservation - from pyspark.sql.observation import Observation - - cobservation = ConnectObservation(observation_name) - observation = Observation(observation_name) - - cdf = ( - self.connect.read.table(self.tbl_name) - .filter("id > 3") - .observe(cobservation, CF.min("id"), CF.max("id"), CF.sum("id")) - .toPandas() - ) - df = ( - self.spark.read.table(self.tbl_name) - .filter("id > 3") - .observe(observation, SF.min("id"), SF.max("id"), SF.sum("id")) - .toPandas() - ) - - self.assert_eq(cdf, df) - - self.assertEqual(cobservation.get, observation.get) - - observed_metrics = cdf.attrs["observed_metrics"] - self.assert_eq(len(observed_metrics), 1) - self.assert_eq(observed_metrics[0].name, observation_name) - self.assert_eq(len(observed_metrics[0].metrics), 3) - for metric in observed_metrics[0].metrics: - self.assertIsInstance(metric, ProtoExpression.Literal) - values = list(map(lambda metric: metric.long, observed_metrics[0].metrics)) - self.assert_eq(values, [4, 99, 4944]) - - with self.assertRaises(PySparkValueError) as pe: - self.connect.read.table(self.tbl_name).observe(observation_name) - - self.check_error( - exception=pe.exception, - errorClass="CANNOT_BE_EMPTY", - messageParameters={"item": "exprs"}, - ) - - with self.assertRaises(PySparkTypeError) as pe: - self.connect.read.table(self.tbl_name).observe(observation_name, CF.lit(1), "id") - - self.check_error( - exception=pe.exception, - errorClass="NOT_LIST_OF_COLUMN", - messageParameters={"arg_name": "exprs"}, - ) - - def test_with_columns(self): - # SPARK-41256: test withColumn(s). - self.assert_eq( - self.connect.read.table(self.tbl_name).withColumn("id", CF.lit(False)).toPandas(), - self.spark.read.table(self.tbl_name).withColumn("id", SF.lit(False)).toPandas(), - ) - - self.assert_eq( - self.connect.read.table(self.tbl_name) - .withColumns({"id": CF.lit(False), "col_not_exist": CF.lit(False)}) - .toPandas(), - self.spark.read.table(self.tbl_name) - .withColumns( - { - "id": SF.lit(False), - "col_not_exist": SF.lit(False), - } - ) - .toPandas(), - ) - - def test_hint(self): - # SPARK-41349: Test hint - self.assert_eq( - self.connect.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(), - self.spark.read.table(self.tbl_name).hint("COALESCE", 3000).toPandas(), - ) - - # Hint with unsupported name will be ignored - self.assert_eq( - self.connect.read.table(self.tbl_name).hint("illegal").toPandas(), - self.spark.read.table(self.tbl_name).hint("illegal").toPandas(), - ) - - # Hint with all supported parameter values - such_a_nice_list = ["itworks1", "itworks2", "itworks3"] - self.assert_eq( - self.connect.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(), - self.spark.read.table(self.tbl_name).hint("my awesome hint", 1.2345, 2).toPandas(), - ) - - # Hint with unsupported parameter values - with self.assertRaises(AnalysisException): - self.connect.read.table(self.tbl_name).hint("REPARTITION", "id+1").toPandas() - - # Hint with unsupported parameter types - with self.assertRaises(TypeError): - self.connect.read.table(self.tbl_name).hint("REPARTITION", range(5)).toPandas() - - # Hint with unsupported parameter types - with self.assertRaises(TypeError): - self.connect.read.table(self.tbl_name).hint( - "my awesome hint", 1.2345, 2, such_a_nice_list, range(6) - ).toPandas() - - # Hint with wrong combination - with self.assertRaises(AnalysisException): - self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas() - - def test_join_hint(self): - cdf1 = self.connect.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) - cdf2 = self.connect.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) - - self.assertTrue("BroadcastHashJoin" in cdf1.join(cdf2.hint("BROADCAST"), "name")._explain_string()) - self.assertTrue("SortMergeJoin" in cdf1.join(cdf2.hint("MERGE"), "name")._explain_string()) - self.assertTrue("ShuffledHashJoin" in cdf1.join(cdf2.hint("SHUFFLE_HASH"), "name")._explain_string()) - - def test_extended_hint_types(self): - cdf = self.connect.range(100).toDF("id") - - cdf.hint( - "my awesome hint", - 1.2345, - "what", - ["itworks1", "itworks2", "itworks3"], - ).show() - - with self.assertRaises(PySparkTypeError) as pe: - cdf.hint( - "my awesome hint", - 1.2345, - "what", - {"itworks1": "itworks2"}, - ).show() - - self.check_error( - exception=pe.exception, - errorClass="INVALID_ITEM_FOR_CONTAINER", - messageParameters={ - "arg_name": "parameters", - "allowed_types": "str, float, int, Column, list[str], list[float], list[int]", - "item_type": "dict", - }, - ) - - def test_empty_dataset(self): - # SPARK-41005: Test arrow based collection with empty dataset. - self.assertTrue( - self.connect.sql("SELECT 1 AS X LIMIT 0") - .toPandas() - .equals(self.spark.sql("SELECT 1 AS X LIMIT 0").toPandas()) - ) - pdf = self.connect.sql("SELECT 1 AS X LIMIT 0").toPandas() - self.assertEqual(0, len(pdf)) # empty dataset - self.assertEqual(1, len(pdf.columns)) # one column - self.assertEqual("X", pdf.columns[0]) - - def test_is_empty(self): - # SPARK-41212: Test is empty - self.assertFalse(self.connect.sql("SELECT 1 AS X").isEmpty()) - self.assertTrue(self.connect.sql("SELECT 1 AS X LIMIT 0").isEmpty()) - - def test_is_empty_with_unsupported_types(self): - df = self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval") - self.assertEqual(df.count(), 1) - self.assertFalse(df.isEmpty()) - - def test_session(self): - self.assertEqual(self.connect, self.connect.sql("SELECT 1").sparkSession) - - def test_show(self): - # SPARK-41111: Test the show method - show_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._show_string() - # +---+---+ - # | X| Y| - # +---+---+ - # | 1| 2| - # +---+---+ - expected = "+---+---+\n| X| Y|\n+---+---+\n| 1| 2|\n+---+---+\n" - self.assertEqual(show_str, expected) - - def test_repr(self): - # SPARK-41213: Test the __repr__ method - query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)""" - self.assertEqual( - self.connect.sql(query).__repr__(), - self.spark.sql(query).__repr__(), - ) - - def test_explain_string(self): - # SPARK-41122: test explain API. - plan_str = self.connect.sql("SELECT 1")._explain_string(extended=True) - self.assertTrue("Parsed Logical Plan" in plan_str) - self.assertTrue("Analyzed Logical Plan" in plan_str) - self.assertTrue("Optimized Logical Plan" in plan_str) - self.assertTrue("Physical Plan" in plan_str) - - with self.assertRaises(PySparkValueError) as pe: - self.connect.sql("SELECT 1")._explain_string(mode="unknown") - self.check_error( - exception=pe.exception, - errorClass="UNKNOWN_EXPLAIN_MODE", - messageParameters={"explain_mode": "unknown"}, - ) - - def test_count(self) -> None: - # SPARK-41308: test count() API. - self.assertEqual( - self.connect.read.table(self.tbl_name).count(), - self.spark.read.table(self.tbl_name).count(), - ) - - def test_simple_transform(self) -> None: - """SPARK-41203: Support DF.transform""" - - def transform_df(input_df: CDataFrame) -> CDataFrame: - return input_df.select((CF.col("id") + CF.lit(10)).alias("id")) - - df = self.connect.range(1, 100) - result_left = df.transform(transform_df).collect() - result_right = self.connect.range(11, 110).collect() - self.assertEqual(result_right, result_left) - - # Check assertion. - with self.assertRaises(AssertionError): - df.transform(lambda x: 2) # type: ignore - - def test_alias(self) -> None: - """Testing supported and unsupported alias""" - col0 = self.connect.range(1, 10).select(CF.col("id").alias("name", metadata={"max": 99})).schema.names[0] - self.assertEqual("name", col0) - - with self.assertRaises(SparkConnectException) as exc: - self.connect.range(1, 10).select(CF.col("id").alias("this", "is", "not")).collect() - self.assertIn("(this, is, not)", str(exc.exception)) - - def test_column_regexp(self) -> None: - # SPARK-41438: test dataframe.colRegex() - ndf = self.connect.read.table(self.tbl_name3) - df = self.spark.read.table(self.tbl_name3) - - self.assert_eq( - ndf.select(ndf.colRegex("`tes.*\n.*mn`")).toPandas(), - df.select(df.colRegex("`tes.*\n.*mn`")).toPandas(), - ) - - def test_repartition(self) -> None: - # SPARK-41354: test dataframe.repartition(numPartitions) - self.assert_eq( - self.connect.read.table(self.tbl_name).repartition(10).toPandas(), - self.spark.read.table(self.tbl_name).repartition(10).toPandas(), - ) - - self.assert_eq( - self.connect.read.table(self.tbl_name).coalesce(10).toPandas(), - self.spark.read.table(self.tbl_name).coalesce(10).toPandas(), - ) - - def test_repartition_by_expression(self) -> None: - # SPARK-41354: test dataframe.repartition(expressions) - self.assert_eq( - self.connect.read.table(self.tbl_name).repartition(10, "id").toPandas(), - self.spark.read.table(self.tbl_name).repartition(10, "id").toPandas(), - ) - - self.assert_eq( - self.connect.read.table(self.tbl_name).repartition("id").toPandas(), - self.spark.read.table(self.tbl_name).repartition("id").toPandas(), - ) - - # repartition with unsupported parameter values - with self.assertRaises(AnalysisException): - self.connect.read.table(self.tbl_name).repartition("id+1").toPandas() - - def test_repartition_by_range(self) -> None: - # SPARK-41354: test dataframe.repartitionByRange(expressions) - cdf = self.connect.read.table(self.tbl_name) - sdf = self.spark.read.table(self.tbl_name) - - self.assert_eq( - cdf.repartitionByRange(10, "id").toPandas(), - sdf.repartitionByRange(10, "id").toPandas(), - ) - - self.assert_eq( - cdf.repartitionByRange("id").toPandas(), - sdf.repartitionByRange("id").toPandas(), - ) - - self.assert_eq( - cdf.repartitionByRange(cdf.id.desc()).toPandas(), - sdf.repartitionByRange(sdf.id.desc()).toPandas(), - ) - - # repartitionByRange with unsupported parameter values - with self.assertRaises(AnalysisException): - self.connect.read.table(self.tbl_name).repartitionByRange("id+1").toPandas() - - def test_crossjoin(self): - # SPARK-41227: Test CrossJoin - connect_df = self.connect.read.table(self.tbl_name) - spark_df = self.spark.read.table(self.tbl_name) - self.assert_eq( - set(connect_df.select("id").join(other=connect_df.select("name"), how="cross").toPandas()), - set(spark_df.select("id").join(other=spark_df.select("name"), how="cross").toPandas()), - ) - self.assert_eq( - set(connect_df.select("id").crossJoin(other=connect_df.select("name")).toPandas()), - set(spark_df.select("id").crossJoin(other=spark_df.select("name")).toPandas()), - ) - - def test_self_join(self): - # SPARK-47713: this query fails in classic spark - df1 = self.connect.createDataFrame([(1, "a")], schema=["i", "j"]) - df1_filter = df1.filter(df1.i > 0) - df2 = df1.join(df1_filter, df1.i == 1) - self.assertEqual(df2.count(), 1) - self.assertEqual(df2.columns, ["i", "j", "i", "j"]) - self.assertEqual(list(df2.first()), [1, "a", 1, "a"]) - - def test_with_metadata(self): - cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], schema=["age", "name"]) - self.assertEqual(cdf.schema["age"].metadata, {}) - self.assertEqual(cdf.schema["name"].metadata, {}) - - cdf1 = cdf.withMetadata(columnName="age", metadata={"max_age": 5}) - self.assertEqual(cdf1.schema["age"].metadata, {"max_age": 5}) - - cdf2 = cdf.withMetadata(columnName="name", metadata={"names": ["Alice", "Bob"]}) - self.assertEqual(cdf2.schema["name"].metadata, {"names": ["Alice", "Bob"]}) - - with self.assertRaises(PySparkTypeError) as pe: - cdf.withMetadata(columnName="name", metadata=["magic"]) - - self.check_error( - exception=pe.exception, - errorClass="NOT_DICT", - messageParameters={ - "arg_name": "metadata", - "arg_type": "list", - }, - ) - - def test_version(self): - self.assertEqual( - self.connect.version, - self.spark.version, - ) - - def test_same_semantics(self): - plan = self.connect.sql("SELECT 1") - other = self.connect.sql("SELECT 1") - self.assertTrue(plan.sameSemantics(other)) - - def test_semantic_hash(self): - plan = self.connect.sql("SELECT 1") - other = self.connect.sql("SELECT 1") - self.assertEqual( - plan.semanticHash(), - other.semanticHash(), - ) - - def test_sql_with_command(self): - # SPARK-42705: spark.sql should return values from the command. - self.assertEqual(self.connect.sql("show functions").collect(), self.spark.sql("show functions").collect()) - - def test_df_caache(self): - df = self.connect.range(10) - df.cache() - self.assert_eq(10, df.count()) - self.assertTrue(df.is_cached) - - def test_parse_col_name(self): - from pyspark.sql.connect.types import parse_attr_name - - self.assert_eq(parse_attr_name(""), [""]) - - self.assert_eq(parse_attr_name("a"), ["a"]) - self.assert_eq(parse_attr_name("`a`"), ["a"]) - self.assert_eq(parse_attr_name("`a"), None) - self.assert_eq(parse_attr_name("a`"), None) - - self.assert_eq(parse_attr_name("`a`.b"), ["a", "b"]) - self.assert_eq(parse_attr_name("`a`.`b`"), ["a", "b"]) - self.assert_eq(parse_attr_name("`a```.b"), ["a`", "b"]) - self.assert_eq(parse_attr_name("`a``.b"), None) - - self.assert_eq(parse_attr_name("a.b.c"), ["a", "b", "c"]) - self.assert_eq(parse_attr_name("`a`.`b`.`c`"), ["a", "b", "c"]) - self.assert_eq(parse_attr_name("a.`b`.c"), ["a", "b", "c"]) - - self.assert_eq(parse_attr_name("`a.b.c`"), ["a.b.c"]) - self.assert_eq(parse_attr_name("a.`b.c`"), ["a", "b.c"]) - self.assert_eq(parse_attr_name("`a.b`.c"), ["a.b", "c"]) - self.assert_eq(parse_attr_name("`a.b.c"), None) - self.assert_eq(parse_attr_name("a.b.c`"), None) - self.assert_eq(parse_attr_name("`a.`b.`c"), None) - self.assert_eq(parse_attr_name("a`.b`.c`"), None) - - self.assert_eq(parse_attr_name("`ab..c`e.f"), None) - - def test_verify_col_name(self): - from pyspark.sql.connect.types import verify_col_name - - cdf = ( - self.connect.range(10) - .withColumn("v", CF.lit(123)) - .withColumn("s", CF.struct("id", "v")) - .withColumn("m", CF.struct("s", "v")) - .withColumn("a", CF.array("s")) - ) - - # root - # |-- id: long (nullable = false) - # |-- v: integer (nullable = false) - # |-- s: struct (nullable = false) - # | |-- id: long (nullable = false) - # | |-- v: integer (nullable = false) - # |-- m: struct (nullable = false) - # | |-- s: struct (nullable = false) - # | | |-- id: long (nullable = false) - # | | |-- v: integer (nullable = false) - # | |-- v: integer (nullable = false) - # |-- a: array (nullable = false) - # | |-- element: struct (containsNull = false) - # | | |-- id: long (nullable = false) - # | | |-- v: integer (nullable = false) - - self.assertTrue(verify_col_name("id", cdf.schema)) - self.assertTrue(verify_col_name("`id`", cdf.schema)) - - self.assertTrue(verify_col_name("v", cdf.schema)) - self.assertTrue(verify_col_name("`v`", cdf.schema)) - - self.assertFalse(verify_col_name("x", cdf.schema)) - self.assertFalse(verify_col_name("`x`", cdf.schema)) - - self.assertTrue(verify_col_name("s", cdf.schema)) - self.assertTrue(verify_col_name("`s`", cdf.schema)) - self.assertTrue(verify_col_name("s.id", cdf.schema)) - self.assertTrue(verify_col_name("s.`id`", cdf.schema)) - self.assertTrue(verify_col_name("`s`.id", cdf.schema)) - self.assertTrue(verify_col_name("`s`.`id`", cdf.schema)) - self.assertFalse(verify_col_name("`s.id`", cdf.schema)) - - self.assertTrue(verify_col_name("m", cdf.schema)) - self.assertTrue(verify_col_name("`m`", cdf.schema)) - self.assertTrue(verify_col_name("m.s.id", cdf.schema)) - self.assertTrue(verify_col_name("m.s.`id`", cdf.schema)) - self.assertTrue(verify_col_name("m.`s`.id", cdf.schema)) - self.assertTrue(verify_col_name("`m`.`s`.`id`", cdf.schema)) - self.assertFalse(verify_col_name("m.`s.id`", cdf.schema)) - - self.assertTrue(verify_col_name("a", cdf.schema)) - self.assertTrue(verify_col_name("`a`", cdf.schema)) - self.assertTrue(verify_col_name("a.`v`", cdf.schema)) - self.assertTrue(verify_col_name("a.`v`", cdf.schema)) - self.assertTrue(verify_col_name("`a`.v", cdf.schema)) - self.assertTrue(verify_col_name("`a`.`v`", cdf.schema)) - self.assertFalse(verify_col_name("`a`.`x`", cdf.schema)) - - cdf = ( - self.connect.range(10) - .withColumn("v", CF.lit(123)) - .withColumn("s.s", CF.struct("id", "v")) - .withColumn("m`", CF.struct("`s.s`", "v")) - ) - - # root - # |-- id: long (nullable = false) - # |-- v: string (nullable = false) - # |-- s.s: struct (nullable = false) - # | |-- id: long (nullable = false) - # | |-- v: string (nullable = false) - # |-- m`: struct (nullable = false) - # | |-- s.s: struct (nullable = false) - # | | |-- id: long (nullable = false) - # | | |-- v: string (nullable = false) - # | |-- v: string (nullable = false) - - self.assertFalse(verify_col_name("s", cdf.schema)) - self.assertFalse(verify_col_name("`s`", cdf.schema)) - self.assertFalse(verify_col_name("s.s", cdf.schema)) - self.assertFalse(verify_col_name("s.`s`", cdf.schema)) - self.assertFalse(verify_col_name("`s`.s", cdf.schema)) - self.assertTrue(verify_col_name("`s.s`", cdf.schema)) - - self.assertFalse(verify_col_name("m", cdf.schema)) - self.assertFalse(verify_col_name("`m`", cdf.schema)) - self.assertTrue(verify_col_name("`m```", cdf.schema)) - - self.assertFalse(verify_col_name("`m```.s", cdf.schema)) - self.assertFalse(verify_col_name("`m```.`s`", cdf.schema)) - self.assertFalse(verify_col_name("`m```.s.s", cdf.schema)) - self.assertFalse(verify_col_name("`m```.s.`s`", cdf.schema)) - self.assertTrue(verify_col_name("`m```.`s.s`", cdf.schema)) - - self.assertFalse(verify_col_name("`m```.s.s.v", cdf.schema)) - self.assertFalse(verify_col_name("`m```.s.`s`.v", cdf.schema)) - self.assertTrue(verify_col_name("`m```.`s.s`.v", cdf.schema)) - self.assertTrue(verify_col_name("`m```.`s.s`.`v`", cdf.schema)) - - def test_truncate_message(self): - cdf1 = self.connect.createDataFrame( - [ - ("a B c"), - ("X y Z"), - ], - ["a" * 4096], - ) - plan1 = cdf1._plan.to_proto(self.connect._client) - - proto_string_1 = self.connect._client._proto_to_string(plan1, False) - self.assertTrue(len(proto_string_1) > 10000, len(proto_string_1)) - proto_string_truncated_1 = self.connect._client._proto_to_string(plan1, True) - self.assertTrue(len(proto_string_truncated_1) < 4000, len(proto_string_truncated_1)) - - cdf2 = cdf1.select("a" * 4096, "a" * 4096, "a" * 4096) - plan2 = cdf2._plan.to_proto(self.connect._client) - - proto_string_2 = self.connect._client._proto_to_string(plan2, False) - self.assertTrue(len(proto_string_2) > 20000, len(proto_string_2)) - proto_string_truncated_2 = self.connect._client._proto_to_string(plan2, True) - self.assertTrue(len(proto_string_truncated_2) < 8000, len(proto_string_truncated_2)) - - cdf3 = cdf1.select("a" * 4096) - for _ in range(64): - cdf3 = cdf3.select("a" * 4096) - plan3 = cdf3._plan.to_proto(self.connect._client) - - proto_string_3 = self.connect._client._proto_to_string(plan3, False) - self.assertTrue(len(proto_string_3) > 128000, len(proto_string_3)) - proto_string_truncated_3 = self.connect._client._proto_to_string(plan3, True) - self.assertTrue(len(proto_string_truncated_3) < 64000, len(proto_string_truncated_3)) - - -class SparkConnectGCTests(SparkConnectSQLTestCase): - @classmethod - def setUpClass(cls): - cls.origin = os.getenv("USER", None) - os.environ["USER"] = "SparkConnectGCTests" - super(SparkConnectGCTests, cls).setUpClass() - - @classmethod - def tearDownClass(cls): - super(SparkConnectGCTests, cls).tearDownClass() - if cls.origin is not None: - os.environ["USER"] = cls.origin - else: - del os.environ["USER"] - - def test_garbage_collection_checkpoint(self): - # SPARK-48258: Make sure garbage-collecting DataFrame remove the paired state - # in Spark Connect server - df = self.connect.range(10).localCheckpoint() - self.assertIsNotNone(df._plan._relation_id) - cached_remote_relation_id = df._plan._relation_id - - jvm = self.spark._jvm - session_holder = getattr( - getattr( - jvm.org.apache.spark.sql.connect.service, - "SparkConnectService$", - ), - "MODULE$", - ).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id) - - # Check the state exists. - self.assertIsNotNone(session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)) - - del df - gc.collect() - - def condition(): - # Check the state was removed up on garbage-collection. - self.assertIsNone(session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)) - - eventually(catch_assertions=True)(condition)() - - def test_garbage_collection_derived_checkpoint(self): - # SPARK-48258: Should keep the cached remote relation when derived DataFrames exist - df = self.connect.range(10).localCheckpoint() - self.assertIsNotNone(df._plan._relation_id) - derived = df.repartition(10) - cached_remote_relation_id = df._plan._relation_id - - jvm = self.spark._jvm - session_holder = getattr( - getattr( - jvm.org.apache.spark.sql.connect.service, - "SparkConnectService$", - ), - "MODULE$", - ).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id) - - # Check the state exists. - self.assertIsNotNone(session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)) - - del df - gc.collect() - - def condition(): - self.assertIsNone(session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)) - - # Should not remove the cache - with self.assertRaises(AssertionError): - eventually(catch_assertions=True, timeout=5)(condition)() - - del derived - gc.collect() - - eventually(catch_assertions=True)(condition)() - - -if __name__ == "__main__": - from pyspark.sql.tests.connect.test_connect_basic import * - - try: - import xmlrunner - - testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) - except ImportError: - testRunner = None - - unittest.main(testRunner=testRunner, verbosity=2) diff --git a/tests/connect/test_connect.py b/tests/connect/test_parquet_simple.py similarity index 98% rename from tests/connect/test_connect.py rename to tests/connect/test_parquet_simple.py index f6afa28876..3b3ddf42d0 100644 --- a/tests/connect/test_connect.py +++ b/tests/connect/test_parquet_simple.py @@ -14,7 +14,7 @@ def test_simple(): print("Starting Daft-Connect server") - # connect_start("sc://localhost:50052") + connect_start("sc://localhost:50052") print("Created spark connect server") diff --git a/tests/connect/test_session.py b/tests/connect/test_session.py deleted file mode 100644 index 6fdea1e7b4..0000000000 --- a/tests/connect/test_session.py +++ /dev/null @@ -1,265 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import threading -import time -import unittest -from typing import Optional - -from pyspark import InheritableThread, inheritable_thread_target -from pyspark.sql.connect.client import DefaultChannelBuilder -from pyspark.sql.connect.session import SparkSession as RemoteSparkSession -from pyspark.testing.connectutils import should_test_connect - -if should_test_connect: - from pyspark.testing.connectutils import ReusedConnectTestCase - - -class CustomChannelBuilder(DefaultChannelBuilder): - @property - def userId(self) -> Optional[str]: - return "abc" - - -class SparkSessionTestCase(unittest.TestCase): - def test_fails_to_create_session_without_remote_and_channel_builder(self): - with self.assertRaises(ValueError): - RemoteSparkSession.builder.getOrCreate() - - def test_fails_to_create_when_both_remote_and_channel_builder_are_specified(self): - with self.assertRaises(ValueError): - ( - RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder("sc://localhost")) - .remote("sc://localhost") - .getOrCreate() - ) - - def test_creates_session_with_channel_builder(self): - test_session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder("sc://other")).getOrCreate() - host = test_session.client.host - test_session.stop() - - self.assertEqual("other", host) - - def test_creates_session_with_remote(self): - test_session = RemoteSparkSession.builder.remote("sc://other").getOrCreate() - host = test_session.client.host - test_session.stop() - - self.assertEqual("other", host) - - def test_session_stop(self): - session = RemoteSparkSession.builder.remote("sc://other").getOrCreate() - - self.assertFalse(session.is_stopped) - session.stop() - self.assertTrue(session.is_stopped) - - def test_session_create_sets_active_session(self): - session = RemoteSparkSession.builder.remote("sc://abc").create() - session2 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() - - self.assertIs(session, session2) - session.stop() - - def test_active_session_expires_when_client_closes(self): - s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() - s2 = RemoteSparkSession.getActiveSession() - - self.assertIs(s1, s2) - - # We don't call close() to avoid executing ExecutePlanResponseReattachableIterator - s1._client._closed = True - - self.assertIsNone(RemoteSparkSession.getActiveSession()) - s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() - - self.assertIsNot(s1, s3) - - def test_default_session_expires_when_client_closes(self): - s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() - s2 = RemoteSparkSession.getDefaultSession() - - self.assertIs(s1, s2) - - # We don't call close() to avoid executing ExecutePlanResponseReattachableIterator - s1._client._closed = True - - self.assertIsNone(RemoteSparkSession.getDefaultSession()) - s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() - - self.assertIsNot(s1, s3) - - -class JobCancellationTests(ReusedConnectTestCase): - def test_tags(self): - self.spark.clearTags() - self.spark.addTag("a") - self.assertEqual(self.spark.getTags(), {"a"}) - self.spark.addTag("b") - self.spark.removeTag("a") - self.assertEqual(self.spark.getTags(), {"b"}) - self.spark.addTag("c") - self.spark.clearTags() - self.assertEqual(self.spark.getTags(), set()) - self.spark.clearTags() - - def test_tags_multithread(self): - output1 = None - output2 = None - - def tag1(): - nonlocal output1 - - self.spark.addTag("tag1") - output1 = self.spark.getTags() - - def tag2(): - nonlocal output2 - - self.spark.addTag("tag2") - output2 = self.spark.getTags() - - t1 = threading.Thread(target=tag1) - t1.start() - t1.join() - t2 = threading.Thread(target=tag2) - t2.start() - t2.join() - - self.assertIsNotNone(output1) - self.assertEqual(output1, {"tag1"}) - self.assertIsNotNone(output2) - self.assertEqual(output2, {"tag2"}) - - def test_interrupt_tag(self): - thread_ids = range(4) - self.check_job_cancellation( - lambda job_group: self.spark.addTag(job_group), - lambda job_group: self.spark.interruptTag(job_group), - thread_ids, - [i for i in thread_ids if i % 2 == 0], - [i for i in thread_ids if i % 2 != 0], - ) - self.spark.clearTags() - - def test_interrupt_all(self): - thread_ids = range(4) - self.check_job_cancellation( - lambda job_group: None, - lambda job_group: self.spark.interruptAll(), - thread_ids, - thread_ids, - [], - ) - self.spark.clearTags() - - def check_job_cancellation(self, setter, canceller, thread_ids, thread_ids_to_cancel, thread_ids_to_run): - job_id_a = "job_ids_to_cancel" - job_id_b = "job_ids_to_run" - threads = [] - - # A list which records whether job is cancelled. - # The index of the array is the thread index which job run in. - is_job_cancelled = [False for _ in thread_ids] - - def run_job(job_id, index): - """ - Executes a job with the group ``job_group``. Each job waits for 3 seconds - and then exits. - """ - try: - setter(job_id) - - def func(itr): - for pdf in itr: - time.sleep(pdf._1.iloc[0]) - yield pdf - - self.spark.createDataFrame([[20]]).repartition(1).mapInPandas(func, schema="_1 LONG").collect() - is_job_cancelled[index] = False - except Exception: - # Assume that exception means job cancellation. - is_job_cancelled[index] = True - - # Test if job succeeded when not cancelled. - run_job(job_id_a, 0) - self.assertFalse(is_job_cancelled[0]) - self.spark.clearTags() - - # Run jobs - for i in thread_ids_to_cancel: - t = threading.Thread(target=run_job, args=(job_id_a, i)) - t.start() - threads.append(t) - - for i in thread_ids_to_run: - t = threading.Thread(target=run_job, args=(job_id_b, i)) - t.start() - threads.append(t) - - # Wait to make sure all jobs are executed. - time.sleep(10) - # And then, cancel one job group. - canceller(job_id_a) - - # Wait until all threads launching jobs are finished. - for t in threads: - t.join() - - for i in thread_ids_to_cancel: - self.assertTrue(is_job_cancelled[i], f"Thread {i}: Job in group A was not cancelled.") - - for i in thread_ids_to_run: - self.assertFalse(is_job_cancelled[i], f"Thread {i}: Job in group B did not succeeded.") - - def test_inheritable_tags(self): - self.check_inheritable_tags(create_thread=lambda target, session: InheritableThread(target, session=session)) - self.check_inheritable_tags( - create_thread=lambda target, session: threading.Thread(target=inheritable_thread_target(session)(target)) - ) - - # Test decorator usage - @inheritable_thread_target(self.spark) - def func(target): - return target() - - self.check_inheritable_tags(create_thread=lambda target, session: threading.Thread(target=func, args=(target,))) - - def check_inheritable_tags(self, create_thread): - spark = self.spark - spark.addTag("a") - first = set() - second = set() - - def get_inner_local_prop(): - spark.addTag("c") - second.update(spark.getTags()) - - def get_outer_local_prop(): - spark.addTag("b") - first.update(spark.getTags()) - t2 = create_thread(target=get_inner_local_prop, session=spark) - t2.start() - t2.join() - - t1 = create_thread(target=get_outer_local_prop, session=spark) - t1.start() - t1.join() - - self.assertEqual(spark.getTags(), {"a"}) - self.assertEqual(first, {"a", "b"}) - self.assertEqual(second, {"a", "b", "c"}) From a1be27548f91ecb95c31045dc3085c0f208fe506 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 6 Nov 2024 14:02:21 -0800 Subject: [PATCH 3/6] feat(connect): Add Python connect_start binding and remove RON dependency - Add connect_start method binding to Python interface - Remove RON dependency and related code - Register connect modules in Python bindings - Clean up unused RON serialization config code This change exposes the connect_start functionality to Python while simplifying dependencies by removing the unused RON serialization library. --- Cargo.lock | 17 ----------------- daft/daft/__init__.pyi | 2 ++ src/daft-connect/Cargo.toml | 2 +- src/daft-connect/src/lib.rs | 13 +------------ src/lib.rs | 1 + tests/connect/test_parquet_simple.py | 3 ++- 6 files changed, 7 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 645373c596..57c2a15b90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1054,9 +1054,6 @@ name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" -dependencies = [ - "serde", -] [[package]] name = "block-buffer" @@ -1904,7 +1901,6 @@ dependencies = [ "eyre", "futures", "pyo3", - "ron", "spark-connect", "tempfile", "tokio", @@ -5227,19 +5223,6 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" -[[package]] -name = "ron" -version = "0.9.0-alpha.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c0bd893640cac34097a74f0c2389ddd54c62d6a3c635fa93cafe6b6bc19be6a" -dependencies = [ - "base64 0.21.7", - "bitflags 2.6.0", - "serde", - "serde_derive", - "unicode-ident", -] - [[package]] name = "rstest" version = "0.18.2" diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 6dad8b6f56..8917eeb2cc 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1232,6 +1232,8 @@ def list_sql_functions() -> list[SQLFunctionStub]: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... +def connect_start(addr: str) -> None: ... + # expr numeric ops def abs(expr: PyExpr) -> PyExpr: ... def cbrt(expr: PyExpr) -> PyExpr: ... diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index e82be71bbc..bd9788ba76 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -4,7 +4,7 @@ dashmap = "6.1.0" eyre = "0.6.12" futures = "0.3.31" pyo3 = {workspace = true, optional = true} -ron = "0.9.0-alpha.0" +#ron = "0.9.0-alpha.0" tokio = {version = "1.40.0", features = ["full"]} tokio-stream = "0.1.16" tonic = "0.12.3" diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 8e021cd66d..8ae14c3a4e 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -12,7 +12,6 @@ use dashmap::DashMap; use eyre::Context; #[cfg(feature = "python")] use pyo3::types::PyModuleMethods; -use ron::extensions::Extensions; use spark_connect::{ analyze_plan_response, command::CommandType, @@ -92,16 +91,6 @@ impl DaftSparkConnectService { } } -fn pretty_config() -> ron::ser::PrettyConfig { - ron::ser::PrettyConfig::default() - .extensions( - Extensions::IMPLICIT_SOME - | Extensions::UNWRAP_NEWTYPES - | Extensions::UNWRAP_VARIANT_NEWTYPES, - ) - .indentor(" ".to_string()) -} - #[tonic::async_trait] impl SparkConnectService for DaftSparkConnectService { type ExecutePlanStream = std::pin::Pin< @@ -382,7 +371,7 @@ impl SparkConnectService for DaftSparkConnectService { #[tracing::instrument(skip_all)] async fn reattach_execute( &self, - request: Request, + _request: Request, ) -> Result, Status> { warn!("reattach_execute operation is not yet implemented"); diff --git a/src/lib.rs b/src/lib.rs index a7a1382538..201f6b229f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -118,6 +118,7 @@ pub mod pylib { daft_sql::register_modules(m)?; daft_functions::register_modules(m)?; daft_functions_json::register_modules(m)?; + daft_connect::register_modules(m)?; m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(build_type))?; diff --git a/tests/connect/test_parquet_simple.py b/tests/connect/test_parquet_simple.py index 3b3ddf42d0..1c3bf26eb0 100644 --- a/tests/connect/test_parquet_simple.py +++ b/tests/connect/test_parquet_simple.py @@ -10,11 +10,12 @@ from pyspark.sql import SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.functions import col +from daft.daft import connect_start def test_simple(): print("Starting Daft-Connect server") - connect_start("sc://localhost:50052") + # connect_start("sc://localhost:50051") print("Created spark connect server") From 5f84fe1b763fe78c2af51d0b6b543ac4fbddcf96 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 6 Nov 2024 14:16:02 -0800 Subject: [PATCH 4/6] minify test --- tests/connect/test_parquet_simple.py | 123 +++++---------------------- 1 file changed, 19 insertions(+), 104 deletions(-) diff --git a/tests/connect/test_parquet_simple.py b/tests/connect/test_parquet_simple.py index 1c3bf26eb0..b1e99778bf 100644 --- a/tests/connect/test_parquet_simple.py +++ b/tests/connect/test_parquet_simple.py @@ -1,128 +1,43 @@ -# def test_apply_lambda -# def test_apply_module_func -# def test_apply_inline_func -# def test_apply_lambda_pyobj - from __future__ import annotations +import pathlib import time +import pyarrow as pa +import pyarrow.parquet as papq from pyspark.sql import SparkSession from pyspark.sql.dataframe import DataFrame -from pyspark.sql.functions import col + from daft.daft import connect_start -def test_simple(): - print("Starting Daft-Connect server") - # connect_start("sc://localhost:50051") +def test_read_parquet(tmpdir): + # Create a temporary directory for the parquet file + tmpdir = pathlib.Path(tmpdir) + parquet_path = tmpdir / "increasing_id_data.parquet" + + # Create test data with PyArrow + data = pa.Table.from_pydict({"id": [0, 1, 2, 3, 4]}) + + # Write parquet file using PyArrow + papq.write_table(data, parquet_path) - print("Created spark connect server") + # todo: have env variable to control whether we embed the server or not + connect_start("sc://localhost:50051") # Create a Spark session using Spark Connect - spark: SparkSession = ( - SparkSession.builder.appName("SparkConnectExample").remote("sc://localhost:50051").getOrCreate() - ) + spark: SparkSession = SparkSession.builder.appName("DaftParquetTest").remote("sc://localhost:50051").getOrCreate() print("Spark session created") - # Read the Parquet file back into a DataFrame - df: DataFrame = spark.read.parquet("/Users/andrewgazelka/Projects/simple-spark-connect/increasing_id_data.parquet") + # Read the Parquet file back into a DataFrame using Spark Connect + df: DataFrame = spark.read.parquet(str(parquet_path)) print("DataFrame read from Parquet file") - # The DataFrame remains unchanged: - # +---+ - # | id| - # +---+ - # | 0| - # | 1| - # | 2| - # | 3| - # | 4| - # +---+ print("DataFrame schema:") df.printSchema() - # root - # |-- id: long (nullable = false) - - print("\nDataFrame content:") - df.show() - - print("done showing") - - # Perform operations on the DataFrame - # 1. filter(col("id") > 2): Select only rows where 'id' is greater than 2 - # 2. withColumn("id2", col("id") + 2): Add a new column 'id2' that is 'id' plus 2 - result: DataFrame = df.filter(col("id") > 2).withColumn("id2", col("id") + 2) - - print("\nFiltered and transformed DataFrame:") - result.show() - - # result_pandas = result.toPandas() - # The resulting DataFrame looks like this: - # +---+---+ - # | id|id2| - # +---+---+ - # | 3| 5| - # | 4| 6| - # +---+---+ - # Explanation: - # 1. Only rows with id > 2 are kept (3 and 4) - # 2. A new column 'id2' is added with values id + 2 - - # Stop the Spark session - # spark.sql("select * from increasing_id_data").show() spark.stop() print("Spark session stopped") - - # Waiting for 10 seconds time.sleep(2) - print("End of main function") - - -# from daft. -# -# -# def add_1(x): -# return x + 1 -# -# -# def test_apply_module_func(): -# df = daft.from_pydict({"a": [1, 2, 3]}) -# df = df.with_column("a_plus_1", df["a"].apply(add_1, return_dtype=DataType.int32())) -# assert df.to_pydict() == {"a": [1, 2, 3], "a_plus_1": [2, 3, 4]} -# -# -# def test_apply_lambda(): -# df = daft.from_pydict({"a": [1, 2, 3]}) -# df = df.with_column("a_plus_1", df["a"].apply(lambda x: x + 1, return_dtype=DataType.int32())) -# assert df.to_pydict() == {"a": [1, 2, 3], "a_plus_1": [2, 3, 4]} -# -# -# def test_apply_inline_func(): -# def inline_add_1(x): -# return x + 1 -# -# df = daft.from_pydict({"a": [1, 2, 3]}) -# df = df.with_column("a_plus_1", df["a"].apply(inline_add_1, return_dtype=DataType.int32())) -# assert df.to_pydict() == {"a": [1, 2, 3], "a_plus_1": [2, 3, 4]} -# -# -# @dataclasses.dataclass -# class MyObj: -# x: int -# -# -# def test_apply_obj(): -# df = daft.from_pydict({"obj": [MyObj(x=0), MyObj(x=0), MyObj(x=0)]}) -# -# def inline_mutate_obj(obj): -# obj.x = 1 -# return obj -# -# df = df.with_column("mut_obj", df["obj"].apply(inline_mutate_obj, return_dtype=DataType.python())) -# result = df.to_pydict() -# for mut_obj in result["mut_obj"]: -# assert mut_obj.x == 1 From bab1620e4a25502190889f56731903c3d687553b Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 6 Nov 2024 16:00:44 -0800 Subject: [PATCH 5/6] test passes yay --- Cargo.lock | 1 + Cargo.toml | 1 + src/daft-connect/Cargo.toml | 3 +- src/daft-connect/src/command.rs | 137 +++++++++++++++++- src/daft-connect/src/config.rs | 21 +-- src/daft-connect/src/convert.rs | 2 +- .../convert/data_conversion/show_string.rs | 4 +- src/daft-connect/src/lib.rs | 69 +++++---- src/daft-connect/src/session.rs | 2 +- src/daft-connect/src/util.rs | 2 +- tests/connect/test_parquet_simple.py | 44 +++--- 11 files changed, 216 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 57c2a15b90..03fbe4ebe3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1890,6 +1890,7 @@ version = "0.3.0-dev0" dependencies = [ "arrow2", "common-daft-config", + "common-file-formats", "daft-core", "daft-dsl", "daft-local-execution", diff --git a/Cargo.toml b/Cargo.toml index 68313903c1..5e8e08665b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -181,6 +181,7 @@ comfy-table = "7.1.1" common-daft-config = {path = "src/common/daft-config"} common-display = {path = "src/common/display"} common-error = {path = "src/common/error", default-features = false} +common-file-formats = {path = "src/common/file-formats"} daft-connect = {path = "src/daft-connect", default-features = false} daft-core = {path = "src/daft-core"} daft-dsl = {path = "src/daft-dsl"} diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index bd9788ba76..10e06329fe 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -4,7 +4,7 @@ dashmap = "6.1.0" eyre = "0.6.12" futures = "0.3.31" pyo3 = {workspace = true, optional = true} -#ron = "0.9.0-alpha.0" +# ron = "0.9.0-alpha.0" tokio = {version = "1.40.0", features = ["full"]} tokio-stream = "0.1.16" tonic = "0.12.3" @@ -13,6 +13,7 @@ tracing-tracy = "0.11.3" uuid = {version = "1.10.0", features = ["v4"]} arrow2.workspace = true common-daft-config.workspace = true +common-file-formats.workspace = true daft-core.workspace = true daft-dsl.workspace = true daft-local-execution.workspace = true diff --git a/src/daft-connect/src/command.rs b/src/daft-connect/src/command.rs index c999e2bac0..c0a835ca5e 100644 --- a/src/daft-connect/src/command.rs +++ b/src/daft-connect/src/command.rs @@ -1,21 +1,26 @@ // Stream of Result -use std::thread; +use std::{ops::ControlFlow, thread}; use arrow2::io::ipc::write::StreamWriter; +use common_file_formats::FileFormat; use daft_table::Table; use eyre::Context; use futures::TryStreamExt; use spark_connect::{ execute_plan_response::{ArrowBatch, ResponseType, ResultComplete}, spark_connect_service_server::SparkConnectService, - ExecutePlanResponse, Relation, + write_operation::{SaveMode, SaveType}, + ExecutePlanResponse, Relation, WriteOperation, }; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::Status; use uuid::Uuid; -use crate::{convert::convert_data, DaftSparkConnectService, Session}; +use crate::{ + convert::{convert_data, run_local, to_logical_plan}, + invalid_argument, unimplemented_err, DaftSparkConnectService, Session, +}; type DaftStream = ::ExecutePlanStream; @@ -89,14 +94,14 @@ impl Session { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let mut channel = ExecutablePlanChannel { - session_id: self.id().to_string(), + session_id: self.client_side_session_id().to_string(), server_side_session_id: self.server_side_session_id().to_string(), operation_id: operation_id.clone(), tx: tx.clone(), }; thread::spawn({ - let session_id = self.id().to_string(); + let session_id = self.client_side_session_id().to_string(); let server_side_session_id = self.server_side_session_id().to_string(); move || { let result = convert_data(command, &mut channel); @@ -125,4 +130,126 @@ impl Session { Ok(Box::pin(recv_stream)) } + + pub fn handle_write_operation( + &self, + operation: WriteOperation, + operation_id: String, + ) -> Result { + let mode = operation.mode(); + + let WriteOperation { + input, + source, + sort_column_names, + partitioning_columns, + bucket_by, + options, + clustering_columns, + save_type, + mode: _, + } = operation; + + let input = input.ok_or_else(|| invalid_argument!("input is None"))?; + + let source = source.unwrap_or_else(|| "parquet".to_string()); + if source != "parquet" { + return Err(unimplemented_err!( + "Only writing parquet is supported for now but got {source}" + )); + } + + match mode { + SaveMode::Unspecified => {} + SaveMode::Append => { + return Err(unimplemented_err!("Append mode is not yet supported")); + } + SaveMode::Overwrite => { + return Err(unimplemented_err!("Overwrite mode is not yet supported")); + } + SaveMode::ErrorIfExists => { + return Err(unimplemented_err!( + "ErrorIfExists mode is not yet supported" + )); + } + SaveMode::Ignore => { + return Err(unimplemented_err!("Ignore mode is not yet supported")); + } + } + + if !sort_column_names.is_empty() { + return Err(unimplemented_err!("Sort by columns is not yet supported")); + } + + if !partitioning_columns.is_empty() { + return Err(unimplemented_err!( + "Partitioning columns is not yet supported" + )); + } + + if bucket_by.is_some() { + return Err(unimplemented_err!("Bucket by columns is not yet supported")); + } + + if !options.is_empty() { + return Err(unimplemented_err!("Options are not yet supported")); + } + + if !clustering_columns.is_empty() { + return Err(unimplemented_err!( + "Clustering columns is not yet supported" + )); + } + + let save_type = save_type.ok_or_else(|| invalid_argument!("save_type is required"))?; + + let save_path = match save_type { + SaveType::Path(path) => path, + SaveType::Table(_) => { + return Err(unimplemented_err!("Save type table is not yet supported")); + } + }; + + std::thread::scope(|scope| { + let res = scope.spawn(|| { + let plan = to_logical_plan(input) + .map_err(|_| Status::internal("Failed to convert to logical plan"))?; + + // todo: assuming this is parquet + // todo: is save_path right? + let plan = plan + .table_write(&save_path, FileFormat::Parquet, None, None, None) + .map_err(|_| Status::internal("Failed to write table"))?; + + let plan = plan.build(); + + run_local( + &plan, + |_table| ControlFlow::Continue(()), + || ControlFlow::Break(()), + ) + .map_err(|e| Status::internal(format!("Failed to write table: {e}")))?; + + Result::<(), Status>::Ok(()) + }); + + res.join().unwrap() + })?; + + let session_id = self.client_side_session_id().to_string(); + let server_side_session_id = self.server_side_session_id().to_string(); + + Ok(Box::pin(futures::stream::once(async { + Ok(ExecutePlanResponse { + session_id, + server_side_session_id, + operation_id, + response_id: "abcxyz".to_string(), + metrics: None, + observed_metrics: vec![], + schema: None, + response_type: Some(ResponseType::ResultComplete(ResultComplete {})), + }) + }))) + } } diff --git a/src/daft-connect/src/config.rs b/src/daft-connect/src/config.rs index 863a800b2a..21f71a379f 100644 --- a/src/daft-connect/src/config.rs +++ b/src/daft-connect/src/config.rs @@ -11,7 +11,7 @@ use crate::Session; impl Session { fn config_response(&self) -> ConfigResponse { ConfigResponse { - session_id: self.id().to_string(), + session_id: self.client_side_session_id().to_string(), server_side_session_id: self.server_side_session_id().to_string(), pairs: vec![], warnings: vec![], @@ -21,7 +21,8 @@ impl Session { pub fn set(&mut self, operation: Set) -> Result { let mut response = self.config_response(); - let span = tracing::info_span!("set", session_id = %self.id(), ?operation); + let span = + tracing::info_span!("set", session_id = %self.client_side_session_id(), ?operation); let _enter = span.enter(); for KeyValue { key, value } in operation.pairs { @@ -45,7 +46,7 @@ impl Session { pub fn get(&self, operation: Get) -> Result { let mut response = self.config_response(); - let span = tracing::info_span!("get", session_id = %self.id()); + let span = tracing::info_span!("get", session_id = %self.client_side_session_id()); let _enter = span.enter(); for key in operation.keys { @@ -59,7 +60,8 @@ impl Session { pub fn get_with_default(&self, operation: GetWithDefault) -> Result { let mut response = self.config_response(); - let span = tracing::info_span!("get_with_default", session_id = %self.id()); + let span = + tracing::info_span!("get_with_default", session_id = %self.client_side_session_id()); let _enter = span.enter(); for KeyValue { @@ -79,7 +81,7 @@ impl Session { pub fn get_option(&self, operation: GetOption) -> Result { let mut response = self.config_response(); - let span = tracing::info_span!("get_option", session_id = %self.id()); + let span = tracing::info_span!("get_option", session_id = %self.client_side_session_id()); let _enter = span.enter(); for key in operation.keys { @@ -93,7 +95,7 @@ impl Session { pub fn get_all(&self, operation: GetAll) -> Result { let mut response = self.config_response(); - let span = tracing::info_span!("get_all", session_id = %self.id()); + let span = tracing::info_span!("get_all", session_id = %self.client_side_session_id()); let _enter = span.enter(); let Some(prefix) = operation.prefix else { @@ -119,7 +121,7 @@ impl Session { pub fn unset(&mut self, operation: Unset) -> Result { let mut response = self.config_response(); - let span = tracing::info_span!("unset", session_id = %self.id()); + let span = tracing::info_span!("unset", session_id = %self.client_side_session_id()); let _enter = span.enter(); for key in operation.keys { @@ -137,10 +139,11 @@ impl Session { pub fn is_modifiable(&self, _operation: IsModifiable) -> Result { let response = self.config_response(); - let span = tracing::info_span!("is_modifiable", session_id = %self.id()); + let span = + tracing::info_span!("is_modifiable", session_id = %self.client_side_session_id()); let _enter = span.enter(); - tracing::warn!(session_id = %self.id(), "is_modifiable operation not yet implemented"); + tracing::warn!(session_id = %self.client_side_session_id(), "is_modifiable operation not yet implemented"); // todo: need to implement this Ok(response) } diff --git a/src/daft-connect/src/convert.rs b/src/daft-connect/src/convert.rs index 98cfa9e4d8..ddb3530669 100644 --- a/src/daft-connect/src/convert.rs +++ b/src/daft-connect/src/convert.rs @@ -18,7 +18,7 @@ use eyre::Context; pub use plan_conversion::to_logical_plan; pub use schema_conversion::connect_schema; -pub fn map_to_tables( +pub fn run_local( logical_plan: &LogicalPlanRef, mut f: impl FnMut(&Table) -> T, default: impl FnOnce() -> T, diff --git a/src/daft-connect/src/convert/data_conversion/show_string.rs b/src/daft-connect/src/convert/data_conversion/show_string.rs index 35c5e3d602..bed2952a21 100644 --- a/src/daft-connect/src/convert/data_conversion/show_string.rs +++ b/src/daft-connect/src/convert/data_conversion/show_string.rs @@ -6,7 +6,7 @@ use spark_connect::ShowString; use crate::{ command::ConcreteDataChannel, - convert::{map_to_tables, plan_conversion::to_logical_plan}, + convert::{plan_conversion::to_logical_plan, run_local}, }; pub fn show_string( @@ -28,7 +28,7 @@ pub fn show_string( let logical_plan = to_logical_plan(input)?.build(); - map_to_tables( + run_local( &logical_plan, |table| -> eyre::Result<()> { let display = format!("{table}"); diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 8ae14c3a4e..33466183e8 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -5,6 +5,7 @@ #![feature(iter_from_coroutine)] #![feature(stmt_expr_attributes)] #![feature(try_trait_v2_residual)] +#![warn(unused)] use std::ops::ControlFlow; @@ -27,7 +28,7 @@ use tonic::{transport::Server, Request, Response, Status}; use tracing::{info, warn}; use uuid::Uuid; -use crate::{convert::map_to_tables, session::Session}; +use crate::{convert::run_local, session::Session}; mod command; mod config; @@ -139,59 +140,65 @@ impl SparkConnectService for DaftSparkConnectService { match command { CommandType::RegisterFunction(_) => { - Err(unimplemented!("RegisterFunction not implemented")) + Err(unimplemented_err!("RegisterFunction not implemented")) } - CommandType::WriteOperation(_) => { - Err(unimplemented!("WriteOperation not implemented")) + CommandType::WriteOperation(op) => { + println!("WriteOperation: {:#2?}", op); + + let result = session.handle_write_operation(op, operation)?; + + return Ok(Response::new(result)); } CommandType::CreateDataframeView(_) => { - Err(unimplemented!("CreateDataframeView not implemented")) + Err(unimplemented_err!("CreateDataframeView not implemented")) } CommandType::WriteOperationV2(_) => { - Err(unimplemented!("WriteOperationV2 not implemented")) + Err(unimplemented_err!("WriteOperationV2 not implemented")) } CommandType::SqlCommand(..) => { - Err(unimplemented!("SQL execution not yet implemented")) - } - CommandType::WriteStreamOperationStart(_) => { - Err(unimplemented!("WriteStreamOperationStart not implemented")) + Err(unimplemented_err!("SQL execution not yet implemented")) } + CommandType::WriteStreamOperationStart(_) => Err(unimplemented_err!( + "WriteStreamOperationStart not implemented" + )), CommandType::StreamingQueryCommand(_) => { - Err(unimplemented!("StreamingQueryCommand not implemented")) + Err(unimplemented_err!("StreamingQueryCommand not implemented")) } CommandType::GetResourcesCommand(_) => { - Err(unimplemented!("GetResourcesCommand not implemented")) + Err(unimplemented_err!("GetResourcesCommand not implemented")) } - CommandType::StreamingQueryManagerCommand(_) => Err(unimplemented!( + CommandType::StreamingQueryManagerCommand(_) => Err(unimplemented_err!( "StreamingQueryManagerCommand not implemented" )), CommandType::RegisterTableFunction(_) => { - Err(unimplemented!("RegisterTableFunction not implemented")) + Err(unimplemented_err!("RegisterTableFunction not implemented")) } - CommandType::StreamingQueryListenerBusCommand(_) => Err(unimplemented!( + CommandType::StreamingQueryListenerBusCommand(_) => Err(unimplemented_err!( "StreamingQueryListenerBusCommand not implemented" )), CommandType::RegisterDataSource(_) => { - Err(unimplemented!("RegisterDataSource not implemented")) + Err(unimplemented_err!("RegisterDataSource not implemented")) } - CommandType::CreateResourceProfileCommand(_) => Err(unimplemented!( + CommandType::CreateResourceProfileCommand(_) => Err(unimplemented_err!( "CreateResourceProfileCommand not implemented" )), CommandType::CheckpointCommand(_) => { - Err(unimplemented!("CheckpointCommand not implemented")) + Err(unimplemented_err!("CheckpointCommand not implemented")) } - CommandType::RemoveCachedRemoteRelationCommand(_) => Err(unimplemented!( + CommandType::RemoveCachedRemoteRelationCommand(_) => Err(unimplemented_err!( "RemoveCachedRemoteRelationCommand not implemented" )), CommandType::MergeIntoTableCommand(_) => { - Err(unimplemented!("MergeIntoTableCommand not implemented")) + Err(unimplemented_err!("MergeIntoTableCommand not implemented")) + } + CommandType::Extension(_) => { + Err(unimplemented_err!("Extension not implemented")) } - CommandType::Extension(_) => Err(unimplemented!("Extension not implemented")), } } }?; - Err(unimplemented!("Unsupported plan type")) + Err(unimplemented_err!("Unsupported plan type")) } #[tracing::instrument(skip_all)] @@ -230,7 +237,7 @@ impl SparkConnectService for DaftSparkConnectService { &self, _request: Request>, ) -> Result, Status> { - Err(unimplemented!( + Err(unimplemented_err!( "add_artifacts operation is not yet implemented" )) } @@ -310,7 +317,7 @@ impl SparkConnectService for DaftSparkConnectService { let logical_plan = logical_plan.build(); let res = std::thread::spawn(move || { - let result = map_to_tables( + let result = run_local( &logical_plan, |table| { let table = format!("{table}"); @@ -342,7 +349,7 @@ impl SparkConnectService for DaftSparkConnectService { let response = Response::new(res); Ok(response) } - _ => Err(unimplemented!( + _ => Err(unimplemented_err!( "Analyze plan operation is not yet implemented" )), } @@ -354,7 +361,7 @@ impl SparkConnectService for DaftSparkConnectService { _request: Request, ) -> Result, Status> { println!("got artifact status"); - Err(unimplemented!( + Err(unimplemented_err!( "artifact_status operation is not yet implemented" )) } @@ -365,7 +372,9 @@ impl SparkConnectService for DaftSparkConnectService { _request: Request, ) -> Result, Status> { println!("got interrupt"); - Err(unimplemented!("interrupt operation is not yet implemented")) + Err(unimplemented_err!( + "interrupt operation is not yet implemented" + )) } #[tracing::instrument(skip_all)] @@ -394,7 +403,7 @@ impl SparkConnectService for DaftSparkConnectService { let session = self.get_session(&request.session_id)?; let response = ReleaseExecuteResponse { - session_id: session.id().to_string(), + session_id: session.client_side_session_id().to_string(), server_side_session_id: session.server_side_session_id().to_string(), operation_id: Some(request.operation_id), // todo: impl properly }; @@ -408,7 +417,7 @@ impl SparkConnectService for DaftSparkConnectService { _request: Request, ) -> Result, Status> { println!("got release session"); - Err(unimplemented!( + Err(unimplemented_err!( "release_session operation is not yet implemented" )) } @@ -419,7 +428,7 @@ impl SparkConnectService for DaftSparkConnectService { _request: Request, ) -> Result, Status> { println!("got fetch error details"); - Err(unimplemented!( + Err(unimplemented_err!( "fetch_error_details operation is not yet implemented" )) } diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs index 1be05b3948..72b477478f 100644 --- a/src/daft-connect/src/session.rs +++ b/src/daft-connect/src/session.rs @@ -38,7 +38,7 @@ impl Session { } } - pub fn id(&self) -> &str { + pub fn client_side_session_id(&self) -> &str { &self.id } diff --git a/src/daft-connect/src/util.rs b/src/daft-connect/src/util.rs index 29a593f342..0cdaffff54 100644 --- a/src/daft-connect/src/util.rs +++ b/src/daft-connect/src/util.rs @@ -9,7 +9,7 @@ macro_rules! invalid_argument { } #[macro_export] -macro_rules! unimplemented { +macro_rules! unimplemented_err { ($arg: tt) => {{ let msg = format!($arg); ::tonic::Status::unimplemented(msg) diff --git a/tests/connect/test_parquet_simple.py b/tests/connect/test_parquet_simple.py index b1e99778bf..cb3ba9f1b1 100644 --- a/tests/connect/test_parquet_simple.py +++ b/tests/connect/test_parquet_simple.py @@ -12,32 +12,36 @@ def test_read_parquet(tmpdir): - # Create a temporary directory for the parquet file - tmpdir = pathlib.Path(tmpdir) - parquet_path = tmpdir / "increasing_id_data.parquet" + # Convert tmpdir to Path object + test_dir = pathlib.Path(tmpdir) + input_parquet_path = test_dir / "input.parquet" - # Create test data with PyArrow - data = pa.Table.from_pydict({"id": [0, 1, 2, 3, 4]}) + # Create sample data with sequential IDs + sample_data = pa.Table.from_pydict({"id": [0, 1, 2, 3, 4]}) - # Write parquet file using PyArrow - papq.write_table(data, parquet_path) + # Write sample data to input parquet file + papq.write_table(sample_data, input_parquet_path) - # todo: have env variable to control whether we embed the server or not + # Start Daft Connect server + # TODO: Add env var to control server embedding connect_start("sc://localhost:50051") - # Create a Spark session using Spark Connect - spark: SparkSession = SparkSession.builder.appName("DaftParquetTest").remote("sc://localhost:50051").getOrCreate() + # Initialize Spark Connect session + spark_session: SparkSession = ( + SparkSession.builder.appName("DaftParquetReadWriteTest").remote("sc://localhost:50051").getOrCreate() + ) - print("Spark session created") + # Read input parquet with Spark Connect + spark_df: DataFrame = spark_session.read.parquet(str(input_parquet_path)) - # Read the Parquet file back into a DataFrame using Spark Connect - df: DataFrame = spark.read.parquet(str(parquet_path)) - print("DataFrame read from Parquet file") + # Write DataFrame to output parquet + output_parquet_path = test_dir / "output.parquet" + spark_df.write.parquet(str(output_parquet_path)) - print("DataFrame schema:") - df.printSchema() + # Verify output matches input + output_data = papq.read_table(output_parquet_path) + assert output_data.equals(sample_data) - spark.stop() - print("Spark session stopped") - time.sleep(2) - print("End of main function") + # Clean up Spark session + spark_session.stop() + time.sleep(2) # Allow time for session cleanup From afefdcef4f87df10cbb1cbb5ef6ce195bd8ffaf1 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 6 Nov 2024 16:07:24 -0800 Subject: [PATCH 6/6] remove comment --- src/daft-connect/src/command.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/daft-connect/src/command.rs b/src/daft-connect/src/command.rs index c0a835ca5e..83e1522f5a 100644 --- a/src/daft-connect/src/command.rs +++ b/src/daft-connect/src/command.rs @@ -1,5 +1,3 @@ -// Stream of Result - use std::{ops::ControlFlow, thread}; use arrow2::io::ipc::write::StreamWriter;