Skip to content

Commit

Permalink
improve DbApi.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Dec 10, 2023
1 parent 58d29fd commit 9745c5f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 37 deletions.
6 changes: 4 additions & 2 deletions scalasql/core/src/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ trait Config {
}

/**
* Configures the underlying JDBC connection's `setFetchSize`
* Configures the underlying JDBC connection's `setFetchSize`. Can be overriden
* on a per-query basis by passing `fetchSize = n` to `db.run`
*/
def defaultFetchSize: Int = -1

/**
* Configures the underlying JDBC connection's `setQueryTimeout`
* Configures the underlying JDBC connection's `setQueryTimeout`. Can be overriden
* on a per-query basis by passing `queryTimeoutSeconds = n` to `db.run`
*/
def defaultQueryTimeoutSeconds: Int = -1

Expand Down
87 changes: 52 additions & 35 deletions scalasql/core/src/DbApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,17 @@ object DbApi {
* create savepoints, or roll back the transaction.
*/
trait Txn extends DbApi {
/**
* Creates a SQL Savepoint that is active within the given block; automatically
* releases the savepoint if the block completes successfully and rolls it back
* if the block terminates with an exception, and allows you to roll back the
* savepoint manually via the [[DbApi.Savepoint]] parameter passed to that block
*/
def savepoint[T](block: DbApi.Savepoint => T): T

/**
* Tolls back any active Savepoints and then rolls back this Transaction
*/
def rollback(): Unit
}

Expand All @@ -125,6 +135,19 @@ object DbApi {
def rollback(): Unit
}

//
// run
// |
// runRaw runSql +---------+---------+
// | | | |
// streamRaw streamSql stream updateRaw updateSql
// | | | | |
// | streamFlattened | |
// | | +----+----+
// +-----+------+ |
// | runRawUpdate0
// streamRaw0

class Impl(
connection: java.sql.Connection,
config: Config,
Expand All @@ -139,7 +162,7 @@ object DbApi {
savepointStack.append(savepoint)

try {
val res = block(new DbApi.SavepointApiImpl(savepoint, () => rollbackSavepoint(savepoint)))
val res = block(new DbApi.SavepointImpl(savepoint, () => rollbackSavepoint(savepoint)))
if (savepointStack.lastOption.exists(_ eq savepoint)) {
// Only release if this savepoint has not been rolled back,
// directly or indirectly
Expand Down Expand Up @@ -214,24 +237,6 @@ object DbApi {
variables.map(v => (s: PreparedStatement, n: Int) => s.setObject(n, v))
}

def runRawQuery0[T](
sql: String,
variables: Seq[(PreparedStatement, Int) => Unit],
fetchSize: Int,
queryTimeoutSeconds: Int,
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
)(
block: ResultSet => T
): T = {
val statement = connection.prepareStatement(sql)

for ((variable, i) <- variables.iterator.zipWithIndex) variable(statement, i + 1)
configureRunCloseStatement(statement, fetchSize, queryTimeoutSeconds, sql, fileName, lineNum)(
s => block(s.executeQuery())
)
}

def runSql[R](
sql: SqlStr,
fetchSize: Int = -1,
Expand All @@ -240,8 +245,7 @@ object DbApi {
implicit qr: Queryable.Row[_, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): IndexedSeq[R] =
streamSql(sql, fetchSize, queryTimeoutSeconds).toVector
): IndexedSeq[R] = streamSql(sql, fetchSize, queryTimeoutSeconds).toVector

def runRawUpdate0(
sql: String,
Expand Down Expand Up @@ -301,8 +305,14 @@ object DbApi {
variables: Seq[Any] = Nil,
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(implicit fileName: sourcecode.FileName, lineNum: sourcecode.Line): Int =
runRawUpdate0(sql, anySeqPuts(variables), fetchSize, queryTimeoutSeconds, fileName, lineNum)
)(implicit fileName: sourcecode.FileName, lineNum: sourcecode.Line): Int = runRawUpdate0(
sql,
anySeqPuts(variables),
fetchSize,
queryTimeoutSeconds,
fileName,
lineNum
)

def updateSql(sql: SqlStr, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
implicit fileName: sourcecode.FileName,
Expand Down Expand Up @@ -363,11 +373,9 @@ object DbApi {
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): Generator[R] = {
val flattened = SqlStr.flatten(sql)

streamFlattened(
qr.construct,
flattened,
SqlStr.flatten(sql),
fetchSize,
queryTimeoutSeconds,
fileName,
Expand Down Expand Up @@ -402,15 +410,24 @@ object DbApi {
lineNum: sourcecode.Line
) = new Generator[R] {
def generate(handleItem: R => Generator.Action): Generator.Action = {
val statement = connection.prepareStatement(sql)
for ((setVariable, i) <- variables.iterator.zipWithIndex) setVariable(statement, i + 1)

runRawQuery0(sql, variables, fetchSize, queryTimeoutSeconds, fileName, lineNum) {
resultSet =>
var action: Generator.Action = Generator.Continue
while (resultSet.next() && action == Generator.Continue) {
val rowRes = construct(new Queryable.ResultSetIterator(resultSet))
action = handleItem(rowRes)
}
action
configureRunCloseStatement(
statement,
fetchSize,
queryTimeoutSeconds,
sql,
fileName,
lineNum
) { stmt =>
val resultSet = stmt.executeQuery()
var action: Generator.Action = Generator.Continue
while (resultSet.next() && action == Generator.Continue) {
val rowRes = construct(new Queryable.ResultSetIterator(resultSet))
action = handleItem(rowRes)
}
action
}
}
}
Expand Down Expand Up @@ -439,7 +456,7 @@ object DbApi {
def close() = connection.close()
}

class SavepointApiImpl(savepoint: java.sql.Savepoint, rollback0: () => Unit) extends Savepoint {
class SavepointImpl(savepoint: java.sql.Savepoint, rollback0: () => Unit) extends Savepoint {
def savepointId = savepoint.getSavepointId
def savepointName = savepoint.getSavepointName
def rollback() = rollback0()
Expand Down

0 comments on commit 9745c5f

Please sign in to comment.