Skip to content

Commit

Permalink
Fix stream cancelation cancels underlying stream (#395)
Browse files Browse the repository at this point in the history
* update the stream to correctly terminate the underlying stream when canceled

* fix up a bunch of test warnings

* update group streaming as well

* bump the pod

* make sure all streams cancel correctly
  • Loading branch information
nplasterer authored Sep 3, 2024
1 parent 3fe600e commit bcce4f3
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 103 deletions.
130 changes: 96 additions & 34 deletions Sources/XMTPiOS/Conversations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,52 @@ public actor Conversations {

public func streamGroups() async throws -> AsyncThrowingStream<Group, Error> {
AsyncThrowingStream { continuation in
Task {
let task = Task {
let groupCallback = GroupStreamCallback(client: self.client) { group in
guard !Task.isCancelled else {
continuation.finish()
return
}
continuation.yield(group)
}
guard let stream = try await self.client.v3Client?.conversations().stream(callback: groupCallback) else {
guard let stream = await self.client.v3Client?.conversations().stream(callback: groupCallback) else {
continuation.finish(throwing: GroupError.streamingFailure)
return
}

self.streamHolder.stream = stream
continuation.onTermination = { @Sendable reason in
stream.end()
}
}

continuation.onTermination = { @Sendable reason in
task.cancel()
self.streamHolder.stream?.end()
}
}
}

private func streamGroupConversations() -> AsyncThrowingStream<Conversation, Error> {
AsyncThrowingStream { continuation in
Task {
self.streamHolder.stream = try await self.client.v3Client?.conversations().stream(
let task = Task {
self.streamHolder.stream = await self.client.v3Client?.conversations().stream(
callback: GroupStreamCallback(client: self.client) { group in
guard !Task.isCancelled else {
continuation.finish()
return
}
continuation.yield(Conversation.group(group))
}
)
continuation.onTermination = { @Sendable reason in
self.streamHolder.stream?.end()
}
}

continuation.onTermination = { @Sendable reason in
task.cancel()
self.streamHolder.stream?.end()
}
}
}
Expand Down Expand Up @@ -383,63 +406,88 @@ public actor Conversations {

public func streamAllGroupMessages() -> AsyncThrowingStream<DecodedMessage, Error> {
AsyncThrowingStream { continuation in
Task {
let messageCallback = MessageCallback(client: self.client) { message in
if let decodedMessage = MessageV3(client: self.client, ffiMessage: message).decodeOrNull() {
continuation.yield(decodedMessage)
let task = Task {
self.streamHolder.stream = await self.client.v3Client?.conversations().streamAllMessages(
messageCallback: MessageCallback(client: self.client) { message in
guard !Task.isCancelled else {
continuation.finish()
self.streamHolder.stream?.end() // End the stream upon cancellation
return
}
do {
continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decode())
} catch {
print("Error onMessage \(error)")
}
}
}
guard let stream = try await client.v3Client?.conversations().streamAllMessages(messageCallback: messageCallback) else {
continuation.finish(throwing: GroupError.streamingFailure)
return
}
continuation.onTermination = { @Sendable reason in
stream.end()
}
)
}

continuation.onTermination = { _ in
task.cancel()
self.streamHolder.stream?.end()
}
}
}

public func streamAllMessages(includeGroups: Bool = false) async throws -> AsyncThrowingStream<DecodedMessage, Error> {
public func streamAllMessages(includeGroups: Bool = false) -> AsyncThrowingStream<DecodedMessage, Error> {
AsyncThrowingStream<DecodedMessage, Error> { continuation in
@Sendable func forwardStreamToMerged(stream: AsyncThrowingStream<DecodedMessage, Error>) async {
do {
var iterator = stream.makeAsyncIterator()
while let element = try await iterator.next() {
guard !Task.isCancelled else {
continuation.finish()
self.streamHolder.stream?.end()
return
}
continuation.yield(element)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
Task {

let task = Task {
await forwardStreamToMerged(stream: streamAllV2Messages())
}

if includeGroups {
Task {
await forwardStreamToMerged(stream: streamAllGroupMessages())
}
}

continuation.onTermination = { _ in
task.cancel()
self.streamHolder.stream?.end()
}
}
}

public func streamAllGroupDecryptedMessages() -> AsyncThrowingStream<DecryptedMessage, Error> {
AsyncThrowingStream { continuation in
Task {
do {
self.streamHolder.stream = try await self.client.v3Client?.conversations().streamAllMessages(
messageCallback: MessageCallback(client: self.client) { message in
do {
continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt())
} catch {
print("Error onMessage \(error)")
}
let task = Task {
self.streamHolder.stream = await self.client.v3Client?.conversations().streamAllMessages(
messageCallback: MessageCallback(client: self.client) { message in
guard !Task.isCancelled else {
continuation.finish()
self.streamHolder.stream?.end() // End the stream upon cancellation
return
}
)
} catch {
print("STREAM ERR: \(error)")
}
do {
continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt())
} catch {
print("Error onMessage \(error)")
}
}
)
}

continuation.onTermination = { _ in
task.cancel()
self.streamHolder.stream?.end()
}
}
}
Expand All @@ -450,23 +498,37 @@ public actor Conversations {
do {
var iterator = stream.makeAsyncIterator()
while let element = try await iterator.next() {
guard !Task.isCancelled else {
continuation.finish()
self.streamHolder.stream?.end()
return
}
continuation.yield(element)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
Task {
await forwardStreamToMerged(stream: try streamAllV2DecryptedMessages())

let task = Task {
await forwardStreamToMerged(stream: streamAllV2DecryptedMessages())
}
if (includeGroups) {

if includeGroups {
Task {
await forwardStreamToMerged(stream: streamAllGroupDecryptedMessages())
}
}

continuation.onTermination = { _ in
task.cancel()
self.streamHolder.stream?.end()
}
}
}




func streamAllV2DecryptedMessages() -> AsyncThrowingStream<DecryptedMessage, Error> {
Expand Down
68 changes: 44 additions & 24 deletions Sources/XMTPiOS/Group.swift
Original file line number Diff line number Diff line change
Expand Up @@ -293,41 +293,61 @@ public struct Group: Identifiable, Equatable, Hashable {

public func streamMessages() -> AsyncThrowingStream<DecodedMessage, Error> {
AsyncThrowingStream { continuation in
Task.detached {
do {
self.streamHolder.stream = try await ffiGroup.stream(
messageCallback: MessageCallback(client: self.client) { message in
do {
continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decode())
} catch {
print("Error onMessage \(error)")
}
let task = Task.detached {
self.streamHolder.stream = await self.ffiGroup.stream(
messageCallback: MessageCallback(client: self.client) { message in
guard !Task.isCancelled else {
continuation.finish()
return
}
)
} catch {
print("STREAM ERR: \(error)")
do {
continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decode())
} catch {
print("Error onMessage \(error)")
continuation.finish(throwing: error)
}
}
)

continuation.onTermination = { @Sendable reason in
self.streamHolder.stream?.end()
}
}

continuation.onTermination = { @Sendable reason in
task.cancel()
self.streamHolder.stream?.end()
}
}
}

public func streamDecryptedMessages() -> AsyncThrowingStream<DecryptedMessage, Error> {
AsyncThrowingStream { continuation in
Task.detached {
do {
self.streamHolder.stream = try await ffiGroup.stream(
messageCallback: MessageCallback(client: self.client) { message in
do {
continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt())
} catch {
print("Error onMessage \(error)")
}
let task = Task.detached {
self.streamHolder.stream = await self.ffiGroup.stream(
messageCallback: MessageCallback(client: self.client) { message in
guard !Task.isCancelled else {
continuation.finish()
return
}
)
} catch {
print("STREAM ERR: \(error)")
do {
continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt())
} catch {
print("Error onMessage \(error)")
continuation.finish(throwing: error)
}
}
)

continuation.onTermination = { @Sendable reason in
self.streamHolder.stream?.end()
}
}

continuation.onTermination = { @Sendable reason in
task.cancel()
self.streamHolder.stream?.end()
}
}
}

Expand Down
18 changes: 9 additions & 9 deletions Tests/XMTPTests/ConversationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ class ConversationTests: XCTestCase {
}

func testDoesNotAllowConversationWithSelf() async throws {
let expectation = expectation(description: "convo with self throws")
let expectation = XCTestExpectation(description: "convo with self throws")
let client = aliceClient!

do {
try await client.conversations.newConversation(with: alice.walletAddress)
_ = try await client.conversations.newConversation(with: alice.walletAddress)
} catch {
expectation.fulfill()
}

wait(for: [expectation], timeout: 0.1)
await fulfillment(of: [expectation], timeout: 3)
}

func testCanStreamConversationsV2() async throws {
Expand All @@ -103,7 +103,7 @@ class ConversationTests: XCTestCase {

let wallet2 = try PrivateKey.generate()
let client2 = try await Client.create(account: wallet2, options: options)
let expectation1 = expectation(description: "got a conversation")
let expectation1 = XCTestExpectation(description: "got a conversation")
expectation1.expectedFulfillmentCount = 2

Task(priority: .userInitiated) {
Expand Down Expand Up @@ -140,7 +140,7 @@ class ConversationTests: XCTestCase {

try await conversation2.send(content: "hi from new wallet")

await waitForExpectations(timeout: 30)
await fulfillment(of: [expectation1], timeout: 30)
}

func publishLegacyContact(client: Client) async throws {
Expand All @@ -161,7 +161,7 @@ class ConversationTests: XCTestCase {
return
}

let expectation = expectation(description: "got a message")
let expectation = XCTestExpectation(description: "got a message")

Task(priority: .userInitiated) {
for try await message in conversation.streamMessages() {
Expand All @@ -174,7 +174,7 @@ class ConversationTests: XCTestCase {
// Stream a message
try await conversation.send(content: "hi alice")

await waitForExpectations(timeout: 3)
await fulfillment(of: [expectation], timeout: 3)
}

func testCanLoadV2Messages() async throws {
Expand Down Expand Up @@ -458,7 +458,7 @@ class ConversationTests: XCTestCase {
XCTAssertTrue(isAllowed)

try await bobClient.contacts.deny(addresses: [alice.address])
try await bobClient.contacts.refreshConsentList()
_ = try await bobClient.contacts.refreshConsentList()

let isDenied = (try await bobConversation.consentState()) == .denied

Expand Down Expand Up @@ -491,7 +491,7 @@ class ConversationTests: XCTestCase {
XCTAssertTrue(isUnknown)

try await aliceConversation.send(content: "hey bob")
try await aliceClient.contacts.refreshConsentList()
_ = try await aliceClient.contacts.refreshConsentList()
let isNowAllowed = (try await aliceConversation.consentState()) == .allowed

// Conversations you send a message to get marked as allowed
Expand Down
Loading

0 comments on commit bcce4f3

Please sign in to comment.