From be312f1b8378b25da32dca1f36c664eacb21b851 Mon Sep 17 00:00:00 2001 From: Alsey Coleman Miller Date: Fri, 28 Apr 2023 01:43:50 -0700 Subject: [PATCH] Fixed `GATTCentral` --- Sources/GATT/GATTCentral.swift | 58 +++++++++++++++---------- Sources/GATT/GATTClientConnection.swift | 23 ++-------- Tests/GATTTests/GATTTests.swift | 6 +-- Tests/GATTTests/TestL2CAPSocket.swift | 15 ++++--- 4 files changed, 50 insertions(+), 52 deletions(-) diff --git a/Sources/GATT/GATTCentral.swift b/Sources/GATT/GATTCentral.swift index 509655b..367c8ad 100644 --- a/Sources/GATT/GATTCentral.swift +++ b/Sources/GATT/GATTCentral.swift @@ -88,11 +88,11 @@ public final class GATTCentral , HCILEAdvertisingReport.Report)]() + var connections = [Peripheral: (connection: GATTClientConnection, socket: Socket)](minimumCapacity: 2) + func found(_ report: HCILEAdvertisingReport.Report) -> ScanData { let peripheral = Peripheral(id: report.address) let scanData = ScanData( @@ -258,8 +274,6 @@ internal extension GATTCentral { return scanData } - var address: BluetoothAddress? - func readAddress(_ hostController: HostController) async throws -> BluetoothAddress { if let cachedAddress = self.address { return cachedAddress @@ -270,17 +284,15 @@ internal extension GATTCentral { } } - var connections = [Peripheral: GATTClientConnection](minimumCapacity: 2) - - func didConnect(_ connection: GATTClientConnection) { - self.connections[connection.peripheral] = connection + func didConnect(_ connection: GATTClientConnection, _ socket: Socket) { + self.connections[connection.peripheral] = (connection, socket) } - func removeConnection(_ peripheral: Peripheral) { + func removeConnection(_ peripheral: Peripheral) async { self.connections[peripheral] = nil } - func removeAllConnections() { + func removeAllConnections() async { self.connections.removeAll(keepingCapacity: true) } } diff --git a/Sources/GATT/GATTClientConnection.swift b/Sources/GATT/GATTClientConnection.swift index edb44c0..8000530 100644 --- a/Sources/GATT/GATTClientConnection.swift +++ b/Sources/GATT/GATTClientConnection.swift @@ -17,12 +17,10 @@ internal final class GATTClientConnection { let peripheral: Peripheral - private weak var delegate: GATTClientConnectionDelegate? - let client: GATTClient private let cache = Cache() - + var maximumUpdateValueLength: Int { get async { // ATT_MTU-3 @@ -36,17 +34,13 @@ internal final class GATTClientConnection { peripheral: Peripheral, socket: Socket, maximumTransmissionUnit: ATTMaximumTransmissionUnit, - delegate: GATTClientConnectionDelegate + log: ((String) -> ())? = nil ) async { self.peripheral = peripheral self.client = await GATTClient( socket: socket, maximumTransmissionUnit: maximumTransmissionUnit, - log: { [weak delegate] message in - delegate?.connection(peripheral, log: message) - }, didDisconnect: { [weak delegate] error in - await delegate?.connection(peripheral, didDisconnect: error) - } + log: log ) } @@ -221,21 +215,10 @@ internal final class GATTClientConnection { descriptors: descriptors ) } - - private func log(_ message: String) { - delegate?.connection(peripheral, log: message) - } } // MARK: - Supporting Types -internal protocol GATTClientConnectionDelegate: AnyObject { - - func connection(_ peripheral: Peripheral, log: String) - - func connection(_ peripheral: Peripheral, didDisconnect error: Swift.Error?) async -} - internal extension GATTClientConnection { typealias Cache = GATTClientConnectionCache diff --git a/Tests/GATTTests/GATTTests.swift b/Tests/GATTTests/GATTTests.swift index 640c23b..61deff0 100644 --- a/Tests/GATTTests/GATTTests.swift +++ b/Tests/GATTTests/GATTTests.swift @@ -102,9 +102,9 @@ final class GATTTests: XCTestCase { XCTAssertEqual(services.count, 0) let clientMTU = try await central.maximumTransmissionUnit(for: peripheral) XCTAssertEqual(clientMTU, finalMTU) - let maximumUpdateValueLength = await central.storage.connections.values.first?.maximumUpdateValueLength + let maximumUpdateValueLength = await central.storage.connections.first?.value.connection.maximumUpdateValueLength XCTAssertEqual(maximumUpdateValueLength, Int(finalMTU.rawValue) - 3) - let clientCache = await (central.storage.connections.values.first?.client.connection.socket as? TestL2CAPSocket)?.cache + let clientCache = await (central.storage.connections.values.first?.connection.client.connection.socket as? TestL2CAPSocket)?.cache XCTAssertEqual(clientCache?.prefix(1), mockData.client.prefix(1)) // not same because extra service discovery request } ) @@ -213,7 +213,7 @@ final class GATTTests: XCTestCase { else { XCTFail(); return } XCTAssertEqual(foundService.uuid, .batteryService) XCTAssertEqual(foundService.isPrimary, true) - let clientCache = await (central.storage.connections.values.first?.client.connection.socket as? TestL2CAPSocket)?.cache + let clientCache = await (central.storage.connections.values.first?.connection.client.connection.socket as? TestL2CAPSocket)?.cache XCTAssertEqual(clientCache, mockData.client) } ) diff --git a/Tests/GATTTests/TestL2CAPSocket.swift b/Tests/GATTTests/TestL2CAPSocket.swift index fd9fdd5..7404d51 100644 --- a/Tests/GATTTests/TestL2CAPSocket.swift +++ b/Tests/GATTTests/TestL2CAPSocket.swift @@ -13,7 +13,7 @@ import GATT /// Test L2CAP socket internal actor TestL2CAPSocket: L2CAPSocket { - + private actor Cache { static let shared = Cache() @@ -107,6 +107,10 @@ internal actor TestL2CAPSocket: L2CAPSocket { // MARK: - Methods + func close() async { + + } + func accept() async throws -> TestL2CAPSocket { // sleep until a client socket is created while (await Cache.shared.pendingClients[address] ?? []).isEmpty { @@ -129,7 +133,7 @@ internal actor TestL2CAPSocket: L2CAPSocket { else { throw POSIXError(.ECONNRESET) } await target.receive(data) - eventContinuation.yield(.write(data.count)) + eventContinuation.yield(.didWrite(data.count)) } /// Reads from the socket. @@ -145,18 +149,17 @@ internal actor TestL2CAPSocket: L2CAPSocket { let data = self.receivedData.removeFirst() cache.append(data) - eventContinuation.yield(.read(data.count)) - try await Task.sleep(nanoseconds: 1_000_000) + eventContinuation.yield(.didRead(data.count)) return data } fileprivate func receive(_ data: Data) { receivedData.append(data) print("L2CAP Socket: \(name) recieved \([UInt8](data))") - eventContinuation.yield(.pendingRead) + eventContinuation.yield(.read) } - fileprivate func connect(to socket: TestL2CAPSocket) { + internal func connect(to socket: TestL2CAPSocket) { self.target = socket } }