Skip to content

Commit

Permalink
fix: ANRs when getting identities [WPB-8753]
Browse files Browse the repository at this point in the history
  • Loading branch information
saleniuk committed Apr 25, 2024
1 parent 2931055 commit 9a41d93
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.cryptography.Ed22519Key
import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.cryptography.MLSGroupId
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logger.obfuscateId
import com.wire.kalium.logic.CoreFailure
Expand Down Expand Up @@ -67,6 +68,8 @@ import com.wire.kalium.logic.functional.right
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.sync.SyncManager
import com.wire.kalium.logic.sync.incremental.EventSource
import com.wire.kalium.logic.util.CurrentTimeProvider
import com.wire.kalium.logic.util.ExpirableCache
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.logic.wrapMLSRequest
import com.wire.kalium.logic.wrapStorageRequest
Expand All @@ -92,6 +95,7 @@ import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.merge
import kotlinx.coroutines.withContext
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

data class ApplicationMessage(
val message: ByteArray,
Expand Down Expand Up @@ -219,6 +223,7 @@ internal class MLSConversationDataSource(
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val checkRevocationList: CheckRevocationListUseCase,
private val certificateRevocationListRepository: CertificateRevocationListRepository,
currentTimeProvider: CurrentTimeProvider = DateTimeUtil::currentInstant,
private val idMapper: IdMapper = MapperProvider.idMapper(),
private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId),
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper(),
Expand Down Expand Up @@ -671,15 +676,19 @@ internal class MLSConversationDataSource(
})
})

private val getDeviceIdentitiesCache =
ExpirableCache<Pair<MLSGroupId, Set<CryptoQualifiedClientId>>, List<WireIdentity>>(IDENTITIES_EXPIRATION, currentTimeProvider)
private val getUserIdentitiesCache =
ExpirableCache<Pair<MLSGroupId, Set<UserId>>, Map<String, List<WireIdentity>>>(IDENTITIES_EXPIRATION, currentTimeProvider)

override suspend fun getClientIdentity(clientId: ClientId) =
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {

mlsClient.getDeviceIdentities(
it.mlsGroupId,
listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto()))
).firstOrNull()
val cryptoQualifiedClientId = CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto())
getDeviceIdentitiesCache.getOrPut(it.mlsGroupId to setOf(cryptoQualifiedClientId)) {
mlsClient.getDeviceIdentities(it.mlsGroupId, listOf(cryptoQualifiedClientId))
}.firstOrNull()
}
}
}
Expand All @@ -694,10 +703,9 @@ internal class MLSConversationDataSource(
}.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.getUserIdentities(
mlsGroupId,
listOf(userId.toCrypto())
)[userId.value] ?: emptyList()
getUserIdentitiesCache.getOrPut(mlsGroupId to setOf(userId)) {
mlsClient.getUserIdentities(mlsGroupId, listOf(userId.toCrypto()))
}[userId.value] ?: emptyList()
}
}
}
Expand All @@ -712,13 +720,13 @@ internal class MLSConversationDataSource(
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
val userIdsAndIdentity = mutableMapOf<UserId, List<WireIdentity>>()

mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities
}
getUserIdentitiesCache.getOrPut(mlsGroupId to userIds.toSet()) {
mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
}.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities
}
}

userIdsAndIdentity
}
Expand Down Expand Up @@ -895,3 +903,5 @@ internal class MLSConversationDataSource(

private data class CommitOperationResult<T>(val commitBundle: CommitBundle?, val result: T)
}

val IDENTITIES_EXPIRATION = 1.seconds
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.logic.data.user.UserRepository
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.flatMapLatest
import kotlinx.coroutines.flow.map
Expand Down Expand Up @@ -55,6 +56,6 @@ class ObserveConversationMembersUseCaseImpl internal constructor(
}
}.flatMapLatest { detailsFlows ->
combine(detailsFlows) { it.toList() }
}
}.distinctUntilChanged()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.util

import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.datetime.Instant
import kotlin.time.Duration

class ExpirableCache<K, V>(private val expiration: Duration, private val currentTime: CurrentTimeProvider) {
private val mutex = Mutex()
private val map = mutableMapOf<K, Pair<Instant, V>>()

init {
if (expiration.isNegative()) throw IllegalArgumentException("Expiration must be positive")
}

suspend fun getOrPut(key: K, create: suspend () -> V): V = mutex.withLock {
val currentValue = map[key]?.let { (addedAt, value) ->
if (addedAt.plus(expiration) >= currentTime()) value else null
}
currentValue ?: create().also { map[key] = currentTime() to it }
}
}

typealias CurrentTimeProvider = () -> Instant
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.left
import com.wire.kalium.logic.sync.SyncManager
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.util.CurrentTimeProvider
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.network.api.base.authenticated.client.ClientApi
Expand Down Expand Up @@ -101,22 +102,30 @@ import io.mockative.once
import io.mockative.thenDoNothing
import io.mockative.twice
import io.mockative.verify
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.yield
import kotlinx.datetime.Instant
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs
import kotlin.test.assertNotEquals
import kotlin.time.Duration.Companion.milliseconds

