Skip to content

Commit

Permalink
Fix support for Postgres array of custom types (#1483)
Browse files Browse the repository at this point in the history
This commit fixes the array decoder to support custom types. The core of the issue was that the array decoder did not use the type info retrieved from the database. It means that it only supported native types.

This commit fixes the issue by using the element type info fetched from the database. A new internal helper method is added to the `PgType` struct: it returns the type info for the inner array element, if available.

Closes #1477
  • Loading branch information
demurgos authored Dec 29, 2021
1 parent dee5147 commit 32f1273
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 4 deletions.
6 changes: 5 additions & 1 deletion sqlx-core/src/postgres/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::sync::Arc;
/// Describes the type of the `pg_type.typtype` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum TypType {
Base,
Composite,
Expand Down Expand Up @@ -45,6 +46,7 @@ impl TryFrom<u8> for TypType {
/// Describes the type of the `pg_type.typcategory` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum TypCategory {
Array,
Boolean,
Expand Down Expand Up @@ -198,7 +200,9 @@ impl PgConnection {

(Ok(TypType::Base), Ok(TypCategory::Array)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
kind: PgTypeKind::Array(
self.maybe_fetch_type_info_by_oid(element, true).await?,
),
name: name.into(),
oid,
}))))
Expand Down
120 changes: 120 additions & 0 deletions sqlx-core/src/postgres/type_info.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(dead_code)]

use std::borrow::Cow;
use std::fmt::{self, Display, Formatter};
use std::ops::Deref;
use std::sync::Arc;
Expand Down Expand Up @@ -750,6 +751,125 @@ impl PgType {
}
}
}

/// If `self` is an array type, return the type info for its element.
///
/// This method should only be called on resolved types: calling it on
/// a type that is merely declared (DeclareWithOid/Name) is a bug.
pub(crate) fn try_array_element(&self) -> Option<Cow<'_, PgTypeInfo>> {
// We explicitly match on all the `None` cases to ensure an exhaustive match.
match self {
PgType::Bool => None,
PgType::BoolArray => Some(Cow::Owned(PgTypeInfo(PgType::Bool))),
PgType::Bytea => None,
PgType::ByteaArray => Some(Cow::Owned(PgTypeInfo(PgType::Bytea))),
PgType::Char => None,
PgType::CharArray => Some(Cow::Owned(PgTypeInfo(PgType::Char))),
PgType::Name => None,
PgType::NameArray => Some(Cow::Owned(PgTypeInfo(PgType::Name))),
PgType::Int8 => None,
PgType::Int8Array => Some(Cow::Owned(PgTypeInfo(PgType::Int8))),
PgType::Int2 => None,
PgType::Int2Array => Some(Cow::Owned(PgTypeInfo(PgType::Int2))),
PgType::Int4 => None,
PgType::Int4Array => Some(Cow::Owned(PgTypeInfo(PgType::Int4))),
PgType::Text => None,
PgType::TextArray => Some(Cow::Owned(PgTypeInfo(PgType::Text))),
PgType::Oid => None,
PgType::OidArray => Some(Cow::Owned(PgTypeInfo(PgType::Oid))),
PgType::Json => None,
PgType::JsonArray => Some(Cow::Owned(PgTypeInfo(PgType::Json))),
PgType::Point => None,
PgType::PointArray => Some(Cow::Owned(PgTypeInfo(PgType::Point))),
PgType::Lseg => None,
PgType::LsegArray => Some(Cow::Owned(PgTypeInfo(PgType::Lseg))),
PgType::Path => None,
PgType::PathArray => Some(Cow::Owned(PgTypeInfo(PgType::Path))),
PgType::Box => None,
PgType::BoxArray => Some(Cow::Owned(PgTypeInfo(PgType::Box))),
PgType::Polygon => None,
PgType::PolygonArray => Some(Cow::Owned(PgTypeInfo(PgType::Polygon))),
PgType::Line => None,
PgType::LineArray => Some(Cow::Owned(PgTypeInfo(PgType::Line))),
PgType::Cidr => None,
PgType::CidrArray => Some(Cow::Owned(PgTypeInfo(PgType::Cidr))),
PgType::Float4 => None,
PgType::Float4Array => Some(Cow::Owned(PgTypeInfo(PgType::Float4))),
PgType::Float8 => None,
PgType::Float8Array => Some(Cow::Owned(PgTypeInfo(PgType::Float8))),
PgType::Circle => None,
PgType::CircleArray => Some(Cow::Owned(PgTypeInfo(PgType::Circle))),
PgType::Macaddr8 => None,
PgType::Macaddr8Array => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr8))),
PgType::Money => None,
PgType::MoneyArray => Some(Cow::Owned(PgTypeInfo(PgType::Money))),
PgType::Macaddr => None,
PgType::MacaddrArray => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr))),
PgType::Inet => None,
PgType::InetArray => Some(Cow::Owned(PgTypeInfo(PgType::Inet))),
PgType::Bpchar => None,
PgType::BpcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Bpchar))),
PgType::Varchar => None,
PgType::VarcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Varchar))),
PgType::Date => None,
PgType::DateArray => Some(Cow::Owned(PgTypeInfo(PgType::Date))),
PgType::Time => None,
PgType::TimeArray => Some(Cow::Owned(PgTypeInfo(PgType::Time))),
PgType::Timestamp => None,
PgType::TimestampArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamp))),
PgType::Timestamptz => None,
PgType::TimestamptzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamptz))),
PgType::Interval => None,
PgType::IntervalArray => Some(Cow::Owned(PgTypeInfo(PgType::Interval))),
PgType::Timetz => None,
PgType::TimetzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timetz))),
PgType::Bit => None,
PgType::BitArray => Some(Cow::Owned(PgTypeInfo(PgType::Bit))),
PgType::Varbit => None,
PgType::VarbitArray => Some(Cow::Owned(PgTypeInfo(PgType::Varbit))),
PgType::Numeric => None,
PgType::NumericArray => Some(Cow::Owned(PgTypeInfo(PgType::Numeric))),
PgType::Record => None,
PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))),
PgType::Uuid => None,
PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))),
PgType::Jsonb => None,
PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))),
PgType::Int4Range => None,
PgType::Int4RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int4Range))),
PgType::NumRange => None,
PgType::NumRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::NumRange))),
PgType::TsRange => None,
PgType::TsRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TsRange))),
PgType::TstzRange => None,
PgType::TstzRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TstzRange))),
PgType::DateRange => None,
PgType::DateRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::DateRange))),
PgType::Int8Range => None,
PgType::Int8RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int8Range))),
PgType::Jsonpath => None,
PgType::JsonpathArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonpath))),
// There is no `UnknownArray`
PgType::Unknown => None,
// There is no `VoidArray`
PgType::Void => None,
PgType::Custom(ty) => match &ty.kind {
PgTypeKind::Simple => None,
PgTypeKind::Pseudo => None,
PgTypeKind::Domain(_) => None,
PgTypeKind::Composite(_) => None,
PgTypeKind::Array(ref elem_type_info) => Some(Cow::Borrowed(elem_type_info)),
PgTypeKind::Enum(_) => None,
PgTypeKind::Range(_) => None,
},
PgType::DeclareWithOid(oid) => {
unreachable!("(bug) use of unresolved type declaration [oid={}]", oid);
}
PgType::DeclareWithName(name) => {
unreachable!("(bug) use of unresolved type declaration [name={}]", name);
}
}
}
}

