Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose bulk ingest in flight sql client and server #6201

Merged
merged 12 commits into from
Aug 15, 2024
14 changes: 11 additions & 3 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ use arrow_flight::sql::{
ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference,
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys,
CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, ProstMessageExt, Searchable,
SqlInfo, TicketStatementQuery, XdbcDataType,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementIngest,
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable,
ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, XdbcDataType,
};
use arrow_flight::utils::batches_to_flight_data;
use arrow_flight::{
Expand Down Expand Up @@ -615,6 +615,14 @@ impl FlightSqlService for FlightSqlServiceImpl {
Ok(FAKE_UPDATE_RESULT)
}

async fn do_put_statement_ingest(
&self,
_ticket: CommandStatementIngest,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Ok(FAKE_UPDATE_RESULT)
}

async fn do_put_substrait_plan(
&self,
_ticket: CommandStatementSubstraitPlan,
Expand Down
4 changes: 2 additions & 2 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -679,15 +679,15 @@ impl FlightClient {
/// it encounters an error it uses the oneshot sender to
/// notify the error and stop any further streaming. See `do_put` or
/// `do_exchange` for it's uses.
struct FallibleRequestStream<T, E> {
pub(crate) struct FallibleRequestStream<T, E> {
/// sender to notify error
sender: Option<Sender<E>>,
/// fallible stream
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
}

impl<T, E> FallibleRequestStream<T, E> {
fn new(
pub(crate) fn new(
sender: Sender<E>,
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
) -> Self {
Expand Down
55 changes: 51 additions & 4 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::collections::HashMap;
use std::str::FromStr;
use tonic::metadata::AsciiMetadataKey;

use crate::client::FallibleRequestStream;
use crate::decode::FlightRecordBatchStream;
use crate::encode::FlightDataEncoderBuilder;
use crate::error::FlightError;
Expand All @@ -39,8 +40,8 @@ use crate::sql::{
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementQuery, CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult,
ProstMessageExt, SqlInfo,
CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
};
use crate::trailers::extract_lazy_trailers;
use crate::{
Expand All @@ -53,10 +54,10 @@ use arrow_ipc::convert::fb_to_schema;
use arrow_ipc::reader::read_record_batch;
use arrow_ipc::{root_as_message, MessageHeader};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use futures::{stream, TryStreamExt};
use futures::{stream, Stream, TryStreamExt};
use prost::Message;
use tonic::transport::Channel;
use tonic::{IntoRequest, Streaming};
use tonic::{IntoRequest, IntoStreamingRequest, Streaming};

/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data
/// by FlightSQL protocol.
Expand Down Expand Up @@ -227,6 +228,52 @@ impl FlightSqlServiceClient<Channel> {
Ok(result.record_count)
}

/// Execute a bulk ingest on the server and return the number of records added
pub async fn execute_ingest<S>(
&mut self,
command: CommandStatementIngest,
stream: S,
) -> Result<i64, ArrowError>
where
S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
{
let (sender, receiver) = futures::channel::oneshot::channel();

let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
let flight_data = FlightDataEncoderBuilder::new()
.with_flight_descriptor(Some(descriptor))
.build(stream);

// Intercept client errors and send them to the one shot channel above
let flight_data = Box::pin(flight_data);
let flight_data: FallibleRequestStream<FlightData, FlightError> =
FallibleRequestStream::new(sender, flight_data);

let req = self.set_request_headers(flight_data.into_streaming_request())?;
let mut result = self
.flight_client
.do_put(req)
.await
.map_err(status_to_arrow_error)?
.into_inner();

// check if the there were any errors in the input stream provided note
// if receiver.await fails, it means the sender was dropped and there is
// no message to return.
if let Ok(msg) = receiver.await {
return Err(ArrowError::ExternalError(Box::new(msg)));
}

let result = result
.message()
.await
.map_err(status_to_arrow_error)?
.unwrap();
let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
let result: DoPutUpdateResult = any.unpack()?.unwrap();
Ok(result.record_count)
}

/// Request a list of catalogs as tabular FlightInfo results
pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(CommandGetCatalogs {})
Expand Down
8 changes: 7 additions & 1 deletion arrow-flight/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ mod gen {
}

pub use gen::action_end_transaction_request::EndTransaction;
pub use gen::command_statement_ingest::table_definition_options::{
TableExistsOption, TableNotExistOption,
};
pub use gen::command_statement_ingest::TableDefinitionOptions;
pub use gen::ActionBeginSavepointRequest;
pub use gen::ActionBeginSavepointResult;
pub use gen::ActionBeginTransactionRequest;
Expand All @@ -74,6 +78,7 @@ pub use gen::CommandGetTables;
pub use gen::CommandGetXdbcTypeInfo;
pub use gen::CommandPreparedStatementQuery;
pub use gen::CommandPreparedStatementUpdate;
pub use gen::CommandStatementIngest;
pub use gen::CommandStatementQuery;
pub use gen::CommandStatementSubstraitPlan;
pub use gen::CommandStatementUpdate;
Expand Down Expand Up @@ -250,11 +255,12 @@ prost_message_ext!(
CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery,
CommandPreparedStatementUpdate,
CommandStatementIngest,
CommandStatementQuery,
CommandStatementSubstraitPlan,
CommandStatementUpdate,
DoPutUpdateResult,
DoPutPreparedStatementResult,
DoPutUpdateResult,
TicketStatementQuery,
);

Expand Down
25 changes: 22 additions & 3 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use super::{
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate,
DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
TicketStatementQuery,
CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan,
CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt,
SqlInfo, TicketStatementQuery,
};
use crate::{
flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty,
Expand Down Expand Up @@ -397,6 +397,17 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
))
}

/// Execute a bulk ingestion.
async fn do_put_statement_ingest(
&self,
_ticket: CommandStatementIngest,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_statement_ingest has no default implementation",
))
}

/// Bind parameters to given prepared statement.
///
/// Returns an opaque handle that the client should pass
Expand Down Expand Up @@ -713,6 +724,14 @@ where
})]);
Ok(Response::new(Box::pin(output)))
}
Command::CommandStatementIngest(command) => {
let record_count = self.do_put_statement_ingest(command, request).await?;
let result = DoPutUpdateResult { record_count };
let output = futures::stream::iter(vec![Ok(PutResult {
app_metadata: result.as_any().encode_to_vec().into(),
})]);
Ok(Response::new(Box::pin(output)))
}
Command::CommandPreparedStatementQuery(command) => {
let result = self
.do_put_prepared_statement_query(command, request)
Expand Down
1 change: 1 addition & 0 deletions arrow-flight/tests/common/fixture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub struct TestFixture {

impl TestFixture {
/// create a new test fixture from the server
#[allow(dead_code)]
pub async fn new<T: FlightService>(test_server: FlightServiceServer<T>) -> Self {
// let OS choose a free port
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
Expand Down
1 change: 1 addition & 0 deletions arrow-flight/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
pub mod fixture;
pub mod server;
pub mod trailers_layer;
pub mod utils;
118 changes: 118 additions & 0 deletions arrow-flight/tests/common/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Common utilities for testing flight clients and servers

use std::sync::Arc;

use arrow_array::{
types::Int32Type, ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch,
StringViewArray, UInt8Array,
};
use arrow_schema::{DataType, Field, Schema};

/// Make a primitive batch for testing
///
/// Example:
/// i: 0, 1, None, 3, 4
/// f: 5.0, 4.0, None, 2.0, 1.0
#[allow(dead_code)]
pub fn make_primitive_batch(num_rows: usize) -> RecordBatch {
let i: UInt8Array = (0..num_rows)
.map(|i| {
if i == num_rows / 2 {
None
} else {
Some(i.try_into().unwrap())
}
})
.collect();

let f: Float64Array = (0..num_rows)
.map(|i| {
if i == num_rows / 2 {
None
} else {
Some((num_rows - i) as f64)
}
})
.collect();

RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap()
}

/// Make a dictionary batch for testing
///
/// Example:
/// a: value0, value1, value2, None, value1, value2
#[allow(dead_code)]
pub fn make_dictionary_batch(num_rows: usize) -> RecordBatch {
let values: Vec<_> = (0..num_rows)
.map(|i| {
if i == num_rows / 2 {
None
} else {
// repeat some values for low cardinality
let v = i / 3;
Some(format!("value{v}"))
}
})
.collect();

let a: DictionaryArray<Int32Type> = values
.iter()
.map(|s| s.as_ref().map(|s| s.as_str()))
.collect();

RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap()
}

#[allow(dead_code)]
pub fn make_view_batches(num_rows: usize) -> RecordBatch {
const LONG_TEST_STRING: &str =
"This is a long string to make sure binary view array handles it";
let schema = Schema::new(vec![
Field::new("field1", DataType::BinaryView, true),
Field::new("field2", DataType::Utf8View, true),
]);

let string_view_values: Vec<Option<&str>> = (0..num_rows)
.map(|i| match i % 3 {
0 => None,
1 => Some("foo"),
2 => Some(LONG_TEST_STRING),
_ => unreachable!(),
})
.collect();

let bin_view_values: Vec<Option<&[u8]>> = (0..num_rows)
.map(|i| match i % 3 {
0 => None,
1 => Some("bar".as_bytes()),
2 => Some(LONG_TEST_STRING.as_bytes()),
_ => unreachable!(),
})
.collect();

let binary_array = BinaryViewArray::from_iter(bin_view_values);
let utf8_array = StringViewArray::from_iter(string_view_values);
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(binary_array), Arc::new(utf8_array)],
)
.unwrap()
}
Loading
Loading