diff --git a/Sources/WebSockets/WebSocket.swift b/Sources/WebSockets/WebSocket.swift index 72c522c..f5e01b4 100644 --- a/Sources/WebSockets/WebSocket.swift +++ b/Sources/WebSockets/WebSocket.swift @@ -571,6 +571,12 @@ private extension WebSocket { continue } } + // The `AsyncThrowingStream` used by Connection finishes the stream (yielding `nil`) if the Task is + // canceled. However, we want an error to be thrown in this instance so that it's not confused with + // the server simply dropping the connection. + if Task.isCancelled { + throw WebSocketError.canceled + } if openingHandshakeDidExpire { throw WebSocketError.timeout } diff --git a/Sources/WebSockets/WebSocketError.swift b/Sources/WebSockets/WebSocketError.swift index 2414139..e0b8f44 100644 --- a/Sources/WebSockets/WebSocketError.swift +++ b/Sources/WebSockets/WebSocketError.swift @@ -53,6 +53,9 @@ public enum WebSocketError: Error { /// The handshake did not complete within the specified timeframe. case timeout + /// The `Task` performing the handshake was canceled. + case canceled + /// The redirect limit was exceeded. This usually indicates a redirect loop. case maximumRedirectsExceeded @@ -99,6 +102,8 @@ extension WebSocketError: CustomDebugStringConvertible { return "Unexpected disconnect" case .timeout: return "WebSocket handshake timed out" + case .canceled: + return "The task performing the WebSocket handshake was canceled" case .maximumRedirectsExceeded: return "Maximum number of HTTP redirects exceeded" case .invalidRedirectLocation(let location): diff --git a/Tests/WebSocketsTests/WebSocketTests.swift b/Tests/WebSocketsTests/WebSocketTests.swift index 5bd3f42..9e53962 100644 --- a/Tests/WebSocketsTests/WebSocketTests.swift +++ b/Tests/WebSocketsTests/WebSocketTests.swift @@ -550,6 +550,27 @@ class WebSocketTests: XCTestCase { XCTAssert(events == expected) } + func testHandshakeCancellation() async throws { + let server = TestServer() + defer { + Task { + await server.stop() + } + } + let socket = WebSocket(url: try await server.start(path: "/test")) + let t = Task { + await socket.send(text: "Hello") + } + t.cancel() + let _ = await t.value + do { + for try await _ in socket { + } + XCTFail("Expected an exception ") + } catch WebSocketError.canceled { + } + } + func expectCloseCode(quirk: QuirkyTestServer.Quirk, code: WebSocket.CloseCode, reason: String? = nil) async throws { let server = QuirkyTestServer(with: quirk) defer {