Skip to content

Commit

Permalink
Postgres should respect target decimal precision and scale (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov authored Sep 26, 2024
1 parent 4615519 commit ee092f8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
38 changes: 34 additions & 4 deletions src/sql/arrow_sql_gen/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use arrow::array::{
Time64NanosecondBuilder, TimestampNanosecondBuilder, UInt32Builder,
};
use arrow::datatypes::{
DataType, Date32Type, Field, Int8Type, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit,
DataType, Date32Type, Field, Int8Type, IntervalMonthDayNanoType, IntervalUnit, Schema,
SchemaRef, TimeUnit,
};
use bigdecimal::num_bigint::BigInt;
use bigdecimal::num_bigint::Sign;
Expand Down Expand Up @@ -182,7 +183,7 @@ macro_rules! handle_composite_types {
///
/// Returns an error if there is a failure in converting the rows to a `RecordBatch`.
#[allow(clippy::too_many_lines)]
pub fn rows_to_arrow(rows: &[Row]) -> Result<RecordBatch> {
pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option<SchemaRef>) -> Result<RecordBatch> {
let mut arrow_fields: Vec<Option<Field>> = Vec::new();
let mut arrow_columns_builders: Vec<Option<Box<dyn ArrayBuilder>>> = Vec::new();
let mut postgres_types: Vec<Type> = Vec::new();
Expand All @@ -194,14 +195,32 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result<RecordBatch> {
for column in row.columns() {
let column_name = column.name();
let column_type = column.type_();
let data_type = map_column_type_to_data_type(column_type);

let mut numeric_scale: Option<u16> = None;

let data_type = if *column_type == Type::NUMERIC {
if let Some(schema) = projected_schema.as_ref() {
match get_decimal_column_precision_and_scale(column_name, schema) {
Some((precision, scale)) => {
numeric_scale = Some(u16::try_from(scale).unwrap_or_default());
Some(DataType::Decimal128(precision, scale))
}
None => None,
}
} else {
None
}
} else {
map_column_type_to_data_type(column_type)
};

match &data_type {
Some(data_type) => {
arrow_fields.push(Some(Field::new(column_name, data_type.clone(), true)));
}
None => arrow_fields.push(None),
}
postgres_numeric_scales.push(None);
postgres_numeric_scales.push(numeric_scale);
arrow_columns_builders
.push(map_data_type_to_array_builder_optional(data_type.as_ref()));
postgres_types.push(column_type.clone());
Expand Down Expand Up @@ -1251,3 +1270,14 @@ mod tests {
assert_eq!(positive_result.wkb, positive_geometry);
}
}

fn get_decimal_column_precision_and_scale(
column_name: &str,
projected_schema: &SchemaRef,
) -> Option<(u8, i8)> {
let field = projected_schema.field_with_name(column_name).ok()?;
match field.data_type() {
DataType::Decimal128(precision, scale) => Some((*precision, *scale)),
_ => None,
}
}
8 changes: 4 additions & 4 deletions src/sql/db_connection_pool/dbconnection/postgresconn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl<'a>
}
};

let rec = match rows_to_arrow(rows.as_slice()) {
let rec = match rows_to_arrow(rows.as_slice(), &None) {
Ok(rec) => rec,
Err(e) => {
return Err(super::Error::UnableToGetSchema {
Expand All @@ -128,7 +128,7 @@ impl<'a>
&self,
sql: &str,
params: &[&'a (dyn ToSql + Sync)],
_projected_schema: Option<SchemaRef>,
projected_schema: Option<SchemaRef>,
) -> Result<SendableRecordBatchStream> {
// TODO: We should have a way to detect if params have been passed
// if they haven't we should use .copy_out instead, because it should be much faster
Expand All @@ -139,12 +139,12 @@ impl<'a>
.context(QuerySnafu)?;

// chunk the stream into groups of rows
let mut stream = streamable.chunks(4_000).boxed().map(|rows| {
let mut stream = streamable.chunks(4_000).boxed().map(move |rows| {
let rows = rows
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.context(QuerySnafu)?;
let rec = rows_to_arrow(rows.as_slice()).context(ConversionSnafu)?;
let rec = rows_to_arrow(rows.as_slice(), &projected_schema).context(ConversionSnafu)?;
Ok::<_, PostgresError>(rec)
});

Expand Down

0 comments on commit ee092f8

Please sign in to comment.