Skip to content

Commit

Permalink
fix(mls): fetch and set mls-removal keys for 1on1 conversations (WPB-…
Browse files Browse the repository at this point in the history
…10743) πŸ’ πŸ’ (#3044)

* Commit with unresolved merge conflicts

* fix merge conflicts

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Mohamad Jaara <[email protected]>
  • Loading branch information
github-actions[bot] and MohamadJaara authored Oct 1, 2024
1 parent c667979 commit 8b65be7
Show file tree
Hide file tree
Showing 17 changed files with 93 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.wire.kalium.logic.data.id.TeamId
import com.wire.kalium.logic.data.message.MessagePreview
import com.wire.kalium.logic.data.message.UnreadEventType
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.data.user.ConnectionState
import com.wire.kalium.logic.data.user.OtherUser
import com.wire.kalium.logic.data.user.User
Expand Down Expand Up @@ -79,7 +80,8 @@ data class Conversation(
val archivedDateTime: Instant?,
val mlsVerificationStatus: VerificationStatus,
val proteusVerificationStatus: VerificationStatus,
val legalHoldStatus: LegalHoldStatus
val legalHoldStatus: LegalHoldStatus,
val mlsPublicKeys: MLSPublicKeys? = null
) {

companion object {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.data.mls

data class MLSPublicKeys(
val removal: Map<String, String>?
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.message.MessagePreview
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.data.user.AvailabilityStatusMapper
import com.wire.kalium.logic.data.user.BotService
import com.wire.kalium.logic.data.user.Connection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ internal class ConversationDataSource internal constructor(
wrapApiRequest {
conversationApi.fetchMlsOneToOneConversation(userId.toApi())
}.map { conversationResponse ->
// question: do we need to do this? since it's one on one!
addOtherMemberIfMissing(conversationResponse, userId)
}.flatMap { conversationResponse ->
val selfUserTeamId = selfTeamIdProvider().getOrNull()
Expand All @@ -554,7 +555,9 @@ internal class ConversationDataSource internal constructor(
selfUserTeamId = selfUserTeamId
).map { conversationResponse }
}.flatMap { response ->
this.getConversationById(response.id.toModel())
this.getConversationById(response.id.toModel()).map {
it.copy(mlsPublicKeys = conversationMapper.fromApiModel(response.publicKeys))
}
}

private fun addOtherMemberIfMissing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
Expand All @@ -51,7 +52,7 @@ import kotlinx.coroutines.withContext
* but has not yet joined the corresponding MLS group.
*/
internal interface JoinExistingMLSConversationUseCase {
suspend operator fun invoke(conversationId: ConversationId): Either<CoreFailure, Unit>
suspend operator fun invoke(conversationId: ConversationId, mlsPublicKeys: MLSPublicKeys? = null): Either<CoreFailure, Unit>
}

@Suppress("LongParameterList")
Expand All @@ -65,7 +66,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
) : JoinExistingMLSConversationUseCase {
private val dispatcher = kaliumDispatcher.io

override suspend operator fun invoke(conversationId: ConversationId): Either<CoreFailure, Unit> =
override suspend operator fun invoke(conversationId: ConversationId, mlsPublicKeys: MLSPublicKeys?): Either<CoreFailure, Unit> =
if (!featureSupport.isMLSSupported ||
!clientRepository.hasRegisteredMLSClient().getOrElse(false)
) {
Expand All @@ -76,15 +77,16 @@ internal class JoinExistingMLSConversationUseCaseImpl(
Either.Left(StorageFailure.DataNotFound)
}, { conversation ->
withContext(dispatcher) {
joinOrEstablishMLSGroupAndRetry(conversation)
joinOrEstablishMLSGroupAndRetry(conversation, mlsPublicKeys)
}
})
}

private suspend fun joinOrEstablishMLSGroupAndRetry(
conversation: Conversation
conversation: Conversation,
mlsPublicKeys: MLSPublicKeys?
): Either<CoreFailure, Unit> =
joinOrEstablishMLSGroup(conversation)
joinOrEstablishMLSGroup(conversation, mlsPublicKeys)
.flatMapLeft { failure ->
if (failure is NetworkFailure.ServerMiscommunication && failure.kaliumException is KaliumException.InvalidRequestError) {
if (failure.kaliumException.isMlsStaleMessage()) {
Expand All @@ -101,13 +103,15 @@ internal class JoinExistingMLSConversationUseCaseImpl(
// Re-fetch current epoch and try again
if (conversation.type == Conversation.Type.ONE_ON_ONE) {
conversationRepository.getConversationMembers(conversation.id).flatMap {
conversationRepository.fetchMlsOneToOneConversation(it.first())
conversationRepository.fetchMlsOneToOneConversation(it.first()).map {
it.mlsPublicKeys
}
}
} else {
conversationRepository.fetchConversation(conversation.id)
}.flatMap {
conversationRepository.getConversationById(conversation.id).flatMap { conversation ->
joinOrEstablishMLSGroup(conversation)
joinOrEstablishMLSGroup(conversation, null)
}
}
} else if (failure.kaliumException.isMlsMissingGroupInfo()) {
Expand All @@ -122,7 +126,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
}

@Suppress("LongMethod")
private suspend fun joinOrEstablishMLSGroup(conversation: Conversation): Either<CoreFailure, Unit> {
private suspend fun joinOrEstablishMLSGroup(conversation: Conversation, publicKeys: MLSPublicKeys?): Either<CoreFailure, Unit> {
val protocol = conversation.protocol
val type = conversation.type
return when {
Expand Down Expand Up @@ -202,7 +206,8 @@ internal class JoinExistingMLSConversationUseCaseImpl(
conversationRepository.getConversationMembers(conversation.id).flatMap { members ->
mlsConversationRepository.establishMLSGroup(
protocol.groupId,
members
members,
publicKeys
)
}.onSuccess {
kaliumLogger.logStructuredJson(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.mlspublickeys.getRemovalKey
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.e2ei.RevocationListChecker
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.flatMapLeft
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package com.wire.kalium.logic.data.mlspublickeys
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
Expand All @@ -30,10 +31,6 @@ import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.network.api.base.authenticated.serverpublickey.MLSPublicKeyApi
import io.ktor.util.decodeBase64Bytes

data class MLSPublicKeys(
val removal: Map<String, String>?
)

fun MLSPublicKeys.getRemovalKey(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {
val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ internal class MLSOneOnOneConversationResolverImpl(
} else {
kaliumLogger.d("Establishing mls group for one-on-one with ${userId.toLogString()}")
conversationRepository.fetchMlsOneToOneConversation(userId).flatMap { conversation ->
joinExistingMLSConversationUseCase(conversation.id).map { conversation.id }
joinExistingMLSConversationUseCase(conversation.id, conversation.mlsPublicKeys).map { conversation.id }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,10 @@ class ConversationGroupRepositoryTest {
}.wasInvoked(exactly = once)

coVerify {
arrangement.joinExistingMLSConversation.invoke(eq(ADD_MEMBER_TO_CONVERSATION_SUCCESSFUL_RESPONSE.event.qualifiedConversation.toModel()))
arrangement.joinExistingMLSConversation.invoke(
ADD_MEMBER_TO_CONVERSATION_SUCCESSFUL_RESPONSE.event.qualifiedConversation.toModel(),
null
)
}.wasInvoked(exactly = once)

coVerify {
Expand Down Expand Up @@ -950,7 +953,10 @@ class ConversationGroupRepositoryTest {
}.wasInvoked(exactly = once)

coVerify {
arrangement.joinExistingMLSConversation.invoke(eq(ADD_MEMBER_TO_CONVERSATION_SUCCESSFUL_RESPONSE.event.qualifiedConversation.toModel()))
arrangement.joinExistingMLSConversation.invoke(
ADD_MEMBER_TO_CONVERSATION_SUCCESSFUL_RESPONSE.event.qualifiedConversation.toModel(),
null
)
}.wasInvoked(exactly = once)

coVerify {
Expand Down Expand Up @@ -1282,9 +1288,10 @@ class ConversationGroupRepositoryTest {
@Test
fun givenAConversationFailsWithUnreachableAndNotFromUsersInRequest_whenAddingMembers_thenRetryIsNotExecutedAndCreateSysMessage() =
runTest {
val conversation = TestConversation.CONVERSATION.copy(id = ConversationId("valueConvo", "domainConvo"))
// given
val (arrangement, conversationGroupRepository) = Arrangement()
.withConversationDetailsById(TestConversation.CONVERSATION)
.withConversationDetailsById(conversation)
.withProtocolInfoById(PROTEUS_PROTOCOL_INFO)
.withFetchUsersIfUnknownByIdsSuccessful()
.withAddMemberAPIFailsFirstWithUnreachableThenSucceed(
Expand All @@ -1309,11 +1316,9 @@ class ConversationGroupRepositoryTest {

coVerify {
arrangement.newGroupConversationSystemMessagesCreator.conversationFailedToAddMembers(
conversationId = any(),
userIdList = matches {
it.containsAll(expectedInitialUsersNotFromUnreachableInformed)
},
type = any()
conversationId = conversation.id,
userIdList = expectedInitialUsersNotFromUnreachableInformed,
type = MessageContent.MemberChange.FailedToAdd.Type.Federation
)
}.wasInvoked(once)
}
Expand Down Expand Up @@ -1772,7 +1777,7 @@ class ConversationGroupRepositoryTest {

suspend fun withJoinExistingMlsConversationSucceeds() = apply {
coEvery {
joinExistingMLSConversation.invoke(any())
joinExistingMLSConversation.invoke(any(), any())
}.returns(Either.Right(Unit))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import com.wire.kalium.logic.data.id.toCrypto
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.framework.TestClient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class JoinExistingMLSConversationsUseCaseTest {
}.wasNotInvoked()

coVerify {
arrangement.joinExistingMLSConversationUseCase.invoke(any())
arrangement.joinExistingMLSConversationUseCase.invoke(any(), any())
}.wasNotInvoked()
}

Expand All @@ -76,7 +76,7 @@ class JoinExistingMLSConversationsUseCaseTest {
}.wasNotInvoked()

coVerify {
arrangement.joinExistingMLSConversationUseCase.invoke(any())
arrangement.joinExistingMLSConversationUseCase.invoke(any(), any())
}.wasNotInvoked()
}

Expand All @@ -88,7 +88,7 @@ class JoinExistingMLSConversationsUseCaseTest {
joinExistingMLSConversationsUseCase().shouldSucceed()

coVerify {
arrangement.joinExistingMLSConversationUseCase.invoke(any())
arrangement.joinExistingMLSConversationUseCase.invoke(any(), any())
}.wasInvoked(twice)
}

Expand All @@ -100,7 +100,7 @@ class JoinExistingMLSConversationsUseCaseTest {
joinExistingMLSConversationsUseCase().shouldSucceed()

coVerify {
arrangement.joinExistingMLSConversationUseCase.invoke(any())
arrangement.joinExistingMLSConversationUseCase.invoke(any(), any())
}.wasInvoked(twice)
}

Expand All @@ -113,7 +113,7 @@ class JoinExistingMLSConversationsUseCaseTest {
assertIs<NetworkFailure>(it)
}
coVerify {
arrangement.joinExistingMLSConversationUseCase.invoke(any())
arrangement.joinExistingMLSConversationUseCase.invoke(any(), any())
}.wasInvoked(twice)
}

Expand All @@ -125,7 +125,7 @@ class JoinExistingMLSConversationsUseCaseTest {
joinExistingMLSConversationsUseCase().shouldSucceed()

coVerify {
arrangement.joinExistingMLSConversationUseCase.invoke(any())
arrangement.joinExistingMLSConversationUseCase.invoke(any(), any())
}.wasInvoked(twice)
}

Expand Down Expand Up @@ -161,25 +161,25 @@ class JoinExistingMLSConversationsUseCaseTest {

suspend fun withJoinExistingMLSConversationSuccessful() = apply {
coEvery {
joinExistingMLSConversationUseCase.invoke(any())
joinExistingMLSConversationUseCase.invoke(any(), any())
}.returns(Either.Right(Unit))
}

suspend fun withJoinExistingMLSConversationNetworkFailure() = apply {
coEvery {
joinExistingMLSConversationUseCase.invoke(any())
joinExistingMLSConversationUseCase.invoke(any(), any())
}.returns(Either.Left(NetworkFailure.NoNetworkConnection(null)))
}

suspend fun withJoinExistingMLSConversationFailure() = apply {
coEvery {
joinExistingMLSConversationUseCase.invoke(any())
joinExistingMLSConversationUseCase.invoke(any(), any())
}.returns(Either.Left(CoreFailure.NotSupportedByProteus))
}

suspend fun withNoKeyPackagesAvailable() = apply {
coEvery {
joinExistingMLSConversationUseCase.invoke(any())
joinExistingMLSConversationUseCase.invoke(any(), any())
}.returns(Either.Left(CoreFailure.MissingKeyPackages(setOf())))
}

Expand Down
Loading

0 comments on commit 8b65be7

Please sign in to comment.