Skip to content

Commit

Permalink
fix(mls): set removal-keys for 1on1 calls from conversation-response …
Browse files Browse the repository at this point in the history
…(WPB-10743) 🍒 (#3019)

* Commit with unresolved merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix stalling test

* fix tests

* fix tests

---------

Co-authored-by: Mojtaba Chenani <[email protected]>
Co-authored-by: Mohamad Jaara <[email protected]>
  • Loading branch information
3 people authored Sep 26, 2024
1 parent 8268f21 commit 1dd7cb2
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ internal class ConversationGroupRepositoryImpl(
val conversationEntity = conversationMapper.fromApiModelToDaoModel(
conversationResponse, mlsGroupState = ConversationEntity.GroupState.PENDING_CREATION, selfTeamId
)
val mlsPublicKeys = conversationMapper.fromApiModel(conversationResponse.publicKeys)
val protocol = protocolInfoMapper.fromEntity(conversationEntity.protocolInfo)

return wrapStorageRequest {
Expand All @@ -166,7 +167,8 @@ internal class ConversationGroupRepositoryImpl(
is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup(
groupID = protocol.groupId,
members = usersList + selfUserId,
allowSkippingUsersWithoutKeyPackages = true
publicKeys = mlsPublicKeys,
allowSkippingUsersWithoutKeyPackages = true,
).map { it.notAddedUsers }
}
}.flatMap { protocolSpecificAdditionFailures ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.message.MessagePreview
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.user.AvailabilityStatusMapper
import com.wire.kalium.logic.data.user.BotService
import com.wire.kalium.logic.data.user.Connection
Expand All @@ -41,6 +42,7 @@ import com.wire.kalium.network.api.authenticated.conversation.ConvTeamInfo
import com.wire.kalium.network.api.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.authenticated.conversation.ReceiptMode
import com.wire.kalium.network.api.authenticated.serverpublickey.MLSPublicKeysDTO
import com.wire.kalium.network.api.model.ConversationAccessDTO
import com.wire.kalium.network.api.model.ConversationAccessRoleDTO
import com.wire.kalium.persistence.dao.conversation.ConversationEntity
Expand All @@ -59,6 +61,7 @@ import kotlin.time.toDuration

interface ConversationMapper {
fun fromApiModelToDaoModel(apiModel: ConversationResponse, mlsGroupState: GroupState?, selfUserTeamId: TeamId?): ConversationEntity
fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?): MLSPublicKeys?
fun fromDaoModel(daoModel: ConversationViewEntity): Conversation
fun fromDaoModel(daoModel: ConversationEntity): Conversation
fun fromDaoModelToDetails(
Expand Down Expand Up @@ -136,6 +139,12 @@ internal class ConversationMapperImpl(
legalHoldStatus = ConversationEntity.LegalHoldStatus.DISABLED
)

override fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?) = mlsPublicKeysDTO?.let {
MLSPublicKeys(
removal = mlsPublicKeysDTO.removal
)
}

override fun fromDaoModel(daoModel: ConversationViewEntity): Conversation = with(daoModel) {
val lastReadDateEntity = if (type == ConversationEntity.Type.CONNECTION_PENDING) Instant.UNIX_FIRST_DATE
else lastReadDate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.mlspublickeys.getRemovalKey
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.e2ei.RevocationListChecker
Expand Down Expand Up @@ -123,6 +125,7 @@ interface MLSConversationRepository {
suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
publicKeys: MLSPublicKeys? = null,
allowSkippingUsersWithoutKeyPackages: Boolean = false
): Either<CoreFailure, MLSAdditionResult>

Expand Down Expand Up @@ -554,16 +557,18 @@ internal class MLSConversationDataSource(
override suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean,
publicKeys: MLSPublicKeys?,
allowSkippingUsersWithoutKeyPackages: Boolean
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> {
mlsPublicKeysRepository.getKeyForCipherSuite(
CipherSuite.fromTag(it.getDefaultCipherSuite())
).flatMap { key ->
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> { mlsClient ->
val cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())
val keys = publicKeys?.getRemovalKey(cipherSuite) ?: mlsPublicKeysRepository.getKeyForCipherSuite(cipherSuite)

keys.flatMap { externalSenders ->
establishMLSGroup(
groupID = groupID,
members = members,
externalSenders = key,
externalSenders = externalSenders,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ data class MLSPublicKeys(
val removal: Map<String, String>?
)

fun MLSPublicKeys.getRemovalKey(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {
val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = this.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
return key.decodeBase64Bytes().right()
}

interface MLSPublicKeysRepository {
suspend fun fetchKeys(): Either<CoreFailure, MLSPublicKeys>
suspend fun getKeys(): Either<CoreFailure, MLSPublicKeys>
Expand All @@ -42,7 +51,6 @@ interface MLSPublicKeysRepository {

class MLSPublicKeysRepositoryImpl(
private val mlsPublicKeyApi: MLSPublicKeyApi,
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
) : MLSPublicKeysRepository {

// TODO: make it thread safe
Expand All @@ -60,14 +68,8 @@ class MLSPublicKeysRepositoryImpl(
}

override suspend fun getKeyForCipherSuite(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {

return getKeys().flatMap { serverPublicKeys ->
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = serverPublicKeys.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
key.decodeBase64Bytes().right()
serverPublicKeys.getRemovalKey(cipherSuite)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ class ConversationGroupRepositoryTest {
}.wasInvoked(once)

coVerify {
mlsConversationRepository.establishMLSGroup(any(), any(), eq(true))
mlsConversationRepository.establishMLSGroup(any(), any(), any(), eq(true))
}.wasInvoked(once)

coVerify {
Expand Down Expand Up @@ -465,7 +465,7 @@ class ConversationGroupRepositoryTest {
}.wasInvoked(once)

coVerify {
mlsConversationRepository.establishMLSGroup(any(), any(), eq(true))
mlsConversationRepository.establishMLSGroup(any(), any(), any(), eq(true))
}.wasInvoked(once)

coVerify {
Expand Down Expand Up @@ -1723,11 +1723,10 @@ class ConversationGroupRepositoryTest {
legalHoldHandler
)

suspend fun withMlsConversationEstablished(additionResult: MLSAdditionResult): Arrangement {
suspend fun withMlsConversationEstablished(additionResult: MLSAdditionResult): Arrangement = apply {
coEvery {
mlsConversationRepository.establishMLSGroup(any(), any(), any())
mlsConversationRepository.establishMLSGroup(any(), any(), any(), any())
}.returns(Either.Right(additionResult))
return this
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arr
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CRYPTO_CLIENT_ID
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.KEY_PACKAGE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.MLS_PUBLIC_KEY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.ROTATE_BUNDLE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.WIRE_IDENTITY
Expand Down Expand Up @@ -98,26 +99,30 @@ import io.mockative.matches
import io.mockative.mock
import io.mockative.once
import io.mockative.twice
import io.mockative.verify
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.yield
import kotlinx.datetime.Instant
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs

class MLSConversationRepositoryTest {

@BeforeTest

@Test
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest(TestKaliumDispatcher.default) {
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
.withGetMLSClientSuccessful()
.withDecryptMLSMessageSuccessful(Arrangement.DECRYPTED_MESSAGE_BUNDLE)
.arrange()

val epochChange = async(TestKaliumDispatcher.default) {
val epochChange = async() {
arrangement.epochsFlow.first()
}
yield()
Expand Down Expand Up @@ -168,7 +173,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand Down Expand Up @@ -280,6 +285,84 @@ class MLSConversationRepositoryTest {
}.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNotNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryNotCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(kaliumDispatcher = testKaliumDispatcher)
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = MLS_PUBLIC_KEY)
result.shouldSucceed()

coVerify {
arrangement.mlsClient.createConversation(
groupId = eq(Arrangement.RAW_GROUP_ID),
externalSenders = any())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.addMember(
groupId = eq(Arrangement.RAW_GROUP_ID),
membersKeyPackages = any())
}.wasInvoked(once)

coVerify {
arrangement.mlsMessageApi.sendCommitBundle(any<MLSMessageApi.CommitBundle>())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.commitAccepted(eq(Arrangement.RAW_GROUP_ID))
}.wasInvoked(once)

coVerify {
arrangement.mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryIsCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
arrangement.mlsClient.createConversation(eq(Arrangement.RAW_GROUP_ID), any())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.addMember(eq(Arrangement.RAW_GROUP_ID), any())
}.wasInvoked(once)

coVerify {
arrangement.mlsMessageApi.sendCommitBundle(any<MLSMessageApi.CommitBundle>())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.commitAccepted(eq(Arrangement.RAW_GROUP_ID))
}.wasInvoked(once)

coVerify {
arrangement.mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.wasInvoked(once)
}

@Test
fun givenNewCrlDistributionPoints_whenEstablishingMLSGroup_thenCheckRevocationList() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
Expand Down Expand Up @@ -329,7 +412,7 @@ class MLSConversationRepositoryTest {
.withWaitUntilLiveSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand Down Expand Up @@ -357,7 +440,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR)
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldFail()

coVerify {
Expand Down Expand Up @@ -385,7 +468,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand All @@ -410,7 +493,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList())
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList(), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,12 @@ class JoinExistingMLSConversationUseCaseTest {
joinExistingMLSConversationsUseCase(Arrangement.MLS_UNESTABLISHED_GROUP_CONVERSATION.id).shouldSucceed()

coVerify {
arrangement.mlsConversationRepository.establishMLSGroup(eq(Arrangement.GROUP_ID3), eq(emptyList()), any())
arrangement.mlsConversationRepository.establishMLSGroup(
groupID = Arrangement.GROUP_ID3,
members = emptyList(),
publicKeys = null,
allowSkippingUsersWithoutKeyPackages = false
)
}.wasNotInvoked()
}

Expand All @@ -148,7 +153,12 @@ class JoinExistingMLSConversationUseCaseTest {
joinExistingMLSConversationsUseCase(Arrangement.MLS_UNESTABLISHED_SELF_CONVERSATION.id).shouldSucceed()

coVerify {
arrangement.mlsConversationRepository.establishMLSGroup(eq(Arrangement.GROUP_ID_SELF), eq(emptyList()), any())
arrangement.mlsConversationRepository.establishMLSGroup(
groupID = Arrangement.GROUP_ID_SELF,
members = emptyList(),
publicKeys = null,
allowSkippingUsersWithoutKeyPackages = false
)
}.wasInvoked(once)
}

Expand All @@ -167,7 +177,12 @@ class JoinExistingMLSConversationUseCaseTest {
joinExistingMLSConversationsUseCase(Arrangement.MLS_UNESTABLISHED_ONE_ONE_ONE_CONVERSATION.id).shouldSucceed()

coVerify {
arrangement.mlsConversationRepository.establishMLSGroup(eq(Arrangement.GROUP_ID_ONE_ON_ONE), eq(members), any())
arrangement.mlsConversationRepository.establishMLSGroup(
groupID = Arrangement.GROUP_ID_ONE_ON_ONE,
members = members,
publicKeys = null,
allowSkippingUsersWithoutKeyPackages = false
)
}.wasInvoked(once)
}

Expand Down Expand Up @@ -256,7 +271,7 @@ class JoinExistingMLSConversationUseCaseTest {

suspend fun withEstablishMLSGroupSuccessful(additionResult: MLSAdditionResult) = apply {
coEvery {
mlsConversationRepository.establishMLSGroup(any(), any(), any())
mlsConversationRepository.establishMLSGroup(any(), any(), any(), any())
}.returns(Either.Right(additionResult))
}

Expand Down
Loading

0 comments on commit 1dd7cb2

Please sign in to comment.