diff --git a/Cargo.lock b/Cargo.lock index ec65979338..03fbe4ebe3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1843,6 +1843,7 @@ dependencies = [ "common-tracing", "common-version", "daft-compression", + "daft-connect", "daft-core", "daft-csv", "daft-dsl", @@ -1883,6 +1884,35 @@ dependencies = [ "url", ] +[[package]] +name = "daft-connect" +version = "0.3.0-dev0" +dependencies = [ + "arrow2", + "common-daft-config", + "common-file-formats", + "daft-core", + "daft-dsl", + "daft-local-execution", + "daft-physical-plan", + "daft-plan", + "daft-schema", + "daft-table", + "dashmap", + "eyre", + "futures", + "pyo3", + "spark-connect", + "tempfile", + "tokio", + "tokio-stream", + "tonic", + "tracing", + "tracing-subscriber", + "tracing-tracy", + "uuid 1.10.0", +] + [[package]] name = "daft-core" version = "0.3.0-dev0" @@ -2477,6 +2507,20 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deflate64" version = "0.1.9" @@ -2721,6 +2765,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fallible-streaming-iterator" version = "0.1.9" @@ -2825,9 +2879,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -2840,9 +2894,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -2850,15 +2904,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -2867,9 +2921,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" @@ -2888,9 +2942,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -2899,15 +2953,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -2917,9 +2971,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -2933,6 +2987,19 @@ dependencies = [ "slab", ] +[[package]] +name = "generator" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb949699c3e4df3a183b1d2142cb24277057055ed23c68ed58894f76c517223" +dependencies = [ + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.58.0", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -3457,7 +3524,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core", + "windows-core 0.52.0", ] [[package]] @@ -3513,6 +3580,12 @@ dependencies = [ "quick-error 2.0.1", ] +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -3866,6 +3939,19 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lz4" version = "1.26.0" @@ -5298,6 +5384,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -5832,7 +5924,7 @@ dependencies = [ "ntapi", "once_cell", "rayon", - "windows", + "windows 0.52.0", ] [[package]] @@ -6091,9 +6183,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.2" +version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" dependencies = [ "backtrace", "bytes", @@ -6288,6 +6380,38 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tracing-tracy" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc775fdaf33c3dfd19dc354729e65e87914bc67dcdc390ca1210807b8bee5902" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracy-client", +] + +[[package]] +name = "tracy-client" +version = "0.17.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "746b078c6a09ebfd5594609049e07116735c304671eaab06ce749854d23435bc" +dependencies = [ + "loom", + "once_cell", + "tracy-client-sys", +] + +[[package]] +name = "tracy-client-sys" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3637e734239e12ab152cd269302500bd063f37624ee210cd04b4936ed671f3b1" +dependencies = [ + "cc", + "windows-targets 0.52.6", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -6702,7 +6826,17 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ - "windows-core", + "windows-core 0.52.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core 0.58.0", "windows-targets 0.52.6", ] @@ -6715,6 +6849,60 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index 5204ac2f81..5e8e08665b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,10 +9,11 @@ common-system-info = {path = "src/common/system-info", default-features = false} common-tracing = {path = "src/common/tracing", default-features = false} common-version = {path = "src/common/version", default-features = false} daft-compression = {path = "src/daft-compression", default-features = false} +daft-connect = {path = "src/daft-connect", optional = true} daft-core = {path = "src/daft-core", default-features = false} daft-csv = {path = "src/daft-csv", default-features = false} daft-dsl = {path = "src/daft-dsl", default-features = false} -daft-functions = {path = "src/daft-functions", default-features = false} +daft-functions = {path = "src/daft-functions"} daft-functions-json = {path = "src/daft-functions-json", default-features = false} daft-hash = {path = "src/daft-hash", default-features = false} daft-image = {path = "src/daft-image", default-features = false} @@ -44,6 +45,11 @@ python = [ "common-display/python", "common-resource-request/python", "common-system-info/python", + "common-daft-config/python", + "common-display/python", + "common-resource-request/python", + "common-system-info/python", + "daft-connect/python", "daft-core/python", "daft-csv/python", "daft-dsl/python", @@ -61,14 +67,10 @@ python = [ "daft-scheduler/python", "daft-sql/python", "daft-stats/python", - "daft-table/python", - "daft-functions/python", - "daft-functions-json/python", + "daft-stats/python", "daft-writers/python", - "common-daft-config/python", - "common-system-info/python", - "common-display/python", - "common-resource-request/python", + "daft-table/python", + "dep:daft-connect", "dep:pyo3", "dep:pyo3-log" ] @@ -154,6 +156,7 @@ members = [ "src/daft-table", "src/daft-writers", "src/hyperloglog", + "src/daft-connect", "src/parquet2", # "src/spark-connect-script", "src/generated/spark-connect" @@ -161,6 +164,7 @@ members = [ [workspace.dependencies] ahash = "0.8.11" +anyhow = "1.0.89" approx = "0.5.1" async-compat = "0.2.3" async-compression = {version = "0.4.12", features = [ @@ -174,8 +178,21 @@ bytes = "1.6.0" chrono = "0.4.38" chrono-tz = "0.8.4" comfy-table = "7.1.1" +common-daft-config = {path = "src/common/daft-config"} +common-display = {path = "src/common/display"} common-error = {path = "src/common/error", default-features = false} +common-file-formats = {path = "src/common/file-formats"} +daft-connect = {path = "src/daft-connect", default-features = false} +daft-core = {path = "src/daft-core"} +daft-dsl = {path = "src/daft-dsl"} daft-hash = {path = "src/daft-hash"} +daft-local-execution = {path = "src/daft-local-execution"} +daft-micropartition = {path = "src/daft-micropartition"} +daft-physical-plan = {path = "src/daft-physical-plan"} +daft-plan = {path = "src/daft-plan"} +daft-schema = {path = "src/daft-schema"} +daft-sql = {path = "src/daft-sql"} +daft-table = {path = "src/daft-table"} derivative = "2.2.0" derive_builder = "0.20.2" divan = "0.1.14" @@ -204,6 +221,7 @@ serde_json = "1.0.116" sha1 = "0.11.0-pre.4" sketches-ddsketch = {version = "0.2.2", features = ["use_serde"]} snafu = {version = "0.7.4", features = ["futures"]} +spark-connect = {path = "src/spark-connect", default-features = false} sqlparser = "0.51.0" sysinfo = "0.30.12" tango-bench = "0.6.0" @@ -233,7 +251,7 @@ path = "src/arrow2" version = "1.3.3" [workspace.dependencies.derive_more] -features = ["display"] +features = ["display", "from", "constructor"] version = "1.0.0" [workspace.dependencies.lazy_static] @@ -321,7 +339,7 @@ uninlined_format_args = "allow" unnecessary_wraps = "allow" unnested_or_patterns = "allow" unreadable_literal = "allow" -# todo: remove? +# todo: remove this at some point unsafe_derive_deserialize = "allow" unused_async = "allow" # used_underscore_items = "allow" # REMOVE diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 6dad8b6f56..8917eeb2cc 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1232,6 +1232,8 @@ def list_sql_functions() -> list[SQLFunctionStub]: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... +def connect_start(addr: str) -> None: ... + # expr numeric ops def abs(expr: PyExpr) -> PyExpr: ... def cbrt(expr: PyExpr) -> PyExpr: ... diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml new file mode 100644 index 0000000000..10e06329fe --- /dev/null +++ b/src/daft-connect/Cargo.toml @@ -0,0 +1,39 @@ +[dependencies] +dashmap = "6.1.0" +# papaya = "0.1.3" +eyre = "0.6.12" +futures = "0.3.31" +pyo3 = {workspace = true, optional = true} +# ron = "0.9.0-alpha.0" +tokio = {version = "1.40.0", features = ["full"]} +tokio-stream = "0.1.16" +tonic = "0.12.3" +tracing-subscriber = {version = "0.3.18", features = ["env-filter"]} +tracing-tracy = "0.11.3" +uuid = {version = "1.10.0", features = ["v4"]} +arrow2.workspace = true +common-daft-config.workspace = true +common-file-formats.workspace = true +daft-core.workspace = true +daft-dsl.workspace = true +daft-local-execution.workspace = true +daft-physical-plan.workspace = true +daft-plan.workspace = true +daft-schema.workspace = true +daft-table.workspace = true +spark-connect.workspace = true +tracing.workspace = true + +[dev-dependencies] +tempfile = "3.4.0" + +[features] +python = ["dep:pyo3"] + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "daft-connect" +version = {workspace = true} diff --git a/src/daft-connect/src/command.rs b/src/daft-connect/src/command.rs new file mode 100644 index 0000000000..83e1522f5a --- /dev/null +++ b/src/daft-connect/src/command.rs @@ -0,0 +1,253 @@ +use std::{ops::ControlFlow, thread}; + +use arrow2::io::ipc::write::StreamWriter; +use common_file_formats::FileFormat; +use daft_table::Table; +use eyre::Context; +use futures::TryStreamExt; +use spark_connect::{ + execute_plan_response::{ArrowBatch, ResponseType, ResultComplete}, + spark_connect_service_server::SparkConnectService, + write_operation::{SaveMode, SaveType}, + ExecutePlanResponse, Relation, WriteOperation, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::Status; +use uuid::Uuid; + +use crate::{ + convert::{convert_data, run_local, to_logical_plan}, + invalid_argument, unimplemented_err, DaftSparkConnectService, Session, +}; + +type DaftStream = ::ExecutePlanStream; + +struct ExecutablePlanChannel { + session_id: String, + server_side_session_id: String, + operation_id: String, + tx: tokio::sync::mpsc::UnboundedSender>, +} + +pub trait ConcreteDataChannel { + fn send_table(&mut self, table: &Table) -> eyre::Result<()>; +} + +impl ConcreteDataChannel for ExecutablePlanChannel { + fn send_table(&mut self, table: &Table) -> eyre::Result<()> { + let mut data = Vec::new(); + + let mut writer = StreamWriter::new( + &mut data, + arrow2::io::ipc::write::WriteOptions { compression: None }, + ); + + let row_count = table.num_rows(); + + let schema = table + .schema + .to_arrow() + .wrap_err("Failed to convert Daft schema to Arrow schema")?; + + writer + .start(&schema, None) + .wrap_err("Failed to start Arrow stream writer with schema")?; + + let arrays = table.get_inner_arrow_arrays().collect(); + let chunk = arrow2::chunk::Chunk::new(arrays); + + writer + .write(&chunk, None) + .wrap_err("Failed to write Arrow chunk to stream writer")?; + + let response = ExecutePlanResponse { + session_id: self.session_id.to_string(), + server_side_session_id: self.server_side_session_id.to_string(), + operation_id: self.operation_id.to_string(), + response_id: Uuid::new_v4().to_string(), // todo: implement this + metrics: None, // todo: implement this + observed_metrics: vec![], + schema: None, + response_type: Some(ResponseType::ArrowBatch(ArrowBatch { + row_count: row_count as i64, + data, + start_offset: None, + })), + }; + + self.tx + .send(Ok(response)) + .wrap_err("Error sending response to client")?; + + Ok(()) + } +} + +impl Session { + pub async fn handle_root_command( + &self, + command: Relation, + operation_id: String, + ) -> Result { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + let mut channel = ExecutablePlanChannel { + session_id: self.client_side_session_id().to_string(), + server_side_session_id: self.server_side_session_id().to_string(), + operation_id: operation_id.clone(), + tx: tx.clone(), + }; + + thread::spawn({ + let session_id = self.client_side_session_id().to_string(); + let server_side_session_id = self.server_side_session_id().to_string(); + move || { + let result = convert_data(command, &mut channel); + + if let Err(e) = result { + tx.send(Err(e)).unwrap(); + } else { + let finished = ExecutePlanResponse { + session_id, + server_side_session_id, + operation_id: operation_id.to_string(), + response_id: Uuid::new_v4().to_string(), + metrics: None, + observed_metrics: vec![], + schema: None, + response_type: Some(ResponseType::ResultComplete(ResultComplete {})), + }; + + tx.send(Ok(finished)).unwrap(); + } + } + }); + + let recv_stream = + UnboundedReceiverStream::new(rx).map_err(|e| Status::internal(e.to_string())); + + Ok(Box::pin(recv_stream)) + } + + pub fn handle_write_operation( + &self, + operation: WriteOperation, + operation_id: String, + ) -> Result { + let mode = operation.mode(); + + let WriteOperation { + input, + source, + sort_column_names, + partitioning_columns, + bucket_by, + options, + clustering_columns, + save_type, + mode: _, + } = operation; + + let input = input.ok_or_else(|| invalid_argument!("input is None"))?; + + let source = source.unwrap_or_else(|| "parquet".to_string()); + if source != "parquet" { + return Err(unimplemented_err!( + "Only writing parquet is supported for now but got {source}" + )); + } + + match mode { + SaveMode::Unspecified => {} + SaveMode::Append => { + return Err(unimplemented_err!("Append mode is not yet supported")); + } + SaveMode::Overwrite => { + return Err(unimplemented_err!("Overwrite mode is not yet supported")); + } + SaveMode::ErrorIfExists => { + return Err(unimplemented_err!( + "ErrorIfExists mode is not yet supported" + )); + } + SaveMode::Ignore => { + return Err(unimplemented_err!("Ignore mode is not yet supported")); + } + } + + if !sort_column_names.is_empty() { + return Err(unimplemented_err!("Sort by columns is not yet supported")); + } + + if !partitioning_columns.is_empty() { + return Err(unimplemented_err!( + "Partitioning columns is not yet supported" + )); + } + + if bucket_by.is_some() { + return Err(unimplemented_err!("Bucket by columns is not yet supported")); + } + + if !options.is_empty() { + return Err(unimplemented_err!("Options are not yet supported")); + } + + if !clustering_columns.is_empty() { + return Err(unimplemented_err!( + "Clustering columns is not yet supported" + )); + } + + let save_type = save_type.ok_or_else(|| invalid_argument!("save_type is required"))?; + + let save_path = match save_type { + SaveType::Path(path) => path, + SaveType::Table(_) => { + return Err(unimplemented_err!("Save type table is not yet supported")); + } + }; + + std::thread::scope(|scope| { + let res = scope.spawn(|| { + let plan = to_logical_plan(input) + .map_err(|_| Status::internal("Failed to convert to logical plan"))?; + + // todo: assuming this is parquet + // todo: is save_path right? + let plan = plan + .table_write(&save_path, FileFormat::Parquet, None, None, None) + .map_err(|_| Status::internal("Failed to write table"))?; + + let plan = plan.build(); + + run_local( + &plan, + |_table| ControlFlow::Continue(()), + || ControlFlow::Break(()), + ) + .map_err(|e| Status::internal(format!("Failed to write table: {e}")))?; + + Result::<(), Status>::Ok(()) + }); + + res.join().unwrap() + })?; + + let session_id = self.client_side_session_id().to_string(); + let server_side_session_id = self.server_side_session_id().to_string(); + + Ok(Box::pin(futures::stream::once(async { + Ok(ExecutePlanResponse { + session_id, + server_side_session_id, + operation_id, + response_id: "abcxyz".to_string(), + metrics: None, + observed_metrics: vec![], + schema: None, + response_type: Some(ResponseType::ResultComplete(ResultComplete {})), + }) + }))) + } +} diff --git a/src/daft-connect/src/config.rs b/src/daft-connect/src/config.rs new file mode 100644 index 0000000000..21f71a379f --- /dev/null +++ b/src/daft-connect/src/config.rs @@ -0,0 +1,215 @@ +use std::collections::BTreeMap; + +use spark_connect::{ + config_request::{Get, GetAll, GetOption, GetWithDefault, IsModifiable, Set, Unset}, + ConfigResponse, KeyValue, +}; +use tonic::Status; + +use crate::Session; + +impl Session { + fn config_response(&self) -> ConfigResponse { + ConfigResponse { + session_id: self.client_side_session_id().to_string(), + server_side_session_id: self.server_side_session_id().to_string(), + pairs: vec![], + warnings: vec![], + } + } + + pub fn set(&mut self, operation: Set) -> Result { + let mut response = self.config_response(); + + let span = + tracing::info_span!("set", session_id = %self.client_side_session_id(), ?operation); + let _enter = span.enter(); + + for KeyValue { key, value } in operation.pairs { + let Some(value) = value else { + let msg = format!("Missing value for key {key}. If you want to unset a value use the Unset operation"); + response.warnings.push(msg); + continue; + }; + + let previous = self.config_values_mut().insert(key.clone(), value.clone()); + if previous.is_some() { + tracing::info!("Updated existing configuration value"); + } else { + tracing::info!("Set new configuration value"); + } + } + + Ok(response) + } + + pub fn get(&self, operation: Get) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + for key in operation.keys { + let value = self.config_values().get(&key).cloned(); + response.pairs.push(KeyValue { key, value }); + } + + Ok(response) + } + + pub fn get_with_default(&self, operation: GetWithDefault) -> Result { + let mut response = self.config_response(); + + let span = + tracing::info_span!("get_with_default", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + for KeyValue { + key, + value: default_value, + } in operation.pairs + { + let value = self.config_values().get(&key).cloned().or(default_value); + response.pairs.push(KeyValue { key, value }); + } + + Ok(response) + } + + /// Needs to be fixed so it has different behavior than [`Session::get`]. Not entirely + /// sure how it should work yet. + pub fn get_option(&self, operation: GetOption) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get_option", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + for key in operation.keys { + let value = self.config_values().get(&key).cloned(); + response.pairs.push(KeyValue { key, value }); + } + + Ok(response) + } + + pub fn get_all(&self, operation: GetAll) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get_all", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + let Some(prefix) = operation.prefix else { + for (key, value) in self.config_values() { + response.pairs.push(KeyValue { + key: key.clone(), + value: Some(value.clone()), + }); + } + return Ok(response); + }; + + for (k, v) in prefix_search(self.config_values(), &prefix) { + response.pairs.push(KeyValue { + key: k.clone(), + value: Some(v.clone()), + }); + } + + Ok(response) + } + + pub fn unset(&mut self, operation: Unset) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("unset", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + for key in operation.keys { + if self.config_values_mut().remove(&key).is_none() { + let msg = format!("Key {key} not found"); + response.warnings.push(msg); + } else { + tracing::info!("Unset configuration value"); + } + } + + Ok(response) + } + + pub fn is_modifiable(&self, _operation: IsModifiable) -> Result { + let response = self.config_response(); + + let span = + tracing::info_span!("is_modifiable", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + tracing::warn!(session_id = %self.client_side_session_id(), "is_modifiable operation not yet implemented"); + // todo: need to implement this + Ok(response) + } +} + +fn prefix_search<'a, V>( + map: &'a BTreeMap, + prefix: &'a str, +) -> impl Iterator { + let start = map.range(prefix.to_string()..); + start.take_while(move |(k, _)| k.starts_with(prefix)) +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use super::*; + + #[test] + fn test_prefix_search() { + let mut map = BTreeMap::new(); + map.insert("apple".to_string(), 1); + map.insert("application".to_string(), 2); + map.insert("banana".to_string(), 3); + map.insert("app".to_string(), 4); + map.insert("apricot".to_string(), 5); + + // Test with prefix "app" + let result: Vec<_> = prefix_search(&map, "app").collect(); + assert_eq!( + result, + vec![ + (&"app".to_string(), &4), + (&"apple".to_string(), &1), + (&"application".to_string(), &2), + ] + ); + + // Test with prefix "b" + let result: Vec<_> = prefix_search(&map, "b").collect(); + assert_eq!(result, vec![(&"banana".to_string(), &3),]); + + // Test with prefix that doesn't match any keys + let result: Vec<_> = prefix_search(&map, "z").collect(); + assert_eq!(result, vec![]); + + // Test with empty prefix (should return all items) + let result: Vec<_> = prefix_search(&map, "").collect(); + assert_eq!( + result, + vec![ + (&"app".to_string(), &4), + (&"apple".to_string(), &1), + (&"application".to_string(), &2), + (&"apricot".to_string(), &5), + (&"banana".to_string(), &3), + ] + ); + + // Test with prefix that matches a complete key + let result: Vec<_> = prefix_search(&map, "apple").collect(); + assert_eq!(result, vec![(&"apple".to_string(), &1),]); + + // Test with case sensitivity + let result: Vec<_> = prefix_search(&map, "App").collect(); + assert_eq!(result, vec![]); + } +} diff --git a/src/daft-connect/src/convert.rs b/src/daft-connect/src/convert.rs new file mode 100644 index 0000000000..ddb3530669 --- /dev/null +++ b/src/daft-connect/src/convert.rs @@ -0,0 +1,45 @@ +mod data_conversion; +mod expression; +mod formatting; +mod plan_conversion; +mod schema_conversion; + +use std::{ + collections::HashMap, + ops::{ControlFlow, Try}, + sync::Arc, +}; + +use common_daft_config::DaftExecutionConfig; +use daft_plan::LogicalPlanRef; +use daft_table::Table; +pub use data_conversion::convert_data; +use eyre::Context; +pub use plan_conversion::to_logical_plan; +pub use schema_conversion::connect_schema; + +pub fn run_local( + logical_plan: &LogicalPlanRef, + mut f: impl FnMut(&Table) -> T, + default: impl FnOnce() -> T, +) -> eyre::Result { + let physical_plan = daft_physical_plan::translate(logical_plan)?; + let cfg = Arc::new(DaftExecutionConfig::default()); + let psets = HashMap::new(); + + let stream = daft_local_execution::run_local(&physical_plan, psets, cfg, None) + .wrap_err("running local execution")?; + + for elem in stream { + let elem = elem?; + let tables = elem.get_tables()?; + + for table in tables.as_slice() { + if let ControlFlow::Break(x) = f(table).branch() { + return Ok(T::from_residual(x)); + } + } + } + + Ok(default()) +} diff --git a/src/daft-connect/src/convert/data_conversion.rs b/src/daft-connect/src/convert/data_conversion.rs new file mode 100644 index 0000000000..de6b0ed71e --- /dev/null +++ b/src/daft-connect/src/convert/data_conversion.rs @@ -0,0 +1,59 @@ +//! Relation handling for Spark Connect protocol. +//! +//! A Relation represents a structured dataset or transformation in Spark Connect. +//! It can be either a base relation (direct data source) or derived relation +//! (result of operations on other relations). +//! +//! The protocol represents relations as trees of operations where: +//! - Each node is a Relation with metadata and an operation type +//! - Operations can reference other relations, forming a DAG +//! - The tree describes how to derive the final result +//! +//! Example flow for: SELECT age, COUNT(*) FROM employees WHERE dept='Eng' GROUP BY age +//! +//! ```text +//! Aggregate (grouping by age) +//! ↳ Filter (department = 'Engineering') +//! ↳ Read (employees table) +//! ``` +//! +//! Relations abstract away: +//! - Physical storage details +//! - Distributed computation +//! - Query optimization +//! - Data source specifics +//! +//! This allows Spark to optimize and execute queries efficiently across a cluster +//! while providing a consistent API regardless of the underlying data source. +//! ```mermaid +//! +//! ``` + +use eyre::{eyre, Context}; +use spark_connect::{relation::RelType, Relation}; +use tracing::trace; + +use crate::{command::ConcreteDataChannel, convert::formatting::RelTypeExt}; + +mod show_string; +use show_string::show_string; + +mod range; +use range::range; + +pub fn convert_data(plan: Relation, encoder: &mut impl ConcreteDataChannel) -> eyre::Result<()> { + // First check common fields if needed + if let Some(common) = &plan.common { + // contains metadata shared across all relation types + // Log or handle common fields if necessary + trace!("Processing relation with plan_id: {:?}", common.plan_id); + } + + let rel_type = plan.rel_type.ok_or_else(|| eyre!("rel_type is None"))?; + + match rel_type { + RelType::ShowString(input) => show_string(*input, encoder).wrap_err("parsing ShowString"), + RelType::Range(input) => range(input, encoder).wrap_err("parsing Range"), + other => Err(eyre!("Unsupported top-level relation: {}", other.name())), + } +} diff --git a/src/daft-connect/src/convert/data_conversion/range.rs b/src/daft-connect/src/convert/data_conversion/range.rs new file mode 100644 index 0000000000..1afb710a81 --- /dev/null +++ b/src/daft-connect/src/convert/data_conversion/range.rs @@ -0,0 +1,42 @@ +use daft_core::prelude::Series; +use daft_schema::prelude::Schema; +use daft_table::Table; +use eyre::{ensure, Context}; +use spark_connect::Range; + +use crate::command::ConcreteDataChannel; + +pub fn range(range: Range, channel: &mut impl ConcreteDataChannel) -> eyre::Result<()> { + let Range { + start, + end, + step, + num_partitions, + } = range; + + let start = start.unwrap_or(0); + + ensure!(num_partitions.is_none(), "num_partitions is not supported"); + + let step = usize::try_from(step).wrap_err("step must be a positive integer")?; + ensure!(step > 0, "step must be greater than 0"); + + let arrow_array: arrow2::array::Int64Array = (start..end).step_by(step).map(Some).collect(); + let len = arrow_array.len(); + + let singleton_series = Series::try_from(( + "range", + Box::new(arrow_array) as Box, + )) + .wrap_err("creating singleton series")?; + + let singleton_table = Table::new_with_size( + Schema::new(vec![singleton_series.field().clone()])?, + vec![singleton_series], + len, + )?; + + channel.send_table(&singleton_table)?; + + Ok(()) +} diff --git a/src/daft-connect/src/convert/data_conversion/show_string.rs b/src/daft-connect/src/convert/data_conversion/show_string.rs new file mode 100644 index 0000000000..bed2952a21 --- /dev/null +++ b/src/daft-connect/src/convert/data_conversion/show_string.rs @@ -0,0 +1,59 @@ +use daft_core::prelude::Series; +use daft_schema::prelude::Schema; +use daft_table::Table; +use eyre::{ensure, eyre, Context}; +use spark_connect::ShowString; + +use crate::{ + command::ConcreteDataChannel, + convert::{plan_conversion::to_logical_plan, run_local}, +}; + +pub fn show_string( + show_string: ShowString, + channel: &mut impl ConcreteDataChannel, +) -> eyre::Result<()> { + let ShowString { + input, + num_rows, + truncate, + vertical, + } = show_string; + + ensure!(num_rows > 0, "num_rows must be positive, got {num_rows}"); + ensure!(truncate > 0, "truncate must be positive, got {truncate}"); + ensure!(!vertical, "vertical is not yet supported"); + + let input = *input.ok_or_else(|| eyre!("input is None"))?; + + let logical_plan = to_logical_plan(input)?.build(); + + run_local( + &logical_plan, + |table| -> eyre::Result<()> { + let display = format!("{table}"); + + let arrow_array: arrow2::array::Utf8Array = + std::iter::once(display.as_str()).map(Some).collect(); + + let singleton_series = Series::try_from(( + "show_string", + Box::new(arrow_array) as Box, + )) + .wrap_err("creating singleton series")?; + + let singleton_table = Table::new_with_size( + Schema::new(vec![singleton_series.field().clone()])?, + vec![singleton_series], + 1, + )?; + + channel.send_table(&singleton_table)?; + + Ok(()) + }, + || Ok(()), + )??; + + Ok(()) +} diff --git a/src/daft-connect/src/convert/expression.rs b/src/daft-connect/src/convert/expression.rs new file mode 100644 index 0000000000..f79a7bf5a8 --- /dev/null +++ b/src/daft-connect/src/convert/expression.rs @@ -0,0 +1,120 @@ +use daft_dsl::{Expr as DaftExpr, Operator}; +use eyre::{bail, ensure, eyre, Result}; +use spark_connect::{expression, expression::literal::LiteralType, Expression}; + +pub fn convert_expression(expr: Expression) -> Result { + match expr.expr_type { + Some(expression::ExprType::Literal(lit)) => Ok(DaftExpr::Literal(convert_literal(lit)?)), + + Some(expression::ExprType::UnresolvedAttribute(attr)) => { + Ok(DaftExpr::Column(attr.unparsed_identifier.into())) + } + + Some(expression::ExprType::Alias(alias)) => { + let expression::Alias { + expr, + name, + metadata, + } = *alias; + let expr = *expr.ok_or_else(|| eyre!("expr is None"))?; + + // Convert alias + let expr = convert_expression(expr)?; + + if let Some(metadata) = metadata + && !metadata.is_empty() + { + bail!("Metadata is not yet supported"); + } + + // ignore metadata for now + + let [name] = name.as_slice() else { + bail!("Alias name must have exactly one element"); + }; + + Ok(DaftExpr::Alias(expr.into(), name.as_str().into())) + } + + Some(expression::ExprType::UnresolvedFunction(expression::UnresolvedFunction { + function_name, + arguments, + is_distinct, + is_user_defined_function, + })) => { + ensure!(!is_distinct, "Distinct is not yet supported"); + ensure!( + !is_user_defined_function, + "User-defined functions are not yet supported" + ); + + let op = function_name.as_str(); + match op { + ">" | "<" | "<=" | ">=" | "+" | "-" | "*" | "/" => { + let arr: [Expression; 2] = arguments + .try_into() + .map_err(|_| eyre!("Expected 2 arguments"))?; + let [left, right] = arr; + + let left = convert_expression(left)?; + let right = convert_expression(right)?; + + let op = match op { + ">" => Operator::Gt, + "<" => Operator::Lt, + "<=" => Operator::LtEq, + ">=" => Operator::GtEq, + "+" => Operator::Plus, + "-" => Operator::Minus, + "*" => Operator::Multiply, + "/" => Operator::FloorDivide, // todo is this what we want? + _ => unreachable!(), + }; + + Ok(DaftExpr::BinaryOp { + left: left.into(), + op, + right: right.into(), + }) + } + other => bail!("Unsupported function name: {other}"), + } + } + + // Handle other expression types... + _ => Err(eyre!("Unsupported expression type")), + } +} + +// Helper functions to convert literals, function names, operators etc. + +fn convert_literal(lit: expression::Literal) -> Result { + let literal_type = lit + .literal_type + .ok_or_else(|| eyre!("literal_type is None"))?; + + let result = match literal_type { + LiteralType::Null(..) => daft_dsl::LiteralValue::Null, + LiteralType::Binary(input) => daft_dsl::LiteralValue::Binary(input), + LiteralType::Boolean(input) => daft_dsl::LiteralValue::Boolean(input), + LiteralType::Byte(input) => daft_dsl::LiteralValue::Int32(input), + LiteralType::Short(input) => daft_dsl::LiteralValue::Int32(input), + LiteralType::Integer(input) => daft_dsl::LiteralValue::Int32(input), + LiteralType::Long(input) => daft_dsl::LiteralValue::Int64(input), + LiteralType::Float(input) => daft_dsl::LiteralValue::Float64(f64::from(input)), + LiteralType::Double(input) => daft_dsl::LiteralValue::Float64(input), + LiteralType::String(input) => daft_dsl::LiteralValue::Utf8(input), + LiteralType::Date(input) => daft_dsl::LiteralValue::Date(input), + LiteralType::Decimal(_) + | LiteralType::Timestamp(_) + | LiteralType::TimestampNtz(_) + | LiteralType::CalendarInterval(_) + | LiteralType::YearMonthInterval(_) + | LiteralType::DayTimeInterval(_) + | LiteralType::Array(_) + | LiteralType::Map(_) + | LiteralType::Struct(_) => bail!("unimplemented"), + }; + + Ok(result) +} diff --git a/src/daft-connect/src/convert/formatting.rs b/src/daft-connect/src/convert/formatting.rs new file mode 100644 index 0000000000..3310a918fb --- /dev/null +++ b/src/daft-connect/src/convert/formatting.rs @@ -0,0 +1,69 @@ +use spark_connect::relation::RelType; + +/// Extension trait for RelType to add a `name` method. +pub trait RelTypeExt { + /// Returns the name of the RelType as a string. + fn name(&self) -> &'static str; +} + +impl RelTypeExt for RelType { + fn name(&self) -> &'static str { + match self { + Self::Read(_) => "Read", + Self::Project(_) => "Project", + Self::Filter(_) => "Filter", + Self::Join(_) => "Join", + Self::SetOp(_) => "SetOp", + Self::Sort(_) => "Sort", + Self::Limit(_) => "Limit", + Self::Aggregate(_) => "Aggregate", + Self::Sql(_) => "Sql", + Self::LocalRelation(_) => "LocalRelation", + Self::Sample(_) => "Sample", + Self::Offset(_) => "Offset", + Self::Deduplicate(_) => "Deduplicate", + Self::Range(_) => "Range", + Self::SubqueryAlias(_) => "SubqueryAlias", + Self::Repartition(_) => "Repartition", + Self::ToDf(_) => "ToDf", + Self::WithColumnsRenamed(_) => "WithColumnsRenamed", + Self::ShowString(_) => "ShowString", + Self::Drop(_) => "Drop", + Self::Tail(_) => "Tail", + Self::WithColumns(_) => "WithColumns", + Self::Hint(_) => "Hint", + Self::Unpivot(_) => "Unpivot", + Self::ToSchema(_) => "ToSchema", + Self::RepartitionByExpression(_) => "RepartitionByExpression", + Self::MapPartitions(_) => "MapPartitions", + Self::CollectMetrics(_) => "CollectMetrics", + Self::Parse(_) => "Parse", + Self::GroupMap(_) => "GroupMap", + Self::CoGroupMap(_) => "CoGroupMap", + Self::WithWatermark(_) => "WithWatermark", + Self::ApplyInPandasWithState(_) => "ApplyInPandasWithState", + Self::HtmlString(_) => "HtmlString", + Self::CachedLocalRelation(_) => "CachedLocalRelation", + Self::CachedRemoteRelation(_) => "CachedRemoteRelation", + Self::CommonInlineUserDefinedTableFunction(_) => "CommonInlineUserDefinedTableFunction", + Self::AsOfJoin(_) => "AsOfJoin", + Self::CommonInlineUserDefinedDataSource(_) => "CommonInlineUserDefinedDataSource", + Self::WithRelations(_) => "WithRelations", + Self::Transpose(_) => "Transpose", + Self::FillNa(_) => "FillNa", + Self::DropNa(_) => "DropNa", + Self::Replace(_) => "Replace", + Self::Summary(_) => "Summary", + Self::Crosstab(_) => "Crosstab", + Self::Describe(_) => "Describe", + Self::Cov(_) => "Cov", + Self::Corr(_) => "Corr", + Self::ApproxQuantile(_) => "ApproxQuantile", + Self::FreqItems(_) => "FreqItems", + Self::SampleBy(_) => "SampleBy", + Self::Catalog(_) => "Catalog", + Self::Extension(_) => "Extension", + Self::Unknown(_) => "Unknown", + } + } +} diff --git a/src/daft-connect/src/convert/plan_conversion.rs b/src/daft-connect/src/convert/plan_conversion.rs new file mode 100644 index 0000000000..8fa9b8a7b3 --- /dev/null +++ b/src/daft-connect/src/convert/plan_conversion.rs @@ -0,0 +1,134 @@ +use std::{collections::HashSet, sync::Arc}; + +use daft_plan::{LogicalPlanBuilder, ParquetScanBuilder}; +use eyre::{bail, eyre, Result, WrapErr}; +use spark_connect::{ + expression::Alias, + read::{DataSource, ReadType}, + relation::RelType, + Filter, Read, Relation, WithColumns, +}; +use tracing::warn; + +use crate::convert::expression; + +pub fn to_logical_plan(plan: Relation) -> Result { + let scope = std::thread::spawn(|| { + let rel_type = plan.rel_type.ok_or_else(|| eyre!("rel_type is None"))?; + + match rel_type { + RelType::ShowString(..) => { + bail!("ShowString is only supported as a top-level relation") + } + RelType::Filter(filter) => parse_filter(*filter).wrap_err("parsing Filter"), + RelType::WithColumns(with_columns) => { + parse_with_columns(*with_columns).wrap_err("parsing WithColumns") + } + RelType::Read(read) => parse_read(read), + _ => bail!("Unsupported relation type: {rel_type:?}"), + } + }); + + scope.join().unwrap() +} + +fn parse_filter(filter: Filter) -> Result { + let Filter { input, condition } = filter; + let input = *input.ok_or_else(|| eyre!("input is None"))?; + let input_plan = to_logical_plan(input).wrap_err("parsing input")?; + + let condition = condition.ok_or_else(|| eyre!("condition is None"))?; + let condition = + expression::convert_expression(condition).wrap_err("converting to daft expression")?; + let condition = Arc::new(condition); + + input_plan.filter(condition).wrap_err("applying filter") +} + +fn parse_with_columns(with_columns: WithColumns) -> Result { + let WithColumns { input, aliases } = with_columns; + let input = *input.ok_or_else(|| eyre!("input is None"))?; + let input_plan = to_logical_plan(input).wrap_err("parsing input")?; + + let mut new_exprs = Vec::new(); + let mut existing_columns: HashSet<_> = input_plan.schema().names().into_iter().collect(); + + for alias in aliases { + let Alias { + expr, + name, + metadata, + } = alias; + + if name.len() != 1 { + bail!("Alias name must have exactly one element"); + } + let name = name[0].as_str(); + + if metadata.is_some() { + bail!("Metadata is not yet supported"); + } + + let expr = expr.ok_or_else(|| eyre!("expression is None"))?; + let expr = + expression::convert_expression(*expr).wrap_err("converting to daft expression")?; + let expr = Arc::new(expr); + + new_exprs.push(expr.alias(name)); + + if existing_columns.contains(name) { + existing_columns.remove(name); + } + } + + // Add remaining existing columns + for col_name in existing_columns { + new_exprs.push(daft_dsl::col(col_name)); + } + + input_plan + .select(new_exprs) + .wrap_err("selecting new expressions") +} + +fn parse_read(read: Read) -> Result { + let Read { + is_streaming, + read_type, + } = read; + + warn!("Ignoring is_streaming: {is_streaming}"); + + let read_type = read_type.ok_or_else(|| eyre!("type is None"))?; + + match read_type { + ReadType::NamedTable(_) => bail!("Named tables are not yet supported"), + ReadType::DataSource(data_source) => parse_data_source(data_source), + } +} + +fn parse_data_source(data_source: DataSource) -> Result { + let DataSource { + format, + options, + paths, + predicates, + .. + } = data_source; + + let format = format.ok_or_else(|| eyre!("format is None"))?; + if format != "parquet" { + bail!("Only parquet is supported; got {format}"); + } + + if !options.is_empty() { + bail!("Options are not yet supported"); + } + if !predicates.is_empty() { + bail!("Predicates are not yet supported"); + } + + ParquetScanBuilder::new(paths) + .finish() + .wrap_err("creating ParquetScanBuilder") +} diff --git a/src/daft-connect/src/convert/schema_conversion.rs b/src/daft-connect/src/convert/schema_conversion.rs new file mode 100644 index 0000000000..dcce376b94 --- /dev/null +++ b/src/daft-connect/src/convert/schema_conversion.rs @@ -0,0 +1,56 @@ +use spark_connect::{ + data_type::{Kind, Long, Struct, StructField}, + relation::RelType, + DataType, Relation, +}; + +#[tracing::instrument(skip_all)] +pub fn connect_schema(input: Relation) -> Result { + if input.common.is_some() { + tracing::warn!("We do not currently look at common fields"); + } + + let result = match input + .rel_type + .ok_or_else(|| tonic::Status::internal("rel_type is None"))? + { + RelType::Range(spark_connect::Range { num_partitions, .. }) => { + if num_partitions.is_some() { + return Err(tonic::Status::unimplemented( + "num_partitions is not supported", + )); + } + + let long = Long { + type_variation_reference: 0, + }; + + let id_field = StructField { + name: "id".to_string(), + data_type: Some(DataType { + kind: Some(Kind::Long(long)), + }), + nullable: false, + metadata: None, + }; + + let fields = vec![id_field]; + + let strct = Struct { + fields, + type_variation_reference: 0, + }; + + DataType { + kind: Some(Kind::Struct(strct)), + } + } + other => { + return Err(tonic::Status::unimplemented(format!( + "Unsupported relation type: {other:?}" + ))) + } + }; + + Ok(result) +} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs new file mode 100644 index 0000000000..33466183e8 --- /dev/null +++ b/src/daft-connect/src/lib.rs @@ -0,0 +1,447 @@ +#![feature(iterator_try_collect)] +#![feature(let_chains)] +#![feature(try_trait_v2)] +#![feature(coroutines)] +#![feature(iter_from_coroutine)] +#![feature(stmt_expr_attributes)] +#![feature(try_trait_v2_residual)] +#![warn(unused)] + +use std::ops::ControlFlow; + +use dashmap::DashMap; +use eyre::Context; +#[cfg(feature = "python")] +use pyo3::types::PyModuleMethods; +use spark_connect::{ + analyze_plan_response, + command::CommandType, + plan::OpType, + spark_connect_service_server::{SparkConnectService, SparkConnectServiceServer}, + AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, + ArtifactStatusesRequest, ArtifactStatusesResponse, ConfigRequest, ConfigResponse, + ExecutePlanRequest, ExecutePlanResponse, FetchErrorDetailsRequest, FetchErrorDetailsResponse, + InterruptRequest, InterruptResponse, Plan, ReattachExecuteRequest, ReleaseExecuteRequest, + ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, +}; +use tonic::{transport::Server, Request, Response, Status}; +use tracing::{info, warn}; +use uuid::Uuid; + +use crate::{convert::run_local, session::Session}; + +mod command; +mod config; +mod convert; +mod session; +pub mod util; + +pub fn start(addr: &str) -> eyre::Result<()> { + info!("Daft-Connect server listening on {addr}"); + let addr = util::parse_spark_connect_address(addr)?; + + let service = DaftSparkConnectService::default(); + + info!("Daft-Connect server listening on {addr}"); + + std::thread::spawn(move || { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let result = runtime + .block_on(async { + Server::builder() + .add_service(SparkConnectServiceServer::new(service)) + .serve(addr) + .await + }) + .wrap_err_with(|| format!("Failed to start server on {addr}")); + + if let Err(e) = result { + eprintln!("Daft-Connect server error: {e:?}"); + } + + println!("done with runtime"); + + eyre::Result::<_>::Ok(()) + }); + + Ok(()) +} + +#[derive(Default)] +pub struct DaftSparkConnectService { + client_to_session: DashMap, // To track session data +} + +impl DaftSparkConnectService { + fn get_session( + &self, + session_id: &str, + ) -> Result, Status> { + let Ok(uuid) = Uuid::parse_str(session_id) else { + return Err(Status::invalid_argument( + "Invalid session_id format, must be a UUID", + )); + }; + + let res = self + .client_to_session + .entry(uuid) + .or_insert_with(|| Session::new(session_id.to_string())); + + Ok(res) + } +} + +#[tonic::async_trait] +impl SparkConnectService for DaftSparkConnectService { + type ExecutePlanStream = std::pin::Pin< + Box< + dyn futures::Stream> + Send + Sync + 'static, + >, + >; + type ReattachExecuteStream = std::pin::Pin< + Box< + dyn futures::Stream> + Send + Sync + 'static, + >, + >; + + #[tracing::instrument(skip_all)] + async fn execute_plan( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let session = self.get_session(&request.session_id)?; + + let operation = request + .operation_id + .ok_or_else(|| invalid_argument!("Operation ID is required"))?; + + // Proceed with executing the plan... + let plan = request + .plan + .ok_or_else(|| invalid_argument!("Plan is required"))?; + let plan = plan + .op_type + .ok_or_else(|| invalid_argument!("Plan operation is required"))?; + + use spark_connect::plan::OpType; + + match plan { + OpType::Root(relation) => { + let result = session.handle_root_command(relation, operation).await?; + return Ok(Response::new(result)); + } + OpType::Command(command) => { + let command = command + .command_type + .ok_or_else(|| invalid_argument!("Command type is required"))?; + + match command { + CommandType::RegisterFunction(_) => { + Err(unimplemented_err!("RegisterFunction not implemented")) + } + CommandType::WriteOperation(op) => { + println!("WriteOperation: {:#2?}", op); + + let result = session.handle_write_operation(op, operation)?; + + return Ok(Response::new(result)); + } + CommandType::CreateDataframeView(_) => { + Err(unimplemented_err!("CreateDataframeView not implemented")) + } + CommandType::WriteOperationV2(_) => { + Err(unimplemented_err!("WriteOperationV2 not implemented")) + } + CommandType::SqlCommand(..) => { + Err(unimplemented_err!("SQL execution not yet implemented")) + } + CommandType::WriteStreamOperationStart(_) => Err(unimplemented_err!( + "WriteStreamOperationStart not implemented" + )), + CommandType::StreamingQueryCommand(_) => { + Err(unimplemented_err!("StreamingQueryCommand not implemented")) + } + CommandType::GetResourcesCommand(_) => { + Err(unimplemented_err!("GetResourcesCommand not implemented")) + } + CommandType::StreamingQueryManagerCommand(_) => Err(unimplemented_err!( + "StreamingQueryManagerCommand not implemented" + )), + CommandType::RegisterTableFunction(_) => { + Err(unimplemented_err!("RegisterTableFunction not implemented")) + } + CommandType::StreamingQueryListenerBusCommand(_) => Err(unimplemented_err!( + "StreamingQueryListenerBusCommand not implemented" + )), + CommandType::RegisterDataSource(_) => { + Err(unimplemented_err!("RegisterDataSource not implemented")) + } + CommandType::CreateResourceProfileCommand(_) => Err(unimplemented_err!( + "CreateResourceProfileCommand not implemented" + )), + CommandType::CheckpointCommand(_) => { + Err(unimplemented_err!("CheckpointCommand not implemented")) + } + CommandType::RemoveCachedRemoteRelationCommand(_) => Err(unimplemented_err!( + "RemoveCachedRemoteRelationCommand not implemented" + )), + CommandType::MergeIntoTableCommand(_) => { + Err(unimplemented_err!("MergeIntoTableCommand not implemented")) + } + CommandType::Extension(_) => { + Err(unimplemented_err!("Extension not implemented")) + } + } + } + }?; + + Err(unimplemented_err!("Unsupported plan type")) + } + + #[tracing::instrument(skip_all)] + async fn config( + &self, + request: Request, + ) -> Result, Status> { + println!("got config"); + let request = request.into_inner(); + + let mut session = self.get_session(&request.session_id)?; + + let Some(operation) = request.operation.and_then(|op| op.op_type) else { + return Err(Status::invalid_argument("Missing operation")); + }; + + use spark_connect::config_request::operation::OpType; + + let response = match operation { + OpType::Set(op) => session.set(op), + OpType::Get(op) => session.get(op), + OpType::GetWithDefault(op) => session.get_with_default(op), + OpType::GetOption(op) => session.get_option(op), + OpType::GetAll(op) => session.get_all(op), + OpType::Unset(op) => session.unset(op), + OpType::IsModifiable(op) => session.is_modifiable(op), + }?; + + info!("Response: {response:?}"); + + Ok(Response::new(response)) + } + + #[tracing::instrument(skip_all)] + async fn add_artifacts( + &self, + _request: Request>, + ) -> Result, Status> { + Err(unimplemented_err!( + "add_artifacts operation is not yet implemented" + )) + } + + #[tracing::instrument(skip_all)] + async fn analyze_plan( + &self, + request: Request, + ) -> Result, Status> { + use spark_connect::analyze_plan_request::*; + let request = request.into_inner(); + + let AnalyzePlanRequest { + session_id, + analyze, + .. + } = request; + + let Some(analyze) = analyze else { + return Err(Status::invalid_argument("analyze is required")); + }; + + match analyze { + Analyze::Schema(Schema { plan }) => { + let Some(Plan { op_type }) = plan else { + return Err(Status::invalid_argument("plan is required")); + }; + + let Some(OpType::Root(relation)) = op_type else { + return Err(Status::invalid_argument("op_type is required to be root")); + }; + + let result = convert::connect_schema(relation)?; + + let schema = analyze_plan_response::DdlParse { + parsed: Some(result), + }; + + let response = AnalyzePlanResponse { + session_id, + server_side_session_id: String::new(), + result: Some(analyze_plan_response::Result::DdlParse(schema)), + }; + + println!("response: {response:#?}"); + + Ok(Response::new(response)) + } + Analyze::TreeString(tree_string) => { + if let Some(level) = tree_string.level { + warn!("Ignoring level {level} in TreeString"); + } + + let Some(plan) = tree_string.plan else { + return Err(invalid_argument!("TreeString must have a plan")); + }; + + let Some(op_type) = plan.op_type else { + return Err(invalid_argument!("plan must have an op_type")); + }; + + println!("op_type: {op_type:?}"); + + let OpType::Root(plan) = op_type else { + return Err(invalid_argument!("Only op_type Root is supported")); + }; + + let logical_plan = match convert::to_logical_plan(plan) { + Ok(lp) => lp, + Err(e) => { + return Err(invalid_argument!( + "Failed to convert to logical plan: {e:?}" + )); + } + }; + + let logical_plan = logical_plan.build(); + + let res = std::thread::spawn(move || { + let result = run_local( + &logical_plan, + |table| { + let table = format!("{table}"); + ControlFlow::Break(table) + }, + || ControlFlow::Continue(()), + ) + .unwrap(); + + let result = match result { + ControlFlow::Break(x) => Some(x), + ControlFlow::Continue(()) => None, + } + .unwrap(); + + AnalyzePlanResponse { + session_id, + server_side_session_id: String::new(), + result: Some(analyze_plan_response::Result::TreeString( + analyze_plan_response::TreeString { + tree_string: result, + }, + )), + } + }); + + let res = res.join().unwrap(); + + let response = Response::new(res); + Ok(response) + } + _ => Err(unimplemented_err!( + "Analyze plan operation is not yet implemented" + )), + } + } + + #[tracing::instrument(skip_all)] + async fn artifact_status( + &self, + _request: Request, + ) -> Result, Status> { + println!("got artifact status"); + Err(unimplemented_err!( + "artifact_status operation is not yet implemented" + )) + } + + #[tracing::instrument(skip_all)] + async fn interrupt( + &self, + _request: Request, + ) -> Result, Status> { + println!("got interrupt"); + Err(unimplemented_err!( + "interrupt operation is not yet implemented" + )) + } + + #[tracing::instrument(skip_all)] + async fn reattach_execute( + &self, + _request: Request, + ) -> Result, Status> { + warn!("reattach_execute operation is not yet implemented"); + + let singleton_stream = futures::stream::once(async { + Err(Status::unimplemented( + "reattach_execute operation is not yet implemented", + )) + }); + + Ok(Response::new(Box::pin(singleton_stream))) + } + + #[tracing::instrument(skip_all)] + async fn release_execute( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let session = self.get_session(&request.session_id)?; + + let response = ReleaseExecuteResponse { + session_id: session.client_side_session_id().to_string(), + server_side_session_id: session.server_side_session_id().to_string(), + operation_id: Some(request.operation_id), // todo: impl properly + }; + + Ok(Response::new(response)) + } + + #[tracing::instrument(skip_all)] + async fn release_session( + &self, + _request: Request, + ) -> Result, Status> { + println!("got release session"); + Err(unimplemented_err!( + "release_session operation is not yet implemented" + )) + } + + #[tracing::instrument(skip_all)] + async fn fetch_error_details( + &self, + _request: Request, + ) -> Result, Status> { + println!("got fetch error details"); + Err(unimplemented_err!( + "fetch_error_details operation is not yet implemented" + )) + } +} +#[cfg(feature = "python")] +#[pyo3::pyfunction] +#[pyo3(name = "connect_start")] +pub fn py_connect_start(addr: &str) -> pyo3::PyResult<()> { + start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}"))) +} + +#[cfg(feature = "python")] +pub fn register_modules(parent: &pyo3::Bound) -> pyo3::PyResult<()> { + parent.add_function(pyo3::wrap_pyfunction_bound!(py_connect_start, parent)?)?; + Ok(()) +} diff --git a/src/daft-connect/src/main.rs b/src/daft-connect/src/main.rs new file mode 100644 index 0000000000..6dbe5dac6c --- /dev/null +++ b/src/daft-connect/src/main.rs @@ -0,0 +1,32 @@ +use daft_connect::DaftSparkConnectService; +use spark_connect::spark_connect_service_server::SparkConnectServiceServer; +use tonic::transport::Server; +use tracing_subscriber::layer::SubscriberExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing::subscriber::set_global_default( + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::new("info")) + .with(tracing_tracy::TracyLayer::default()), + ) + .expect("setup tracy layer"); + + let addr = "[::1]:50051".parse()?; + let service = DaftSparkConnectService::default(); + + println!("Daft-Connect server listening on {}", addr); + + tokio::select! { + result = Server::builder() + .add_service(SparkConnectServiceServer::new(service)) + .serve(addr) => { + result?; + } + _ = tokio::signal::ctrl_c() => { + println!("\nReceived Ctrl-C, gracefully shutting down server"); + } + } + + Ok(()) +} diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs new file mode 100644 index 0000000000..72b477478f --- /dev/null +++ b/src/daft-connect/src/session.rs @@ -0,0 +1,48 @@ +use std::collections::{BTreeMap, HashMap}; + +use uuid::Uuid; + +pub struct Session { + /// so order is preserved, and so we can efficiently do a prefix search + /// + /// Also, + config_values: BTreeMap, + + #[expect( + unused, + reason = "this will be used in the future especially to pass spark connect tests" + )] + tables_by_name: HashMap, + + id: String, + server_side_session_id: String, +} + +impl Session { + pub fn config_values(&self) -> &BTreeMap { + &self.config_values + } + + pub fn config_values_mut(&mut self) -> &mut BTreeMap { + &mut self.config_values + } + + pub fn new(id: String) -> Self { + let server_side_session_id = Uuid::new_v4(); + let server_side_session_id = server_side_session_id.to_string(); + Self { + config_values: Default::default(), + tables_by_name: Default::default(), + id, + server_side_session_id, + } + } + + pub fn client_side_session_id(&self) -> &str { + &self.id + } + + pub fn server_side_session_id(&self) -> &str { + &self.server_side_session_id + } +} diff --git a/src/daft-connect/src/util.rs b/src/daft-connect/src/util.rs new file mode 100644 index 0000000000..0cdaffff54 --- /dev/null +++ b/src/daft-connect/src/util.rs @@ -0,0 +1,109 @@ +use std::net::ToSocketAddrs; + +#[macro_export] +macro_rules! invalid_argument { + ($arg: tt) => {{ + let msg = format!($arg); + ::tonic::Status::invalid_argument(msg) + }}; +} + +#[macro_export] +macro_rules! unimplemented_err { + ($arg: tt) => {{ + let msg = format!($arg); + ::tonic::Status::unimplemented(msg) + }}; +} + +pub fn parse_spark_connect_address(addr: &str) -> eyre::Result { + // Check if address starts with "sc://" + if !addr.starts_with("sc://") { + return Err(eyre::eyre!("Address must start with 'sc://'")); + } + + // Remove the "sc://" prefix + let addr = addr.trim_start_matches("sc://"); + + // Resolve the hostname using tokio's DNS resolver + let addrs = addr.to_socket_addrs()?; + + // Take the first resolved address + addrs + .into_iter() + .next() + .ok_or_else(|| eyre::eyre!("No addresses found for hostname")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_spark_connect_address_valid() { + let addr = "sc://localhost:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_spark_connect_address_missing_prefix() { + let addr = "localhost:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("must start with 'sc://'")); + } + + #[test] + fn test_parse_spark_connect_address_invalid_port() { + let addr = "sc://localhost:invalid"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + } + + #[test] + fn test_parse_spark_connect_address_missing_port() { + let addr = "sc://localhost"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + } + + #[test] + fn test_parse_spark_connect_address_empty() { + let addr = ""; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("must start with 'sc://'")); + } + + #[test] + fn test_parse_spark_connect_address_only_prefix() { + let addr = "sc://"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Invalid address format")); + } + + #[test] + fn test_parse_spark_connect_address_ipv4() { + let addr = "sc://127.0.0.1:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_spark_connect_address_ipv6() { + let addr = "sc://[::1]:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_ok()); + } +} diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 553ad18b40..719da409c4 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -11,7 +11,7 @@ mod sources; use common_error::{DaftError, DaftResult}; use lazy_static::lazy_static; -pub use run::NativeExecutor; +pub use run::{run_local, NativeExecutor}; use snafu::{futures::TryFutureExt, Snafu}; lazy_static! { diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 0a450ace70..c68546f318 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -99,6 +99,12 @@ impl Table { Ok(Self::new_unchecked(schema, columns?, num_rows)) } + pub fn get_inner_arrow_arrays( + &self, + ) -> impl Iterator> + '_ { + self.columns.iter().map(|s| s.to_arrow()) + } + /// Create a new [`Table`] and validate against `num_rows` /// /// Note that this function is slow. You might instead be looking for [`Table::new_unchecked`] if you've already performed your own validation logic. @@ -194,6 +200,10 @@ impl Table { self.num_rows } + pub fn num_rows(&self) -> usize { + self.num_rows + } + pub fn is_empty(&self) -> bool { self.len() == 0 } diff --git a/src/lib.rs b/src/lib.rs index a7a1382538..201f6b229f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -118,6 +118,7 @@ pub mod pylib { daft_sql::register_modules(m)?; daft_functions::register_modules(m)?; daft_functions_json::register_modules(m)?; + daft_connect::register_modules(m)?; m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(build_type))?; diff --git a/tests/connect/__init__.py b/tests/connect/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/connect/test_parquet_simple.py b/tests/connect/test_parquet_simple.py new file mode 100644 index 0000000000..cb3ba9f1b1 --- /dev/null +++ b/tests/connect/test_parquet_simple.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import pathlib +import time + +import pyarrow as pa +import pyarrow.parquet as papq +from pyspark.sql import SparkSession +from pyspark.sql.dataframe import DataFrame + +from daft.daft import connect_start + + +def test_read_parquet(tmpdir): + # Convert tmpdir to Path object + test_dir = pathlib.Path(tmpdir) + input_parquet_path = test_dir / "input.parquet" + + # Create sample data with sequential IDs + sample_data = pa.Table.from_pydict({"id": [0, 1, 2, 3, 4]}) + + # Write sample data to input parquet file + papq.write_table(sample_data, input_parquet_path) + + # Start Daft Connect server + # TODO: Add env var to control server embedding + connect_start("sc://localhost:50051") + + # Initialize Spark Connect session + spark_session: SparkSession = ( + SparkSession.builder.appName("DaftParquetReadWriteTest").remote("sc://localhost:50051").getOrCreate() + ) + + # Read input parquet with Spark Connect + spark_df: DataFrame = spark_session.read.parquet(str(input_parquet_path)) + + # Write DataFrame to output parquet + output_parquet_path = test_dir / "output.parquet" + spark_df.write.parquet(str(output_parquet_path)) + + # Verify output matches input + output_data = papq.read_table(output_parquet_path) + assert output_data.equals(sample_data) + + # Clean up Spark session + spark_session.stop() + time.sleep(2) # Allow time for session cleanup