diff --git a/driver-sync/src/main/kotlin/FilteredMongoCollection.kt b/driver-sync/src/main/kotlin/FilteredMongoCollection.kt index 018e32c..9d64b31 100644 --- a/driver-sync/src/main/kotlin/FilteredMongoCollection.kt +++ b/driver-sync/src/main/kotlin/FilteredMongoCollection.kt @@ -8,7 +8,6 @@ import com.mongodb.client.result.UpdateResult import com.mongodb.kotlin.client.FindIterable import fr.qsh.ktmongo.dsl.expr.FilterExpression import fr.qsh.ktmongo.dsl.expr.UpdateExpression -import java.util.concurrent.TimeUnit private class FilteredMongoCollection( private val upstream: MongoCollection, @@ -19,8 +18,8 @@ private class FilteredMongoCollection( override fun count(options: CountOptions): Long = upstream.count(options, baseFilter) // countEstimated is a real count when a filter is present, it's slower but at least it won't break the app - override fun countEstimated(options: EstimatedDocumentCountOptions): Long = upstream.count( - CountOptions().maxTime(options.getMaxTime(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS).comment(options.comment), + override fun countEstimated(options: EstimatedDocumentCountOptions): Long = upstream.countForReal( + options, baseFilter, ) diff --git a/driver-sync/src/main/kotlin/MongoCollection.kt b/driver-sync/src/main/kotlin/MongoCollection.kt index dd9287c..db7e8d9 100644 --- a/driver-sync/src/main/kotlin/MongoCollection.kt +++ b/driver-sync/src/main/kotlin/MongoCollection.kt @@ -8,6 +8,7 @@ import com.mongodb.client.result.UpdateResult import com.mongodb.kotlin.client.FindIterable import fr.qsh.ktmongo.dsl.expr.FilterExpression import fr.qsh.ktmongo.dsl.expr.UpdateExpression +import java.util.concurrent.TimeUnit /** * Parent interface to all collection types provided by KtMongo. @@ -136,6 +137,8 @@ sealed interface MongoCollection { * Views do not possess the required metadata. * When this function is called on a view (either a MongoDB view or a [filter] logical view), a regular [count] is executed instead. * + * When this function is called from within a [transaction], a regular [count] is executed instead. + * * ### External resources * * - [Official documentation](https://www.mongodb.com/docs/manual/reference/method/db.collection.estimatedDocumentCount/) @@ -341,3 +344,13 @@ sealed interface MongoCollection { // endregion } + +internal fun MongoCollection.countForReal( + options: EstimatedDocumentCountOptions, + predicate: FilterExpression.() -> Unit = {}, +) = count( + options = CountOptions() + .maxTime(options.getMaxTime(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS) + .comment(options.comment), + predicate = predicate +) diff --git a/driver-sync/src/main/kotlin/NativeMongoCollection.kt b/driver-sync/src/main/kotlin/NativeMongoCollection.kt index ea99a33..7ccb7ea 100644 --- a/driver-sync/src/main/kotlin/NativeMongoCollection.kt +++ b/driver-sync/src/main/kotlin/NativeMongoCollection.kt @@ -35,21 +35,30 @@ class NativeMongoCollection( // region Find override fun find(): FindIterable = - unsafe.find() + when (val session = getCurrentSession()) { + null -> unsafe.find() + else -> unsafe.find(session) + } override fun find(predicate: FilterExpression.() -> Unit): FindIterable { val bson = FilterExpression(unsafe.codecRegistry) .apply(predicate) .toBsonDocument() - return unsafe.find(bson) + return when (val session = getCurrentSession()) { + null -> unsafe.find(bson) + else -> unsafe.find(session, bson) + } } // endregion // region Count override fun count(options: CountOptions): Long = - unsafe.countDocuments(options = options) + when (val session = getCurrentSession()) { + null -> unsafe.countDocuments(options = options) + else -> unsafe.countDocuments(session, options = options) + } override fun count( options: CountOptions, @@ -59,11 +68,17 @@ class NativeMongoCollection( .apply(predicate) .toBsonDocument() - return unsafe.countDocuments(bson, options) + return when (val session = getCurrentSession()) { + null -> unsafe.countDocuments(bson, options) + else -> unsafe.countDocuments(session, bson, options) + } } override fun countEstimated(options: EstimatedDocumentCountOptions): Long = - unsafe.estimatedDocumentCount(options) + when (getCurrentSession()) { + null -> unsafe.estimatedDocumentCount(options) + else -> countForReal(options) // Downgrade to a regular count + } // endregion // region Update @@ -81,11 +96,10 @@ class NativeMongoCollection( .apply(update) .toBsonDocument() - return unsafe.updateOne( - filter = filterBson, - update = updateBson, - options = options, - ) + return when (val session = getCurrentSession()) { + null -> unsafe.updateOne(filterBson, updateBson, options) + else -> unsafe.updateOne(session, filterBson, updateBson, options) + } } override fun updateMany( @@ -101,11 +115,10 @@ class NativeMongoCollection( .apply(update) .toBsonDocument() - return unsafe.updateMany( - filter = filterBson, - update = updateBson, - options = options, - ) + return when (val session = getCurrentSession()) { + null -> unsafe.updateMany(filterBson, updateBson, options) + else -> unsafe.updateMany(session, filterBson, updateBson, options) + } } override fun findOneAndUpdate( @@ -121,11 +134,10 @@ class NativeMongoCollection( .apply(update) .toBsonDocument() - return unsafe.findOneAndUpdate( - filter = filterBson, - update = updateBson, - options = options, - ) + return when (val session = getCurrentSession()) { + null -> unsafe.findOneAndUpdate(filterBson, updateBson, options) + else -> unsafe.findOneAndUpdate(session, filterBson, updateBson, options) + } } // endregion diff --git a/driver-sync/src/main/kotlin/Transactions.kt b/driver-sync/src/main/kotlin/Transactions.kt new file mode 100644 index 0000000..ba39421 --- /dev/null +++ b/driver-sync/src/main/kotlin/Transactions.kt @@ -0,0 +1,109 @@ +package fr.qsh.ktmongo.sync + +import com.mongodb.ClientSessionOptions +import com.mongodb.kotlin.client.ClientSession +import com.mongodb.kotlin.client.MongoClient + +/** + * See [transaction]. + */ +interface TransactionScope { + + /** + * Commits any uncommitted work from the current [transaction]. + */ + fun commit() + + /** + * Aborts any uncommitted work from the current [transaction]. + */ + fun abort() +} + +private class TransactionScopeImpl( + val session: ClientSession +) : TransactionScope { + + override fun commit() { + session.commitTransaction() + } + + override fun abort() { + session.abortTransaction() + } +} + +private val currentClientSession = ThreadLocal() + +/** + * Manages a distributed transaction. + * + * ### Transactions shouldn't be used, most of the time + * + * In MongoDB, documents are always updated atomically. + * This reduces the need for transactions, since all data needing to be updated at once is expected to be in a single + * document. + * + * However, sometimes, distributed transactions are still necessary. + * Note that MongoDB isn't optimized for heavy use of distributed transactions. + * + * ### Usage + * + * **This example purely demonstrates the syntax. It is not a valid situation to use distributed transactions.** + * + * ```kotlin + * val client = MongoClient.create() + * val database = client.getDatabase("test") + * + * val jedi = database.getCollection("jedi").asKtMongo() + * val padawan = database.getCollection("padawan").asKtMongo() + * + * client.transaction { + * padawan.insertOne { + * Padawan::id set 1234 + * Padawan::name set "Alexsandr" + * } + * + * jedi.updateOne( + * filter = { Jedi::id eq 967 }, + * update = { Jedi::padawans add 1234 } + * ) + * } + * ``` + * + * ### Behavior + * + * If the block terminates with an exception, the transaction is aborted. + * If the block terminates normally, the transaction is committed. + * + * Alternatively, the [TransactionScope.commit] and [TransactionScope.abort] functions can be called + * to manually commit or abort transactions. + * + * ### External resources + * + * - [Official documentation](https://www.mongodb.com/docs/manual/core/transactions/) + */ +fun MongoClient.transaction( + options: ClientSessionOptions = ClientSessionOptions.builder().build(), + block: TransactionScope.() -> R +): R { + val previousSession: ClientSession? = currentClientSession.get() + val session = startSession(options) + + try { + currentClientSession.set(session) + val ret = block(TransactionScopeImpl(session)) + + session.commitTransaction() + return ret + } catch (e: Throwable) { + session.abortTransaction() + throw e + } finally { + currentClientSession.set(previousSession) + session.close() + } +} + +internal fun getCurrentSession(): ClientSession? = + currentClientSession.get()