diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt index 30e02bc0415..b10043bc728 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt @@ -135,6 +135,7 @@ internal class ConversationGroupRepositoryImpl( val conversationEntity = conversationMapper.fromApiModelToDaoModel( conversationResponse, mlsGroupState = ConversationEntity.GroupState.PENDING_CREATION, selfTeamId ) + val mlsPublicKeys = conversationMapper.fromApiModel(conversationResponse.publicKeys) val protocol = protocolInfoMapper.fromEntity(conversationEntity.protocolInfo) return wrapStorageRequest { @@ -147,7 +148,8 @@ internal class ConversationGroupRepositoryImpl( is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup( groupID = protocol.groupId, members = usersList + selfUserId, - allowSkippingUsersWithoutKeyPackages = true + publicKeys = mlsPublicKeys, + allowSkippingUsersWithoutKeyPackages = true, ).map { it.notAddedUsers } } }.flatMap { additionalFailedUsers -> diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt index d77339e5b2e..4c47e7fa4c4 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt @@ -27,6 +27,7 @@ import com.wire.kalium.logic.data.id.toApi import com.wire.kalium.logic.data.id.toDao import com.wire.kalium.logic.data.id.toModel import com.wire.kalium.logic.data.message.MessagePreview +import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys import com.wire.kalium.logic.data.user.AvailabilityStatusMapper import com.wire.kalium.logic.data.user.BotService import com.wire.kalium.logic.data.user.Connection @@ -40,6 +41,7 @@ import com.wire.kalium.network.api.base.authenticated.conversation.ConvTeamInfo import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequest import com.wire.kalium.network.api.base.authenticated.conversation.ReceiptMode +import com.wire.kalium.network.api.base.authenticated.serverpublickey.MLSPublicKeysDTO import com.wire.kalium.network.api.base.model.ConversationAccessDTO import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO import com.wire.kalium.persistence.dao.conversation.ConversationEntity @@ -59,6 +61,7 @@ import kotlin.time.toDuration interface ConversationMapper { fun fromApiModelToDaoModel(apiModel: ConversationResponse, mlsGroupState: GroupState?, selfUserTeamId: TeamId?): ConversationEntity + fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?): MLSPublicKeys? fun fromDaoModel(daoModel: ConversationViewEntity): Conversation fun fromDaoModel(daoModel: ConversationEntity): Conversation fun fromDaoModelToDetails( @@ -130,6 +133,12 @@ internal class ConversationMapperImpl( legalHoldStatus = ConversationEntity.LegalHoldStatus.DISABLED ) + override fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?) = mlsPublicKeysDTO?.let { + MLSPublicKeys( + removal = mlsPublicKeysDTO.removal + ) + } + override fun fromDaoModel(daoModel: ConversationViewEntity): Conversation = with(daoModel) { val lastReadDateEntity = if (type == ConversationEntity.Type.CONNECTION_PENDING) UNIX_FIRST_DATE else lastReadDate.toIsoDateTimeString() diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index 7629f5d0110..a36555eef9c 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -47,7 +47,9 @@ import com.wire.kalium.logic.data.id.toModel import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider import com.wire.kalium.logic.data.keypackage.KeyPackageRepository import com.wire.kalium.logic.data.mls.CipherSuite +import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository +import com.wire.kalium.logic.data.mlspublickeys.getRemovalKey import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.feature.e2ei.usecase.CheckRevocationListUseCase @@ -126,6 +128,7 @@ interface MLSConversationRepository { suspend fun establishMLSGroup( groupID: GroupID, members: List, + publicKeys: MLSPublicKeys? = null, allowSkippingUsersWithoutKeyPackages: Boolean = false ): Either @@ -575,16 +578,18 @@ internal class MLSConversationDataSource( override suspend fun establishMLSGroup( groupID: GroupID, members: List, - allowSkippingUsersWithoutKeyPackages: Boolean, + publicKeys: MLSPublicKeys?, + allowSkippingUsersWithoutKeyPackages: Boolean ): Either = withContext(serialDispatcher) { - mlsClientProvider.getMLSClient().flatMap { - mlsPublicKeysRepository.getKeyForCipherSuite( - CipherSuite.fromTag(it.getDefaultCipherSuite()) - ).flatMap { key -> + mlsClientProvider.getMLSClient().flatMap { mlsClient -> + val cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite()) + val keys = publicKeys?.getRemovalKey(cipherSuite) ?: mlsPublicKeysRepository.getKeyForCipherSuite(cipherSuite) + + keys.flatMap { externalSenders -> establishMLSGroup( groupID = groupID, members = members, - externalSenders = key, + externalSenders = externalSenders, allowPartialMemberList = allowSkippingUsersWithoutKeyPackages ) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/mlspublickeys/MLSPublicKeysRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/mlspublickeys/MLSPublicKeysRepository.kt index 48709255f9b..180c506a838 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/mlspublickeys/MLSPublicKeysRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/mlspublickeys/MLSPublicKeysRepository.kt @@ -34,6 +34,15 @@ data class MLSPublicKeys( val removal: Map? ) +fun MLSPublicKeys.getRemovalKey(cipherSuite: CipherSuite): Either { + val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper() + val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite) + val key = this.removal?.let { removalKeys -> + removalKeys[keySignature.value] + } ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite"))) + return key.decodeBase64Bytes().right() +} + interface MLSPublicKeysRepository { suspend fun fetchKeys(): Either suspend fun getKeys(): Either @@ -42,7 +51,6 @@ interface MLSPublicKeysRepository { class MLSPublicKeysRepositoryImpl( private val mlsPublicKeyApi: MLSPublicKeyApi, - private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper() ) : MLSPublicKeysRepository { // TODO: make it thread safe @@ -60,14 +68,8 @@ class MLSPublicKeysRepositoryImpl( } override suspend fun getKeyForCipherSuite(cipherSuite: CipherSuite): Either { - return getKeys().flatMap { serverPublicKeys -> - val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite) - val key = serverPublicKeys.removal?.let { removalKeys -> - removalKeys[keySignature.value] - } ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite"))) - key.decodeBase64Bytes().right() + serverPublicKeys.getRemovalKey(cipherSuite) } } - } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt index 6025b53873a..e4733ac2618 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt @@ -279,7 +279,7 @@ class ConversationGroupRepositoryTest { verify(mlsConversationRepository) .suspendFunction(mlsConversationRepository::establishMLSGroup) - .with(anything(), anything(), eq(true)) + .with(anything(), anything(), anything(), eq(true)) .wasInvoked(once) verify(newConversationMembersRepository) @@ -323,7 +323,7 @@ class ConversationGroupRepositoryTest { verify(mlsConversationRepository) .suspendFunction(mlsConversationRepository::establishMLSGroup) - .with(anything(), anything(), eq(true)) + .with(anything(), anything(), anything(), eq(true)) .wasInvoked(once) verify(newConversationMembersRepository) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index a81bf905943..e00d437401a 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -42,6 +42,7 @@ import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arr import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CRYPTO_CLIENT_ID import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.KEY_PACKAGE +import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.MLS_PUBLIC_KEY import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.ROTATE_BUNDLE import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.WIRE_IDENTITY @@ -174,7 +175,7 @@ class MLSConversationRepositoryTest { .withSendCommitBundleSuccessful() .arrange() - val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1)) + val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null) result.shouldSucceed() verify(arrangement.mlsClient) @@ -300,6 +301,90 @@ class MLSConversationRepositoryTest { .wasNotInvoked() } + @Test + fun givenPublicKeysIsNotNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryNotCalled() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519) + .withCommitPendingProposalsReturningNothing() + .withClaimKeyPackagesSuccessful() + .withGetMLSClientSuccessful() + .withKeyForCipherSuite() + .withAddMLSMemberSuccessful() + .withSendCommitBundleSuccessful() + .arrange() + + val result = + mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = MLS_PUBLIC_KEY) + result.shouldSucceed() + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::createConversation) + .with(eq(Arrangement.RAW_GROUP_ID), anything()) + .wasInvoked(once) + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::addMember) + .with(eq(Arrangement.RAW_GROUP_ID), anything()) + .wasInvoked(once) + + verify(arrangement.mlsMessageApi) + .suspendFunction(arrangement.mlsMessageApi::sendCommitBundle) + .with(anyInstanceOf(MLSMessageApi.CommitBundle::class)) + .wasInvoked(once) + + verify(arrangement.mlsClient) + .function(arrangement.mlsClient::commitAccepted) + .with(eq(Arrangement.RAW_GROUP_ID)) + .wasInvoked(once) + + verify(arrangement.mlsPublicKeysRepository) + .function(arrangement.mlsPublicKeysRepository::getKeyForCipherSuite) + .with(anything()) + .wasNotInvoked() + } + + @Test + fun givenPublicKeysIsNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryIsCalled() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519) + .withCommitPendingProposalsReturningNothing() + .withClaimKeyPackagesSuccessful() + .withGetMLSClientSuccessful() + .withKeyForCipherSuite() + .withAddMLSMemberSuccessful() + .withSendCommitBundleSuccessful() + .arrange() + + val result = + mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null) + result.shouldSucceed() + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::createConversation) + .with(eq(Arrangement.RAW_GROUP_ID), anything()) + .wasInvoked(once) + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::addMember) + .with(eq(Arrangement.RAW_GROUP_ID), anything()) + .wasInvoked(once) + + verify(arrangement.mlsMessageApi) + .suspendFunction(arrangement.mlsMessageApi::sendCommitBundle) + .with(anyInstanceOf(MLSMessageApi.CommitBundle::class)) + .wasInvoked(once) + + verify(arrangement.mlsClient) + .function(arrangement.mlsClient::commitAccepted) + .with(eq(Arrangement.RAW_GROUP_ID)) + .wasInvoked(once) + + verify(arrangement.mlsPublicKeysRepository) + .function(arrangement.mlsPublicKeysRepository::getKeyForCipherSuite) + .with(anything()) + .wasInvoked(once) + } + @Test fun givenNewCrlDistributionPoints_whenEstablishingMLSGroup_thenCheckRevocationList() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() @@ -351,7 +436,7 @@ class MLSConversationRepositoryTest { .withWaitUntilLiveSuccessful() .arrange() - val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1)) + val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null) result.shouldSucceed() verify(arrangement.mlsClient) @@ -382,7 +467,7 @@ class MLSConversationRepositoryTest { .withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR) .arrange() - val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1)) + val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null) result.shouldFail() verify(arrangement.mlsMessageApi) @@ -413,7 +498,7 @@ class MLSConversationRepositoryTest { .withSendCommitBundleSuccessful() .arrange() - val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1)) + val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null) result.shouldSucceed() verify(arrangement.keyPackageRepository) @@ -434,7 +519,7 @@ class MLSConversationRepositoryTest { .withSendCommitBundleSuccessful() .arrange() - val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList()) + val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList(), publicKeys = null) result.shouldSucceed() verify(arrangement.mlsClient) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationResponse.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationResponse.kt index 5a8e10c3e82..f5c506f8139 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationResponse.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationResponse.kt @@ -18,6 +18,7 @@ package com.wire.kalium.network.api.base.authenticated.conversation +import com.wire.kalium.network.api.base.authenticated.serverpublickey.MLSPublicKeysDTO import com.wire.kalium.network.api.base.model.ConversationAccessDTO import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO import com.wire.kalium.network.api.base.model.ConversationId @@ -86,7 +87,10 @@ data class ConversationResponse( val accessRole: Set = ConversationAccessRoleDTO.DEFAULT_VALUE_WHEN_NULL, @SerialName("receipt_mode") - val receiptMode: ReceiptMode + val receiptMode: ReceiptMode, + + @SerialName("public_keys") + val publicKeys: MLSPublicKeysDTO? = null ) { @Suppress("MagicNumber") @@ -152,6 +156,14 @@ data class ConversationResponseV3( val receiptMode: ReceiptMode, ) +@Serializable +data class ConversationResponseV6( + @SerialName("conversation") + val conversation: ConversationResponseV3, + @SerialName("public_keys") + val publicKeys: MLSPublicKeysDTO +) + @Serializable data class ConversationMembersResponse( @SerialName("self") diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/ApiModelMapper.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/ApiModelMapper.kt index 54bfa3df9a1..1991398d143 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/ApiModelMapper.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/ApiModelMapper.kt @@ -20,6 +20,7 @@ package com.wire.kalium.network.api.base.model import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponseV3 +import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponseV6 import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequest import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequestV3 import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest @@ -33,6 +34,7 @@ internal interface ApiModelMapper { fun toApiV3(request: CreateConversationRequest): CreateConversationRequestV3 fun toApiV3(request: UpdateConversationAccessRequest): UpdateConversationAccessRequestV3 fun fromApiV3(response: ConversationResponseV3): ConversationResponse + fun fromApiV6(response: ConversationResponseV6): ConversationResponse } internal class ApiModelMapperImpl : ApiModelMapper { @@ -76,4 +78,23 @@ internal class ApiModelMapperImpl : ApiModelMapper { response.receiptMode ) + override fun fromApiV6(response: ConversationResponseV6): ConversationResponse = + ConversationResponse( + creator = response.conversation.creator, + members = response.conversation.members, + name = response.conversation.name, + id = response.conversation.id, + groupId = response.conversation.groupId, + epoch = response.conversation.epoch, + type = response.conversation.type, + messageTimer = response.conversation.messageTimer, + teamId = response.conversation.teamId, + protocol = response.conversation.protocol, + lastEventTime = response.conversation.lastEventTime, + mlsCipherSuiteTag = response.conversation.mlsCipherSuiteTag, + access = response.conversation.access, + accessRole = response.conversation.accessRole, + receiptMode = response.conversation.receiptMode, + publicKeys = response.publicKeys + ) } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v6/authenticated/ConversationApiV6.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v6/authenticated/ConversationApiV6.kt index bcdaa50c0a4..69969799443 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v6/authenticated/ConversationApiV6.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v6/authenticated/ConversationApiV6.kt @@ -19,8 +19,29 @@ package com.wire.kalium.network.api.v6.authenticated import com.wire.kalium.network.AuthenticatedNetworkClient +import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse +import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponseV6 +import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequest +import com.wire.kalium.network.api.base.model.ApiModelMapper +import com.wire.kalium.network.api.base.model.ApiModelMapperImpl import com.wire.kalium.network.api.v5.authenticated.ConversationApiV5 +import com.wire.kalium.network.utils.NetworkResponse +import com.wire.kalium.network.utils.mapSuccess +import com.wire.kalium.network.utils.wrapKaliumResponse +import io.ktor.client.request.post +import io.ktor.client.request.setBody internal open class ConversationApiV6 internal constructor( authenticatedNetworkClient: AuthenticatedNetworkClient, -) : ConversationApiV5(authenticatedNetworkClient) + private val apiModelMapper: ApiModelMapper = ApiModelMapperImpl() +) : ConversationApiV5(authenticatedNetworkClient) { + override suspend fun createOne2OneConversation( + createConversationRequest: CreateConversationRequest + ): NetworkResponse = wrapKaliumResponse { + httpClient.post("$PATH_CONVERSATIONS/$PATH_ONE_2_ONE") { + setBody(apiModelMapper.toApiV3(createConversationRequest)) + } + }.mapSuccess { + apiModelMapper.fromApiV6(it) + } +}