Skip to content

Commit

Permalink
tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Dec 11, 2023
1 parent b7013ea commit 591286e
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 25 deletions.
8 changes: 4 additions & 4 deletions scalasql/query/src/CompoundSelect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class CompoundSelect[Q, R](
override protected def selectRenderer(prevContext: Context) =
new CompoundSelect.Renderer(this, prevContext)

override protected def selectLhsMap(prevContext: Context): Map[Expr.Identity, SqlStr] = {
SelectBase.lhsMap(lhs, prevContext)
override protected def selectColumnExprs(prevContext: Context): Map[Expr.Identity, SqlStr] = {
SelectBase.columnExprs(lhs, prevContext)
}
}

Expand All @@ -102,7 +102,7 @@ object CompoundSelect {
import query.dialect._
lazy val lhsToSqlQuery = SimpleSelect.getRenderer(query.lhs, prevContext)

lazy val lhsLhsMap = SelectBase.lhsMap(query.lhs, prevContext)
lazy val lhsLhsMap = SelectBase.columnExprs(query.lhs, prevContext)
lazy val context = lhsToSqlQuery.context
.withExprNaming(lhsToSqlQuery.context.exprNaming ++ lhsLhsMap)

Expand Down Expand Up @@ -130,7 +130,7 @@ object CompoundSelect {
val compound = SqlStr.optSeq(query.compoundOps) { compoundOps =>
val compoundStrs = compoundOps.map { op =>
val rhsToSqlQuery = SimpleSelect.getRenderer(op.rhs, prevContext)
lazy val rhsLhsMap = SelectBase.lhsMap(op.rhs, prevContext)
lazy val rhsLhsMap = SelectBase.columnExprs(op.rhs, prevContext)
// We match up the RHS SimpleSelect's lhsMap with the LHS SimpleSelect's lhsMap,
// because the expressions in the CompoundSelect's lhsMap correspond to those
// belonging to the LHS SimpleSelect, but we need the corresponding expressions
Expand Down
2 changes: 1 addition & 1 deletion scalasql/query/src/From.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TableRef(val value: Table.Base) extends From {
class SubqueryRef(val value: SelectBase, val qr: Queryable[_, _]) extends From {
def fromRefPrefix(prevContext: Context): String = "subquery"

def fromExprAliases(prevContext: Context) = SelectBase.lhsMap(value, prevContext)
def fromExprAliases(prevContext: Context) = SelectBase.columnExprs(value, prevContext)

def renderSql(name: SqlStr, prevContext: Context, liveExprs: LiveSqlExprs) = {
val renderSql = SelectBase.renderer(value, prevContext)
Expand Down
6 changes: 2 additions & 4 deletions scalasql/query/src/OnConflict.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ class OnConflict[Q, R](query: Query[R] with InsertReturnable[Q], expr: Q, table:

object OnConflict {
class Ignore[Q, R](
query: Query[R] with InsertReturnable[Q],
protected val query: Query[R] with InsertReturnable[Q],
columns: Seq[Column[_]],
val table: TableRef
) extends Query.DelegateQuery[R]
with InsertReturnable[Q] {
protected def expr = WithSqlExpr.get(query)
protected def queryDelegate = query
protected def renderSql(ctx: Context) = {
val str = Renderable.renderSql(query)(ctx)
str + sql" ON CONFLICT (${SqlStr.join(columns.map(c => SqlStr.raw(c.name)), SqlStr.commaSep)}) DO NOTHING"
Expand All @@ -34,14 +33,13 @@ object OnConflict {
}

class Update[Q, R](
query: Query[R] with InsertReturnable[Q],
protected val query: Query[R] with InsertReturnable[Q],
columns: Seq[Column[_]],
updates: Seq[Column.Assignment[_]],
val table: TableRef
) extends Query.DelegateQuery[R]
with InsertReturnable[Q] {
protected def expr = WithSqlExpr.get(query)
protected def queryDelegate = query
protected def renderSql(ctx: Context) = {
implicit val implicitCtx = Context.compute(ctx, Nil, Some(table))
val str = Renderable.renderSql(query)
Expand Down
13 changes: 6 additions & 7 deletions scalasql/query/src/Query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ object Query {
* most of the abstract methods to it
*/
trait DelegateQuery[R] extends scalasql.query.Query[R] {
protected def queryDelegate: Query[_]
protected def queryWalkLabels() = queryDelegate.queryWalkLabels()
protected def queryWalkExprs() = queryDelegate.queryWalkExprs()
protected override def queryIsSingleRow = queryDelegate.queryIsSingleRow
protected override def queryIsExecuteUpdate = queryDelegate.queryIsExecuteUpdate
protected def query: Query[_]
protected def queryWalkLabels() = query.queryWalkLabels()
protected def queryWalkExprs() = query.queryWalkExprs()
protected override def queryIsSingleRow = query.queryIsSingleRow
protected override def queryIsExecuteUpdate = query.queryIsExecuteUpdate
}

/**
Expand Down Expand Up @@ -77,8 +77,7 @@ object Query {
/**
* A [[Query]] that wraps another [[Query]] but sets [[queryIsSingleRow]] to `true`
*/
class Single[R](query: Query[Seq[R]]) extends Query.DelegateQuery[R] {
protected def queryDelegate = query
class Single[R](protected val query: Query[Seq[R]]) extends Query.DelegateQuery[R] {
protected override def queryIsSingleRow: Boolean = true

protected def renderSql(ctx: Context): SqlStr = Renderable.renderSql(query)(ctx)
Expand Down
4 changes: 2 additions & 2 deletions scalasql/query/src/Select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ object Select {
override protected def selectRenderer(prevContext: Context): SelectBase.Renderer =
SelectBase.renderer(selectToSimpleSelect(), prevContext)

override protected def selectLhsMap(prevContext: Context): Map[Expr.Identity, SqlStr] =
SelectBase.lhsMap(selectToSimpleSelect(), prevContext)
override protected def selectColumnExprs(prevContext: Context): Map[Expr.Identity, SqlStr] =
SelectBase.columnExprs(selectToSimpleSelect(), prevContext)

override protected def selectToSimpleSelect(): SimpleSelect[Q, R]

Expand Down
4 changes: 2 additions & 2 deletions scalasql/query/src/SelectBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package scalasql.query
import scalasql.core.{Context, LiveSqlExprs, Expr, SqlStr}

trait SelectBase {
protected def selectLhsMap(prevContext: Context): Map[Expr.Identity, SqlStr]
protected def selectColumnExprs(prevContext: Context): Map[Expr.Identity, SqlStr]
protected def selectRenderer(prevContext: Context): SelectBase.Renderer
}
object SelectBase {
def lhsMap(s: SelectBase, prevContext: Context) = s.selectLhsMap(prevContext)
def columnExprs(s: SelectBase, prevContext: Context) = s.selectColumnExprs(prevContext)
def renderer(s: SelectBase, prevContext: Context) = s.selectRenderer(prevContext)

trait Renderer {
Expand Down
2 changes: 1 addition & 1 deletion scalasql/query/src/SimpleSelect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class SimpleSelect[Q, R](
protected def selectRenderer(prevContext: Context): SimpleSelect.Renderer[_, _] =
new SimpleSelect.Renderer(this, prevContext)

protected def selectLhsMap(prevContext: Context): Map[Expr.Identity, SqlStr] = {
protected def selectColumnExprs(prevContext: Context): Map[Expr.Identity, SqlStr] = {

lazy val flattenedExpr = qr.walkLabelsAndExprs(expr)

Expand Down
2 changes: 1 addition & 1 deletion scalasql/query/src/Values.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Values[Q, R](val ts: Seq[R])(
override protected def selectRenderer(prevContext: Context): SelectBase.Renderer =
new Values.Renderer(this)(implicitly, prevContext)

override protected def selectLhsMap(prevContext: Context): Map[Expr.Identity, SqlStr] = {
override protected def selectColumnExprs(prevContext: Context): Map[Expr.Identity, SqlStr] = {
qr.walkExprs(expr)
.zipWithIndex
.map { case (e, i) => (Expr.identity(e), SqlStr.raw(columnName(i))) }
Expand Down
4 changes: 2 additions & 2 deletions scalasql/query/src/WithCte.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class WithCte[Q, R](
override protected def selectRenderer(prevContext: Context) =
new WithCte.Renderer(withPrefix, this, prevContext)

override protected def selectLhsMap(prevContext: Context): Map[Expr.Identity, SqlStr] = {
SelectBase.lhsMap(rhs, prevContext)
override protected def selectColumnExprs(prevContext: Context): Map[Expr.Identity, SqlStr] = {
SelectBase.columnExprs(rhs, prevContext)
}

override protected def queryConstruct(args: Queryable.ResultSetIterator): Seq[R] =
Expand Down
2 changes: 1 addition & 1 deletion scalasql/src/dialects/MySqlDialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ object MySqlDialect extends MySqlDialect {
updates: Seq[Column.Assignment[_]],
table: TableRef
) extends Query.DelegateQuery[R] {
protected def queryDelegate = insert.query
protected def query = insert.query
override def queryIsExecuteUpdate = true

protected def renderSql(ctx: Context) = {
Expand Down

0 comments on commit 591286e

Please sign in to comment.