Skip to content

Commit

Permalink
feat: FederatedSearchParser accept if other domain allowed [WPB-10788] (
Browse files Browse the repository at this point in the history
#3004)

* feat: FederatedSearchParser accept if other domain allowed

* Added GetConversationProtocolInfoUseCase

* Code style fix
  • Loading branch information
borichellow authored Sep 16, 2024
1 parent c9ca072 commit 85e5c25
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class ConversationScope internal constructor(
val observeConversationDetails: ObserveConversationDetailsUseCase
get() = ObserveConversationDetailsUseCase(conversationRepository)

val getConversationProtocolInfo: GetConversationProtocolInfoUseCase
get() = GetConversationProtocolInfoUseCase(conversationRepository)

val notifyConversationIsOpen: NotifyConversationIsOpenUseCase
get() = NotifyConversationIsOpenUseCaseImpl(
oneOnOneResolver,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/

package com.wire.kalium.logic.feature.conversation

import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.util.KaliumDispatcher
import com.wire.kalium.util.KaliumDispatcherImpl
import kotlinx.coroutines.withContext

/**
* This use case that get Conversation.ProtocolInfo for a specific conversation.
* @see Conversation.ProtocolInfo
*/
interface GetConversationProtocolInfoUseCase {
sealed class Result {
data class Success(val protocolInfo: Conversation.ProtocolInfo) : Result()
data class Failure(val storageFailure: StorageFailure) : Result()
}

/**
* @param conversationId the id of the conversation to observe
* @return a [Result] with the [Conversation.ProtocolInfo] of the conversation
*/
suspend operator fun invoke(conversationId: ConversationId): Result
}

@Suppress("FunctionNaming")
internal fun GetConversationProtocolInfoUseCase(
conversationRepository: ConversationRepository,
dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) = object : GetConversationProtocolInfoUseCase {
override suspend operator fun invoke(conversationId: ConversationId): GetConversationProtocolInfoUseCase.Result =
withContext(dispatcher.io) {
conversationRepository.getConversationProtocolInfo(conversationId)
.fold({
GetConversationProtocolInfoUseCase.Result.Failure(it)
}, {
GetConversationProtocolInfoUseCase.Result.Success(it)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class FederatedSearchParser(
private val cachedIsFederationEnabled = AtomicReference<Boolean?>(null)
private val mutex = Mutex()

suspend operator fun invoke(searchQuery: String): Result {
suspend operator fun invoke(searchQuery: String, isOtherDomainAllowed: Boolean): Result {

val isFederated = cachedIsFederationEnabled.get()
?: mutex.withLock {
Expand All @@ -51,7 +51,7 @@ class FederatedSearchParser(
}

return when {
!isFederated -> Result(searchQuery, selfUserId.domain)
!isFederated || !isOtherDomainAllowed -> Result(searchQuery, selfUserId.domain)

searchQuery.matches(regex) -> {
val domain = searchQuery.substringAfterLast(DOMAIN_SEPARATOR)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.conversation

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.test_util.testKaliumDispatcher
import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl
import com.wire.kalium.util.KaliumDispatcher
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

class GetConversationProtocolInfoUseCaseTest {

@Test
fun givenGetConversationProtocolFails_whenInvoke_thenFailureReturned() = runTest {
val (_, useCase) = Arrangement()
.arrange {
dispatcher = this@runTest.testKaliumDispatcher
withConversationProtocolInfo(Either.Left(StorageFailure.DataNotFound))
}

assertTrue(useCase(ConversationId("ss", "dd")) is GetConversationProtocolInfoUseCase.Result.Failure)
}

@Test
fun givenGetConversationProtocolSucceed_whenInvoke_thenSuccessReturned() = runTest {
val (_, useCase) = Arrangement()
.arrange {
dispatcher = this@runTest.testKaliumDispatcher
withConversationProtocolInfo(Either.Right(Conversation.ProtocolInfo.Proteus))
}

val result = useCase(ConversationId("ss", "dd"))

assertTrue(result is GetConversationProtocolInfoUseCase.Result.Success)
assertEquals(Conversation.ProtocolInfo.Proteus, result.protocolInfo)
}

private class Arrangement : ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl() {

var dispatcher: KaliumDispatcher = TestKaliumDispatcher
private lateinit var getConversationProtocolInfo: GetConversationProtocolInfoUseCase

suspend fun arrange(block: suspend Arrangement.() -> Unit): Pair<Arrangement, GetConversationProtocolInfoUseCase> {
block()
getConversationProtocolInfo = GetConversationProtocolInfoUseCase(
conversationRepository,
dispatcher
)

return this to getConversationProtocolInfo
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class FederatedSearchParserTest {
}

val searchQuery = "searchQuery"
val result = federatedSearchParser(searchQuery)
val result = federatedSearchParser(searchQuery, true)

assertEquals(searchQuery, result.searchTerm)
assertEquals(selfUserId.domain, result.domain)
Expand All @@ -58,7 +58,7 @@ class FederatedSearchParserTest {
}

val searchQuery = "search Query"
val result = federatedSearchParser(searchQuery)
val result = federatedSearchParser(searchQuery, true)

assertEquals(searchQuery, result.searchTerm)
assertEquals(selfUserId.domain, result.domain)
Expand All @@ -75,7 +75,7 @@ class FederatedSearchParserTest {
}

val searchQuery = " search Query @domain.co"
val result = federatedSearchParser(searchQuery)
val result = federatedSearchParser(searchQuery, true)

assertEquals(" search Query ", result.searchTerm)
assertEquals("domain.co", result.domain)
Expand All @@ -92,9 +92,26 @@ class FederatedSearchParserTest {
}

val searchQuery = " search Query @domain.co"
federatedSearchParser(searchQuery)
federatedSearchParser(searchQuery)
federatedSearchParser(searchQuery)
federatedSearchParser(searchQuery, true)
federatedSearchParser(searchQuery, true)
federatedSearchParser(searchQuery, true)

coVerify {
arrangement.sessionRepository.isFederated(eq(selfUserId))
}.wasInvoked(exactly = once)
}

@Test
fun givenUserIsNotFederated_whenSearchQueryIncludeDomainButRemoteDomainForbidden_thenSearchQueryIsNotModified() = runTest {
val (arrangement, federatedSearchParser) = Arrangement().arrange {
withIsFederated(result = true.right(), userId = AnyMatcher(valueOf()))
}

val searchQuery = " search Query @domain.co"
val result = federatedSearchParser(searchQuery, false)

assertEquals(" search Query @domain.co", result.searchTerm)
assertEquals(selfUserId.domain, result.domain)

coVerify {
arrangement.sessionRepository.isFederated(eq(selfUserId))
Expand Down

0 comments on commit 85e5c25

Please sign in to comment.