From e39e4a8e439f922cafe2c804b31b605da6727acb Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sun, 10 Dec 2023 17:32:30 -0800 Subject: [PATCH] extract DelegateQuery and DelegateQueryable --- scalasql/core/src/DbApi.scala | 2 +- scalasql/core/src/Queryable.scala | 4 +-- scalasql/query/src/Aggregate.scala | 11 +++----- scalasql/query/src/OnConflict.scala | 13 +++------- scalasql/query/src/Query.scala | 32 ++++++++++++++++-------- scalasql/query/src/Returning.scala | 23 ++++++++--------- scalasql/query/src/Select.scala | 5 ++-- scalasql/query/src/SimpleSelect.scala | 5 ++-- scalasql/query/src/Values.scala | 9 +++---- scalasql/src/dialects/MySqlDialect.scala | 8 ++---- 10 files changed, 52 insertions(+), 60 deletions(-) diff --git a/scalasql/core/src/DbApi.scala b/scalasql/core/src/DbApi.scala index dbb2379d..d2287834 100644 --- a/scalasql/core/src/DbApi.scala +++ b/scalasql/core/src/DbApi.scala @@ -189,7 +189,7 @@ object DbApi { fileName, lineNum ) - if (qr.singleRow(query)) { + if (qr.isSingleRow(query)) { val results = res.take(2).toVector assert( results.size == 1, diff --git a/scalasql/core/src/Queryable.scala b/scalasql/core/src/Queryable.scala index d2a258cf..409b52e4 100644 --- a/scalasql/core/src/Queryable.scala +++ b/scalasql/core/src/Queryable.scala @@ -40,7 +40,7 @@ trait Queryable[-Q, R] { * Whether this query expects a single row to be returned, if so we can assert on * the number of rows and raise an error if 0 rows or 2+ rows are present */ - def singleRow(q: Q): Boolean + def isSingleRow(q: Q): Boolean /** * Converts the given queryable value into a [[SqlStr]], that can then be executed @@ -83,7 +83,7 @@ object Queryable { */ trait Row[Q, R] extends Queryable[Q, R] { def isExecuteUpdate(q: Q): Boolean = false - def singleRow(q: Q): Boolean = true + def isSingleRow(q: Q): Boolean = true def walkLabels(): Seq[List[String]] def walkLabels(q: Q): Seq[List[String]] = walkLabels() diff --git a/scalasql/query/src/Aggregate.scala b/scalasql/query/src/Aggregate.scala index 953cb753..a4a2c6d2 100644 --- a/scalasql/query/src/Aggregate.scala +++ b/scalasql/query/src/Aggregate.scala @@ -5,14 +5,11 @@ import scalasql.core.{Queryable, Expr, SqlStr, TypeMapper, Context} class Aggregate[Q, R]( toSqlStr0: Context => SqlStr, construct0: Queryable.ResultSetIterator => R, - expr: Q -)( - qr: Queryable[Q, R] -) extends Query[R] { + protected val expr: Q, + protected val qr: Queryable[Q, R] +) extends Query.DelegateQueryable[Q, R] { - protected def queryWalkLabels() = qr.walkLabels(expr) - protected def queryWalkExprs() = qr.walkExprs(expr) - protected def queryIsSingleRow: Boolean = true + protected override def queryIsSingleRow: Boolean = true protected def renderSql(ctx: Context) = toSqlStr0(ctx) override protected def queryConstruct(args: Queryable.ResultSetIterator): R = construct0(args) diff --git a/scalasql/query/src/OnConflict.scala b/scalasql/query/src/OnConflict.scala index 344ec9fa..70275bd5 100644 --- a/scalasql/query/src/OnConflict.scala +++ b/scalasql/query/src/OnConflict.scala @@ -18,12 +18,10 @@ object OnConflict { query: Query[R] with InsertReturnable[Q], columns: Seq[Column[_]], val table: TableRef - ) extends Query[R] + ) extends Query.DelegateQuery[R] with InsertReturnable[Q] { protected def expr = WithSqlExpr.get(query) - protected def queryWalkLabels() = Query.walkLabels(query) - protected def queryWalkExprs() = Query.walkSqlExprs(query) - protected def queryIsSingleRow = Query.isSingleRow(query) + protected def queryDelegate = query protected def renderSql(ctx: Context) = { val str = Renderable.renderSql(query)(ctx) str + sql" ON CONFLICT (${SqlStr.join(columns.map(c => SqlStr.raw(c.name)), SqlStr.commaSep)}) DO NOTHING" @@ -40,13 +38,10 @@ object OnConflict { columns: Seq[Column[_]], updates: Seq[Column.Assignment[_]], val table: TableRef - ) extends Query[R] + ) extends Query.DelegateQuery[R] with InsertReturnable[Q] { protected def expr = WithSqlExpr.get(query) - - protected def queryWalkLabels() = Query.walkLabels(query) - protected def queryWalkExprs() = Query.walkSqlExprs(query) - protected def queryIsSingleRow = Query.isSingleRow(query) + protected def queryDelegate = query protected def renderSql(ctx: Context) = { implicit val implicitCtx = Context.compute(ctx, Nil, Some(table)) val str = Renderable.renderSql(query) diff --git a/scalasql/query/src/Query.scala b/scalasql/query/src/Query.scala index 504f323f..e9c2a516 100644 --- a/scalasql/query/src/Query.scala +++ b/scalasql/query/src/Query.scala @@ -1,8 +1,7 @@ package scalasql.query import scalasql.core.SqlStr.Renderable -import scalasql.core.{Queryable, SqlStr, Expr} -import scalasql.core.Context +import scalasql.core.{Context, Expr, Queryable, SqlStr, WithSqlExpr} /** * A SQL Query, either a [[Query.Multiple]] that returns multiple rows, or @@ -25,6 +24,22 @@ object Query { protected override def queryIsExecuteUpdate = true } + trait DelegateQuery[R] extends scalasql.query.Query[R] { + protected def queryDelegate: Query[_] + protected def queryWalkLabels() = queryDelegate.queryWalkLabels() + protected def queryWalkExprs() = queryDelegate.queryWalkExprs() + protected override def queryIsSingleRow = queryDelegate.queryIsSingleRow + protected override def queryIsExecuteUpdate = queryDelegate.queryIsExecuteUpdate + } + + trait DelegateQueryable[Q, R] extends scalasql.query.Query[R] with WithSqlExpr[Q] { + protected def qr: Queryable[Q, _] + protected def queryWalkLabels() = qr.walkLabels(expr) + protected def queryWalkExprs() = qr.walkExprs(expr) + protected override def queryIsSingleRow = qr.isSingleRow(expr) + protected override def queryIsExecuteUpdate = qr.isExecuteUpdate(expr) + } + implicit def QueryQueryable[R]: Queryable[Query[R], R] = new QueryQueryable[Query[R], R]() def walkLabels[R](q: Query[R]) = q.queryWalkLabels() @@ -35,21 +50,16 @@ object Query { override def isExecuteUpdate(q: Q) = q.queryIsExecuteUpdate override def walkLabels(q: Q) = q.queryWalkLabels() override def walkExprs(q: Q) = q.queryWalkExprs() - override def singleRow(q: Q) = q.queryIsSingleRow + override def isSingleRow(q: Q) = q.queryIsSingleRow def toSqlStr(q: Q, ctx: Context): SqlStr = q.renderSql(ctx) override def construct(q: Q, args: Queryable.ResultSetIterator): R = q.queryConstruct(args) } - trait Multiple[R] extends Query[Seq[R]] - - class Single[R](query: Multiple[R]) extends Query[R] { - override def queryIsExecuteUpdate = query.queryIsExecuteUpdate - protected def queryWalkLabels() = query.queryWalkLabels() - protected def queryWalkExprs() = query.queryWalkExprs() - - protected def queryIsSingleRow: Boolean = true + class Single[R](query: Query[Seq[R]]) extends Query.DelegateQuery[R] { + protected def queryDelegate = query + protected override def queryIsSingleRow: Boolean = true protected def renderSql(ctx: Context): SqlStr = Renderable.renderSql(query)(ctx) protected override def queryConstruct(args: Queryable.ResultSetIterator): R = diff --git a/scalasql/query/src/Returning.scala b/scalasql/query/src/Returning.scala index dbb5c538..2ee46cad 100644 --- a/scalasql/query/src/Returning.scala +++ b/scalasql/query/src/Returning.scala @@ -17,37 +17,34 @@ trait InsertReturnable[Q] extends Returnable[Q] /** * A query with a `RETURNING` clause */ -trait Returning[Q, R] extends Query.Multiple[R] { +trait Returning[Q, R] extends Query[Seq[R]] with Query.DelegateQueryable[Q, Seq[R]] { def single: Query.Single[R] = new Query.Single(this) } object InsertReturning { class Impl[Q, R](returnable: InsertReturnable[_], returning: Q)( - implicit val qr: Queryable.Row[Q, R] + implicit qr: Queryable.Row[Q, R] ) extends Returning.Impl0[Q, R](qr, returnable, returning) - with Returning[Q, R] { - protected def expr: Q = returning - } + with Returning[Q, R] {} } object Returning { - class Impl0[Q, R](qr: Queryable.Row[Q, R], returnable: Returnable[_], returning: Q) - extends Returning[Q, R] { + class Impl0[Q, R]( + protected val qr: Queryable.Row[Q, R], + returnable: Returnable[_], + protected val expr: Q + ) extends Returning[Q, R] { override protected def queryConstruct(args: Queryable.ResultSetIterator): Seq[R] = { Seq(qr.construct(args)) } - protected def queryWalkLabels() = qr.walkLabels(returning) - - protected def queryWalkExprs() = qr.walkExprs(returning) - override def queryIsSingleRow = false protected override def renderSql(ctx0: Context) = { implicit val implicitCtx = Context.compute(ctx0, Nil, Some(returnable.table)) val prefix = Renderable.renderSql(returnable) - val walked = qr.walkLabelsAndExprs(returning) + val walked = qr.walkLabelsAndExprs(expr) val exprStr = ExprsToSql.apply(walked, implicitCtx, SqlStr.empty) val suffix = sql" RETURNING $exprStr" @@ -55,7 +52,7 @@ object Returning { } } - class Impl[Q, R](returnable: Returnable[_], returning: Q)(implicit val qr: Queryable.Row[Q, R]) + class Impl[Q, R](returnable: Returnable[_], returning: Q)(implicit qr: Queryable.Row[Q, R]) extends Impl0[Q, R](qr, returnable, returning) with Returning[Q, R] diff --git a/scalasql/query/src/Select.scala b/scalasql/query/src/Select.scala index c5ee5f11..f82cb2b4 100644 --- a/scalasql/query/src/Select.scala +++ b/scalasql/query/src/Select.scala @@ -39,7 +39,8 @@ trait Select[Q, R] with Aggregatable[Q] with Joinable[Q, R] with JoinOps[Select, Q, R] - with Query.Multiple[R] + with Query[Seq[R]] + with Query.DelegateQueryable[Q, Seq[R]] with SelectBase { protected def dialect: DialectTypeMappers @@ -191,8 +192,6 @@ trait Select[Q, R] renderer.render(LiveSqlExprs.none).withCompleteQuery(true) } - protected def queryWalkLabels() = qr.walkLabels(expr) - protected def queryWalkExprs() = qr.walkExprs(expr) protected override def queryIsSingleRow = false /** diff --git a/scalasql/query/src/SimpleSelect.scala b/scalasql/query/src/SimpleSelect.scala index 6a3613d8..2a110026 100644 --- a/scalasql/query/src/SimpleSelect.scala +++ b/scalasql/query/src/SimpleSelect.scala @@ -152,8 +152,9 @@ class SimpleSelect[Q, R]( new Aggregate[E, V]( implicit ctx => copied.renderSql(ctx), r => Query.construct(copied, r).head, - selectProxyExpr - )(qr) + selectProxyExpr, + qr + ) } def mapAggregate[Q2, R2]( diff --git a/scalasql/query/src/Values.scala b/scalasql/query/src/Values.scala index 8dce41b0..e950c515 100644 --- a/scalasql/query/src/Values.scala +++ b/scalasql/query/src/Values.scala @@ -18,18 +18,15 @@ import scalasql.core.SqlStr.{Renderable, SqlStringSyntax} class Values[Q, R](val ts: Seq[R])( implicit val qr: Queryable.Row[Q, R], protected val dialect: DialectTypeMappers -) extends Select.Proxy[Q, R] { +) extends Select.Proxy[Q, R] + with Query.DelegateQueryable[Q, Seq[R]] { assert(ts.nonEmpty, "`Values` clause does not support empty sequence") protected def selectToSimpleSelect() = this.subquery val tableRef = new SubqueryRef(this, qr) protected def columnName(n: Int) = s"column${n + 1}" - override val expr: Q = qr.deconstruct(ts.head) - - override protected def queryWalkLabels() = qr.walkExprs(expr).indices.map(i => List(i.toString)) - - override protected def queryWalkExprs() = qr.walkExprs(expr) + protected override val expr: Q = qr.deconstruct(ts.head) override protected def selectRenderer(prevContext: Context): SelectBase.Renderer = new Values.Renderer(this)(implicitly, prevContext) diff --git a/scalasql/src/dialects/MySqlDialect.scala b/scalasql/src/dialects/MySqlDialect.scala index b7dbc72c..7262764e 100644 --- a/scalasql/src/dialects/MySqlDialect.scala +++ b/scalasql/src/dialects/MySqlDialect.scala @@ -268,13 +268,9 @@ object MySqlDialect extends MySqlDialect { insert: OnConflictable[Q, R], updates: Seq[Column.Assignment[_]], table: TableRef - ) extends Query[R] { - + ) extends Query.DelegateQuery[R] { + protected def queryDelegate = insert.query override def queryIsExecuteUpdate = true - protected def queryWalkLabels() = Query.walkLabels(insert.query) - protected def queryWalkExprs() = Query.walkSqlExprs(insert.query) - - protected def queryIsSingleRow = Query.isSingleRow(insert.query) protected def renderSql(ctx: Context) = { implicit val implicitCtx = Context.compute(ctx, Nil, Some(table))