Skip to content

Commit

Permalink
Add sqlite_busytimeout parameter as user configurable param (#121)
Browse files Browse the repository at this point in the history
* Add sqlite_busytimeout parameter as user configurable param

* Remove debug log

* Fix lint, fix integration test

---------

Co-authored-by: Phillip LeBlanc <[email protected]>
  • Loading branch information
Sevenannn and phillipleblanc authored Sep 30, 2024
1 parent b0af919 commit 857577d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 27 deletions.
63 changes: 43 additions & 20 deletions src/sql/db_connection_pool/sqlitepool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{sync::Arc, time::Duration};

use async_trait::async_trait;
use snafu::{prelude::*, ResultExt};
Expand Down Expand Up @@ -26,14 +26,16 @@ pub struct SqliteConnectionPoolFactory {
path: Arc<str>,
mode: Mode,
attach_databases: Option<Vec<Arc<str>>>,
busy_timeout: Duration,
}

impl SqliteConnectionPoolFactory {
pub fn new(path: &str, mode: Mode) -> Self {
pub fn new(path: &str, mode: Mode, busy_timeout: Duration) -> Self {
SqliteConnectionPoolFactory {
path: path.into(),
mode,
attach_databases: None,
busy_timeout,
}
}

Expand Down Expand Up @@ -80,9 +82,14 @@ impl SqliteConnectionPoolFactory {
vec![]
};

let pool =
SqliteConnectionPool::new(&self.path, self.mode, join_push_down, attach_databases)
.await?;
let pool = SqliteConnectionPool::new(
&self.path,
self.mode,
join_push_down,
attach_databases,
self.busy_timeout,
)
.await?;

pool.setup().await?;

Expand All @@ -96,6 +103,7 @@ pub struct SqliteConnectionPool {
mode: Mode,
path: Arc<str>,
attach_databases: Vec<Arc<str>>,
busy_timeout: Duration,
}

impl SqliteConnectionPool {
Expand All @@ -113,6 +121,7 @@ impl SqliteConnectionPool {
mode: Mode,
join_push_down: JoinPushDown,
attach_databases: Vec<Arc<str>>,
busy_timeout: Duration,
) -> Result<Self> {
let conn = match mode {
Mode::Memory => Connection::open_in_memory()
Expand All @@ -130,6 +139,7 @@ impl SqliteConnectionPool {
mode,
attach_databases,
path: path.into(),
busy_timeout,
})
}

Expand All @@ -147,19 +157,23 @@ impl SqliteConnectionPool {

pub async fn setup(&self) -> Result<()> {
let conn = self.conn.clone();
let busy_timeout = self.busy_timeout;

// these configuration options are only applicable for file-mode databases
if self.mode == Mode::File {
// change transaction mode to Write-Ahead log instead of default atomic rollback journal: https://www.sqlite.org/wal.html
// NOTE: This is a no-op if the database is in-memory, as only MEMORY or OFF are supported: https://www.sqlite.org/pragma.html#pragma_journal_mode
conn.call(|conn| {
conn.call(move |conn| {
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "busy_timeout", "5000")?;
conn.pragma_update(None, "synchronous", "NORMAL")?;
conn.pragma_update(None, "cache_size", "-20000")?;
conn.pragma_update(None, "foreign_keys", "true")?;
conn.pragma_update(None, "temp_store", "memory")?;
// conn.set_transaction_behavior(TransactionBehavior::Immediate); introduced in rustqlite 0.32.1, but tokio-rusqlite is still on 0.31.0

// Set user configurable connection timeout
conn.busy_timeout(busy_timeout)?;

Ok(())
})
.await
Expand Down Expand Up @@ -212,6 +226,7 @@ impl SqliteConnectionPool {
mode: self.mode,
path: Arc::clone(&self.path),
attach_databases: self.attach_databases.clone(),
busy_timeout: self.busy_timeout,
}),
Mode::File => {
let attach_databases = if self.attach_databases.is_empty() {
Expand All @@ -220,7 +235,7 @@ impl SqliteConnectionPool {
Some(self.attach_databases.clone())
};

SqliteConnectionPoolFactory::new(&self.path, self.mode)
SqliteConnectionPoolFactory::new(&self.path, self.mode, self.busy_timeout)
.with_databases(attach_databases)
.build()
.await
Expand Down Expand Up @@ -250,6 +265,7 @@ mod tests {
use crate::sql::db_connection_pool::Mode;
use rand::Rng;
use rstest::rstest;
use std::time::Duration;

fn random_db_name() -> String {
let mut rng = rand::thread_rng();
Expand All @@ -266,7 +282,7 @@ mod tests {
#[tokio::test]
async fn test_sqlite_connection_pool_factory() {
let db_name = random_db_name();
let factory = SqliteConnectionPoolFactory::new(&db_name, Mode::File);
let factory = SqliteConnectionPoolFactory::new(&db_name, Mode::File, None);
let pool = factory.build().await.unwrap();

assert!(pool.join_push_down == JoinPushDown::AllowedFor(db_name.clone()));
Expand All @@ -285,10 +301,11 @@ mod tests {
db_names.sort();

let factory =
SqliteConnectionPoolFactory::new(&db_names[0], Mode::File).with_databases(Some(vec![
db_names[1].clone().into(),
db_names[2].clone().into(),
]));
SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000))
.with_databases(Some(vec![
db_names[1].clone().into(),
db_names[2].clone().into(),
]));

SqliteConnectionPool::init(&db_names[1], Mode::File)
.await
Expand Down Expand Up @@ -317,7 +334,8 @@ mod tests {
async fn test_sqlite_connection_pool_factory_with_empty_attachments() {
let db_name = random_db_name();
let factory =
SqliteConnectionPoolFactory::new(&db_name, Mode::File).with_databases(Some(vec![]));
SqliteConnectionPoolFactory::new(&db_name, Mode::File, Duration::from_millis(5000))
.with_databases(Some(vec![]));

let pool = factory.build().await.unwrap();

Expand All @@ -333,8 +351,12 @@ mod tests {

#[tokio::test]
async fn test_sqlite_connection_pool_factory_memory_with_attachments() {
let factory = SqliteConnectionPoolFactory::new("./test.sqlite", Mode::Memory)
.with_databases(Some(vec!["./test1.sqlite".into(), "./test2.sqlite".into()]));
let factory = SqliteConnectionPoolFactory::new(
"./test.sqlite",
Mode::Memory,
Duration::from_millis(5000),
)
.with_databases(Some(vec!["./test1.sqlite".into(), "./test2.sqlite".into()]));
let pool = factory.build().await.unwrap();

assert!(pool.join_push_down == JoinPushDown::Disallow);
Expand All @@ -355,10 +377,11 @@ mod tests {
db_names.sort();

let factory =
SqliteConnectionPoolFactory::new(&db_names[0], Mode::File).with_databases(Some(vec![
db_names[1].clone().into(),
db_names[2].clone().into(),
]));
SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000))
.with_databases(Some(vec![
db_names[1].clone().into(),
db_names[2].clone().into(),
]));
let pool = factory.build().await;

assert!(pool.is_err());
Expand Down
31 changes: 28 additions & 3 deletions src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use rusqlite::{ToSql, Transaction};
use snafu::prelude::*;
use sql_table::SQLiteTable;
use std::collections::HashSet;
use std::time::Duration;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
use tokio_rusqlite::Connection;
Expand Down Expand Up @@ -93,6 +94,9 @@ pub enum Error {

#[snafu(display("Unable to infer schema: {source}"))]
UnableToInferSchema { source: dbconnection::Error },

#[snafu(display("Invalid SQLite busy_timeout value"))]
InvalidBusyTimeoutValue { value: String },
}

type Result<T, E = Error> = std::result::Result<T, E>;
Expand All @@ -104,6 +108,7 @@ pub struct SqliteTableProviderFactory {
const SQLITE_DB_PATH_PARAM: &str = "file";
const SQLITE_DB_BASE_FOLDER_PARAM: &str = "data_directory";
const SQLITE_ATTACH_DATABASES_PARAM: &str = "attach_databases";
const SQLITE_BUSY_TIMEOUT_PARAM: &str = "sqlite_busy_timeout";

impl SqliteTableProviderFactory {
#[must_use]
Expand Down Expand Up @@ -139,10 +144,27 @@ impl SqliteTableProviderFactory {
.unwrap_or(default_filepath)
}

pub fn sqlite_busy_timeout(&self, options: &HashMap<String, String>) -> Result<Duration> {
let busy_timeout = options.get(SQLITE_BUSY_TIMEOUT_PARAM).cloned();
match busy_timeout {
Some(busy_timeout) => {
let result: u64 = busy_timeout.parse().map_err(|_| {
InvalidBusyTimeoutValueSnafu {
value: busy_timeout,
}
.build()
})?;
Ok(Duration::from_millis(result))
}
None => Ok(Duration::from_millis(5000)),
}
}

pub async fn get_or_init_instance(
&self,
db_path: impl Into<Arc<str>>,
mode: Mode,
busy_timeout: Duration,
) -> Result<SqliteConnectionPool> {
let db_path = db_path.into();
let key = match mode {
Expand All @@ -155,7 +177,7 @@ impl SqliteTableProviderFactory {
return instance.try_clone().await.context(DbConnectionPoolSnafu);
}

let pool = SqliteConnectionPoolFactory::new(&db_path, mode)
let pool = SqliteConnectionPoolFactory::new(&db_path, mode, busy_timeout)
.build()
.await
.context(DbConnectionPoolSnafu)?;
Expand Down Expand Up @@ -219,10 +241,13 @@ impl TableProviderFactory for SqliteTableProviderFactory {
);
}

let busy_timeout = self
.sqlite_busy_timeout(&cmd.options)
.map_err(to_datafusion_error)?;
let db_path: Arc<str> = self.sqlite_file_path(&name, &cmd.options).into();

let pool: Arc<SqliteConnectionPool> = Arc::new(
self.get_or_init_instance(Arc::clone(&db_path), mode)
self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout)
.await
.map_err(to_datafusion_error)?,
);
Expand All @@ -234,7 +259,7 @@ impl TableProviderFactory for SqliteTableProviderFactory {
// even though we setup SQLite to use WAL mode, the pool isn't really a pool so shares the same connection
// and we can't have concurrent writes when sharing the same connection
Arc::new(
self.get_or_init_instance(Arc::clone(&db_path), mode)
self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout)
.await
.map_err(to_datafusion_error)?,
)
Expand Down
12 changes: 8 additions & 4 deletions tests/sqlite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ async fn arrow_sqlite_round_trip(
tracing::debug!("Running tests on {table_name}");
let ctx = SessionContext::new();

let pool = SqliteConnectionPoolFactory::new(":memory:", Mode::Memory)
.build()
.await
.expect("Sqlite connection pool to be created");
let pool = SqliteConnectionPoolFactory::new(
":memory:",
Mode::Memory,
std::time::Duration::from_millis(5000),
)
.build()
.await
.expect("Sqlite connection pool to be created");

let conn = pool
.connect()
Expand Down

0 comments on commit 857577d

Please sign in to comment.