Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for nested transaction rollbacks via savepoints in sql #4375

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ ensure-prisma-present:
echo "⚠️ ../prisma diverges from prisma/prisma main branch. Test results might diverge from those in CI ⚠️ "; \
fi \
else \
echo "git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) ../prisma"; \
git clone --depth=1 https://github.com/prisma/prisma.git --branch=$(DRIVER_ADAPTERS_BRANCH) "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \
echo "git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks ../prisma"; \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need to be changed back before merge

git clone --depth=1 https://github.com/LucianBuzzo/prisma.git --branch=lucianbuzzo/nested-rollbacks "../prisma" && echo "Prisma repository has been cloned to ../prisma"; \
fi;

# Quick schema validation of whatever you have in the dev_datamodel.prisma file.
Expand Down
9 changes: 9 additions & 0 deletions libs/driver-adapters/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ pub(crate) struct TransactionProxy {
/// transaction options
options: TransactionOptions,

/// begin transaction
pub begin: AdapterMethod<(), ()>,

/// commit transaction
commit: AdapterMethod<(), ()>,

Expand Down Expand Up @@ -133,11 +136,13 @@ impl TransactionContextProxy {
impl TransactionProxy {
pub fn new(js_transaction: &JsObject) -> JsResult<Self> {
let commit = get_named_property(js_transaction, "commit")?;
let begin = get_named_property(js_transaction, "begin")?;
let rollback = get_named_property(js_transaction, "rollback")?;
let options = get_named_property(js_transaction, "options")?;
let options = from_js_value::<TransactionOptions>(options);

Ok(Self {
begin,
commit,
rollback,
options,
Expand All @@ -149,6 +154,10 @@ impl TransactionProxy {
&self.options
}

pub fn begin(&self) -> UnsafeFuture<impl Future<Output = quaint::Result<()>> + '_> {
UnsafeFuture(self.begin.call_as_async(()))
}

/// Commits the transaction via the driver adapter.
///
/// ## Cancellation safety
Expand Down
2 changes: 2 additions & 0 deletions libs/driver-adapters/src/queryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ impl JsQueryable {
// 3. Spawn a transaction from the context.
let tx = tx_ctx.start_transaction().await?;

tx.increment_depth();

let begin_stmt = tx.begin_statement();
let tx_opts = tx.options();

Expand Down
98 changes: 94 additions & 4 deletions libs/driver-adapters/src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::future::Future;
use std::{
future::Future,
sync::atomic::{AtomicI32, Ordering},
};

use async_trait::async_trait;
use prisma_metrics::gauge;
Expand Down Expand Up @@ -86,11 +89,16 @@ impl Queryable for JsTransactionContext {
pub(crate) struct JsTransaction {
tx_proxy: TransactionProxy,
inner: JsBaseQueryable,
pub depth: AtomicI32,
}

impl JsTransaction {
pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self {
Self { inner, tx_proxy }
Self {
inner,
tx_proxy,
depth: AtomicI32::new(0),
}
}

pub fn options(&self) -> &TransactionOptions {
Expand All @@ -108,14 +116,43 @@ impl JsTransaction {
)
.await
}

pub fn increment_depth(&self) {
self.depth.fetch_add(1, Ordering::Relaxed);
}
}

#[async_trait]
impl QuaintTransaction for JsTransaction {
fn depth(&self) -> i32 {
self.depth.load(Ordering::Relaxed)
}

async fn begin(&self) -> quaint::Result<()> {
// increment of this gauge is done in DriverProxy::startTransaction
gauge!("prisma_client_queries_active").decrement(1.0);

self.depth.fetch_add(1, Ordering::Relaxed);

let begin_stmt = self.begin_statement();

if self.options().use_phantom_query {
let commit_stmt = JsBaseQueryable::phantom_query_message(begin_stmt);
self.raw_phantom_cmd(commit_stmt.as_str()).await?;
} else {
self.inner.raw_cmd(begin_stmt).await?;
}

UnsafeFuture(self.tx_proxy.begin()).await
}

async fn commit(&self) -> quaint::Result<()> {
// increment of this gauge is done in DriverProxy::startTransaction
gauge!("prisma_client_queries_active").decrement(1.0);

// Reset the depth to 0 on commit
self.depth.store(0, Ordering::Relaxed);

let commit_stmt = "COMMIT";

if self.options().use_phantom_query {
Expand All @@ -125,13 +162,18 @@ impl QuaintTransaction for JsTransaction {
self.inner.raw_cmd(commit_stmt).await?;
}

UnsafeFuture(self.tx_proxy.commit()).await
let _ = UnsafeFuture(self.tx_proxy.commit()).await;

Ok(())
}

async fn rollback(&self) -> quaint::Result<()> {
// increment of this gauge is done in DriverProxy::startTransaction
gauge!("prisma_client_queries_active").decrement(1.0);

// Modify the depth value
self.depth.fetch_sub(1, Ordering::Relaxed);

let rollback_stmt = "ROLLBACK";

if self.options().use_phantom_query {
Expand All @@ -141,7 +183,51 @@ impl QuaintTransaction for JsTransaction {
self.inner.raw_cmd(rollback_stmt).await?;
}

UnsafeFuture(self.tx_proxy.rollback()).await
let _ = UnsafeFuture(self.tx_proxy.rollback()).await;

Ok(())
}

async fn create_savepoint(&self) -> quaint::Result<()> {
let new_depth = self.depth.fetch_add(1, Ordering::Relaxed) + 1;

let create_savepoint_statement = self.create_savepoint_statement(new_depth);
if self.options().use_phantom_query {
let create_savepoint_statement = JsBaseQueryable::phantom_query_message(&create_savepoint_statement);
self.raw_phantom_cmd(create_savepoint_statement.as_str()).await?;
} else {
self.inner.raw_cmd(&create_savepoint_statement).await?;
}

Ok(())
}

async fn release_savepoint(&self) -> quaint::Result<()> {
let depth_val = self.depth.fetch_sub(1, Ordering::Relaxed);

let release_savepoint_statement = self.release_savepoint_statement(depth_val);
if self.options().use_phantom_query {
let release_savepoint_statement = JsBaseQueryable::phantom_query_message(&release_savepoint_statement);
self.raw_phantom_cmd(release_savepoint_statement.as_str()).await?;
} else {
self.inner.raw_cmd(&release_savepoint_statement).await?;
}

Ok(())
}

async fn rollback_to_savepoint(&self) -> quaint::Result<()> {
let depth_val = self.depth.fetch_sub(1, Ordering::Relaxed);
let rollback_to_savepoint_statement = self.rollback_to_savepoint_statement(depth_val);
if self.options().use_phantom_query {
let rollback_to_savepoint_statement =
JsBaseQueryable::phantom_query_message(&rollback_to_savepoint_statement);
self.raw_phantom_cmd(rollback_to_savepoint_statement.as_str()).await?;
} else {
self.inner.raw_cmd(&rollback_to_savepoint_statement).await?;
}

Ok(())
}

fn as_queryable(&self) -> &dyn Queryable {
Expand Down Expand Up @@ -198,6 +284,10 @@ impl Queryable for JsTransaction {
fn requires_isolation_first(&self) -> bool {
self.inner.requires_isolation_first()
}

fn begin_statement(&self) -> &'static str {
self.inner.begin_statement()
}
}

#[cfg(target_arch = "wasm32")]
Expand Down
22 changes: 19 additions & 3 deletions quaint/src/connector/mssql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
};
use async_trait::async_trait;
use futures::lock::Mutex;
use std::borrow::Cow;
use std::{
convert::TryFrom,
future::Future,
Expand Down Expand Up @@ -48,9 +49,7 @@ impl TransactionCapable for Mssql {

let opts = TransactionOptions::new(isolation, self.requires_isolation_first());

Ok(Box::new(
DefaultTransaction::new(self, self.begin_statement(), opts).await?,
))
Ok(Box::new(DefaultTransaction::new(self, opts).await?))
}
}

Expand Down Expand Up @@ -244,10 +243,27 @@ impl Queryable for Mssql {
Ok(())
}

/// Statement to begin a transaction
fn begin_statement(&self) -> &'static str {
"BEGIN TRAN"
}

/// Statement to create a savepoint
fn create_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("SAVE TRANSACTION savepoint{depth}"))
}

// MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested
// transaction we just continue onwards
fn release_savepoint_statement(&self, _depth: i32) -> Cow<'static, str> {
Cow::Borrowed("")
}

/// Statement to rollback to a savepoint
fn rollback_to_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("ROLLBACK TRANSACTION savepoint{depth}"))
}

fn requires_isolation_first(&self) -> bool {
true
}
Expand Down
21 changes: 21 additions & 0 deletions quaint/src/connector/mysql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use mysql_async::{
self as my,
prelude::{Query as _, Queryable as _},
};
use std::borrow::Cow;
use std::{
future::Future,
sync::atomic::{AtomicBool, Ordering},
Expand Down Expand Up @@ -347,4 +348,24 @@ impl Queryable for Mysql {
fn requires_isolation_first(&self) -> bool {
true
}

/// Statement to begin a transaction
fn begin_statement(&self) -> &'static str {
"BEGIN"
}

/// Statement to create a savepoint
fn create_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
}

/// Statement to release a savepoint
fn release_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback to a savepoint
fn rollback_to_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("ROLLBACK TO savepoint{depth}"))
}
}
25 changes: 22 additions & 3 deletions quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use postgres_native_tls::MakeTlsConnector;
use postgres_types::{Kind as PostgresKind, Type as PostgresType};
use prisma_metrics::WithMetricsInstrumentation;
use query::PreparedQuery;
use std::borrow::Cow;
use std::{
fmt::{Debug, Display},
fs,
Expand Down Expand Up @@ -541,9 +542,7 @@ impl<Cache: QueryCache> TransactionCapable for PostgreSql<Cache> {
) -> crate::Result<Box<dyn Transaction + 'a>> {
let opts = TransactionOptions::new(isolation, self.requires_isolation_first());

Ok(Box::new(
DefaultTransaction::new(self, self.begin_statement(), opts).await?,
))
Ok(Box::new(DefaultTransaction::new(self, opts).await?))
}
}

