Skip to content

Commit

Permalink
fix(api): propagate connectionLost error from websocket client to syn…
Browse files Browse the repository at this point in the history
…c engine (#3800)

* restart remote sync engine when websocket has connection error

* test(api): fix broken test cases

* add unit test case

* fix(datastore): propagate connectionLost URLError
  • Loading branch information
5d authored Aug 9, 2024
1 parent 694aae0 commit 00aac42
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ extension AppSyncRealTimeClient {
// listen to response
self.subject
.setFailureType(to: AppSyncRealTimeRequest.Error.self)
.flatMap { Self.filterResponse(request: request, response: $0) }
.flatMap { Self.filterResponse(request: request, result: $0) }
.timeout(.seconds(timeout), scheduler: DispatchQueue.global(qos: .userInitiated), customError: { .timeout })
.first()
.sink(receiveCompletion: { completion in
Expand Down Expand Up @@ -65,47 +65,59 @@ extension AppSyncRealTimeClient {

private static func filterResponse(
request: AppSyncRealTimeRequest,
response: AppSyncRealTimeResponse
result: Result<AppSyncRealTimeResponse, Error>
) -> AnyPublisher<AppSyncRealTimeResponse, AppSyncRealTimeRequest.Error> {
let justTheResponse = Just(response)
.setFailureType(to: AppSyncRealTimeRequest.Error.self)
.eraseToAnyPublisher()

switch (request, response.type) {
case (.connectionInit, .connectionAck):
return justTheResponse

case (.start(let startRequest), .startAck) where startRequest.id == response.id:
return justTheResponse

case (.stop(let id), .stopAck) where id == response.id:
return justTheResponse

case (_, .error)
where request.id != nil
&& request.id == response.id
&& response.payload?.errors != nil:
let errorsJson: JSONValue = (response.payload?.errors)!
let errors = errorsJson.asArray ?? [errorsJson]
let reqeustErrors = errors.compactMap(AppSyncRealTimeRequest.parseResponseError(error:))
if reqeustErrors.isEmpty {

switch result {
case .success(let response):
let justTheResponse = Just(response)
.setFailureType(to: AppSyncRealTimeRequest.Error.self)
.eraseToAnyPublisher()

switch (request, response.type) {
case (.connectionInit, .connectionAck):
return justTheResponse

case (.start(let startRequest), .startAck) where startRequest.id == response.id:
return justTheResponse

case (.stop(let id), .stopAck) where id == response.id:
return justTheResponse

case (_, .error)
where request.id != nil
&& request.id == response.id
&& response.payload?.errors != nil:
let errorsJson: JSONValue = (response.payload?.errors)!
let errors = errorsJson.asArray ?? [errorsJson]
let reqeustErrors = errors.compactMap(AppSyncRealTimeRequest.parseResponseError(error:))
if reqeustErrors.isEmpty {
return Empty(
outputType: AppSyncRealTimeResponse.self,
failureType: AppSyncRealTimeRequest.Error.self
).eraseToAnyPublisher()
} else {
return Fail(
outputType: AppSyncRealTimeResponse.self,
failure: reqeustErrors.first!
).eraseToAnyPublisher()
}

default:
return Empty(
outputType: AppSyncRealTimeResponse.self,
failureType: AppSyncRealTimeRequest.Error.self
).eraseToAnyPublisher()
} else {
return Fail(
outputType: AppSyncRealTimeResponse.self,
failure: reqeustErrors.first!
).eraseToAnyPublisher()

}

default:
return Empty(
case .failure:
return Fail(
outputType: AppSyncRealTimeResponse.self,
failureType: AppSyncRealTimeRequest.Error.self
failure: .timeout
).eraseToAnyPublisher()

}


}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ actor AppSyncRealTimeClient: AppSyncRealTimeClientProtocol {
/// WebSocketClient offering connections at the WebSocket protocol level
internal var webSocketClient: AppSyncWebSocketClientProtocol
/// Writable data stream convert WebSocketEvent to AppSyncRealTimeResponse
internal let subject = PassthroughSubject<AppSyncRealTimeResponse, Never>()
internal let subject = PassthroughSubject<Result<AppSyncRealTimeResponse, Error>, Never>()

var isConnected: Bool {
self.state.value == .connected
Expand Down Expand Up @@ -283,15 +283,25 @@ actor AppSyncRealTimeClient: AppSyncRealTimeClientProtocol {
private func filterAppSyncSubscriptionEvent(
with id: String
) -> AnyPublisher<AppSyncSubscriptionEvent, Never> {
subject.filter { $0.id == id || $0.type == .connectionError }
.map { response -> AppSyncSubscriptionEvent? in
switch response.type {
case .connectionError, .error:
return .error(Self.decodeAppSyncRealTimeResponseError(response.payload))
case .data:
return response.payload.map { .data($0) }
default:
return nil
subject.filter {
switch $0 {
case .success(let response): return response.id == id || response.type == .connectionError
case .failure(let error): return true
}
}
.map { result -> AppSyncSubscriptionEvent? in
switch result {
case .success(let response):
switch response.type {
case .connectionError, .error:
return .error(Self.decodeAppSyncRealTimeResponseError(response.payload))
case .data:
return response.payload.map { .data($0) }
default:
return nil
}
case .failure(let error):
return .error([error])
}
}
.compactMap { $0 }
Expand Down Expand Up @@ -368,9 +378,9 @@ extension AppSyncRealTimeClient {
self.cancellablesBindToConnection = Set()

case .error(let error):
// Since we've activated auto-reconnect functionality in WebSocketClient upon connection failure,
// we only record errors here for debugging purposes.
// Propagate connection error to downstream for Sync engine to restart
log.debug("[AppSyncRealTimeClient] WebSocket error event: \(error)")
self.subject.send(.failure(error))
case .string(let string):
guard let data = string.data(using: .utf8) else {
log.debug("[AppSyncRealTimeClient] Failed to decode string \(string)")
Expand Down Expand Up @@ -400,7 +410,7 @@ extension AppSyncRealTimeClient {
switch event.type {
case .connectionAck:
log.debug("[AppSyncRealTimeClient] AppSync connected: \(String(describing: event.payload))")
subject.send(event)
subject.send(.success(event))

self.resumeExistingSubscriptions()
self.state.send(.connected)
Expand All @@ -411,7 +421,7 @@ extension AppSyncRealTimeClient {

default:
log.debug("[AppSyncRealTimeClient] AppSync received response: \(event)")
subject.send(event)
subject.send(.success(event))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import Amplify
import Foundation
import AWSPluginsCore
@_spi(WebSocket) import AWSPluginsCore
import InternalAmplifyCredentials
import Combine


public class AWSGraphQLSubscriptionTaskRunner<R: Decodable>: InternalTaskRunner, InternalTaskAsyncThrowingSequence, InternalTaskThrowingChannel {
public typealias Request = GraphQLOperationRequest<R>
public typealias InProcess = GraphQLSubscriptionEvent<R>
Expand Down Expand Up @@ -387,32 +388,7 @@ fileprivate func toAPIError<R: Decodable>(_ errors: [Error], type: R.Type) -> AP
"Subscription item event failed with error" +
(hasAuthorizationError ? ": \(APIError.UnauthorizedMessageString)" : "")
}

#if swift(<5.8)
if let errors = errors.cast(to: AppSyncRealTimeRequest.Error.self) {
let hasAuthorizationError = errors.contains(where: { $0 == .unauthorized})
return APIError.operationError(
errorDescription(hasAuthorizationError),
"",
errors.first
)
} else if let errors = errors.cast(to: GraphQLError.self) {
let hasAuthorizationError = errors.map(\.extensions)
.compactMap { $0.flatMap { $0["errorType"]?.stringValue } }
.contains(where: { AppSyncErrorType($0) == .unauthorized })
return APIError.operationError(
errorDescription(hasAuthorizationError),
"",
GraphQLResponseError<R>.error(errors)
)
} else {
return APIError.operationError(
errorDescription(),
"",
errors.first
)
}
#else

switch errors {
case let errors as [AppSyncRealTimeRequest.Error]:
let hasAuthorizationError = errors.contains(where: { $0 == .unauthorized})
Expand All @@ -430,12 +406,14 @@ fileprivate func toAPIError<R: Decodable>(_ errors: [Error], type: R.Type) -> AP
"",
GraphQLResponseError<R>.error(errors)
)

case let errors as [WebSocketClient.Error]:
return APIError.networkError("WebSocketClient connection aborted", nil, URLError(.networkConnectionLost))
default:
return APIError.operationError(
errorDescription(),
"",
errors.first
)
}
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class AppSyncRealTimeClientTests: XCTestCase {
}
Task {
try await Task.sleep(nanoseconds: 80 * 1000)
await appSyncClient.subject.send(.init(id: nil, payload: nil, type: .connectionAck))
await appSyncClient.subject.send(.success(.init(id: nil, payload: nil, type: .connectionAck)))
}
await fulfillment(of: [finishExpectation], timeout: timeout + 1)
}
Expand Down Expand Up @@ -91,7 +91,7 @@ class AppSyncRealTimeClientTests: XCTestCase {
}
Task {
try await Task.sleep(nanoseconds: 80 * 1000)
await appSyncClient.subject.send(.init(
await appSyncClient.subject.send(.success(.init(
id: id,
payload: .object([
"errors": .array([
Expand All @@ -101,7 +101,7 @@ class AppSyncRealTimeClientTests: XCTestCase {
])
]),
type: .error
))
)))
}
await fulfillment(of: [limitExceededErrorExpectation], timeout: timeout + 1)
}
Expand Down Expand Up @@ -134,7 +134,7 @@ class AppSyncRealTimeClientTests: XCTestCase {

Task {
try await Task.sleep(nanoseconds: 80 * 1000)
await appSyncClient.subject.send(.init(
await appSyncClient.subject.send(.success(.init(
id: id,
payload: .object([
"errors": .array([
Expand All @@ -144,7 +144,7 @@ class AppSyncRealTimeClientTests: XCTestCase {
])
]),
type: .error
))
)))
}
await fulfillment(of: [
maxSubscriptionsReachedExpectation
Expand Down Expand Up @@ -181,7 +181,7 @@ class AppSyncRealTimeClientTests: XCTestCase {

Task {
try await Task.sleep(nanoseconds: 80 * 1000)
await appSyncClient.subject.send(.init(
await appSyncClient.subject.send(.success(.init(
id: id,
payload: .object([
"errors": .array([
Expand All @@ -191,7 +191,7 @@ class AppSyncRealTimeClientTests: XCTestCase {
])
]),
type: .error
))
)))
}
await fulfillment(of: [
triggerUnknownErrorExpectation
Expand Down Expand Up @@ -487,4 +487,68 @@ class AppSyncRealTimeClientTests: XCTestCase {
startTriggered
], timeout: 3, enforceOrder: true)
}

func testNetworkInterrupt_withAppSyncRealTimeClientConnected_triggersApiNetworkError() async throws {
var cancellables = Set<AnyCancellable>()
let mockWebSocketClient = MockWebSocketClient()
let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor()
let appSyncClient = AppSyncRealTimeClient(
endpoint: URL(string: "https://example.com")!,
requestInterceptor: mockAppSyncRequestInterceptor,
webSocketClient: mockWebSocketClient
)
let id = UUID().uuidString
let query = UUID().uuidString

let startTriggered = expectation(description: "webSocket writing start event to connection")
let errorReceived = expectation(description: "webSocket connection lost error is received")

await mockWebSocketClient.setStateToConnected()
Task {
try await Task.sleep(nanoseconds: 80 * 1_000_000)
await mockWebSocketClient.subject.send(.connected)
try await Task.sleep(nanoseconds: 80 * 1_000_000)
await mockWebSocketClient.subject.send(.string("""
{"type": "connection_ack", "payload": { "connectionTimeoutMs": 300000 }}
"""))
try await Task.sleep(nanoseconds: 80 * 1_000_000)
await mockWebSocketClient.subject.send(.error(WebSocketClient.Error.connectionLost))
}
try await appSyncClient.subscribe(id: id, query: query).sink { event in
if case .error(let errors) = event,
errors.count == 1,
let error = errors.first,
let connectionLostError = error as? WebSocketClient.Error,
connectionLostError == WebSocketClient.Error.connectionLost
{
errorReceived.fulfill()
}
}.store(in: &cancellables)
await mockWebSocketClient.actionSubject
.sink { action in
switch action {
case .write(let message):
guard let response = try? JSONDecoder().decode(
JSONValue.self,
from: message.data(using: .utf8)!
) else {
XCTFail("Response should be able to decode to AppSyncRealTimeResponse")
return
}

if response.type?.stringValue == "start" {
XCTAssertEqual(response.id?.stringValue, id)
XCTAssertEqual(response.payload?.asObject?["data"]?.stringValue, query)
startTriggered.fulfill()
}

default:
break
}
}
.store(in: &cancellables)

await fulfillment(of: [startTriggered, errorReceived], timeout: 2)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0
//


import Network
import Combine

Expand Down Expand Up @@ -38,6 +37,7 @@ public final class AmplifyNetworkMonitor {
label: "com.amazonaws.amplify.ios.network.websocket.monitor",
qos: .userInitiated
))

}

public func updateState(_ nextState: State) {
Expand All @@ -48,4 +48,5 @@ public final class AmplifyNetworkMonitor {
subject.send(completion: .finished)
monitor.cancel()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ extension RemoteSyncEngine {
}

func scheduleRestartOrTerminate(error: AmplifyError) {
Self.log.debug("scheduling restart or terminate on error: \(error)")
let advice = getRetryAdvice(error: error)
if advice.shouldRetry {
scheduleRestart(advice: advice)
Expand Down

0 comments on commit 00aac42

Please sign in to comment.