Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ANRs when getting identities [WPB-8753] #2718

Merged
merged 7 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/gradle-ios-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
uses: ./.github/workflows/codestyle.yml
gradle-run-tests:
needs: [detekt]
runs-on: macos-latest
runs-on: macos-12

steps:
- name: Checkout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.cryptography.Ed22519Key
import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.cryptography.MLSGroupId
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logger.obfuscateId
import com.wire.kalium.logic.CoreFailure
Expand Down Expand Up @@ -67,6 +68,8 @@ import com.wire.kalium.logic.functional.right
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.sync.SyncManager
import com.wire.kalium.logic.sync.incremental.EventSource
import com.wire.kalium.logic.util.CurrentTimeProvider
import com.wire.kalium.logic.util.ExpirableCache
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.logic.wrapMLSRequest
import com.wire.kalium.logic.wrapStorageRequest
Expand All @@ -92,6 +95,7 @@ import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.merge
import kotlinx.coroutines.withContext
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

data class ApplicationMessage(
val message: ByteArray,
Expand Down Expand Up @@ -219,6 +223,7 @@ internal class MLSConversationDataSource(
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val checkRevocationList: CheckRevocationListUseCase,
private val certificateRevocationListRepository: CertificateRevocationListRepository,
currentTimeProvider: CurrentTimeProvider = DateTimeUtil::currentInstant,
private val idMapper: IdMapper = MapperProvider.idMapper(),
private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId),
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper(),
Expand Down Expand Up @@ -671,15 +676,19 @@ internal class MLSConversationDataSource(
})
})

private val getDeviceIdentitiesCache =
ExpirableCache<Pair<MLSGroupId, Set<CryptoQualifiedClientId>>, List<WireIdentity>>(IDENTITIES_TTL, currentTimeProvider)
private val getUserIdentitiesCache =
ExpirableCache<Pair<MLSGroupId, Set<UserId>>, Map<String, List<WireIdentity>>>(IDENTITIES_TTL, currentTimeProvider)

override suspend fun getClientIdentity(clientId: ClientId) =
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {

mlsClient.getDeviceIdentities(
it.mlsGroupId,
listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto()))
).firstOrNull()
val cryptoQualifiedClientId = CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto())
getDeviceIdentitiesCache.getOrPut(it.mlsGroupId to setOf(cryptoQualifiedClientId)) {
mlsClient.getDeviceIdentities(it.mlsGroupId, listOf(cryptoQualifiedClientId))
}.firstOrNull()
}
}
}
Expand All @@ -694,10 +703,9 @@ internal class MLSConversationDataSource(
}.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.getUserIdentities(
mlsGroupId,
listOf(userId.toCrypto())
)[userId.value] ?: emptyList()
getUserIdentitiesCache.getOrPut(mlsGroupId to setOf(userId)) {
mlsClient.getUserIdentities(mlsGroupId, listOf(userId.toCrypto()))
}[userId.value] ?: emptyList()
}
}
}
Expand All @@ -712,13 +720,13 @@ internal class MLSConversationDataSource(
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
val userIdsAndIdentity = mutableMapOf<UserId, List<WireIdentity>>()

mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities
}
getUserIdentitiesCache.getOrPut(mlsGroupId to userIds.toSet()) {
mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
}.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities
}
}

userIdsAndIdentity
}
Expand Down Expand Up @@ -895,3 +903,5 @@ internal class MLSConversationDataSource(

private data class CommitOperationResult<T>(val commitBundle: CommitBundle?, val result: T)
}

val IDENTITIES_TTL = 1.seconds
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.logic.data.user.UserRepository
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.flatMapLatest
import kotlinx.coroutines.flow.map
Expand Down Expand Up @@ -55,6 +56,6 @@ class ObserveConversationMembersUseCaseImpl internal constructor(
}
}.flatMapLatest { detailsFlows ->
combine(detailsFlows) { it.toList() }
}
}.distinctUntilChanged()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.util

import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.datetime.Instant
import kotlin.time.Duration

class ExpirableCache<K, V>(private val timeToLive: Duration, private val currentTime: CurrentTimeProvider) {
private val mutex = Mutex()
private val map = mutableMapOf<K, Pair<Instant, V>>()

init {
if (timeToLive.isNegative()) throw IllegalArgumentException("timeToLive must be positive")
}

suspend fun getOrPut(key: K, create: suspend () -> V): V = mutex.withLock {
val currentValue = map[key]?.let { (addedAt, value) ->
if (addedAt.plus(timeToLive) >= currentTime()) value else null
}
currentValue ?: create().also { map[key] = currentTime() to it }
}
}

