Skip to content

Commit

Permalink
fix(mls): set removal-keys for 1on1 calls from conversation-response (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mchenani authored Sep 18, 2024
1 parent 8728f13 commit 1b85149
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ internal class ConversationGroupRepositoryImpl(
val conversationEntity = conversationMapper.fromApiModelToDaoModel(
conversationResponse, mlsGroupState = ConversationEntity.GroupState.PENDING_CREATION, selfTeamId
)
val mlsPublicKeys = conversationMapper.fromApiModel(conversationResponse.publicKeys)
val protocol = protocolInfoMapper.fromEntity(conversationEntity.protocolInfo)

return wrapStorageRequest {
Expand All @@ -147,7 +148,8 @@ internal class ConversationGroupRepositoryImpl(
is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup(
groupID = protocol.groupId,
members = usersList + selfUserId,
allowSkippingUsersWithoutKeyPackages = true
publicKeys = mlsPublicKeys,
allowSkippingUsersWithoutKeyPackages = true,
).map { it.notAddedUsers }
}
}.flatMap { additionalFailedUsers ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.message.MessagePreview
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.user.AvailabilityStatusMapper
import com.wire.kalium.logic.data.user.BotService
import com.wire.kalium.logic.data.user.Connection
Expand All @@ -40,6 +41,7 @@ import com.wire.kalium.network.api.base.authenticated.conversation.ConvTeamInfo
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.base.authenticated.conversation.ReceiptMode
import com.wire.kalium.network.api.base.authenticated.serverpublickey.MLSPublicKeysDTO
import com.wire.kalium.network.api.base.model.ConversationAccessDTO
import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO
import com.wire.kalium.persistence.dao.conversation.ConversationEntity
Expand All @@ -59,6 +61,7 @@ import kotlin.time.toDuration

interface ConversationMapper {
fun fromApiModelToDaoModel(apiModel: ConversationResponse, mlsGroupState: GroupState?, selfUserTeamId: TeamId?): ConversationEntity
fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?): MLSPublicKeys?
fun fromDaoModel(daoModel: ConversationViewEntity): Conversation
fun fromDaoModel(daoModel: ConversationEntity): Conversation
fun fromDaoModelToDetails(
Expand Down Expand Up @@ -130,6 +133,12 @@ internal class ConversationMapperImpl(
legalHoldStatus = ConversationEntity.LegalHoldStatus.DISABLED
)

override fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?) = mlsPublicKeysDTO?.let {
MLSPublicKeys(
removal = mlsPublicKeysDTO.removal
)
}

override fun fromDaoModel(daoModel: ConversationViewEntity): Conversation = with(daoModel) {
val lastReadDateEntity = if (type == ConversationEntity.Type.CONNECTION_PENDING) UNIX_FIRST_DATE
else lastReadDate.toIsoDateTimeString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.mlspublickeys.getRemovalKey
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.feature.e2ei.usecase.CheckRevocationListUseCase
Expand Down Expand Up @@ -126,6 +128,7 @@ interface MLSConversationRepository {
suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
publicKeys: MLSPublicKeys? = null,
allowSkippingUsersWithoutKeyPackages: Boolean = false
): Either<CoreFailure, MLSAdditionResult>

Expand Down Expand Up @@ -575,16 +578,18 @@ internal class MLSConversationDataSource(
override suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean,
publicKeys: MLSPublicKeys?,
allowSkippingUsersWithoutKeyPackages: Boolean
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> {
mlsPublicKeysRepository.getKeyForCipherSuite(
CipherSuite.fromTag(it.getDefaultCipherSuite())
).flatMap { key ->
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> { mlsClient ->
val cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())
val keys = publicKeys?.getRemovalKey(cipherSuite) ?: mlsPublicKeysRepository.getKeyForCipherSuite(cipherSuite)

keys.flatMap { externalSenders ->
establishMLSGroup(
groupID = groupID,
members = members,
externalSenders = key,
externalSenders = externalSenders,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ data class MLSPublicKeys(
val removal: Map<String, String>?
)

fun MLSPublicKeys.getRemovalKey(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {
val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = this.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
return key.decodeBase64Bytes().right()
}

interface MLSPublicKeysRepository {
suspend fun fetchKeys(): Either<CoreFailure, MLSPublicKeys>
suspend fun getKeys(): Either<CoreFailure, MLSPublicKeys>
Expand All @@ -42,7 +51,6 @@ interface MLSPublicKeysRepository {

class MLSPublicKeysRepositoryImpl(
private val mlsPublicKeyApi: MLSPublicKeyApi,
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
) : MLSPublicKeysRepository {

// TODO: make it thread safe
Expand All @@ -60,14 +68,8 @@ class MLSPublicKeysRepositoryImpl(
}

override suspend fun getKeyForCipherSuite(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {

return getKeys().flatMap { serverPublicKeys ->
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = serverPublicKeys.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
key.decodeBase64Bytes().right()
serverPublicKeys.getRemovalKey(cipherSuite)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class ConversationGroupRepositoryTest {

verify(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.with(anything(), anything(), eq(true))
.with(anything(), anything(), anything(), eq(true))
.wasInvoked(once)

verify(newConversationMembersRepository)
Expand Down Expand Up @@ -323,7 +323,7 @@ class ConversationGroupRepositoryTest {

verify(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.with(anything(), anything(), eq(true))
.with(anything(), anything(), anything(), eq(true))
.wasInvoked(once)

verify(newConversationMembersRepository)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arr
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CRYPTO_CLIENT_ID
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.KEY_PACKAGE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.MLS_PUBLIC_KEY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.ROTATE_BUNDLE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.WIRE_IDENTITY
Expand Down Expand Up @@ -174,7 +175,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

verify(arrangement.mlsClient)
Expand Down Expand Up @@ -300,6 +301,90 @@ class MLSConversationRepositoryTest {
.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNotNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryNotCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = MLS_PUBLIC_KEY)
result.shouldSucceed()

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::createConversation)
.with(eq(Arrangement.RAW_GROUP_ID), anything())
.wasInvoked(once)

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::addMember)
.with(eq(Arrangement.RAW_GROUP_ID), anything())
.wasInvoked(once)

verify(arrangement.mlsMessageApi)
.suspendFunction(arrangement.mlsMessageApi::sendCommitBundle)
.with(anyInstanceOf(MLSMessageApi.CommitBundle::class))
.wasInvoked(once)

verify(arrangement.mlsClient)
.function(arrangement.mlsClient::commitAccepted)
.with(eq(Arrangement.RAW_GROUP_ID))
.wasInvoked(once)

verify(arrangement.mlsPublicKeysRepository)
.function(arrangement.mlsPublicKeysRepository::getKeyForCipherSuite)
.with(anything<CipherSuite>())
.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryIsCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::createConversation)
.with(eq(Arrangement.RAW_GROUP_ID), anything())
.wasInvoked(once)

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::addMember)
.with(eq(Arrangement.RAW_GROUP_ID), anything())
.wasInvoked(once)

verify(arrangement.mlsMessageApi)
.suspendFunction(arrangement.mlsMessageApi::sendCommitBundle)
.with(anyInstanceOf(MLSMessageApi.CommitBundle::class))
.wasInvoked(once)

verify(arrangement.mlsClient)
.function(arrangement.mlsClient::commitAccepted)
.with(eq(Arrangement.RAW_GROUP_ID))
.wasInvoked(once)

verify(arrangement.mlsPublicKeysRepository)
.function(arrangement.mlsPublicKeysRepository::getKeyForCipherSuite)
.with(anything<CipherSuite>())
.wasInvoked(once)
}

@Test
fun givenNewCrlDistributionPoints_whenEstablishingMLSGroup_thenCheckRevocationList() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
Expand Down Expand Up @@ -351,7 +436,7 @@ class MLSConversationRepositoryTest {
.withWaitUntilLiveSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

verify(arrangement.mlsClient)
Expand Down Expand Up @@ -382,7 +467,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR)
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldFail()

verify(arrangement.mlsMessageApi)
Expand Down Expand Up @@ -413,7 +498,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

verify(arrangement.keyPackageRepository)
Expand All @@ -434,7 +519,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList())
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList(), publicKeys = null)
result.shouldSucceed()

verify(arrangement.mlsClient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.wire.kalium.network.api.base.authenticated.conversation

import com.wire.kalium.network.api.base.authenticated.serverpublickey.MLSPublicKeysDTO
import com.wire.kalium.network.api.base.model.ConversationAccessDTO
import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO
import com.wire.kalium.network.api.base.model.ConversationId
Expand Down Expand Up @@ -86,7 +87,10 @@ data class ConversationResponse(
val accessRole: Set<ConversationAccessRoleDTO> = ConversationAccessRoleDTO.DEFAULT_VALUE_WHEN_NULL,

@SerialName("receipt_mode")
val receiptMode: ReceiptMode
val receiptMode: ReceiptMode,

@SerialName("public_keys")
val publicKeys: MLSPublicKeysDTO? = null
) {

@Suppress("MagicNumber")
Expand Down Expand Up @@ -152,6 +156,14 @@ data class ConversationResponseV3(
val receiptMode: ReceiptMode,
)

@Serializable
data class ConversationResponseV6(
@SerialName("conversation")
val conversation: ConversationResponseV3,
@SerialName("public_keys")
val publicKeys: MLSPublicKeysDTO
)

@Serializable
data class ConversationMembersResponse(
@SerialName("self")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.wire.kalium.network.api.base.model

import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponseV3
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponseV6
import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequestV3
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest
Expand All @@ -33,6 +34,7 @@ internal interface ApiModelMapper {
fun toApiV3(request: CreateConversationRequest): CreateConversationRequestV3
fun toApiV3(request: UpdateConversationAccessRequest): UpdateConversationAccessRequestV3
fun fromApiV3(response: ConversationResponseV3): ConversationResponse
fun fromApiV6(response: ConversationResponseV6): ConversationResponse
}

internal class ApiModelMapperImpl : ApiModelMapper {
Expand Down Expand Up @@ -76,4 +78,23 @@ internal class ApiModelMapperImpl : ApiModelMapper {
response.receiptMode
)

override fun fromApiV6(response: ConversationResponseV6): ConversationResponse =
ConversationResponse(
creator = response.conversation.creator,
members = response.conversation.members,
name = response.conversation.name,
id = response.conversation.id,
groupId = response.conversation.groupId,
epoch = response.conversation.epoch,
type = response.conversation.type,
messageTimer = response.conversation.messageTimer,
teamId = response.conversation.teamId,
protocol = response.conversation.protocol,
lastEventTime = response.conversation.lastEventTime,
mlsCipherSuiteTag = response.conversation.mlsCipherSuiteTag,
access = response.conversation.access,
accessRole = response.conversation.accessRole,
receiptMode = response.conversation.receiptMode,
publicKeys = response.publicKeys
)
}
Loading

0 comments on commit 1b85149

Please sign in to comment.