Skip to content

Commit

Permalink
fix: failure fetching federated certificate chain (#2660)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitorhugods authored and github-actions[bot] committed Mar 14, 2024
1 parent d1c8a89 commit 0dd7b4a
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.foldToEitherWhileRight
import com.wire.kalium.logic.functional.getOrFail
import com.wire.kalium.logic.functional.left
import com.wire.kalium.logic.functional.onSuccess
Expand Down Expand Up @@ -344,24 +345,29 @@ class E2EIRepositoryImpl(

override suspend fun fetchFederationCertificates() = discoveryUrl().flatMap {
wrapApiRequest {
acmeApi.getACMEFederation(it)
acmeApi.getACMEFederationCertificateChain(it)
}.fold({
E2EIFailure.IntermediateCert(it).left()
}, { data ->
currentClientIdProvider().fold({
E2EIFailure.TrustAnchors(it).left()
}, { clientId ->
mlsClientProvider.getCoreCrypto(clientId).fold({
E2EIFailure.MissingMLSClient(it).left()
}, { coreCrypto ->
wrapE2EIRequest {
coreCrypto.registerIntermediateCa(data)
}
})
})
registerIntermediateCAs(data)
})
}

private suspend fun registerIntermediateCAs(data: List<String>) =
currentClientIdProvider().fold({
E2EIFailure.TrustAnchors(it).left()
}, { clientId ->
mlsClientProvider.getCoreCrypto(clientId).fold({
E2EIFailure.MissingMLSClient(it).left()
}, { coreCrypto ->
data.foldToEitherWhileRight(Unit) { item, _ ->
wrapE2EIRequest {
coreCrypto.registerIntermediateCa(item)
}
}
})
})

override fun discoveryUrl() =
userConfigRepository.getE2EISettings().fold({
E2EIFailure.MissingTeamSettings.left()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class EnrollE2EIUseCaseImpl internal constructor(
e2EIRepository.initFreshE2EIClient(isNewClient = isNewClientRegistration)

e2EIRepository.fetchAndSetTrustAnchors()
e2EIRepository.fetchFederationCertificates().getOrFail {
kaliumLogger.e("Failure fetching federation certificates during E2EI Enrolling!. Failure:$it")
return it.left()
}

val acmeDirectories = e2EIRepository.loadACMEDirectories().getOrFail {
return it.left()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ import com.wire.kalium.network.api.base.unbound.acme.DtoAuthorizationChallengeTy
import com.wire.kalium.network.exceptions.KaliumException
import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.util.DateTimeUtil
import io.ktor.http.Url
import io.mockative.Mock
import io.mockative.any
import io.mockative.anyInstanceOf
Expand Down Expand Up @@ -881,7 +880,7 @@ class E2EIRepositoryTest {
.wasInvoked(once)

verify(arrangement.acmeApi)
.suspendFunction(arrangement.acmeApi::getACMEFederation)
.suspendFunction(arrangement.acmeApi::getACMEFederationCertificateChain)
.with(any())
.wasNotInvoked()

Expand All @@ -892,12 +891,12 @@ class E2EIRepositoryTest {
}

@Test
fun givenACMEFederationApiSucceed_whenFetchACMECertificates_thenItSucceed() = runTest {
fun givenACMEFederationApiSucceeds_whenFetchACMECertificates_thenAllCertificatesAreRegistered() = runTest {
val certificateList = listOf("a", "b", "potato")
// Given

val (arrangement, e2eiRepository) = Arrangement()
.withGettingE2EISettingsReturns(Either.Right(E2EI_TEAM_SETTINGS))
.withAcmeFederationApiSucceed()
.withAcmeFederationApiSucceed(certificateList)
.withCurrentClientIdProviderSuccessful()
.withGetCoreCryptoSuccessful()
.withRegisterIntermediateCABag()
Expand All @@ -915,14 +914,16 @@ class E2EIRepositoryTest {
.wasInvoked(once)

verify(arrangement.acmeApi)
.suspendFunction(arrangement.acmeApi::getACMEFederation)
.suspendFunction(arrangement.acmeApi::getACMEFederationCertificateChain)
.with(any())
.wasInvoked(once)

verify(arrangement.coreCryptoCentral)
.suspendFunction(arrangement.coreCryptoCentral::registerIntermediateCa)
.with(any())
.wasInvoked(once)
certificateList.forEach { certificateValue ->
verify(arrangement.coreCryptoCentral)
.suspendFunction(arrangement.coreCryptoCentral::registerIntermediateCa)
.with(eq(certificateValue))
.wasInvoked(once)
}
}

@Test
Expand All @@ -949,7 +950,7 @@ class E2EIRepositoryTest {
.wasInvoked(once)

verify(arrangement.acmeApi)
.suspendFunction(arrangement.acmeApi::getACMEFederation)
.suspendFunction(arrangement.acmeApi::getACMEFederationCertificateChain)
.with(any())
.wasNotInvoked()

Expand Down Expand Up @@ -1280,16 +1281,16 @@ class E2EIRepositoryTest {
.thenReturn(NetworkResponse.Error(INVALID_REQUEST_ERROR))
}

fun withAcmeFederationApiSucceed() = apply {
fun withAcmeFederationApiSucceed(certificateList: List<String>) = apply {
given(acmeApi)
.suspendFunction(acmeApi::getACMEFederation)
.suspendFunction(acmeApi::getACMEFederationCertificateChain)
.whenInvokedWith(any())
.thenReturn(NetworkResponse.Success("", mapOf(), 200))
.thenReturn(NetworkResponse.Success(certificateList, mapOf(), 200))
}

fun withAcmeFederationApiFails() = apply {
given(acmeApi)
.suspendFunction(acmeApi::getACMEFederation)
.suspendFunction(acmeApi::getACMEFederationCertificateChain)
.whenInvokedWith(any())
.thenReturn(NetworkResponse.Error(INVALID_REQUEST_ERROR))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class EnrollE2EICertificateUseCaseTest {
// given
arrangement.withInitializingE2EIClientSucceed()
arrangement.withLoadTrustAnchorsResulting(Either.Right(Unit))
arrangement.withFetchFederationCertificateChainResulting(Either.Right(Unit))
arrangement.withLoadACMEDirectoriesResulting(E2EIFailure.AcmeDirectories(TEST_CORE_FAILURE).left())

// when
Expand Down Expand Up @@ -144,6 +145,7 @@ class EnrollE2EICertificateUseCaseTest {
// given
arrangement.withInitializingE2EIClientSucceed()
arrangement.withLoadTrustAnchorsResulting(Either.Right(Unit))
arrangement.withFetchFederationCertificateChainResulting(Either.Right(Unit))
arrangement.withLoadACMEDirectoriesResulting(Either.Right(ACME_DIRECTORIES))
arrangement.withGetACMENonceResulting(E2EIFailure.AcmeNonce(TEST_CORE_FAILURE).left())

Expand Down Expand Up @@ -224,6 +226,7 @@ class EnrollE2EICertificateUseCaseTest {
// given
arrangement.withInitializingE2EIClientSucceed()
arrangement.withLoadTrustAnchorsResulting(Either.Right(Unit))
arrangement.withFetchFederationCertificateChainResulting(Either.Right(Unit))
arrangement.withLoadACMEDirectoriesResulting(Either.Right(ACME_DIRECTORIES))
arrangement.withGetACMENonceResulting(Either.Right(RANDOM_NONCE))
arrangement.withCreateNewAccountResulting(E2EIFailure.AcmeNewAccount(TEST_CORE_FAILURE).left())
Expand Down Expand Up @@ -309,6 +312,7 @@ class EnrollE2EICertificateUseCaseTest {
// given
arrangement.withInitializingE2EIClientSucceed()
arrangement.withLoadTrustAnchorsResulting(Either.Right(Unit))
arrangement.withFetchFederationCertificateChainResulting(Either.Right(Unit))
arrangement.withLoadACMEDirectoriesResulting(Either.Right(ACME_DIRECTORIES))
arrangement.withGetACMENonceResulting(Either.Right(RANDOM_NONCE))
arrangement.withCreateNewAccountResulting(Either.Right(RANDOM_NONCE))
Expand Down Expand Up @@ -396,6 +400,7 @@ class EnrollE2EICertificateUseCaseTest {
// given
arrangement.withInitializingE2EIClientSucceed()
arrangement.withLoadTrustAnchorsResulting(Either.Right(Unit))
arrangement.withFetchFederationCertificateChainResulting(Either.Right(Unit))
arrangement.withLoadACMEDirectoriesResulting(Either.Right(ACME_DIRECTORIES))
arrangement.withGetACMENonceResulting(Either.Right(RANDOM_NONCE))
arrangement.withCreateNewAccountResulting(Either.Right(RANDOM_NONCE))
Expand Down Expand Up @@ -487,6 +492,7 @@ class EnrollE2EICertificateUseCaseTest {
// given
arrangement.withInitializingE2EIClientSucceed()
arrangement.withLoadTrustAnchorsResulting(Either.Right(Unit))
arrangement.withFetchFederationCertificateChainResulting(Either.Right(Unit))
arrangement.withLoadACMEDirectoriesResulting(Either.Right(ACME_DIRECTORIES))
arrangement.withGetACMENonceResulting(Either.Right(RANDOM_NONCE))
arrangement.withCreateNewAccountResulting(Either.Right(RANDOM_NONCE))
Expand Down Expand Up @@ -1111,6 +1117,13 @@ class EnrollE2EICertificateUseCaseTest {
.thenReturn(result)
}

fun withFetchFederationCertificateChainResulting(result: Either<E2EIFailure, Unit>) = apply {
given(e2EIRepository)
.suspendFunction(e2EIRepository::fetchFederationCertificates)
.whenInvoked()
.thenReturn(result)
}

fun withLoadACMEDirectoriesResulting(result: Either<E2EIFailure, AcmeDirectory>) = apply {
given(e2EIRepository)
.suspendFunction(e2EIRepository::loadACMEDirectories)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.network.utils.CustomErrors
import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.network.utils.flatMap
import com.wire.kalium.network.utils.handleUnsuccessfulResponse
import com.wire.kalium.network.utils.mapSuccess
import com.wire.kalium.network.utils.wrapKaliumResponse
import io.ktor.client.call.body
import io.ktor.client.request.accept
Expand All @@ -48,7 +49,14 @@ interface ACMEApi {
suspend fun sendACMERequest(url: String, body: ByteArray? = null): NetworkResponse<ACMEResponse>
suspend fun sendAuthorizationRequest(url: String, body: ByteArray? = null): NetworkResponse<ACMEAuthorizationResponse>
suspend fun sendChallengeRequest(url: String, body: ByteArray): NetworkResponse<ChallengeResponse>
suspend fun getACMEFederation(discoveryUrl: String): NetworkResponse<String>

/**
* Retrieves the ACME federation certificate chain from the specified discovery URL.
*
* @param discoveryUrl The non-blank URL of the ACME federation discovery endpoint.
* @return A [NetworkResponse] object containing the certificate chain as a list of strings.
*/
suspend fun getACMEFederationCertificateChain(discoveryUrl: String): NetworkResponse<List<String>>
suspend fun getClientDomainCRL(url: String): NetworkResponse<ByteArray>
}

Expand Down Expand Up @@ -223,7 +231,7 @@ class ACMEApiImpl internal constructor(
}
}

override suspend fun getACMEFederation(discoveryUrl: String): NetworkResponse<String> {
override suspend fun getACMEFederationCertificateChain(discoveryUrl: String): NetworkResponse<List<String>> {
val protocolWithAuthority = Url(discoveryUrl).protocolWithAuthority
if (discoveryUrl.isBlank() || protocolWithAuthority.isBlank()) {
return NetworkResponse.Error(
Expand All @@ -237,9 +245,9 @@ class ACMEApiImpl internal constructor(
)
}

return wrapKaliumResponse {
return wrapKaliumResponse<FederationCertificateChainResponse> {
httpClient.get("$protocolWithAuthority/$PATH_ACME_FEDERATION")
}
}.mapSuccess { it.certificates }
}

override suspend fun getClientDomainCRL(url: String): NetworkResponse<ByteArray> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,11 @@ enum class DtoAuthorizationChallengeType {
OIDC
}

@Serializable
data class FederationCertificateChainResponse(
@SerialName("crts")
val certificates: List<String>
)

@JvmInline
value class CertificateChain(val value: String)
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ import kotlin.test.*

internal class ACMEApiTest : ApiTest() {

@Test
fun givingASuccessfulResponse_whenGettingACMEFederationCertificateChain_thenAllCertificatesShouldBeParsed() = runTest {
val expected = listOf("a", "b", "potato")

val networkClient = mockUnboundNetworkClient(
responseBody = """
{
"crts": ["a", "b", "potato"]
}
""".trimIndent(),
statusCode = HttpStatusCode.OK
)

val acmeApi: ACMEApi = ACMEApiImpl(networkClient, networkClient)

val result = acmeApi.getACMEFederationCertificateChain("someURL")

assertTrue(result.isSuccessful())
assertContentEquals(expected, result.value)
}

@Ignore
@Test
fun whenCallingGeTrustAnchorsApi_theResponseShouldBeConfigureCorrectly() = runTest {
Expand Down Expand Up @@ -185,6 +206,7 @@ internal class ACMEApiTest : ApiTest() {
assertEquals(expected, actual.value)
}
}

companion object {
private const val ACME_DISCOVERY_URL = "https://balderdash.hogwash.work:9000/acme/google-android/directory"
private const val ACME_DIRECTORIES_PATH = "https://balderdash.hogwash.work:9000/acme/google-android/directory"
Expand Down

0 comments on commit 0dd7b4a

Please sign in to comment.