@OptIn(ExperimentalCoroutinesApi::class)
class MLSConversationRepositoryTest {

val dispatcher = TestKaliumDispatcher.default
private val currentTime: CurrentTimeProvider = { Instant.fromEpochMilliseconds(dispatcher.scheduler.currentTime) }

@Test
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest(TestKaliumDispatcher.default) {
val (arrangement, mlsConversationRepository) = Arrangement()
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest(dispatcher) {
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withDecryptMLSMessageSuccessful(Arrangement.DECRYPTED_MESSAGE_BUNDLE)
.arrange()
Expand Down Expand Up @@ -1636,7 +1645,147 @@ class MLSConversationRepositoryTest {
}
}

private class Arrangement {
@Test
fun givenClientId_whenGettingIdentitiesTwiceInAShortTime_thenGetIdentitiesFromMlsClientOnceAndCacheThem() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetDeviceIdentitiesReturn(listOf(identity1))
.withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY.copy(mlsGroupId = groupId))
.arrange()

val result1 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)
advanceTimeBy(IDENTITIES_EXPIRATION - 500.milliseconds)
arrangement.withGetDeviceIdentitiesReturn(listOf(identity2))
val result2 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)

assertEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getDeviceIdentities)
.with(eq(groupId), any())
.wasInvoked(once)
}

@Test
fun givenClientId_whenGettingIdentitiesTwiceInALongTime_thenBothTimesGetIdentitiesFromMlsClient() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetDeviceIdentitiesReturn(listOf(identity1))
.withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY.copy(mlsGroupId = groupId))
.arrange()

val result1 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)
advanceTimeBy(IDENTITIES_EXPIRATION + 500.milliseconds)
arrangement.withGetDeviceIdentitiesReturn(listOf(identity2))
val result2 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)

assertNotEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getDeviceIdentities)
.with(eq(groupId), any())
.wasInvoked(twice)
}

@Test
fun givenUserId_whenGettingIdentitiesTwiceInAShortTime_thenGetIdentitiesFromMlsClientOnceAndCacheThem() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByUserIdReturns(groupId)
.withGetEstablishedSelfMLSGroupIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)
advanceTimeBy(IDENTITIES_EXPIRATION - 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)

assertEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)
}

@Test
fun givenUserId_whenGettingIdentitiesTwiceInALongTime_thenBothTimesGetIdentitiesFromMlsClient() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByUserIdReturns(groupId)
.withGetEstablishedSelfMLSGroupIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)
advanceTimeBy(IDENTITIES_EXPIRATION + 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)

assertNotEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(twice)
}

@Test
fun givenMemberId_whenGettingIdentitiesTwiceInAShortTime_thenGetIdentitiesFromMlsClientOnceAndCacheThem() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByConversationIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))
advanceTimeBy(IDENTITIES_EXPIRATION - 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))

assertEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)
}

@Test
fun givenMemberId_whenGettingIdentitiesTwiceInALongTime_thenBothTimesGetIdentitiesFromMlsClient() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByConversationIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))
advanceTimeBy(IDENTITIES_EXPIRATION + 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))

assertNotEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(twice)
}

private class Arrangement(val currentTimeProvider: CurrentTimeProvider = DateTimeUtil::currentInstant) {

@Mock
val commitBundleEventReceiver = mock(classOf<CommitBundleEventReceiver>())
Expand Down Expand Up @@ -1699,7 +1848,8 @@ class MLSConversationRepositoryTest {
proposalTimersFlow,
keyPackageLimitsProvider,
checkRevocationList,
certificateRevocationListRepository
certificateRevocationListRepository,
currentTimeProvider
)

fun withCommitBundleEventReceiverSucceeding() = apply {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,39 @@ class ObserveConversationMembersUseCaseTest {
awaitComplete()
}
}

@Test
fun givenAConversationID_whenObservingMembersAnDataDidNotChange_thenDoNotEmitTheSameValuesAgain() = runTest {
val conversationID = TestConversation.ID
val otherUser = TestUser.OTHER
val selfUser = TestUser.SELF
val membersListChannel = Channel<List<Member>>(Channel.UNLIMITED)

given(userRepository)
.suspendFunction(userRepository::observeUser)
.whenInvokedWith(eq(TestUser.SELF.id))
.thenReturn(flowOf(selfUser))

given(userRepository)
.suspendFunction(userRepository::observeUser)
.whenInvokedWith(eq(otherUser.id))
.thenReturn(flowOf(otherUser))

given(conversationRepository)
.suspendFunction(conversationRepository::observeConversationMembers)
.whenInvokedWith(eq(conversationID))
.thenReturn(membersListChannel.consumeAsFlow())

observeConversationMembers(conversationID).test {

membersListChannel.send(listOf(Member(otherUser.id, Member.Role.Member)))
assertContentEquals(listOf(MemberDetails(otherUser, Member.Role.Member)), awaitItem())

membersListChannel.send(listOf(Member(otherUser.id, Member.Role.Member)))
expectNoEvents()

membersListChannel.close()
awaitComplete()
}
}
}
Loading

0 comments on commit 9a41d93

Please sign in to comment.