From 857577d37858a9bfd5f2f7613993a0e592c963c3 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Sun, 29 Sep 2024 22:25:54 -0700 Subject: [PATCH] Add sqlite_busytimeout parameter as user configurable param (#121) * Add sqlite_busytimeout parameter as user configurable param * Remove debug log * Fix lint, fix integration test --------- Co-authored-by: Phillip LeBlanc --- src/sql/db_connection_pool/sqlitepool.rs | 63 ++++++++++++++++-------- src/sqlite.rs | 31 ++++++++++-- tests/sqlite/mod.rs | 12 +++-- 3 files changed, 79 insertions(+), 27 deletions(-) diff --git a/src/sql/db_connection_pool/sqlitepool.rs b/src/sql/db_connection_pool/sqlitepool.rs index 5b16cd0..6c22baf 100644 --- a/src/sql/db_connection_pool/sqlitepool.rs +++ b/src/sql/db_connection_pool/sqlitepool.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use async_trait::async_trait; use snafu::{prelude::*, ResultExt}; @@ -26,14 +26,16 @@ pub struct SqliteConnectionPoolFactory { path: Arc, mode: Mode, attach_databases: Option>>, + busy_timeout: Duration, } impl SqliteConnectionPoolFactory { - pub fn new(path: &str, mode: Mode) -> Self { + pub fn new(path: &str, mode: Mode, busy_timeout: Duration) -> Self { SqliteConnectionPoolFactory { path: path.into(), mode, attach_databases: None, + busy_timeout, } } @@ -80,9 +82,14 @@ impl SqliteConnectionPoolFactory { vec![] }; - let pool = - SqliteConnectionPool::new(&self.path, self.mode, join_push_down, attach_databases) - .await?; + let pool = SqliteConnectionPool::new( + &self.path, + self.mode, + join_push_down, + attach_databases, + self.busy_timeout, + ) + .await?; pool.setup().await?; @@ -96,6 +103,7 @@ pub struct SqliteConnectionPool { mode: Mode, path: Arc, attach_databases: Vec>, + busy_timeout: Duration, } impl SqliteConnectionPool { @@ -113,6 +121,7 @@ impl SqliteConnectionPool { mode: Mode, join_push_down: JoinPushDown, attach_databases: Vec>, + busy_timeout: Duration, ) -> Result { let conn = match mode { Mode::Memory => Connection::open_in_memory() @@ -130,6 +139,7 @@ impl SqliteConnectionPool { mode, attach_databases, path: path.into(), + busy_timeout, }) } @@ -147,19 +157,23 @@ impl SqliteConnectionPool { pub async fn setup(&self) -> Result<()> { let conn = self.conn.clone(); + let busy_timeout = self.busy_timeout; // these configuration options are only applicable for file-mode databases if self.mode == Mode::File { // change transaction mode to Write-Ahead log instead of default atomic rollback journal: https://www.sqlite.org/wal.html // NOTE: This is a no-op if the database is in-memory, as only MEMORY or OFF are supported: https://www.sqlite.org/pragma.html#pragma_journal_mode - conn.call(|conn| { + conn.call(move |conn| { conn.pragma_update(None, "journal_mode", "WAL")?; - conn.pragma_update(None, "busy_timeout", "5000")?; conn.pragma_update(None, "synchronous", "NORMAL")?; conn.pragma_update(None, "cache_size", "-20000")?; conn.pragma_update(None, "foreign_keys", "true")?; conn.pragma_update(None, "temp_store", "memory")?; // conn.set_transaction_behavior(TransactionBehavior::Immediate); introduced in rustqlite 0.32.1, but tokio-rusqlite is still on 0.31.0 + + // Set user configurable connection timeout + conn.busy_timeout(busy_timeout)?; + Ok(()) }) .await @@ -212,6 +226,7 @@ impl SqliteConnectionPool { mode: self.mode, path: Arc::clone(&self.path), attach_databases: self.attach_databases.clone(), + busy_timeout: self.busy_timeout, }), Mode::File => { let attach_databases = if self.attach_databases.is_empty() { @@ -220,7 +235,7 @@ impl SqliteConnectionPool { Some(self.attach_databases.clone()) }; - SqliteConnectionPoolFactory::new(&self.path, self.mode) + SqliteConnectionPoolFactory::new(&self.path, self.mode, self.busy_timeout) .with_databases(attach_databases) .build() .await @@ -250,6 +265,7 @@ mod tests { use crate::sql::db_connection_pool::Mode; use rand::Rng; use rstest::rstest; + use std::time::Duration; fn random_db_name() -> String { let mut rng = rand::thread_rng(); @@ -266,7 +282,7 @@ mod tests { #[tokio::test] async fn test_sqlite_connection_pool_factory() { let db_name = random_db_name(); - let factory = SqliteConnectionPoolFactory::new(&db_name, Mode::File); + let factory = SqliteConnectionPoolFactory::new(&db_name, Mode::File, None); let pool = factory.build().await.unwrap(); assert!(pool.join_push_down == JoinPushDown::AllowedFor(db_name.clone())); @@ -285,10 +301,11 @@ mod tests { db_names.sort(); let factory = - SqliteConnectionPoolFactory::new(&db_names[0], Mode::File).with_databases(Some(vec![ - db_names[1].clone().into(), - db_names[2].clone().into(), - ])); + SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000)) + .with_databases(Some(vec![ + db_names[1].clone().into(), + db_names[2].clone().into(), + ])); SqliteConnectionPool::init(&db_names[1], Mode::File) .await @@ -317,7 +334,8 @@ mod tests { async fn test_sqlite_connection_pool_factory_with_empty_attachments() { let db_name = random_db_name(); let factory = - SqliteConnectionPoolFactory::new(&db_name, Mode::File).with_databases(Some(vec![])); + SqliteConnectionPoolFactory::new(&db_name, Mode::File, Duration::from_millis(5000)) + .with_databases(Some(vec![])); let pool = factory.build().await.unwrap(); @@ -333,8 +351,12 @@ mod tests { #[tokio::test] async fn test_sqlite_connection_pool_factory_memory_with_attachments() { - let factory = SqliteConnectionPoolFactory::new("./test.sqlite", Mode::Memory) - .with_databases(Some(vec!["./test1.sqlite".into(), "./test2.sqlite".into()])); + let factory = SqliteConnectionPoolFactory::new( + "./test.sqlite", + Mode::Memory, + Duration::from_millis(5000), + ) + .with_databases(Some(vec!["./test1.sqlite".into(), "./test2.sqlite".into()])); let pool = factory.build().await.unwrap(); assert!(pool.join_push_down == JoinPushDown::Disallow); @@ -355,10 +377,11 @@ mod tests { db_names.sort(); let factory = - SqliteConnectionPoolFactory::new(&db_names[0], Mode::File).with_databases(Some(vec![ - db_names[1].clone().into(), - db_names[2].clone().into(), - ])); + SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000)) + .with_databases(Some(vec![ + db_names[1].clone().into(), + db_names[2].clone().into(), + ])); let pool = factory.build().await; assert!(pool.is_err()); diff --git a/src/sqlite.rs b/src/sqlite.rs index 9f905c2..173937c 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -26,6 +26,7 @@ use rusqlite::{ToSql, Transaction}; use snafu::prelude::*; use sql_table::SQLiteTable; use std::collections::HashSet; +use std::time::Duration; use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; use tokio_rusqlite::Connection; @@ -93,6 +94,9 @@ pub enum Error { #[snafu(display("Unable to infer schema: {source}"))] UnableToInferSchema { source: dbconnection::Error }, + + #[snafu(display("Invalid SQLite busy_timeout value"))] + InvalidBusyTimeoutValue { value: String }, } type Result = std::result::Result; @@ -104,6 +108,7 @@ pub struct SqliteTableProviderFactory { const SQLITE_DB_PATH_PARAM: &str = "file"; const SQLITE_DB_BASE_FOLDER_PARAM: &str = "data_directory"; const SQLITE_ATTACH_DATABASES_PARAM: &str = "attach_databases"; +const SQLITE_BUSY_TIMEOUT_PARAM: &str = "sqlite_busy_timeout"; impl SqliteTableProviderFactory { #[must_use] @@ -139,10 +144,27 @@ impl SqliteTableProviderFactory { .unwrap_or(default_filepath) } + pub fn sqlite_busy_timeout(&self, options: &HashMap) -> Result { + let busy_timeout = options.get(SQLITE_BUSY_TIMEOUT_PARAM).cloned(); + match busy_timeout { + Some(busy_timeout) => { + let result: u64 = busy_timeout.parse().map_err(|_| { + InvalidBusyTimeoutValueSnafu { + value: busy_timeout, + } + .build() + })?; + Ok(Duration::from_millis(result)) + } + None => Ok(Duration::from_millis(5000)), + } + } + pub async fn get_or_init_instance( &self, db_path: impl Into>, mode: Mode, + busy_timeout: Duration, ) -> Result { let db_path = db_path.into(); let key = match mode { @@ -155,7 +177,7 @@ impl SqliteTableProviderFactory { return instance.try_clone().await.context(DbConnectionPoolSnafu); } - let pool = SqliteConnectionPoolFactory::new(&db_path, mode) + let pool = SqliteConnectionPoolFactory::new(&db_path, mode, busy_timeout) .build() .await .context(DbConnectionPoolSnafu)?; @@ -219,10 +241,13 @@ impl TableProviderFactory for SqliteTableProviderFactory { ); } + let busy_timeout = self + .sqlite_busy_timeout(&cmd.options) + .map_err(to_datafusion_error)?; let db_path: Arc = self.sqlite_file_path(&name, &cmd.options).into(); let pool: Arc = Arc::new( - self.get_or_init_instance(Arc::clone(&db_path), mode) + self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout) .await .map_err(to_datafusion_error)?, ); @@ -234,7 +259,7 @@ impl TableProviderFactory for SqliteTableProviderFactory { // even though we setup SQLite to use WAL mode, the pool isn't really a pool so shares the same connection // and we can't have concurrent writes when sharing the same connection Arc::new( - self.get_or_init_instance(Arc::clone(&db_path), mode) + self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout) .await .map_err(to_datafusion_error)?, ) diff --git a/tests/sqlite/mod.rs b/tests/sqlite/mod.rs index 74adbaf..0862e57 100644 --- a/tests/sqlite/mod.rs +++ b/tests/sqlite/mod.rs @@ -21,10 +21,14 @@ async fn arrow_sqlite_round_trip( tracing::debug!("Running tests on {table_name}"); let ctx = SessionContext::new(); - let pool = SqliteConnectionPoolFactory::new(":memory:", Mode::Memory) - .build() - .await - .expect("Sqlite connection pool to be created"); + let pool = SqliteConnectionPoolFactory::new( + ":memory:", + Mode::Memory, + std::time::Duration::from_millis(5000), + ) + .build() + .await + .expect("Sqlite connection pool to be created"); let conn = pool .connect()