|
| 1 | +import Logging |
| 2 | +import FluentKit |
| 3 | +import FluentBenchmark |
| 4 | +import FluentPostgresDriver |
| 5 | +import XCTest |
| 6 | +import PostgresKit |
| 7 | + |
| 8 | +final class FluentPostgresTransactionControlTests: XCTestCase { |
| 9 | + |
| 10 | + func testTransactionControl() throws { |
| 11 | + try (self.db as! TransactionControlDatabase).beginTransaction().wait() |
| 12 | + |
| 13 | + let todo1 = Todo(title: "Test") |
| 14 | + let todo2 = Todo(title: "Test2") |
| 15 | + try todo1.save(on: self.db).wait() |
| 16 | + try todo2.save(on: self.db).wait() |
| 17 | + |
| 18 | + try (self.db as! TransactionControlDatabase).commitTransaction().wait() |
| 19 | + |
| 20 | + let count = try Todo.query(on: self.db).count().wait() |
| 21 | + XCTAssertEqual(count, 2) |
| 22 | + } |
| 23 | + |
| 24 | + func testRollback() throws { |
| 25 | + try (self.db as! TransactionControlDatabase).beginTransaction().wait() |
| 26 | + |
| 27 | + let todo1 = Todo(title: "Test") |
| 28 | + |
| 29 | + try todo1.save(on: self.db).wait() |
| 30 | + |
| 31 | + let duplicate = Todo(title: "Test") |
| 32 | + var errorCaught = false |
| 33 | + |
| 34 | + do { |
| 35 | + try duplicate.create(on: self.db).wait() |
| 36 | + } catch { |
| 37 | + errorCaught = true |
| 38 | + try (self.db as! TransactionControlDatabase).rollbackTransaction().wait() |
| 39 | + } |
| 40 | + |
| 41 | + if !errorCaught { |
| 42 | + try (self.db as! TransactionControlDatabase).commitTransaction().wait() |
| 43 | + } |
| 44 | + |
| 45 | + XCTAssertTrue(errorCaught) |
| 46 | + let count2 = try Todo.query(on: self.db).count().wait() |
| 47 | + XCTAssertEqual(count2, 0) |
| 48 | + } |
| 49 | + |
| 50 | + var benchmarker: FluentBenchmarker { |
| 51 | + return .init(databases: self.dbs) |
| 52 | + } |
| 53 | + var eventLoopGroup: EventLoopGroup! |
| 54 | + var threadPool: NIOThreadPool! |
| 55 | + var dbs: Databases! |
| 56 | + var db: Database { |
| 57 | + self.benchmarker.database |
| 58 | + } |
| 59 | + var postgres: PostgresDatabase { |
| 60 | + self.db as! PostgresDatabase |
| 61 | + } |
| 62 | + |
| 63 | + override func setUpWithError() throws { |
| 64 | + try super.setUpWithError() |
| 65 | + |
| 66 | + XCTAssert(isLoggingConfigured) |
| 67 | + self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) |
| 68 | + self.threadPool = NIOThreadPool(numberOfThreads: 1) |
| 69 | + self.dbs = Databases(threadPool: threadPool, on: self.eventLoopGroup) |
| 70 | + |
| 71 | + self.dbs.use(.testPostgres(subconfig: "A"), as: .a) |
| 72 | + self.dbs.use(.testPostgres(subconfig: "B"), as: .b) |
| 73 | + |
| 74 | + let a = self.dbs.database(.a, logger: Logger(label: "test.fluent.a"), on: self.eventLoopGroup.next()) |
| 75 | + _ = try (a as! PostgresDatabase).query("drop schema public cascade").wait() |
| 76 | + _ = try (a as! PostgresDatabase).query("create schema public").wait() |
| 77 | + |
| 78 | + let b = self.dbs.database(.b, logger: Logger(label: "test.fluent.b"), on: self.eventLoopGroup.next()) |
| 79 | + _ = try (b as! PostgresDatabase).query("drop schema public cascade").wait() |
| 80 | + _ = try (b as! PostgresDatabase).query("create schema public").wait() |
| 81 | + |
| 82 | + try CreateTodo().prepare(on: self.db).wait() |
| 83 | + } |
| 84 | + |
| 85 | + override func tearDownWithError() throws { |
| 86 | + try CreateTodo().revert(on: self.db).wait() |
| 87 | + self.dbs.shutdown() |
| 88 | + try self.threadPool.syncShutdownGracefully() |
| 89 | + try self.eventLoopGroup.syncShutdownGracefully() |
| 90 | + try super.tearDownWithError() |
| 91 | + } |
| 92 | + |
| 93 | + final class Todo: Model { |
| 94 | + static let schema = "todos" |
| 95 | + |
| 96 | + @ID |
| 97 | + var id: UUID? |
| 98 | + |
| 99 | + @Field(key: "title") |
| 100 | + var title: String |
| 101 | + |
| 102 | + init() { } |
| 103 | + init(title: String) { self.title = title; id = nil } |
| 104 | + } |
| 105 | + |
| 106 | + struct CreateTodo: Migration { |
| 107 | + func prepare(on database: Database) -> EventLoopFuture<Void> { |
| 108 | + return database.schema("todos") |
| 109 | + .id() |
| 110 | + .field("title", .string, .required) |
| 111 | + .unique(on: "title") |
| 112 | + .create() |
| 113 | + } |
| 114 | + |
| 115 | + func revert(on database: Database) -> EventLoopFuture<Void> { |
| 116 | + return database.schema("todos").delete() |
| 117 | + } |
| 118 | + } |
| 119 | +} |
0 commit comments