diff --git a/Package.swift b/Package.swift index 0f9a3c75..6475730b 100644 --- a/Package.swift +++ b/Package.swift @@ -4,14 +4,14 @@ import PackageDescription let package = Package( name: "fluent", platforms: [ - .macOS(.v10_15) + .macOS(.v10_15), ], products: [ .library(name: "Fluent", targets: ["Fluent"]), ], dependencies: [ - .package(url: "https://github.com/vapor/fluent-kit.git", from: "1.0.0-rc.1"), - .package(url: "https://github.com/vapor/vapor.git", from: "4.0.0-rc.1"), + .package(url: "https://github.com/vapor/fluent-kit.git", from: "1.0.0-rc.1.19"), + .package(url: "https://github.com/vapor/vapor.git", from: "4.0.0"), ], targets: [ .target(name: "Fluent", dependencies: [ @@ -20,6 +20,7 @@ let package = Package( ]), .testTarget(name: "FluentTests", dependencies: [ .target(name: "Fluent"), + .product(name: "XCTFluent", package: "fluent-kit"), .product(name: "XCTVapor", package: "vapor"), ]), ] diff --git a/Sources/Fluent/Deprecated.swift b/Sources/Fluent/Deprecated.swift new file mode 100644 index 00000000..30c7069a --- /dev/null +++ b/Sources/Fluent/Deprecated.swift @@ -0,0 +1,19 @@ +import Vapor + +@available(*, deprecated, renamed: "ModelAuthenticatable") +public typealias ModelUser = ModelAuthenticatable + +@available(*, deprecated, renamed: "ModelAuthenticatable") +public typealias ModelUserToken = ModelTokenAuthenticatable + +extension Application.Fluent.Sessions { + @available(*, deprecated, renamed: "Model.sessionAuthenticator()") + public func middleware( + for user: User.Type, + databaseID: DatabaseID? = nil + ) -> Middleware + where User: SessionAuthenticatable, User: Model, User.SessionID == User.IDValue + { + User.sessionAuthenticator(databaseID) + } +} diff --git a/Sources/Fluent/Fluent+Sessions.swift b/Sources/Fluent/Fluent+Sessions.swift index d5593d55..d8021257 100644 --- a/Sources/Fluent/Fluent+Sessions.swift +++ b/Sources/Fluent/Fluent+Sessions.swift @@ -10,16 +10,28 @@ extension Application.Fluent { } } -extension Application.Fluent.Sessions { - public func middleware( - for user: User.Type, - databaseID: DatabaseID? = nil - ) -> Middleware - where User: SessionAuthenticatable, User: Model, User.SessionID == User.IDValue - { - DatabaseSessionAuthenticator(databaseID: databaseID).middleware() +public protocol ModelSessionAuthenticatable: Model, SessionAuthenticatable + where Self.SessionID == Self.IDValue +{ } + +extension ModelSessionAuthenticatable { + public var sessionID: SessionID { + guard let id = self.id else { + fatalError("Cannot persist unsaved model to session.") + } + return id } +} +extension Model where Self: SessionAuthenticatable, Self.SessionID == Self.IDValue { + public static func sessionAuthenticator( + _ databaseID: DatabaseID? = nil + ) -> Authenticator { + DatabaseSessionAuthenticator(databaseID: databaseID) + } +} + +extension Application.Fluent.Sessions { public func driver(_ databaseID: DatabaseID? = nil) -> SessionDriver { DatabaseSessions(databaseID: databaseID) } @@ -46,20 +58,20 @@ private struct DatabaseSessions: SessionDriver { func createSession(_ data: SessionData, for request: Request) -> EventLoopFuture { let id = self.generateID() - return Session(key: id, data: data) + return SessionRecord(key: id, data: data) .create(on: request.db(self.databaseID)) .map { id } } func readSession(_ sessionID: SessionID, for request: Request) -> EventLoopFuture { - return Session.query(on: request.db(self.databaseID)) + SessionRecord.query(on: request.db(self.databaseID)) .filter(\.$key == sessionID) .first() .map { $0?.data } } func updateSession(_ sessionID: SessionID, to data: SessionData, for request: Request) -> EventLoopFuture { - return Session.query(on: request.db(self.databaseID)) + SessionRecord.query(on: request.db(self.databaseID)) .filter(\.$key == sessionID) .set(\.$data, to: data) .update() @@ -67,7 +79,7 @@ private struct DatabaseSessions: SessionDriver { } func deleteSession(_ sessionID: SessionID, for request: Request) -> EventLoopFuture { - return Session.query(on: request.db(self.databaseID)) + SessionRecord.query(on: request.db(self.databaseID)) .filter(\.$key == sessionID) .delete() } @@ -86,15 +98,37 @@ private struct DatabaseSessionAuthenticator: SessionAuthenticator { let databaseID: DatabaseID? - func resolve(sessionID: User.SessionID, for request: Request) -> EventLoopFuture { - User.find(sessionID, on: request.db) + func authenticate(sessionID: User.SessionID, for request: Request) -> EventLoopFuture { + User.find(sessionID, on: request.db).map { + if let user = $0 { + request.auth.login(user) + } + } } } -public final class Session: Model { - public static let schema = "sessions" +public final class SessionRecord: Model { + public static let schema = "_fluent_sessions" + + private struct _Migration: Migration { + func prepare(on database: Database) -> EventLoopFuture { + database.schema("_fluent_sessions") + .id() + .field("key", .string, .required) + .field("data", .json, .required) + .create() + } + + func revert(on database: Database) -> EventLoopFuture { + database.schema("_fluent_sessions").delete() + } + } + + public static var migration: Migration { + _Migration() + } - @ID(key: "id") + @ID(key: .id) public var id: UUID? @Field(key: "key") @@ -103,9 +137,7 @@ public final class Session: Model { @Field(key: "data") public var data: SessionData - public init() { - - } + public init() { } public init(id: UUID? = nil, key: SessionID, data: SessionData) { self.id = id @@ -113,17 +145,3 @@ public final class Session: Model { self.data = data } } - -public struct CreateSession: Migration { - public func prepare(on database: Database) -> EventLoopFuture { - return database.schema("sessions") - .field("id", .uuid, .identifier(auto: false)) - .field("key", .string, .required) - .field("data", .json, .required) - .create() - } - - public func revert(on database: Database) -> EventLoopFuture { - return database.schema("sessions").delete() - } -} diff --git a/Sources/Fluent/ModelUser.swift b/Sources/Fluent/ModelAuthenticatable.swift similarity index 70% rename from Sources/Fluent/ModelUser.swift rename to Sources/Fluent/ModelAuthenticatable.swift index 0479f4c2..5414db36 100644 --- a/Sources/Fluent/ModelUser.swift +++ b/Sources/Fluent/ModelAuthenticatable.swift @@ -1,16 +1,16 @@ import Vapor -public protocol ModelUser: Model, Authenticatable { +public protocol ModelAuthenticatable: Model, Authenticatable { static var usernameKey: KeyPath> { get } static var passwordHashKey: KeyPath> { get } func verify(password: String) throws -> Bool } -extension ModelUser { +extension ModelAuthenticatable { public static func authenticator( database: DatabaseID? = nil - ) -> ModelUserAuthenticator { - ModelUserAuthenticator(database: database) + ) -> Authenticator { + ModelAuthenticator(database: database) } var _$username: Field { @@ -22,27 +22,27 @@ extension ModelUser { } } -public struct ModelUserAuthenticator: BasicAuthenticator - where User: ModelUser +private struct ModelAuthenticator: BasicAuthenticator + where User: ModelAuthenticatable { public let database: DatabaseID? public func authenticate( basic: BasicAuthorization, for request: Request - ) -> EventLoopFuture { + ) -> EventLoopFuture { User.query(on: request.db(self.database)) .filter(\._$username == basic.username) .first() .flatMapThrowing { guard let user = $0 else { - return nil + return } guard try user.verify(password: basic.password) else { - return nil + return } - return user + request.auth.login(user) } } } diff --git a/Sources/Fluent/ModelUserToken.swift b/Sources/Fluent/ModelTokenAuthenticatable.swift similarity index 59% rename from Sources/Fluent/ModelUserToken.swift rename to Sources/Fluent/ModelTokenAuthenticatable.swift index a10748cd..df889dc1 100644 --- a/Sources/Fluent/ModelUserToken.swift +++ b/Sources/Fluent/ModelTokenAuthenticatable.swift @@ -1,17 +1,17 @@ import Vapor -public protocol ModelUserToken: Model { +public protocol ModelTokenAuthenticatable: Model, Authenticatable { associatedtype User: Model & Authenticatable static var valueKey: KeyPath> { get } static var userKey: KeyPath> { get } var isValid: Bool { get } } -extension ModelUserToken { +extension ModelTokenAuthenticatable { public static func authenticator( database: DatabaseID? = nil - ) -> ModelUserTokenAuthenticator { - ModelUserTokenAuthenticator(database: database) + ) -> Authenticator { + ModelTokenAuthenticator(database: database) } var _$value: Field { @@ -23,8 +23,8 @@ extension ModelUserToken { } } -public struct ModelUserTokenAuthenticator: BearerAuthenticator - where Token: ModelUserToken +private struct ModelTokenAuthenticator: BearerAuthenticator + where Token: ModelTokenAuthenticatable { public typealias User = Token.User public let database: DatabaseID? @@ -32,21 +32,23 @@ public struct ModelUserTokenAuthenticator: BearerAuthenticator public func authenticate( bearer: BearerAuthorization, for request: Request - ) -> EventLoopFuture { + ) -> EventLoopFuture { let db = request.db(self.database) return Token.query(on: db) .filter(\._$value == bearer.token) .first() .flatMap - { token -> EventLoopFuture in + { token -> EventLoopFuture in guard let token = token else { - return request.eventLoop.makeSucceededFuture(nil) + return request.eventLoop.makeSucceededFuture(()) } guard token.isValid else { - return token.delete(on: db).map { nil } + return token.delete(on: db) + } + request.auth.login(token) + return token._$user.get(on: db).map { + request.auth.login($0) } - return token._$user.get(on: db) - .map { $0 } } } } diff --git a/Tests/FluentTests/FluentOperatorTests.swift b/Tests/FluentTests/OperatorTests.swift similarity index 97% rename from Tests/FluentTests/FluentOperatorTests.swift rename to Tests/FluentTests/OperatorTests.swift index 068a71ee..518a5050 100644 --- a/Tests/FluentTests/FluentOperatorTests.swift +++ b/Tests/FluentTests/OperatorTests.swift @@ -2,7 +2,7 @@ import Fluent import Vapor import XCTVapor -final class FluentOperatorTests: XCTestCase { +final class OperatorTests: XCTestCase { func testCustomOperators() throws { let db = DummyDatabase() diff --git a/Tests/FluentTests/FluentPaginationTests.swift b/Tests/FluentTests/PaginationTests.swift similarity index 80% rename from Tests/FluentTests/FluentPaginationTests.swift rename to Tests/FluentTests/PaginationTests.swift index 63dbee37..07b3a897 100644 --- a/Tests/FluentTests/FluentPaginationTests.swift +++ b/Tests/FluentTests/PaginationTests.swift @@ -1,33 +1,37 @@ import Fluent import Vapor import XCTVapor +import XCTFluent -final class FluentPaginationTests: XCTestCase { +final class PaginationTests: XCTestCase { func testPagination() throws { let app = Application(.testing) defer { app.shutdown() } - var rows: [TestRow] = [] + var rows: [TestOutput] = [] for i in 1...1_000 { - rows.append(TestRow(data: ["id": i, "title": "Todo #\(i)"])) + rows.append(TestOutput([ + "id": i, + "title": "Todo #\(i)" + ])) } - - app.databases.use(TestDatabaseConfiguration { query in + let test = CallbackTestDatabase { query in XCTAssertEqual(query.schema, "todos") - let result: [TestRow] + let result: [TestOutput] if let limit = query.limits.first?.value, let offset = query.offsets.first?.value { - result = [TestRow](rows[min(offset, rows.count - 1).. EventLoopFuture> in Todo.query(on: req.db).paginate(for: req) diff --git a/Tests/FluentTests/FluentRepositoryTests.swift b/Tests/FluentTests/RepositoryTests.swift similarity index 90% rename from Tests/FluentTests/FluentRepositoryTests.swift rename to Tests/FluentTests/RepositoryTests.swift index 45639e26..338f17c5 100644 --- a/Tests/FluentTests/FluentRepositoryTests.swift +++ b/Tests/FluentTests/RepositoryTests.swift @@ -1,8 +1,9 @@ import Fluent import Vapor +import XCTFluent import XCTVapor -final class FluentRepositoryTests: XCTestCase { +final class RepositoryTests: XCTestCase { func testRepositoryPatternStatic() throws { let app = Application(.testing) defer { app.shutdown() } @@ -37,13 +38,8 @@ final class FluentRepositoryTests: XCTestCase { let app = Application(.testing) defer { app.shutdown() } - app.databases.use(TestDatabaseConfiguration { query in - XCTAssertEqual(query.schema, "posts") - return [ - TestRow(data: ["id": 1, "content": "a"]), - TestRow(data: ["id": 2, "content": "b"]), - ] - }, as: .test) + let test = ArrayTestDatabase() + app.databases.use(test.configuration, as: .test) app.posts.use { req in DatabasePostRepository(database: req.db(.test)) @@ -57,6 +53,11 @@ final class FluentRepositoryTests: XCTestCase { .init(id: 1, content: "a"), .init(id: 2, content: "b") ] + + test.append([ + TestOutput(["id": 1, "content": "a"]), + TestOutput(["id": 2, "content": "b"]), + ]) try app.testable().test(.GET, "foo") { res in XCTAssertEqual(res.status, .ok) diff --git a/Tests/FluentTests/SessionTests.swift b/Tests/FluentTests/SessionTests.swift new file mode 100644 index 00000000..a4753755 --- /dev/null +++ b/Tests/FluentTests/SessionTests.swift @@ -0,0 +1,103 @@ +import XCTFluent +import XCTVapor +import Fluent +import Vapor + +final class SessionTests: XCTestCase { + func testSessions() throws { + let app = Application(.testing) + defer { app.shutdown() } + + // Setup test db. + let test = ArrayTestDatabase() + app.databases.use(test.configuration, as: .test) + app.migrations.add(SessionRecord.migration) + + // Configure sessions. + app.sessions.use(.fluent) + app.middleware.use(app.sessions.middleware) + + // Setup routes. + app.get("set", ":value") { req -> HTTPStatus in + req.session.data["name"] = req.parameters.get("value") + return .ok + } + app.get("get") { req -> String in + req.session.data["name"] ?? "n/a" + } + app.get("del") { req -> HTTPStatus in + req.session.destroy() + return .ok + } + + // Add single query output with empty row. + test.append([TestOutput()]) + // Store session id. + var sessionID: String? + try app.test(.GET, "/set/vapor") { res in + sessionID = res.headers.setCookie?["vapor-session"]?.string + XCTAssertEqual(res.status, .ok) + } + + // Add single query output with session data for session read. + test.append([ + TestOutput([ + "id": UUID(), + "key": SessionID(string: sessionID!), + "data": SessionData(["name": "vapor"]) + ]) + ]) + // Add empty query output for session update. + test.append([]) + try app.test(.GET, "/get", beforeRequest: { req in + var cookies = HTTPCookies() + cookies["vapor-session"] = .init(string: sessionID!) + req.headers.cookie = cookies + }) { res in + XCTAssertEqual(res.status, .ok) + XCTAssertEqual(res.body.string, "vapor") + } + } +} + +final class User: Model { + static let schema = "users" + + @ID(key: .id) + var id: UUID? + + @Field(key: "name") + var name: String + + init() { } + + init(id: UUID? = nil, name: String) { + self.id = id + self.name = name + } +} + +extension User: ModelSessionAuthenticatable { } + +extension DatabaseID { + static var test: Self { + .init(string: "test") + } +} + +struct StaticDatabase: DatabaseConfiguration, DatabaseDriver { + let database: Database + var middleware: [AnyModelMiddleware] = [] + + func makeDriver(for databases: Databases) -> DatabaseDriver { + self + } + + func makeDatabase(with context: DatabaseContext) -> Database { + self.database + } + + func shutdown() { + // Do nothing. + } +} diff --git a/Tests/FluentTests/TestDatabase.swift b/Tests/FluentTests/TestDatabase.swift deleted file mode 100644 index b14ead33..00000000 --- a/Tests/FluentTests/TestDatabase.swift +++ /dev/null @@ -1,81 +0,0 @@ -import Fluent - -extension DatabaseID { - static var test: DatabaseID { .init(string: "test") } -} - -struct TestDatabase: Database { - let driver: TestDatabaseDriver - let context: DatabaseContext - - func execute(query: DatabaseQuery, onOutput: @escaping (DatabaseOutput) -> ()) -> EventLoopFuture { - self.driver.handler(query).forEach { row in - self.context.eventLoop.execute { - onOutput(row) - } - } - return self.eventLoop.makeSucceededFuture(()) - } - - func execute(schema: DatabaseSchema) -> EventLoopFuture { - fatalError() - } - - func execute(enum: DatabaseEnum) -> EventLoopFuture { - fatalError() - } - - func withConnection(_ closure: @escaping (Database) -> EventLoopFuture) -> EventLoopFuture { - closure(self) - } - - func transaction(_ closure: @escaping (Database) -> EventLoopFuture) -> EventLoopFuture { - closure(self) - } -} - -struct TestRow: DatabaseOutput { - var data: [FieldKey: Any] - - var description: String { - self.data.description - } - - func schema(_ schema: String) -> DatabaseOutput { - self - } - - func contains(_ path: [FieldKey]) -> Bool { - self.data.keys.contains(path[0]) - } - - func decode(_ path: [FieldKey], as type: T.Type) throws -> T where T : Decodable { - self.data[path[0]]! as! T - } -} - -final class TestDatabaseDriver: DatabaseDriver { - let handler: (DatabaseQuery) -> [DatabaseOutput] - - init(_ handler: @escaping (DatabaseQuery) -> [DatabaseOutput]) { - self.handler = handler - } - - func makeDatabase(with context: DatabaseContext) -> Database { - TestDatabase(driver: self, context: context) - } - - func shutdown() { - // nothing - } -} - -struct TestDatabaseConfiguration: DatabaseConfiguration { - let handler: (DatabaseQuery) -> [DatabaseOutput] - - var middleware: [AnyModelMiddleware] = [] - - func makeDriver(for databases: Databases) -> DatabaseDriver { - TestDatabaseDriver(handler) - } -}