From af52382fe74817ad0312b08a1737ba826ea52320 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 5 May 2024 09:51:15 -0700 Subject: [PATCH] Added support for bit type to SQLx --- CHANGELOG.md | 2 +- src/bit.rs | 2 +- src/sqlx_ext/bit.rs | 109 ++++++++++++++++++++++++++++++++++++++++++++ src/sqlx_ext/mod.rs | 1 + 4 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 src/sqlx_ext/bit.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 672f9c9..c02414d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ ## 0.3.3 (unreleased) - Added support for `halfvec`, `bit`, and `sparsevec` types to Rust-Postgres -- Added support for `halfvec` and `sparsevec` type to SQLx +- Added support for `halfvec`, `bit`, and `sparsevec` type to SQLx - Added support for `halfvec` and `sparsevec` type to Diesel - Added `l1_distance` function for Diesel diff --git a/src/bit.rs b/src/bit.rs index d48d7b5..4ff1444 100644 --- a/src/bit.rs +++ b/src/bit.rs @@ -24,7 +24,7 @@ impl<'a> Bit<'a> { self.data } - #[cfg(any(feature = "postgres"))] + #[cfg(any(feature = "postgres", feature = "sqlx"))] pub(crate) fn from_sql(buf: &[u8]) -> Result> { let len = i32::from_be_bytes(buf[0..4].try_into()?) as usize; let data = &buf[4..4 + len / 8]; diff --git a/src/sqlx_ext/bit.rs b/src/sqlx_ext/bit.rs new file mode 100644 index 0000000..e6f2628 --- /dev/null +++ b/src/sqlx_ext/bit.rs @@ -0,0 +1,109 @@ +use sqlx::encode::IsNull; +use sqlx::error::BoxDynError; +use sqlx::postgres::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef}; +use sqlx::{Decode, Encode, Postgres, Type}; +use std::convert::TryFrom; + +use crate::Bit; + +impl<'a> Type for Bit<'a> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("bit") + } +} + +impl<'a> Encode<'a, Postgres> for Bit<'a> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + let len = self.len; + buf.extend(&i32::try_from(len).unwrap().to_be_bytes()); + + for v in self.data { + buf.extend(&v.to_be_bytes()); + } + + IsNull::No + } +} + +impl<'a> Decode<'a, Postgres> for Bit<'a> { + fn decode(value: PgValueRef<'a>) -> Result { + let buf = <&[u8] as Decode>::decode(value)?; + Bit::from_sql(buf) + } +} + +impl<'a> PgHasArrayType for Bit<'a> { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_bit") + } +} + +#[cfg(test)] +mod tests { + use crate::Bit; + use sqlx::postgres::PgPoolOptions; + use sqlx::Row; + + #[async_std::test] + async fn it_works() -> Result<(), sqlx::Error> { + let pool = PgPoolOptions::new() + .max_connections(1) + .connect("postgres://localhost/pgvector_rust_test") + .await?; + + sqlx::query("CREATE EXTENSION IF NOT EXISTS vector") + .execute(&pool) + .await?; + sqlx::query("DROP TABLE IF EXISTS sqlx_bit_items") + .execute(&pool) + .await?; + sqlx::query("CREATE TABLE sqlx_bit_items (id bigserial PRIMARY KEY, embedding bit(8))") + .execute(&pool) + .await?; + + let vec = Bit::from_bytes(&[0b10101010]); + let vec2 = Bit::from_bytes(&[0b01010101]); + sqlx::query("INSERT INTO sqlx_bit_items (embedding) VALUES ($1), ($2), (NULL)") + .bind(&vec) + .bind(&vec2) + .execute(&pool) + .await?; + + let query_vec = Bit::from_bytes(&[0b10101010]); + let row = + sqlx::query("SELECT embedding FROM sqlx_bit_items ORDER BY embedding <~> $1 LIMIT 1") + .bind(query_vec) + .fetch_one(&pool) + .await?; + let res_vec: Bit = row.try_get("embedding").unwrap(); + assert_eq!(vec, res_vec); + assert_eq!(&[0b10101010], res_vec.as_bytes()); + + let null_row = + sqlx::query("SELECT embedding FROM sqlx_bit_items WHERE embedding IS NULL LIMIT 1") + .fetch_one(&pool) + .await?; + let null_res: Option = null_row.try_get("embedding").unwrap(); + assert!(null_res.is_none()); + + // ensures binary format is correct + let text_row = + sqlx::query("SELECT embedding::text FROM sqlx_bit_items ORDER BY id LIMIT 1") + .fetch_one(&pool) + .await?; + let text_res: String = text_row.try_get("embedding").unwrap(); + assert_eq!("10101010", text_res); + + sqlx::query("ALTER TABLE sqlx_bit_items ADD COLUMN factors bit(8)[]") + .execute(&pool) + .await?; + + let vecs = &[vec, vec2]; + sqlx::query("INSERT INTO sqlx_bit_items (factors) VALUES ($1)") + .bind(vecs) + .execute(&pool) + .await?; + + Ok(()) + } +} diff --git a/src/sqlx_ext/mod.rs b/src/sqlx_ext/mod.rs index 6cb3982..0e7494e 100644 --- a/src/sqlx_ext/mod.rs +++ b/src/sqlx_ext/mod.rs @@ -1,3 +1,4 @@ +mod bit; mod sparsevec; mod vector;