Skip to content

Commit

Permalink
Revert MySQL & SQLite returning support
Browse files Browse the repository at this point in the history
  • Loading branch information
billy1624 committed Nov 10, 2021
1 parent cc035d7 commit 7560a64
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 213 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ jobs:
name: Examples
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
path: [basic, actix_example, actix4_example, axum_example, rocket_example]
Expand All @@ -312,6 +313,7 @@ jobs:
if: ${{ (needs.init.outputs.run-partial == 'true' && needs.init.outputs.run-issues == 'true') }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
path: [86, 249, 262]
Expand Down
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ sqlx = { version = "^0.5", optional = true }
uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
ouroboros = "0.11"
url = "^2.2"
regex = "^1"

[dev-dependencies]
smol = { version = "^1.2" }
Expand Down
10 changes: 5 additions & 5 deletions src/database/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ pub trait ConnectionTrait<'a>: Sync {
T: Send,
E: std::error::Error + Send;

/// Check if the connection supports `RETURNING` syntax on insert
fn returning_on_insert(&self) -> bool;

/// Check if the connection supports `RETURNING` syntax on update
fn returning_on_update(&self) -> bool;
/// Check if the connection supports `RETURNING` syntax on insert and update
fn support_returning(&self) -> bool {
let db_backend = self.get_database_backend();
db_backend.support_returning()
}

/// Check if the connection is a test connection for the Mock database
fn is_mock_connection(&self) -> bool {
Expand Down
60 changes: 5 additions & 55 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,61 +214,6 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
}
}

fn returning_on_insert(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => {
// Supported if it's MariaDB on or after version 10.5.0
// Not supported in all MySQL versions
conn.support_returning
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => {
// Supported by all Postgres versions
true
}
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
conn.support_returning
}
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}

fn returning_on_update(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(_) => {
// Not supported in all MySQL & MariaDB versions
false
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => {
// Supported by all Postgres versions
true
}
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => {
// Supported by SQLite on or after version 3.35.0 (2021-03-12)
conn.support_returning
}
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}

#[cfg(feature = "mock")]
fn is_mock_connection(&self) -> bool {
matches!(self, DatabaseConnection::MockDatabaseConnection(_))
Expand Down Expand Up @@ -322,6 +267,11 @@ impl DbBackend {
Self::Sqlite => Box::new(SqliteQueryBuilder),
}
}

/// Check if the database supports `RETURNING` syntax on insert and update
pub fn support_returning(&self) -> bool {
matches!(self, Self::Postgres)
}
}

#[cfg(test)]
Expand Down
33 changes: 2 additions & 31 deletions src/database/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub struct DatabaseTransaction {
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
open: bool,
support_returning: bool,
}

impl std::fmt::Debug for DatabaseTransaction {
Expand All @@ -29,12 +28,10 @@ impl DatabaseTransaction {
#[cfg(feature = "sqlx-mysql")]
pub(crate) async fn new_mysql(
inner: PoolConnection<sqlx::MySql>,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::MySql(inner))),
DbBackend::MySql,
support_returning,
)
.await
}
Expand All @@ -46,20 +43,17 @@ impl DatabaseTransaction {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Postgres(inner))),
DbBackend::Postgres,
true,
)
.await
}

#[cfg(feature = "sqlx-sqlite")]
pub(crate) async fn new_sqlite(
inner: PoolConnection<sqlx::Sqlite>,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> {
Self::begin(
Arc::new(Mutex::new(InnerConnection::Sqlite(inner))),
DbBackend::Sqlite,
support_returning,
)
.await
}
Expand All @@ -69,28 +63,17 @@ impl DatabaseTransaction {
inner: Arc<crate::MockDatabaseConnection>,
) -> Result<DatabaseTransaction, DbErr> {
let backend = inner.get_database_backend();
Self::begin(
Arc::new(Mutex::new(InnerConnection::Mock(inner))),
backend,
match backend {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
)
.await
Self::begin(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await
}

async fn begin(
conn: Arc<Mutex<InnerConnection>>,
backend: DbBackend,
support_returning: bool,
) -> Result<DatabaseTransaction, DbErr> {
let res = DatabaseTransaction {
conn,
backend,
open: true,
support_returning,
};
match *res.conn.lock().await {
#[cfg(feature = "sqlx-mysql")]
Expand Down Expand Up @@ -347,8 +330,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
}

async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend, self.support_returning)
.await
DatabaseTransaction::begin(Arc::clone(&self.conn), self.backend).await
}

/// Execute the function inside a transaction.
Expand All @@ -365,17 +347,6 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
let transaction = self.begin().await.map_err(TransactionError::Connection)?;
transaction.run(_callback).await
}

fn returning_on_insert(&self) -> bool {
self.support_returning
}

fn returning_on_update(&self) -> bool {
match self.backend {
DbBackend::MySql => false,
_ => self.support_returning,
}
}
}

