Skip to content

Commit

Permalink
extract DelegateQuery and DelegateQueryable
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Dec 11, 2023
1 parent e9aacf2 commit e39e4a8
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 60 deletions.
2 changes: 1 addition & 1 deletion scalasql/core/src/DbApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ object DbApi {
fileName,
lineNum
)
if (qr.singleRow(query)) {
if (qr.isSingleRow(query)) {
val results = res.take(2).toVector
assert(
results.size == 1,
Expand Down
4 changes: 2 additions & 2 deletions scalasql/core/src/Queryable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ trait Queryable[-Q, R] {
* Whether this query expects a single row to be returned, if so we can assert on
* the number of rows and raise an error if 0 rows or 2+ rows are present
*/
def singleRow(q: Q): Boolean
def isSingleRow(q: Q): Boolean

/**
* Converts the given queryable value into a [[SqlStr]], that can then be executed
Expand Down Expand Up @@ -83,7 +83,7 @@ object Queryable {
*/
trait Row[Q, R] extends Queryable[Q, R] {
def isExecuteUpdate(q: Q): Boolean = false
def singleRow(q: Q): Boolean = true
def isSingleRow(q: Q): Boolean = true
def walkLabels(): Seq[List[String]]
def walkLabels(q: Q): Seq[List[String]] = walkLabels()

Expand Down
11 changes: 4 additions & 7 deletions scalasql/query/src/Aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@ import scalasql.core.{Queryable, Expr, SqlStr, TypeMapper, Context}
class Aggregate[Q, R](
toSqlStr0: Context => SqlStr,
construct0: Queryable.ResultSetIterator => R,
expr: Q
)(
qr: Queryable[Q, R]
) extends Query[R] {
protected val expr: Q,
protected val qr: Queryable[Q, R]
) extends Query.DelegateQueryable[Q, R] {

protected def queryWalkLabels() = qr.walkLabels(expr)
protected def queryWalkExprs() = qr.walkExprs(expr)
protected def queryIsSingleRow: Boolean = true
protected override def queryIsSingleRow: Boolean = true
protected def renderSql(ctx: Context) = toSqlStr0(ctx)

override protected def queryConstruct(args: Queryable.ResultSetIterator): R = construct0(args)
Expand Down
13 changes: 4 additions & 9 deletions scalasql/query/src/OnConflict.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ object OnConflict {
query: Query[R] with InsertReturnable[Q],
columns: Seq[Column[_]],
val table: TableRef
) extends Query[R]
) extends Query.DelegateQuery[R]
with InsertReturnable[Q] {
protected def expr = WithSqlExpr.get(query)
protected def queryWalkLabels() = Query.walkLabels(query)
protected def queryWalkExprs() = Query.walkSqlExprs(query)
protected def queryIsSingleRow = Query.isSingleRow(query)
protected def queryDelegate = query
protected def renderSql(ctx: Context) = {
val str = Renderable.renderSql(query)(ctx)
str + sql" ON CONFLICT (${SqlStr.join(columns.map(c => SqlStr.raw(c.name)), SqlStr.commaSep)}) DO NOTHING"
Expand All @@ -40,13 +38,10 @@ object OnConflict {
columns: Seq[Column[_]],
updates: Seq[Column.Assignment[_]],
val table: TableRef
) extends Query[R]
) extends Query.DelegateQuery[R]
with InsertReturnable[Q] {
protected def expr = WithSqlExpr.get(query)

protected def queryWalkLabels() = Query.walkLabels(query)
protected def queryWalkExprs() = Query.walkSqlExprs(query)
protected def queryIsSingleRow = Query.isSingleRow(query)
protected def queryDelegate = query
protected def renderSql(ctx: Context) = {
implicit val implicitCtx = Context.compute(ctx, Nil, Some(table))
val str = Renderable.renderSql(query)
Expand Down
32 changes: 21 additions & 11 deletions scalasql/query/src/Query.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package scalasql.query

import scalasql.core.SqlStr.Renderable
import scalasql.core.{Queryable, SqlStr, Expr}
import scalasql.core.Context
import scalasql.core.{Context, Expr, Queryable, SqlStr, WithSqlExpr}

/**
* A SQL Query, either a [[Query.Multiple]] that returns multiple rows, or
Expand All @@ -25,6 +24,22 @@ object Query {
protected override def queryIsExecuteUpdate = true
}

trait DelegateQuery[R] extends scalasql.query.Query[R] {
protected def queryDelegate: Query[_]
protected def queryWalkLabels() = queryDelegate.queryWalkLabels()
protected def queryWalkExprs() = queryDelegate.queryWalkExprs()
protected override def queryIsSingleRow = queryDelegate.queryIsSingleRow
protected override def queryIsExecuteUpdate = queryDelegate.queryIsExecuteUpdate
}

trait DelegateQueryable[Q, R] extends scalasql.query.Query[R] with WithSqlExpr[Q] {
protected def qr: Queryable[Q, _]
protected def queryWalkLabels() = qr.walkLabels(expr)
protected def queryWalkExprs() = qr.walkExprs(expr)
protected override def queryIsSingleRow = qr.isSingleRow(expr)
protected override def queryIsExecuteUpdate = qr.isExecuteUpdate(expr)
}

implicit def QueryQueryable[R]: Queryable[Query[R], R] = new QueryQueryable[Query[R], R]()

def walkLabels[R](q: Query[R]) = q.queryWalkLabels()
Expand All @@ -35,21 +50,16 @@ object Query {
override def isExecuteUpdate(q: Q) = q.queryIsExecuteUpdate
override def walkLabels(q: Q) = q.queryWalkLabels()
override def walkExprs(q: Q) = q.queryWalkExprs()
override def singleRow(q: Q) = q.queryIsSingleRow
override def isSingleRow(q: Q) = q.queryIsSingleRow

def toSqlStr(q: Q, ctx: Context): SqlStr = q.renderSql(ctx)

override def construct(q: Q, args: Queryable.ResultSetIterator): R = q.queryConstruct(args)
}

trait Multiple[R] extends Query[Seq[R]]

class Single[R](query: Multiple[R]) extends Query[R] {
override def queryIsExecuteUpdate = query.queryIsExecuteUpdate
protected def queryWalkLabels() = query.queryWalkLabels()
protected def queryWalkExprs() = query.queryWalkExprs()

protected def queryIsSingleRow: Boolean = true
class Single[R](query: Query[Seq[R]]) extends Query.DelegateQuery[R] {
protected def queryDelegate = query
protected override def queryIsSingleRow: Boolean = true

protected def renderSql(ctx: Context): SqlStr = Renderable.renderSql(query)(ctx)
protected override def queryConstruct(args: Queryable.ResultSetIterator): R =
Expand Down
23 changes: 10 additions & 13 deletions scalasql/query/src/Returning.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,42 @@ trait InsertReturnable[Q] extends Returnable[Q]
/**
* A query with a `RETURNING` clause
*/
trait Returning[Q, R] extends Query.Multiple[R] {
trait Returning[Q, R] extends Query[Seq[R]] with Query.DelegateQueryable[Q, Seq[R]] {
def single: Query.Single[R] = new Query.Single(this)
}

object InsertReturning {
class Impl[Q, R](returnable: InsertReturnable[_], returning: Q)(
implicit val qr: Queryable.Row[Q, R]
implicit qr: Queryable.Row[Q, R]
) extends Returning.Impl0[Q, R](qr, returnable, returning)
with Returning[Q, R] {
protected def expr: Q = returning
}
with Returning[Q, R] {}
}
object Returning {
class Impl0[Q, R](qr: Queryable.Row[Q, R], returnable: Returnable[_], returning: Q)
extends Returning[Q, R] {
class Impl0[Q, R](
protected val qr: Queryable.Row[Q, R],
returnable: Returnable[_],
protected val expr: Q
) extends Returning[Q, R] {

override protected def queryConstruct(args: Queryable.ResultSetIterator): Seq[R] = {
Seq(qr.construct(args))
}

protected def queryWalkLabels() = qr.walkLabels(returning)

protected def queryWalkExprs() = qr.walkExprs(returning)

override def queryIsSingleRow = false

protected override def renderSql(ctx0: Context) = {
implicit val implicitCtx = Context.compute(ctx0, Nil, Some(returnable.table))

val prefix = Renderable.renderSql(returnable)
val walked = qr.walkLabelsAndExprs(returning)
val walked = qr.walkLabelsAndExprs(expr)
val exprStr = ExprsToSql.apply(walked, implicitCtx, SqlStr.empty)
val suffix = sql" RETURNING $exprStr"

prefix + suffix
}

}
class Impl[Q, R](returnable: Returnable[_], returning: Q)(implicit val qr: Queryable.Row[Q, R])
class Impl[Q, R](returnable: Returnable[_], returning: Q)(implicit qr: Queryable.Row[Q, R])
extends Impl0[Q, R](qr, returnable, returning)
with Returning[Q, R]

Expand Down
5 changes: 2 additions & 3 deletions scalasql/query/src/Select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ trait Select[Q, R]
with Aggregatable[Q]
with Joinable[Q, R]
with JoinOps[Select, Q, R]
with Query.Multiple[R]
with Query[Seq[R]]
with Query.DelegateQueryable[Q, Seq[R]]
with SelectBase {

protected def dialect: DialectTypeMappers
Expand Down Expand Up @@ -191,8 +192,6 @@ trait Select[Q, R]

renderer.render(LiveSqlExprs.none).withCompleteQuery(true)
}
protected def queryWalkLabels() = qr.walkLabels(expr)
protected def queryWalkExprs() = qr.walkExprs(expr)
protected override def queryIsSingleRow = false

/**
Expand Down
5 changes: 3 additions & 2 deletions scalasql/query/src/SimpleSelect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ class SimpleSelect[Q, R](
new Aggregate[E, V](
implicit ctx => copied.renderSql(ctx),
r => Query.construct(copied, r).head,
selectProxyExpr
)(qr)
selectProxyExpr,
qr
)
}

def mapAggregate[Q2, R2](
Expand Down
9 changes: 3 additions & 6 deletions scalasql/query/src/Values.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@ import scalasql.core.SqlStr.{Renderable, SqlStringSyntax}
class Values[Q, R](val ts: Seq[R])(
implicit val qr: Queryable.Row[Q, R],
protected val dialect: DialectTypeMappers
) extends Select.Proxy[Q, R] {
) extends Select.Proxy[Q, R]
with Query.DelegateQueryable[Q, Seq[R]] {
assert(ts.nonEmpty, "`Values` clause does not support empty sequence")

protected def selectToSimpleSelect() = this.subquery
val tableRef = new SubqueryRef(this, qr)
protected def columnName(n: Int) = s"column${n + 1}"

override val expr: Q = qr.deconstruct(ts.head)

override protected def queryWalkLabels() = qr.walkExprs(expr).indices.map(i => List(i.toString))

override protected def queryWalkExprs() = qr.walkExprs(expr)
protected override val expr: Q = qr.deconstruct(ts.head)

override protected def selectRenderer(prevContext: Context): SelectBase.Renderer =
new Values.Renderer(this)(implicitly, prevContext)
Expand Down
8 changes: 2 additions & 6 deletions scalasql/src/dialects/MySqlDialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,9 @@ object MySqlDialect extends MySqlDialect {
insert: OnConflictable[Q, R],
updates: Seq[Column.Assignment[_]],
table: TableRef
) extends Query[R] {

) extends Query.DelegateQuery[R] {
protected def queryDelegate = insert.query
override def queryIsExecuteUpdate = true
protected def queryWalkLabels() = Query.walkLabels(insert.query)
protected def queryWalkExprs() = Query.walkSqlExprs(insert.query)

protected def queryIsSingleRow = Query.isSingleRow(insert.query)

protected def renderSql(ctx: Context) = {
implicit val implicitCtx = Context.compute(ctx, Nil, Some(table))
Expand Down

0 comments on commit e39e4a8

Please sign in to comment.