Skip to content

Commit

Permalink
Let each client have its own CodecRegistry (#162)
Browse files Browse the repository at this point in the history
* Give each client their own codec registry

This should prevent crashes we've seen where many clients try to
register with the global at once

* Bump podspec
  • Loading branch information
nakajima authored Sep 20, 2023
1 parent 7c628a4 commit c9e809c
Show file tree
Hide file tree
Showing 29 changed files with 192 additions and 173 deletions.
2 changes: 1 addition & 1 deletion Sources/XMTP/ApiClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public enum ApiClientError: Error {
case subscribeError(String)
}

protocol ApiClient {
protocol ApiClient: Sendable {
var environment: XMTPEnvironment { get }
init(environment: XMTPEnvironment, secure: Bool, rustClient: XMTPRust.RustClient, appVersion: String?) throws
func setAuthToken(_ token: String)
Expand Down
16 changes: 6 additions & 10 deletions Sources/XMTP/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ public struct ClientOptions {
/// 2. To sign a random salt used to encrypt the key bundle in storage. This happens every time the client is started, including the very first time).
///
/// > Important: The client connects to the XMTP `dev` environment by default. Use ``ClientOptions`` to change this and other parameters of the network connection.
public class Client {
public final class Client: Sendable {
/// The wallet address of the ``SigningKey`` used to create this Client.
public var address: String
var privateKeyBundleV1: PrivateKeyBundleV1
var apiClient: ApiClient
public let address: String
let privateKeyBundleV1: PrivateKeyBundleV1
let apiClient: ApiClient

/// Access ``Conversations`` for this Client.
public lazy var conversations: Conversations = .init(client: self)
Expand All @@ -67,13 +67,9 @@ public class Client {
apiClient.environment
}

static var codecRegistry = {
var registry = CodecRegistry()
registry.register(codec: TextCodec())
return registry
}()
var codecRegistry = CodecRegistry()

public static func register(codec: any ContentCodec) {
public func register(codec: any ContentCodec) {
codecRegistry.register(codec: codec)
}

Expand Down
2 changes: 1 addition & 1 deletion Sources/XMTP/CodecRegistry.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import Foundation

struct CodecRegistry {
var codecs: [String: any ContentCodec] = [:]
var codecs: [String: any ContentCodec] = [TextCodec().id: TextCodec()]

mutating func register(codec: any ContentCodec) {
codecs[codec.id] = codec
Expand Down
4 changes: 2 additions & 2 deletions Sources/XMTP/Codecs/AttachmentCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public struct AttachmentCodec: ContentCodec {

public var contentType = ContentTypeAttachment

public func encode(content: Attachment) throws -> EncodedContent {
public func encode(content: Attachment, client _: Client) throws -> EncodedContent {
var encodedContent = EncodedContent()

encodedContent.type = ContentTypeAttachment
Expand All @@ -44,7 +44,7 @@ public struct AttachmentCodec: ContentCodec {
return encodedContent
}

public func decode(content: EncodedContent) throws -> Attachment {
public func decode(content: EncodedContent, client _: Client) throws -> Attachment {
guard let mimeType = content.parameters["mimeType"],
let filename = content.parameters["filename"]
else {
Expand Down
4 changes: 2 additions & 2 deletions Sources/XMTP/Codecs/Composite.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ struct CompositeCodec: ContentCodec {
ContentTypeComposite
}

public func encode(content: DecodedComposite) throws -> EncodedContent {
public func encode(content: DecodedComposite, client _: Client) throws -> EncodedContent {
let composite = toComposite(content: content)
var encoded = EncodedContent()
encoded.type = ContentTypeComposite
encoded.content = try composite.serializedData()
return encoded
}

public func decode(content encoded: EncodedContent) throws -> DecodedComposite {
public func decode(content encoded: EncodedContent, client _: Client) throws -> DecodedComposite {
let composite = try Composite(serializedData: encoded.content)
let decodedComposite = fromComposite(composite: composite)
return decodedComposite
Expand Down
12 changes: 6 additions & 6 deletions Sources/XMTP/Codecs/ContentCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ enum CodecError: String, Error {
public typealias EncodedContent = Xmtp_MessageContents_EncodedContent

extension EncodedContent {
public func decoded<T>() throws -> T {
let codec = Client.codecRegistry.find(for: type)
public func decoded<T>(with client: Client) throws -> T {
let codec = client.codecRegistry.find(for: type)

var encodedContent = self

if hasCompression {
encodedContent = try decompressContent()
}

if let content = try codec.decode(content: encodedContent) as? T {
if let content = try codec.decode(content: encodedContent, client: client) as? T {
return content
}

Expand Down Expand Up @@ -69,9 +69,9 @@ public protocol ContentCodec: Hashable, Equatable {
associatedtype T

var contentType: ContentTypeID { get }
func encode(content: T) throws -> EncodedContent
func decode(content: EncodedContent) throws -> T
func fallback(content: T) throws -> String?
func encode(content: T, client: Client) throws -> EncodedContent
func decode(content: EncodedContent, client: Client) throws -> T
func fallback(content: T) throws -> String?
}

public extension ContentCodec {
Expand Down
4 changes: 2 additions & 2 deletions Sources/XMTP/Codecs/DecodedComposite.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public struct DecodedComposite {
self.encodedContent = encodedContent
}

func content<T>() throws -> T? {
return try encodedContent?.decoded()
func content<T>(with client: Client) throws -> T? {
return try encodedContent?.decoded(with: client)
}
}
4 changes: 2 additions & 2 deletions Sources/XMTP/Codecs/ReactionCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public struct ReactionCodec: ContentCodec {

public init() {}

public func encode(content: Reaction) throws -> EncodedContent {
public func encode(content: Reaction, client _: Client) throws -> EncodedContent {
var encodedContent = EncodedContent()

encodedContent.type = ContentTypeReaction
Expand All @@ -71,7 +71,7 @@ public struct ReactionCodec: ContentCodec {
return encodedContent
}

public func decode(content: EncodedContent) throws -> Reaction {
public func decode(content: EncodedContent, client _: Client) throws -> Reaction {
// First try to decode it in the canonical form.
// swiftlint:disable no_optional_try
if let reaction = try? JSONDecoder().decode(Reaction.self, from: content.content) {
Expand Down
4 changes: 2 additions & 2 deletions Sources/XMTP/Codecs/ReadReceiptCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public struct ReadReceiptCodec: ContentCodec {

public var contentType = ContentTypeReadReceipt

public func encode(content: ReadReceipt) throws -> EncodedContent {
public func encode(content: ReadReceipt, client _: Client) throws -> EncodedContent {
var encodedContent = EncodedContent()

encodedContent.type = ContentTypeReadReceipt
Expand All @@ -29,7 +29,7 @@ public struct ReadReceiptCodec: ContentCodec {
return encodedContent
}

public func decode(content: EncodedContent) throws -> ReadReceipt {
public func decode(content: EncodedContent, client _: Client) throws -> ReadReceipt {
return ReadReceipt()
}

Expand Down
8 changes: 4 additions & 4 deletions Sources/XMTP/Codecs/RemoteAttachmentCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ public struct RemoteAttachment {
}
}

public static func encodeEncrypted<Codec: ContentCodec, T>(content: T, codec: Codec) throws -> EncryptedEncodedContent where Codec.T == T {
public static func encodeEncrypted<Codec: ContentCodec, T>(content: T, codec: Codec, with client: Client) throws -> EncryptedEncodedContent where Codec.T == T {
let secret = try Crypto.secureRandomBytes(count: 32)
let encodedContent = try codec.encode(content: content).serializedData()
let encodedContent = try codec.encode(content: content, client: client).serializedData()
let ciphertext = try Crypto.encrypt(secret, encodedContent)
let contentDigest = sha256(data: ciphertext.aes256GcmHkdfSha256.payload)

Expand Down Expand Up @@ -139,7 +139,7 @@ public struct RemoteAttachmentCodec: ContentCodec {

public var contentType = ContentTypeRemoteAttachment

public func encode(content: RemoteAttachment) throws -> EncodedContent {
public func encode(content: RemoteAttachment, client _: Client) throws -> EncodedContent {
var encodedContent = EncodedContent()

encodedContent.type = ContentTypeRemoteAttachment
Expand All @@ -157,7 +157,7 @@ public struct RemoteAttachmentCodec: ContentCodec {
return encodedContent
}

public func decode(content: EncodedContent) throws -> RemoteAttachment {
public func decode(content: EncodedContent, client _: Client) throws -> RemoteAttachment {
guard let url = String(data: content.content, encoding: .utf8) else {
throw RemoteAttachmentError.invalidURL
}
Expand Down
26 changes: 13 additions & 13 deletions Sources/XMTP/Codecs/ReplyCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,39 @@ public struct Reply {
public var content: Any
public var contentType: ContentTypeID

public init(reference: String, content: Any, contentType: ContentTypeID) {
self.reference = reference
self.content = content
self.contentType = contentType
}
public init(reference: String, content: Any, contentType: ContentTypeID) {
self.reference = reference
self.content = content
self.contentType = contentType
}
}

public struct ReplyCodec: ContentCodec {
public var contentType = ContentTypeReply

public init() {}

public func encode(content reply: Reply) throws -> EncodedContent {
public func encode(content reply: Reply, client: Client) throws -> EncodedContent {
var encodedContent = EncodedContent()
let replyCodec = Client.codecRegistry.find(for: reply.contentType)
let replyCodec = client.codecRegistry.find(for: reply.contentType)

encodedContent.type = contentType
// TODO: cut when we're certain no one is looking for "contentType" here.
encodedContent.parameters["contentType"] = reply.contentType.description
encodedContent.parameters["reference"] = reply.reference
encodedContent.content = try encodeReply(codec: replyCodec, content: reply.content).serializedData()
encodedContent.content = try encodeReply(codec: replyCodec, content: reply.content, client: client).serializedData()

return encodedContent
}

public func decode(content: EncodedContent) throws -> Reply {
public func decode(content: EncodedContent, client: Client) throws -> Reply {
guard let reference = content.parameters["reference"] else {
throw CodecError.invalidContent
}

let replyEncodedContent = try EncodedContent(serializedData: content.content)
let replyCodec = Client.codecRegistry.find(for: replyEncodedContent.type)
let replyContent = try replyCodec.decode(content: replyEncodedContent)
let replyCodec = client.codecRegistry.find(for: replyEncodedContent.type)
let replyContent = try replyCodec.decode(content: replyEncodedContent, client: client)

return Reply(
reference: reference,
Expand All @@ -55,9 +55,9 @@ public struct ReplyCodec: ContentCodec {
)
}

func encodeReply<Codec: ContentCodec>(codec: Codec, content: Any) throws -> EncodedContent {
func encodeReply<Codec: ContentCodec>(codec: Codec, content: Any, client: Client) throws -> EncodedContent {
if let content = content as? Codec.T {
return try codec.encode(content: content)
return try codec.encode(content: content, client: client)
} else {
throw CodecError.invalidContent
}
Expand Down
6 changes: 3 additions & 3 deletions Sources/XMTP/Codecs/TextCodec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ public struct TextCodec: ContentCodec {

public typealias T = String

public init() {}
public init() { }

public var contentType = ContentTypeText

public func encode(content: String) throws -> EncodedContent {
public func encode(content: String, client _: Client) throws -> EncodedContent {
var encodedContent = EncodedContent()

encodedContent.type = ContentTypeText
Expand All @@ -31,7 +31,7 @@ public struct TextCodec: ContentCodec {
return encodedContent
}

public func decode(content: EncodedContent) throws -> String {
public func decode(content: EncodedContent, client _: Client) throws -> String {
if let encoding = content.parameters["encoding"], encoding != "UTF-8" {
throw TextCodecError.invalidEncoding
}
Expand Down
7 changes: 4 additions & 3 deletions Sources/XMTP/ConversationV1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ public struct ConversationV1 {
}

func prepareMessage<T>(content: T, options: SendOptions?) async throws -> PreparedMessage {
let codec = Client.codecRegistry.find(for: options?.contentType)
let codec = client.codecRegistry.find(for: options?.contentType)

func encode<Codec: ContentCodec>(codec: Codec, content: Any) throws -> EncodedContent {
if let content = content as? Codec.T {
return try codec.encode(content: content)
return try codec.encode(content: content, client: client)
} else {
throw CodecError.invalidContent
}
Expand Down Expand Up @@ -203,7 +203,8 @@ public struct ConversationV1 {
let header = try message.v1.header

var decoded = DecodedMessage(
topic: envelope.contentTopic,
client: client,
topic: envelope.contentTopic,
encodedContent: encodedMessage,
senderAddress: header.sender.walletAddress,
sent: message.v1.sentAt
Expand Down
8 changes: 4 additions & 4 deletions Sources/XMTP/ConversationV2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ public struct ConversationV2 {
}

func prepareMessage<T>(content: T, options: SendOptions?) async throws -> PreparedMessage {
let codec = Client.codecRegistry.find(for: options?.contentType)
let codec = client.codecRegistry.find(for: options?.contentType)

func encode<Codec: ContentCodec>(codec: Codec, content: Any) throws -> EncodedContent {
if let content = content as? Codec.T {
return try codec.encode(content: content)
return try codec.encode(content: content, client: client)
} else {
throw CodecError.invalidContent
}
Expand Down Expand Up @@ -176,7 +176,7 @@ public struct ConversationV2 {
}

private func decode(_ message: MessageV2) throws -> DecodedMessage {
try MessageV2.decode(message, keyMaterial: keyMaterial)
try MessageV2.decode(message, keyMaterial: keyMaterial, client: client)
}

@discardableResult func send<T>(content: T, options: SendOptions? = nil) async throws -> String {
Expand All @@ -200,7 +200,7 @@ public struct ConversationV2 {
}

public func encode<Codec: ContentCodec, T>(codec: Codec, content: T) async throws -> Data where Codec.T == T {
let content = try codec.encode(content: content)
let content = try codec.encode(content: content, client: client)

let message = try await MessageV2.encode(
client: client,
Expand Down
30 changes: 20 additions & 10 deletions Sources/XMTP/DecodedMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Foundation

/// Decrypted messages from a conversation.
public struct DecodedMessage: Sendable {
public var topic: String
public var topic: String

public var id: String = ""

Expand All @@ -21,15 +21,24 @@ public struct DecodedMessage: Sendable {
/// When the message was sent
public var sent: Date

public init(topic: String, encodedContent: EncodedContent, senderAddress: String, sent: Date) {
self.topic = topic
self.encodedContent = encodedContent
self.senderAddress = senderAddress
self.sent = sent
var client: Client

public init(
client: Client,
topic: String,
encodedContent: EncodedContent,
senderAddress: String,
sent: Date
) {
self.client = client
self.topic = topic
self.encodedContent = encodedContent
self.senderAddress = senderAddress
self.sent = sent
}

public func content<T>() throws -> T {
return try encodedContent.decoded()
return try encodedContent.decoded(with: client)
}

public var fallbackContent: String {
Expand All @@ -46,10 +55,11 @@ public struct DecodedMessage: Sendable {
}

public extension DecodedMessage {
static func preview(topic: String, body: String, senderAddress: String, sent: Date) -> DecodedMessage {
static func preview(client: Client, topic: String, body: String, senderAddress: String, sent: Date) -> DecodedMessage {
// swiftlint:disable force_try
let encoded = try! TextCodec().encode(content: body)
let encoded = try! TextCodec().encode(content: body, client: client)
// swiftlint:enable force_try
return DecodedMessage(topic: topic, encodedContent: encoded, senderAddress: senderAddress, sent: sent)

return DecodedMessage(client: client, topic: topic, encodedContent: encoded, senderAddress: senderAddress, sent: sent)
}
}
Loading

0 comments on commit c9e809c

Please sign in to comment.