Skip to content

Commit

Permalink
Fix LIMIT and OFFSET for MS SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
kiendang committed Sep 10, 2024
1 parent 3a6af73 commit 708cff5
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 23 deletions.
11 changes: 9 additions & 2 deletions scalasql/query/src/CompoundSelect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ object CompoundSelect {
// columns are duplicates or not, and thus what final set of rows is returned
lazy val preserveAll = query.compoundOps.exists(_.op != "UNION ALL")

def render(liveExprs: LiveExprs) = {
protected def prerender(liveExprs: LiveExprs) = {
val innerLiveExprs =
if (preserveAll) LiveExprs.none
else liveExprs.map(_ ++ newReferencedExpressions)
Expand All @@ -138,7 +138,14 @@ object CompoundSelect {
SqlStr.join(compoundStrs)
}

lhsStr + compound + sortOpt + limitOpt + offsetOpt
(lhsStr, compound, sortOpt, limitOpt, offsetOpt)
}

def render(liveExprs: LiveExprs) = {
prerender(liveExprs) match {
case (lhsStr, compound, sortOpt, limitOpt, offsetOpt) =>
lhsStr + compound + sortOpt + limitOpt + offsetOpt
}
}
def orderToSqlStr(newCtx: Context) =
CompoundSelect.orderToSqlStr(query.orderBy, newCtx, gap = true)
Expand Down
29 changes: 26 additions & 3 deletions scalasql/src/dialects/MsSqlDialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import scalasql.core.SqlStr.{Renderable, SqlStringSyntax}
import scalasql.operations.{ConcatOps, MathOps, TrimOps}

import java.time.{Instant, LocalDateTime, OffsetDateTime}
import scalasql.core.LiveExprs

trait MsSqlDialect extends Dialect {
protected def dialectCastParams = false
Expand Down Expand Up @@ -166,7 +167,11 @@ object MsSqlDialect extends MsSqlDialect {
where,
groupBy0
)
with Select[Q, R]
with Select[Q, R] {
override def take(n: Int): scalasql.query.Select[Q,R] = throw new Exception(".take must follow .sortBy")

override def drop(n: Int): scalasql.query.Select[Q,R] = throw new Exception(".drop must follow .sortBy")
}

class CompoundSelect[Q, R](
lhs: scalasql.query.SimpleSelect[Q, R],
Expand All @@ -177,6 +182,11 @@ object MsSqlDialect extends MsSqlDialect {
)(implicit qr: Queryable.Row[Q, R])
extends scalasql.query.CompoundSelect(lhs, compoundOps, orderBy, limit, offset)
with Select[Q, R] {
override def take(n: Int): scalasql.query.Select[Q, R] = copy(
limit = Some(limit.fold(n)(math.min(_, n))),
offset = offset.orElse(Some(0))
)

protected override def selectRenderer(prevContext: Context): SubqueryRef.Wrapped.Renderer =
new CompoundSelectRenderer(this, prevContext)
}
Expand All @@ -185,9 +195,22 @@ object MsSqlDialect extends MsSqlDialect {
query: scalasql.query.CompoundSelect[Q, R],
prevContext: Context
) extends scalasql.query.CompoundSelect.Renderer(query, prevContext) {
override lazy val limitOpt = SqlStr.flatten(SqlStr.opt(query.limit) { limit =>
sql" FETCH FIRST $limit ROWS ONLY"
})

override lazy val limitOpt = SqlStr
.flatten(CompoundSelectRendererForceLimit.limitToSqlStr(query.limit, query.offset))
override lazy val offsetOpt = SqlStr.flatten(
SqlStr.opt(query.offset.orElse(Option.when(query.limit.nonEmpty)(0))) { offset =>
sql" OFFSET $offset ROWS"
}
)

override def render(liveExprs: LiveExprs): SqlStr = {
prerender(liveExprs) match {
case (lhsStr, compound, sortOpt, limitOpt, offsetOpt) =>
lhsStr + compound + sortOpt + offsetOpt + limitOpt
}
}

override def orderToSqlStr(newCtx: Context) = {
SqlStr.optSeq(query.orderBy) { orderBys =>
Expand Down
107 changes: 89 additions & 18 deletions scalasql/test/src/query/CompoundSelectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ trait CompoundSelectTests extends ScalaSqlSuite {

test("sortLimit") - checker(
query = Text { Product.select.sortBy(_.price).map(_.name).take(2) },
sql = "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?",
sqls = Seq(
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?",
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY"
),
value = Seq("Cookie", "Socks"),
docs = """
ScalaSql also supports various combinations of `.take` and `.drop`, translating to SQL
Expand All @@ -61,14 +64,18 @@ trait CompoundSelectTests extends ScalaSqlSuite {
query = Text { Product.select.sortBy(_.price).map(_.name).drop(2) },
sqls = Seq(
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ?",
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?"
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?",
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS"
),
value = Seq("Face Mask", "Skate Board", "Guitar", "Camera")
)

test("sortLimitTwiceHigher") - checker(
query = Text { Product.select.sortBy(_.price).map(_.name).take(2).take(3) },
sql = "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?",
sqls = Seq(
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?",
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY"
),
value = Seq("Cookie", "Socks"),
docs = """
Note that `.drop` and `.take` follow Scala collections' semantics, so calling e.g. `.take`
Expand All @@ -79,48 +86,68 @@ trait CompoundSelectTests extends ScalaSqlSuite {

test("sortLimitTwiceLower") - checker(
query = Text { Product.select.sortBy(_.price).map(_.name).take(2).take(1) },
sql = "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?",
sqls = Seq(
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?",
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY"
),
value = Seq("Cookie")
)

test("sortLimitOffset") - checker(
query = Text { Product.select.sortBy(_.price).map(_.name).drop(2).take(2) },
sql =
sqls = Seq(
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?",
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY"
),
value = Seq("Face Mask", "Skate Board")
)

test("sortLimitOffsetTwice") - checker(
query = Text { Product.select.sortBy(_.price).map(_.name).drop(2).drop(2).take(1) },
sql =
sqls = Seq(
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?",
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY"
),
value = Seq("Guitar")
)

test("sortOffsetLimit") - checker(
query = Text { Product.select.sortBy(_.price).map(_.name).drop(2).take(2) },
sql =
sqls = Seq(
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?",
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY"
),
value = Seq("Face Mask", "Skate Board")
)

test("sortLimitOffset") - checker(
query = Text { Product.select.sortBy(_.price).map(_.name).take(2).drop(1) },
sql =
sqls = Seq(
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?",
"SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY"
),
value = Seq("Socks")
)
}

test("distinct") - checker(
query = Text { Purchase.select.sortBy(_.total).desc.take(3).map(_.shippingInfoId).distinct },
sql = """
sqls = Seq(
"""
SELECT DISTINCT subquery0.res AS res
FROM (SELECT purchase0.shipping_info_id AS res
FROM purchase purchase0
ORDER BY purchase0.total DESC
LIMIT ?) subquery0
""",
""",
"""
SELECT DISTINCT subquery0.res AS res
FROM (SELECT purchase0.shipping_info_id AS res
FROM purchase purchase0
ORDER BY purchase0.total DESC
OFFSET ? ROWS FETCH FIRST ? ROWS ONLY) subquery0
"""
),
value = Seq(1, 2),
normalize = (x: Seq[Int]) => x.sorted,
docs = """
Expand All @@ -134,15 +161,26 @@ trait CompoundSelectTests extends ScalaSqlSuite {
Product.crossJoin().filter(_.id === p.productId).map(_.name)
}
},
sql = """
sqls = Seq(
"""
SELECT product1.name AS res
FROM (SELECT purchase0.product_id AS product_id, purchase0.total AS total
FROM purchase purchase0
ORDER BY total DESC
LIMIT ?) subquery0
CROSS JOIN product product1
WHERE (product1.id = subquery0.product_id)
""",
""",
"""
SELECT product1.name AS res
FROM (SELECT purchase0.product_id AS product_id, purchase0.total AS total
FROM purchase purchase0
ORDER BY total DESC
OFFSET ? ROWS FETCH FIRST ? ROWS ONLY) subquery0
CROSS JOIN product product1
WHERE (product1.id = subquery0.product_id)
"""
),
value = Seq("Camera", "Face Mask", "Guitar"),
normalize = (x: Seq[String]) => x.sorted,
docs = """
Expand All @@ -155,13 +193,22 @@ trait CompoundSelectTests extends ScalaSqlSuite {

test("sumBy") - checker(
query = Text { Purchase.select.sortBy(_.total).desc.take(3).sumBy(_.total) },
sql = """
sqls = Seq(
"""
SELECT SUM(subquery0.total) AS res
FROM (SELECT purchase0.total AS total
FROM purchase purchase0
ORDER BY total DESC
LIMIT ?) subquery0
""",
""",
"""
SELECT SUM(subquery0.total) AS res
FROM (SELECT purchase0.total AS total
FROM purchase purchase0
ORDER BY total DESC
OFFSET ? ROWS FETCH FIRST ? ROWS ONLY) subquery0
"""
),
value = 11788.0,
normalize = (x: Double) => x.round.toDouble
)
Expand All @@ -174,13 +221,22 @@ trait CompoundSelectTests extends ScalaSqlSuite {
.take(3)
.aggregate(p => (p.sumBy(_.total), p.avgBy(_.total)))
},
sql = """
sqls = Seq(
"""
SELECT SUM(subquery0.total) AS res_0, AVG(subquery0.total) AS res_1
FROM (SELECT purchase0.total AS total
FROM purchase purchase0
ORDER BY total DESC
LIMIT ?) subquery0
""",
""",
"""
SELECT SUM(subquery0.total) AS res_0, AVG(subquery0.total) AS res_1
FROM (SELECT purchase0.total AS total
FROM purchase purchase0
ORDER BY total DESC
OFFSET ? ROWS FETCH FIRST ? ROWS ONLY) subquery0
"""
),
value = (11788.0, 3929.0),
normalize = (x: (Double, Double)) => (x._1.round.toDouble, x._2.round.toDouble)
)
Expand Down Expand Up @@ -325,7 +381,8 @@ trait CompoundSelectTests extends ScalaSqlSuite {
.drop(4)
.take(4)
},
sql = """
sqls = Seq(
"""
SELECT LOWER(product0.name) AS res
FROM product product0
UNION ALL
Expand All @@ -337,7 +394,21 @@ trait CompoundSelectTests extends ScalaSqlSuite {
ORDER BY res
LIMIT ?
OFFSET ?
""",
""",
"""
SELECT LOWER(product0.name) AS res
FROM product product0
UNION ALL
SELECT LOWER(buyer0.name) AS res
FROM buyer buyer0
UNION
SELECT LOWER(product0.kebab_case_name) AS res
FROM product product0
ORDER BY res
OFFSET ? ROWS
FETCH FIRST ? ROWS ONLY
"""
),
value = Seq("guitar", "james bond", "li haoyi", "skate board")
)
}
Expand Down

0 comments on commit 708cff5

Please sign in to comment.