Skip to content

Commit

Permalink
Postgres: Array enum encoding (#1511)
Browse files Browse the repository at this point in the history
* Postgres: Add test for array enum

* Allow produces() to override type_info() as per doc

* run cargo fmt
  • Loading branch information
chesedo authored Dec 29, 2021
1 parent 04109d9 commit dee5147
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 2 deletions.
8 changes: 7 additions & 1 deletion sqlx-core/src/postgres/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,17 @@ where
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
let type_info = if self.len() < 1 {
T::type_info()
} else {
self[0].produces().unwrap_or_else(T::type_info)
};

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags

// element type
match T::type_info().0 {
match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),

ty => {
Expand Down
2 changes: 1 addition & 1 deletion sqlx-core/src/postgres/types/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<'a> PgRecordEncoder<'a> {
'a: 'q,
T: Encode<'q, Postgres> + Type<Postgres>,
{
let ty = T::type_info();
let ty = value.produces().unwrap_or_else(T::type_info);

if let PgType::DeclareWithName(name) = ty.0 {
// push a hole for this type ID
Expand Down
122 changes: 122 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,128 @@ async fn it_can_copy_out() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn it_encodes_custom_array_issue_1504() -> anyhow::Result<()> {
use sqlx::encode::IsNull;
use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo};
use sqlx::{Decode, Encode, Type, ValueRef};

#[derive(Debug, PartialEq)]
enum Value {
String(String),
Number(i32),
Array(Vec<Value>),
}

impl<'r> Decode<'r, Postgres> for Value {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let typ = value.type_info().into_owned();

if typ == PgTypeInfo::with_name("text") {
let s = <String as Decode<'_, Postgres>>::decode(value)?;

Ok(Self::String(s))
} else if typ == PgTypeInfo::with_name("int4") {
let n = <i32 as Decode<'_, Postgres>>::decode(value)?;

Ok(Self::Number(n))
} else if typ == PgTypeInfo::with_name("_text") {
let arr = Vec::<String>::decode(value)?;
let v = arr.into_iter().map(|s| Value::String(s)).collect();

Ok(Self::Array(v))
} else if typ == PgTypeInfo::with_name("_int4") {
let arr = Vec::<i32>::decode(value)?;
let v = arr.into_iter().map(|n| Value::Number(n)).collect();

Ok(Self::Array(v))
} else {
Err("unknown type".into())
}
}
}

impl Encode<'_, Postgres> for Value {
fn produces(&self) -> Option<PgTypeInfo> {
match self {
Self::Array(a) => {
if a.len() < 1 {
return Some(PgTypeInfo::with_name("_text"));
}

match a[0] {
Self::String(_) => Some(PgTypeInfo::with_name("_text")),
Self::Number(_) => Some(PgTypeInfo::with_name("_int4")),
Self::Array(_) => None,
}
}
Self::String(_) => Some(PgTypeInfo::with_name("text")),
Self::Number(_) => Some(PgTypeInfo::with_name("int4")),
}
}

fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
match self {
Value::String(s) => <String as Encode<'_, Postgres>>::encode_by_ref(s, buf),
Value::Number(n) => <i32 as Encode<'_, Postgres>>::encode_by_ref(n, buf),
Value::Array(arr) => arr.encode(buf),
}
}
}

impl Type<Postgres> for Value {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("unknown")
}

fn compatible(ty: &PgTypeInfo) -> bool {
[
PgTypeInfo::with_name("text"),
PgTypeInfo::with_name("_text"),
PgTypeInfo::with_name("int4"),
PgTypeInfo::with_name("_int4"),
]
.contains(ty)
}
}

let mut conn = new::<Postgres>().await?;

let (row,): (Value,) = sqlx::query_as("SELECT $1::text[] as Dummy")
.bind(Value::Array(vec![
Value::String("Test 0".to_string()),
Value::String("Test 1".to_string()),
]))
.fetch_one(&mut conn)
.await?;

assert_eq!(
row,
Value::Array(vec![
Value::String("Test 0".to_string()),
Value::String("Test 1".to_string()),
])
);

let (row,): (Value,) = sqlx::query_as("SELECT $1::int4[] as Dummy")
.bind(Value::Array(vec![
Value::Number(3),
Value::Number(2),
Value::Number(1),
]))
.fetch_one(&mut conn)
.await?;

assert_eq!(
row,
Value::Array(vec![Value::Number(3), Value::Number(2), Value::Number(1)])
);

Ok(())
}

#[sqlx_macros::test]
async fn test_issue_1254() -> anyhow::Result<()> {
#[derive(sqlx::Type)]
Expand Down

0 comments on commit dee5147

Please sign in to comment.