impl TypeInfo for PgTypeInfo {
Expand Down
7 changes: 4 additions & 3 deletions sqlx-core/src/postgres/types/array.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use bytes::Buf;
use std::borrow::Cow;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand Down Expand Up @@ -103,7 +104,6 @@ where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let element_type_info;
let format = value.format();

match format {
Expand Down Expand Up @@ -131,7 +131,8 @@ where

// the OID of the element
let element_type_oid = buf.get_u32();
element_type_info = PgTypeInfo::try_from_oid(element_type_oid)
let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid)
.or_else(|| value.type_info.try_array_element().map(Cow::into_owned))
.unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid)));

// length of the array axis
Expand Down Expand Up @@ -159,7 +160,7 @@ where

PgValueFormat::Text => {
// no type is provided from the database for the element
element_type_info = T::type_info();
let element_type_info = T::type_info();

let s = value.as_str()?;

Expand Down
87 changes: 87 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,93 @@ CREATE TABLE heating_bills (
Ok(())
}

#[sqlx_macros::test]
async fn it_resolves_custom_type_in_array() -> anyhow::Result<()> {
// Only supported in Postgres 11+
let mut conn = new::<Postgres>().await?;
if matches!(conn.server_version_num(), Some(version) if version < 110000) {
return Ok(());
}

// language=PostgreSQL
conn.execute(
r#"
DROP TABLE IF EXISTS pets;
DROP TYPE IF EXISTS pet_name_and_race;
CREATE TYPE pet_name_and_race AS (
name TEXT,
race TEXT
);
CREATE TABLE pets (
owner TEXT NOT NULL,
name TEXT NOT NULL,
race TEXT NOT NULL,
PRIMARY KEY (owner, name)
);
INSERT INTO pets(owner, name, race)
VALUES
('Alice', 'Foo', 'cat');
INSERT INTO pets(owner, name, race)
VALUES
('Alice', 'Bar', 'dog');
"#,
)
.await?;

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct PetNameAndRace {
name: String,
race: String,
}

impl sqlx::Type<Postgres> for PetNameAndRace {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("pet_name_and_race")
}
}

impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRace {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;
let name = decoder.try_decode::<String>()?;
let race = decoder.try_decode::<String>()?;
Ok(Self { name, race })
}
}

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct PetNameAndRaceArray(Vec<PetNameAndRace>);

impl sqlx::Type<Postgres> for PetNameAndRaceArray {
fn type_info() -> sqlx::postgres::PgTypeInfo {
// Array type name is the name of the element type prefixed with `_`
sqlx::postgres::PgTypeInfo::with_name("_pet_name_and_race")
}
}

impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRaceArray {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
Ok(Self(Vec::<PetNameAndRace>::decode(value)?))
}
}

let mut conn = new::<Postgres>().await?;

let row = sqlx::query("select owner, array_agg(row(name, race)::pet_name_and_race) as pets from pets group by owner")
.fetch_one(&mut conn)
.await?;

let pets: PetNameAndRaceArray = row.get("pets");

assert_eq!(pets.0.len(), 2);
Ok(())
}

#[sqlx_macros::test]
async fn test_pg_server_num() -> anyhow::Result<()> {
use sqlx::postgres::PgConnectionInfo;
Expand Down

0 comments on commit 32f1273

Please sign in to comment.