diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index d8b022d0fa49..1eecd2e33306 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -3,6 +3,7 @@ use crate::{ proxy::{CommonProxy, DriverProxy, Query}, }; use async_trait::async_trait; +use futures::lock::Mutex; use napi::JsObject; use psl::datamodel_connector::Flavour; use quaint::{ @@ -12,6 +13,7 @@ use quaint::{ visitor::{self, Visitor}, Value, }; +use std::sync::Arc; use tracing::{info_span, Instrument}; /// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the @@ -182,6 +184,7 @@ impl JsBaseQueryable { pub struct JsQueryable { inner: JsBaseQueryable, driver_proxy: DriverProxy, + pub transaction_depth: Arc>, } impl std::fmt::Display for JsQueryable { @@ -259,14 +262,19 @@ impl TransactionCapable for JsQueryable { } } - let begin_stmt = tx.begin_statement(); + let mut depth_guard = self.transaction_depth.lock().await; + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_stmt = tx.begin_statement(st_depth).await; let tx_opts = tx.options(); if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); + let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); tx.raw_phantom_cmd(begin_stmt.as_str()).await?; } else { - tx.raw_cmd(begin_stmt).await?; + tx.raw_cmd(&begin_stmt).await?; } if !isolation_first { @@ -288,5 +296,6 @@ pub fn from_napi(driver: JsObject) -> JsQueryable { JsQueryable { inner: JsBaseQueryable::new(common), driver_proxy, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), } } diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index 0d26c7f863aa..d3076e9d4dfb 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use futures::lock::Mutex; use metrics::{decrement_gauge, increment_gauge}; use napi::{bindgen_prelude::FromNapiValue, JsObject}; use quaint::{ @@ -6,6 +7,7 @@ use quaint::{ prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; +use std::sync::Arc; use crate::{ proxy::{CommonProxy, TransactionOptions, TransactionProxy}, @@ -18,13 +20,22 @@ use crate::{ pub(crate) struct JsTransaction { tx_proxy: TransactionProxy, inner: JsBaseQueryable, + pub depth: Arc>, + pub commit_stmt: String, + pub rollback_stmt: String, } impl JsTransaction { pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { increment_gauge!("prisma_client_queries_active", 1.0); - Self { inner, tx_proxy } + Self { + inner, + tx_proxy, + commit_stmt: "COMMIT".to_string(), + rollback_stmt: "ROLLBACK".to_string(), + depth: Arc::new(futures::lock::Mutex::new(0)), + } } pub fn options(&self) -> &TransactionOptions { @@ -39,10 +50,11 @@ impl JsTransaction { #[async_trait] impl QuaintTransaction for JsTransaction { - async fn commit(&self) -> quaint::Result<()> { + async fn commit(&mut self) -> quaint::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); - let commit_stmt = "COMMIT"; + let mut depth_guard = self.depth.lock().await; + let commit_stmt = &self.commit_stmt; if self.options().use_phantom_query { let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); @@ -51,13 +63,17 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(commit_stmt).await?; } + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + self.tx_proxy.commit().await } - async fn rollback(&self) -> quaint::Result<()> { + async fn rollback(&mut self) -> quaint::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); - let rollback_stmt = "ROLLBACK"; + let mut depth_guard = self.depth.lock().await; + let rollback_stmt = &self.rollback_stmt; if self.options().use_phantom_query { let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); @@ -66,6 +82,9 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(rollback_stmt).await?; } + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + self.tx_proxy.rollback().await }