Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mls): set removal-keys for 1on1 calls from conversation-response (WPB-10743) 🍒 #3019

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ internal class ConversationGroupRepositoryImpl(
val conversationEntity = conversationMapper.fromApiModelToDaoModel(
conversationResponse, mlsGroupState = ConversationEntity.GroupState.PENDING_CREATION, selfTeamId
)
val mlsPublicKeys = conversationMapper.fromApiModel(conversationResponse.publicKeys)
val protocol = protocolInfoMapper.fromEntity(conversationEntity.protocolInfo)

return wrapStorageRequest {
Expand All @@ -166,7 +167,8 @@ internal class ConversationGroupRepositoryImpl(
is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup(
groupID = protocol.groupId,
members = usersList + selfUserId,
allowSkippingUsersWithoutKeyPackages = true
publicKeys = mlsPublicKeys,
allowSkippingUsersWithoutKeyPackages = true,
).map { it.notAddedUsers }
}
}.flatMap { protocolSpecificAdditionFailures ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.message.MessagePreview
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.user.AvailabilityStatusMapper
import com.wire.kalium.logic.data.user.BotService
import com.wire.kalium.logic.data.user.Connection
Expand All @@ -41,6 +42,7 @@ import com.wire.kalium.network.api.authenticated.conversation.ConvTeamInfo
import com.wire.kalium.network.api.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.authenticated.conversation.ReceiptMode
import com.wire.kalium.network.api.authenticated.serverpublickey.MLSPublicKeysDTO
import com.wire.kalium.network.api.model.ConversationAccessDTO
import com.wire.kalium.network.api.model.ConversationAccessRoleDTO
import com.wire.kalium.persistence.dao.conversation.ConversationEntity
Expand All @@ -59,6 +61,7 @@ import kotlin.time.toDuration

interface ConversationMapper {
fun fromApiModelToDaoModel(apiModel: ConversationResponse, mlsGroupState: GroupState?, selfUserTeamId: TeamId?): ConversationEntity
fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?): MLSPublicKeys?
fun fromDaoModel(daoModel: ConversationViewEntity): Conversation
fun fromDaoModel(daoModel: ConversationEntity): Conversation
fun fromDaoModelToDetails(
Expand Down Expand Up @@ -136,6 +139,12 @@ internal class ConversationMapperImpl(
legalHoldStatus = ConversationEntity.LegalHoldStatus.DISABLED
)

override fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?) = mlsPublicKeysDTO?.let {
MLSPublicKeys(
removal = mlsPublicKeysDTO.removal
)
}

override fun fromDaoModel(daoModel: ConversationViewEntity): Conversation = with(daoModel) {
val lastReadDateEntity = if (type == ConversationEntity.Type.CONNECTION_PENDING) Instant.UNIX_FIRST_DATE
else lastReadDate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.mlspublickeys.getRemovalKey
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.e2ei.RevocationListChecker
Expand Down Expand Up @@ -123,6 +125,7 @@ interface MLSConversationRepository {
suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
publicKeys: MLSPublicKeys? = null,
allowSkippingUsersWithoutKeyPackages: Boolean = false
): Either<CoreFailure, MLSAdditionResult>

Expand Down Expand Up @@ -554,16 +557,18 @@ internal class MLSConversationDataSource(
override suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean,
publicKeys: MLSPublicKeys?,
allowSkippingUsersWithoutKeyPackages: Boolean
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> {
mlsPublicKeysRepository.getKeyForCipherSuite(
CipherSuite.fromTag(it.getDefaultCipherSuite())
).flatMap { key ->
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> { mlsClient ->
val cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())
val keys = publicKeys?.getRemovalKey(cipherSuite) ?: mlsPublicKeysRepository.getKeyForCipherSuite(cipherSuite)

keys.flatMap { externalSenders ->
establishMLSGroup(
groupID = groupID,
members = members,
externalSenders = key,
externalSenders = externalSenders,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ data class MLSPublicKeys(
val removal: Map<String, String>?
)

fun MLSPublicKeys.getRemovalKey(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {
val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = this.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
return key.decodeBase64Bytes().right()
}

interface MLSPublicKeysRepository {
suspend fun fetchKeys(): Either<CoreFailure, MLSPublicKeys>
suspend fun getKeys(): Either<CoreFailure, MLSPublicKeys>
Expand All @@ -42,7 +51,6 @@ interface MLSPublicKeysRepository {

class MLSPublicKeysRepositoryImpl(
private val mlsPublicKeyApi: MLSPublicKeyApi,
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
) : MLSPublicKeysRepository {

// TODO: make it thread safe
Expand All @@ -60,14 +68,8 @@ class MLSPublicKeysRepositoryImpl(
}

override suspend fun getKeyForCipherSuite(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {

return getKeys().flatMap { serverPublicKeys ->
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = serverPublicKeys.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
key.decodeBase64Bytes().right()
serverPublicKeys.getRemovalKey(cipherSuite)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ class ConversationGroupRepositoryTest {
}.wasInvoked(once)

coVerify {
mlsConversationRepository.establishMLSGroup(any(), any(), eq(true))
mlsConversationRepository.establishMLSGroup(any(), any(), any(), eq(true))
}.wasInvoked(once)

coVerify {
Expand Down Expand Up @@ -465,7 +465,7 @@ class ConversationGroupRepositoryTest {
}.wasInvoked(once)

coVerify {
mlsConversationRepository.establishMLSGroup(any(), any(), eq(true))
mlsConversationRepository.establishMLSGroup(any(), any(), any(), eq(true))
}.wasInvoked(once)

coVerify {
Expand Down Expand Up @@ -1723,11 +1723,10 @@ class ConversationGroupRepositoryTest {
legalHoldHandler
)

suspend fun withMlsConversationEstablished(additionResult: MLSAdditionResult): Arrangement {
suspend fun withMlsConversationEstablished(additionResult: MLSAdditionResult): Arrangement = apply {
coEvery {
mlsConversationRepository.establishMLSGroup(any(), any(), any())
mlsConversationRepository.establishMLSGroup(any(), any(), any(), any())
}.returns(Either.Right(additionResult))
return this
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arr
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CRYPTO_CLIENT_ID
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.KEY_PACKAGE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.MLS_PUBLIC_KEY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.ROTATE_BUNDLE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.WIRE_IDENTITY
Expand Down Expand Up @@ -98,26 +99,30 @@ import io.mockative.matches
import io.mockative.mock
import io.mockative.once
import io.mockative.twice
import io.mockative.verify
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.yield
import kotlinx.datetime.Instant
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs

class MLSConversationRepositoryTest {

@BeforeTest

@Test
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest(TestKaliumDispatcher.default) {
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
.withGetMLSClientSuccessful()
.withDecryptMLSMessageSuccessful(Arrangement.DECRYPTED_MESSAGE_BUNDLE)
.arrange()

val epochChange = async(TestKaliumDispatcher.default) {
val epochChange = async() {
arrangement.epochsFlow.first()
}
yield()
Expand Down Expand Up @@ -168,7 +173,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand Down Expand Up @@ -280,6 +285,84 @@ class MLSConversationRepositoryTest {
}.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNotNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryNotCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(kaliumDispatcher = testKaliumDispatcher)
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = MLS_PUBLIC_KEY)
result.shouldSucceed()

coVerify {
arrangement.mlsClient.createConversation(
groupId = eq(Arrangement.RAW_GROUP_ID),
externalSenders = any())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.addMember(
groupId = eq(Arrangement.RAW_GROUP_ID),
membersKeyPackages = any())
}.wasInvoked(once)

coVerify {
arrangement.mlsMessageApi.sendCommitBundle(any<MLSMessageApi.CommitBundle>())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.commitAccepted(eq(Arrangement.RAW_GROUP_ID))
}.wasInvoked(once)

coVerify {
arrangement.mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryIsCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
arrangement.mlsClient.createConversation(eq(Arrangement.RAW_GROUP_ID), any())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.addMember(eq(Arrangement.RAW_GROUP_ID), any())
}.wasInvoked(once)

coVerify {
arrangement.mlsMessageApi.sendCommitBundle(any<MLSMessageApi.CommitBundle>())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.commitAccepted(eq(Arrangement.RAW_GROUP_ID))
}.wasInvoked(once)

coVerify {
arrangement.mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.wasInvoked(once)
}

@Test
fun givenNewCrlDistributionPoints_whenEstablishingMLSGroup_thenCheckRevocationList() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
Expand Down Expand Up @@ -329,7 +412,7 @@ class MLSConversationRepositoryTest {
.withWaitUntilLiveSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand Down Expand Up @@ -357,7 +440,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR)
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldFail()

coVerify {
Expand Down Expand Up @@ -385,7 +468,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand All @@ -410,7 +493,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList())
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList(), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,12 @@ class JoinExistingMLSConversationUseCaseTest {
joinExistingMLSConversationsUseCase(Arrangement.MLS_UNESTABLISHED_GROUP_CONVERSATION.id).shouldSucceed()

coVerify {
arrangement.mlsConversationRepository.establishMLSGroup(eq(Arrangement.GROUP_ID3), eq(emptyList()), any())
arrangement.mlsConversationRepository.establishMLSGroup(
groupID = Arrangement.GROUP_ID3,
members = emptyList(),
publicKeys = null,
allowSkippingUsersWithoutKeyPackages = false
)
}.wasNotInvoked()
}

Expand All @@ -148,7 +153,12 @@ class JoinExistingMLSConversationUseCaseTest {
joinExistingMLSConversationsUseCase(Arrangement.MLS_UNESTABLISHED_SELF_CONVERSATION.id).shouldSucceed()

coVerify {
arrangement.mlsConversationRepository.establishMLSGroup(eq(Arrangement.GROUP_ID_SELF), eq(emptyList()), any())
arrangement.mlsConversationRepository.establishMLSGroup(
groupID = Arrangement.GROUP_ID_SELF,
members = emptyList(),
publicKeys = null,
allowSkippingUsersWithoutKeyPackages = false
)
}.wasInvoked(once)
}

Expand All @@ -167,7 +177,12 @@ class JoinExistingMLSConversationUseCaseTest {
joinExistingMLSConversationsUseCase(Arrangement.MLS_UNESTABLISHED_ONE_ONE_ONE_CONVERSATION.id).shouldSucceed()

coVerify {
arrangement.mlsConversationRepository.establishMLSGroup(eq(Arrangement.GROUP_ID_ONE_ON_ONE), eq(members), any())
arrangement.mlsConversationRepository.establishMLSGroup(
groupID = Arrangement.GROUP_ID_ONE_ON_ONE,
members = members,
publicKeys = null,
allowSkippingUsersWithoutKeyPackages = false
)
}.wasInvoked(once)
}

Expand Down Expand Up @@ -256,7 +271,7 @@ class JoinExistingMLSConversationUseCaseTest {

suspend fun withEstablishMLSGroupSuccessful(additionResult: MLSAdditionResult) = apply {
coEvery {
mlsConversationRepository.establishMLSGroup(any(), any(), any())
mlsConversationRepository.establishMLSGroup(any(), any(), any(), any())
}.returns(Either.Right(additionResult))
}

Expand Down
Loading
Loading