Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set the correct external sender key when creating MLS conversation [WPB-8592] 🍒 🍒 #2779

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class MLSClientImpl(

private val keyRotationDuration: Duration = 30.toDuration(DurationUnit.DAYS)
private val defaultGroupConfiguration = CustomConfiguration(keyRotationDuration, MlsWirePolicy.PLAINTEXT)
override fun getDefaultCipherSuite(): UShort {
return defaultCipherSuite
}

@Suppress("EmptyFunctionBlock")
override suspend fun close() {
Expand Down Expand Up @@ -97,11 +100,11 @@ class MLSClientImpl(

override suspend fun createConversation(
groupId: MLSGroupId,
externalSenders: List<Ed22519Key>
externalSenders: ByteArray
) {
val conf = ConversationConfiguration(
CiphersuiteName.MLS_128_DHKEMX25519_AES128GCM_SHA256_ED25519,
externalSenders.map { toUByteList(it.value) },
listOf(toUByteList(externalSenders)),
defaultGroupConfiguration
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class MLSClientImpl(
coreCrypto.close()
}

override fun getDefaultCipherSuite(): UShort {
return defaultCipherSuite
}

override suspend fun getPublicKey(): Pair<ByteArray, Ciphersuite> {
return coreCrypto.clientPublicKey(defaultCipherSuite, toCredentialType(getMLSCredentials())) to defaultCipherSuite
}
Expand Down Expand Up @@ -104,11 +108,12 @@ class MLSClientImpl(

override suspend fun createConversation(
groupId: MLSGroupId,
externalSenders: List<Ed22519Key>
externalSenders: ByteArray
) {
kaliumLogger.d("createConversation: using defaultCipherSuite=$defaultCipherSuite")
val conf = ConversationConfiguration(
defaultCipherSuite,
externalSenders.map { it.value },
listOf(externalSenders),
defaultGroupConfiguration
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,6 @@ data class DecryptedMessageBundle(
}
}

@JvmInline
value class Ed22519Key(
val value: ByteArray
)

@JvmInline
value class ExternalSenderKey(
val value: ByteArray
Expand All @@ -153,6 +148,11 @@ data class CrlRegistration(

@Suppress("TooManyFunctions")
interface MLSClient {
/**
* Get the default ciphersuite for the client.
* the Default ciphersuite is set when creating the mls client.
*/
fun getDefaultCipherSuite(): UShort

/**
* Free up any resources and shutdown the client.
Expand Down Expand Up @@ -253,7 +253,7 @@ interface MLSClient {
*/
suspend fun createConversation(
groupId: MLSGroupId,
externalSenders: List<Ed22519Key> = emptyList()
externalSenders: ByteArray
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ class MLSClientTest : BaseMLSClientTest() {
}

private suspend fun createClient(user: SampleUser): MLSClient {
return createMLSClient(user.qualifiedClientId, ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITES)
return createMLSClient(user.qualifiedClientId, allowedCipherSuites = ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITES)
}

@Test
fun givemMlsClient_whenCallingGetDefaultCipherSuite_ReturnExpectedValue() = runTest {
val mlsClient = createClient(ALICE1)
assertEquals(DEFAULT_CIPHER_SUITES, mlsClient.getDefaultCipherSuite())
}

@Test
Expand All @@ -51,7 +57,7 @@ class MLSClientTest : BaseMLSClientTest() {
@Test
fun givenNewConversation_whenCallingConversationEpoch_ReturnZeroEpoch() = runTest {
val mlsClient = createClient(ALICE1)
mlsClient.createConversation(MLS_CONVERSATION_ID)
mlsClient.createConversation(MLS_CONVERSATION_ID, externalSenderKey)
assertEquals(0UL, mlsClient.conversationEpoch(MLS_CONVERSATION_ID))
}

Expand All @@ -64,7 +70,7 @@ class MLSClientTest : BaseMLSClientTest() {

val aliceKeyPackage = aliceClient.generateKeyPackages(1).first()
val clientKeyPackageList = listOf(aliceKeyPackage)
bobClient.createConversation(MLS_CONVERSATION_ID)
bobClient.createConversation(MLS_CONVERSATION_ID, externalSenderKey)
val welcome = bobClient.addMember(MLS_CONVERSATION_ID, clientKeyPackageList)?.welcome!!
bobClient.commitAccepted(MLS_CONVERSATION_ID)
val welcomeBundle = aliceClient.processWelcomeMessage(welcome)
Expand All @@ -82,7 +88,7 @@ class MLSClientTest : BaseMLSClientTest() {

val aliceKeyPackage = aliceClient.generateKeyPackages(1).first()
val clientKeyPackageList = listOf(aliceKeyPackage)
bobClient.createConversation(MLS_CONVERSATION_ID)
bobClient.createConversation(MLS_CONVERSATION_ID, externalSenderKey)
val welcome = bobClient.addMember(MLS_CONVERSATION_ID, clientKeyPackageList)!!.welcome!!
val welcomeBundle = aliceClient.processWelcomeMessage(welcome)

Expand All @@ -98,7 +104,7 @@ class MLSClientTest : BaseMLSClientTest() {
val alice1KeyPackage = alice1Client.generateKeyPackages(1).first()
val clientKeyPackageList = listOf(alice1KeyPackage)

bobClient.createConversation(MLS_CONVERSATION_ID)
bobClient.createConversation(MLS_CONVERSATION_ID, externalSenderKey)
bobClient.addMember(MLS_CONVERSATION_ID, clientKeyPackageList)
bobClient.commitAccepted(MLS_CONVERSATION_ID)
val proposal = alice2Client.joinConversation(MLS_CONVERSATION_ID, 1UL)
Expand All @@ -117,7 +123,7 @@ class MLSClientTest : BaseMLSClientTest() {

val clientKeyPackageList = listOf(aliceClient.generateKeyPackages(1).first())

bobClient.createConversation(MLS_CONVERSATION_ID)
bobClient.createConversation(MLS_CONVERSATION_ID, externalSenderKey)
val welcome = bobClient.addMember(MLS_CONVERSATION_ID, clientKeyPackageList)?.welcome!!
bobClient.commitAccepted(MLS_CONVERSATION_ID)
val welcomeBundle = aliceClient.processWelcomeMessage(welcome)
Expand All @@ -135,7 +141,7 @@ class MLSClientTest : BaseMLSClientTest() {

val clientKeyPackageList = listOf(aliceClient.generateKeyPackages(1).first())

bobClient.createConversation(MLS_CONVERSATION_ID)
bobClient.createConversation(MLS_CONVERSATION_ID, externalSenderKey)
val welcome = bobClient.addMember(MLS_CONVERSATION_ID, clientKeyPackageList)?.welcome!!
bobClient.commitAccepted((MLS_CONVERSATION_ID))
val welcomeBundle = aliceClient.processWelcomeMessage(welcome)
Expand All @@ -149,7 +155,7 @@ class MLSClientTest : BaseMLSClientTest() {
val bobClient = createClient(BOB1)
val carolClient = createClient(CAROL1)

bobClient.createConversation(MLS_CONVERSATION_ID)
bobClient.createConversation(MLS_CONVERSATION_ID, externalSenderKey)
val welcome = bobClient.addMember(
MLS_CONVERSATION_ID,
listOf(aliceClient.generateKeyPackages(1).first())
Expand All @@ -160,7 +166,7 @@ class MLSClientTest : BaseMLSClientTest() {

val commit = bobClient.addMember(
MLS_CONVERSATION_ID,
listOf( carolClient.generateKeyPackages(1).first())
listOf(carolClient.generateKeyPackages(1).first())
)?.commit!!

assertNull(aliceClient.decryptMessage(MLS_CONVERSATION_ID, commit).first().message)
Expand All @@ -176,7 +182,7 @@ class MLSClientTest : BaseMLSClientTest() {
aliceClient.generateKeyPackages(1).first(),
carolClient.generateKeyPackages(1).first()
)
bobClient.createConversation(MLS_CONVERSATION_ID)
bobClient.createConversation(MLS_CONVERSATION_ID, externalSenderKey)
val welcome = bobClient.addMember(MLS_CONVERSATION_ID, clientKeyPackageList)?.welcome!!
bobClient.commitAccepted(MLS_CONVERSATION_ID)
val welcomeBundle = aliceClient.processWelcomeMessage(welcome)
Expand All @@ -188,8 +194,9 @@ class MLSClientTest : BaseMLSClientTest() {
}

companion object {
val ALLOWED_CIPHER_SUITES = listOf(1.toUShort())
val externalSenderKey = ByteArray(32)
val DEFAULT_CIPHER_SUITES = 1.toUShort()
val ALLOWED_CIPHER_SUITES = listOf(1.toUShort())
const val MLS_CONVERSATION_ID = "JfflcPtUivbg+1U3Iyrzsh5D2ui/OGS5Rvf52ipH5KY="
const val PLAIN_TEXT = "Hello World"
val ALICE1 = SampleUser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import kotlin.time.Duration

@Suppress("TooManyFunctions")
class MLSClientImpl : MLSClient {
override fun getDefaultCipherSuite(): UShort {
TODO("Not yet implemented")
}

override suspend fun close() {
TODO("Not yet implemented")
}
Expand Down Expand Up @@ -66,7 +70,7 @@ class MLSClientImpl : MLSClient {
TODO("Not yet implemented")
}

override suspend fun createConversation(groupId: MLSGroupId, externalSenders: List<Ed22519Key>) {
override suspend fun createConversation(groupId: MLSGroupId, externalSenders: ByteArray) {
TODO("Not yet implemented")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ interface MLSFailure : CoreFailure {
data object StaleProposal : MLSFailure
data object StaleCommit : MLSFailure

class Generic(internal val exception: Exception) : MLSFailure {
data class Generic(internal val exception: Exception) : MLSFailure {
val rootCause: Throwable get() = exception
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ import com.wire.kalium.cryptography.CommitBundle
import com.wire.kalium.cryptography.CryptoCertificateStatus
import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.cryptography.Ed22519Key
import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logger.obfuscateId
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.E2EIFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.configuration.server.ServerConfig
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.mls.MLSAdditionResult
Expand All @@ -47,7 +45,7 @@ import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysMapper
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
Expand Down Expand Up @@ -218,10 +216,8 @@ internal class MLSConversationDataSource(
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val revocationListChecker: RevocationListChecker,
private val certificateRevocationListRepository: CertificateRevocationListRepository,
private val serverConfigLinks: ServerConfig.Links,
private val idMapper: IdMapper = MapperProvider.idMapper(),
private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId),
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper(),
private val mlsCommitBundleMapper: MLSCommitBundleMapper = MapperProvider.mlsCommitBundleMapper(),
kaliumDispatcher: KaliumDispatcher = KaliumDispatcherImpl
) : MLSConversationRepository {
Expand Down Expand Up @@ -310,10 +306,12 @@ internal class MLSConversationDataSource(
}
sendCommitBundleForExternalCommit(groupID, commitBundle)
}.onSuccess {
conversationDAO.updateConversationGroupState(
ConversationEntity.GroupState.ESTABLISHED,
idMapper.toCryptoModel(groupID)
)
wrapStorageRequest {
conversationDAO.updateConversationGroupState(
ConversationEntity.GroupState.ESTABLISHED,
idMapper.toCryptoModel(groupID)
)
}
}
}
}
Expand Down Expand Up @@ -552,14 +550,17 @@ internal class MLSConversationDataSource(
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsPublicKeysRepository.getKeys().flatMap { publicKeys ->
val keys = publicKeys.map { mlsPublicKeysMapper.toCrypto(it) }
establishMLSGroup(
groupID = groupID,
members = members,
keys = keys,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> {
mlsPublicKeysRepository.getKeyForCipherSuite(
CipherSuite.fromTag(it.getDefaultCipherSuite())
).flatMap { key ->
establishMLSGroup(
groupID = groupID,
members = members,
externalSenders = key,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
}
}
}

Expand All @@ -573,7 +574,7 @@ internal class MLSConversationDataSource(
establishMLSGroup(
groupID = groupID,
members = emptyList(),
keys = listOf(mlsPublicKeysMapper.toCrypto(externalSenderKey)),
externalSenders = externalSenderKey.value,
allowPartialMemberList = false
).map { Unit }
} ?: Either.Left(StorageFailure.DataNotFound)
Expand All @@ -583,14 +584,14 @@ internal class MLSConversationDataSource(
private suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
keys: List<Ed22519Key>,
externalSenders: ByteArray,
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.createConversation(
idMapper.toCryptoModel(groupID),
keys
externalSenders
)
}.flatMapLeft {
if (it is MLSFailure.ConversationAlreadyExists) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,18 @@ sealed class CipherSuite(open val tag: Int) {
61489 -> MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519
else -> UNKNOWN(tag)
}

fun fromTag(tag: UShort) = fromTag(tag.toInt())
}
}

fun CipherSuite.signatureAlgorithm(): MLSPublicKeyTypeDTO? = when (this) {
CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256 -> MLSPublicKeyTypeDTO.P256
CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256 -> MLSPublicKeyTypeDTO.ECDSA_SECP256R1_SHA256
CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 -> MLSPublicKeyTypeDTO.ED25519
CipherSuite.MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 -> MLSPublicKeyTypeDTO.ED25519
CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519 -> MLSPublicKeyTypeDTO.ED25519
CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384 -> MLSPublicKeyTypeDTO.P384
CipherSuite.MLS_256_DHKEMP521_AES256GCM_SHA512_P521 -> MLSPublicKeyTypeDTO.P521
CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384 -> MLSPublicKeyTypeDTO.ECDSA_SECP384R1_SHA384
CipherSuite.MLS_256_DHKEMP521_AES256GCM_SHA512_P521 -> MLSPublicKeyTypeDTO.ECDSA_SECP521R1_SHA512
CipherSuite.MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448 -> MLSPublicKeyTypeDTO.ED448
CipherSuite.MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448 -> MLSPublicKeyTypeDTO.ED448
is CipherSuite.UNKNOWN -> null
Expand Down

This file was deleted.

Loading
Loading