Skip to content

Commit

Permalink
Fix Client hanging on unexpected Channel closures (#2610)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou authored Jan 15, 2024
1 parent 2d267b5 commit bc595f3
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 92 deletions.
2 changes: 1 addition & 1 deletion zio-http/src/main/scala/zio/http/ZClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ object ZClient {
_ <-
onComplete.await.interruptible.exit.flatMap { exit =>
if (exit.isInterrupted) {
channelInterface.interrupt
channelInterface.interrupt.ignore
.zipRight(connectionPool.invalidate(connection))
.uninterruptible
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ final class ClientInboundHandler(
}

override def exceptionCaught(ctx: ChannelHandlerContext, error: Throwable): Unit = {
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(
onResponse.fail(error) *> onComplete.fail(error),
)(unsafeClass, trace)
ctx.fireExceptionCaught(error)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ final class ClientResponseStreamHandler(
)
else {
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(
NettyFutureExecutor
.executed(ctx.close())
.as(ChannelState.Invalid)
.exit
.flatMap(onComplete.done(_)),
onComplete.succeed(ChannelState.Invalid) *> NettyFutureExecutor.executed(ctx.close()),
)(unsafeClass, trace)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import zio.http.netty.model.Conversions
import zio.http.netty.socket.NettySocketProtocol

import io.netty.channel.{Channel, ChannelFactory, ChannelHandler, EventLoopGroup}
import io.netty.handler.codec.PrematureChannelClosureException
import io.netty.handler.codec.http.websocketx.{WebSocketClientProtocolHandler, WebSocketFrame => JWebSocketFrame}
import io.netty.handler.codec.http.{FullHttpRequest, HttpObjectAggregator}

Expand All @@ -50,7 +51,7 @@ final case class NettyClientDriver private[netty] (
createSocketApp: () => WebSocketApp[Any],
webSocketConfig: WebSocketConfig,
)(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface] = {
NettyRequestEncoder.encode(req).flatMap { jReq =>
val f = NettyRequestEncoder.encode(req).flatMap { jReq =>
for {
_ <- Scope.addFinalizer {
ZIO.attempt {
Expand Down Expand Up @@ -152,6 +153,27 @@ final case class NettyClientDriver private[netty] (
}
}
}

f.ensuring(
ZIO
.unless(location.scheme.isWebSocket) {
// If the channel was closed and the promises were not completed, this will lead to the request hanging so we need
// to listen to the close future and complete the promises
NettyFutureExecutor
.executed(channel.closeFuture())
.interruptible
.zipRight(
// If onComplete was already set, it means another fiber is already in the process of fulfilling the promises
// so we don't need to fulfill `onResponse`
onComplete.interrupt && onResponse.fail(
new PrematureChannelClosureException(
"Channel closed while executing the request. This is likely caused due to a client connection misconfiguration",
),
),
)
}
.forkScoped,
)
}

override def createConnectionPool(dnsResolver: DnsResolver, config: ConnectionPoolConfig)(implicit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ object NettyConnectionPool {

private final class ZioNettyConnectionPool(
pool: ZKeyedPool[Throwable, PoolKey, JChannel],
maxItems: PoolKey => Int,
) extends NettyConnectionPool {
override def get(
location: Location.Absolute,
Expand All @@ -189,25 +190,44 @@ object NettyConnectionPool {
idleTimeout: Option[Duration],
connectionTimeout: Option[Duration],
localAddress: Option[InetSocketAddress] = None,
)(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] =
pool
.get(
PoolKey(
location,
proxy,
sslOptions,
maxInitialLineLength,
maxHeaderSize,
decompression,
idleTimeout,
connectionTimeout,
),
)
)(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] = ZIO.uninterruptibleMask { restore =>
val key = PoolKey(
location,
proxy,
sslOptions,
maxInitialLineLength,
maxHeaderSize,
decompression,
idleTimeout,
connectionTimeout,
)

restore(pool.get(key)).withEarlyRelease.flatMap { case (release, channel) =>
// Channel might have closed while in the pool, either because of a timeout or because of a connection error
// We retry a few times hoping to obtain an open channel
// NOTE: We need to release the channel before retrying, so that it can be closed and removed from the pool
// We do that in a forked fiber so that we don't "block" the current fiber while the new resource is obtained
if (channel.isOpen) ZIO.succeed(channel)
else invalidate(channel) *> release.forkDaemon *> ZIO.fail(None)
}
.retry(retrySchedule(key))
.catchAll {
case None => pool.get(key) // We did all we could, let the caller handle it
case e: Throwable => ZIO.fail(e)
}
.withFinalizer(c => ZIO.unless(c.isOpen)(invalidate(c)))
}

override def invalidate(channel: JChannel)(implicit trace: Trace): ZIO[Any, Nothing, Unit] =
pool.invalidate(channel)

override def enableKeepAlive: Boolean = true

private def retrySchedule[E](key: PoolKey)(implicit trace: Trace) =
Schedule.recurWhile[E] {
case None => true
case _ => false
} && Schedule.recurs(maxItems(key))
}

def fromConfig(
Expand Down Expand Up @@ -251,7 +271,6 @@ object NettyConnectionPool {
for {
driver <- ZIO.service[NettyClientDriver]
dnsResolver <- ZIO.service[DnsResolver]
poolPromise <- Promise.make[Nothing, ZKeyedPool[Throwable, PoolKey, JChannel]]
poolFn = (key: PoolKey) =>
createChannel(
driver.channelFactory,
Expand All @@ -266,20 +285,10 @@ object NettyConnectionPool {
key.connectionTimeout,
None,
dnsResolver,
).tap { channel =>
NettyFutureExecutor
.executed(channel.closeFuture())
.interruptible
.zipRight(
poolPromise.await.flatMap(_.invalidate(channel)),
)
.forkDaemon
}.uninterruptible
keyedPool <- ZKeyedPool
.make(poolFn, (key: PoolKey) => size(key.location))
.tap(poolPromise.succeed)
.tapErrorCause(poolPromise.failCause)
} yield new ZioNettyConnectionPool(keyedPool)
).uninterruptible
_size = (key: PoolKey) => size(key.location)
keyedPool <- ZKeyedPool.make(poolFn, _size)
} yield new ZioNettyConnectionPool(keyedPool, _size)

private def createDynamic(
min: Int,
Expand All @@ -296,7 +305,6 @@ object NettyConnectionPool {
for {
driver <- ZIO.service[NettyClientDriver]
dnsResolver <- ZIO.service[DnsResolver]
poolPromise <- Promise.make[Nothing, ZKeyedPool[Throwable, PoolKey, JChannel]]
poolFn = (key: PoolKey) =>
createChannel(
driver.channelFactory,
Expand All @@ -311,20 +319,13 @@ object NettyConnectionPool {
key.connectionTimeout,
None,
dnsResolver,
).tap { channel =>
NettyFutureExecutor
.executed(channel.closeFuture())
.interruptible
.zipRight(
poolPromise.await.flatMap(_.invalidate(channel)),
)
.forkDaemon
}.uninterruptible
keyedPool <- ZKeyedPool
.make(poolFn, (key: PoolKey) => min(key.location) to max(key.location), (key: PoolKey) => ttl(key.location))
.tap(poolPromise.succeed)
.tapErrorCause(poolPromise.failCause)
} yield new ZioNettyConnectionPool(keyedPool)
).uninterruptible
keyedPool <- ZKeyedPool.make(
poolFn,
(key: PoolKey) => min(key.location) to max(key.location),
(key: PoolKey) => ttl(key.location),
)
} yield new ZioNettyConnectionPool(keyedPool, key => max(key.location))

implicit final class BootstrapSyntax(val bootstrap: Bootstrap) extends AnyVal {
def withOption[T](option: ChannelOption[T], value: Option[T]): Bootstrap =
Expand Down
51 changes: 30 additions & 21 deletions zio-http/src/test/scala/zio/http/ClientHttpsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

package zio.http

import zio.test.Assertion.{anything, equalTo, fails, hasField}
import zio.test.TestAspect.{ignore, timeout}
import zio.test.assertZIO
import zio.{Scope, ZLayer, durationInt}
import zio._
import zio.test.Assertion._
import zio.test.TestAspect.{ignore, nonFlaky}
import zio.test.{TestAspect, assertZIO}

import zio.http.netty.NettyConfig
import zio.http.netty.client.NettyClientDriver
Expand All @@ -31,8 +31,8 @@ object ClientHttpsSpec extends ZIOHttpSpec {
trustStorePassword = "changeit",
)

val waterAerobics =
URL.decode("https://sports.api.decathlon.com/groups/water-aerobics").toOption.get
val zioDev =
URL.decode("https://zio.dev").toOption.get

val badRequest =
URL
Expand All @@ -47,11 +47,11 @@ object ClientHttpsSpec extends ZIOHttpSpec {

override def spec = suite("Https Client request")(
test("respond Ok") {
val actual = Client.request(Request.get(waterAerobics))
val actual = Client.request(Request.get(zioDev))
assertZIO(actual)(anything)
},
}.provide(ZLayer.succeed(ZClient.Config.default), partialClientLayer, Scope.default),
test("respond Ok with sslConfig") {
val actual = Client.request(Request.get(waterAerobics))
val actual = Client.request(Request.get(zioDev))
assertZIO(actual)(anything)
},
test("should respond as Bad Request") {
Expand All @@ -63,19 +63,28 @@ object ClientHttpsSpec extends ZIOHttpSpec {
assertZIO(actual)(equalTo(Status.BadRequest))
} @@ ignore,
test("should throw DecoderException for handshake failure") {
val actual = Client
.request(
Request.get(untrusted),
)
.exit
assertZIO(actual)(fails(hasField("class.simpleName", _.getClass.getSimpleName, equalTo("DecoderException"))))
},
).provide(
ZLayer.succeed(ZClient.Config.default.ssl(sslConfig)),
val actual = Client.request(Request.get(untrusted)).exit
assertZIO(actual)(
fails(
hasField(
"class.simpleName",
_.getClass.getSimpleName,
isOneOf(List("DecoderException", "PrematureChannelClosureException")),
),
),
)
} @@ nonFlaky(20),
)
.provideSomeLayer[Client](Scope.default)
.provideShared(
ZLayer.succeed(ZClient.Config.default.ssl(sslConfig)),
partialClientLayer,
) @@ TestAspect.withLiveClock

private val partialClientLayer = ZLayer.makeSome[ZClient.Config, Client](
Client.customized,
NettyClientDriver.live,
DnsResolver.default,
ZLayer.succeed(NettyConfig.default),
Scope.default,
) @@ ignore
ZLayer.succeed(NettyConfig.defaultWithFastShutdown),
)
}
26 changes: 16 additions & 10 deletions zio-http/src/test/scala/zio/http/SSLSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
package zio.http

import zio.test.Assertion.equalTo
import zio.test.TestAspect.{ignore, timeout}
import zio.test.{Gen, assertZIO, check}
import zio.{Scope, ZIO, ZLayer, durationInt}
import zio.test.{Gen, assertCompletes, assertNever, assertZIO}
import zio.{Scope, ZLayer}

import zio.http.netty.NettyConfig
import zio.http.netty.client.NettyClientDriver
Expand Down Expand Up @@ -62,14 +61,21 @@ object SSLSpec extends ZIOHttpSpec {
ZLayer.succeed(NettyConfig.default),
Scope.default,
),
test("fail with DecoderException when client doesn't have the server certificate") {
val actual = Client
// Unfortunately if the channel closes before we create the request, we can't extract the DecoderException
test(
"fail with DecoderException or PrematureChannelClosureException when client doesn't have the server certificate",
) {
Client
.request(Request.get(httpsUrl))
.catchSome {
case e if e.getClass.getSimpleName == "DecoderException" =>
ZIO.succeed("DecoderException")
}
assertZIO(actual)(equalTo("DecoderException"))
.fold(
{ e =>
val expectedErrors = List("DecoderException", "PrematureChannelClosureException")
val errorType = e.getClass.getSimpleName
if (expectedErrors.contains(errorType)) assertCompletes
else assertNever(s"request failed with unexpected error type: $errorType")
},
_ => assertNever("expected request to fail"),
)
}.provide(
Client.customized,
ZLayer.succeed(ZClient.Config.default.ssl(clientSSL2)),
Expand Down
4 changes: 2 additions & 2 deletions zio-http/src/test/scala/zio/http/ServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ object ServerSpec extends HttpRunnableSpec {
.mapZIO(_.asString)
.run()
.exit
assertZIO(res)(failsWithA[java.io.IOException])
assertZIO(res)(fails(anything))
} @@ TestAspect.timeout(10.seconds),
test("streaming failure - unknown content type") {
val res =
Expand All @@ -395,7 +395,7 @@ object ServerSpec extends HttpRunnableSpec {
.mapZIO(_.asString)
.run()
.exit
assertZIO(res)(failsWithA[java.io.IOException])
assertZIO(res)(fails(anything))
} @@ TestAspect.timeout(10.seconds),
suite("html")(
test("body") {
Expand Down
Loading

0 comments on commit bc595f3

Please sign in to comment.