Skip to content

Commit

Permalink
move CTE exprNaming updates into Context.compute
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Dec 11, 2023
1 parent 620e161 commit 9b272a5
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 43 deletions.
2 changes: 1 addition & 1 deletion scalasql/core/src/Context.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ object Context {
prefixedFroms.iterator
.flatMap { t =>
t
.fromExprAliases(prevContext)
.fromExprAliases(prevContext.withFromNaming(newFromNaming))
.map { case (e, s) => (e, sql"${SqlStr.raw(newFromNaming(t), Array(e))}.$s") }
}

Expand Down
9 changes: 6 additions & 3 deletions scalasql/query/src/From.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package scalasql.query

import scalasql.core.{Context, LiveSqlExprs, Queryable, Expr, SqlStr}
import scalasql.core.{Context, Expr, ExprsToSql, LiveSqlExprs, Queryable, SqlStr, WithSqlExpr}
import scalasql.core.Context.From
import scalasql.core.SqlStr.SqlStringSyntax

Expand All @@ -11,6 +11,7 @@ class TableRef(val value: Table.Base) extends From {
override def toString = s"TableRef(${Table.name(value)})"

def fromRefPrefix(prevContext: Context) = prevContext.config.tableNameMapper(Table.name(value))

def fromExprAliases(prevContext: Context) = Map()

def renderSql(name: SqlStr, prevContext: Context, liveExprs: LiveSqlExprs) = {
Expand Down Expand Up @@ -48,10 +49,12 @@ object SubqueryRef {
}
}

class WithCteRef() extends From {
class WithCteRef(walked: Queryable.Walked) extends From {
def fromRefPrefix(prevContext: Context) = "cte"

def fromExprAliases(prevContext: Context) = Map()
def fromExprAliases(prevContext: Context) = {
ExprsToSql.selectColumnReferences(walked, prevContext).toMap
}

def renderSql(name: SqlStr, prevContext: Context, liveExprs: LiveSqlExprs) = {
name
Expand Down
67 changes: 32 additions & 35 deletions scalasql/query/src/WithCte.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,34 @@ import scalasql.core.SqlStr.{Renderable, SqlStringSyntax}
* A SQL `WITH` clause
*/
class WithCte[Q, R](
walked: Queryable.Walked,
val lhs: Select[_, _],
val lhsSubQuery: WithCteRef,
val cteRef: WithCteRef,
val rhs: Select[Q, R],
val withPrefix: SqlStr = sql"WITH "
)(implicit val qr: Queryable.Row[Q, R], protected val dialect: DialectTypeMappers)
extends Select.Proxy[Q, R] {

override protected def expr = Joinable.toFromExpr(rhs)._2
private def unprefixed = new WithCte(lhs, lhsSubQuery, rhs, SqlStr.commaSep)
private def unprefixed = new WithCte(walked, lhs, cteRef, rhs, SqlStr.commaSep)

protected def selectToSimpleSelect() = this.subquery

override def map[Q2, R2](f: Q => Q2)(implicit qr2: Queryable.Row[Q2, R2]): Select[Q2, R2] = {
new WithCte(lhs, lhsSubQuery, rhs.map(f))
new WithCte(walked, lhs, cteRef, rhs.map(f))
}

override def filter(f: Q => Expr[Boolean]): Select[Q, R] = {
new WithCte(rhs.filter(f), lhsSubQuery, rhs)
new WithCte(walked, rhs.filter(f), cteRef, rhs)
}

override def sortBy(f: Q => Expr[_]) = new WithCte(lhs, lhsSubQuery, rhs.sortBy(f))
override def sortBy(f: Q => Expr[_]) = new WithCte(walked, lhs, cteRef, rhs.sortBy(f))

override def drop(n: Int) = new WithCte(lhs, lhsSubQuery, rhs.drop(n))
override def take(n: Int) = new WithCte(lhs, lhsSubQuery, rhs.take(n))
override def drop(n: Int) = new WithCte(walked, lhs, cteRef, rhs.drop(n))
override def take(n: Int) = new WithCte(walked, lhs, cteRef, rhs.take(n))

override protected def selectRenderer(prevContext: Context) =
new WithCte.Renderer(withPrefix, this, prevContext)
new WithCte.Renderer(walked, withPrefix, this, prevContext)

override protected def selectExprAliases(prevContext: Context): Map[Expr.Identity, SqlStr] = {
SubqueryRef.Wrapped.exprAliases(rhs, prevContext)
Expand Down Expand Up @@ -92,36 +93,32 @@ object WithCte {
}
}

class Renderer[Q, R](withPrefix: SqlStr, query: WithCte[Q, R], prevContext: Context)
extends SubqueryRef.Wrapped.Renderer {
class Renderer[Q, R](
walked: Queryable.Walked,
withPrefix: SqlStr,
query: WithCte[Q, R],
prevContext: Context
) extends SubqueryRef.Wrapped.Renderer {
def render(liveExprs: LiveSqlExprs) = {
val walked =
query.lhs.qr
.asInstanceOf[Queryable[Any, Any]]
.walkLabelsAndExprs(WithSqlExpr.get(query.lhs))

val newExprNaming = ExprsToSql.selectColumnReferences(walked, prevContext)
val newContext = Context.compute(prevContext, Seq(query.cteRef), None)
val cteName = SqlStr.raw(newContext.fromNaming(query.cteRef))
val leadingSpace = query.rhs match {
case w: WithCte[Q, R] => SqlStr.empty
case r => sql" "
}

val newContext = Context.compute(prevContext, Seq(query.lhsSubQuery), None)
val cteName = SqlStr.raw(newContext.fromNaming(query.lhsSubQuery))
val rhsSql = SqlStr.flatten(
(query.rhs match {
case w: WithCte[Q, R] => SqlStr.empty
case r => sql" "
}) +
SubqueryRef.Wrapped
.renderer(
query.rhs match {
case w: WithCte[Q, R] => w.unprefixed
case r => r
},
newContext.withExprNaming(
newContext.exprNaming ++
newExprNaming.map { case (k, v) => (k, sql"$cteName.$v") }
)
)
.render(liveExprs)
)
val wrapped = SubqueryRef.Wrapped
.renderer(
query.rhs match {
case w: WithCte[Q, R] => w.unprefixed
case r => r
},
newContext
)
.render(liveExprs)

val rhsSql = SqlStr.flatten(leadingSpace + wrapped)
val rhsReferenced = LiveSqlExprs.some(rhsSql.referencedExprs.toSet)
val lhsSql =
SubqueryRef.Wrapped.renderer(query.lhs, prevContext).render(rhsReferenced)
Expand Down
8 changes: 4 additions & 4 deletions scalasql/src/dialects/DbApiQueryOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ class DbApiQueryOps(dialect: DialectTypeMappers) {
def values[Q, R](ts: Seq[R])(implicit qr: Queryable.Row[Q, R]): Values[Q, R] =
new scalasql.query.Values(ts)

import scalasql.core.SqlStr.SqlStringSyntax

/** Generates a SQL `WITH` common table expression clause */
def withCte[Q, Q2, R, R2](
lhs: Select[Q, R]
)(block: Select[Q, R] => Select[Q2, R2])(implicit qr: Queryable.Row[Q2, R2]): Select[Q2, R2] = {
val lhsSubQueryRef = new WithCteRef()

val walked = lhs.qr.walkLabelsAndExprs(WithSqlExpr.get(lhs))
val lhsSubQueryRef = new WithCteRef(lhs.qr.walkLabelsAndExprs(WithSqlExpr.get(lhs)))
val rhsSelect = new WithCte.Proxy[Q, R](lhs, lhsSubQueryRef, lhs.qr, dialect)

new WithCte(lhs, lhsSubQueryRef, block(rhsSelect))
new WithCte(walked, lhs, lhsSubQueryRef, block(rhsSelect))
}
}

0 comments on commit 9b272a5

Please sign in to comment.