diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 1fdc3d9..de684e5 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -24,16 +24,6 @@ jobs: scheme: SpeziFoundation artifactname: SpeziFoundation.xcresult resultBundle: SpeziFoundation.xcresult - packageios_latest: - name: Build and Test Swift Package iOS Latest - uses: StanfordSpezi/.github/.github/workflows/xcodebuild-or-fastlane.yml@v2 - with: - runsonlabels: '["macOS", "self-hosted"]' - scheme: SpeziFoundation - xcodeversion: latest - swiftVersion: 6 - artifactname: SpeziFoundation-Latest.xcresult - resultBundle: SpeziFoundation-Latest.xcresult packagewatchos: name: Build and Test Swift Package watchOS uses: StanfordSpezi/.github/.github/workflows/xcodebuild-or-fastlane.yml@v2 @@ -73,16 +63,6 @@ jobs: resultBundle: SpeziFoundationMacOS.xcresult destination: 'platform=macOS,arch=arm64' artifactname: SpeziFoundationMacOS.xcresult - codeql: - name: CodeQL - uses: StanfordSpezi/.github/.github/workflows/xcodebuild-or-fastlane.yml@v2 - with: - codeql: true - test: false - scheme: SpeziFoundation - permissions: - security-events: write - actions: read uploadcoveragereport: name: Upload Coverage Report needs: [packageios, packagewatchos, packagevisionos, packagetvos, packagemacos] diff --git a/Package.swift b/Package.swift index e5fa034..b8aece2 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.9 +// swift-tools-version:6.0 // // This source file is part of the Stanford Spezi open-source project @@ -12,13 +12,6 @@ import class Foundation.ProcessInfo import PackageDescription -#if swift(<6) -let swiftConcurrency: SwiftSetting = .enableExperimentalFeature("StrictConcurrency") -#else -let swiftConcurrency: SwiftSetting = .enableUpcomingFeature("StrictConcurrency") -#endif - - let package = Package( name: "SpeziFoundation", defaultLocalization: "en", @@ -32,16 +25,18 @@ let package = Package( products: [ .library(name: "SpeziFoundation", targets: ["SpeziFoundation"]) ], - dependencies: swiftLintPackage(), + dependencies: [ + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0") + ] + swiftLintPackage(), targets: [ .target( name: "SpeziFoundation", + dependencies: [ + .product(name: "Atomics", package: "swift-atomics") + ], resources: [ .process("Resources") ], - swiftSettings: [ - swiftConcurrency - ], plugins: [] + swiftLintPlugin() ), .testTarget( @@ -49,9 +44,6 @@ let package = Package( dependencies: [ .target(name: "SpeziFoundation") ], - swiftSettings: [ - swiftConcurrency - ], plugins: [] + swiftLintPlugin() ) ] diff --git a/Sources/SpeziFoundation/Semaphore/AsyncSemaphore.swift b/Sources/SpeziFoundation/Concurrency/AsyncSemaphore.swift similarity index 79% rename from Sources/SpeziFoundation/Semaphore/AsyncSemaphore.swift rename to Sources/SpeziFoundation/Concurrency/AsyncSemaphore.swift index 6deb803..c8bc7e7 100644 --- a/Sources/SpeziFoundation/Semaphore/AsyncSemaphore.swift +++ b/Sources/SpeziFoundation/Concurrency/AsyncSemaphore.swift @@ -109,17 +109,16 @@ public final class AsyncSemaphore: Sendable { /// This method allows the `Task` calling ``waitCheckingCancellation()`` to be cancelled while waiting, throwing a `CancellationError` if the `Task` is cancelled before it can proceed. /// /// - Throws: `CancellationError` if the task is cancelled while waiting. - public func waitCheckingCancellation() async throws { - try Task.checkCancellation() // check if we are already cancelled + public func waitCheckingCancellation() async throws(CancellationError) { + if Task.isCancelled { // check if we are already cancelled + throw CancellationError() + } unsafeLock() // this is okay, as the continuation body actually runs sync, so we do no have async code within critical region - do { - // check if we got cancelled while acquiring the lock - try Task.checkCancellation() - } catch { + if Task.isCancelled { // check if we got cancelled while acquiring the lock unsafeUnlock() - throw error + throw CancellationError() } value -= 1 // decrease the value @@ -130,37 +129,42 @@ public final class AsyncSemaphore: Sendable { let id = UUID() - try await withTaskCancellationHandler { - try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in - if Task.isCancelled { - value += 1 // restore the value - unsafeUnlock() - - continuation.resume(throwing: CancellationError()) - } else { - suspendedTasks.append(SuspendedTask(id: id, suspension: .cancelable(continuation))) - unsafeUnlock() + do { + try await withTaskCancellationHandler { + try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in + if Task.isCancelled { + value += 1 // restore the value + unsafeUnlock() + + continuation.resume(throwing: CancellationError()) + } else { + suspendedTasks.append(SuspendedTask(id: id, suspension: .cancelable(continuation))) + unsafeUnlock() + } } - } - } onCancel: { - let task = nsLock.withLock { - value += 1 + } onCancel: { + let task = nsLock.withLock { + value += 1 - guard let index = suspendedTasks.firstIndex(where: { $0.id == id }) else { - preconditionFailure("Inconsistent internal state reached") + guard let index = suspendedTasks.firstIndex(where: { $0.id == id }) else { + preconditionFailure("Inconsistent internal state reached") + } + + let task = suspendedTasks[index] + suspendedTasks.remove(at: index) + return task } - let task = suspendedTasks[index] - suspendedTasks.remove(at: index) - return task - } - - switch task.suspension { - case .regular: - preconditionFailure("Tried to cancel a task that was not cancellable!") - case let .cancelable(continuation): - continuation.resume(throwing: CancellationError()) + switch task.suspension { + case .regular: + preconditionFailure("Tried to cancel a task that was not cancellable!") + case let .cancelable(continuation): + continuation.resume(throwing: CancellationError()) + } } + } catch { + assert(error is CancellationError, "Injected unexpected error into continuation: \(error)") + throw CancellationError() } } diff --git a/Sources/SpeziFoundation/Concurrency/ManagedAsynchronousAccess.swift b/Sources/SpeziFoundation/Concurrency/ManagedAsynchronousAccess.swift new file mode 100644 index 0000000..730dab3 --- /dev/null +++ b/Sources/SpeziFoundation/Concurrency/ManagedAsynchronousAccess.swift @@ -0,0 +1,186 @@ +// +// This source file is part of the Stanford Spezi open-source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + + +/// A continuation with exclusive access. +/// +/// +public final class ManagedAsynchronousAccess { + private final class CallSiteState { + var wasCancelled = false + + init() {} + } + + private let access: AsyncSemaphore + private var continuation: CheckedContinuation? + private var associatedState: CallSiteState? + + /// Determine if the is currently an ongoing access. + public var ongoingAccess: Bool { + continuation != nil + } + + /// Create a new managed asynchronous access. + public init() { + self.access = AsyncSemaphore(value: 1) + } + + private func markCancelled() { + if let associatedState { + associatedState.wasCancelled = true + self.associatedState = nil + } + } + + /// Resume the continuation by either returning a value or throwing an error. + /// - Parameter result: The result to return from the continuation. + /// - Returns: Returns `true`, if there was another task waiting to access the continuation and it was resumed. + @discardableResult + public func resume(with result: sending Result) -> Bool { + self.associatedState = nil + if let continuation { + self.continuation = nil + let didSignalAnyone = access.signal() + continuation.resume(with: result) + return didSignalAnyone + } + + return false + } + + /// Resume the continuation by returning a value. + /// - Parameter value: The value to return from the continuation. + /// - Returns: Returns `true`, if there was another task waiting to access the continuation and it was resumed. + @discardableResult + public func resume(returning value: sending Value) -> Bool { + resume(with: .success(value)) + } + + /// Resume the continuation by throwing an error. + /// - Parameter error: The error that is thrown from the continuation. + /// - Returns: Returns `true`, if there was another task waiting to access the continuation and it was resumed. + @discardableResult + public func resume(throwing error: E) -> Bool { + resume(with: .failure(error)) + } +} + + +extension ManagedAsynchronousAccess where Value == Void { + /// Resume the continuation. + /// - Returns: Returns `true`, if there was another task waiting to access the continuation and it was resumed. + @discardableResult + public func resume() -> Bool { + self.resume(returning: ()) + } +} + + +extension ManagedAsynchronousAccess where E == Error { + /// Perform an managed, asynchronous access. + /// + /// Call this method to perform an managed, asynchronous access. This method awaits exclusive access, creates a continuation and + /// calls the provided closure and then suspends until ``resume(with:)`` is called. + /// + /// - Parameters: + /// - isolation: Inherits actor isolation from the call site. + /// - action: The action that is executed inside the continuation closure that triggers an asynchronous operation. + /// - Returns: The value from the continuation. + public func perform( + isolation: isolated (any Actor)? = #isolation, + action: () -> Void + ) async throws -> Value { + try await access.waitCheckingCancellation() + + let state = CallSiteState() + + defer { + if state.wasCancelled { + withUnsafeCurrentTask { task in + task?.cancel() + } + } + } + + return try await withCheckedThrowingContinuation { continuation in + assert(self.continuation == nil, "continuation was unexpectedly not nil") + self.continuation = continuation + assert(self.associatedState == nil, "associatedState was unexpectedly not nil") + self.associatedState = state + action() + } + } + + /// Cancel all ongoing accesses. + /// + /// Calling this methods will cancel all tasks that currently await exclusive access and will resume the continuation by throwing a + /// cancellation error. + /// - Parameter error: A custom error that is thrown instead of the cancellation error. + public func cancelAll(error: E? = nil) { + markCancelled() + if let continuation { + self.continuation = nil + continuation.resume(throwing: error ?? CancellationError()) + } + access.cancelAll() + } +} + + +extension ManagedAsynchronousAccess where E == Never { + /// Perform an managed, asynchronous access. + /// + /// Call this method to perform an managed, asynchronous access. This method awaits exclusive access, creates a continuation and + /// calls the provided closure and then suspends until ``resume(with:)`` is called. + /// + /// - Parameters: + /// - isolation: Inherits actor isolation from the call site. + /// - action: The action that is executed inside the continuation closure that triggers an asynchronous operation. + public func perform( + isolation: isolated (any Actor)? = #isolation, + action: () -> Void + ) async throws(CancellationError) -> Value { + try await access.waitCheckingCancellation() + + let state = CallSiteState() + + let value = await withCheckedContinuation { continuation in + assert(self.continuation == nil, "continuation was unexpectedly not nil") + self.continuation = continuation + assert(self.associatedState == nil, "associatedState was unexpectedly not nil") + self.associatedState = state + action() + } + + if state.wasCancelled { + withUnsafeCurrentTask { task in + task?.cancel() + } + throw CancellationError() + } + + return value + } +} + + +extension ManagedAsynchronousAccess where Value == Void, E == Never { + /// Cancel all ongoing accesses. + /// + /// Calling this methods will cancel all tasks that currently await exclusive access. + /// The continuation will be resumed. Make sure to propagate cancellation information yourself. + public func cancelAll() { + markCancelled() + if let continuation { + self.continuation = nil + continuation.resume() + } + access.cancelAll() + } +} diff --git a/Sources/SpeziFoundation/Concurrency/RWLock.swift b/Sources/SpeziFoundation/Concurrency/RWLock.swift new file mode 100644 index 0000000..c96f492 --- /dev/null +++ b/Sources/SpeziFoundation/Concurrency/RWLock.swift @@ -0,0 +1,110 @@ +// +// This source file is part of the Stanford Spezi open-source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + + +import Foundation + + +protocol PThreadReadWriteLock: AnyObject { + // We need the unsafe mutable pointer, as otherwise we need to pass the property as inout parameter which isn't thread safe. + var rwLock: UnsafeMutablePointer { get } +} + + +/// Read-Write Lock using `pthread_rwlock`. +/// +/// Looking at [Benchmarking Swift Locking APIs](https://www.vadimbulavin.com/benchmarking-locking-apis), using `pthread_rwlock` +/// is favorable over using dispatch queues. +/// +/// - Note: Refer to ``RecursiveRWLock`` if you need a recursive read-write lock. +public final class RWLock: PThreadReadWriteLock, @unchecked Sendable { + let rwLock: UnsafeMutablePointer + + /// Create a new read-write lock. + public init() { + rwLock = Self.pthreadInit() + } + + /// Call `body` with a reading lock. + /// + /// - Parameter body: A function that reads a value while locked. + /// - Returns: The value returned from the given function. + public func withReadLock(body: () throws -> T) rethrows -> T { + pthreadWriteLock() + defer { + pthreadUnlock() + } + return try body() + } + + /// Call `body` with a writing lock. + /// + /// - Parameter body: A function that writes a value while locked, then returns some value. + /// - Returns: The value returned from the given function. + public func withWriteLock(body: () throws -> T) rethrows -> T { + pthreadWriteLock() + defer { + pthreadUnlock() + } + return try body() + } + + /// Determine if the lock is currently write locked. + /// - Returns: Returns `true` if the lock is currently write locked. + public func isWriteLocked() -> Bool { + let status = pthread_rwlock_trywrlock(rwLock) + + // see status description https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man3/pthread_rwlock_trywrlock.3.html + switch status { + case 0: + pthreadUnlock() + return false + case EBUSY: // The calling thread is not able to acquire the lock without blocking. + return false // means we aren't locked + case EDEADLK: // The calling thread already owns the read/write lock (for reading or writing). + return true + default: + preconditionFailure("Unexpected status from pthread_rwlock_tryrdlock: \(status)") + } + } + + deinit { + pthreadDeinit() + } +} + + +extension PThreadReadWriteLock { + static func pthreadInit() -> UnsafeMutablePointer { + let lock: UnsafeMutablePointer = .allocate(capacity: 1) + let status = pthread_rwlock_init(lock, nil) + precondition(status == 0, "pthread_rwlock_init failed with status \(status)") + return lock + } + + func pthreadWriteLock() { + let status = pthread_rwlock_wrlock(rwLock) + assert(status == 0, "pthread_rwlock_wrlock failed with statusĀ \(status)") + } + + func pthreadReadLock() { + let status = pthread_rwlock_rdlock(rwLock) + assert(status == 0, "pthread_rwlock_rdlock failed with status \(status)") + } + + func pthreadUnlock() { + let status = pthread_rwlock_unlock(rwLock) + assert(status == 0, "pthread_rwlock_unlock failed with status \(status)") + } + + func pthreadDeinit() { + let status = pthread_rwlock_destroy(rwLock) + assert(status == 0) + rwLock.deallocate() + } +} diff --git a/Sources/SpeziFoundation/Concurrency/RecursiveRWLock.swift b/Sources/SpeziFoundation/Concurrency/RecursiveRWLock.swift new file mode 100644 index 0000000..b904f10 --- /dev/null +++ b/Sources/SpeziFoundation/Concurrency/RecursiveRWLock.swift @@ -0,0 +1,114 @@ +// +// This source file is part of the Stanford Spezi open-source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + + +import Atomics +import Foundation + + +/// Recursive Read-Write Lock using `pthread_rwlock`. +/// +/// This is a recursive version of the ``RWLock`` implementation. +public final class RecursiveRWLock: PThreadReadWriteLock, @unchecked Sendable { + let rwLock: UnsafeMutablePointer + + private let writerThread = ManagedAtomic(nil) + private var writerCount = 0 + private var readerCount = 0 + + /// Create a new recursive read-write lock. + public init() { + rwLock = Self.pthreadInit() + } + + + private func writeLock() { + let selfThread = pthread_self() + + if let writer = writerThread.load(ordering: .relaxed), + pthread_equal(writer, selfThread) != 0 { + // we know that the writerThread is us, so access to `writerCount` is synchronized (its us that holds the rwLock). + writerCount += 1 + assert(writerCount > 1, "Synchronization issue. Writer count is unexpectedly low: \(writerCount)") + return + } + + pthreadWriteLock() + + writerThread.store(selfThread, ordering: .relaxed) + writerCount = 1 + } + + private func writeUnlock() { + // we assume this is called while holding the write lock, so access to `writerCount` is safe + if writerCount > 1 { + writerCount -= 1 + return + } + + // otherwise it is the last unlock + writerThread.store(nil, ordering: .relaxed) + writerCount = 0 + + pthreadUnlock() + } + + private func readLock() { + let selfThread = pthread_self() + + if let writer = writerThread.load(ordering: .relaxed), + pthread_equal(writer, selfThread) != 0 { + // we know that the writerThread is us, so access to `readerCount` is synchronized (its us that holds the rwLock). + readerCount += 1 + assert(readerCount > 0, "Synchronization issue. Reader count is unexpectedly low: \(readerCount)") + return + } + + pthreadReadLock() + } + + private func readUnlock() { + // we assume this is called while holding the reader lock, so access to `readerCount` is safe + if readerCount > 0 { + // fine to go down to zero (we still hold the lock in write mode) + readerCount -= 1 + return + } + + pthreadUnlock() + } + + + /// Call `body` with a writing lock. + /// + /// - Parameter body: A function that writes a value while locked, then returns some value. + /// - Returns: The value returned from the given function. + public func withWriteLock(body: () throws -> T) rethrows -> T { + writeLock() + defer { + writeUnlock() + } + return try body() + } + + /// Call `body` with a reading lock. + /// + /// - Parameter body: A function that reads a value while locked. + /// - Returns: The value returned from the given function. + public func withReadLock(body: () throws -> T) rethrows -> T { + readLock() + defer { + readUnlock() + } + return try body() + } + + deinit { + pthreadDeinit() + } +} diff --git a/Sources/SpeziFoundation/Misc/DataDescriptor.swift b/Sources/SpeziFoundation/Misc/DataDescriptor.swift new file mode 100644 index 0000000..919dc56 --- /dev/null +++ b/Sources/SpeziFoundation/Misc/DataDescriptor.swift @@ -0,0 +1,112 @@ +// +// This source file is part of the Stanford Spezi open-source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + +import Foundation + + +/// A matching descriptor for a `Data`-based field. +/// +/// To match against `Data`, you provide the ``data`` you want to match to, with an additional ``mask`` that defines which bits should be considered for the check. +public struct DataDescriptor { + /// The data to match against. + public let data: Data + /// The mask that + public let mask: Data + + /// Create a new data descriptor. + /// - Parameters: + /// - data: The data. + /// - mask: The mask. + public init(data: Data, mask: Data) { + self.data = data + self.mask = mask + precondition(mask.count == data.count, "The data mask must the data size. Mask length \(mask.count), data length \(data.count).") + } + + /// Create a new data descriptor with a mask that matches all bits. + /// - Parameter data: The data. + public init(data: Data) { + let mask = Data(repeating: 0xFF, count: data.count) + self.init(data: data, mask: mask) + } + + private static func bitwiseAnd(lhs: Data, rhs: Data) -> Data { + if rhs.count > lhs.count { + return bitwiseAnd(lhs: rhs, rhs: lhs) + } + + var value = lhs + + for index in rhs.indices { + value[index] = lhs[index] & rhs[index] + } + + return value + } + + /// Determine if the data descriptor matches the provided Data value. + /// - Parameter value: The data value to check if it matches the descriptor. + /// - Returns: Return `true` if the bits as defined by ``mask`` if `value` and ``data`` are equal. + public func matches(_ value: Data) -> Bool { + let valueMasked = Self.bitwiseAnd(lhs: value, rhs: mask) + let dataMasked = Self.bitwiseAnd(lhs: data, rhs: mask) + + return Self.equalBitPattern(lhs: valueMasked, rhs: dataMasked) + } +} + + +extension DataDescriptor: Sendable, Hashable { + /// Determine if two data blobs expose the same bit pattern (e.g., additional zero bytes do not matter) + /// - Parameters: + /// - lhs: The left-hand-side. + /// - rhs: The right-hand-side. + /// - Returns: Returns `true` if the bit pattern matches. + static func equalBitPattern(lhs: Data, rhs: Data) -> Bool { + if rhs.count > lhs.count { + return Self.equalBitPattern(lhs: rhs, rhs: lhs) + } + + if lhs.count > rhs.count { + guard lhs[rhs.endIndex...].allSatisfy({ $0 == 0 }) else { + return false + } + } + + for index in rhs.indices { + guard lhs[index] == rhs[index] else { + return false + } + } + + return true + } + + public static func == (lhs: DataDescriptor, rhs: DataDescriptor) -> Bool { + Self.equalBitPattern(lhs: lhs.mask, rhs: rhs.mask) + && Self.equalBitPattern( + lhs: Self.bitwiseAnd(lhs: lhs.data, rhs: lhs.mask), + rhs: Self.bitwiseAnd(lhs: rhs.data, rhs: rhs.mask) + ) + } +} + + +extension DataDescriptor: CustomStringConvertible, CustomDebugStringConvertible { + public var description: String { + """ + DataDescriptor(\ + data: \(data.map { String(format: "%02hhx", $0) }.joined()), \ + mask: \(mask.map { String(format: "%02hhx", $0) }.joined()) + """ + } + + public var debugDescription: String { + description + } +} diff --git a/Sources/SpeziFoundation/Misc/TimeoutError.swift b/Sources/SpeziFoundation/Misc/TimeoutError.swift index 8368f00..e82c036 100644 --- a/Sources/SpeziFoundation/Misc/TimeoutError.swift +++ b/Sources/SpeziFoundation/Misc/TimeoutError.swift @@ -98,7 +98,7 @@ extension TimeoutError: LocalizedError { /// - timeout: The duration of the timeout. /// - action: The action to run once the timeout passed. @inlinable -public func withTimeout(of timeout: Duration, perform action: @Sendable () async -> Void) async { +public func withTimeout(of timeout: Duration, perform action: sending () async -> Void) async { try? await Task.sleep(for: timeout) guard !Task.isCancelled else { return diff --git a/Sources/SpeziFoundation/SpeziFoundation.docc/SpeziFoundation.md b/Sources/SpeziFoundation/SpeziFoundation.docc/SpeziFoundation.md index 3d26433..430fb60 100644 --- a/Sources/SpeziFoundation/SpeziFoundation.docc/SpeziFoundation.md +++ b/Sources/SpeziFoundation/SpeziFoundation.docc/SpeziFoundation.md @@ -23,14 +23,20 @@ Spezi Foundation provides a base layer of functionality useful in many applicati - ``AnyArray`` - ``AnyOptional`` -### Semaphore +### Concurrency +- ``RWLock`` +- ``RecursiveRWLock`` - ``AsyncSemaphore`` +- ``ManagedAsynchronousAccess`` ### Encoders and Decoders - ``TopLevelEncoder`` - ``TopLevelDecoder`` +### Data +- ``DataDescriptor`` + ### Timeout - ``TimeoutError`` diff --git a/Tests/SpeziFoundationTests/AsyncSemaphoreTests.swift b/Tests/SpeziFoundationTests/AsyncSemaphoreTests.swift index 0524554..c909161 100644 --- a/Tests/SpeziFoundationTests/AsyncSemaphoreTests.swift +++ b/Tests/SpeziFoundationTests/AsyncSemaphoreTests.swift @@ -261,7 +261,7 @@ final class AsyncSemaphoreTests: XCTestCase { // swiftlint:disable:this type_b /// its `run()` method, and counts the effective number of /// concurrent executions for testing purpose. @MainActor - class Runner { + final class Runner { private let semaphore: AsyncSemaphore private var count = 0 private(set) var effectiveMaxConcurrentRuns = 0 diff --git a/Tests/SpeziFoundationTests/DataDescriptorTests.swift b/Tests/SpeziFoundationTests/DataDescriptorTests.swift new file mode 100644 index 0000000..ecf5cc5 --- /dev/null +++ b/Tests/SpeziFoundationTests/DataDescriptorTests.swift @@ -0,0 +1,123 @@ +// +// This source file is part of the Stanford Spezi open-source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + +@testable import SpeziFoundation +import XCTest + + +final class DataDescriptorTests: XCTestCase { + func testEqualBitPattern() { + let data1 = Data([0xFF, 0x00, 0xAA]) + let data2 = Data([0xFF, 0x00, 0xAA]) + let data3 = Data([0xFF, 0x00]) + let data4 = Data([0xFF, 0x00, 0xAA, 0x00]) + + XCTAssertTrue(DataDescriptor.equalBitPattern(lhs: data1, rhs: data2), "Identical data should be equal") + XCTAssertFalse(DataDescriptor.equalBitPattern(lhs: data1, rhs: data3), "Different length data should not be equal") + XCTAssertTrue(DataDescriptor.equalBitPattern(lhs: data1, rhs: data4), "Additional zero bytes in rhs should be ignored") + } + + func testDataDescriptorEquality_sameDataAndMask() { + let data1 = Data([0xFF, 0x00, 0xAA]) + let mask1 = Data([0xFF, 0xFF, 0xFF]) + + let descriptor1 = DataDescriptor(data: data1, mask: mask1) + let descriptor2 = DataDescriptor(data: data1, mask: mask1) + + XCTAssertEqual(descriptor1, descriptor2, "Data descriptors with the same data and mask should be equal") + } + + func testDataDescriptorEquality_differentDataSameMask() { + let data1 = Data([0xFF, 0x00, 0xAA]) + let data2 = Data([0xFF, 0x00, 0xBB]) + let mask1 = Data([0xFF, 0xFF, 0xFF]) + + let descriptor1 = DataDescriptor(data: data1, mask: mask1) + let descriptor2 = DataDescriptor(data: data2, mask: mask1) + + XCTAssertNotEqual(descriptor1, descriptor2, "Data descriptors with different data but the same mask should not be equal") + } + + func testDataDescriptorEquality_sameDataDifferentMask() { + let data1 = Data([0xFF, 0x00, 0xAA]) + let mask1 = Data([0xFF, 0xFF, 0xFF]) + let mask2 = Data([0xFF, 0x00, 0xFF]) + + let descriptor1 = DataDescriptor(data: data1, mask: mask1) + let descriptor2 = DataDescriptor(data: data1, mask: mask2) + + XCTAssertNotEqual(descriptor1, descriptor2, "Data descriptors with the same data but different masks should not be equal") + } + + func testDataDescriptorEquality_bitwiseAndComparison() { + let data1 = Data([0xFF, 0x00, 0xAA]) + let data2 = Data([0xFF, 0x00, 0xAB]) + let mask1 = Data([0xFF, 0xFF, 0xFE]) + + let descriptor1 = DataDescriptor(data: data1, mask: mask1) + let descriptor2 = DataDescriptor(data: data2, mask: mask1) + + XCTAssertEqual(descriptor1, descriptor2, "Data descriptors with different data but the same masked result should be equal") + } + + func testMatches_sameDataAndMask() { + let data = Data([0xFF, 0x00, 0xAA]) + let mask = Data([0xFF, 0xFF, 0xFF]) + let descriptor = DataDescriptor(data: data, mask: mask) + + XCTAssertTrue(descriptor.matches(data), "Data should match exactly when both the data and the mask are identical.") + } + + func testMatches_differentDataSameMask() { + let data = Data([0xFF, 0x00, 0xAA]) + let otherData = Data([0xFF, 0x00, 0xAB]) + let mask = Data([0xFF, 0xFF, 0xFF]) + let descriptor = DataDescriptor(data: data, mask: mask) + + XCTAssertFalse(descriptor.matches(otherData), "Data should not match when the data differs and the mask is fully applied.") + } + + func testMatches_maskedDataMatches() { + let data = Data([0xFF, 0x00, 0xAA]) + let otherData = Data([0xFF, 0x00, 0xAB]) + let mask = Data([0xFF, 0xFF, 0xFE]) // Ignore the last bit of the last byte + let descriptor = DataDescriptor(data: data, mask: mask) + + XCTAssertTrue(descriptor.matches(otherData), "Data should match if the masked bits are equal.") + } + + func testMatches_dataShorterThanDescriptor() { + let data = Data([0xFF, 0x00, 0xAA]) + let shorterData = Data([0xFF, 0x00]) + let mask = Data([0xFF, 0xFF, 0xFF]) + let descriptor = DataDescriptor(data: data, mask: mask) + + XCTAssertFalse(descriptor.matches(shorterData), "Data shorter than the descriptor should not match.") + } + + func testMatches_dataLongerThanDescriptor() { + let data = Data([0xFF, 0x00, 0xAA]) + let longerData = Data([0xFF, 0x00, 0xAA, 0x00]) + let mask = Data([0xFF, 0xFF, 0xFF]) + let descriptor = DataDescriptor(data: data, mask: mask) + + XCTAssertTrue(descriptor.matches(longerData), "Data longer than the descriptor should match as long as the relevant bits match.") + } + + func testMatches_differentMasksSameData() { + let data = Data([0xFF, 0x00, 0xAA]) + let maskedData = Data([0xFF, 0x00, 0xAB]) + let mask1 = Data([0xFF, 0xFF, 0xFE]) // Ignore the last bit of the last byte + let mask2 = Data([0xFF, 0xFF, 0xFF]) // Fully consider all bits + let descriptor1 = DataDescriptor(data: data, mask: mask1) + let descriptor2 = DataDescriptor(data: data, mask: mask2) + + XCTAssertTrue(descriptor1.matches(maskedData), "Descriptor with mask that ignores last bit should match data with last bit difference.") + XCTAssertFalse(descriptor2.matches(maskedData), "Descriptor with fully applied mask should not match data with last bit difference.") + } +} diff --git a/Tests/SpeziFoundationTests/ManagedAsynchronousAccessTests.swift b/Tests/SpeziFoundationTests/ManagedAsynchronousAccessTests.swift new file mode 100644 index 0000000..7bafe23 --- /dev/null +++ b/Tests/SpeziFoundationTests/ManagedAsynchronousAccessTests.swift @@ -0,0 +1,265 @@ +// +// This source file is part of the Stanford Spezi open-source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + +import SpeziFoundation +import XCTest + + +final class ManagedAsynchronousAccessTests: XCTestCase { + @MainActor + func testResumeWithSuccess() async throws { + let access = ManagedAsynchronousAccess() + let expectedValue = "Success" + + let expectation = XCTestExpectation(description: "task") + + Task { + do { + let value = try await access.perform { + // this is were you would trigger your operation + } + XCTAssertEqual(value, expectedValue) + } catch { + XCTFail("Unexpected error: \(error)") + } + expectation.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + + XCTAssertTrue(access.ongoingAccess) + + let didResume = access.resume(returning: expectedValue) + + XCTAssertFalse(didResume) + XCTAssertFalse(access.ongoingAccess) + + await fulfillment(of: [expectation]) + } + + @MainActor + func testResumeWithError() async throws { + let access = ManagedAsynchronousAccess() + + let expectation = XCTestExpectation(description: "task") + Task { + do { + _ = try await access.perform {} + XCTFail("Expected error, but got success.") + } catch { + XCTAssertTrue(error is TimeoutError) + } + expectation.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + + XCTAssertTrue(access.ongoingAccess) + + // throw some error + let didResume = access.resume(throwing: TimeoutError()) + + XCTAssertFalse(didResume) + XCTAssertFalse(access.ongoingAccess) + + await fulfillment(of: [expectation], timeout: 2.0) + } + + @MainActor + func testCancelAll() async throws { + let access = ManagedAsynchronousAccess() + + let expectation = XCTestExpectation(description: "task") + let handle = Task { + do { + _ = try await access.perform {} + XCTFail("Expected cancellation error.") + } catch { + XCTAssertTrue(error is CancellationError) + } + expectation.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + + XCTAssertTrue(access.ongoingAccess) + XCTAssertFalse(handle.isCancelled) + + access.cancelAll() + + XCTAssertFalse(access.ongoingAccess) + await fulfillment(of: [expectation]) + + XCTAssert(handle.isCancelled, "Task was not marked as cancelled.") + } + + @MainActor + func testCancelAllNeverError() async throws { + let access = ManagedAsynchronousAccess() + + let expectation = XCTestExpectation(description: "task") + let handle = Task { + do { + try await access.perform {} + XCTFail("Expected cancellation to turn into a cancellation error") + } catch { + XCTAssertTrue(error is CancellationError) + } + expectation.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + + XCTAssertTrue(access.ongoingAccess) + XCTAssertFalse(handle.isCancelled) + + access.cancelAll() + + XCTAssertFalse(access.ongoingAccess) + await fulfillment(of: [expectation]) + XCTAssert(handle.isCancelled, "Task was not marked as cancelled.") + } + + func testResumeWithoutOngoingAccess() { + let access = ManagedAsynchronousAccess() + + let didResume = access.resume(returning: "No Access") + + XCTAssertFalse(didResume) + } + + @MainActor + func testResumeWithVoidValue() async throws { + let access = ManagedAsynchronousAccess() + + let expectation = XCTestExpectation(description: "task") + Task { + try await access.perform {} + expectation.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + + XCTAssertTrue(access.ongoingAccess) + + let didResume = access.resume() + + XCTAssertFalse(didResume) + XCTAssertFalse(access.ongoingAccess) + await fulfillment(of: [expectation]) + } + + + @MainActor + func testExclusiveAccess() async throws { + let access = ManagedAsynchronousAccess() + let expectedValue0 = "Success0" + let expectedValue1 = "Success1" + + let expectation0 = XCTestExpectation(description: "task0") + let expectation1 = XCTestExpectation(description: "task1") + + Task { + do { + let value = try await access.perform {} + XCTAssertEqual(value, expectedValue0) + } catch { + XCTFail("Unexpected error: \(error)") + } + expectation0.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + + Task { + do { + let value = try await access.perform {} + XCTAssertEqual(value, expectedValue1) + } catch { + XCTFail("Unexpected error: \(error)") + } + expectation1.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + + XCTAssertTrue(access.ongoingAccess) + + let didResume0 = access.resume(returning: expectedValue0) + + XCTAssertTrue(didResume0) + XCTAssertFalse(access.ongoingAccess) + + await fulfillment(of: [expectation0]) + + try await Task.sleep(for: .milliseconds(100)) + + XCTAssertTrue(access.ongoingAccess) + + let didResume1 = access.resume(returning: expectedValue1) + + XCTAssertFalse(didResume1) + XCTAssertFalse(access.ongoingAccess) + + await fulfillment(of: [expectation1]) + } + + @MainActor + func testExclusiveAccessNeverError() async throws { + let access = ManagedAsynchronousAccess() + let expectedValue0 = "Success0" + let expectedValue1 = "Success1" + + let expectation0 = XCTestExpectation(description: "task0") + let expectation1 = XCTestExpectation(description: "task1") + + Task { + do { + let value = try await access.perform {} + XCTAssertEqual(value, expectedValue0) + } catch is CancellationError { + XCTFail("Unexpected error cancellation") + } + expectation0.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + + Task { + do { + let value = try await access.perform {} + XCTAssertEqual(value, expectedValue1) + } catch is CancellationError { + XCTFail("Unexpected error cancellation") + } + expectation1.fulfill() + } + + try await Task.sleep(for: .milliseconds(100)) + + XCTAssertTrue(access.ongoingAccess) + + let didResume0 = access.resume(returning: expectedValue0) + + XCTAssertTrue(didResume0) + XCTAssertFalse(access.ongoingAccess) + + await fulfillment(of: [expectation0]) + + try await Task.sleep(for: .milliseconds(100)) + + XCTAssertTrue(access.ongoingAccess) + + let didResume1 = access.resume(returning: expectedValue1) + + XCTAssertFalse(didResume1) + XCTAssertFalse(access.ongoingAccess) + + await fulfillment(of: [expectation1]) + } +} diff --git a/Tests/SpeziFoundationTests/RWLockTests.swift b/Tests/SpeziFoundationTests/RWLockTests.swift new file mode 100644 index 0000000..ddbbc30 --- /dev/null +++ b/Tests/SpeziFoundationTests/RWLockTests.swift @@ -0,0 +1,256 @@ +// +// This source file is part of the Stanford Spezi open-source project +// +// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) +// +// SPDX-License-Identifier: MIT +// + +import SpeziFoundation +import XCTest + + +final class RWLockTests: XCTestCase { + func testConcurrentReads() { + let lock = RWLock() + let expectation1 = self.expectation(description: "First read") + let expectation2 = self.expectation(description: "Second read") + + Task.detached { + lock.withReadLock { + usleep(100_000) // Simulate read delay (200ms) + expectation1.fulfill() + } + } + + Task.detached { + lock.withReadLock { + usleep(100_000) // Simulate read delay (200ms) + expectation2.fulfill() + } + } + + wait(for: [expectation1, expectation2], timeout: 1.0) + } + + func testWriteBlocksOtherWrites() { + let lock = RWLock() + let expectation1 = self.expectation(description: "First write") + let expectation2 = self.expectation(description: "Second write") + + Task.detached { + lock.withWriteLock { + usleep(200_000) // Simulate write delay (200ms) + expectation1.fulfill() + } + } + + Task.detached { + try await Task.sleep(for: .milliseconds(100)) + lock.withWriteLock { + expectation2.fulfill() + } + } + + wait(for: [expectation1, expectation2], timeout: 1.0) + } + + func testWriteBlocksReads() { + let lock = RWLock() + let expectation1 = self.expectation(description: "Write") + let expectation2 = self.expectation(description: "Read") + + Task.detached { + lock.withWriteLock { + usleep(200_000) // Simulate write delay (200ms) + expectation1.fulfill() + } + } + + Task.detached { + try await Task.sleep(for: .milliseconds(100)) + lock.withReadLock { + expectation2.fulfill() + } + } + + wait(for: [expectation1, expectation2], timeout: 1.0) + } + + func testIsWriteLocked() { + let lock = RWLock() + + Task.detached { + lock.withWriteLock { + XCTAssertTrue(lock.isWriteLocked()) + usleep(100_000) // Simulate write delay (100ms) + } + } + + usleep(50_000) // Give the other thread time to lock (50ms) + XCTAssertFalse(lock.isWriteLocked()) + } + + func testMultipleLocksAcquired() { + let lock1 = RWLock() + let lock2 = RWLock() + let expectation1 = self.expectation(description: "Read") + + Task.detached { + lock1.withReadLock { + lock2.withReadLock { + expectation1.fulfill() + } + } + } + + wait(for: [expectation1], timeout: 1.0) + } + + + func testConcurrentReadsRecursive() { + let lock = RecursiveRWLock() + let expectation1 = self.expectation(description: "First read") + let expectation2 = self.expectation(description: "Second read") + + Task.detached { + lock.withReadLock { + usleep(100_000) // Simulate read delay 100 ms + expectation1.fulfill() + } + } + + Task.detached { + lock.withReadLock { + usleep(100_000) // Simulate read delay 100ms + expectation2.fulfill() + } + } + + wait(for: [expectation1, expectation2], timeout: 1.0) + } + + func testWriteBlocksOtherWritesRecursive() { + let lock = RecursiveRWLock() + let expectation1 = self.expectation(description: "First write") + let expectation2 = self.expectation(description: "Second write") + + Task.detached { + lock.withWriteLock { + usleep(200_000) // Simulate write delay 200ms + expectation1.fulfill() + } + } + + Task.detached { + try await Task.sleep(for: .milliseconds(100)) + lock.withWriteLock { + expectation2.fulfill() + } + } + + wait(for: [expectation1, expectation2], timeout: 1.0) + } + + func testWriteBlocksReadsRecursive() { + let lock = RecursiveRWLock() + let expectation1 = self.expectation(description: "Write") + let expectation2 = self.expectation(description: "Read") + + Task.detached { + lock.withWriteLock { + usleep(200_000) // Simulate write delay 200 ms + expectation1.fulfill() + } + } + + Task.detached { + try await Task.sleep(for: .milliseconds(100)) + lock.withReadLock { + expectation2.fulfill() + } + } + + wait(for: [expectation1, expectation2], timeout: 1.0) + } + + func testMultipleLocksAcquiredRecursive() { + let lock1 = RecursiveRWLock() + let lock2 = RecursiveRWLock() + let expectation1 = self.expectation(description: "Read") + + Task.detached { + lock1.withReadLock { + lock2.withReadLock { + expectation1.fulfill() + } + } + } + + wait(for: [expectation1], timeout: 1.0) + } + + func testRecursiveReadReadAcquisition() { + let lock = RecursiveRWLock() + let expectation1 = self.expectation(description: "Read") + + Task.detached { + lock.withReadLock { + lock.withReadLock { + expectation1.fulfill() + } + } + } + + wait(for: [expectation1], timeout: 1.0) + } + + func testRecursiveWriteRecursiveAcquisition() { + let lock = RecursiveRWLock() + let expectation1 = self.expectation(description: "Read") + let expectation2 = self.expectation(description: "ReadWrite") + let expectation3 = self.expectation(description: "WriteRead") + let expectation4 = self.expectation(description: "Write") + + let expectation5 = self.expectation(description: "Race") + + Task.detached { + lock.withWriteLock { + usleep(50_000) // Simulate write delay 50 ms + lock.withReadLock { + expectation1.fulfill() + usleep(200_000) // Simulate write delay 200 ms + lock.withWriteLock { + expectation2.fulfill() + } + } + + lock.withWriteLock { + usleep(200_000) // Simulate write delay 200 ms + lock.withReadLock { + expectation3.fulfill() + } + expectation4.fulfill() + } + } + } + + Task.detached { + await withDiscardingTaskGroup { group in + for _ in 0..<10 { + group.addTask { + // random sleep up to 50 ms + try? await Task.sleep(nanoseconds: UInt64.random(in: 0...50_000_000)) + lock.withWriteLock { + _ = usleep(100) + } + } + } + } + + expectation5.fulfill() + } + + wait(for: [expectation1, expectation2, expectation3, expectation4, expectation5], timeout: 20.0) + } +} diff --git a/Tests/SpeziFoundationTests/TimeoutTests.swift b/Tests/SpeziFoundationTests/TimeoutTests.swift index ce516b8..933da6a 100644 --- a/Tests/SpeziFoundationTests/TimeoutTests.swift +++ b/Tests/SpeziFoundationTests/TimeoutTests.swift @@ -12,31 +12,38 @@ import XCTest final class TimeoutTests: XCTestCase { - @MainActor private var continuation: CheckedContinuation? + @MainActor + private final class Storage { + var continuation: CheckedContinuation? + } + + private let storage = Storage() @MainActor func operation(for duration: Duration) { Task { @MainActor in try? await Task.sleep(for: duration) - if let continuation = self.continuation { + if let continuation = storage.continuation { continuation.resume() - self.continuation = nil + storage.continuation = nil } } } @MainActor func operationMethod(timeout: Duration, operation: Duration, timeoutExpectation: XCTestExpectation) async throws { - async let _ = withTimeout(of: timeout) { @MainActor in + let storage = storage + async let _ = withTimeout(of: timeout) { @MainActor [storage] in + XCTAssertFalse(Task.isCancelled) timeoutExpectation.fulfill() - if let continuation { + if let continuation = storage.continuation { + storage.continuation = nil continuation.resume(throwing: TimeoutError()) - self.continuation = nil } } try await withCheckedThrowingContinuation { continuation in - self.continuation = continuation + storage.continuation = continuation self.operation(for: operation) } }