Skip to content

Commit

Permalink
Merge pull request #455 from xmtp/nmolnar/strongly-typed-client
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas authored Sep 13, 2023
2 parents 8f0c04a + 4d30160 commit ad59a8b
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 176 deletions.
2 changes: 1 addition & 1 deletion bench/encode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const encodeV1 = () => {
const alice = await Client.create(newWallet(), { env: 'local' })
const bobKeys = (await newPrivateKeyBundle()).getPublicKeyBundle()

const message = randomBytes(size)
const message = randomBytes(size).toString()
const timestamp = new Date()

// The returned function is the actual benchmark. Everything above is setup
Expand Down
78 changes: 59 additions & 19 deletions src/Client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import { utils } from 'ethers'
import { Signer } from './types/Signer'
import { Conversations } from './conversations'
import { ContentTypeText, TextCodec } from './codecs/Text'
import { ContentTypeId, ContentCodec } from './MessageContent'
import { compress } from './Compression'
import { ContentTypeId, ContentCodec, EncodedContent } from './MessageContent'
import { compress, decompress } from './Compression'
import { content as proto, messageApi } from '@xmtp/proto'
import { decodeContactBundle, encodeContactBundle } from './ContactBundle'
import HttpApiClient, {
Expand Down Expand Up @@ -41,6 +41,7 @@ import {
} from './keystore/persistence'
import { hasMetamaskWithSnaps } from './keystore/snapHelpers'
import { version as snapVersion, package as snapPackage } from './snapInfo.json'
import { ExtractDecodedType } from './types/client'
const { Compression } = proto

// eslint-disable @typescript-eslint/explicit-module-boundary-types
Expand Down Expand Up @@ -244,7 +245,8 @@ export function defaultOptions(opts?: Partial<ClientOptions>): ClientOptions {
* Client class initiates connection to the XMTP network.
* Should be created with `await Client.create(options)`
*/
export default class Client {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export default class Client<ContentTypes = any> {
address: string
keystore: Keystore
apiClient: ApiClient
Expand All @@ -256,7 +258,7 @@ export default class Client {
> // addresses and key bundles that we have witnessed

private _backupClient: BackupClient
private readonly _conversations: Conversations
private readonly _conversations: Conversations<ContentTypes>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
private _codecs: Map<string, ContentCodec<any>>
private _maxContentSize: number
Expand Down Expand Up @@ -286,7 +288,7 @@ export default class Client {
/**
* @type {Conversations}
*/
get conversations(): Conversations {
get conversations(): Conversations<ContentTypes> {
return this._conversations
}

Expand All @@ -304,10 +306,16 @@ export default class Client {
* @param wallet the wallet as a Signer instance
* @param opts specify how to to connect to the network
*/
static async create(

// eslint-disable-next-line @typescript-eslint/no-explicit-any
static async create<ContentCodecs extends ContentCodec<any>[] = []>(
wallet: Signer | null,
opts?: Partial<ClientOptions>
): Promise<Client> {
opts?: Partial<ClientOptions> & { codecs?: ContentCodecs }
): Promise<
Client<
ExtractDecodedType<[...ContentCodecs, TextCodec][number]> | undefined
>
> {
const options = defaultOptions(opts)
const apiClient = options.apiClientFactory(options)
const keystore = await bootstrapKeystore(options, apiClient, wallet)
Expand All @@ -317,12 +325,9 @@ export default class Client {
const address = publicKeyBundle.walletSignatureAddress()
apiClient.setAuthenticator(new KeystoreAuthenticator(keystore))
const backupClient = await Client.setupBackupClient(address, options.env)
const client = new Client(
publicKeyBundle,
apiClient,
backupClient,
keystore
)
const client = new Client<
ExtractDecodedType<[...ContentCodecs, TextCodec][number]> | undefined
>(publicKeyBundle, apiClient, backupClient, keystore)
await client.init(options)
return client
}
Expand All @@ -337,9 +342,9 @@ export default class Client {
* impersonate a user on the XMTP network and read the user's
* messages.
*/
static async getKeys(
static async getKeys<U>(
wallet: Signer | null,
opts?: Partial<ClientOptions>
opts?: Partial<ClientOptions> & { codecs?: U }
): Promise<Uint8Array> {
const client = await Client.create(wallet, opts)
const keys = await client.keystore.getPrivateKeyBundle()
Expand Down Expand Up @@ -596,10 +601,13 @@ export default class Client {
* messages of the given Content Type
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
registerCodec(codec: ContentCodec<any>): void {
registerCodec<Codec extends ContentCodec<any>>(
codec: Codec
): Client<ContentTypes | ExtractDecodedType<Codec>> {
const id = codec.contentType
const key = `${id.authorityId}/${id.typeId}`
this._codecs.set(key, codec)
return this
}

/**
Expand All @@ -624,8 +632,7 @@ export default class Client {
* with the given options
*/
async encodeContent(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
content: any,
content: ContentTypes,
options?: SendOptions
): Promise<Uint8Array> {
const contentType = options?.contentType || ContentTypeText
Expand All @@ -646,6 +653,39 @@ export default class Client {
return proto.EncodedContent.encode(encoded).finish()
}

async decodeContent(contentBytes: Uint8Array): Promise<{
content: ContentTypes
contentType: ContentTypeId
error?: Error
contentFallback?: string
}> {
const encodedContent = proto.EncodedContent.decode(contentBytes)

if (!encodedContent.type) {
throw new Error('missing content type')
}

let content: any // eslint-disable-line @typescript-eslint/no-explicit-any
const contentType = new ContentTypeId(encodedContent.type)
let error: Error | undefined

await decompress(encodedContent, 1000)

const codec = this.codecFor(contentType)
if (codec) {
content = codec.decode(encodedContent as EncodedContent, this)
} else {
error = new Error('unknown content type ' + contentType)
}

return {
content,
contentType,
error,
contentFallback: encodedContent.fallback,
}
}

listInvitations(opts?: ListMessagesOptions): Promise<messageApi.Envelope[]> {
return this.listEnvelopes(
buildUserInviteTopic(this.address),
Expand Down
87 changes: 29 additions & 58 deletions src/Message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,14 @@ import {
ConversationV2,
} from './conversations/Conversation'
import type Client from './Client'
import {
message as proto,
content as protoContent,
conversationReference,
} from '@xmtp/proto'
import { message as proto, conversationReference } from '@xmtp/proto'
import Long from 'long'
import Ciphertext from './crypto/Ciphertext'
import { PublicKeyBundle, PublicKey } from './crypto'
import { bytesToHex } from './crypto/utils'
import { sha256 } from './crypto/encryption'
import {
ContentTypeFallback,
ContentTypeId,
EncodedContent,
} from './MessageContent'
import { ContentTypeId } from './MessageContent'
import { dateToNs, nsToDate } from './utils'
import { decompress } from './Compression'
import { Keystore } from './keystore'
import { buildDecryptV1Request, getResultOrThrow } from './utils/keystore'

Expand Down Expand Up @@ -228,16 +219,17 @@ export class MessageV2 extends MessageBase implements proto.MessageV2 {

export type Message = MessageV1 | MessageV2

export class DecodedMessage {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export class DecodedMessage<ContentTypes = any> {
id: string
messageVersion: 'v1' | 'v2'
senderAddress: string
recipientAddress?: string
sent: Date
contentTopic: string
conversation: Conversation
conversation: Conversation<ContentTypes>
contentType: ContentTypeId
content: any // eslint-disable-line @typescript-eslint/no-explicit-any
content: ContentTypes
error?: Error
contentBytes: Uint8Array
contentFallback?: string
Expand All @@ -255,7 +247,7 @@ export class DecodedMessage {
sent,
error,
contentFallback,
}: Omit<DecodedMessage, 'toBytes'>) {
}: Omit<DecodedMessage<ContentTypes>, 'toBytes'>) {
this.id = id
this.messageVersion = messageVersion
this.senderAddress = senderAddress
Expand Down Expand Up @@ -283,10 +275,10 @@ export class DecodedMessage {
}).finish()
}

static async fromBytes(
static async fromBytes<ContentTypes>(
data: Uint8Array,
client: Client
): Promise<DecodedMessage> {
client: Client<ContentTypes>
): Promise<DecodedMessage<ContentTypes>> {
const protoVal = proto.DecodedMessage.decode(data)
const messageVersion = protoVal.messageVersion

Expand All @@ -299,7 +291,7 @@ export class DecodedMessage {
}

const { content, contentType, error, contentFallback } =
await decodeContent(protoVal.contentBytes, client)
await client.decodeContent(protoVal.contentBytes)

return new DecodedMessage({
...protoVal,
Expand All @@ -317,16 +309,16 @@ export class DecodedMessage {
})
}

static fromV1Message(
static fromV1Message<ContentTypes>(
message: MessageV1,
content: any, // eslint-disable-line @typescript-eslint/no-explicit-any
content: ContentTypes,
contentType: ContentTypeId,
contentBytes: Uint8Array,
contentTopic: string,
conversation: Conversation,
conversation: Conversation<ContentTypes>,
error?: Error,
contentFallback?: string
): DecodedMessage {
): DecodedMessage<ContentTypes> {
const { id, senderAddress, recipientAddress, sent } = message
if (!senderAddress) {
throw new Error('Sender address is required')
Expand All @@ -347,17 +339,17 @@ export class DecodedMessage {
})
}

static fromV2Message(
static fromV2Message<ContentTypes>(
message: MessageV2,
content: any, // eslint-disable-line @typescript-eslint/no-explicit-any
content: ContentTypes,
contentType: ContentTypeId,
contentTopic: string,
contentBytes: Uint8Array,
conversation: Conversation,
conversation: Conversation<ContentTypes>,
senderAddress: string,
error?: Error,
contentFallback?: string
): DecodedMessage {
): DecodedMessage<ContentTypes> {
const { id, sent } = message

return new DecodedMessage({
Expand All @@ -376,39 +368,11 @@ export class DecodedMessage {
}
}

export async function decodeContent(contentBytes: Uint8Array, client: Client) {
const encodedContent = protoContent.EncodedContent.decode(contentBytes)

if (!encodedContent.type) {
throw new Error('missing content type')
}

let content: any // eslint-disable-line @typescript-eslint/no-explicit-any
const contentType = new ContentTypeId(encodedContent.type)
let error: Error | undefined

await decompress(encodedContent, 1000)

const codec = client.codecFor(contentType)
if (codec) {
content = codec.decode(encodedContent as EncodedContent, client)
} else {
error = new Error('unknown content type ' + contentType)
}

return {
content,
contentType,
error,
contentFallback: encodedContent.fallback,
}
}

function conversationReferenceToConversation(
function conversationReferenceToConversation<ContentTypes>(
reference: conversationReference.ConversationReference,
client: Client,
client: Client<ContentTypes>,
version: DecodedMessage['messageVersion']
): Conversation {
): Conversation<ContentTypes> {
if (version === 'v1') {
return new ConversationV1(
client,
Expand All @@ -427,3 +391,10 @@ function conversationReferenceToConversation(
}
throw new Error(`Unknown conversation version ${version}`)
}

export function decodeContent<ContentTypes>(
contentBytes: Uint8Array,
client: Client<ContentTypes>
) {
return client.decodeContent(contentBytes)
}
13 changes: 7 additions & 6 deletions src/Stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ export type ContentTopicUpdater<M> = (msg: M) => string[] | undefined
* Stream implements an Asynchronous Iterable over messages received from a topic.
* As such can be used with constructs like for-await-of, yield*, array destructing, etc.
*/
export default class Stream<T> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export default class Stream<T, ClientType = any> {
topics: string[]
client: Client
client: Client<ClientType>
// queue of incoming Waku messages
messages: T[]
// queue of already pending Promises
Expand All @@ -32,7 +33,7 @@ export default class Stream<T> {
onConnectionLost?: OnConnectionLostCallback

constructor(
client: Client,
client: Client<ClientType>,
topics: string[],
decoder: MessageDecoder<T>,
contentTopicUpdater?: ContentTopicUpdater<T>,
Expand Down Expand Up @@ -100,13 +101,13 @@ export default class Stream<T> {
)
}

static async create<T>(
client: Client,
static async create<T, ClientType = string>(
client: Client<ClientType>,
topics: string[],
decoder: MessageDecoder<T>,
contentTopicUpdater?: ContentTopicUpdater<T>,
onConnectionLost?: OnConnectionLostCallback
): Promise<Stream<T>> {
): Promise<Stream<T, ClientType>> {
const stream = new Stream(
client,
topics,
Expand Down
Loading

0 comments on commit ad59a8b

Please sign in to comment.