Expand Down Expand Up @@ -751,6 +750,26 @@ impl<Cache: QueryCache> Queryable for PostgreSql<Cache> {
fn requires_isolation_first(&self) -> bool {
false
}

/// Statement to begin a transaction
fn begin_statement(&self) -> &'static str {
"BEGIN"
}

/// Statement to create a savepoint
fn create_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
}

/// Statement to release a savepoint
fn release_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback to a savepoint
fn rollback_to_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}"))
}
}

/// Sorted list of CockroachDB's reserved keywords.
Expand Down
19 changes: 18 additions & 1 deletion quaint/src/connector/queryable.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use super::{DescribedQuery, IsolationLevel, ResultSet, Transaction};
use crate::ast::*;
use async_trait::async_trait;
Expand Down Expand Up @@ -94,6 +96,21 @@ pub trait Queryable: Send + Sync {
"BEGIN"
}

/// Statement to create a savepoint in a transaction
fn create_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("SAVEPOINT savepoint{depth}"))
}

/// Statement to release a savepoint in a transaction
fn release_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("RELEASE SAVEPOINT savepoint{depth}"))
}

/// Statement to rollback to a savepoint in a transaction
fn rollback_to_savepoint_statement(&self, depth: i32) -> Cow<'static, str> {
Cow::Owned(format!("ROLLBACK TO SAVEPOINT savepoint{depth}"))
}

/// Sets the transaction isolation level to given value.
/// Implementers have to make sure that the passed isolation level is valid for the underlying database.
async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()>;
Expand Down Expand Up @@ -129,7 +146,7 @@ macro_rules! impl_default_TransactionCapable {
let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first());

Ok(Box::new(
crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?,
crate::connector::DefaultTransaction::new(self, opts).await?,
))
}
}
Expand Down
Loading
Loading