Skip to content

Commit

Permalink
feat: compatible with postgres interval type (#2146)
Browse files Browse the repository at this point in the history
* feat: impl ToSql/FromSql/ToSqlText for PgInterval.

* chore: remove useless code.

* feat: compatible with postgres interval type.

* chore: cr comment.
  • Loading branch information
QuenKar authored Aug 11, 2023
1 parent 2dcc677 commit 6877d08
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 8 deletions.
60 changes: 52 additions & 8 deletions src/servers/src/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod interval;

use std::collections::HashMap;
use std::ops::Deref;

use chrono::{NaiveDate, NaiveDateTime};
use common_time::Interval;
use datafusion_common::ScalarValue;
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::Schema;
Expand All @@ -26,6 +29,7 @@ use pgwire::api::Type;
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use query::plan::LogicalPlan;

use self::interval::PgInterval;
use crate::error::{self, Error, Result};
use crate::SqlPlan;

Expand Down Expand Up @@ -98,14 +102,13 @@ pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWir
})))
}
}
Value::Interval(_) | Value::List(_) => {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!(
"cannot write value {:?} in postgres protocol: unimplemented",
&value
),
})))
}
Value::Interval(v) => builder.encode_field(&PgInterval::from(*v)),
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!(
"cannot write value {:?} in postgres protocol: unimplemented",
&value
),
}))),
}
}

Expand Down Expand Up @@ -195,6 +198,10 @@ pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWir
.parameter::<NaiveDateTime>(idx, param_type)?
.map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INTERVAL => Ok(portal
.parameter::<PgInterval>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
_ => Err(invalid_parameter_error(
"unsupported_parameter_type",
Some(&param_type.to_string()),
Expand Down Expand Up @@ -478,6 +485,23 @@ pub(super) fn parameters_to_scalar_values(
}
}
}
&Type::INTERVAL => {
let data = portal.parameter::<PgInterval>(idx, &client_type)?;
match server_type {
ConcreteDataType::Interval(_) => {
ScalarValue::IntervalMonthDayNano(data.map(|i| Interval::from(i).to_i128()))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
));
}
}
}
&Type::BYTEA => {
let data = portal.parameter::<Vec<u8>>(idx, &client_type)?;
match server_type {
Expand Down Expand Up @@ -559,6 +583,11 @@ mod test {
),
ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
ColumnSchema::new("times", ConcreteDataType::time_second_datatype(), true),
ColumnSchema::new(
"intervals",
ConcreteDataType::interval_month_day_nano_datatype(),
true,
),
];
let pg_field_info = vec![
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
Expand Down Expand Up @@ -608,6 +637,13 @@ mod test {
),
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
FieldInfo::new(
"intervals".into(),
None,
None,
Type::INTERVAL,
FieldFormat::Text,
),
];
let schema = Schema::new(column_schemas);
let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
Expand Down Expand Up @@ -703,6 +739,13 @@ mod test {
Type::TIMESTAMP,
FieldFormat::Text,
),
FieldInfo::new(
"intervals".into(),
None,
None,
Type::INTERVAL,
FieldFormat::Text,
),
];

let values = vec![
Expand Down Expand Up @@ -732,6 +775,7 @@ mod test {
Value::Time(1001i64.into()),
Value::DateTime(1000001i64.into()),
Value::Timestamp(1000001i64.into()),
Value::Interval(1000001i128.into()),
];
let mut builder = DataRowEncoder::new(Arc::new(schema));
for i in values.iter() {
Expand Down
127 changes: 127 additions & 0 deletions src/servers/src/postgres/types/interval.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Display;

use bytes::{Buf, BufMut};
use common_time::Interval;
use pgwire::types::ToSqlText;
use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};

#[derive(Debug, Clone, Copy, Default)]
pub struct PgInterval {
months: i32,
days: i32,
microseconds: i64,
}

impl From<Interval> for PgInterval {
fn from(interval: Interval) -> Self {
let (months, days, nanos) = interval.to_month_day_nano();
Self {
months,
days,
microseconds: nanos / 1000,
}
}
}

impl From<PgInterval> for Interval {
fn from(interval: PgInterval) -> Self {
Interval::from_month_day_nano(
interval.months,
interval.days,
// Maybe overflow, but most scenarios ok.
interval.microseconds.checked_mul(1000).unwrap_or_else(|| {
if interval.microseconds.is_negative() {
i64::MIN
} else {
i64::MAX
}
}),
)
}
}

impl Display for PgInterval {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", Interval::from(*self).to_postgres_string())
}
}

impl ToSql for PgInterval {
to_sql_checked!();

fn to_sql(
&self,
_: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<postgres_types::IsNull, Box<dyn snafu::Error + Sync + Send>>
where
Self: Sized,
{
// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L989-L991
out.put_i64(self.microseconds);
out.put_i32(self.days);
out.put_i32(self.months);
Ok(postgres_types::IsNull::No)
}

fn accepts(ty: &Type) -> bool
where
Self: Sized,
{
matches!(ty, &Type::INTERVAL)
}
}

impl<'a> FromSql<'a> for PgInterval {
fn from_sql(
_: &Type,
mut raw: &'a [u8],
) -> std::result::Result<Self, Box<dyn snafu::Error + Sync + Send>> {
// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L1007-L1010
let microseconds = raw.get_i64();
let days = raw.get_i32();
let months = raw.get_i32();
Ok(PgInterval {
months,
days,
microseconds,
})
}

fn accepts(ty: &Type) -> bool {
matches!(ty, &Type::INTERVAL)
}
}

impl ToSqlText for PgInterval {
fn to_sql_text(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<postgres_types::IsNull, Box<dyn snafu::Error + Sync + Send>>
where
Self: Sized,
{
let fmt = match ty {
&Type::INTERVAL => self.to_string(),
_ => return Err("unsupported type".into()),
};

out.put_slice(fmt.as_bytes());
Ok(IsNull::No)
}
}

0 comments on commit 6877d08

Please sign in to comment.