Skip to content

Commit

Permalink
Add ValuesListBoundStatement for handling Values list statements (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanvuong2021 authored Mar 8, 2024
1 parent 66d8614 commit 8dd5f40
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ import org.wfanet.measurement.common.identity.InternalId

/** An SQL statement with bound parameters. */
class BoundStatement
private constructor(
private val baseSql: String,
private val bindings: Collection<Binding>,
) {
private constructor(private val baseSql: String, private val bindings: Collection<Binding>) {
@DslMarker private annotation class DslBuilder

/** Builder for a single statement binding. */
Expand Down Expand Up @@ -60,6 +57,32 @@ private constructor(

/** Binds the parameter named [name] with type [kClass] to `NULL`. */
@PublishedApi internal abstract fun bindNull(name: String, kClass: KClass<*>)

/** Binds the parameter [index] to [value]. */
fun bind(index: Int, value: ExternalId?) = bind(index, value?.value)
/** Binds the parameter [index] to [value]. */
fun bind(index: Int, value: InternalId?) = bind(index, value?.value)
/** Binds the parameter [index] to [value]. */
fun bind(index: Int, value: Message?) =
bind(index, value?.toByteString()?.asReadOnlyByteBuffer())
/** Binds the parameter [index] to [value]. */
fun bind(index: Int, value: ProtocolMessageEnum?) = bind(index, value?.number)

/** Binds the parameter [index] to [value]. */
@JvmName("bindNullableIndex")
inline fun <reified T> bind(index: Int, value: T) {
if (value == null) {
bindNull(index, T::class)
} else {
bind(index, value)
}
}

/** Binds the parameter [index] to [value]. */
abstract fun <T : Any> bind(index: Int, value: T)

/** Binds the parameter [index] with type [kClass] to `NULL`. */
@PublishedApi internal abstract fun bindNull(index: Int, kClass: KClass<*>)
}

/** Builder for a SQL statement, which could be a query. */
Expand Down Expand Up @@ -92,24 +115,37 @@ private constructor(
* ```
*/
abstract fun addBinding(bind: Binder.() -> Unit)

/** Creates a [BoundStatement] from this [Builder]. */
abstract fun build(baseSql: String): BoundStatement
}

private class BinderImpl() : Binder() {
private val values = mutableMapOf<String, Any>()
private val nulls = mutableMapOf<String, Class<out Any?>>()
private class BinderImpl : Binder() {
private val stringIndexValues = mutableMapOf<String, Any>()
private val intIndexValues = mutableMapOf<Int, Any>()
private val stringIndexNulls = mutableMapOf<String, Class<out Any?>>()
private val intIndexNulls = mutableMapOf<Int, Class<out Any?>>()

override fun <T : Any> bind(name: String, value: T) {
values[name] = value
stringIndexValues[name] = value
}

override fun <T : Any> bind(index: Int, value: T) {
intIndexValues[index] = value
}

override fun bindNull(name: String, kClass: KClass<*>) {
nulls[name] = kClass.javaObjectType
stringIndexNulls[name] = kClass.javaObjectType
}

fun build() = Binding(values, nulls)
override fun bindNull(index: Int, kClass: KClass<*>) {
intIndexNulls[index] = kClass.javaObjectType
}

fun build() = Binding(stringIndexValues, intIndexValues, stringIndexNulls, intIndexNulls)
}

private class BuilderImpl(private val baseSql: String) : Builder() {
private class BuilderImpl : Builder() {
private val binders: MutableList<BinderImpl> = mutableListOf()
private var bindable = true
private val initialBinder: BinderImpl
Expand All @@ -132,8 +168,12 @@ private constructor(

override fun bindNull(name: String, kClass: KClass<*>) = initialBinder.bindNull(name, kClass)

override fun <T : Any> bind(index: Int, value: T) = initialBinder.bind(index, value)

override fun bindNull(index: Int, kClass: KClass<*>) = initialBinder.bindNull(index, kClass)

/** Builds a [BoundStatement] from this builder. */
fun build(): BoundStatement {
override fun build(baseSql: String): BoundStatement {
return BoundStatement(baseSql, binders.map { it.build() })
}
}
Expand All @@ -156,22 +196,41 @@ private constructor(

companion object {
internal fun boundStatement(baseSql: String, bind: Builder.() -> Unit): BoundStatement {
return BuilderImpl(baseSql).apply(bind).build()
return BuilderImpl().apply(bind).build(baseSql)
}

internal fun builder(bind: Builder.() -> Unit): Builder {
return BuilderImpl().apply(bind)
}

private fun Statement.apply(binding: Binding) {
for ((name, value) in binding.values) {
for ((name, value) in binding.stringIndexValues) {
bind(name, value)
}
for ((name, type) in binding.nulls) {
for ((index, value) in binding.intIndexValues) {
bind(index, value)
}
for ((name, type) in binding.stringIndexNulls) {
bindNull(name, type)
}
for ((index, type) in binding.intIndexNulls) {
bindNull(index, type)
}
}
}
}

private data class Binding(val values: Map<String, Any>, val nulls: Map<String, Class<out Any?>>)
private data class Binding(
val stringIndexValues: Map<String, Any>,
val intIndexValues: Map<Int, Any>,
val stringIndexNulls: Map<String, Class<out Any?>>,
val intIndexNulls: Map<Int, Class<out Any?>>,
)

/** Builds a [BoundStatement]. */
fun boundStatement(baseSql: String, bind: BoundStatement.Builder.() -> Unit = {}): BoundStatement =
BoundStatement.boundStatement(baseSql, bind)

/** Creates a [BoundStatement.Builder]. */
fun builder(bind: BoundStatement.Builder.() -> Unit = {}): BoundStatement.Builder =
BoundStatement.builder(bind)
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.common.db.r2dbc.postgres

import com.google.protobuf.Message
import com.google.protobuf.ProtocolMessageEnum
import org.wfanet.measurement.common.db.r2dbc.BoundStatement
import org.wfanet.measurement.common.db.r2dbc.builder
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.InternalId

/** Contains a Postgres specific builder for [BoundStatement]s with a values list. */
class ValuesListBoundStatement private constructor() {
@DslMarker private annotation class DslBuilder

@DslBuilder
class ValuesListBoundStatementBuilder(
valuesStartIndex: Int,
@PublishedApi internal val paramCount: Int,
@PublishedApi internal val binder: BoundStatement.Binder,
) {
@PublishedApi
internal var valuesCurIndex = valuesStartIndex
private set

fun addValuesBinding(bind: ValuesListBoundStatementBuilder.() -> Unit) {
this.apply(bind)
valuesCurIndex += paramCount
}

fun bindValuesParam(index: Int, value: ExternalId?) =
binder.bind(index + valuesCurIndex, value?.value)

fun bindValuesParam(index: Int, value: InternalId?) =
binder.bind(index + valuesCurIndex, value?.value)

fun bindValuesParam(index: Int, value: Message?) =
binder.bind(index + valuesCurIndex, value?.toByteString()?.asReadOnlyByteBuffer())

fun bindValuesParam(index: Int, value: ProtocolMessageEnum?) =
binder.bind(index + valuesCurIndex, value?.number)

inline fun <reified T : Any> bindValuesParam(index: Int, value: T?) =
binder.bind(index + valuesCurIndex, value)

fun bind(name: String, value: ExternalId?) = binder.bind(name, value?.value)

fun bind(name: String, value: InternalId?) = binder.bind(name, value?.value)

fun bind(name: String, value: Message?) =
binder.bind(name, value?.toByteString()?.asReadOnlyByteBuffer())

fun bind(name: String, value: ProtocolMessageEnum?) = binder.bind(name, value?.number)

inline fun <reified T : Any> bind(name: String, value: T?) {
binder.bind(name, value)
}

fun bind(index: Int, value: ExternalId?) = binder.bind(index, value?.value)

fun bind(index: Int, value: InternalId?) = binder.bind(index, value?.value)

fun bind(index: Int, value: Message?) =
binder.bind(index, value?.toByteString()?.asReadOnlyByteBuffer())

fun bind(index: Int, value: ProtocolMessageEnum?) = binder.bind(index, value?.number)

inline fun <reified T : Any> bind(index: Int, value: T?) {
binder.bind(index, value)
}
}

companion object {
const val VALUES_LIST_PLACEHOLDER = "VALUES_LIST"

internal fun valuesListBoundStatement(
valuesStartIndex: Int,
paramCount: Int,
baseSql: String,
bind: ValuesListBoundStatementBuilder.() -> Unit = {},
): BoundStatement {
var valuesEndIndex = 0
val builder: BoundStatement.Builder = builder {
val builder = ValuesListBoundStatementBuilder(valuesStartIndex, paramCount, this)
builder.apply(bind)
valuesEndIndex = builder.valuesCurIndex
}

val range = valuesStartIndex + 1..valuesEndIndex
val params = range.toList()
val chunkedParams = params.chunked(paramCount)
val valuesList =
chunkedParams.joinToString(separator = ",") { values ->
values.joinToString(prefix = "(", postfix = ")") { value -> "$$value" }
}

return builder.build(baseSql.replace(VALUES_LIST_PLACEHOLDER, valuesList))
}
}
}

/** Builds a [BoundStatement] with a values list. */
fun valuesListBoundStatement(
valuesStartIndex: Int,
paramCount: Int,
baseSql: String,
bind: ValuesListBoundStatement.ValuesListBoundStatementBuilder.() -> Unit = {},
): BoundStatement =
ValuesListBoundStatement.valuesListBoundStatement(
valuesStartIndex = valuesStartIndex,
paramCount = paramCount,
baseSql = baseSql,
bind,
)
Loading

0 comments on commit 8dd5f40

Please sign in to comment.