Skip to content

Commit

Permalink
Fixed GATTCentral
Browse files Browse the repository at this point in the history
  • Loading branch information
colemancda committed Apr 28, 2023
1 parent 6b88f5e commit be312f1
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 52 deletions.
58 changes: 35 additions & 23 deletions Sources/GATT/GATTCentral.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,34 +88,54 @@ public final class GATTCentral <HostController: BluetoothHostControllerInterface
}

public func connect(to peripheral: Peripheral) async throws {
// get scan data (bluetooth address) for new connection
// get scan data (Bluetooth address) for new connection
guard let (scanData, report) = await self.storage.scanData[peripheral]
else { throw CentralError.unknownPeripheral }
// log
self.log?("[\(scanData.peripheral)]: Open connection (\(report.addressType))")
self.log(scanData.peripheral, "Open connection (\(report.addressType))")
// load cache device address
let localAddress = try await storage.readAddress(hostController)
// open socket
let socket = try await Socket.lowEnergyClient(
address: localAddress,
destination: report
)
// keep connection open and store for future use
Task.detached { [weak self] in
for await event in socket.event {
switch event {
case let .error(error):
self?.log(peripheral, error.localizedDescription)
case .close:
self?.log(peripheral, "Did disconnect.")
default:
break
}
}
await self?.storage.removeConnection(peripheral)
}
let connection = await GATTClientConnection(
peripheral: peripheral,
socket: socket,
maximumTransmissionUnit: self.options.maximumTransmissionUnit,
delegate: self
log: { [weak self] in
self?.log(peripheral, $0)
}
)
// store connection
await self.storage.didConnect(connection)
await self.storage.didConnect(connection, socket)
}

public func disconnect(_ peripheral: Peripheral) async {
if let (_, socket) = await storage.connections[peripheral] {
await socket.close()
}
await storage.removeConnection(peripheral)
}

public func disconnectAll() async {
for (_, socket) in await storage.connections.values {
await socket.close()
}
await storage.removeAllConnections()
}

Expand Down Expand Up @@ -217,23 +237,15 @@ public final class GATTCentral <HostController: BluetoothHostControllerInterface
guard await storage.scanData.keys.contains(peripheral)
else { throw CentralError.unknownPeripheral }

guard let connection = await storage.connections[peripheral]
guard let (connection, _) = await storage.connections[peripheral]
else { throw CentralError.disconnected }

return connection
}
}

