Skip to content

Commit

Permalink
fix(mls): update migrated conversation with correct group id and ciph…
Browse files Browse the repository at this point in the history
…er suites (WPB-9169) 🍒 (#2771) (#2788)

* Commit with unresolved merge conflicts

* Commit with unresolved merge conflicts

* trigger CI

* fix tests to match candidate changes

* fix tests

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Mojtaba Chenani <[email protected]>
Co-authored-by: Vitor Hugo Schwaab <[email protected]>
  • Loading branch information
3 people authored Jun 5, 2024
1 parent d891f1c commit 6e038eb
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import com.wire.kalium.logic.functional.mapRight
import com.wire.kalium.logic.functional.mapToRightOr
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.functional.right
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.logic.wrapMLSRequest
Expand All @@ -63,7 +64,6 @@ import com.wire.kalium.network.api.base.authenticated.conversation.ConversationR
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessResponse
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolResponse
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationReceiptModeResponse
import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationMemberRoleDTO
import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationReceiptModeDTO
Expand Down Expand Up @@ -1026,17 +1026,8 @@ internal class ConversationDataSource internal constructor(
): Either<CoreFailure, Boolean> =
wrapApiRequest {
conversationApi.updateProtocol(conversationId.toApi(), protocol.toApi())
}.flatMap { response ->
when (response) {
UpdateConversationProtocolResponse.ProtocolUnchanged -> {
// no need to update conversation
Either.Right(false)
}

is UpdateConversationProtocolResponse.ProtocolUpdated -> {
updateProtocolLocally(conversationId, protocol)
}
}
}.flatMap {
updateProtocolLocally(conversationId, protocol)
}

override suspend fun updateProtocolLocally(
Expand All @@ -1047,19 +1038,19 @@ internal class ConversationDataSource internal constructor(
conversationApi.fetchConversationDetails(conversationId.toApi())
}.flatMap { conversationResponse ->
wrapStorageRequest {
conversationDAO.updateConversationProtocol(
conversationDAO.updateConversationProtocolAndCipherSuite(
conversationId = conversationId.toDao(),
protocol = protocol.toDao()
groupID = conversationResponse.groupId,
protocol = protocol.toDao(),
cipherSuite = ConversationEntity.CipherSuite.fromTag(conversationResponse.mlsCipherSuiteTag)
)
}.flatMap { updated ->
if (updated) {
val selfUserTeamId = selfTeamIdProvider().getOrNull()
persistConversations(listOf(conversationResponse), selfUserTeamId, invalidateMembers = true)
} else {
Either.Right(Unit)
}.map {
updated
return@flatMap true.right()
}
val selfUserTeamId = selfTeamIdProvider().getOrNull()
persistConversations(listOf(conversationResponse), selfUserTeamId, invalidateMembers = true)
.map { true }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.PersistenceQualifiedId
import com.wire.kalium.logic.data.id.QualifiedID
import com.wire.kalium.logic.data.id.SelfTeamIdProvider
import com.wire.kalium.logic.data.id.TeamId
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toCrypto
import com.wire.kalium.logic.data.id.toDao
Expand All @@ -37,8 +39,6 @@ import com.wire.kalium.logic.data.user.SelfUser
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.id.SelfTeamIdProvider
import com.wire.kalium.logic.data.id.TeamId
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.framework.TestTeam
import com.wire.kalium.logic.framework.TestUser
Expand Down Expand Up @@ -95,14 +95,13 @@ import com.wire.kalium.util.DateTimeUtil
import io.ktor.http.HttpStatusCode
import io.mockative.Mock
import io.mockative.any
import io.mockative.eq
import io.mockative.coEvery
import io.mockative.coVerify
import io.mockative.eq
import io.mockative.fake.valueOf
import io.mockative.matchers.AnyMatcher
import io.mockative.matchers.EqualsMatcher
import io.mockative.matchers.Matcher
import io.mockative.matchers.PredicateMatcher
import io.mockative.matches
import io.mockative.mock
import io.mockative.once
Expand All @@ -121,7 +120,6 @@ import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import com.wire.kalium.network.api.base.model.ConversationId as APIConversationId
import com.wire.kalium.network.api.base.model.ConversationId as ConversationIdDTO
import com.wire.kalium.persistence.dao.client.Client as ClientEntity

@Suppress("LargeClass")
Expand Down Expand Up @@ -562,7 +560,6 @@ class ConversationRepositoryTest {

val (arrange, conversationRepository) = Arrangement()
.withApiUpdateAccessRoleReturns(NetworkResponse.Success(newAccess, mapOf(), 200))
.withDaoUpdateAccessSuccess()
.arrange()

conversationRepository.updateAccessInfo(
Expand Down Expand Up @@ -1156,12 +1153,18 @@ class ConversationRepositoryTest {
}

@Test
fun givenNoChange_whenUpdatingProtocolToMls_thenShouldNotUpdateLocally() = runTest {
fun givenNoChange_whenUpdatingProtocolToMls_thenShouldUpdateLocally() = runTest {
// given
val protocol = Conversation.Protocol.MLS

val conversationResponse = NetworkResponse.Success(
TestConversation.CONVERSATION_RESPONSE,
emptyMap(),
HttpStatusCode.OK.value
)
val (arrange, conversationRepository) = Arrangement()
.withDaoUpdateProtocolSuccess()
.withUpdateProtocolResponse(UPDATE_PROTOCOL_UNCHANGED)
.withFetchConversationsDetails(conversationResponse)
.arrange()

// when
Expand All @@ -1171,8 +1174,13 @@ class ConversationRepositoryTest {
with(result) {
shouldSucceed()
coVerify {
arrange.conversationDAO.updateConversationProtocol(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
}.wasNotInvoked()
arrange.conversationDAO.updateConversationProtocolAndCipherSuite(
eq(CONVERSATION_ID.toDao()),
eq(conversationResponse.value.groupId),
eq(protocol.toDao()),
eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag))
)
}.wasInvoked(exactly = once)
}
}

Expand Down Expand Up @@ -1227,7 +1235,12 @@ class ConversationRepositoryTest {
with(result) {
shouldSucceed()
coVerify {
arrange.conversationDAO.updateConversationProtocol(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
arrange.conversationDAO.updateConversationProtocolAndCipherSuite(
eq(CONVERSATION_ID.toDao()),
eq(conversationResponse.value.groupId),
eq(protocol.toDao()),
eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag))
)
}.wasInvoked(exactly = once)
}
}
Expand All @@ -1254,7 +1267,12 @@ class ConversationRepositoryTest {
with(result) {
shouldSucceed()
coVerify {
arrange.conversationDAO.updateConversationProtocol(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
arrange.conversationDAO.updateConversationProtocolAndCipherSuite(
eq(CONVERSATION_ID.toDao()),
eq(conversationResponse.value.groupId),
eq(protocol.toDao()),
eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag))
)
}.wasInvoked(exactly = once)
}
}
Expand All @@ -1274,9 +1292,8 @@ class ConversationRepositoryTest {
// then
with(result) {
shouldFail()
coVerify {
arrange.conversationDAO.updateConversationProtocol(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
}.wasNotInvoked()
coVerify { arrange.conversationDAO.updateConversationProtocolAndCipherSuite(any(), any(), any(), any()) }
.wasNotInvoked()
}
}

Expand Down Expand Up @@ -1511,16 +1528,9 @@ class ConversationRepositoryTest {
}.returns(response)
}

suspend fun withDaoUpdateAccessSuccess() = apply {
coEvery {
conversationDAO.updateAccess(any(), any(), any())
}.returns(Unit)
}

suspend fun withDaoUpdateProtocolSuccess() = apply {
coEvery {
conversationDAO.updateConversationProtocol(any(), any())
}.returns(true)
coEvery { conversationDAO.updateConversationProtocolAndCipherSuite(any(), any(), any(), any()) }
.returns(true)
}

suspend fun withGetConversationProtocolInfoReturns(result: ConversationEntity.ProtocolInfo) = apply {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,11 @@ UPDATE Conversation
SET type = ?
WHERE qualified_id = ?;

updateConversationProtocol {
updateConversationGroupIdAndProtocolInfo {
UPDATE Conversation
SET protocol = :protocol
WHERE qualified_id = :qualified_id AND protocol != :protocol;
SET mls_group_id = :groupId, protocol = :protocol, mls_cipher_suite = :mls_cipher_suite
WHERE qualified_id = :qualified_id AND
protocol != :protocol;
SELECT changes();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ interface ConversationDAO {
protocol: ConversationEntity.Protocol,
teamId: String? = null
): List<QualifiedIDEntity>

suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: String): List<QualifiedIDEntity>
suspend fun observeGetConversationByQualifiedID(qualifiedID: QualifiedIDEntity): Flow<ConversationViewEntity?>
suspend fun observeGetConversationBaseInfoByQualifiedID(qualifiedID: QualifiedIDEntity): Flow<ConversationEntity?>
Expand All @@ -61,6 +62,7 @@ interface ConversationDAO {
userId: UserIDEntity,
protocol: ConversationEntity.Protocol
): List<QualifiedIDEntity>

suspend fun observeOneOnOneConversationWithOtherUser(userId: UserIDEntity): Flow<ConversationViewEntity?>
suspend fun getConversationProtocolInfo(qualifiedID: QualifiedIDEntity): ConversationEntity.ProtocolInfo?
suspend fun observeConversationByGroupID(groupID: String): Flow<ConversationViewEntity?>
Expand Down Expand Up @@ -94,7 +96,13 @@ interface ConversationDAO {
suspend fun whoDeletedMeInConversation(conversationId: QualifiedIDEntity, selfUserIdString: String): UserIDEntity?
suspend fun updateConversationName(conversationId: QualifiedIDEntity, conversationName: String, timestamp: String)
suspend fun updateConversationType(conversationID: QualifiedIDEntity, type: ConversationEntity.Type)
suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean
suspend fun updateConversationProtocolAndCipherSuite(
conversationId: QualifiedIDEntity,
groupID: String?,
protocol: ConversationEntity.Protocol,
cipherSuite: ConversationEntity.CipherSuite
): Boolean

suspend fun getConversationsByUserId(userId: UserIDEntity): List<ConversationEntity>
suspend fun updateConversationReceiptMode(conversationID: QualifiedIDEntity, receiptMode: ConversationEntity.ReceiptMode)
suspend fun updateGuestRoomLink(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,19 @@ internal class ConversationDAOImpl internal constructor(
conversationQueries.updateConversationType(type, conversationID)
}

override suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean {
override suspend fun updateConversationProtocolAndCipherSuite(
conversationId: QualifiedIDEntity,
groupID: String?,
protocol: ConversationEntity.Protocol,
cipherSuite: ConversationEntity.CipherSuite
): Boolean {
return withContext(coroutineContext) {
conversationQueries.updateConversationProtocol(protocol, conversationId).executeAsOne() > 0
conversationQueries.updateConversationGroupIdAndProtocolInfo(
groupID,
protocol,
cipherSuite,
conversationId
).executeAsOne() > 0
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,13 +441,18 @@ class ConversationDAOTest : BaseDatabaseTest() {
@Test
fun givenNewValue_whenUpdatingProtocol_thenItsUpdatedAndReportedAsChanged() = runTest {
val conversation = conversationEntity5
val groupId = "groupId"
val updatedCipherSuite = ConversationEntity.CipherSuite.MLS_256_DHKEMP521_AES256GCM_SHA512_P521
val updatedProtocol = ConversationEntity.Protocol.MLS

conversationDAO.insertConversation(conversation)
val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol)
val changed =
conversationDAO.updateConversationProtocolAndCipherSuite(conversation.id, groupId, updatedProtocol, updatedCipherSuite)

assertTrue(changed)
assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.protocol, updatedProtocol)
assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.mlsGroupId, groupId)
assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.mlsCipherSuite, updatedCipherSuite)
}

@Test
Expand All @@ -456,7 +461,12 @@ class ConversationDAOTest : BaseDatabaseTest() {
val updatedProtocol = ConversationEntity.Protocol.PROTEUS

conversationDAO.insertConversation(conversation)
val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol)
val changed = conversationDAO.updateConversationProtocolAndCipherSuite(
conversation.id,
null,
updatedProtocol,
cipherSuite = ConversationEntity.CipherSuite.UNKNOWN
)

assertFalse(changed)
}
Expand Down

0 comments on commit 6e038eb

Please sign in to comment.