Skip to content
This repository has been archived by the owner on Oct 2, 2024. It is now read-only.

Commit

Permalink
feat(sync): Implicit distributed transactions
Browse files Browse the repository at this point in the history
See #14
  • Loading branch information
CLOVIS-AI committed Aug 15, 2024
1 parent 69f1271 commit 65d27c0
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 23 deletions.
5 changes: 2 additions & 3 deletions driver-sync/src/main/kotlin/FilteredMongoCollection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.mongodb.client.result.UpdateResult
import com.mongodb.kotlin.client.FindIterable
import fr.qsh.ktmongo.dsl.expr.FilterExpression
import fr.qsh.ktmongo.dsl.expr.UpdateExpression
import java.util.concurrent.TimeUnit

private class FilteredMongoCollection<Document : Any>(
private val upstream: MongoCollection<Document>,
Expand All @@ -19,8 +18,8 @@ private class FilteredMongoCollection<Document : Any>(
override fun count(options: CountOptions): Long = upstream.count(options, baseFilter)

// countEstimated is a real count when a filter is present, it's slower but at least it won't break the app
override fun countEstimated(options: EstimatedDocumentCountOptions): Long = upstream.count(
CountOptions().maxTime(options.getMaxTime(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS).comment(options.comment),
override fun countEstimated(options: EstimatedDocumentCountOptions): Long = upstream.countForReal(
options,
baseFilter,
)

Expand Down
13 changes: 13 additions & 0 deletions driver-sync/src/main/kotlin/MongoCollection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.mongodb.client.result.UpdateResult
import com.mongodb.kotlin.client.FindIterable
import fr.qsh.ktmongo.dsl.expr.FilterExpression
import fr.qsh.ktmongo.dsl.expr.UpdateExpression
import java.util.concurrent.TimeUnit

/**
* Parent interface to all collection types provided by KtMongo.
Expand Down Expand Up @@ -136,6 +137,8 @@ sealed interface MongoCollection<Document : Any> {
* Views do not possess the required metadata.
* When this function is called on a view (either a MongoDB view or a [filter] logical view), a regular [count] is executed instead.
*
* When this function is called from within a [transaction], a regular [count] is executed instead.
*
* ### External resources
*
* - [Official documentation](https://www.mongodb.com/docs/manual/reference/method/db.collection.estimatedDocumentCount/)
Expand Down Expand Up @@ -341,3 +344,13 @@ sealed interface MongoCollection<Document : Any> {
// endregion

}

internal fun <Document : Any> MongoCollection<Document>.countForReal(
options: EstimatedDocumentCountOptions,
predicate: FilterExpression<Document>.() -> Unit = {},
) = count(
options = CountOptions()
.maxTime(options.getMaxTime(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS)
.comment(options.comment),
predicate = predicate
)
52 changes: 32 additions & 20 deletions driver-sync/src/main/kotlin/NativeMongoCollection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,30 @@ class NativeMongoCollection<Document : Any>(
// region Find

override fun find(): FindIterable<Document> =
unsafe.find()
when (val session = getCurrentSession()) {
null -> unsafe.find()
else -> unsafe.find(session)
}

override fun find(predicate: FilterExpression<Document>.() -> Unit): FindIterable<Document> {
val bson = FilterExpression<Document>(unsafe.codecRegistry)
.apply(predicate)
.toBsonDocument()

return unsafe.find(bson)
return when (val session = getCurrentSession()) {
null -> unsafe.find(bson)
else -> unsafe.find(session, bson)
}
}

// endregion
// region Count

override fun count(options: CountOptions): Long =
unsafe.countDocuments(options = options)
when (val session = getCurrentSession()) {
null -> unsafe.countDocuments(options = options)
else -> unsafe.countDocuments(session, options = options)
}

override fun count(
options: CountOptions,
Expand All @@ -59,11 +68,17 @@ class NativeMongoCollection<Document : Any>(
.apply(predicate)
.toBsonDocument()

return unsafe.countDocuments(bson, options)
return when (val session = getCurrentSession()) {
null -> unsafe.countDocuments(bson, options)
else -> unsafe.countDocuments(session, bson, options)
}
}

override fun countEstimated(options: EstimatedDocumentCountOptions): Long =
unsafe.estimatedDocumentCount(options)
when (getCurrentSession()) {
null -> unsafe.estimatedDocumentCount(options)
else -> countForReal(options) // Downgrade to a regular count
}

// endregion
// region Update
Expand All @@ -81,11 +96,10 @@ class NativeMongoCollection<Document : Any>(
.apply(update)
.toBsonDocument()

return unsafe.updateOne(
filter = filterBson,
update = updateBson,
options = options,
)
return when (val session = getCurrentSession()) {
null -> unsafe.updateOne(filterBson, updateBson, options)
else -> unsafe.updateOne(session, filterBson, updateBson, options)
}
}

override fun updateMany(
Expand All @@ -101,11 +115,10 @@ class NativeMongoCollection<Document : Any>(
.apply(update)
.toBsonDocument()

return unsafe.updateMany(
filter = filterBson,
update = updateBson,
options = options,
)
return when (val session = getCurrentSession()) {
null -> unsafe.updateMany(filterBson, updateBson, options)
else -> unsafe.updateMany(session, filterBson, updateBson, options)
}
}

override fun findOneAndUpdate(
Expand All @@ -121,11 +134,10 @@ class NativeMongoCollection<Document : Any>(
.apply(update)
.toBsonDocument()

return unsafe.findOneAndUpdate(
filter = filterBson,
update = updateBson,
options = options,
)
return when (val session = getCurrentSession()) {
null -> unsafe.findOneAndUpdate(filterBson, updateBson, options)
else -> unsafe.findOneAndUpdate(session, filterBson, updateBson, options)
}
}

// endregion
Expand Down
109 changes: 109 additions & 0 deletions driver-sync/src/main/kotlin/Transactions.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package fr.qsh.ktmongo.sync

import com.mongodb.ClientSessionOptions
import com.mongodb.kotlin.client.ClientSession
import com.mongodb.kotlin.client.MongoClient

/**
* See [transaction].
*/
interface TransactionScope {

/**
* Commits any uncommitted work from the current [transaction].
*/
fun commit()

/**
* Aborts any uncommitted work from the current [transaction].
*/
fun abort()
}

private class TransactionScopeImpl(
val session: ClientSession
) : TransactionScope {

override fun commit() {
session.commitTransaction()
}

override fun abort() {
session.abortTransaction()
}
}

private val currentClientSession = ThreadLocal<ClientSession>()

/**
* Manages a distributed transaction.
*
* ### Transactions shouldn't be used, most of the time
*
* In MongoDB, documents are always updated atomically.
* This reduces the need for transactions, since all data needing to be updated at once is expected to be in a single
* document.
*
* However, sometimes, distributed transactions are still necessary.
* Note that MongoDB isn't optimized for heavy use of distributed transactions.
*
* ### Usage
*
* **This example purely demonstrates the syntax. It is not a valid situation to use distributed transactions.**
*
* ```kotlin
* val client = MongoClient.create()
* val database = client.getDatabase("test")
*
* val jedi = database.getCollection<Jedi>("jedi").asKtMongo()
* val padawan = database.getCollection<Padawan>("padawan").asKtMongo()
*
* client.transaction {
* padawan.insertOne {
* Padawan::id set 1234
* Padawan::name set "Alexsandr"
* }
*
* jedi.updateOne(
* filter = { Jedi::id eq 967 },
* update = { Jedi::padawans add 1234 }
* )
* }
* ```
*
* ### Behavior
*
* If the block terminates with an exception, the transaction is aborted.
* If the block terminates normally, the transaction is committed.
*
* Alternatively, the [TransactionScope.commit] and [TransactionScope.abort] functions can be called
* to manually commit or abort transactions.
*
* ### External resources
*
* - [Official documentation](https://www.mongodb.com/docs/manual/core/transactions/)
*/
fun <R> MongoClient.transaction(
options: ClientSessionOptions = ClientSessionOptions.builder().build(),
block: TransactionScope.() -> R
): R {
val previousSession: ClientSession? = currentClientSession.get()
val session = startSession(options)

try {
currentClientSession.set(session)
val ret = block(TransactionScopeImpl(session))

session.commitTransaction()
return ret
} catch (e: Throwable) {
session.abortTransaction()
throw e
} finally {
currentClientSession.set(previousSession)
session.close()
}
}

internal fun getCurrentSession(): ClientSession? =
currentClientSession.get()

0 comments on commit 65d27c0

Please sign in to comment.