From 5c88002ee30a8cd5926e0608c9206c2847951b24 Mon Sep 17 00:00:00 2001 From: Douglas Anderson Date: Fri, 26 Jul 2024 16:31:11 -0600 Subject: [PATCH 01/10] Expose CommandStatementIngest as pub in sql module --- arrow-flight/src/sql/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 61eb67b6933e..d7945e925659 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -74,6 +74,7 @@ pub use gen::CommandGetTables; pub use gen::CommandGetXdbcTypeInfo; pub use gen::CommandPreparedStatementQuery; pub use gen::CommandPreparedStatementUpdate; +pub use gen::CommandStatementIngest; pub use gen::CommandStatementQuery; pub use gen::CommandStatementSubstraitPlan; pub use gen::CommandStatementUpdate; @@ -250,6 +251,7 @@ prost_message_ext!( CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, + CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, From 6d669b2da4aa48b56ac1931ff7e2dbb0692e024d Mon Sep 17 00:00:00 2001 From: Douglas Anderson Date: Fri, 26 Jul 2024 16:32:42 -0600 Subject: [PATCH 02/10] Add do_put_statement_ingest to FlightSqlService Dispatch this handler for the new CommandStatementIngest command. --- arrow-flight/src/sql/server.rs | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index b47691c7da5d..e348367a91eb 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -32,9 +32,9 @@ use super::{ CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, - CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, - DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo, - TicketStatementQuery, + CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan, + CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, + SqlInfo, TicketStatementQuery, }; use crate::{ flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty, @@ -397,6 +397,17 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { )) } + /// Execute a bulk ingestion. + async fn do_put_statement_ingest( + &self, + _ticket: CommandStatementIngest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_statement_ingest has no default implementation", + )) + } + /// Bind parameters to given prepared statement. /// /// Returns an opaque handle that the client should pass @@ -713,6 +724,14 @@ where })]); Ok(Response::new(Box::pin(output))) } + Command::CommandStatementIngest(command) => { + let record_count = self.do_put_statement_ingest(command, request).await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } Command::CommandPreparedStatementQuery(command) => { let result = self .do_put_prepared_statement_query(command, request) From 228342a03e888eef74a8eb5e29a4f3d02c2ce784 Mon Sep 17 00:00:00 2001 From: Douglas Anderson Date: Fri, 26 Jul 2024 16:36:53 -0600 Subject: [PATCH 03/10] Sort list --- arrow-flight/src/sql/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index d7945e925659..4710c22b6775 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -255,8 +255,8 @@ prost_message_ext!( CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, - DoPutUpdateResult, DoPutPreparedStatementResult, + DoPutUpdateResult, TicketStatementQuery, ); From 1599afb5f8e7c42bcacb08355ee55f5e6aa29611 Mon Sep 17 00:00:00 2001 From: Douglas Anderson Date: Fri, 26 Jul 2024 16:37:47 -0600 Subject: [PATCH 04/10] Implement stub do_put_statement_ingest in example --- arrow-flight/examples/flight_sql_server.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index d5168debc433..81afecf85625 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -46,9 +46,9 @@ use arrow_flight::sql::{ ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, - CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, ProstMessageExt, Searchable, - SqlInfo, TicketStatementQuery, XdbcDataType, + CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementIngest, + CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, + ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, XdbcDataType, }; use arrow_flight::utils::batches_to_flight_data; use arrow_flight::{ @@ -615,6 +615,14 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(FAKE_UPDATE_RESULT) } + async fn do_put_statement_ingest( + &self, + _ticket: CommandStatementIngest, + _request: Request, + ) -> Result { + Ok(FAKE_UPDATE_RESULT) + } + async fn do_put_substrait_plan( &self, _ticket: CommandStatementSubstraitPlan, From a8a9afc623c712d78e3d3f6b1259a81bbf098660 Mon Sep 17 00:00:00 2001 From: Douglas Anderson Date: Tue, 6 Aug 2024 12:09:42 -0600 Subject: [PATCH 05/10] Refactor helper functions into tests/common/utils --- arrow-flight/src/sql/mod.rs | 4 + arrow-flight/tests/common/mod.rs | 1 + arrow-flight/tests/common/utils.rs | 118 ++++++++++++++++++++++++++++ arrow-flight/tests/encode_decode.rs | 98 +---------------------- 4 files changed, 127 insertions(+), 94 deletions(-) create mode 100644 arrow-flight/tests/common/utils.rs diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 4710c22b6775..453f608d353a 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -50,6 +50,10 @@ mod gen { } pub use gen::action_end_transaction_request::EndTransaction; +pub use gen::command_statement_ingest::table_definition_options::{ + TableExistsOption, TableNotExistOption, +}; +pub use gen::command_statement_ingest::TableDefinitionOptions; pub use gen::ActionBeginSavepointRequest; pub use gen::ActionBeginSavepointResult; pub use gen::ActionBeginTransactionRequest; diff --git a/arrow-flight/tests/common/mod.rs b/arrow-flight/tests/common/mod.rs index 85716e56058c..c4ac027c5890 100644 --- a/arrow-flight/tests/common/mod.rs +++ b/arrow-flight/tests/common/mod.rs @@ -18,3 +18,4 @@ pub mod fixture; pub mod server; pub mod trailers_layer; +pub mod utils; diff --git a/arrow-flight/tests/common/utils.rs b/arrow-flight/tests/common/utils.rs new file mode 100644 index 000000000000..0f70e4b31021 --- /dev/null +++ b/arrow-flight/tests/common/utils.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for testing flight clients and servers + +use std::sync::Arc; + +use arrow_array::{ + types::Int32Type, ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch, + StringViewArray, UInt8Array, +}; +use arrow_schema::{DataType, Field, Schema}; + +/// Make a primitive batch for testing +/// +/// Example: +/// i: 0, 1, None, 3, 4 +/// f: 5.0, 4.0, None, 2.0, 1.0 +#[allow(dead_code)] +pub fn make_primitive_batch(num_rows: usize) -> RecordBatch { + let i: UInt8Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some(i.try_into().unwrap()) + } + }) + .collect(); + + let f: Float64Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some((num_rows - i) as f64) + } + }) + .collect(); + + RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap() +} + +/// Make a dictionary batch for testing +/// +/// Example: +/// a: value0, value1, value2, None, value1, value2 +#[allow(dead_code)] +pub fn make_dictionary_batch(num_rows: usize) -> RecordBatch { + let values: Vec<_> = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + // repeat some values for low cardinality + let v = i / 3; + Some(format!("value{v}")) + } + }) + .collect(); + + let a: DictionaryArray = values + .iter() + .map(|s| s.as_ref().map(|s| s.as_str())) + .collect(); + + RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap() +} + +#[allow(dead_code)] +pub fn make_view_batches(num_rows: usize) -> RecordBatch { + const LONG_TEST_STRING: &str = + "This is a long string to make sure binary view array handles it"; + let schema = Schema::new(vec![ + Field::new("field1", DataType::BinaryView, true), + Field::new("field2", DataType::Utf8View, true), + ]); + + let string_view_values: Vec> = (0..num_rows) + .map(|i| match i % 3 { + 0 => None, + 1 => Some("foo"), + 2 => Some(LONG_TEST_STRING), + _ => unreachable!(), + }) + .collect(); + + let bin_view_values: Vec> = (0..num_rows) + .map(|i| match i % 3 { + 0 => None, + 1 => Some("bar".as_bytes()), + 2 => Some(LONG_TEST_STRING.as_bytes()), + _ => unreachable!(), + }) + .collect(); + + let binary_array = BinaryViewArray::from_iter(bin_view_values); + let utf8_array = StringViewArray::from_iter(string_view_values); + RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(binary_array), Arc::new(utf8_array)], + ) + .unwrap() +} diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 0185fa77f067..cbfae1825845 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -19,11 +19,7 @@ use std::{collections::HashMap, sync::Arc}; -use arrow_array::types::Int32Type; -use arrow_array::{ - ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch, StringViewArray, - UInt8Array, -}; +use arrow_array::{ArrayRef, RecordBatch}; use arrow_cast::pretty::pretty_format_batches; use arrow_flight::flight_descriptor::DescriptorType; use arrow_flight::FlightDescriptor; @@ -36,6 +32,9 @@ use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; use bytes::Bytes; use futures::{StreamExt, TryStreamExt}; +mod common; +use common::utils::{make_dictionary_batch, make_primitive_batch, make_view_batches}; + #[tokio::test] async fn test_empty() { roundtrip(vec![]).await; @@ -415,95 +414,6 @@ async fn test_mismatched_schema_message() { .await; } -/// Make a primitive batch for testing -/// -/// Example: -/// i: 0, 1, None, 3, 4 -/// f: 5.0, 4.0, None, 2.0, 1.0 -fn make_primitive_batch(num_rows: usize) -> RecordBatch { - let i: UInt8Array = (0..num_rows) - .map(|i| { - if i == num_rows / 2 { - None - } else { - Some(i.try_into().unwrap()) - } - }) - .collect(); - - let f: Float64Array = (0..num_rows) - .map(|i| { - if i == num_rows / 2 { - None - } else { - Some((num_rows - i) as f64) - } - }) - .collect(); - - RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap() -} - -/// Make a dictionary batch for testing -/// -/// Example: -/// a: value0, value1, value2, None, value1, value2 -fn make_dictionary_batch(num_rows: usize) -> RecordBatch { - let values: Vec<_> = (0..num_rows) - .map(|i| { - if i == num_rows / 2 { - None - } else { - // repeat some values for low cardinality - let v = i / 3; - Some(format!("value{v}")) - } - }) - .collect(); - - let a: DictionaryArray = values - .iter() - .map(|s| s.as_ref().map(|s| s.as_str())) - .collect(); - - RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap() -} - -fn make_view_batches(num_rows: usize) -> RecordBatch { - const LONG_TEST_STRING: &str = - "This is a long string to make sure binary view array handles it"; - let schema = Schema::new(vec![ - Field::new("field1", DataType::BinaryView, true), - Field::new("field2", DataType::Utf8View, true), - ]); - - let string_view_values: Vec> = (0..num_rows) - .map(|i| match i % 3 { - 0 => None, - 1 => Some("foo"), - 2 => Some(LONG_TEST_STRING), - _ => unreachable!(), - }) - .collect(); - - let bin_view_values: Vec> = (0..num_rows) - .map(|i| match i % 3 { - 0 => None, - 1 => Some("bar".as_bytes()), - 2 => Some(LONG_TEST_STRING.as_bytes()), - _ => unreachable!(), - }) - .collect(); - - let binary_array = BinaryViewArray::from_iter(bin_view_values); - let utf8_array = StringViewArray::from_iter(string_view_values); - RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(binary_array), Arc::new(utf8_array)], - ) - .unwrap() -} - /// Encodes input as a FlightData stream, and then decodes it using /// FlightRecordBatchStream and validates the decoded record batches /// match the input. From 4dcaeabd9060345e2150ba86e5f7f26ec79c7f51 Mon Sep 17 00:00:00 2001 From: Douglas Anderson Date: Tue, 6 Aug 2024 12:13:24 -0600 Subject: [PATCH 06/10] Implement execute_ingest for flight sql client I referenced the C++ implementation here: https://github.com/apache/arrow/commit/0d1ea5db1f9312412fe2cc28363e8c9deb2521ba --- arrow-flight/src/sql/client.rs | 37 ++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 91790898b1cb..3f0ed7e07a1e 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -39,8 +39,8 @@ use crate::sql::{ CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, - CommandStatementQuery, CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, - ProstMessageExt, SqlInfo, + CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate, + DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo, }; use crate::trailers::extract_lazy_trailers; use crate::{ @@ -53,10 +53,10 @@ use arrow_ipc::convert::fb_to_schema; use arrow_ipc::reader::read_record_batch; use arrow_ipc::{root_as_message, MessageHeader}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use futures::{stream, TryStreamExt}; +use futures::{stream, StreamExt, TryStreamExt}; use prost::Message; use tonic::transport::Channel; -use tonic::{IntoRequest, Streaming}; +use tonic::{IntoRequest, IntoStreamingRequest, Streaming}; /// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data /// by FlightSQL protocol. @@ -227,6 +227,35 @@ impl FlightSqlServiceClient { Ok(result.record_count) } + /// Execute a bulk ingest on the server and return the number of records added + pub async fn execute_ingest( + &mut self, + command: CommandStatementIngest, + batches: Vec, + ) -> Result { + let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec()); + let flight_data_encoder = FlightDataEncoderBuilder::new() + .with_flight_descriptor(Some(descriptor)) + .build(stream::iter(batches).map(Ok)); + // Safe unwrap, explicitly wrapped on line above. + let flight_data = flight_data_encoder.map(|fd| fd.unwrap()); + let req = self.set_request_headers(flight_data.into_streaming_request())?; + let mut result = self + .flight_client + .do_put(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result: DoPutUpdateResult = any.unpack()?.unwrap(); + Ok(result.record_count) + } + /// Request a list of catalogs as tabular FlightInfo results pub async fn get_catalogs(&mut self) -> Result { self.get_flight_info_for_command(CommandGetCatalogs {}) From dfa7e559f5c8a5c2f9d5790f440e049afd41514e Mon Sep 17 00:00:00 2001 From: Douglas Anderson Date: Tue, 6 Aug 2024 12:25:47 -0600 Subject: [PATCH 07/10] Add integration test for sql client execute_ingest --- arrow-flight/tests/flight_sql_client.rs | 61 ++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/arrow-flight/tests/flight_sql_client.rs b/arrow-flight/tests/flight_sql_client.rs index 94b768a13621..e38f00cf7565 100644 --- a/arrow-flight/tests/flight_sql_client.rs +++ b/arrow-flight/tests/flight_sql_client.rs @@ -18,14 +18,20 @@ mod common; use crate::common::fixture::TestFixture; +use crate::common::utils::make_primitive_batch; + +use arrow_array::RecordBatch; +use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::flight_service_server::FlightServiceServer; use arrow_flight::sql::client::FlightSqlServiceClient; -use arrow_flight::sql::server::FlightSqlService; +use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream}; use arrow_flight::sql::{ ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionEndTransactionRequest, - EndTransaction, SqlInfo, + CommandStatementIngest, EndTransaction, SqlInfo, TableDefinitionOptions, TableExistsOption, + TableNotExistOption, }; use arrow_flight::Action; +use futures::TryStreamExt; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; @@ -63,12 +69,49 @@ pub async fn test_begin_end_transaction() { .is_err()); } +#[tokio::test] +pub async fn test_execute_ingest() { + let test_server = FlightSqlServiceImpl::new(); + let fixture = TestFixture::new(test_server.service()).await; + let channel = fixture.channel().await; + let mut flight_sql_client = FlightSqlServiceClient::new(channel); + let cmd = CommandStatementIngest { + table_definition_options: Some(TableDefinitionOptions { + if_not_exist: TableNotExistOption::Create.into(), + if_exists: TableExistsOption::Fail.into(), + }), + table: String::from("test"), + schema: None, + catalog: None, + temporary: true, + transaction_id: None, + options: HashMap::default(), + }; + let expected_rows = 10; + let batches = vec![ + make_primitive_batch(5), + make_primitive_batch(3), + make_primitive_batch(2), + ]; + let actual_rows = flight_sql_client + .execute_ingest(cmd, batches) + .await + .expect("ingest should succeed"); + assert_eq!(actual_rows, expected_rows); +} + #[derive(Clone)] pub struct FlightSqlServiceImpl { transactions: Arc>>, } impl FlightSqlServiceImpl { + pub fn new() -> Self { + Self { + transactions: Arc::new(Mutex::new(HashMap::new())), + } + } + /// Return an [`FlightServiceServer`] that can be used with a /// [`Server`](tonic::transport::Server) pub fn service(&self) -> FlightServiceServer { @@ -116,4 +159,18 @@ impl FlightSqlService for FlightSqlServiceImpl { } async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} + + async fn do_put_statement_ingest( + &self, + _ticket: CommandStatementIngest, + request: Request, + ) -> Result { + let batches: Vec = FlightRecordBatchStream::new_from_flight_data( + request.into_inner().map_err(|e| e.into()), + ) + .try_collect() + .await?; + let affected_rows = batches.iter().map(|batch| batch.num_rows() as i64).sum(); + Ok(affected_rows) + } } From b48a06047f67e570f4ac3c39f08271756d58c7e0 Mon Sep 17 00:00:00 2001 From: Douglas Anderson Date: Tue, 6 Aug 2024 16:31:24 -0600 Subject: [PATCH 08/10] Fix lint clippy::new_without_default --- arrow-flight/tests/flight_sql_client.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/arrow-flight/tests/flight_sql_client.rs b/arrow-flight/tests/flight_sql_client.rs index e38f00cf7565..81d58138ed11 100644 --- a/arrow-flight/tests/flight_sql_client.rs +++ b/arrow-flight/tests/flight_sql_client.rs @@ -120,6 +120,12 @@ impl FlightSqlServiceImpl { } } +impl Default for FlightSqlServiceImpl { + fn default() -> Self { + Self::new() + } +} + #[tonic::async_trait] impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; From 6c4ed33b18b6bcb1acbeff40249bd32d5fb47dc1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 13 Aug 2024 16:27:39 -0400 Subject: [PATCH 09/10] Allow streaming ingest for FlightClient::execute_ingest --- arrow-flight/src/sql/client.rs | 13 ++++++++----- arrow-flight/tests/common/fixture.rs | 1 + arrow-flight/tests/flight_sql_client.rs | 14 +++++++++----- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 3f0ed7e07a1e..cfd6c9e78992 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -53,7 +53,7 @@ use arrow_ipc::convert::fb_to_schema; use arrow_ipc::reader::read_record_batch; use arrow_ipc::{root_as_message, MessageHeader}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use futures::{stream, StreamExt, TryStreamExt}; +use futures::{stream, Stream, StreamExt, TryStreamExt}; use prost::Message; use tonic::transport::Channel; use tonic::{IntoRequest, IntoStreamingRequest, Streaming}; @@ -228,15 +228,18 @@ impl FlightSqlServiceClient { } /// Execute a bulk ingest on the server and return the number of records added - pub async fn execute_ingest( + pub async fn execute_ingest( &mut self, command: CommandStatementIngest, - batches: Vec, - ) -> Result { + stream: S, + ) -> Result + where + S: Stream> + Send + 'static, + { let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec()); let flight_data_encoder = FlightDataEncoderBuilder::new() .with_flight_descriptor(Some(descriptor)) - .build(stream::iter(batches).map(Ok)); + .build(stream); // Safe unwrap, explicitly wrapped on line above. let flight_data = flight_data_encoder.map(|fd| fd.unwrap()); let req = self.set_request_headers(flight_data.into_streaming_request())?; diff --git a/arrow-flight/tests/common/fixture.rs b/arrow-flight/tests/common/fixture.rs index 141879e2a358..a666fa5d0d59 100644 --- a/arrow-flight/tests/common/fixture.rs +++ b/arrow-flight/tests/common/fixture.rs @@ -41,6 +41,7 @@ pub struct TestFixture { impl TestFixture { /// create a new test fixture from the server + #[allow(dead_code)] pub async fn new(test_server: FlightServiceServer) -> Self { // let OS choose a free port let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/arrow-flight/tests/flight_sql_client.rs b/arrow-flight/tests/flight_sql_client.rs index 81d58138ed11..e65309493649 100644 --- a/arrow-flight/tests/flight_sql_client.rs +++ b/arrow-flight/tests/flight_sql_client.rs @@ -31,7 +31,7 @@ use arrow_flight::sql::{ TableNotExistOption, }; use arrow_flight::Action; -use futures::TryStreamExt; +use futures::{StreamExt, TryStreamExt}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; @@ -40,9 +40,7 @@ use uuid::Uuid; #[tokio::test] pub async fn test_begin_end_transaction() { - let test_server = FlightSqlServiceImpl { - transactions: Arc::new(Mutex::new(HashMap::new())), - }; + let test_server = FlightSqlServiceImpl::new(); let fixture = TestFixture::new(test_server.service()).await; let channel = fixture.channel().await; let mut flight_sql_client = FlightSqlServiceClient::new(channel); @@ -94,21 +92,26 @@ pub async fn test_execute_ingest() { make_primitive_batch(2), ]; let actual_rows = flight_sql_client - .execute_ingest(cmd, batches) + .execute_ingest(cmd, futures::stream::iter(batches.clone()).map(Ok)) .await .expect("ingest should succeed"); assert_eq!(actual_rows, expected_rows); + // make sure the batches made it through to the server + let ingested_batches = test_server.ingested_batches.lock().await.clone(); + assert_eq!(ingested_batches, batches); } #[derive(Clone)] pub struct FlightSqlServiceImpl { transactions: Arc>>, + ingested_batches: Arc>>, } impl FlightSqlServiceImpl { pub fn new() -> Self { Self { transactions: Arc::new(Mutex::new(HashMap::new())), + ingested_batches: Arc::new(Mutex::new(Vec::new())), } } @@ -177,6 +180,7 @@ impl FlightSqlService for FlightSqlServiceImpl { .try_collect() .await?; let affected_rows = batches.iter().map(|batch| batch.num_rows() as i64).sum(); + *self.ingested_batches.lock().await.as_mut() = batches; Ok(affected_rows) } } From 22a4a1d75374bc13d270e614e37c6b0ef66682ab Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 15 Aug 2024 15:05:15 -0400 Subject: [PATCH 10/10] Properly return client errors --- arrow-flight/src/client.rs | 4 +- arrow-flight/src/sql/client.rs | 23 +++++++++-- arrow-flight/tests/flight_sql_client.rs | 54 +++++++++++++++++++------ 3 files changed, 63 insertions(+), 18 deletions(-) diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index 3f62b256d56e..af3c8fba30ff 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -679,7 +679,7 @@ impl FlightClient { /// it encounters an error it uses the oneshot sender to /// notify the error and stop any further streaming. See `do_put` or /// `do_exchange` for it's uses. -struct FallibleRequestStream { +pub(crate) struct FallibleRequestStream { /// sender to notify error sender: Option>, /// fallible stream @@ -687,7 +687,7 @@ struct FallibleRequestStream { } impl FallibleRequestStream { - fn new( + pub(crate) fn new( sender: Sender, fallible_stream: Pin> + Send + 'static>>, ) -> Self { diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index cfd6c9e78992..9f9963c92531 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -24,6 +24,7 @@ use std::collections::HashMap; use std::str::FromStr; use tonic::metadata::AsciiMetadataKey; +use crate::client::FallibleRequestStream; use crate::decode::FlightRecordBatchStream; use crate::encode::FlightDataEncoderBuilder; use crate::error::FlightError; @@ -53,7 +54,7 @@ use arrow_ipc::convert::fb_to_schema; use arrow_ipc::reader::read_record_batch; use arrow_ipc::{root_as_message, MessageHeader}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use futures::{stream, Stream, StreamExt, TryStreamExt}; +use futures::{stream, Stream, TryStreamExt}; use prost::Message; use tonic::transport::Channel; use tonic::{IntoRequest, IntoStreamingRequest, Streaming}; @@ -236,12 +237,18 @@ impl FlightSqlServiceClient { where S: Stream> + Send + 'static, { + let (sender, receiver) = futures::channel::oneshot::channel(); + let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec()); - let flight_data_encoder = FlightDataEncoderBuilder::new() + let flight_data = FlightDataEncoderBuilder::new() .with_flight_descriptor(Some(descriptor)) .build(stream); - // Safe unwrap, explicitly wrapped on line above. - let flight_data = flight_data_encoder.map(|fd| fd.unwrap()); + + // Intercept client errors and send them to the one shot channel above + let flight_data = Box::pin(flight_data); + let flight_data: FallibleRequestStream = + FallibleRequestStream::new(sender, flight_data); + let req = self.set_request_headers(flight_data.into_streaming_request())?; let mut result = self .flight_client @@ -249,6 +256,14 @@ impl FlightSqlServiceClient { .await .map_err(status_to_arrow_error)? .into_inner(); + + // check if the there were any errors in the input stream provided note + // if receiver.await fails, it means the sender was dropped and there is + // no message to return. + if let Ok(msg) = receiver.await { + return Err(ArrowError::ExternalError(Box::new(msg))); + } + let result = result .message() .await diff --git a/arrow-flight/tests/flight_sql_client.rs b/arrow-flight/tests/flight_sql_client.rs index e65309493649..349da062a82d 100644 --- a/arrow-flight/tests/flight_sql_client.rs +++ b/arrow-flight/tests/flight_sql_client.rs @@ -22,6 +22,7 @@ use crate::common::utils::make_primitive_batch; use arrow_array::RecordBatch; use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightServiceServer; use arrow_flight::sql::client::FlightSqlServiceClient; use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream}; @@ -73,18 +74,7 @@ pub async fn test_execute_ingest() { let fixture = TestFixture::new(test_server.service()).await; let channel = fixture.channel().await; let mut flight_sql_client = FlightSqlServiceClient::new(channel); - let cmd = CommandStatementIngest { - table_definition_options: Some(TableDefinitionOptions { - if_not_exist: TableNotExistOption::Create.into(), - if_exists: TableExistsOption::Fail.into(), - }), - table: String::from("test"), - schema: None, - catalog: None, - temporary: true, - transaction_id: None, - options: HashMap::default(), - }; + let cmd = make_ingest_command(); let expected_rows = 10; let batches = vec![ make_primitive_batch(5), @@ -101,6 +91,46 @@ pub async fn test_execute_ingest() { assert_eq!(ingested_batches, batches); } +#[tokio::test] +pub async fn test_execute_ingest_error() { + let test_server = FlightSqlServiceImpl::new(); + let fixture = TestFixture::new(test_server.service()).await; + let channel = fixture.channel().await; + let mut flight_sql_client = FlightSqlServiceClient::new(channel); + let cmd = make_ingest_command(); + // send an error from the client + let batches = vec![ + Ok(make_primitive_batch(5)), + Err(FlightError::NotYetImplemented( + "Client error message".to_string(), + )), + ]; + // make sure the client returns the error from the client + let err = flight_sql_client + .execute_ingest(cmd, futures::stream::iter(batches)) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "External error: Not yet implemented: Client error message" + ); +} + +fn make_ingest_command() -> CommandStatementIngest { + CommandStatementIngest { + table_definition_options: Some(TableDefinitionOptions { + if_not_exist: TableNotExistOption::Create.into(), + if_exists: TableExistsOption::Fail.into(), + }), + table: String::from("test"), + schema: None, + catalog: None, + temporary: true, + transaction_id: None, + options: HashMap::default(), + } +} + #[derive(Clone)] pub struct FlightSqlServiceImpl { transactions: Arc>>,