Skip to content

Commit

Permalink
Add more cases to AsyncDatabaseClientTest. (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjayVas authored Oct 18, 2024
1 parent 3b4b1ac commit 6f79ccf
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package org.wfanet.measurement.gcloud.spanner
import com.google.cloud.Timestamp
import com.google.cloud.spanner.AsyncResultSet
import com.google.cloud.spanner.AsyncRunner
import com.google.cloud.spanner.CommitResponse
import com.google.cloud.spanner.DatabaseClient
import com.google.cloud.spanner.DatabaseId
import com.google.cloud.spanner.Key
Expand All @@ -25,6 +26,7 @@ import com.google.cloud.spanner.Mutation
import com.google.cloud.spanner.Options.QueryOption
import com.google.cloud.spanner.Options.ReadOption
import com.google.cloud.spanner.ReadContext
import com.google.cloud.spanner.ReadOnlyTransaction
import com.google.cloud.spanner.Spanner
import com.google.cloud.spanner.SpannerException
import com.google.cloud.spanner.Statement
Expand All @@ -47,7 +49,6 @@ import kotlinx.coroutines.channels.ProducerScope
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.buffer
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.flow.singleOrNull
import kotlinx.coroutines.launch
Expand All @@ -73,9 +74,16 @@ class AsyncDatabaseClient(private val dbClient: DatabaseClient, private val exec
return ReadContextImpl(dbClient.singleUse(bound))
}

/** @see DatabaseClient.singleUseReadOnlyTransaction */
fun singleUseReadOnlyTransaction(
bound: TimestampBound = TimestampBound.strong()
): ReadOnlyTransaction {
return ReadOnlyTransactionImpl(dbClient.singleUseReadOnlyTransaction(bound))
}

/** @see DatabaseClient.readOnlyTransaction */
fun readOnlyTransaction(bound: TimestampBound = TimestampBound.strong()): ReadContext {
return ReadContextImpl(dbClient.readOnlyTransaction(bound))
fun readOnlyTransaction(bound: TimestampBound = TimestampBound.strong()): ReadOnlyTransaction {
return ReadOnlyTransactionImpl(dbClient.readOnlyTransaction(bound))
}

/** @see DatabaseClient.readWriteTransaction */
Expand All @@ -85,7 +93,7 @@ class AsyncDatabaseClient(private val dbClient: DatabaseClient, private val exec

/** @see DatabaseClient.write */
suspend fun write(mutations: Iterable<Mutation>) {
readWriteTransaction().execute { txn -> txn.buffer(mutations) }
readWriteTransaction().run { txn -> txn.buffer(mutations) }
}

/** @see DatabaseClient.write */
Expand Down Expand Up @@ -149,20 +157,35 @@ class AsyncDatabaseClient(private val dbClient: DatabaseClient, private val exec
fun executeQuery(statement: Statement, vararg options: QueryOption): Flow<Struct>
}

/** Coroutine version of [AsyncRunner]. */
/** Async coroutine version of [com.google.cloud.spanner.ReadOnlyTransaction]. */
interface ReadOnlyTransaction : ReadContext {
/** @see com.google.cloud.spanner.ReadOnlyTransaction.getReadTimestamp */
val readTimestamp: Timestamp
}

/** Async coroutine version of [com.google.cloud.spanner.TransactionRunner]. */
interface TransactionRunner {
/**
* Executes a read/write transaction asynchronously, suspending until it is complete.
* Executes a read/write transaction with retries as necessary.
*
* @param doWork function that does work inside a transaction
* @see com.google.cloud.spanner.AsyncRunner.runAsync
*
* This acts as a coroutine builder. [doWork] has a [CoroutineScope] receiver to ensure that
* coroutine builders called from it run in the [CoroutineScope] defined by this function.
*
* @see com.google.cloud.spanner.TransactionRunner.run
*/
suspend fun <R> execute(doWork: TransactionWork<R>): R
suspend fun <R> run(doWork: TransactionWork<R>): R

/** Alias for [run]. */
@Deprecated(message = "Use run", replaceWith = ReplaceWith("run(doWork)"))
suspend fun <R> execute(doWork: TransactionWork<R>): R = run(doWork)

/** @see com.google.cloud.spanner.TransactionRunner.getCommitTimestamp */
suspend fun getCommitTimestamp(): Timestamp

/** @see com.google.cloud.spanner.TransactionRunner.getCommitTimestamp */
suspend fun getCommitResponse(): CommitResponse
}

/** Async coroutine version of [com.google.cloud.spanner.TransactionContext]. */
Expand All @@ -177,18 +200,18 @@ class AsyncDatabaseClient(private val dbClient: DatabaseClient, private val exec
suspend fun executeUpdate(statement: Statement): Long
}

private inner class TransactionRunnerImpl(private val runner: AsyncRunner) : TransactionRunner {
override suspend fun <R> execute(doWork: TransactionWork<R>): R {
private inner class TransactionRunnerImpl(private val delegate: AsyncRunner) : TransactionRunner {
override suspend fun <R> run(doWork: TransactionWork<R>): R {
try {
return runner.run(executor, doWork)
return delegate.run(executor, doWork)
} catch (e: SpannerException) {
throw e.wrappedException ?: e
}
}

override suspend fun getCommitTimestamp(): Timestamp {
return runner.commitTimestamp.await()
}
override suspend fun getCommitTimestamp(): Timestamp = delegate.commitTimestamp.await()

override suspend fun getCommitResponse(): CommitResponse = delegate.commitResponse.await()
}

companion object {
Expand Down Expand Up @@ -241,6 +264,13 @@ private class ReadContextImpl(private val readContext: ReadContext) :
}
}

private class ReadOnlyTransactionImpl(private val delegate: ReadOnlyTransaction) :
AsyncDatabaseClient.ReadOnlyTransaction,
AsyncDatabaseClient.ReadContext by ReadContextImpl(delegate) {
override val readTimestamp: Timestamp
get() = delegate.readTimestamp
}

private class TransactionContextImpl(private val txn: TransactionContext) :
AsyncDatabaseClient.TransactionContext, AsyncDatabaseClient.ReadContext by ReadContextImpl(txn) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ fun Statement.Builder.bindJson(paramValuePair: Pair<String, AbstractMessage?>):
return bind(paramName).toProtoJson(value)
}

/** Builds a [Statement]. */
fun statement(sql: String): Statement = Statement.newBuilder(sql).build()

/** Builds a [Statement]. */
inline fun statement(sql: String, bind: Statement.Builder.() -> Unit): Statement =
Statement.newBuilder(sql).apply(bind).build()
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,245 @@ package org.wfanet.measurement.gcloud.spanner
import com.google.cloud.spanner.Struct
import com.google.common.truth.Truth.assertThat
import java.nio.file.Path
import kotlin.test.assertFailsWith
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.debug.junit4.CoroutinesTimeout
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.junit.ClassRule
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.wfanet.measurement.common.CountDownLatch
import org.wfanet.measurement.common.getJarResourcePath
import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorDatabaseRule
import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorRule

@RunWith(JUnit4::class)
class AsyncDatabaseClientTest {
@get:Rule val database = SpannerEmulatorDatabaseRule(spannerEmulator, CHANGELOG_PATH)
@get:Rule val timeout = CoroutinesTimeout.seconds(5)

private val databaseClient: AsyncDatabaseClient
get() = database.databaseClient

@Test
fun `executes simple query`() {
fun `executeQuery returns result`() {
val results: List<Struct> = runBlocking {
databaseClient.singleUse().executeQuery(statement("SELECT TRUE") {}).toList()
databaseClient.singleUse().executeQuery(statement("SELECT TRUE")).toList()
}

assertThat(results.single().getBoolean(0)).isTrue()
}

@Test
fun `run applies buffered mutations`() {
runBlocking {
databaseClient.readWriteTransaction().run { txn ->
txn.bufferInsertMutation("Cars") {
set("CarId").to(1)
set("Year").to(1990)
set("Make").to("Nissan")
set("Model").to("Stanza")
}
txn.bufferInsertMutation("Cars") {
set("CarId").to(2)
set("Year").to(1997)
set("Make").to("Honda")
set("Model").to("CR-V")
}
}
}

val results: List<Struct> = runBlocking {
databaseClient.singleUse().use { txn ->
txn
.executeQuery(statement("SELECT CarId, Year, Make, Model FROM Cars ORDER BY CarId"))
.toList()
}
}
assertThat(results)
.containsExactly(
struct {
set("CarId").to(1)
set("Year").to(1990)
set("Make").to("Nissan")
set("Model").to("Stanza")
},
struct {
set("CarId").to(2)
set("Year").to(1997)
set("Make").to("Honda")
set("Model").to("CR-V")
},
)
.inOrder()
}

@Test
fun `run executes statement`() {
val statementSql =
"""
INSERT INTO Cars(CarId, Year, Make, Model)
VALUES
(1, 1990, 'Nissan', 'Stanza'),
(2, 1997, 'Honda', 'CR-V')
"""
.trimIndent()

runBlocking {
databaseClient.readWriteTransaction().run { txn ->
txn.executeUpdate(statement(statementSql))
}
}

val results: List<Struct> = runBlocking {
databaseClient.singleUse().use { txn ->
txn
.executeQuery(statement("SELECT CarId, Year, Make, Model FROM Cars ORDER BY CarId"))
.toList()
}
}
assertThat(results)
.containsExactly(
struct {
set("CarId").to(1)
set("Year").to(1990)
set("Make").to("Nissan")
set("Model").to("Stanza")
},
struct {
set("CarId").to(2)
set("Year").to(1997)
set("Make").to("Honda")
set("Model").to("CR-V")
},
)
.inOrder()
}

@Test
fun `run bubbles exceptions from transaction work`() = runBlocking {
val message = "Error inside transaction work"

val exception =
assertFailsWith<Exception> {
databaseClient.readWriteTransaction().run { _ -> throw Exception(message) }
}

assertThat(exception).hasMessageThat().isEqualTo(message)
}

@Test
fun `run can read results within transaction`() = runBlocking {
databaseClient.readWriteTransaction().run { txn ->
txn.bufferInsertMutation("Cars") {
set("CarId").to(1)
set("Year").to(1990)
set("Make").to("Nissan")
set("Model").to("Stanza")
}
txn.bufferInsertMutation("Cars") {
set("CarId").to(2)
set("Year").to(1997)
set("Make").to("Honda")
set("Model").to("CR-V")
}
}

val maxCarId =
databaseClient.readWriteTransaction().run { txn ->
val maxCarId =
txn
.executeQuery(statement("SELECT MAX(CarId) AS MaxCarId FROM Cars"))
.toList()
.single()
.getLong("MaxCarId")
txn.bufferInsertMutation("Cars") {
set("CarId").to(maxCarId + 1)
set("Year").to(2004)
set("Make").to("Infiniti")
set("Model").to("G35")
}
maxCarId + 1
}

assertThat(maxCarId).isEqualTo(3)
}

@Test
fun `run can execute concurrent transactions`() =
runBlocking(Dispatchers.Default) {
coroutineScope {
val latch = CountDownLatch(1)
launch {
databaseClient.readWriteTransaction().run { txn ->
// Use a latch to ensure that transactions are running concurrently and that one will be
// forced to abort and retry.
latch.await()
txn.executeQuery(statement("SELECT * FROM Cars")).toList()

txn.bufferInsertMutation("Cars") {
set("CarId").to(1)
set("Year").to(1990)
set("Make").to("Nissan")
set("Model").to("Stanza")
}
txn.bufferInsertMutation("Cars") {
set("CarId").to(2)
set("Year").to(1997)
set("Make").to("Honda")
set("Model").to("CR-V")
}
}
}

databaseClient.readWriteTransaction().run { txn ->
latch.countDown()
txn.executeQuery(statement("SELECT * FROM Cars")).toList()

txn.bufferInsertMutation("Cars") {
set("CarId").to(3)
set("Year").to(2004)
set("Make").to("Infiniti")
set("Model").to("G35")
}
}
}

val results: List<Struct> =
databaseClient.singleUse().use { txn ->
txn
.executeQuery(statement("SELECT CarId, Year, Make, Model FROM Cars ORDER BY CarId"))
.toList()
}
assertThat(results)
.containsExactly(
struct {
set("CarId").to(1)
set("Year").to(1990)
set("Make").to("Nissan")
set("Model").to("Stanza")
},
struct {
set("CarId").to(2)
set("Year").to(1997)
set("Make").to("Honda")
set("Model").to("CR-V")
},
struct {
set("CarId").to(3)
set("Year").to(2004)
set("Make").to("Infiniti")
set("Model").to("G35")
},
)
.inOrder()
}

companion object {
private const val CHANGELOG_RESOURCE_NAME = "db/spanner/changelog.yaml"
private val CHANGELOG_PATH: Path =
Expand Down
Loading

0 comments on commit 6f79ccf

Please sign in to comment.