Skip to content

Commit

Permalink
🚸 Add support for predicates on optionals
Browse files Browse the repository at this point in the history
  • Loading branch information
ftchirou committed Apr 18, 2021
1 parent 966d8a4 commit b7f9071
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 19 deletions.
2 changes: 1 addition & 1 deletion PredicateKit.podspec
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

Pod::Spec.new do |spec|
spec.name = "PredicateKit"
spec.version = "1.3.0"
spec.version = "1.4.0"
spec.summary = "Write expressive and type-safe predicates for CoreData using key-paths, comparisons and logical operators, literal values, and functions."
spec.description = <<-DESC
PredicateKit allows Swift developers to write expressive and type-safe predicates for CoreData using key-paths, comparisons and logical operators, literal values, and functions.
Expand Down
19 changes: 18 additions & 1 deletion PredicateKit/CoreData/NSFetchRequestBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct NSFetchRequestBuilder {
case .direct, .any, .all:
return NSComparisonPredicate(
leftExpression: makeExpression(from: comparison.expression),
rightExpression: NSExpression(forConstantValue: comparison.value),
rightExpression: makeExpression(from: comparison.value),
modifier: makeComparisonModifier(from: comparison.modifier),
type: makeOperator(from: comparison.operator),
options: makeComparisonOptions(from: comparison.options)
Expand Down Expand Up @@ -110,6 +110,10 @@ struct NSFetchRequestBuilder {
expression.toNSExpression(conversionOptions)
}

private func makeExpression(from primitive: Primitive) -> NSExpression {
return NSExpression(forConstantValue: primitive.value)
}

private func makeOperator(from operator: ComparisonOperator) -> NSComparisonPredicate.Operator {
switch `operator` {
case .beginsWith:
Expand Down Expand Up @@ -298,6 +302,19 @@ extension Query: NSExpressionConvertible {
}
}

// MARK: - Primitive

private extension Primitive {
var value: Any? {
switch Self.type {
case .nil:
return NSNull()
default:
return self
}
}
}

// MARK: - KeyPath

extension AnyKeyPath {
Expand Down
38 changes: 30 additions & 8 deletions PredicateKit/Predicate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ public enum Function<Input: Expression, Output>: Expression where Input.Value: A

public enum Index<Array: Expression>: Expression where Array.Value: AnyArray {
public typealias Root = Array.Root
public typealias Value = Array.Value.Element
public typealias Value = Array.Value.ArrayElement

case index(Array, Int)
case first(Array)
Expand Down Expand Up @@ -371,6 +371,11 @@ public func == <E: Expression, T: Equatable & Primitive> (lhs: E, rhs: T) -> Pre
.comparison(.init(lhs, .equal, rhs))
}

@_disfavoredOverload
public func == <E: Expression> (lhs: E, rhs: Nil) -> Predicate<E.Root> where E.Value: OptionalType {
.comparison(.init(lhs, .equal, rhs))
}

public func != <E: Expression, T: Equatable & Primitive> (lhs: E, rhs: T) -> Predicate<E.Root> where E.Value == T {
.comparison(.init(lhs, .notEqual, rhs))
}
Expand Down Expand Up @@ -495,15 +500,15 @@ extension Expression where Value: AnyArray {
.last(self)
}

public func at<T>(index: Int, _ keyPath: KeyPath<Value.Element, T>) -> ArrayElementKeyPath<Self, T> {
public func at<T>(index: Int, _ keyPath: KeyPath<Value.ArrayElement, T>) -> ArrayElementKeyPath<Self, T> {
.init(.index(index), self, keyPath)
}

public func first<T>(_ keyPath: KeyPath<Value.Element, T>) -> ArrayElementKeyPath<Self, T> {
public func first<T>(_ keyPath: KeyPath<Value.ArrayElement, T>) -> ArrayElementKeyPath<Self, T> {
.init(.first, self, keyPath)
}

public func last<T>(_ keyPath: KeyPath<Value.Element, T>) -> ArrayElementKeyPath<Self, T> {
public func last<T>(_ keyPath: KeyPath<Value.ArrayElement, T>) -> ArrayElementKeyPath<Self, T> {
.init(.last, self, keyPath)
}
}
Expand Down Expand Up @@ -600,6 +605,8 @@ extension Expression {

// MARK: - Supporting Protocols

// MARK: - StringValue

public protocol StringValue {
}

Expand All @@ -609,10 +616,15 @@ extension String: StringValue {
extension Optional: StringValue where Wrapped == String {
}

// MARK: - AnyArrayOrSet

public protocol AnyArrayOrSet {
associatedtype Element
}

extension Array: AnyArrayOrSet {
}

extension Set: AnyArrayOrSet {
}

Expand All @@ -623,16 +635,22 @@ extension Optional: AnyArrayOrSet where Wrapped: AnyArrayOrSet {
public typealias Element = Wrapped.Element
}

// MARK: - AnyArray

public protocol AnyArray {
associatedtype Element
associatedtype ArrayElement
}

extension Array: AnyArrayOrSet {
extension Array: AnyArray {
public typealias ArrayElement = Element
}

extension Array: AnyArray {
extension Optional: AnyArray where Wrapped: AnyArray {
public typealias ArrayElement = Wrapped.ArrayElement
}

// MARK: - PrimitiveCollection

public protocol PrimitiveCollection {
associatedtype PrimitiveElement: Primitive
}
Expand All @@ -649,6 +667,8 @@ extension Optional: PrimitiveCollection where Wrapped: PrimitiveCollection {
public typealias PrimitiveElement = Wrapped.PrimitiveElement
}

// MARK: - AdditiveCollection

public protocol AdditiveCollection {
associatedtype AdditiveElement: AdditiveArithmetic & Primitive
}
Expand All @@ -661,6 +681,8 @@ extension Optional: AdditiveCollection where Wrapped: PrimitiveCollection & Addi
public typealias AdditiveElement = Wrapped.AdditiveElement
}

// MARK: - ComparableCollection

public protocol ComparableCollection {
associatedtype ComparableElement: Comparable & Primitive
}
Expand All @@ -673,7 +695,7 @@ extension Optional: ComparableCollection where Wrapped: ComparableCollection {
public typealias ComparableElement = Wrapped.ComparableElement
}

// MARK: -
// MARK: - Private Initializers

extension Comparison {
fileprivate init<E: Expression>(
Expand Down
19 changes: 19 additions & 0 deletions PredicateKit/Primitive.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import Foundation

// MARK: - Primitive

public protocol Primitive {
static var type: Type { get }
}
Expand All @@ -45,6 +47,7 @@ public indirect enum Type: Equatable {
case data
case wrapped(Type)
case array(Type)
case `nil`
}

extension Bool: Primitive {
Expand Down Expand Up @@ -131,6 +134,22 @@ extension Optional: Primitive where Wrapped: Primitive {
public static var type: Type { Wrapped.type }
}

public struct Nil: Primitive, ExpressibleByNilLiteral {
public static var type: Type { .nil }

public init(nilLiteral: ()) {
}
}

// MARK: - Optional

public protocol OptionalType {
associatedtype Wrapped
}

extension Optional: OptionalType {
}

extension Optional: Comparable where Wrapped: Comparable {
public static func < (lhs: Self, rhs: Self) -> Bool {
switch (lhs, rhs) {
Expand Down
29 changes: 29 additions & 0 deletions PredicateKit/SwiftUI/SwiftUISupport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,32 @@ extension FetchRequest {
self.init(context: context, predicate: predicate)
}
}

@available(iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension FetchRequest {
/// Creates a fetch request that returns all objects in the underlying store.
///
/// - Important: Use this initializer **only** in conjunction with the SwiftUI property wrapper` @FetchRequest`. Fetch
/// requests created with this initializer cannot be executed outside of SwiftUI as they rely on the CoreData
/// managed object context injected in the environment of a SwiftUI view.
///
/// ## Example
///
/// struct ContentView: View {
/// @SwiftUI.FetchRequest()
/// .sorted(by: \Note.creationDate, .ascending)
/// .limit(100)
/// )
/// var notes: FetchedResults<Note>
///
/// var body: some View {
/// List(notes, id: \.self) {
/// Text($0.text)
/// }
/// }
/// }
///
public init() {
self.init(predicate: true)
}
}
50 changes: 50 additions & 0 deletions PredicateKitTests/CoreDataTests/NSFetchRequestBuilderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,45 @@ final class NSFetchRequestBuilderTests: XCTestCase {

XCTAssertTrue(fatalError.contains("does not conform to NSExpressionConvertible"))
}

func testObjectNilEqualityPredicate() throws {
let request = makeRequest(\Data.optionalRelationship == nil)
let builder = makeRequestBuilder()

let result: NSFetchRequest<Data> = builder.makeRequest(from: request)

let comparison = try XCTUnwrap(result.predicate as? NSComparisonPredicate)
XCTAssertEqual(comparison.leftExpression, NSExpression(forKeyPath: "optionalRelationship"))
XCTAssertEqual(comparison.rightExpression, NSExpression(forConstantValue: NSNull()))
XCTAssertEqual(comparison.predicateOperatorType, .equalTo)
XCTAssertEqual(comparison.comparisonPredicateModifier, .direct)
}

func testArrayNilEqualityPredicate() throws {
let request = makeRequest(\Data.optionalRelationships == nil)
let builder = makeRequestBuilder()

let result: NSFetchRequest<Data> = builder.makeRequest(from: request)

let comparison = try XCTUnwrap(result.predicate as? NSComparisonPredicate)
XCTAssertEqual(comparison.leftExpression, NSExpression(forKeyPath: "optionalRelationships"))
XCTAssertEqual(comparison.rightExpression, NSExpression(forConstantValue: NSNull()))
XCTAssertEqual(comparison.predicateOperatorType, .equalTo)
XCTAssertEqual(comparison.comparisonPredicateModifier, .direct)
}

func testNestedPrimitiveNilEqualityPredicate() throws {
let request = makeRequest(\Data.optionalRelationship?.text == nil)
let builder = makeRequestBuilder()

let result: NSFetchRequest<Data> = builder.makeRequest(from: request)

let comparison = try XCTUnwrap(result.predicate as? NSComparisonPredicate)
XCTAssertEqual(comparison.leftExpression, NSExpression(forKeyPath: "optionalRelationship.text"))
XCTAssertEqual(comparison.rightExpression, NSExpression(forConstantValue: NSNull()))
XCTAssertEqual(comparison.predicateOperatorType, .equalTo)
XCTAssertEqual(comparison.comparisonPredicateModifier, .direct)
}
}

// MARK: -
Expand All @@ -1051,6 +1090,8 @@ private class Data: NSManagedObject {
@NSManaged var creationDate: Date
@NSManaged var relationship: Relationship
@NSManaged var relationships: [Relationship]
@NSManaged var optionalRelationship: Relationship?
@NSManaged var optionalRelationships: [Relationship]?
}

private class Relationship: NSManagedObject {
Expand Down Expand Up @@ -1079,3 +1120,12 @@ private func makeRequestBuilder(
) -> NSFetchRequestBuilder {
.init(entityName: "")
}

class NoteGroup: NSManagedObject {
@NSManaged var notes: [NewNote]?
}

class NewNote: NSManagedObject {
@NSManaged var group: NoteGroup?
@NSManaged var id: String?
}
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,40 @@ final class NSManagedObjectContextExtensionsTests: XCTestCase {
XCTAssertTrue(inspector.inspectCalled)
}

func testFetchWithNilEquality() throws {
let now = Date()

try container.viewContext.insertNotes(
(text: "Hello, World!", creationDate: .distantFuture, updateDate: now, numberOfViews: 42, tags: ["greeting"]),
(text: "Goodbye!", creationDate: .distantPast, updateDate: nil, numberOfViews: 3, tags: ["greeting"])
)

let notes: [Note] = try container.viewContext
.fetch(where: \Note.updateDate == nil)
.result()

XCTAssertEqual(notes.count, 1)
XCTAssertEqual(notes.first?.text, "Goodbye!")
XCTAssertEqual(notes.first?.tags, ["greeting"])
XCTAssertEqual(notes.first?.numberOfViews, 3)
}

func testFetchWithArrayNilEqualityNilEquality() throws {
try container.viewContext.insertUsers(
(name: "John Doe", billingAccountType: "Pro", purchases: [35.0, 120.0]),
(name: "Jane Doe", billingAccountType: "Default", purchases: nil)
)

let users: [User] = try container.viewContext
.fetch(where: \User.billingInfo.purchases == nil)
.inspect(on: MockNSFetchRequestInspector())
.result()

XCTAssertEqual(users.count, 1)
XCTAssertEqual(users.first?.name, "Jane Doe")
XCTAssertEqual(users.first?.billingInfo.accountType, "Default")
}

private func makePersistentContainer() -> NSPersistentContainer {
return self.makePersistentContainer(with: model)
}
Expand All @@ -639,6 +673,7 @@ final class NSManagedObjectContextExtensionsTests: XCTestCase {
class Note: NSManagedObject {
@NSManaged var text: String
@NSManaged var creationDate: Date
@NSManaged var updateDate: Date?
@NSManaged var numberOfViews: Int
@NSManaged var tags: [String]
}
Expand All @@ -654,7 +689,7 @@ class User: NSManagedObject {

class BillingInfo: NSManagedObject {
@NSManaged var accountType: String
@NSManaged var purchases: [Double]
@NSManaged var purchases: [Double]?
}

class UserAccount: NSManagedObject {
Expand Down Expand Up @@ -703,6 +738,21 @@ private extension NSManagedObjectContext {
try save()
}

func insertNotes(
_ notes: (text: String, creationDate: Date, updateDate: Date?, numberOfViews: Int, tags: [String])...
) throws {
for description in notes {
let note = NSEntityDescription.insertNewObject(forEntityName: "Note", into: self) as! Note
note.text = description.text
note.tags = description.tags
note.numberOfViews = description.numberOfViews
note.creationDate = description.creationDate
note.updateDate = description.updateDate
}

try save()
}

func insertAccounts(purchases: [[Double]]) throws {
for description in purchases {
let account = NSEntityDescription.insertNewObject(forEntityName: "Account", into: self) as! Account
Expand All @@ -712,7 +762,7 @@ private extension NSManagedObjectContext {
try save()
}

func insertUsers(_ users: (name: String, billingAccountType: String, purchases: [Double])...) throws {
func insertUsers(_ users: (name: String, billingAccountType: String, purchases: [Double]?)...) throws {
for description in users {
let user = NSEntityDescription.insertNewObject(forEntityName: "User", into: self) as! User
user.name = description.name
Expand Down
Loading

0 comments on commit b7f9071

Please sign in to comment.