typealias CurrentTimeProvider = () -> Instant
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.left
import com.wire.kalium.logic.sync.SyncManager
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.util.CurrentTimeProvider
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.network.api.base.authenticated.client.ClientApi
Expand Down Expand Up @@ -101,22 +102,30 @@ import io.mockative.once
import io.mockative.thenDoNothing
import io.mockative.twice
import io.mockative.verify
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.yield
import kotlinx.datetime.Instant
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs
import kotlin.test.assertNotEquals
import kotlin.time.Duration.Companion.milliseconds

@OptIn(ExperimentalCoroutinesApi::class)
class MLSConversationRepositoryTest {

val dispatcher = TestKaliumDispatcher.default
private val currentTime: CurrentTimeProvider = { Instant.fromEpochMilliseconds(dispatcher.scheduler.currentTime) }

@Test
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest(TestKaliumDispatcher.default) {
val (arrangement, mlsConversationRepository) = Arrangement()
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest(dispatcher) {
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withDecryptMLSMessageSuccessful(Arrangement.DECRYPTED_MESSAGE_BUNDLE)
.arrange()
Expand Down Expand Up @@ -1636,7 +1645,147 @@ class MLSConversationRepositoryTest {
}
}

private class Arrangement {
@Test
fun givenClientId_whenGettingIdentitiesTwiceInAShortTime_thenGetIdentitiesFromMlsClientOnceAndCacheThem() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetDeviceIdentitiesReturn(listOf(identity1))
.withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY.copy(mlsGroupId = groupId))
.arrange()

val result1 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)
advanceTimeBy(IDENTITIES_TTL - 500.milliseconds)
arrangement.withGetDeviceIdentitiesReturn(listOf(identity2))
val result2 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)

assertEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getDeviceIdentities)
.with(eq(groupId), any())
.wasInvoked(once)
}

@Test
fun givenClientId_whenGettingIdentitiesTwiceInALongTime_thenBothTimesGetIdentitiesFromMlsClient() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetDeviceIdentitiesReturn(listOf(identity1))
.withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY.copy(mlsGroupId = groupId))
.arrange()

val result1 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)
advanceTimeBy(IDENTITIES_TTL + 500.milliseconds)
arrangement.withGetDeviceIdentitiesReturn(listOf(identity2))
val result2 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)

assertNotEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getDeviceIdentities)
.with(eq(groupId), any())
.wasInvoked(twice)
}

@Test
fun givenUserId_whenGettingIdentitiesTwiceInAShortTime_thenGetIdentitiesFromMlsClientOnceAndCacheThem() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByUserIdReturns(groupId)
.withGetEstablishedSelfMLSGroupIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)
advanceTimeBy(IDENTITIES_TTL - 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)

assertEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)
}

@Test
fun givenUserId_whenGettingIdentitiesTwiceInALongTime_thenBothTimesGetIdentitiesFromMlsClient() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByUserIdReturns(groupId)
.withGetEstablishedSelfMLSGroupIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)
advanceTimeBy(IDENTITIES_TTL + 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)

assertNotEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(twice)
}

@Test
fun givenMemberId_whenGettingIdentitiesTwiceInAShortTime_thenGetIdentitiesFromMlsClientOnceAndCacheThem() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByConversationIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))
advanceTimeBy(IDENTITIES_TTL - 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))

assertEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)
}

@Test
fun givenMemberId_whenGettingIdentitiesTwiceInALongTime_thenBothTimesGetIdentitiesFromMlsClient() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByConversationIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))
advanceTimeBy(IDENTITIES_TTL + 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))

assertNotEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(twice)
}

private class Arrangement(val currentTimeProvider: CurrentTimeProvider = DateTimeUtil::currentInstant) {

@Mock
val commitBundleEventReceiver = mock(classOf<CommitBundleEventReceiver>())
Expand Down Expand Up @@ -1699,7 +1848,8 @@ class MLSConversationRepositoryTest {
proposalTimersFlow,
keyPackageLimitsProvider,
checkRevocationList,
certificateRevocationListRepository
certificateRevocationListRepository,
currentTimeProvider
)

fun withCommitBundleEventReceiverSucceeding() = apply {
Expand Down
Loading
Loading