Skip to content

Commit a8b2839

Browse files
authored
Add ability to control transactions (#199)
* Conform Database to TransactionControlDatabase * Fix manifest for CI * Fix CI * Use release of FluentKit
1 parent 5230817 commit a8b2839

File tree

3 files changed

+143
-1
lines changed

3 files changed

+143
-1
lines changed

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ let package = Package(
1111
],
1212
dependencies: [
1313
.package(url: "https://github.com/vapor/async-kit.git", from: "1.2.0"),
14-
.package(url: "https://github.com/vapor/fluent-kit.git", from: "1.27.0"),
14+
.package(url: "https://github.com/vapor/fluent-kit.git", from: "1.31.0"),
1515
.package(url: "https://github.com/vapor/postgres-kit.git", from: "2.5.1"),
1616
],
1717
targets: [

Sources/FluentPostgresDriver/FluentPostgresDatabase.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,29 @@ extension _FluentPostgresDatabase: Database {
124124
}
125125
}
126126

127+
extension _FluentPostgresDatabase: TransactionControlDatabase {
128+
func beginTransaction() -> EventLoopFuture<Void> {
129+
self.database.withConnection { conn in
130+
self.logger.log(level: self.sqlLogLevel, "BEGIN")
131+
return conn.simpleQuery("BEGIN").map { _ in }
132+
}
133+
}
134+
135+
func commitTransaction() -> NIOCore.EventLoopFuture<Void> {
136+
self.database.withConnection { conn in
137+
self.logger.log(level: self.sqlLogLevel, "COMMIT")
138+
return conn.simpleQuery("COMMIT").map { _ in }
139+
}
140+
}
141+
142+
func rollbackTransaction() -> NIOCore.EventLoopFuture<Void> {
143+
self.database.withConnection { conn in
144+
self.logger.log(level: self.sqlLogLevel, "ROLLBACK")
145+
return conn.simpleQuery("ROLLBACK").map { _ in }
146+
}
147+
}
148+
}
149+
127150
extension _FluentPostgresDatabase: SQLDatabase {
128151
var dialect: SQLDialect {
129152
PostgresDialect()
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import Logging
2+
import FluentKit
3+
import FluentBenchmark
4+
import FluentPostgresDriver
5+
import XCTest
6+
import PostgresKit
7+
8+
final class FluentPostgresTransactionControlTests: XCTestCase {
9+
10+
func testTransactionControl() throws {
11+
try (self.db as! TransactionControlDatabase).beginTransaction().wait()
12+
13+
let todo1 = Todo(title: "Test")
14+
let todo2 = Todo(title: "Test2")
15+
try todo1.save(on: self.db).wait()
16+
try todo2.save(on: self.db).wait()
17+
18+
try (self.db as! TransactionControlDatabase).commitTransaction().wait()
19+
20+
let count = try Todo.query(on: self.db).count().wait()
21+
XCTAssertEqual(count, 2)
22+
}
23+
24+
func testRollback() throws {
25+
try (self.db as! TransactionControlDatabase).beginTransaction().wait()
26+
27+
let todo1 = Todo(title: "Test")
28+
29+
try todo1.save(on: self.db).wait()
30+
31+
let duplicate = Todo(title: "Test")
32+
var errorCaught = false
33+
34+
do {
35+
try duplicate.create(on: self.db).wait()
36+
} catch {
37+
errorCaught = true
38+
try (self.db as! TransactionControlDatabase).rollbackTransaction().wait()
39+
}
40+
41+
if !errorCaught {
42+
try (self.db as! TransactionControlDatabase).commitTransaction().wait()
43+
}
44+
45+
XCTAssertTrue(errorCaught)
46+
let count2 = try Todo.query(on: self.db).count().wait()
47+
XCTAssertEqual(count2, 0)
48+
}
49+
50+
var benchmarker: FluentBenchmarker {
51+
return .init(databases: self.dbs)
52+
}
53+
var eventLoopGroup: EventLoopGroup!
54+
var threadPool: NIOThreadPool!
55+
var dbs: Databases!
56+
var db: Database {
57+
self.benchmarker.database
58+
}
59+
var postgres: PostgresDatabase {
60+
self.db as! PostgresDatabase
61+
}
62+
63+
override func setUpWithError() throws {
64+
try super.setUpWithError()
65+
66+
XCTAssert(isLoggingConfigured)
67+
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
68+
self.threadPool = NIOThreadPool(numberOfThreads: 1)
69+
self.dbs = Databases(threadPool: threadPool, on: self.eventLoopGroup)
70+
71+
self.dbs.use(.testPostgres(subconfig: "A"), as: .a)
72+
self.dbs.use(.testPostgres(subconfig: "B"), as: .b)
73+
74+
let a = self.dbs.database(.a, logger: Logger(label: "test.fluent.a"), on: self.eventLoopGroup.next())
75+
_ = try (a as! PostgresDatabase).query("drop schema public cascade").wait()
76+
_ = try (a as! PostgresDatabase).query("create schema public").wait()
77+
78+
let b = self.dbs.database(.b, logger: Logger(label: "test.fluent.b"), on: self.eventLoopGroup.next())
79+
_ = try (b as! PostgresDatabase).query("drop schema public cascade").wait()
80+
_ = try (b as! PostgresDatabase).query("create schema public").wait()
81+
82+
try CreateTodo().prepare(on: self.db).wait()
83+
}
84+
85+
override func tearDownWithError() throws {
86+
try CreateTodo().revert(on: self.db).wait()
87+
self.dbs.shutdown()
88+
try self.threadPool.syncShutdownGracefully()
89+
try self.eventLoopGroup.syncShutdownGracefully()
90+
try super.tearDownWithError()
91+
}
92+
93+
final class Todo: Model {
94+
static let schema = "todos"
95+
96+
@ID
97+
var id: UUID?
98+
99+
@Field(key: "title")
100+
var title: String
101+
102+
init() { }
103+
init(title: String) { self.title = title; id = nil }
104+
}
105+
106+
struct CreateTodo: Migration {
107+
func prepare(on database: Database) -> EventLoopFuture<Void> {
108+
return database.schema("todos")
109+
.id()
110+
.field("title", .string, .required)
111+
.unique(on: "title")
112+
.create()
113+
}
114+
115+
func revert(on database: Database) -> EventLoopFuture<Void> {
116+
return database.schema("todos").delete()
117+
}
118+
}
119+
}

0 commit comments

Comments
 (0)