Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set the correct cipher suite when claiming key packages [WPB-8592] 🍒 #2746

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f9a5f75
Commit with unresolved merge conflicts
MohamadJaara May 2, 2024
d03ac5a
Commit with unresolved merge conflicts
MohamadJaara May 6, 2024
0e5bc02
Commit with unresolved merge conflicts
MohamadJaara May 7, 2024
745ecac
Commit with unresolved merge conflicts
MohamadJaara May 7, 2024
54948e8
Commit with unresolved merge conflicts
MohamadJaara May 7, 2024
5da1fd7
Merge remote-tracking branch 'refs/remotes/origin/release/candidate' …
MohamadJaara May 16, 2024
53357a8
fix tests
MohamadJaara May 17, 2024
14e8e1c
detekt
MohamadJaara May 17, 2024
ec66056
Trigger CI
MohamadJaara May 18, 2024
07b76b2
Merge branch 'release/candidate' into feat/pass-signature-algorithm-w…
MohamadJaara May 18, 2024
a4b4b9d
Merge remote-tracking branch 'refs/remotes/origin/chore/update-CC-to-…
MohamadJaara May 18, 2024
614244e
Merge branch 'refs/heads/feat/pass-signature-algorithm-when-registrin…
MohamadJaara May 18, 2024
0b9ddb8
Merge remote-tracking branch 'refs/remotes/origin/chore/update-CC-to-…
MohamadJaara May 18, 2024
50e9359
Trigger CI
MohamadJaara May 18, 2024
b7ce200
Merge branch 'chore/update-CC-to-RC-59-cherry-pick' into fix/fetch-ML…
MohamadJaara May 18, 2024
802560e
detekt
MohamadJaara May 18, 2024
ba40ef2
test
MohamadJaara May 19, 2024
995d0e8
test
MohamadJaara May 19, 2024
473df21
BaseProteusClientTest
MohamadJaara May 19, 2024
a0b1fbd
fix merge issues
MohamadJaara May 21, 2024
6f4110e
Merge branch 'release/candidate' into chore/update-CC-to-RC-59-cherry…
MohamadJaara May 21, 2024
753914e
Merge branch 'chore/update-CC-to-RC-59-cherry-pick' into fix/fetch-ML…
MohamadJaara May 21, 2024
0164522
Merge branch 'refs/heads/fix/fetch-MLS-config-when-not-avilable-local…
MohamadJaara May 21, 2024
8b05199
Merge branch 'release/candidate' into feat/set-the-correct-public-key…
MohamadJaara May 21, 2024
1202c3f
fix merge issues
MohamadJaara May 21, 2024
d6a3bfd
Merge branch 'refs/heads/feat/set-the-correct-public-key-when-creatin…
MohamadJaara May 21, 2024
7b9d7ac
Merge remote-tracking branch 'refs/remotes/origin/release/candidate' …
MohamadJaara May 21, 2024
89dd9fe
fix merge issues
MohamadJaara May 21, 2024
f8607a5
fix merge issues
MohamadJaara May 21, 2024
cbcd570
fix merge issues
MohamadJaara May 22, 2024
8a02fd7
fix test
MohamadJaara May 22, 2024
37a582b
detekt
MohamadJaara May 22, 2024
b8dac7f
fix tests
MohamadJaara May 22, 2024
915fb30
fix tests
MohamadJaara May 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ class MLSClientTest : BaseMLSClientTest() {
}

private suspend fun createClient(user: SampleUser): MLSClient {
return createMLSClient(user.qualifiedClientId, allowedCipherSuites = ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITES)
return createMLSClient(
clientId = user.qualifiedClientId,
allowedCipherSuites = ALLOWED_CIPHER_SUITES,
defaultCipherSuite = DEFAULT_CIPHER_SUITES
)
}