extension GATTCentral: GATTClientConnectionDelegate {

func connection(_ peripheral: Peripheral, log message: String) {
private func log(_ peripheral: Peripheral, _ message: String) {
log?("[\(peripheral)]: " + message)
}

func connection(_ peripheral: Peripheral, didDisconnect error: Swift.Error?) async {
await storage.removeConnection(peripheral)
log?("[\(peripheral)]: " + "did disconnect \(error?.localizedDescription ?? "")")
}
}

// MARK: - Supporting Types
Expand All @@ -243,8 +255,12 @@ internal extension GATTCentral {

actor Storage {

var address: BluetoothAddress?

var scanData = [Peripheral: (ScanData<Peripheral, Advertisement>, HCILEAdvertisingReport.Report)]()

var connections = [Peripheral: (connection: GATTClientConnection<Socket>, socket: Socket)](minimumCapacity: 2)

func found(_ report: HCILEAdvertisingReport.Report) -> ScanData<Peripheral, Advertisement> {
let peripheral = Peripheral(id: report.address)
let scanData = ScanData(
Expand All @@ -258,8 +274,6 @@ internal extension GATTCentral {
return scanData
}

var address: BluetoothAddress?

func readAddress(_ hostController: HostController) async throws -> BluetoothAddress {
if let cachedAddress = self.address {
return cachedAddress
Expand All @@ -270,17 +284,15 @@ internal extension GATTCentral {
}
}

var connections = [Peripheral: GATTClientConnection<Socket>](minimumCapacity: 2)

func didConnect(_ connection: GATTClientConnection<Socket>) {
self.connections[connection.peripheral] = connection
func didConnect(_ connection: GATTClientConnection<Socket>, _ socket: Socket) {
self.connections[connection.peripheral] = (connection, socket)
}

func removeConnection(_ peripheral: Peripheral) {
func removeConnection(_ peripheral: Peripheral) async {
self.connections[peripheral] = nil
}

func removeAllConnections() {
func removeAllConnections() async {
self.connections.removeAll(keepingCapacity: true)
}
}
Expand Down
23 changes: 3 additions & 20 deletions Sources/GATT/GATTClientConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ internal final class GATTClientConnection <Socket: L2CAPSocket> {

let peripheral: Peripheral

private weak var delegate: GATTClientConnectionDelegate?

let client: GATTClient

private let cache = Cache()

var maximumUpdateValueLength: Int {
get async {
// ATT_MTU-3
Expand All @@ -36,17 +34,13 @@ internal final class GATTClientConnection <Socket: L2CAPSocket> {
peripheral: Peripheral,
socket: Socket,
maximumTransmissionUnit: ATTMaximumTransmissionUnit,
delegate: GATTClientConnectionDelegate
log: ((String) -> ())? = nil
) async {
self.peripheral = peripheral
self.client = await GATTClient(
socket: socket,
maximumTransmissionUnit: maximumTransmissionUnit,
log: { [weak delegate] message in
delegate?.connection(peripheral, log: message)
}, didDisconnect: { [weak delegate] error in
await delegate?.connection(peripheral, didDisconnect: error)
}
log: log
)
}

Expand Down Expand Up @@ -221,21 +215,10 @@ internal final class GATTClientConnection <Socket: L2CAPSocket> {
descriptors: descriptors
)
}

private func log(_ message: String) {
delegate?.connection(peripheral, log: message)
}
}

// MARK: - Supporting Types

internal protocol GATTClientConnectionDelegate: AnyObject {

func connection(_ peripheral: Peripheral, log: String)

func connection(_ peripheral: Peripheral, didDisconnect error: Swift.Error?) async
}

internal extension GATTClientConnection {

typealias Cache = GATTClientConnectionCache
Expand Down
6 changes: 3 additions & 3 deletions Tests/GATTTests/GATTTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ final class GATTTests: XCTestCase {
XCTAssertEqual(services.count, 0)
let clientMTU = try await central.maximumTransmissionUnit(for: peripheral)
XCTAssertEqual(clientMTU, finalMTU)
let maximumUpdateValueLength = await central.storage.connections.values.first?.maximumUpdateValueLength
let maximumUpdateValueLength = await central.storage.connections.first?.value.connection.maximumUpdateValueLength
XCTAssertEqual(maximumUpdateValueLength, Int(finalMTU.rawValue) - 3)
let clientCache = await (central.storage.connections.values.first?.client.connection.socket as? TestL2CAPSocket)?.cache
let clientCache = await (central.storage.connections.values.first?.connection.client.connection.socket as? TestL2CAPSocket)?.cache
XCTAssertEqual(clientCache?.prefix(1), mockData.client.prefix(1)) // not same because extra service discovery request
}
)
Expand Down Expand Up @@ -213,7 +213,7 @@ final class GATTTests: XCTestCase {
else { XCTFail(); return }
XCTAssertEqual(foundService.uuid, .batteryService)
XCTAssertEqual(foundService.isPrimary, true)
let clientCache = await (central.storage.connections.values.first?.client.connection.socket as? TestL2CAPSocket)?.cache
let clientCache = await (central.storage.connections.values.first?.connection.client.connection.socket as? TestL2CAPSocket)?.cache
XCTAssertEqual(clientCache, mockData.client)
}
)
Expand Down
15 changes: 9 additions & 6 deletions Tests/GATTTests/TestL2CAPSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import GATT

/// Test L2CAP socket
internal actor TestL2CAPSocket: L2CAPSocket {

private actor Cache {

static let shared = Cache()
Expand Down Expand Up @@ -107,6 +107,10 @@ internal actor TestL2CAPSocket: L2CAPSocket {

// MARK: - Methods

func close() async {

}

func accept() async throws -> TestL2CAPSocket {
// sleep until a client socket is created
while (await Cache.shared.pendingClients[address] ?? []).isEmpty {
Expand All @@ -129,7 +133,7 @@ internal actor TestL2CAPSocket: L2CAPSocket {
else { throw POSIXError(.ECONNRESET) }

await target.receive(data)
eventContinuation.yield(.write(data.count))
eventContinuation.yield(.didWrite(data.count))
}

/// Reads from the socket.
Expand All @@ -145,18 +149,17 @@ internal actor TestL2CAPSocket: L2CAPSocket {

let data = self.receivedData.removeFirst()
cache.append(data)
eventContinuation.yield(.read(data.count))
try await Task.sleep(nanoseconds: 1_000_000)
eventContinuation.yield(.didRead(data.count))
return data
}

fileprivate func receive(_ data: Data) {
receivedData.append(data)
print("L2CAP Socket: \(name) recieved \([UInt8](data))")
eventContinuation.yield(.pendingRead)
eventContinuation.yield(.read)
}

fileprivate func connect(to socket: TestL2CAPSocket) {
internal func connect(to socket: TestL2CAPSocket) {
self.target = socket
}
}
Expand Down

0 comments on commit be312f1

Please sign in to comment.