Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for nested transaction rollbacks via savepoints in sql #4405

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions quaint/src/connector/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::{
use async_trait::async_trait;
use connection_string::JdbcString;
use futures::lock::Mutex;
use std::sync::Arc;
use std::{
convert::TryFrom,
fmt,
Expand Down Expand Up @@ -106,11 +107,23 @@ impl TransactionCapable for Mssql {
.or(self.url.query_params.transaction_isolation_level)
.or(Some(SQL_SERVER_DEFAULT_ISOLATION));

let opts = TransactionOptions::new(isolation, self.requires_isolation_first());
let mut transaction_depth = self.transaction_depth.lock().await;
*transaction_depth += 1;
let st_depth = *transaction_depth;

Ok(Box::new(
DefaultTransaction::new(self, self.begin_statement(), opts).await?,
))
let begin_statement = self.begin_statement(st_depth).await;
let commit_stmt = self.commit_statement(st_depth).await;
let rollback_stmt = self.rollback_statement(st_depth).await;

let opts = TransactionOptions::new(
isolation,
self.requires_isolation_first(),
self.transaction_depth.clone(),
commit_stmt,
rollback_stmt,
);

Ok(Box::new(DefaultTransaction::new(self, &begin_statement, opts).await?))
}
}

Expand Down Expand Up @@ -273,6 +286,7 @@ pub struct Mssql {
url: MssqlUrl,
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

impl Mssql {
Expand Down Expand Up @@ -304,6 +318,7 @@ impl Mssql {
url,
socket_timeout,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
};

if let Some(isolation) = this.url.transaction_isolation_level() {
Expand Down Expand Up @@ -443,8 +458,41 @@ impl Queryable for Mssql {
Ok(())
}

fn begin_statement(&self) -> &'static str {
"BEGIN TRAN"
/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"BEGIN TRAN".to_string()
};

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
// MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested
// transaction we just continue onwards
let ret = if depth > 1 {
" ".to_string()
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}

fn requires_isolation_first(&self) -> bool {
Expand Down
35 changes: 35 additions & 0 deletions quaint/src/connector/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use mysql_async::{
prelude::{Query as _, Queryable as _},
};
use percent_encoding::percent_decode;
use std::sync::Arc;
use std::{
borrow::Cow,
future::Future,
Expand Down Expand Up @@ -41,6 +42,7 @@ pub struct Mysql {
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
statement_cache: Mutex<LruCache<String, my::Statement>>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

/// Wraps a connection url and exposes the parsing logic used by quaint, including default values.
Expand Down Expand Up @@ -376,6 +378,7 @@ impl Mysql {
statement_cache: Mutex::new(url.cache()),
url,
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -583,6 +586,38 @@ impl Queryable for Mysql {
fn requires_isolation_first(&self) -> bool {
true
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}

#[cfg(test)]
Expand Down
35 changes: 35 additions & 0 deletions quaint/src/connector/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use lru_cache::LruCache;
use native_tls::{Certificate, Identity, TlsConnector};
use percent_encoding::percent_decode;
use postgres_native_tls::MakeTlsConnector;
use std::sync::Arc;
use std::{
borrow::{Borrow, Cow},
fmt::{Debug, Display},
Expand Down Expand Up @@ -63,6 +64,7 @@ pub struct PostgreSql {
socket_timeout: Option<Duration>,
statement_cache: Mutex<LruCache<String, Statement>>,
is_healthy: AtomicBool,
transaction_depth: Arc<Mutex<i32>>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -652,6 +654,7 @@ impl PostgreSql {
pg_bouncer: url.query_params.pg_bouncer,
statement_cache: Mutex::new(url.cache()),
is_healthy: AtomicBool::new(true),
transaction_depth: Arc::new(Mutex::new(0)),
})
}

Expand Down Expand Up @@ -932,6 +935,38 @@ impl Queryable for PostgreSql {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}

/// Sorted list of CockroachDB's reserved keywords.
Expand Down
34 changes: 30 additions & 4 deletions quaint/src/connector/queryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,18 @@ pub trait Queryable: Send + Sync {
}

/// Statement to begin a transaction
fn begin_statement(&self) -> &'static str {
"BEGIN"
async fn begin_statement(&self, _depth: i32) -> String {
"BEGIN".to_string()
}

/// Statement to commit a transaction
async fn commit_statement(&self, _depth: i32) -> String {
"COMMIT".to_string()
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, _depth: i32) -> String {
"ROLLBACK".to_string()
}

/// Sets the transaction isolation level to given value.
Expand Down Expand Up @@ -117,10 +127,26 @@ macro_rules! impl_default_TransactionCapable {
&'a self,
isolation: Option<IsolationLevel>,
) -> crate::Result<Box<dyn crate::connector::Transaction + 'a>> {
let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first());
let depth = self.transaction_depth.clone();
let mut depth_guard = self.transaction_depth.lock().await;
*depth_guard += 1;

let st_depth = *depth_guard;

let begin_statement = self.begin_statement(st_depth).await;
let commit_stmt = self.commit_statement(st_depth).await;
let rollback_stmt = self.rollback_statement(st_depth).await;

let opts = crate::connector::TransactionOptions::new(
isolation,
self.requires_isolation_first(),
depth,
commit_stmt,
rollback_stmt,
);

Ok(Box::new(
crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?,
crate::connector::DefaultTransaction::new(self, &begin_statement, opts).await?,
))
}
}
Expand Down
40 changes: 39 additions & 1 deletion quaint/src/connector/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
visitor::{self, Visitor},
};
use async_trait::async_trait;
use std::sync::Arc;
use std::{convert::TryFrom, path::Path, time::Duration};
use tokio::sync::Mutex;

Expand All @@ -25,6 +26,7 @@ pub use rusqlite;
/// A connector interface for the SQLite database
pub struct Sqlite {
pub(crate) client: Mutex<rusqlite::Connection>,
transaction_depth: Arc<futures::lock::Mutex<i32>>,
}

/// Wraps a connection url and exposes the parsing logic used by Quaint,
Expand Down Expand Up @@ -141,7 +143,10 @@ impl TryFrom<&str> for Sqlite {

let client = Mutex::new(conn);

Ok(Sqlite { client })
Ok(Sqlite {
client,
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}
}

Expand All @@ -156,6 +161,7 @@ impl Sqlite {

Ok(Sqlite {
client: Mutex::new(client),
transaction_depth: Arc::new(futures::lock::Mutex::new(0)),
})
}

Expand Down Expand Up @@ -252,6 +258,38 @@ impl Queryable for Sqlite {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
async fn begin_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() };

return ret;
}

/// Statement to commit a transaction
async fn commit_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"COMMIT".to_string()
};

return ret;
}

/// Statement to rollback a transaction
async fn rollback_statement(&self, depth: i32) -> String {
let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth);
let ret = if depth > 1 {
savepoint_stmt
} else {
"ROLLBACK".to_string()
};

return ret;
}
}

#[cfg(test)]
Expand Down
Loading