Skip to content

Commit

Permalink
feat: set the correct external sender key when creating MLS conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamadJaara committed May 2, 2024
1 parent 00d9937 commit d50e325
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class MLSClientImpl(
private val keyRotationDuration: Duration = 30.toDuration(DurationUnit.DAYS)
private val defaultGroupConfiguration = CustomConfiguration(keyRotationDuration.toJavaDuration(), MlsWirePolicy.PLAINTEXT)

override fun getDefaultCipherSuite(): UShort {
return defaultCipherSuite
}

override suspend fun close() {
coreCrypto.close()
}
Expand Down Expand Up @@ -104,11 +108,11 @@ class MLSClientImpl(

override suspend fun createConversation(
groupId: MLSGroupId,
externalSenders: List<Ed22519Key>
externalSenders: ByteArray
) {
val conf = ConversationConfiguration(
defaultCipherSuite,
externalSenders.map { it.value },
listOf(externalSenders),
defaultGroupConfiguration
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ data class CrlRegistration(

@Suppress("TooManyFunctions")
interface MLSClient {
/**
* Get the default ciphersuite for the client.
* the Default ciphersuite is set when creating the mls client.
*/
fun getDefaultCipherSuite(): UShort

/**
* Free up any resources and shutdown the client.
Expand Down Expand Up @@ -253,7 +258,7 @@ interface MLSClient {
*/
suspend fun createConversation(
groupId: MLSGroupId,
externalSenders: List<Ed22519Key> = emptyList()
externalSenders: ByteArray
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import com.wire.kalium.cryptography.CommitBundle
import com.wire.kalium.cryptography.CryptoCertificateStatus
import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.cryptography.Ed22519Key
import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logger.obfuscateId
Expand All @@ -48,6 +47,7 @@ import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysMapper
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.user.UserId
Expand Down Expand Up @@ -572,14 +572,28 @@ internal class MLSConversationDataSource(
allowSkippingUsersWithoutKeyPackages: Boolean,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsPublicKeysRepository.getKeys().flatMap { publicKeys ->
val keys = publicKeys.map { mlsPublicKeysMapper.toCrypto(it) }
establishMLSGroup(
groupID = groupID,
members = members,
keys = keys,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> {
val keySignature = CipherSuite.fromTag(it.getDefaultCipherSuite()).let {
mlsPublicKeysMapper.fromCipherSuite(it)
}
val key = publicKeys.removal?.let { removalKeys ->
removalKeys[keySignature.value]
}

if (key == null) {
return@flatMap MLSFailure
.Generic(IllegalArgumentException("No key found for $keySignature, isNullOrEmpty= ${publicKeys.removal.isNullOrEmpty()}"))
.left()
}
establishMLSGroup(
groupID = groupID,
members = members,
externalSenders = key.decodeBase64Bytes(),
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
}
}

}

override suspend fun establishMLSSubConversationGroup(
Expand All @@ -592,7 +606,7 @@ internal class MLSConversationDataSource(
establishMLSGroup(
groupID = groupID,
members = emptyList(),
keys = listOf(mlsPublicKeysMapper.toCrypto(externalSenderKey)),
externalSenders = externalSenderKey.value,
allowPartialMemberList = false
).map { Unit }
} ?: Either.Left(StorageFailure.DataNotFound)
Expand All @@ -602,14 +616,14 @@ internal class MLSConversationDataSource(
private suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
keys: List<Ed22519Key>,
externalSenders: ByteArray,
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.createConversation(
idMapper.toCryptoModel(groupID),
keys
externalSenders
)
}.flatMapLeft {
if (it is MLSFailure.ConversationAlreadyExists) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,15 @@
package com.wire.kalium.logic.data.mlspublickeys

import com.wire.kalium.cryptography.ExternalSenderKey
import com.wire.kalium.network.api.base.authenticated.serverpublickey.MLSPublicKeysDTO
import io.ktor.util.decodeBase64Bytes
import com.wire.kalium.logic.data.mls.CipherSuite

interface MLSPublicKeysMapper {
fun fromDTO(publicKeys: MLSPublicKeysDTO): List<MLSPublicKey>
fun toCrypto(publicKey: MLSPublicKey): com.wire.kalium.cryptography.Ed22519Key
fun toCrypto(externalSenderKey: ExternalSenderKey): com.wire.kalium.cryptography.Ed22519Key
fun fromCipherSuite(cipherSuite: CipherSuite): MLSPublicKeyType
}

class MLSPublicKeysMapperImpl : MLSPublicKeysMapper {
override fun fromDTO(publicKeys: MLSPublicKeysDTO) = with(publicKeys) {
removal?.entries?.mapNotNull {
when (it.key) {
ED25519 -> MLSPublicKey(Ed25519Key(it.value.decodeBase64Bytes()), KeyType.REMOVAL)
else -> null
}
} ?: emptyList()
}

override fun toCrypto(publicKey: MLSPublicKey) = with(publicKey) {
com.wire.kalium.cryptography.Ed22519Key(key.value)
}
Expand All @@ -46,8 +36,54 @@ class MLSPublicKeysMapperImpl : MLSPublicKeysMapper {
com.wire.kalium.cryptography.Ed22519Key(this.value)
}

companion object {
const val ED25519 = "ed25519"
override fun fromCipherSuite(cipherSuite: CipherSuite): MLSPublicKeyType {
return when (cipherSuite) {
CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256 -> MLSPublicKeyType.P256
CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 -> MLSPublicKeyType.ED25519
CipherSuite.MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 -> MLSPublicKeyType.ED25519
CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519 -> MLSPublicKeyType.ED25519
CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384 -> MLSPublicKeyType.P384
CipherSuite.MLS_256_DHKEMP521_AES256GCM_SHA512_P521 -> MLSPublicKeyType.P521
CipherSuite.MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448 -> MLSPublicKeyType.ED448
CipherSuite.MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448 -> MLSPublicKeyType.ED448
is CipherSuite.UNKNOWN -> MLSPublicKeyType.Unknown(null)
}
}
}

sealed class MLSPublicKeyType {
abstract val value: String?

data object P256 : MLSPublicKeyType() {
override val value: String = "p256"
}

data object P384 : MLSPublicKeyType() {
override val value: String = "p384"
}

data object P521 : MLSPublicKeyType() {
override val value: String = "p521"
}

data object ED448 : MLSPublicKeyType() {
override val value: String = "ed448"
}

data object ED25519 : MLSPublicKeyType() {
override val value: String = "ed25519"
}

data class Unknown(override val value: String?) : MLSPublicKeyType()

companion object {
fun from(value: String) = when (value) {
P256.value -> P256
P384.value -> P384
P521.value -> P521
ED448.value -> ED448
ED25519.value -> ED25519
else -> Unknown(value)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,32 @@ import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.network.api.base.authenticated.serverpublickey.MLSPublicKeyApi


data class MLSPublicKeys(
val removal: Map<String, String>?
)


interface MLSPublicKeysRepository {
suspend fun fetchKeys(): Either<CoreFailure, List<MLSPublicKey>>
suspend fun getKeys(): Either<CoreFailure, List<MLSPublicKey>>
suspend fun fetchKeys(): Either<CoreFailure, MLSPublicKeys>
suspend fun getKeys(): Either<CoreFailure, MLSPublicKeys>
}

class MLSPublicKeysRepositoryImpl(
private val mlsPublicKeyApi: MLSPublicKeyApi,
private val mapper: MLSPublicKeysMapper = MLSPublicKeysMapperImpl()
) : MLSPublicKeysRepository {

var publicKeys: List<MLSPublicKey>? = null
// TODO: make it thread safe
var publicKeys: MLSPublicKeys? = null

override suspend fun fetchKeys() =
wrapApiRequest {
mlsPublicKeyApi.getMLSPublicKeys()
}.map {
val keys = mapper.fromDTO(it)
publicKeys = keys
keys
MLSPublicKeys(removal = it.removal)
}

override suspend fun getKeys(): Either<CoreFailure, List<MLSPublicKey>> {
override suspend fun getKeys(): Either<CoreFailure, MLSPublicKeys> {
return publicKeys?.let { Either.Right(it) } ?: fetchKeys()
}

}

0 comments on commit d50e325

Please sign in to comment.