diff --git a/modules/core/src/main/scala/doobie/util/update.scala b/modules/core/src/main/scala/doobie/util/update.scala index 5f3aa9adf..dd522c3f5 100644 --- a/modules/core/src/main/scala/doobie/util/update.scala +++ b/modules/core/src/main/scala/doobie/util/update.scala @@ -85,6 +85,14 @@ object update { def run(a: A): ConnectionIO[Int] = IHC.executeWithoutResultSet(prepareExecutionForRun(a), loggingForRun(a)) + /** Just like `run` but allowing to alter `PreparedExecutionWithoutProcessStep`. + */ + def runAlteringExecution( + a: A, + fn: PreparedExecutionWithoutProcessStep[Int] => PreparedExecutionWithoutProcessStep[Int] + ): ConnectionIO[Int] = + IHC.executeWithoutResultSet(fn(prepareExecutionForRun(a)), loggingForRun(a)) + private def prepareExecutionForRun(a: A): PreparedExecutionWithoutProcessStep[Int] = PreparedExecutionWithoutProcessStep( create = IFC.prepareStatement(sql), @@ -110,6 +118,14 @@ object update { def updateMany[F[_]: Foldable](fa: F[A]): ConnectionIO[Int] = IHC.executeWithoutResultSet(prepareExecutionForUpdateMany(fa), loggingInfoForUpdateMany(fa)) + /** Just like `updateMany` but allowing to alter `PreparedExecutionWithoutProcessStep`. + */ + def updateManyAlteringExecution[F[_]: Foldable]( + fa: F[A], + fn: PreparedExecutionWithoutProcessStep[Int] => PreparedExecutionWithoutProcessStep[Int] + ): ConnectionIO[Int] = + IHC.executeWithoutResultSet(fn(prepareExecutionForUpdateMany(fa)), loggingInfoForUpdateMany(fa)) + private def prepareExecutionForUpdateMany[F[_]: Foldable](fa: F[A]): PreparedExecutionWithoutProcessStep[Int] = PreparedExecutionWithoutProcessStep( create = IFC.prepareStatement(sql), @@ -177,6 +193,17 @@ object update { loggingInfoForUpdateWithGeneratedKeys(a) ) + /** Just like `withUniqueGeneratedKeys` but allowing to alter `PreparedExecution`. + */ + def withUniqueGeneratedKeysAlteringExecution[K: Read](columns: String*)( + a: A, + fn: PreparedExecution[K] => PreparedExecution[K] + ): ConnectionIO[K] = + IHC.executeWithResultSet( + fn(prepareExecutionForWithUniqueGeneratedKeys(columns*)(a)), + loggingInfoForUpdateWithGeneratedKeys(a) + ) + private def prepareExecutionForWithUniqueGeneratedKeys[K: Read](columns: String*)(a: A): PreparedExecution[K] = PreparedExecution( create = IFC.prepareStatement(sql, columns.toArray), diff --git a/modules/core/src/test/scala/doobie/util/UpdateSuite.scala b/modules/core/src/test/scala/doobie/util/UpdateSuite.scala new file mode 100644 index 000000000..3bd6114bf --- /dev/null +++ b/modules/core/src/test/scala/doobie/util/UpdateSuite.scala @@ -0,0 +1,75 @@ +// Copyright (c) 2013-2020 Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package doobie.util + +import cats.syntax.all.* +import cats.effect.IO +import cats.effect.unsafe.implicits.global +import doobie.{Transactor, Update} +import doobie.free.preparedstatement as IFPS + +class UpdateSuite extends munit.FunSuite { + val xa: Transactor[IO] = Transactor.fromDriverManager[IO]( + driver = "org.h2.Driver", + url = "jdbc:h2:mem:;DB_CLOSE_DELAY=-1", + user = "sa", + password = "", + logHandler = None + ) + + test("Update runAlteringExecution") { + import doobie.implicits.* + var didRun = false + (for { + _ <- sql"create temp table t1 (a int)".update.run + res <- Update[Int]("insert into t1 (a) values (?)").runAlteringExecution( + 1, + pe => pe.copy(exec = IFPS.delay { didRun = true } *> pe.exec)) + } yield { + assertEquals(res, 1) + }) + .transact(xa) + .unsafeRunSync() + + assert(didRun) + } + + test("Update updateManyAlteringExecution") { + import doobie.implicits.* + var didRun = false + (for { + _ <- sql"create temp table t1 (a int)".update.run + res <- Update[Int]("insert into t1 (a) values (?)").updateManyAlteringExecution( + List(2, 4, 6, 8), + pe => pe.copy(exec = IFPS.delay { didRun = true } *> pe.exec)) + } yield { + assertEquals(res, 4) + }) + .transact(xa) + .unsafeRunSync() + + assert(didRun) + } + + test("Update withUniqueGeneratedKeysAlteringExecution") { + import doobie.implicits.* + var didRun = false + (for { + _ <- sql"create temp table t1 (a int, b int)".update.run + res <- Update[(Int, Int)]("insert into t1 (a, b) values (?, ?)") + .withUniqueGeneratedKeysAlteringExecution[(Int, Int)]("a", "b")( + (5, 6), + pe => pe.copy(exec = IFPS.delay { didRun = true } *> pe.exec) + ) + } yield { + assertEquals(res, (5, 6)) + }) + .transact(xa) + .unsafeRunSync() + + assert(didRun) + } + +}