From 6e038ebe73dad74bdd3763d8f1b80cee015deadc Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 19:50:40 +0000 Subject: [PATCH] =?UTF-8?q?fix(mls):=20update=20migrated=20conversation=20?= =?UTF-8?q?with=20correct=20group=20id=20and=20cipher=20suites=20(WPB-9169?= =?UTF-8?q?)=20=F0=9F=8D=92=20(#2771)=20(#2788)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Commit with unresolved merge conflicts * Commit with unresolved merge conflicts * trigger CI * fix tests to match candidate changes * fix tests --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Mojtaba Chenani Co-authored-by: Vitor Hugo Schwaab --- .../conversation/ConversationRepository.kt | 31 ++++------ .../ConversationRepositoryTest.kt | 58 +++++++++++-------- .../wire/kalium/persistence/Conversations.sq | 7 ++- .../dao/conversation/ConversationDAO.kt | 10 +++- .../dao/conversation/ConversationDAOImpl.kt | 14 ++++- .../persistence/dao/ConversationDAOTest.kt | 14 ++++- 6 files changed, 82 insertions(+), 52 deletions(-) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt index abf0e386b60..83ff3c6070f 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt @@ -52,6 +52,7 @@ import com.wire.kalium.logic.functional.mapRight import com.wire.kalium.logic.functional.mapToRightOr import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess +import com.wire.kalium.logic.functional.right import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.logic.wrapApiRequest import com.wire.kalium.logic.wrapMLSRequest @@ -63,7 +64,6 @@ import com.wire.kalium.network.api.base.authenticated.conversation.ConversationR import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessResponse -import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolResponse import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationReceiptModeResponse import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationMemberRoleDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationReceiptModeDTO @@ -1026,17 +1026,8 @@ internal class ConversationDataSource internal constructor( ): Either = wrapApiRequest { conversationApi.updateProtocol(conversationId.toApi(), protocol.toApi()) - }.flatMap { response -> - when (response) { - UpdateConversationProtocolResponse.ProtocolUnchanged -> { - // no need to update conversation - Either.Right(false) - } - - is UpdateConversationProtocolResponse.ProtocolUpdated -> { - updateProtocolLocally(conversationId, protocol) - } - } + }.flatMap { + updateProtocolLocally(conversationId, protocol) } override suspend fun updateProtocolLocally( @@ -1047,19 +1038,19 @@ internal class ConversationDataSource internal constructor( conversationApi.fetchConversationDetails(conversationId.toApi()) }.flatMap { conversationResponse -> wrapStorageRequest { - conversationDAO.updateConversationProtocol( + conversationDAO.updateConversationProtocolAndCipherSuite( conversationId = conversationId.toDao(), - protocol = protocol.toDao() + groupID = conversationResponse.groupId, + protocol = protocol.toDao(), + cipherSuite = ConversationEntity.CipherSuite.fromTag(conversationResponse.mlsCipherSuiteTag) ) }.flatMap { updated -> if (updated) { - val selfUserTeamId = selfTeamIdProvider().getOrNull() - persistConversations(listOf(conversationResponse), selfUserTeamId, invalidateMembers = true) - } else { - Either.Right(Unit) - }.map { - updated + return@flatMap true.right() } + val selfUserTeamId = selfTeamIdProvider().getOrNull() + persistConversations(listOf(conversationResponse), selfUserTeamId, invalidateMembers = true) + .map { true } } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt index 2a5addc9a6a..011470907c8 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt @@ -27,6 +27,8 @@ import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.data.id.PersistenceQualifiedId import com.wire.kalium.logic.data.id.QualifiedID +import com.wire.kalium.logic.data.id.SelfTeamIdProvider +import com.wire.kalium.logic.data.id.TeamId import com.wire.kalium.logic.data.id.toApi import com.wire.kalium.logic.data.id.toCrypto import com.wire.kalium.logic.data.id.toDao @@ -37,8 +39,6 @@ import com.wire.kalium.logic.data.user.SelfUser import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.di.MapperProvider -import com.wire.kalium.logic.data.id.SelfTeamIdProvider -import com.wire.kalium.logic.data.id.TeamId import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestTeam import com.wire.kalium.logic.framework.TestUser @@ -95,14 +95,13 @@ import com.wire.kalium.util.DateTimeUtil import io.ktor.http.HttpStatusCode import io.mockative.Mock import io.mockative.any -import io.mockative.eq import io.mockative.coEvery import io.mockative.coVerify +import io.mockative.eq import io.mockative.fake.valueOf import io.mockative.matchers.AnyMatcher import io.mockative.matchers.EqualsMatcher import io.mockative.matchers.Matcher -import io.mockative.matchers.PredicateMatcher import io.mockative.matches import io.mockative.mock import io.mockative.once @@ -121,7 +120,6 @@ import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue import com.wire.kalium.network.api.base.model.ConversationId as APIConversationId -import com.wire.kalium.network.api.base.model.ConversationId as ConversationIdDTO import com.wire.kalium.persistence.dao.client.Client as ClientEntity @Suppress("LargeClass") @@ -562,7 +560,6 @@ class ConversationRepositoryTest { val (arrange, conversationRepository) = Arrangement() .withApiUpdateAccessRoleReturns(NetworkResponse.Success(newAccess, mapOf(), 200)) - .withDaoUpdateAccessSuccess() .arrange() conversationRepository.updateAccessInfo( @@ -1156,12 +1153,18 @@ class ConversationRepositoryTest { } @Test - fun givenNoChange_whenUpdatingProtocolToMls_thenShouldNotUpdateLocally() = runTest { + fun givenNoChange_whenUpdatingProtocolToMls_thenShouldUpdateLocally() = runTest { // given val protocol = Conversation.Protocol.MLS - + val conversationResponse = NetworkResponse.Success( + TestConversation.CONVERSATION_RESPONSE, + emptyMap(), + HttpStatusCode.OK.value + ) val (arrange, conversationRepository) = Arrangement() + .withDaoUpdateProtocolSuccess() .withUpdateProtocolResponse(UPDATE_PROTOCOL_UNCHANGED) + .withFetchConversationsDetails(conversationResponse) .arrange() // when @@ -1171,8 +1174,13 @@ class ConversationRepositoryTest { with(result) { shouldSucceed() coVerify { - arrange.conversationDAO.updateConversationProtocol(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) - }.wasNotInvoked() + arrange.conversationDAO.updateConversationProtocolAndCipherSuite( + eq(CONVERSATION_ID.toDao()), + eq(conversationResponse.value.groupId), + eq(protocol.toDao()), + eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag)) + ) + }.wasInvoked(exactly = once) } } @@ -1227,7 +1235,12 @@ class ConversationRepositoryTest { with(result) { shouldSucceed() coVerify { - arrange.conversationDAO.updateConversationProtocol(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + arrange.conversationDAO.updateConversationProtocolAndCipherSuite( + eq(CONVERSATION_ID.toDao()), + eq(conversationResponse.value.groupId), + eq(protocol.toDao()), + eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag)) + ) }.wasInvoked(exactly = once) } } @@ -1254,7 +1267,12 @@ class ConversationRepositoryTest { with(result) { shouldSucceed() coVerify { - arrange.conversationDAO.updateConversationProtocol(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + arrange.conversationDAO.updateConversationProtocolAndCipherSuite( + eq(CONVERSATION_ID.toDao()), + eq(conversationResponse.value.groupId), + eq(protocol.toDao()), + eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag)) + ) }.wasInvoked(exactly = once) } } @@ -1274,9 +1292,8 @@ class ConversationRepositoryTest { // then with(result) { shouldFail() - coVerify { - arrange.conversationDAO.updateConversationProtocol(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) - }.wasNotInvoked() + coVerify { arrange.conversationDAO.updateConversationProtocolAndCipherSuite(any(), any(), any(), any()) } + .wasNotInvoked() } } @@ -1511,16 +1528,9 @@ class ConversationRepositoryTest { }.returns(response) } - suspend fun withDaoUpdateAccessSuccess() = apply { - coEvery { - conversationDAO.updateAccess(any(), any(), any()) - }.returns(Unit) - } - suspend fun withDaoUpdateProtocolSuccess() = apply { - coEvery { - conversationDAO.updateConversationProtocol(any(), any()) - }.returns(true) + coEvery { conversationDAO.updateConversationProtocolAndCipherSuite(any(), any(), any(), any()) } + .returns(true) } suspend fun withGetConversationProtocolInfoReturns(result: ConversationEntity.ProtocolInfo) = apply { diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq index a2d302359b7..9e1f5009098 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq @@ -374,10 +374,11 @@ UPDATE Conversation SET type = ? WHERE qualified_id = ?; -updateConversationProtocol { +updateConversationGroupIdAndProtocolInfo { UPDATE Conversation -SET protocol = :protocol -WHERE qualified_id = :qualified_id AND protocol != :protocol; +SET mls_group_id = :groupId, protocol = :protocol, mls_cipher_suite = :mls_cipher_suite +WHERE qualified_id = :qualified_id AND + protocol != :protocol; SELECT changes(); } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt index 04184c5d23b..2ac03f7663d 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt @@ -52,6 +52,7 @@ interface ConversationDAO { protocol: ConversationEntity.Protocol, teamId: String? = null ): List + suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: String): List suspend fun observeGetConversationByQualifiedID(qualifiedID: QualifiedIDEntity): Flow suspend fun observeGetConversationBaseInfoByQualifiedID(qualifiedID: QualifiedIDEntity): Flow @@ -61,6 +62,7 @@ interface ConversationDAO { userId: UserIDEntity, protocol: ConversationEntity.Protocol ): List + suspend fun observeOneOnOneConversationWithOtherUser(userId: UserIDEntity): Flow suspend fun getConversationProtocolInfo(qualifiedID: QualifiedIDEntity): ConversationEntity.ProtocolInfo? suspend fun observeConversationByGroupID(groupID: String): Flow @@ -94,7 +96,13 @@ interface ConversationDAO { suspend fun whoDeletedMeInConversation(conversationId: QualifiedIDEntity, selfUserIdString: String): UserIDEntity? suspend fun updateConversationName(conversationId: QualifiedIDEntity, conversationName: String, timestamp: String) suspend fun updateConversationType(conversationID: QualifiedIDEntity, type: ConversationEntity.Type) - suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean + suspend fun updateConversationProtocolAndCipherSuite( + conversationId: QualifiedIDEntity, + groupID: String?, + protocol: ConversationEntity.Protocol, + cipherSuite: ConversationEntity.CipherSuite + ): Boolean + suspend fun getConversationsByUserId(userId: UserIDEntity): List suspend fun updateConversationReceiptMode(conversationID: QualifiedIDEntity, receiptMode: ConversationEntity.ReceiptMode) suspend fun updateGuestRoomLink( diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt index f3cd7d22ae8..abd72375672 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt @@ -351,9 +351,19 @@ internal class ConversationDAOImpl internal constructor( conversationQueries.updateConversationType(type, conversationID) } - override suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean { + override suspend fun updateConversationProtocolAndCipherSuite( + conversationId: QualifiedIDEntity, + groupID: String?, + protocol: ConversationEntity.Protocol, + cipherSuite: ConversationEntity.CipherSuite + ): Boolean { return withContext(coroutineContext) { - conversationQueries.updateConversationProtocol(protocol, conversationId).executeAsOne() > 0 + conversationQueries.updateConversationGroupIdAndProtocolInfo( + groupID, + protocol, + cipherSuite, + conversationId + ).executeAsOne() > 0 } } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt index 6de2f15a360..3778aeeb918 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt @@ -441,13 +441,18 @@ class ConversationDAOTest : BaseDatabaseTest() { @Test fun givenNewValue_whenUpdatingProtocol_thenItsUpdatedAndReportedAsChanged() = runTest { val conversation = conversationEntity5 + val groupId = "groupId" + val updatedCipherSuite = ConversationEntity.CipherSuite.MLS_256_DHKEMP521_AES256GCM_SHA512_P521 val updatedProtocol = ConversationEntity.Protocol.MLS conversationDAO.insertConversation(conversation) - val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol) + val changed = + conversationDAO.updateConversationProtocolAndCipherSuite(conversation.id, groupId, updatedProtocol, updatedCipherSuite) assertTrue(changed) assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.protocol, updatedProtocol) + assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.mlsGroupId, groupId) + assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.mlsCipherSuite, updatedCipherSuite) } @Test @@ -456,7 +461,12 @@ class ConversationDAOTest : BaseDatabaseTest() { val updatedProtocol = ConversationEntity.Protocol.PROTEUS conversationDAO.insertConversation(conversation) - val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol) + val changed = conversationDAO.updateConversationProtocolAndCipherSuite( + conversation.id, + null, + updatedProtocol, + cipherSuite = ConversationEntity.CipherSuite.UNKNOWN + ) assertFalse(changed) }