diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index cef092edb9d7..a217721e0019 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -11,6 +11,7 @@ use crate::{ use async_trait::async_trait; use connection_string::JdbcString; use futures::lock::Mutex; +use std::sync::Arc; use std::{ convert::TryFrom, fmt, @@ -106,11 +107,23 @@ impl TransactionCapable for Mssql { .or(self.url.query_params.transaction_isolation_level) .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + let mut transaction_depth = self.transaction_depth.lock().await; + *transaction_depth += 1; + let st_depth = *transaction_depth; - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) + let begin_statement = self.begin_statement(st_depth).await; + let commit_stmt = self.commit_statement(st_depth).await; + let rollback_stmt = self.rollback_statement(st_depth).await; + + let opts = TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + commit_stmt, + rollback_stmt, + ); + + Ok(Box::new(DefaultTransaction::new(self, &begin_statement, opts).await?)) } } @@ -273,6 +286,7 @@ pub struct Mssql { url: MssqlUrl, socket_timeout: Option<Duration>, is_healthy: AtomicBool, + transaction_depth: Arc<Mutex<i32>>, } impl Mssql { @@ -304,6 +318,7 @@ impl Mssql { url, socket_timeout, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }; if let Some(isolation) = this.url.transaction_isolation_level() { @@ -443,8 +458,41 @@ impl Queryable for Mssql { Ok(()) } - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN TRAN".to_string() + }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested + // transaction we just continue onwards + let ret = if depth > 1 { + " ".to_string() + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } fn requires_isolation_first(&self) -> bool { diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 4b6f27a583da..4bdbbcbf510b 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -14,6 +14,7 @@ use mysql_async::{ prelude::{Query as _, Queryable as _}, }; use percent_encoding::percent_decode; +use std::sync::Arc; use std::{ borrow::Cow, future::Future, @@ -41,6 +42,7 @@ pub struct Mysql { socket_timeout: Option<Duration>, is_healthy: AtomicBool, statement_cache: Mutex<LruCache<String, my::Statement>>, + transaction_depth: Arc<futures::lock::Mutex<i32>>, } /// Wraps a connection url and exposes the parsing logic used by quaint, including default values. @@ -376,6 +378,7 @@ impl Mysql { statement_cache: Mutex::new(url.cache()), url, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -583,6 +586,38 @@ impl Queryable for Mysql { fn requires_isolation_first(&self) -> bool { true } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } #[cfg(test)] diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 766be38b27e4..469eaaeb72bc 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -13,6 +13,7 @@ use lru_cache::LruCache; use native_tls::{Certificate, Identity, TlsConnector}; use percent_encoding::percent_decode; use postgres_native_tls::MakeTlsConnector; +use std::sync::Arc; use std::{ borrow::{Borrow, Cow}, fmt::{Debug, Display}, @@ -63,6 +64,7 @@ pub struct PostgreSql { socket_timeout: Option<Duration>, statement_cache: Mutex<LruCache<String, Statement>>, is_healthy: AtomicBool, + transaction_depth: Arc<Mutex<i32>>, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -652,6 +654,7 @@ impl PostgreSql { pg_bouncer: url.query_params.pg_bouncer, statement_cache: Mutex::new(url.cache()), is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -932,6 +935,38 @@ impl Queryable for PostgreSql { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } /// Sorted list of CockroachDB's reserved keywords. diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 09dbc7abba4c..490d7dc6f86b 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -87,8 +87,18 @@ pub trait Queryable: Send + Sync { } /// Statement to begin a transaction - fn begin_statement(&self) -> &'static str { - "BEGIN" + async fn begin_statement(&self, _depth: i32) -> String { + "BEGIN".to_string() + } + + /// Statement to commit a transaction + async fn commit_statement(&self, _depth: i32) -> String { + "COMMIT".to_string() + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, _depth: i32) -> String { + "ROLLBACK".to_string() } /// Sets the transaction isolation level to given value. @@ -117,10 +127,26 @@ macro_rules! impl_default_TransactionCapable { &'a self, isolation: Option<IsolationLevel>, ) -> crate::Result<Box<dyn crate::connector::Transaction + 'a>> { - let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); + let depth = self.transaction_depth.clone(); + let mut depth_guard = self.transaction_depth.lock().await; + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_statement = self.begin_statement(st_depth).await; + let commit_stmt = self.commit_statement(st_depth).await; + let rollback_stmt = self.rollback_statement(st_depth).await; + + let opts = crate::connector::TransactionOptions::new( + isolation, + self.requires_isolation_first(), + depth, + commit_stmt, + rollback_stmt, + ); Ok(Box::new( - crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?, + crate::connector::DefaultTransaction::new(self, &begin_statement, opts).await?, )) } } diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 3a1ef72b4883..eb490d4569ed 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -13,6 +13,7 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; +use std::sync::Arc; use std::{convert::TryFrom, path::Path, time::Duration}; use tokio::sync::Mutex; @@ -25,6 +26,7 @@ pub use rusqlite; /// A connector interface for the SQLite database pub struct Sqlite { pub(crate) client: Mutex<rusqlite::Connection>, + transaction_depth: Arc<futures::lock::Mutex<i32>>, } /// Wraps a connection url and exposes the parsing logic used by Quaint, @@ -141,7 +143,10 @@ impl TryFrom<&str> for Sqlite { let client = Mutex::new(conn); - Ok(Sqlite { client }) + Ok(Sqlite { + client, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } } @@ -156,6 +161,7 @@ impl Sqlite { Ok(Sqlite { client: Mutex::new(client), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -252,6 +258,38 @@ impl Queryable for Sqlite { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } #[cfg(test)] diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index b7e91e97f6a8..fd857c935653 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -4,7 +4,9 @@ use crate::{ error::{Error, ErrorKind}, }; use async_trait::async_trait; +use futures::lock::Mutex; use metrics::{decrement_gauge, increment_gauge}; +use std::sync::Arc; use std::{fmt, str::FromStr}; extern crate metrics as metrics; @@ -12,10 +14,10 @@ extern crate metrics as metrics; #[async_trait] pub trait Transaction: Queryable { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()>; + async fn commit(&mut self) -> crate::Result<()>; /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()>; + async fn rollback(&mut self) -> crate::Result<()>; /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 fn as_queryable(&self) -> &dyn Queryable; @@ -27,6 +29,15 @@ pub(crate) struct TransactionOptions { /// Whether or not to put the isolation level `SET` before or after the `BEGIN`. pub(crate) isolation_first: bool, + + /// The depth of the transaction, used to determine the nested transaction statements. + pub depth: Arc<Mutex<i32>>, + + /// The statement to use to commit the transaction. + pub commit_stmt: String, + + /// The statement to use to rollback the transaction. + pub rollback_stmt: String, } /// A default representation of an SQL database transaction. If not commited, a @@ -36,6 +47,9 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + pub depth: Arc<Mutex<i32>>, + pub commit_stmt: String, + pub rollback_stmt: String, } impl<'a> DefaultTransaction<'a> { @@ -44,7 +58,12 @@ impl<'a> DefaultTransaction<'a> { begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result<DefaultTransaction<'a>> { - let this = Self { inner }; + let this = Self { + inner, + depth: tx_opts.depth, + commit_stmt: tx_opts.commit_stmt, + rollback_stmt: tx_opts.rollback_stmt, + }; if tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -70,17 +89,29 @@ impl<'a> DefaultTransaction<'a> { #[async_trait] impl<'a> Transaction for DefaultTransaction<'a> { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()> { + async fn commit(&mut self) -> crate::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("COMMIT").await?; + + let mut depth_guard = self.depth.lock().await; + + self.inner.raw_cmd(&self.commit_stmt).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; Ok(()) } /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()> { + async fn rollback(&mut self) -> crate::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("ROLLBACK").await?; + + let mut depth_guard = self.depth.lock().await; + + self.inner.raw_cmd(&self.rollback_stmt).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; Ok(()) } @@ -190,10 +221,19 @@ impl FromStr for IsolationLevel { } } impl TransactionOptions { - pub fn new(isolation_level: Option<IsolationLevel>, isolation_first: bool) -> Self { + pub fn new( + isolation_level: Option<IsolationLevel>, + isolation_first: bool, + depth: Arc<Mutex<i32>>, + commit_stmt: String, + rollback_stmt: String, + ) -> Self { Self { isolation_level, isolation_first, + depth, + commit_stmt, + rollback_stmt, } } } diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 4c4152923377..458a3412ecec 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -500,7 +500,10 @@ impl Quaint { } }; - Ok(PooledConnection { inner }) + Ok(PooledConnection { + inner, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } /// Info about the connection and underlying database. diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c0aa8c93b75d..e1e028bb0ef5 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -10,12 +10,15 @@ use crate::{ error::Error, }; use async_trait::async_trait; +use futures::lock::Mutex; use mobc::{Connection as MobcPooled, Manager}; +use std::sync::Arc; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). pub struct PooledConnection { pub(crate) inner: MobcPooled<QuaintManager>, + pub transaction_depth: Arc<Mutex<i32>>, } impl_default_TransactionCapable!(PooledConnection); @@ -62,8 +65,16 @@ impl Queryable for PooledConnection { self.inner.server_reset_query(tx).await } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 82042f58010b..e5c6175fbbed 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -7,6 +7,7 @@ use crate::{ connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, }; use async_trait::async_trait; +use futures::lock::Mutex; use std::{fmt, sync::Arc}; #[cfg(feature = "sqlite")] @@ -17,6 +18,7 @@ use std::convert::TryFrom; pub struct Quaint { inner: Arc<dyn Queryable>, connection_info: Arc<ConnectionInfo>, + transaction_depth: Arc<Mutex<i32>>, } impl fmt::Debug for Quaint { @@ -163,7 +165,11 @@ impl Quaint { let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); Self::log_start(&connection_info); - Ok(Self { inner, connection_info }) + Ok(Self { + inner, + connection_info, + transaction_depth: Arc::new(Mutex::new(0)), + }) } #[cfg(feature = "sqlite")] @@ -174,6 +180,7 @@ impl Quaint { connection_info: Arc::new(ConnectionInfo::InMemorySqlite { db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), }), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -228,8 +235,16 @@ impl Queryable for Quaint { self.inner.is_healthy() } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 06bebe1a9601..cf471fbf7330 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -64,7 +64,7 @@ async fn select_star_from(api: &mut dyn TestApi) -> crate::Result<()> { async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("value int").await?; - let tx = api.conn().start_transaction(None).await?; + let mut tx = api.conn().start_transaction(None).await?; let insert = Insert::single_into(&table).value("value", 10); let rows_affected = tx.execute(insert.into()).await?; @@ -75,6 +75,20 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { assert_eq!(Value::int32(10), res[0]); + // Check that nested transactions are also rolled back, even at multiple levels deep + let mut tx_inner = api.conn().start_transaction(None).await?; + let inner_insert1 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?; + assert_eq!(1, inner_rows_affected1); + + let mut tx_inner2 = api.conn().start_transaction(None).await?; + let inner_insert2 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?; + assert_eq!(1, inner_rows_affected2); + tx_inner2.commit().await?; + + tx_inner.commit().await?; + tx.rollback().await?; let select = Select::from_table(&table).column("value"); diff --git a/quaint/src/tests/query/error.rs b/quaint/src/tests/query/error.rs index 69c57332b6d3..67334858576e 100644 --- a/quaint/src/tests/query/error.rs +++ b/quaint/src/tests/query/error.rs @@ -456,7 +456,7 @@ async fn concurrent_transaction_conflict(api: &mut dyn TestApi) -> crate::Result let conn1 = api.create_additional_connection().await?; let conn2 = api.create_additional_connection().await?; - let tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; + let mut tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; let tx2 = conn2.start_transaction(Some(IsolationLevel::Serializable)).await?; tx1.query(Select::from_table(&table).into()).await?; diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index ddbb7dfc8429..2f1572ca922a 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -72,7 +72,6 @@ pub struct TransactionOptions { /// An optional pre-defined transaction id. Some value might be provided in case we want to generate /// a new id at the beginning of the transaction - #[serde(skip)] pub new_tx_id: Option<TxId>, } diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index ce125e8fa17e..5c99ebd9f8d3 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,6 +1,6 @@ use crate::CoreError; use connector::Transaction; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::fmt::Display; use tokio::time::{Duration, Instant}; @@ -38,7 +38,7 @@ pub(crate) use messages::*; /// the TransactionActorManager can reply with a helpful error message which explains that no operation can be performed on a closed transaction /// rather than an error message stating that the transaction does not exist. -#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize)] pub struct TxId(String); const MINIMUM_TX_ID_LENGTH: usize = 24; diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index b9a8cfe6564d..fe274658e0b3 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -3,6 +3,7 @@ use crate::{ proxy::{CommonProxy, DriverProxy, Query}, }; use async_trait::async_trait; +use futures::lock::Mutex; use napi::JsObject; use psl::datamodel_connector::Flavour; use quaint::{ @@ -11,6 +12,7 @@ use quaint::{ prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, visitor::{self, Visitor}, }; +use std::sync::Arc; use tracing::{info_span, Instrument}; /// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the @@ -185,6 +187,7 @@ impl JsBaseQueryable { pub struct JsQueryable { inner: JsBaseQueryable, driver_proxy: DriverProxy, + pub transaction_depth: Arc<Mutex<i32>>, } impl std::fmt::Display for JsQueryable { @@ -262,14 +265,19 @@ impl TransactionCapable for JsQueryable { } } - let begin_stmt = tx.begin_statement(); + let mut depth_guard = self.transaction_depth.lock().await; + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_stmt = tx.begin_statement(st_depth).await; let tx_opts = tx.options(); if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); + let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); tx.raw_phantom_cmd(begin_stmt.as_str()).await?; } else { - tx.raw_cmd(begin_stmt).await?; + tx.raw_cmd(&begin_stmt).await?; } if !isolation_first { @@ -291,5 +299,6 @@ pub fn from_napi(driver: JsObject) -> JsQueryable { JsQueryable { inner: JsBaseQueryable::new(common), driver_proxy, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), } } diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index d35a9019c6bc..d4c8f606b918 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use futures::lock::Mutex; use metrics::decrement_gauge; use napi::{bindgen_prelude::FromNapiValue, JsObject}; use quaint::{ @@ -6,6 +7,7 @@ use quaint::{ prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; +use std::sync::Arc; use crate::{ proxy::{CommonProxy, TransactionOptions, TransactionProxy}, @@ -18,11 +20,20 @@ use crate::{ pub(crate) struct JsTransaction { tx_proxy: TransactionProxy, inner: JsBaseQueryable, + pub depth: Arc<Mutex<i32>>, + pub commit_stmt: String, + pub rollback_stmt: String, } impl JsTransaction { pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { - Self { inner, tx_proxy } + Self { + inner, + tx_proxy, + commit_stmt: "COMMIT".to_string(), + rollback_stmt: "ROLLBACK".to_string(), + depth: Arc::new(futures::lock::Mutex::new(0)), + } } pub fn options(&self) -> &TransactionOptions { @@ -37,11 +48,12 @@ impl JsTransaction { #[async_trait] impl QuaintTransaction for JsTransaction { - async fn commit(&self) -> quaint::Result<()> { + async fn commit(&mut self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let commit_stmt = "COMMIT"; + let mut depth_guard = self.depth.lock().await; + let commit_stmt = &self.commit_stmt; if self.options().use_phantom_query { let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); @@ -50,14 +62,18 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(commit_stmt).await?; } + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + self.tx_proxy.commit().await } - async fn rollback(&self) -> quaint::Result<()> { + async fn rollback(&mut self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let rollback_stmt = "ROLLBACK"; + let mut depth_guard = self.depth.lock().await; + let rollback_stmt = &self.rollback_stmt; if self.options().use_phantom_query { let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); @@ -66,6 +82,9 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(rollback_stmt).await?; } + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + self.tx_proxy.rollback().await }