Skip to content

Commit

Permalink
Support URL relative resolution, and use it in a ZClientAspect to fol…
Browse files Browse the repository at this point in the history
…low redirects (#2537)

* RFC3986 Relative Resolution

* Add follow redirects client aspect
  • Loading branch information
jgulotta authored Dec 8, 2023
1 parent e3e3ab1 commit 2e8b486
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 5 deletions.
54 changes: 52 additions & 2 deletions zio-http/src/main/scala/zio/http/Path.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package zio.http

import scala.collection.mutable

import zio.{Chunk, ChunkBuilder}

/**
Expand Down Expand Up @@ -184,6 +182,58 @@ final case class Path private[http] (flags: Path.Flags, segments: Chunk[String])
else Path.empty
} else self

/**
* RFC 3986 § 5.2.4 Remove Dot Segments
* @return
* the Path with `.` and `..` resolved and removed
*/
def removeDotSegments: Path = {
// See https://www.rfc-editor.org/rfc/rfc3986#section-5.2.4
val segments = new Array[String](self.segments.length)
var segmentCount = 0
// leading/trailing slashes may change but is unlikely
var flags = self.flags

var i = 0
val max = self.segments.length

if (!Flag.LeadingSlash.check(flags)) {
// § 5.2.4.2.A/D no leading slash, so skip all initial `./` and `../`
while (i < max && (self.segments(i) == "." | self.segments(i) == "..")) {
i += 1
}
// if the entire input was consumed, there is no more trailing slash
if (i == max) flags = Flag.TrailingSlash.remove(flags)
}

var loop = i < max
while (loop) {
val segment = self.segments(i)

i += 1
loop = i < max

if (segment == "..") {
segmentCount = (segmentCount - 1).max(0)
// § 5.2.4.2.C resolving `/..` and `/../` removes preceding slashes and is itself replaced by a slash
// so if we popped the first one we definitely have a leading slash
if (segmentCount == 0) flags = Flag.LeadingSlash.add(flags)
// § 5.2.4.2.C resolving `/..` and `/../` are both as-if replaced by a `/`
// so if this is the last segment, then we have a trailing slash
if (i == max) flags = Flag.TrailingSlash.add(flags)
} else if (segment == ".") {
// § 5.2.4.2.B resolving `/.` and `/./` are both as-if replaced by a `/`
// so if this is the last segment, then we have a trailing slash
if (i == max) flags = Flag.TrailingSlash.add(flags)
} else {
segments(segmentCount) = segment
segmentCount += 1
}
}

Path(flags, Chunk.fromArray(segments.take(segmentCount)))
}

/**
* Creates a new path from this one with it's segments reversed.
*/
Expand Down
80 changes: 79 additions & 1 deletion zio-http/src/main/scala/zio/http/URL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package zio.http

import java.net.{MalformedURLException, URI, URISyntaxException}
import java.net.{MalformedURLException, URI}

import scala.util.Try

Expand Down Expand Up @@ -184,6 +184,84 @@ final case class URL(
case _ => self.copy(kind = URL.Location.Relative)
}

/**
* RFC 3986 § 5.2 Relative Resolution
* @param reference
* the URL to resolve relative to ``this`` base URL
* @return
* the target URL
*/
def resolve(reference: URL): Either[String, URL] = {
// See https://www.rfc-editor.org/rfc/rfc3986#section-5.2
// § 5.2.1 - `self` is the base and already pre-parsed into components
// § 5.2.2 - strict parsing does not ignore the reference URL scheme, so we use it directly, instead of un-setting it

if (reference.kind.isRelative) {
// § 5.2.2 - reference scheme is undefined, i.e. it is relative
self.kind match {
// § 5.2.1 - `self` is the base and is required to have a scheme, therefore it must be absolute
case Location.Relative => Left("cannot resolve against relative url")

case location: Location.Absolute =>
var path: Path = null
var query: QueryParams = null

if (reference.path.isEmpty) {
// § 5.2.2 - empty reference path keeps base path unmodified
path = self.path
// § 5.2.2 - given an empty reference path, use non-empty reference query params,
// while empty reference query params keeps base query params
// NOTE: strictly, if the reference defines a query it should be used, even if that query is empty
// but currently no-query is not differentiated from empty-query
if (reference.queryParams.isEmpty) {
query = self.queryParams
} else {
query = reference.queryParams
}
} else {
// § 5.2.2 - non-empty reference path always keeps reference query params
query = reference.queryParams

if (reference.path.hasLeadingSlash) {
// § 5.2.2 - reference path starts from root, keep reference path without dot segments
path = reference.path.removeDotSegments
} else {
// § 5.2.2 - merge base and reference paths, then collapse dot segments
// § 5.2.3 - if base has an authority AND an empty path, use the reference path, ensuring a leading slash
// the authority is the [user]@host[:port], which is always present on `self`,
// so we only need to check for an empty path
if (self.path.isEmpty) {
path = reference.path.addLeadingSlash
} else {
// § 5.2.3 - otherwise (base has no authority OR a non-empty path), drop the very last portion of the base path,
// and append all the reference path components
path = Path(
Path.Flags.concat(self.path.flags, reference.path.flags),
self.path.segments.dropRight(1) ++ reference.path.segments,
)
}

path = path.removeDotSegments
}
}

val url = URL(path, location, query, reference.fragment)

Right(url)

}
} else {
// § 5.2.2 - if the reference scheme is defined, i.e. the reference is absolute,
// the target components are the reference components but with dot segments removed

// § 5.2.2 - if the reference scheme is undefined and authority is defined, keep the base scheme
// and take everything else from the reference, removing dot segments from the path
// NOTE: URL currently does not track authority separate from scheme to implement this
// so having an authority is the same as having a scheme and they are treated the same
Right(reference.copy(path = reference.path.removeDotSegments))
}
}

def scheme: Option[Scheme] = kind match {
case Location.Absolute(scheme, _, _) => Some(scheme)
case Location.Relative => None
Expand Down
120 changes: 119 additions & 1 deletion zio-http/src/main/scala/zio/http/ZClientAspect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ object ZClientAspect {
case (duration, Exit.Success(response)) =>
ZIO
.logLevel(level(response.status)) {
def requestHeaders =
def requestHeaders =
headers.collect {
case header: Header if loggedRequestHeaderNames.contains(header.headerName.toLowerCase) =>
LogAnnotation(header.headerName, header.renderedValue)
}.toSet

def responseHeaders =
response.headers.collect {
case header: Header if loggedResponseHeaderNames.contains(header.headerName.toLowerCase) =>
Expand Down Expand Up @@ -318,4 +319,121 @@ object ZClientAspect {
}
}
}

final def followRedirects[R, E](max: Int)(
onRedirectError: (Response, String) => ZIO[R, E, Response],
)(implicit trace: Trace): ZClientAspect[Nothing, R, Nothing, Body, E, Any, Nothing, Response] = {
new ZClientAspect[Nothing, R, Nothing, Body, E, Any, Nothing, Response] {
override def apply[
Env >: Nothing <: R,
In >: Nothing <: Body,
Err >: E <: Any,
Out >: Nothing <: Response,
](client: ZClient[Env, In, Err, Out]): ZClient[Env, In, Err, Out] = {
val oldDriver = client.driver

val newDriver = new ZClient.Driver[Env, Err] {
def scopedRedirectErr(resp: Response, message: String) =
ZIO.scopeWith(_ => onRedirectError(resp, message))

override def request(
version: Version,
method: Method,
url: URL,
headers: Headers,
body: Body,
sslConfig: Option[ClientSSLConfig],
proxy: Option[Proxy],
)(implicit trace: Trace): ZIO[Env & Scope, Err, Response] = {
def req(
attempt: Int,
version: Version,
method: Method,
url: URL,
headers: Headers,
body: Body,
sslConfig: Option[ClientSSLConfig],
proxy: Option[Proxy],
): ZIO[Env & Scope, Err, Response] = {
oldDriver.request(version, method, url, headers, body, sslConfig, proxy).flatMap { resp =>
if (resp.status.isRedirection) {
if (attempt < max) {
resp.headerOrFail(Header.Location) match {
case Some(locOrError) =>
locOrError match {
case Left(locHeaderErr) =>
scopedRedirectErr(resp, locHeaderErr)

case Right(loc) =>
url.resolve(loc.url) match {
case Left(relativeResolveErr) =>
scopedRedirectErr(resp, relativeResolveErr)

case Right(resolved) =>
req(attempt + 1, version, method, resolved, headers, body, sslConfig, proxy)
}
}
case None =>
scopedRedirectErr(resp, "no location header to resolve redirect")
}
} else {
scopedRedirectErr(resp, "followed maximum redirects")
}
} else {
ZIO.succeed(resp)
}
}
}

req(0, version, method, url, headers, body, sslConfig, proxy)
}

override def socket[Env1 <: Env](version: Version, url: URL, headers: Headers, app: WebSocketApp[Env1])(
implicit trace: Trace,
): ZIO[Env1 & Scope, Err, Response] = {
def sock(
attempt: Int,
version: Version,
url: URL,
headers: Headers,
app: WebSocketApp[Env1],
): ZIO[Env1 & Scope, Err, Response] = {
oldDriver.socket(version, url, headers, app).flatMap { resp =>
if (resp.status.isRedirection) {
if (attempt < max) {
resp.headerOrFail(Header.Location) match {
case Some(locOrError) =>
locOrError match {
case Left(locHeaderErr) =>
scopedRedirectErr(resp, locHeaderErr)

case Right(loc) =>
url.resolve(loc.url) match {
case Left(relativeResolveErr) =>
scopedRedirectErr(resp, relativeResolveErr)

case Right(resolved) =>
sock(attempt + 1, version, resolved, headers, app)
}
}
case None =>
scopedRedirectErr(resp, "no location header to resolve redirect")
}
} else {
scopedRedirectErr(resp, "followed maximum redirects")
}
} else {
ZIO.succeed(resp)
}
}
}

sock(0, version, url, headers, app)
}
}

client.transform(client.bodyEncoder, client.bodyDecoder, newDriver)
}
}
}
}
51 changes: 51 additions & 0 deletions zio-http/src/test/scala/zio/http/PathSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -413,5 +413,56 @@ object PathSpec extends ZIOHttpSpec with ExitAssertion {
}
},
),
suite("removeDotSegments")(
test("only leading slash and dots") {
val path = Path.decode("/./../")
val result = path.removeDotSegments
val expected = Path.root

assertTrue(result == expected)
},
test("only leading dots") {
val path = Path.decode("./../")
val result = path.removeDotSegments
val expected = Path.empty

assertTrue(result == expected)
},
test("leading slash and dots") {
val path = Path.decode("/./../path")
val result = path.removeDotSegments
val expected = Path.decode("/path")

assertTrue(result == expected)
},
test("leading dots and path") {
val path = Path.decode("./../path")
val result = path.removeDotSegments
val expected = Path.decode("path")

assertTrue(result == expected)
},
test("double dot to top") {
val path = Path.decode("path/../subpath")
val result = path.removeDotSegments
val expected = Path.decode("/subpath")

assertTrue(result == expected)
},
test("trailing double dots") {
val path = Path.decode("path/ignored/..")
val result = path.removeDotSegments
val expected = Path.decode("path/")

assertTrue(result == expected)
},
test("path traversal") {
val path = Path.decode("/start/ignored/./../path/also/ignored/../../end/.")
val result = path.removeDotSegments
val expected = Path.decode("/start/path/end/")

assertTrue(result == expected)
},
),
)
}
Loading

0 comments on commit 2e8b486

Please sign in to comment.