@Test
fun givemMlsClient_whenCallingGetDefaultCipherSuite_ReturnExpectedValue() = runTest {
fun givenMlsClient_whenCallingGetDefaultCipherSuite_ReturnExpectedValue() = runTest {
val mlsClient = createClient(ALICE1)
assertEquals(DEFAULT_CIPHER_SUITES, mlsClient.getDefaultCipherSuite())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.message.MessageContent.MemberChange.FailedToAdd
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.service.ServiceId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
Expand Down Expand Up @@ -130,20 +131,13 @@ internal class ConversationGroupRepositoryImpl(
}

when (apiResult) {
is Either.Left -> {
val canRetryOnce = apiResult.value.hasUnreachableDomainsError && lastUsersAttempt is LastUsersAttempt.None
if (canRetryOnce) {
extractValidUsersForRetryableError(apiResult.value, usersList)
.flatMap { (validUsers, failedUsers, failType) ->
// edge case, in case backend goes 🍌 and returns non-matching domains
if (failedUsers.isEmpty()) Either.Left(apiResult.value)

createGroupConversation(name, validUsers, options, LastUsersAttempt.Failed(failedUsers, failType))
}
} else {
Either.Left(apiResult.value)
}
}
is Either.Left -> handleCreateConverstionFailure(
apiResult = apiResult,
usersList = usersList,
name = name,
options = options,
lastUsersAttempt = lastUsersAttempt
)

is Either.Right -> handleGroupConversationCreated(apiResult.value, selfTeamId, usersList, lastUsersAttempt)
}
Expand Down Expand Up @@ -210,6 +204,27 @@ internal class ConversationGroupRepositoryImpl(
}
}

private suspend fun handleCreateConverstionFailure(
apiResult: Either.Left<NetworkFailure>,
usersList: List<UserId>,
name: String?,
options: ConversationOptions,
lastUsersAttempt: LastUsersAttempt
): Either<CoreFailure, Conversation> {
val canRetryOnce = apiResult.value.hasUnreachableDomainsError && lastUsersAttempt is LastUsersAttempt.None
return if (canRetryOnce) {
extractValidUsersForRetryableError(apiResult.value, usersList)
.flatMap { (validUsers, failedUsers, failType) ->
// edge case, in case backend goes 🍌 and returns non-matching domains
if (failedUsers.isEmpty()) Either.Left(apiResult.value)

createGroupConversation(name, validUsers, options, LastUsersAttempt.Failed(failedUsers, failType))
}
} else {
Either.Left(apiResult.value)
}
}

override suspend fun addMembers(
userIdList: List<UserId>,
conversationId: ConversationId
Expand All @@ -224,11 +239,21 @@ internal class ConversationGroupRepositoryImpl(
tryAddMembersToCloudAndStorage(userIdList, conversationId, LastUsersAttempt.None)
.flatMap {
// best effort approach for migrated conversations, no retries
mlsConversationRepository.addMemberToMLSGroup(GroupID(protocol.groupId), userIdList)
mlsConversationRepository.addMemberToMLSGroup(
GroupID(protocol.groupId),
userIdList,
CipherSuite.fromTag(protocol.cipherSuite.cipherSuiteTag)
)
}

is ConversationEntity.ProtocolInfo.MLS -> {
tryAddMembersToMLSGroup(conversationId, protocol.groupId, userIdList, LastUsersAttempt.None)
tryAddMembersToMLSGroup(
conversationId,
protocol.groupId,
userIdList,
LastUsersAttempt.None,
cipherSuite = CipherSuite.fromTag(protocol.cipherSuite.cipherSuiteTag)
)
}
}
}
Expand All @@ -237,14 +262,22 @@ internal class ConversationGroupRepositoryImpl(
* Handle the error cases and retry for claimPackages offline and out of packages.
* Handle error case and retry for sendingCommit unreachable or missing legal hold consent.
*/
@Suppress("LongMethod")
private suspend fun tryAddMembersToMLSGroup(
conversationId: ConversationId,
groupId: String,
userIdList: List<UserId>,
lastUsersAttempt: LastUsersAttempt,
cipherSuite: CipherSuite,
remainingAttempts: Int = 2
): Either<CoreFailure, Unit> {
return when (val addingMemberResult = mlsConversationRepository.addMemberToMLSGroup(GroupID(groupId), userIdList)) {
return when (
val addingMemberResult = mlsConversationRepository.addMemberToMLSGroup(
GroupID(groupId),
userIdList,
cipherSuite
)
) {
is Either.Right -> handleMLSMembersNotAdded(conversationId, lastUsersAttempt)
is Either.Left -> {
addingMemberResult.value.handleMLSMembersFailed(
Expand All @@ -253,17 +286,20 @@ internal class ConversationGroupRepositoryImpl(
userIdList = userIdList,
lastUsersAttempt = lastUsersAttempt,
remainingAttempts = remainingAttempts,
cipherSuite = cipherSuite
)
}
}
}

@Suppress("LongMethod")
private suspend fun CoreFailure.handleMLSMembersFailed(
conversationId: ConversationId,
groupId: String,
userIdList: List<UserId>,
lastUsersAttempt: LastUsersAttempt,
remainingAttempts: Int,
cipherSuite: CipherSuite
): Either<CoreFailure, Unit> {
return when {
// claiming key packages offline or out of packages
Expand All @@ -277,7 +313,8 @@ internal class ConversationGroupRepositoryImpl(
failedUsers = lastUsersAttempt.failedUsers + failedUsers,
failType = FailedToAdd.Type.Federation,
),
remainingAttempts = remainingAttempts - 1
remainingAttempts = remainingAttempts - 1,
cipherSuite = cipherSuite
)
}

Expand All @@ -292,7 +329,8 @@ internal class ConversationGroupRepositoryImpl(
failedUsers = lastUsersAttempt.failedUsers + failedUsers,
failType = FailedToAdd.Type.Federation,
),
remainingAttempts = remainingAttempts - 1
remainingAttempts = remainingAttempts - 1,
cipherSuite = cipherSuite
)
}

Expand All @@ -308,7 +346,8 @@ internal class ConversationGroupRepositoryImpl(
failedUsers = lastUsersAttempt.failedUsers + failedUsers,
failType = FailedToAdd.Type.LegalHold,
),
remainingAttempts = remainingAttempts - 1
remainingAttempts = remainingAttempts - 1,
cipherSuite = cipherSuite
)
}
}
Expand Down Expand Up @@ -479,7 +518,11 @@ internal class ConversationGroupRepositoryImpl(

is ConversationEntity.ProtocolInfo.MLSCapable -> {
joinExistingMLSConversation(conversationId).flatMap {
mlsConversationRepository.addMemberToMLSGroup(GroupID(protocol.groupId), listOf(selfUserId))
mlsConversationRepository.addMemberToMLSGroup(
GroupID(protocol.groupId),
listOf(selfUserId),
CipherSuite.fromTag(protocol.cipherSuite.cipherSuiteTag)
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ interface MLSConversationRepository {

suspend fun establishMLSSubConversationGroup(groupID: GroupID, parentId: ConversationId): Either<CoreFailure, Unit>
suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either<CoreFailure, Boolean>
suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List<UserId>): Either<CoreFailure, Unit>
suspend fun addMemberToMLSGroup(
groupID: GroupID,
userIdList: List<UserId>,
cipherSuite: CipherSuite
): Either<CoreFailure, Unit>

suspend fun removeMembersFromMLSGroup(groupID: GroupID, userIdList: List<UserId>): Either<CoreFailure, Unit>
suspend fun removeClientsFromMLSGroup(groupID: GroupID, clientIdList: List<QualifiedClientID>): Either<CoreFailure, Unit>
suspend fun leaveGroup(groupID: GroupID): Either<CoreFailure, Unit>
Expand Down Expand Up @@ -202,7 +207,7 @@ private fun CoreFailure.getStrategy(

// TODO: refactor this repository as it's doing too much.
// A Repository should be a dummy class that get and set some values
@Suppress("TooManyFunctions", "LongParameterList")
@Suppress("TooManyFunctions", "LongParameterList", "LargeClass")
internal class MLSConversationDataSource(
private val selfUserId: UserId,
private val keyPackageRepository: KeyPackageRepository,
Expand Down Expand Up @@ -448,23 +453,29 @@ internal class MLSConversationDataSource(
conversationDAO.getProposalTimers().map { it.map(conversationMapper::fromDaoModel) }.flatten()
)

override suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List<UserId>): Either<CoreFailure, Unit> =
override suspend fun addMemberToMLSGroup(
groupID: GroupID,
userIdList: List<UserId>,
cipherSuite: CipherSuite
): Either<CoreFailure, Unit> =
internalAddMemberToMLSGroup(
groupID = groupID,
userIdList = userIdList,
retryOnStaleMessage = true,
allowPartialMemberList = false
allowPartialMemberList = false,
cipherSuite = cipherSuite
).map { Unit }

private suspend fun internalAddMemberToMLSGroup(
groupID: GroupID,
userIdList: List<UserId>,
retryOnStaleMessage: Boolean,
cipherSuite: CipherSuite,
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
commitPendingProposals(groupID).flatMap {
produceAndSendCommitWithRetryAndResult(groupID, retryOnStaleMessage = retryOnStaleMessage) {
keyPackageRepository.claimKeyPackages(userIdList).flatMap { result ->
keyPackageRepository.claimKeyPackages(userIdList, cipherSuite).flatMap { result ->
if (result.usersWithoutKeyPackagesAvailable.isNotEmpty() && !allowPartialMemberList) {
Either.Left(CoreFailure.MissingKeyPackages(result.usersWithoutKeyPackagesAvailable))
} else {
Expand Down Expand Up @@ -606,7 +617,8 @@ internal class MLSConversationDataSource(
groupID = groupID,
userIdList = members,
retryOnStaleMessage = false,
allowPartialMemberList = allowPartialMemberList
allowPartialMemberList = allowPartialMemberList,
cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())
).onFailure {
wrapMLSRequest {
mlsClient.wipeConversation(groupID.toCrypto())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.mls.KeyPackageClaimResult
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
Expand All @@ -50,7 +51,10 @@ interface KeyPackageRepository {
* available. If the operation fails, it will be [Either.Left] with a [CoreFailure] object indicating the reason for the failure.
* If **no** KeyPackages are available, [CoreFailure.MissingKeyPackages] will be the cause.
*/
suspend fun claimKeyPackages(userIds: List<UserId>): Either<CoreFailure, KeyPackageClaimResult>
suspend fun claimKeyPackages(
userIds: List<UserId>,
cipherSuite: CipherSuite
): Either<CoreFailure, KeyPackageClaimResult>

suspend fun uploadNewKeyPackages(clientId: ClientId, amount: Int = 100): Either<CoreFailure, Unit>

Expand All @@ -61,7 +65,6 @@ interface KeyPackageRepository {
suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either<NetworkFailure, KeyPackageCountDTO>

suspend fun validKeyPackageCount(clientId: ClientId): Either<CoreFailure, Int>

}

class KeyPackageDataSource(
Expand All @@ -71,13 +74,22 @@ class KeyPackageDataSource(
private val selfUserId: UserId,
) : KeyPackageRepository {

override suspend fun claimKeyPackages(userIds: List<UserId>): Either<CoreFailure, KeyPackageClaimResult> =
override suspend fun claimKeyPackages(
userIds: List<UserId>,
cipherSuite: CipherSuite
): Either<CoreFailure, KeyPackageClaimResult> =
currentClientIdProvider().flatMap { selfClientId ->
val failedUsers = mutableSetOf<UserId>()
val claimedKeyPackages = mutableListOf<KeyPackageDTO>()
userIds.forEach { userId ->
wrapApiRequest {
keyPackageApi.claimKeyPackages(KeyPackageApi.Param.SkipOwnClient(userId.toApi(), selfClientId.value))
keyPackageApi.claimKeyPackages(
KeyPackageApi.Param.SkipOwnClient(
userId.toApi(),
selfClientId.value,
cipherSuite = cipherSuite.tag
)
)
}.fold({ failedUsers.add(userId) }) {
if (it.keyPackages.isEmpty() && userId != selfUserId) {
failedUsers.add(userId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ internal class MLSMigratorImpl(
mlsConversationRepository.establishMLSGroup(protocolInfo.groupId, emptyList())
.flatMap {
conversationRepository.getConversationMembers(conversationId).flatMap { members ->
mlsConversationRepository.addMemberToMLSGroup(protocolInfo.groupId, members)
mlsConversationRepository.addMemberToMLSGroup(
protocolInfo.groupId,
members,
protocolInfo.cipherSuite
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class E2EIClientProviderTest {
return this to e2eiClientProvider
}

suspend fun withGetOrFetchMLSConfig(result: SupportedCipherSuite) {
override suspend fun withGetOrFetchMLSConfig(result: SupportedCipherSuite) {
coEvery { mlsClientProvider.getOrFetchMLSConfig() }.returns(result.right())
}
}
Expand Down
Loading
Loading