diff --git a/mongo/options/transactionoptions.go b/mongo/options/transactionoptions.go index 9270cd20d4..60a51c9329 100644 --- a/mongo/options/transactionoptions.go +++ b/mongo/options/transactionoptions.go @@ -39,6 +39,9 @@ type TransactionOptions struct { // be used in its place to control the amount of time that a single operation can run before returning an error. // MaxCommitTime is ignored if Timeout is set on the client. MaxCommitTime *time.Duration + + // NonRetryableOnTransientErrors indicates whether the transaction should not be retried on transient errors. + NonRetryableOnTransientErrors bool } // Transaction creates a new TransactionOptions instance. diff --git a/mongo/session.go b/mongo/session.go index 288bf63efd..d9964a6b93 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -177,8 +177,16 @@ func (s *sessionImpl) EndSession(ctx context.Context) { } // WithTransaction implements the Session interface. -func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error), - opts ...*options.TransactionOptions) (interface{}, error) { +func (s *sessionImpl) WithTransaction( + ctx context.Context, + fn func(ctx SessionContext) (interface{}, error), + opts ...*options.TransactionOptions, +) (interface{}, error) { + var options options.TransactionOptions + if len(opts) > 0 && opts[0] != nil { + options = *opts[0] + } + timeout := time.NewTimer(withTransactionTimeout) defer timeout.Stop() var err error @@ -202,7 +210,7 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionCo default: } - if errorHasLabel(err, driver.TransientTransactionError) { + if !options.NonRetryableOnTransientErrors && errorHasLabel(err, driver.TransientTransactionError) { continue } return res, err @@ -247,7 +255,7 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionCo if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() { continue } - if cerr.HasErrorLabel(driver.TransientTransactionError) { + if !options.NonRetryableOnTransientErrors && cerr.HasErrorLabel(driver.TransientTransactionError) { break CommitLoop } }