diff --git a/src/main/kotlin/org/wfanet/measurement/gcloud/spanner/AsyncDatabaseClient.kt b/src/main/kotlin/org/wfanet/measurement/gcloud/spanner/AsyncDatabaseClient.kt index a9483f84b..a77939ef2 100644 --- a/src/main/kotlin/org/wfanet/measurement/gcloud/spanner/AsyncDatabaseClient.kt +++ b/src/main/kotlin/org/wfanet/measurement/gcloud/spanner/AsyncDatabaseClient.kt @@ -17,6 +17,7 @@ package org.wfanet.measurement.gcloud.spanner import com.google.cloud.Timestamp import com.google.cloud.spanner.AsyncResultSet import com.google.cloud.spanner.AsyncRunner +import com.google.cloud.spanner.CommitResponse import com.google.cloud.spanner.DatabaseClient import com.google.cloud.spanner.DatabaseId import com.google.cloud.spanner.Key @@ -25,6 +26,7 @@ import com.google.cloud.spanner.Mutation import com.google.cloud.spanner.Options.QueryOption import com.google.cloud.spanner.Options.ReadOption import com.google.cloud.spanner.ReadContext +import com.google.cloud.spanner.ReadOnlyTransaction import com.google.cloud.spanner.Spanner import com.google.cloud.spanner.SpannerException import com.google.cloud.spanner.Statement @@ -47,7 +49,6 @@ import kotlinx.coroutines.channels.ProducerScope import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.buffer import kotlinx.coroutines.flow.callbackFlow import kotlinx.coroutines.flow.singleOrNull import kotlinx.coroutines.launch @@ -73,9 +74,16 @@ class AsyncDatabaseClient(private val dbClient: DatabaseClient, private val exec return ReadContextImpl(dbClient.singleUse(bound)) } + /** @see DatabaseClient.singleUseReadOnlyTransaction */ + fun singleUseReadOnlyTransaction( + bound: TimestampBound = TimestampBound.strong() + ): ReadOnlyTransaction { + return ReadOnlyTransactionImpl(dbClient.singleUseReadOnlyTransaction(bound)) + } + /** @see DatabaseClient.readOnlyTransaction */ - fun readOnlyTransaction(bound: TimestampBound = TimestampBound.strong()): ReadContext { - return ReadContextImpl(dbClient.readOnlyTransaction(bound)) + fun readOnlyTransaction(bound: TimestampBound = TimestampBound.strong()): ReadOnlyTransaction { + return ReadOnlyTransactionImpl(dbClient.readOnlyTransaction(bound)) } /** @see DatabaseClient.readWriteTransaction */ @@ -85,7 +93,7 @@ class AsyncDatabaseClient(private val dbClient: DatabaseClient, private val exec /** @see DatabaseClient.write */ suspend fun write(mutations: Iterable) { - readWriteTransaction().execute { txn -> txn.buffer(mutations) } + readWriteTransaction().run { txn -> txn.buffer(mutations) } } /** @see DatabaseClient.write */ @@ -149,20 +157,35 @@ class AsyncDatabaseClient(private val dbClient: DatabaseClient, private val exec fun executeQuery(statement: Statement, vararg options: QueryOption): Flow } - /** Coroutine version of [AsyncRunner]. */ + /** Async coroutine version of [com.google.cloud.spanner.ReadOnlyTransaction]. */ + interface ReadOnlyTransaction : ReadContext { + /** @see com.google.cloud.spanner.ReadOnlyTransaction.getReadTimestamp */ + val readTimestamp: Timestamp + } + + /** Async coroutine version of [com.google.cloud.spanner.TransactionRunner]. */ interface TransactionRunner { /** - * Executes a read/write transaction asynchronously, suspending until it is complete. + * Executes a read/write transaction with retries as necessary. * * @param doWork function that does work inside a transaction - * @see com.google.cloud.spanner.AsyncRunner.runAsync * * This acts as a coroutine builder. [doWork] has a [CoroutineScope] receiver to ensure that * coroutine builders called from it run in the [CoroutineScope] defined by this function. + * + * @see com.google.cloud.spanner.TransactionRunner.run */ - suspend fun execute(doWork: TransactionWork): R + suspend fun run(doWork: TransactionWork): R + + /** Alias for [run]. */ + @Deprecated(message = "Use run", replaceWith = ReplaceWith("run(doWork)")) + suspend fun execute(doWork: TransactionWork): R = run(doWork) + /** @see com.google.cloud.spanner.TransactionRunner.getCommitTimestamp */ suspend fun getCommitTimestamp(): Timestamp + + /** @see com.google.cloud.spanner.TransactionRunner.getCommitTimestamp */ + suspend fun getCommitResponse(): CommitResponse } /** Async coroutine version of [com.google.cloud.spanner.TransactionContext]. */ @@ -177,18 +200,18 @@ class AsyncDatabaseClient(private val dbClient: DatabaseClient, private val exec suspend fun executeUpdate(statement: Statement): Long } - private inner class TransactionRunnerImpl(private val runner: AsyncRunner) : TransactionRunner { - override suspend fun execute(doWork: TransactionWork): R { + private inner class TransactionRunnerImpl(private val delegate: AsyncRunner) : TransactionRunner { + override suspend fun run(doWork: TransactionWork): R { try { - return runner.run(executor, doWork) + return delegate.run(executor, doWork) } catch (e: SpannerException) { throw e.wrappedException ?: e } } - override suspend fun getCommitTimestamp(): Timestamp { - return runner.commitTimestamp.await() - } + override suspend fun getCommitTimestamp(): Timestamp = delegate.commitTimestamp.await() + + override suspend fun getCommitResponse(): CommitResponse = delegate.commitResponse.await() } companion object { @@ -241,6 +264,13 @@ private class ReadContextImpl(private val readContext: ReadContext) : } } +private class ReadOnlyTransactionImpl(private val delegate: ReadOnlyTransaction) : + AsyncDatabaseClient.ReadOnlyTransaction, + AsyncDatabaseClient.ReadContext by ReadContextImpl(delegate) { + override val readTimestamp: Timestamp + get() = delegate.readTimestamp +} + private class TransactionContextImpl(private val txn: TransactionContext) : AsyncDatabaseClient.TransactionContext, AsyncDatabaseClient.ReadContext by ReadContextImpl(txn) { diff --git a/src/main/kotlin/org/wfanet/measurement/gcloud/spanner/Statements.kt b/src/main/kotlin/org/wfanet/measurement/gcloud/spanner/Statements.kt index d33eb12a3..4b5736a6a 100644 --- a/src/main/kotlin/org/wfanet/measurement/gcloud/spanner/Statements.kt +++ b/src/main/kotlin/org/wfanet/measurement/gcloud/spanner/Statements.kt @@ -131,6 +131,9 @@ fun Statement.Builder.bindJson(paramValuePair: Pair): return bind(paramName).toProtoJson(value) } +/** Builds a [Statement]. */ +fun statement(sql: String): Statement = Statement.newBuilder(sql).build() + /** Builds a [Statement]. */ inline fun statement(sql: String, bind: Statement.Builder.() -> Unit): Statement = Statement.newBuilder(sql).apply(bind).build() diff --git a/src/test/kotlin/org/wfanet/measurement/gcloud/spanner/AsyncDatabaseClientTest.kt b/src/test/kotlin/org/wfanet/measurement/gcloud/spanner/AsyncDatabaseClientTest.kt index f0b85bf96..a3dd73271 100644 --- a/src/test/kotlin/org/wfanet/measurement/gcloud/spanner/AsyncDatabaseClientTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/gcloud/spanner/AsyncDatabaseClientTest.kt @@ -19,13 +19,19 @@ package org.wfanet.measurement.gcloud.spanner import com.google.cloud.spanner.Struct import com.google.common.truth.Truth.assertThat import java.nio.file.Path +import kotlin.test.assertFailsWith +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.debug.junit4.CoroutinesTimeout import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.junit.ClassRule import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 +import org.wfanet.measurement.common.CountDownLatch import org.wfanet.measurement.common.getJarResourcePath import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorDatabaseRule import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorRule @@ -33,19 +39,225 @@ import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorRule @RunWith(JUnit4::class) class AsyncDatabaseClientTest { @get:Rule val database = SpannerEmulatorDatabaseRule(spannerEmulator, CHANGELOG_PATH) + @get:Rule val timeout = CoroutinesTimeout.seconds(5) private val databaseClient: AsyncDatabaseClient get() = database.databaseClient @Test - fun `executes simple query`() { + fun `executeQuery returns result`() { val results: List = runBlocking { - databaseClient.singleUse().executeQuery(statement("SELECT TRUE") {}).toList() + databaseClient.singleUse().executeQuery(statement("SELECT TRUE")).toList() } assertThat(results.single().getBoolean(0)).isTrue() } + @Test + fun `run applies buffered mutations`() { + runBlocking { + databaseClient.readWriteTransaction().run { txn -> + txn.bufferInsertMutation("Cars") { + set("CarId").to(1) + set("Year").to(1990) + set("Make").to("Nissan") + set("Model").to("Stanza") + } + txn.bufferInsertMutation("Cars") { + set("CarId").to(2) + set("Year").to(1997) + set("Make").to("Honda") + set("Model").to("CR-V") + } + } + } + + val results: List = runBlocking { + databaseClient.singleUse().use { txn -> + txn + .executeQuery(statement("SELECT CarId, Year, Make, Model FROM Cars ORDER BY CarId")) + .toList() + } + } + assertThat(results) + .containsExactly( + struct { + set("CarId").to(1) + set("Year").to(1990) + set("Make").to("Nissan") + set("Model").to("Stanza") + }, + struct { + set("CarId").to(2) + set("Year").to(1997) + set("Make").to("Honda") + set("Model").to("CR-V") + }, + ) + .inOrder() + } + + @Test + fun `run executes statement`() { + val statementSql = + """ + INSERT INTO Cars(CarId, Year, Make, Model) + VALUES + (1, 1990, 'Nissan', 'Stanza'), + (2, 1997, 'Honda', 'CR-V') + """ + .trimIndent() + + runBlocking { + databaseClient.readWriteTransaction().run { txn -> + txn.executeUpdate(statement(statementSql)) + } + } + + val results: List = runBlocking { + databaseClient.singleUse().use { txn -> + txn + .executeQuery(statement("SELECT CarId, Year, Make, Model FROM Cars ORDER BY CarId")) + .toList() + } + } + assertThat(results) + .containsExactly( + struct { + set("CarId").to(1) + set("Year").to(1990) + set("Make").to("Nissan") + set("Model").to("Stanza") + }, + struct { + set("CarId").to(2) + set("Year").to(1997) + set("Make").to("Honda") + set("Model").to("CR-V") + }, + ) + .inOrder() + } + + @Test + fun `run bubbles exceptions from transaction work`() = runBlocking { + val message = "Error inside transaction work" + + val exception = + assertFailsWith { + databaseClient.readWriteTransaction().run { _ -> throw Exception(message) } + } + + assertThat(exception).hasMessageThat().isEqualTo(message) + } + + @Test + fun `run can read results within transaction`() = runBlocking { + databaseClient.readWriteTransaction().run { txn -> + txn.bufferInsertMutation("Cars") { + set("CarId").to(1) + set("Year").to(1990) + set("Make").to("Nissan") + set("Model").to("Stanza") + } + txn.bufferInsertMutation("Cars") { + set("CarId").to(2) + set("Year").to(1997) + set("Make").to("Honda") + set("Model").to("CR-V") + } + } + + val maxCarId = + databaseClient.readWriteTransaction().run { txn -> + val maxCarId = + txn + .executeQuery(statement("SELECT MAX(CarId) AS MaxCarId FROM Cars")) + .toList() + .single() + .getLong("MaxCarId") + txn.bufferInsertMutation("Cars") { + set("CarId").to(maxCarId + 1) + set("Year").to(2004) + set("Make").to("Infiniti") + set("Model").to("G35") + } + maxCarId + 1 + } + + assertThat(maxCarId).isEqualTo(3) + } + + @Test + fun `run can execute concurrent transactions`() = + runBlocking(Dispatchers.Default) { + coroutineScope { + val latch = CountDownLatch(1) + launch { + databaseClient.readWriteTransaction().run { txn -> + // Use a latch to ensure that transactions are running concurrently and that one will be + // forced to abort and retry. + latch.await() + txn.executeQuery(statement("SELECT * FROM Cars")).toList() + + txn.bufferInsertMutation("Cars") { + set("CarId").to(1) + set("Year").to(1990) + set("Make").to("Nissan") + set("Model").to("Stanza") + } + txn.bufferInsertMutation("Cars") { + set("CarId").to(2) + set("Year").to(1997) + set("Make").to("Honda") + set("Model").to("CR-V") + } + } + } + + databaseClient.readWriteTransaction().run { txn -> + latch.countDown() + txn.executeQuery(statement("SELECT * FROM Cars")).toList() + + txn.bufferInsertMutation("Cars") { + set("CarId").to(3) + set("Year").to(2004) + set("Make").to("Infiniti") + set("Model").to("G35") + } + } + } + + val results: List = + databaseClient.singleUse().use { txn -> + txn + .executeQuery(statement("SELECT CarId, Year, Make, Model FROM Cars ORDER BY CarId")) + .toList() + } + assertThat(results) + .containsExactly( + struct { + set("CarId").to(1) + set("Year").to(1990) + set("Make").to("Nissan") + set("Model").to("Stanza") + }, + struct { + set("CarId").to(2) + set("Year").to(1997) + set("Make").to("Honda") + set("Model").to("CR-V") + }, + struct { + set("CarId").to(3) + set("Year").to(2004) + set("Make").to("Infiniti") + set("Model").to("G35") + }, + ) + .inOrder() + } + companion object { private const val CHANGELOG_RESOURCE_NAME = "db/spanner/changelog.yaml" private val CHANGELOG_PATH: Path = diff --git a/src/test/kotlin/org/wfanet/measurement/gcloud/spanner/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/gcloud/spanner/BUILD.bazel index 4afdea350..7390dec4d 100644 --- a/src/test/kotlin/org/wfanet/measurement/gcloud/spanner/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/gcloud/spanner/BUILD.bazel @@ -2,6 +2,7 @@ load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") kt_jvm_test( name = "AsyncDatabaseClientTest", + timeout = "short", srcs = ["AsyncDatabaseClientTest.kt"], resources = ["//src/test/resources/db/spanner"], tags = [ @@ -13,7 +14,9 @@ kt_jvm_test( "//imports/java/com/google/cloud/spanner", "//imports/java/com/google/common/truth", "//imports/java/org/junit", + "//imports/kotlin/kotlin/test", "//imports/kotlin/kotlinx/coroutines:core", + "//imports/kotlin/kotlinx/coroutines/debug", "//src/main/kotlin/org/wfanet/measurement/common", "//src/main/kotlin/org/wfanet/measurement/gcloud/spanner", "//src/main/kotlin/org/wfanet/measurement/gcloud/spanner/testing", diff --git a/src/test/resources/db/spanner/BUILD.bazel b/src/test/resources/db/spanner/BUILD.bazel index d71bdc4b6..a55b479c6 100644 --- a/src/test/resources/db/spanner/BUILD.bazel +++ b/src/test/resources/db/spanner/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_proto//proto:defs.bzl", "proto_descriptor_set") + package( default_testonly = True, default_visibility = ["//visibility:public"], @@ -7,5 +9,15 @@ filegroup( name = "spanner", srcs = glob([ "*.yaml", + "*.sql", ]), ) + +proto_descriptor_set( + name = "proto_descriptor_set", + visibility = ["//visibility:private"], + deps = [ + "@com_google_googleapis//google/type:dayofweek_proto", + "@com_google_googleapis//google/type:latlng_proto", + ], +) diff --git a/src/test/resources/db/spanner/changelog.yaml b/src/test/resources/db/spanner/changelog.yaml index 5dde243a3..a8919178c 100644 --- a/src/test/resources/db/spanner/changelog.yaml +++ b/src/test/resources/db/spanner/changelog.yaml @@ -17,4 +17,7 @@ databaseChangeLog: - preConditions: onFail: HALT - onError: HALT \ No newline at end of file + onError: HALT + - include: + file: create-test-schema.sql + relativeToChangeLogFile: true \ No newline at end of file diff --git a/src/test/resources/db/spanner/create-test-schema.sql b/src/test/resources/db/spanner/create-test-schema.sql new file mode 100644 index 000000000..99140940b --- /dev/null +++ b/src/test/resources/db/spanner/create-test-schema.sql @@ -0,0 +1,40 @@ +-- liquibase formatted sql + +-- Copyright 2024 The Cross-Media Measurement Authors +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- changeset sanjayvas:create-cars-table dbms:cloudspanner + +-- Set protobuf FileDescriptorSet as a base64 string. This gets applied to the next DDL batch. +SET PROTO_DESCRIPTORS = 'CqQCChtnb29nbGUvdHlwZS9kYXlvZndlZWsucHJvdG8SC2dvb2dsZS50eXBlKoQBCglEYXlPZldlZWsSGwoXREFZX09GX1dFRUtfVU5TUEVDSUZJRUQQABIKCgZNT05EQVkQARILCgdUVUVTREFZEAISDQoJV0VETkVTREFZEAMSDAoIVEhVUlNEQVkQBBIKCgZGUklEQVkQBRIMCghTQVRVUkRBWRAGEgoKBlNVTkRBWRAHQmkKD2NvbS5nb29nbGUudHlwZUIORGF5T2ZXZWVrUHJvdG9QAVo+Z29vZ2xlLmdvbGFuZy5vcmcvZ2VucHJvdG8vZ29vZ2xlYXBpcy90eXBlL2RheW9md2VlaztkYXlvZndlZWuiAgNHVFBiBnByb3RvMwrYAQoYZ29vZ2xlL3R5cGUvbGF0bG5nLnByb3RvEgtnb29nbGUudHlwZSJCCgZMYXRMbmcSGgoIbGF0aXR1ZGUYASABKAFSCGxhdGl0dWRlEhwKCWxvbmdpdHVkZRgCIAEoAVIJbG9uZ2l0dWRlQmMKD2NvbS5nb29nbGUudHlwZUILTGF0TG5nUHJvdG9QAVo4Z29vZ2xlLmdvbGFuZy5vcmcvZ2VucHJvdG8vZ29vZ2xlYXBpcy90eXBlL2xhdGxuZztsYXRsbmf4AQGiAgNHVFBiBnByb3RvMw=='; + +START BATCH DDL; + +CREATE PROTO BUNDLE( + `google.type.DayOfWeek`, + `google.type.LatLng` +); + +CREATE TABLE Cars ( + CarId INT64 NOT NULL, + Year INT64 NOT NULL, + Make STRING(MAX) NOT NULL, + Model STRING(MAX) NOT NULL, + Owner STRING(MAX), + + CurrentLocation `google.type.LatLng`, + WeeklyWashDay `google.type.DayOfWeek` NOT NULL DEFAULT (0), +) PRIMARY KEY (CarId); + +RUN BATCH; \ No newline at end of file