Skip to content

Commit

Permalink
Add more options to the SSL configuration (#3139)
Browse files Browse the repository at this point in the history
Co-authored-by: John A. De Goes <[email protected]>
  • Loading branch information
varshith257 and jdegoes authored Sep 27, 2024
1 parent e1afcf2 commit 40c4b6f
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package zio.http.netty.client

import java.io.{File, FileInputStream, InputStream}
import java.security.KeyStore
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.{KeyManagerFactory, TrustManagerFactory}

import scala.util.Using

Expand All @@ -31,6 +31,39 @@ import zio.http.ClientSSLConfig
import io.netty.handler.ssl.util.InsecureTrustManagerFactory
import io.netty.handler.ssl.{SslContext, SslContextBuilder}
private[netty] object ClientSSLConverter {
private def keyManagerTrustManagerToSslContext(
keyManagerInfo: Option[(String, InputStream, Option[Secret])],
trustManagerInfo: Option[(String, InputStream, Option[Secret])],
sslContextBuilder: SslContextBuilder,
): SslContextBuilder = {
val mkeyManagerFactory =
keyManagerInfo.map { case (keyStoreType, inputStream, maybePassword) =>
val keyStore = KeyStore.getInstance(keyStoreType)
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
val password = maybePassword.map(_.value.toArray).orNull

keyStore.load(inputStream, password)
keyManagerFactory.init(keyStore, password)
keyManagerFactory
}

val mtrustManagerFactory =
trustManagerInfo.map { case (keyStoreType, inputStream, maybePassword) =>
val keyStore = KeyStore.getInstance(keyStoreType)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
val password = maybePassword.map(_.value.toArray).orNull

keyStore.load(inputStream, password)
trustManagerFactory.init(keyStore)
trustManagerFactory
}

var bldr = SslContextBuilder.forClient()
mkeyManagerFactory.foreach(kmf => bldr = bldr.keyManager(kmf))
mtrustManagerFactory.foreach(tmf => bldr = bldr.trustManager(tmf))
bldr
}

private def trustStoreToSslContext(
trustStoreStream: InputStream,
trustStorePassword: Secret,
Expand Down Expand Up @@ -78,6 +111,41 @@ private[netty] object ClientSSLConverter {
case ClientSSLConfig.FromTrustStoreFile(trustStorePath, trustStorePassword) =>
val trustStoreStream = new FileInputStream(trustStorePath)
trustStoreToSslContext(trustStoreStream, trustStorePassword, sslContextBuilder)
case ClientSSLConfig.FromJavaxNetSsl(
keyManagerKeyStoreType,
keyManagerSource,
keyManagerPassword,
trustManagerKeyStoreType,
trustManagerSource,
trustManagerPassword,
) =>
val keyManagerInfo =
keyManagerSource match {
case ClientSSLConfig.FromJavaxNetSsl.File(path) =>
Option(new FileInputStream(path)).map(inputStream =>
(keyManagerKeyStoreType, inputStream, keyManagerPassword),
)
case ClientSSLConfig.FromJavaxNetSsl.Resource(path) =>
Option(getClass.getClassLoader.getResourceAsStream(path)).map(inputStream =>
(keyManagerKeyStoreType, inputStream, keyManagerPassword),
)
case ClientSSLConfig.FromJavaxNetSsl.Empty => None
}

val trustManagerInfo =
trustManagerSource match {
case ClientSSLConfig.FromJavaxNetSsl.File(path) =>
Option(new FileInputStream(path)).map(inputStream =>
(trustManagerKeyStoreType, inputStream, trustManagerPassword),
)
case ClientSSLConfig.FromJavaxNetSsl.Resource(path) =>
Option(getClass.getClassLoader.getResourceAsStream(path)).map(inputStream =>
(trustManagerKeyStoreType, inputStream, trustManagerPassword),
)
case ClientSSLConfig.FromJavaxNetSsl.Empty => None
}

keyManagerTrustManagerToSslContext(keyManagerInfo, trustManagerInfo, sslContextBuilder)
}

def toNettySSLContext(sslConfig: ClientSSLConfig): SslContext = {
Expand Down
33 changes: 23 additions & 10 deletions zio-http/jvm/src/test/scala/zio/http/ClientHttpsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,22 @@ package zio.http

import zio._
import zio.test.Assertion._
import zio.test.TestAspect.{ignore, nonFlaky}
import zio.test.TestAspect.nonFlaky
import zio.test.{TestAspect, assertZIO}

import zio.http.netty.NettyConfig
import zio.http.netty.client.NettyClientDriver

object ClientHttpsSpec extends ZIOHttpSpec {

val sslConfig = ClientSSLConfig.FromTrustStoreResource(
trustStorePath = "truststore.jks",
trustStorePassword = "changeit",
)
abstract class ClientHttpsSpecBase extends ZIOHttpSpec {
val sslConfig: ClientSSLConfig

val zioDev =
URL.decode("https://zio.dev").toOption.get

val badRequest =
URL
.decode(
"https://www.whatissslcertificate.com/google-has-made-the-list-of-untrusted-providers-of-digital-certificates/",
"https://httpbin.org/status/400",
)
.toOption
.get
Expand All @@ -57,7 +53,7 @@ object ClientHttpsSpec extends ZIOHttpSpec {
test("should respond as Bad Request") {
val actual = Client.batched(Request.get(badRequest)).map(_.status)
assertZIO(actual)(equalTo(Status.BadRequest))
} @@ ignore,
},
test("should throw DecoderException for handshake failure") {
val actual = Client.batched(Request.get(untrusted)).exit
assertZIO(actual)(
Expand All @@ -69,7 +65,7 @@ object ClientHttpsSpec extends ZIOHttpSpec {
),
),
)
} @@ nonFlaky(20) @@ ignore,
} @@ nonFlaky(20),
)
.provideShared(
ZLayer.succeed(ZClient.Config.default.ssl(sslConfig)),
Expand All @@ -83,3 +79,20 @@ object ClientHttpsSpec extends ZIOHttpSpec {
ZLayer.succeed(NettyConfig.defaultWithFastShutdown),
)
}

object ClientHttpsSpec extends ClientHttpsSpecBase {

val sslConfig = ClientSSLConfig.FromTrustStoreResource(
trustStorePath = "truststore.jks",
trustStorePassword = "changeit",
)
}

object ClientHttpsFromJavaxNetSslSpec extends ClientHttpsSpecBase {

val sslConfig =
ClientSSLConfig.FromJavaxNetSsl
.builderWithTrustManagerResource("trustStore.jks")
.trustManagerPassword("changeit")
.build()
}
97 changes: 97 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/ClientSSLConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ object ClientSSLConfig {
val trustStorePath = Config.string("trust-store-path")
val trustStorePassword = Config.secret("trust-store-password")

val keyManagerKeyStoreType = Config.string("keyManagerKeyStoreType")
val keyManagerFile = Config.string("keyManagerFile")
val keyManagerResource = Config.string("keyManagerResource")
val keyManagerPassword = Config.secret("keyManagerPassword")
val trustManagerKeyStoreType = Config.string("trustManagerKeyStoreType")
val trustManagerFile = Config.string("trustManagerFile")
val trustManagerResource = Config.string("trustManagerResource")
val trustManagerPassword = Config.secret("trustManagerPassword")

val default = Config.succeed(Default)
val fromCertFile = certPath.map(FromCertFile(_))
val fromCertResource = certPath.map(FromCertResource(_))
Expand All @@ -39,6 +48,45 @@ object ClientSSLConfig {
serverCertConfig.zipWith(clientCertConfig)(FromClientAndServerCert(_, _))
}

val fromJavaxNetSsl = {
keyManagerKeyStoreType.optional
.zip(keyManagerFile.optional)
.zip(keyManagerResource.optional)
.zip(keyManagerPassword.optional)
.zip(trustManagerKeyStoreType.optional)
.zip(
trustManagerFile.optional
.zip(trustManagerResource.optional)
.validate("must supply trustManagerFile or trustManagerResource")(pair =>
pair._1.isDefined || pair._2.isDefined,
),
)
.zip(trustManagerPassword.optional)
.map { case (kmkst, kmf, kmr, kmpass, tmkst, (tmf, tmr), tmpass) =>
val bldr0 =
List[(Option[String], FromJavaxNetSsl => String => FromJavaxNetSsl)](
(kmkst, b => b.keyManagerKeyStoreType(_)),
(kmf, b => b.keyManagerFile),
(kmr, b => b.keyManagerResource),
(tmkst, b => b.trustManagerKeyStoreType(_)),
(tmf, b => b.trustManagerFile),
(tmr, b => b.trustManagerResource),
)
.foldLeft(FromJavaxNetSsl()) { case (bldr, (maybe, lens)) =>
maybe.fold(bldr)(s => lens(bldr)(s))
}

List[(Option[Secret], FromJavaxNetSsl => Secret => FromJavaxNetSsl)](
(kmpass, b => b.keyManagerPassword(_)),
(tmpass, b => b.trustManagerPassword(_)),
)
.foldLeft(bldr0) { case (bldr, (maybe, lens)) =>
maybe.fold(bldr)(s => lens(bldr)(s))
}
.build()
}
}

tpe.switch(
"Default" -> default,
"FromCertFile" -> fromCertFile,
Expand All @@ -58,6 +106,55 @@ object ClientSSLConfig {
clientCertConfig: ClientSSLCertConfig,
) extends ClientSSLConfig

final case class FromJavaxNetSsl(
keyManagerKeyStoreType: String = "JKS",
keyManagerSource: FromJavaxNetSsl.Source = FromJavaxNetSsl.Empty,
keyManagerPassword: Option[Secret] = None,
trustManagerKeyStoreType: String = "JKS",
trustManagerSource: FromJavaxNetSsl.Source = FromJavaxNetSsl.Empty,
trustManagerPassword: Option[Secret] = None,
) extends ClientSSLConfig { self =>

def isValidBuild: Boolean = trustManagerSource != FromJavaxNetSsl.Empty
def isInvalidBuild: Boolean = !isValidBuild
def build(): FromJavaxNetSsl = this

def keyManagerKeyStoreType(tpe: String): FromJavaxNetSsl = self.copy(keyManagerKeyStoreType = tpe)
def keyManagerFile(file: String): FromJavaxNetSsl =
keyManagerSource match {
case FromJavaxNetSsl.Resource(_) => this
case _ => self.copy(keyManagerSource = FromJavaxNetSsl.File(file))
}
def keyManagerResource(path: String): FromJavaxNetSsl = self.copy(keyManagerSource = FromJavaxNetSsl.Resource(path))
def keyManagerPassword(password: Secret): FromJavaxNetSsl = self.copy(keyManagerPassword = Some(password))
def keyManagerPassword(password: String): FromJavaxNetSsl = keyManagerPassword(Secret(password))

def trustManagerKeyStoreType(tpe: String): FromJavaxNetSsl = self.copy(trustManagerKeyStoreType = tpe)
def trustManagerFile(file: String): FromJavaxNetSsl =
trustManagerSource match {
case FromJavaxNetSsl.Resource(_) => this
case _ => self.copy(trustManagerSource = FromJavaxNetSsl.File(file))
}
def trustManagerResource(path: String): FromJavaxNetSsl =
self.copy(trustManagerSource = FromJavaxNetSsl.Resource(path))
def trustManagerPassword(password: Secret): FromJavaxNetSsl = self.copy(trustManagerPassword = Some(password))
def trustManagerPassword(password: String): FromJavaxNetSsl = trustManagerPassword(Secret(password))
}

object FromJavaxNetSsl {

sealed trait Source extends Product with Serializable
case object Empty extends Source
final case class File(file: String) extends Source
final case class Resource(resource: String) extends Source

def builderWithTrustManagerFile(file: String): FromJavaxNetSsl =
FromJavaxNetSsl().trustManagerFile(file)

def builderWithTrustManagerResource(resource: String): FromJavaxNetSsl =
FromJavaxNetSsl().trustManagerResource(resource)
}

object FromTrustStoreResource {
def apply(trustStorePath: String, trustStorePassword: String): FromTrustStoreResource =
FromTrustStoreResource(trustStorePath, Secret(trustStorePassword))
Expand Down

0 comments on commit 40c4b6f

Please sign in to comment.