diff --git a/modules/core/src/main/scala/doobie/util/query.scala b/modules/core/src/main/scala/doobie/util/query.scala index fc351202e..7a8453c7e 100644 --- a/modules/core/src/main/scala/doobie/util/query.scala +++ b/modules/core/src/main/scala/doobie/util/query.scala @@ -10,17 +10,17 @@ import cats.data.NonEmptyList import doobie.free.connection.ConnectionIO import doobie.free.preparedstatement.PreparedStatementIO import doobie.free.resultset.ResultSetIO +import doobie.free.{connection as IFC, preparedstatement as IFPS} +import doobie.hi.{connection as IHC, preparedstatement as IHPS, resultset as IHRS} +import doobie.util.MultiVersionTypeSupport.=:= import doobie.util.analysis.Analysis import doobie.util.compat.FactoryCompat +import doobie.util.fragment.Fragment import doobie.util.log.{LoggingInfo, Parameters} import doobie.util.pos.Pos -import doobie.free.{connection as IFC, preparedstatement as IFPS} -import doobie.hi.{connection as IHC, preparedstatement as IHPS, resultset as IHRS} import fs2.Stream -import scala.collection.immutable.Map -import doobie.util.MultiVersionTypeSupport.=:= -import doobie.util.fragment.Fragment +import java.sql.{PreparedStatement, ResultSet} /** Module defining queries parameterized by input and output types. */ object query { @@ -112,10 +112,16 @@ object query { * @group Results */ def toMap[K, V](a: A)(implicit ev: B =:= (K, V), f: FactoryCompat[(K, V), Map[K, V]]): ConnectionIO[Map[K, V]] = - toConnectionIO( - a, - IHRS.buildPair[Map, K, V](f, read.map(ev)) - ) + toConnectionIO(a, IHRS.buildPair[Map, K, V](f, read.map(ev))) + + /** + * Just like `toMap` but allowing to alter `PreparedExecution`. + */ + def toMapAlteringExecution[K, V](a: A, fn: PreparedExecutionUpdate[Map[K, V]])(implicit + ev: B =:= (K, V), + f: FactoryCompat[(K, V), Map[K, V]] + ): ConnectionIO[Map[K, V]] = + toConnectionIOAlteringExecution(a, IHRS.buildPair[Map, K, V](f, read.map(ev)), fn) /** Apply the argument `a` to construct a program in `[[doobie.free.connection.ConnectionIO ConnectionIO]]` yielding * an `F[B]` accumulated via `MonadPlus` append. This method is more general but less efficient than `to`. @@ -146,15 +152,15 @@ object query { def nel(a: A): ConnectionIO[NonEmptyList[B]] = toConnectionIO(a, IHRS.nel[B]) - private def toConnectionIO[C](a: A, rsio: ResultSetIO[C]): ConnectionIO[C] = { - IHC.executeWithResultSet( - create = IFC.prepareStatement(sql), - prep = IHPS.set(a), - exec = IFPS.executeQuery, - process = rsio, - loggingInfo = mkLoggingInfo(a) - ) - } + private def toConnectionIO[C](a: A, rsio: ResultSetIO[C]): ConnectionIO[C] = + PreparedExecution(sql, a, rsio).execute(mkLoggingInfo(a)) + + private def toConnectionIOAlteringExecution[C]( + a: A, + rsio: ResultSetIO[C], + fn: PreparedExecutionUpdate[C] + ): ConnectionIO[C] = + fn(PreparedExecution(sql, a, rsio)).execute(mkLoggingInfo(a)) private def mkLoggingInfo(a: A): LoggingInfo = LoggingInfo( @@ -252,6 +258,33 @@ object query { } + type PreparedExecutionUpdate[A] = PreparedExecution[A] => PreparedExecution[A] + + final case class PreparedExecution[C]( + create: ConnectionIO[PreparedStatement], + prep: PreparedStatementIO[Unit], + exec: PreparedStatementIO[ResultSet], + process: ResultSetIO[C] + ) { ctx => + private[util] def execute(loggingInfo: LoggingInfo) = IHC.executeWithResultSet( + create = ctx.create, + prep = ctx.prep, + exec = ctx.exec, + process = ctx.process, + loggingInfo = loggingInfo + ) + } + + private object PreparedExecution { + def apply[C, A](sql: String, a: A, rsio: ResultSetIO[C])(implicit w: Write[A]): PreparedExecution[C] = + PreparedExecution( + create = IFC.prepareStatement(sql), + prep = IHPS.set(a), + exec = IFPS.executeQuery, + process = rsio + ) + } + /** An abstract query closed over its input arguments and yielding values of type `B`, without a specified * disposition. Methods provided on `[[Query0]]` allow the query to be interpreted as a stream or program in * `CollectionIO`. diff --git a/modules/core/src/test/scala/doobie/util/QuerySuite.scala b/modules/core/src/test/scala/doobie/util/QuerySuite.scala index 242792b5c..987858f01 100644 --- a/modules/core/src/test/scala/doobie/util/QuerySuite.scala +++ b/modules/core/src/test/scala/doobie/util/QuerySuite.scala @@ -5,8 +5,10 @@ package doobie.util import cats.effect.IO -import doobie.*, doobie.implicits.* -import scala.Predef.* +import cats.syntax.all.* +import doobie.* +import doobie.hi.resultset as IHRS +import doobie.implicits.* class QuerySuite extends munit.FunSuite { @@ -61,6 +63,20 @@ class QuerySuite extends munit.FunSuite { assertEquals(q.contramap[Int](n => "bar" * n).to[List](1).transact(xa).unsafeRunSync(), Nil) } + test("Query toMapAlteringExecution (result set operations)") { + var didRun = false + + pairQuery.toMapAlteringExecution[String, Int]( + "x", + { preparedExec => + val process = IHRS.delay { didRun = true } *> preparedExec.process + preparedExec.copy(process = process) + }) + .transact(xa).unsafeRunSync() + + assert(didRun) + } + test("Query0 from Query (non-empty) to") { assertEquals(q.toQuery0("foo").to[List].transact(xa).unsafeRunSync(), List(123)) }