From ee092f87ffe19a53626697f0da44a7951c7014d0 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Wed, 25 Sep 2024 22:10:43 -0700 Subject: [PATCH] Postgres should respect target decimal precision and scale (#120) --- src/sql/arrow_sql_gen/postgres.rs | 38 +++++++++++++++++-- .../dbconnection/postgresconn.rs | 8 ++-- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/sql/arrow_sql_gen/postgres.rs b/src/sql/arrow_sql_gen/postgres.rs index f247c16..221618f 100644 --- a/src/sql/arrow_sql_gen/postgres.rs +++ b/src/sql/arrow_sql_gen/postgres.rs @@ -12,7 +12,8 @@ use arrow::array::{ Time64NanosecondBuilder, TimestampNanosecondBuilder, UInt32Builder, }; use arrow::datatypes::{ - DataType, Date32Type, Field, Int8Type, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, + DataType, Date32Type, Field, Int8Type, IntervalMonthDayNanoType, IntervalUnit, Schema, + SchemaRef, TimeUnit, }; use bigdecimal::num_bigint::BigInt; use bigdecimal::num_bigint::Sign; @@ -182,7 +183,7 @@ macro_rules! handle_composite_types { /// /// Returns an error if there is a failure in converting the rows to a `RecordBatch`. #[allow(clippy::too_many_lines)] -pub fn rows_to_arrow(rows: &[Row]) -> Result { +pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Result { let mut arrow_fields: Vec> = Vec::new(); let mut arrow_columns_builders: Vec>> = Vec::new(); let mut postgres_types: Vec = Vec::new(); @@ -194,14 +195,32 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { for column in row.columns() { let column_name = column.name(); let column_type = column.type_(); - let data_type = map_column_type_to_data_type(column_type); + + let mut numeric_scale: Option = None; + + let data_type = if *column_type == Type::NUMERIC { + if let Some(schema) = projected_schema.as_ref() { + match get_decimal_column_precision_and_scale(column_name, schema) { + Some((precision, scale)) => { + numeric_scale = Some(u16::try_from(scale).unwrap_or_default()); + Some(DataType::Decimal128(precision, scale)) + } + None => None, + } + } else { + None + } + } else { + map_column_type_to_data_type(column_type) + }; + match &data_type { Some(data_type) => { arrow_fields.push(Some(Field::new(column_name, data_type.clone(), true))); } None => arrow_fields.push(None), } - postgres_numeric_scales.push(None); + postgres_numeric_scales.push(numeric_scale); arrow_columns_builders .push(map_data_type_to_array_builder_optional(data_type.as_ref())); postgres_types.push(column_type.clone()); @@ -1251,3 +1270,14 @@ mod tests { assert_eq!(positive_result.wkb, positive_geometry); } } + +fn get_decimal_column_precision_and_scale( + column_name: &str, + projected_schema: &SchemaRef, +) -> Option<(u8, i8)> { + let field = projected_schema.field_with_name(column_name).ok()?; + match field.data_type() { + DataType::Decimal128(precision, scale) => Some((*precision, *scale)), + _ => None, + } +} diff --git a/src/sql/db_connection_pool/dbconnection/postgresconn.rs b/src/sql/db_connection_pool/dbconnection/postgresconn.rs index 870e92c..0c78d48 100644 --- a/src/sql/db_connection_pool/dbconnection/postgresconn.rs +++ b/src/sql/db_connection_pool/dbconnection/postgresconn.rs @@ -111,7 +111,7 @@ impl<'a> } }; - let rec = match rows_to_arrow(rows.as_slice()) { + let rec = match rows_to_arrow(rows.as_slice(), &None) { Ok(rec) => rec, Err(e) => { return Err(super::Error::UnableToGetSchema { @@ -128,7 +128,7 @@ impl<'a> &self, sql: &str, params: &[&'a (dyn ToSql + Sync)], - _projected_schema: Option, + projected_schema: Option, ) -> Result { // TODO: We should have a way to detect if params have been passed // if they haven't we should use .copy_out instead, because it should be much faster @@ -139,12 +139,12 @@ impl<'a> .context(QuerySnafu)?; // chunk the stream into groups of rows - let mut stream = streamable.chunks(4_000).boxed().map(|rows| { + let mut stream = streamable.chunks(4_000).boxed().map(move |rows| { let rows = rows .into_iter() .collect::, _>>() .context(QuerySnafu)?; - let rec = rows_to_arrow(rows.as_slice()).context(ConversionSnafu)?; + let rec = rows_to_arrow(rows.as_slice(), &projected_schema).context(ConversionSnafu)?; Ok::<_, PostgresError>(rec) });