Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add an option to turn on compression for arrow output #4730

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ aquamarine = "0.3"
arrow = { version = "51.0.0", features = ["prettyprint"] }
arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] }
arrow-flight = "51.0"
arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4"] }
arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4", "zstd"] }
arrow-schema = { version = "51.0", features = ["serde"] }
async-stream = "0.3"
async-trait = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion src/servers/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ mod test {
RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
let json_resp = match format {
ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await,
ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
ResponseFormat::Csv => CsvResponse::from_output(outputs).await,
ResponseFormat::Table => TableResponse::from_output(outputs).await,
ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
Expand Down
95 changes: 90 additions & 5 deletions src/servers/src/http/arrow_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use std::pin::Pin;
use std::sync::Arc;

use arrow::datatypes::Schema;
use arrow_ipc::writer::FileWriter;
use arrow_ipc::writer::{FileWriter, IpcWriteOptions};
use arrow_ipc::CompressionType;
use axum::http::{header, HeaderValue};
use axum::response::{IntoResponse, Response};
use common_error::status_code::StatusCode;
Expand All @@ -41,10 +42,15 @@ pub struct ArrowResponse {
async fn write_arrow_bytes(
mut recordbatches: Pin<Box<dyn RecordBatchStream + Send>>,
schema: &Arc<Schema>,
compression: Option<CompressionType>,
) -> Result<Vec<u8>, Error> {
let mut bytes = Vec::new();
{
let mut writer = FileWriter::try_new(&mut bytes, schema).context(error::ArrowSnafu)?;
let options = IpcWriteOptions::default()
.try_with_compression(compression)
.context(error::ArrowSnafu)?;
let mut writer = FileWriter::try_new_with_options(&mut bytes, schema, options)
.context(error::ArrowSnafu)?;

while let Some(rb) = recordbatches.next().await {
let rb = rb.context(error::CollectRecordbatchSnafu)?;
Expand All @@ -59,15 +65,31 @@ async fn write_arrow_bytes(
Ok(bytes)
}

fn compression_type(compression: Option<String>) -> Option<CompressionType> {
match compression
.map(|compression| compression.to_lowercase())
.as_deref()
{
Some("zstd") => Some(CompressionType::ZSTD),
Some("lz4") => Some(CompressionType::LZ4_FRAME),
_ => None,
}
}

impl ArrowResponse {
pub async fn from_output(mut outputs: Vec<error::Result<Output>>) -> HttpResponse {
pub async fn from_output(
mut outputs: Vec<error::Result<Output>>,
compression: Option<String>,
) -> HttpResponse {
if outputs.len() > 1 {
return HttpResponse::Error(ErrorResponse::from_error_message(
StatusCode::InvalidArguments,
"cannot output multi-statements result in arrow format".to_string(),
));
}

let compression = compression_type(compression);

match outputs.pop() {
None => HttpResponse::Arrow(ArrowResponse {
data: vec![],
Expand All @@ -80,7 +102,9 @@ impl ArrowResponse {
}),
OutputData::RecordBatches(batches) => {
let schema = batches.schema();
match write_arrow_bytes(batches.as_stream(), schema.arrow_schema()).await {
match write_arrow_bytes(batches.as_stream(), schema.arrow_schema(), compression)
.await
{
Ok(payload) => HttpResponse::Arrow(ArrowResponse {
data: payload,
execution_time_ms: 0,
Expand All @@ -90,7 +114,7 @@ impl ArrowResponse {
}
OutputData::Stream(batches) => {
let schema = batches.schema();
match write_arrow_bytes(batches, schema.arrow_schema()).await {
match write_arrow_bytes(batches, schema.arrow_schema(), compression).await {
Ok(payload) => HttpResponse::Arrow(ArrowResponse {
data: payload,
execution_time_ms: 0,
Expand Down Expand Up @@ -136,3 +160,64 @@ impl IntoResponse for ArrowResponse {
.into_response()
}
}

#[cfg(test)]
mod test {
use std::io::Cursor;

use arrow_ipc::reader::FileReader;
use arrow_schema::DataType;
use common_recordbatch::{RecordBatch, RecordBatches};
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{StringVector, UInt32Vector};

use super::*;

#[tokio::test]
async fn test_arrow_output() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];

for compression in [None, Some("zstd".to_string()), Some("lz4".to_string())].into_iter() {
let recordbatch = RecordBatch::new(schema.clone(), columns.clone()).unwrap();
let recordbatches =
RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];

let http_resp = ArrowResponse::from_output(outputs, compression).await;
match http_resp {
HttpResponse::Arrow(resp) => {
let output = resp.data;
let mut reader =
FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
let schema = reader.schema();
assert_eq!(schema.fields[0].name(), "numbers");
assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
assert_eq!(schema.fields[1].name(), "strings");
assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);

let rb = reader.next().unwrap().expect("read record batch failed");
assert_eq!(rb.num_columns(), 2);
assert_eq!(rb.num_rows(), 4);
}
HttpResponse::Error(e) => {
panic!("unexpected {:?}", e);
}
_ => unreachable!(),
}
}
}
}
9 changes: 7 additions & 2 deletions src/servers/src/http/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ use crate::query_handler::sql::ServerSqlQueryHandlerRef;
pub struct SqlQuery {
pub db: Option<String>,
pub sql: Option<String>,
// (Optional) result format: [`greptimedb_v1`, `influxdb_v1`, `csv`],
// (Optional) result format: [`greptimedb_v1`, `influxdb_v1`, `csv`,
// `arrow`],
// the default value is `greptimedb_v1`
pub format: Option<String>,
// Returns epoch timestamps with the specified precision.
Expand All @@ -64,6 +65,8 @@ pub struct SqlQuery {
// param too.
pub epoch: Option<String>,
pub limit: Option<usize>,
// For arrow output
pub compression: Option<String>,
}

/// Handler to execute sql
Expand Down Expand Up @@ -128,7 +131,9 @@ pub async fn sql(
};

let mut resp = match format {
ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await,
ResponseFormat::Arrow => {
ArrowResponse::from_output(outputs, query_params.compression).await
}
ResponseFormat::Csv => CsvResponse::from_output(outputs).await,
ResponseFormat::Table => TableResponse::from_output(outputs).await,
ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
Expand Down