Skip to content

Commit

Permalink
revert ExpirableCache
Browse files Browse the repository at this point in the history
  • Loading branch information
saleniuk committed Apr 26, 2024
1 parent f735a46 commit 9f08ed5
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 282 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.cryptography.Ed22519Key
import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.cryptography.MLSGroupId
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logger.obfuscateId
import com.wire.kalium.logic.CoreFailure
Expand Down Expand Up @@ -68,8 +67,6 @@ import com.wire.kalium.logic.functional.right
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.sync.SyncManager
import com.wire.kalium.logic.sync.incremental.EventSource
import com.wire.kalium.logic.util.CurrentTimeProvider
import com.wire.kalium.logic.util.ExpirableCache
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.logic.wrapMLSRequest
import com.wire.kalium.logic.wrapStorageRequest
Expand All @@ -95,7 +92,6 @@ import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.merge
import kotlinx.coroutines.withContext
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

data class ApplicationMessage(
val message: ByteArray,
Expand Down Expand Up @@ -223,7 +219,6 @@ internal class MLSConversationDataSource(
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val checkRevocationList: CheckRevocationListUseCase,
private val certificateRevocationListRepository: CertificateRevocationListRepository,
currentTimeProvider: CurrentTimeProvider = DateTimeUtil::currentInstant,
private val idMapper: IdMapper = MapperProvider.idMapper(),
private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId),
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper(),
Expand Down Expand Up @@ -676,19 +671,15 @@ internal class MLSConversationDataSource(
})
})

private val getDeviceIdentitiesCache =
ExpirableCache<Pair<MLSGroupId, Set<CryptoQualifiedClientId>>, List<WireIdentity>>(IDENTITIES_TTL, currentTimeProvider)
private val getUserIdentitiesCache =
ExpirableCache<Pair<MLSGroupId, Set<UserId>>, Map<String, List<WireIdentity>>>(IDENTITIES_TTL, currentTimeProvider)

override suspend fun getClientIdentity(clientId: ClientId) =
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
val cryptoQualifiedClientId = CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto())
getDeviceIdentitiesCache.getOrPut(it.mlsGroupId to setOf(cryptoQualifiedClientId)) {
mlsClient.getDeviceIdentities(it.mlsGroupId, listOf(cryptoQualifiedClientId))
}.firstOrNull()

mlsClient.getDeviceIdentities(
it.mlsGroupId,
listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto()))
).firstOrNull()
}
}
}
Expand All @@ -703,9 +694,10 @@ internal class MLSConversationDataSource(
}.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
getUserIdentitiesCache.getOrPut(mlsGroupId to setOf(userId)) {
mlsClient.getUserIdentities(mlsGroupId, listOf(userId.toCrypto()))
}[userId.value] ?: emptyList()
mlsClient.getUserIdentities(
mlsGroupId,
listOf(userId.toCrypto())
)[userId.value] ?: emptyList()
}
}
}
Expand All @@ -720,13 +712,13 @@ internal class MLSConversationDataSource(
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
val userIdsAndIdentity = mutableMapOf<UserId, List<WireIdentity>>()
getUserIdentitiesCache.getOrPut(mlsGroupId to userIds.toSet()) {
mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
}.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities

mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities
}
}
}

userIdsAndIdentity
}
Expand Down Expand Up @@ -903,5 +895,3 @@ internal class MLSConversationDataSource(

private data class CommitOperationResult<T>(val commitBundle: CommitBundle?, val result: T)
}

val IDENTITIES_TTL = 1.seconds

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.left
import com.wire.kalium.logic.sync.SyncManager
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.util.CurrentTimeProvider
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.network.api.base.authenticated.client.ClientApi
Expand Down Expand Up @@ -102,30 +101,22 @@ import io.mockative.once
import io.mockative.thenDoNothing
import io.mockative.twice
import io.mockative.verify
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.yield
import kotlinx.datetime.Instant
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs
import kotlin.test.assertNotEquals
import kotlin.time.Duration.Companion.milliseconds