/// Defines errors for handling transaction failures
Expand Down
67 changes: 9 additions & 58 deletions src/driver/sqlx_mysql.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use regex::Regex;
use std::{future::Future, pin::Pin};

use sqlx::{
mysql::{MySqlArguments, MySqlConnectOptions, MySqlQueryResult, MySqlRow},
MySql, MySqlPool, Row,
MySql, MySqlPool,
};

sea_query::sea_query_driver_mysql!();
use sea_query_driver_mysql::bind_query;

use crate::{
debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction,
DbBackend, QueryStream, Statement, TransactionError,
QueryStream, Statement, TransactionError,
};

use super::sqlx_common::*;
Expand All @@ -24,7 +23,6 @@ pub struct SqlxMySqlConnector;
#[derive(Debug, Clone)]
pub struct SqlxMySqlPoolConnection {
pool: MySqlPool,
pub(crate) support_returning: bool,
}

impl SqlxMySqlConnector {
Expand All @@ -44,7 +42,9 @@ impl SqlxMySqlConnector {
opt.disable_statement_logging();
}
if let Ok(pool) = options.pool_options().connect_with(opt).await {
into_db_connection(pool).await
Ok(DatabaseConnection::SqlxMySqlPoolConnection(
SqlxMySqlPoolConnection { pool },
))
} else {
Err(DbErr::Conn("Failed to connect.".to_owned()))
}
Expand All @@ -53,8 +53,8 @@ impl SqlxMySqlConnector {

impl SqlxMySqlConnector {
/// Instantiate a sqlx pool connection to a [DatabaseConnection]
pub async fn from_sqlx_mysql_pool(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
into_db_connection(pool).await
pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection { pool })
}
}

Expand Down Expand Up @@ -129,7 +129,7 @@ impl SqlxMySqlPoolConnection {
/// Bundle a set of SQL statements that execute together.
pub async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if let Ok(conn) = self.pool.acquire().await {
DatabaseTransaction::new_mysql(conn, self.support_returning).await
DatabaseTransaction::new_mysql(conn).await
} else {
Err(DbErr::Query(
"Failed to acquire connection from pool.".to_owned(),
Expand All @@ -148,7 +148,7 @@ impl SqlxMySqlPoolConnection {
E: std::error::Error + Send,
{
if let Ok(conn) = self.pool.acquire().await {
let transaction = DatabaseTransaction::new_mysql(conn, self.support_returning)
let transaction = DatabaseTransaction::new_mysql(conn)
.await
.map_err(|e| TransactionError::Connection(e))?;
transaction.run(callback).await
Expand Down Expand Up @@ -183,52 +183,3 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq
}
query
}

async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
let support_returning = parse_support_returning(&pool).await?;
Ok(DatabaseConnection::SqlxMySqlPoolConnection(
SqlxMySqlPoolConnection {
pool,
support_returning,
},
))
}

async fn parse_support_returning(pool: &MySqlPool) -> Result<bool, DbErr> {
let stmt = Statement::from_string(
DbBackend::MySql,
r#"SHOW VARIABLES LIKE "version""#.to_owned(),
);
let query = sqlx_query(&stmt);
let row = query
.fetch_one(pool)
.await
.map_err(sqlx_error_to_query_err)?;
let version: String = row.try_get("Value").map_err(sqlx_error_to_query_err)?;
let support_returning = if !version.contains("MariaDB") {
// This is MySQL
// Not supported in all MySQL versions
false
} else {
// This is MariaDB
let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap();
let captures = regex.captures(&version).unwrap();
macro_rules! parse_captures {
( $idx: expr ) => {
captures.get($idx).map_or(0, |m| {
m.as_str()
.parse::<usize>()
.map_err(|e| DbErr::Conn(e.to_string()))
.unwrap()
})
};
}
let ver_major = parse_captures!(1);
let ver_minor = parse_captures!(2);
// Supported if it's MariaDB with version 10.5.0 or after
ver_major >= 10 && ver_minor >= 5
};
debug_print!("db_version: {}", version);
debug_print!("db_support_returning: {}", support_returning);
Ok(support_returning)
}
Loading

0 comments on commit 7560a64

Please sign in to comment.