@OptIn(ExperimentalCoroutinesApi::class)
class MLSConversationRepositoryTest {

val dispatcher = TestKaliumDispatcher.default
private val currentTime: CurrentTimeProvider = { Instant.fromEpochMilliseconds(dispatcher.scheduler.currentTime) }

@Test
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest(dispatcher) {
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest(TestKaliumDispatcher.default) {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withDecryptMLSMessageSuccessful(Arrangement.DECRYPTED_MESSAGE_BUNDLE)
.arrange()
Expand Down Expand Up @@ -1645,147 +1636,7 @@ class MLSConversationRepositoryTest {
}
}

@Test
fun givenClientId_whenGettingIdentitiesTwiceInAShortTime_thenGetIdentitiesFromMlsClientOnceAndCacheThem() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetDeviceIdentitiesReturn(listOf(identity1))
.withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY.copy(mlsGroupId = groupId))
.arrange()

val result1 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)
advanceTimeBy(IDENTITIES_TTL - 500.milliseconds)
arrangement.withGetDeviceIdentitiesReturn(listOf(identity2))
val result2 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)

assertEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getDeviceIdentities)
.with(eq(groupId), any())
.wasInvoked(once)
}

@Test
fun givenClientId_whenGettingIdentitiesTwiceInALongTime_thenBothTimesGetIdentitiesFromMlsClient() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetDeviceIdentitiesReturn(listOf(identity1))
.withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY.copy(mlsGroupId = groupId))
.arrange()

val result1 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)
advanceTimeBy(IDENTITIES_TTL + 500.milliseconds)
arrangement.withGetDeviceIdentitiesReturn(listOf(identity2))
val result2 = mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)

assertNotEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getDeviceIdentities)
.with(eq(groupId), any())
.wasInvoked(twice)
}

@Test
fun givenUserId_whenGettingIdentitiesTwiceInAShortTime_thenGetIdentitiesFromMlsClientOnceAndCacheThem() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByUserIdReturns(groupId)
.withGetEstablishedSelfMLSGroupIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)
advanceTimeBy(IDENTITIES_TTL - 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)

assertEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)
}

@Test
fun givenUserId_whenGettingIdentitiesTwiceInALongTime_thenBothTimesGetIdentitiesFromMlsClient() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByUserIdReturns(groupId)
.withGetEstablishedSelfMLSGroupIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)
advanceTimeBy(IDENTITIES_TTL + 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getUserIdentity(TestUser.USER_ID)

assertNotEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(twice)
}

@Test
fun givenMemberId_whenGettingIdentitiesTwiceInAShortTime_thenGetIdentitiesFromMlsClientOnceAndCacheThem() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByConversationIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))
advanceTimeBy(IDENTITIES_TTL - 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))

assertEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)
}

@Test
fun givenMemberId_whenGettingIdentitiesTwiceInALongTime_thenBothTimesGetIdentitiesFromMlsClient() = runTest(dispatcher) {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val identity1 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name1"))
val identity2 = WIRE_IDENTITY.copy(certificate = WIRE_IDENTITY.certificate?.copy(displayName = "name2"))
val (arrangement, mlsConversationRepository) = Arrangement(currentTime)
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity1)))
.withGetMLSGroupIdByConversationIdReturns(groupId)
.arrange()

val result1 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))
advanceTimeBy(IDENTITIES_TTL + 500.milliseconds)
arrangement.withGetUserIdentitiesReturn(mapOf(TestUser.USER_ID.value to listOf(identity2)))
val result2 = mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(TestUser.USER_ID))

assertNotEquals(result1, result2)
verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(twice)
}

private class Arrangement(val currentTimeProvider: CurrentTimeProvider = DateTimeUtil::currentInstant) {
private class Arrangement {

@Mock
val commitBundleEventReceiver = mock(classOf<CommitBundleEventReceiver>())
Expand Down Expand Up @@ -1848,8 +1699,7 @@ class MLSConversationRepositoryTest {
proposalTimersFlow,
keyPackageLimitsProvider,
checkRevocationList,
certificateRevocationListRepository,
currentTimeProvider
certificateRevocationListRepository
)

fun withCommitBundleEventReceiverSucceeding() = apply {
Expand Down
Loading

0 comments on commit 9f08ed5

Please sign in to comment.