Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalise path templates when generating ZIO Http RoutePattern-s #4051

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,25 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](
basicRequest.get(uri"$baseUri/p1/x/p3").send(backend).map(_.body shouldBe Right("2: x")) >>
basicRequest.get(uri"$baseUri/p1/y/p3").send(backend).map(_.body shouldBe Right("2: y")) >>
basicRequest.get(uri"$baseUri/p1/p2/p4").send(backend).map(_.code shouldBe StatusCode.NotFound)
},
// #4050
testServer(
"two endpoints with fixed path & path capture as the middle component, different methods",
NonEmptyList.of(
route(
List[ServerEndpoint[Any, F]](
endpoint.get.in("p1" / "p2").out(stringBody).serverLogic(_ => pureResult("1".asRight[Unit])),
endpoint.delete.in("p1" / path[String]("p")).out(stringBody).serverLogic((v: String) => pureResult(s"2: $v".asRight[Unit]))
)
)
)
) { (backend, baseUri) =>
basicRequest.get(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe Right("1")) >>
basicRequest.get(uri"$baseUri/p1/x").send(backend).map(_.code shouldBe StatusCode.MethodNotAllowed) >>
basicRequest.delete(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe Right("2: p2")) >>
basicRequest.delete(uri"$baseUri/p1/p3").send(backend).map(_.body shouldBe Right("2: p3")) >>
basicRequest.get(uri"$baseUri/p1/p2/p3").send(backend).map(_.code shouldBe StatusCode.NotFound) >>
basicRequest.delete(uri"$baseUri/p1/p2/p3").send(backend).map(_.code shouldBe StatusCode.NotFound)
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import sttp.tapir.ztapir._
import zio._
import zio.http.codec.PathCodec
import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _}
import scala.util.chaining._

trait ZioHttpInterpreter[R] {
def zioHttpServerOptions: ZioHttpServerOptions[R] = ZioHttpServerOptions.default
Expand Down Expand Up @@ -61,8 +62,8 @@ trait ZioHttpInterpreter[R] {
// here we'll keep the endpoint together with the meta-data needed to create the zio-http routing information
case class ServerEndpointWithPattern(
index: Int,
pathTemplate: String,
routePattern: RoutePattern[_],
pathTemplate: Vector[String],
routePattern: RoutePattern[Any], // the Any here is a way to work around the type checker
endpoint: ZServerEndpoint[R & R2, ZioStreams with WebSockets]
)

Expand All @@ -72,13 +73,13 @@ trait ZioHttpInterpreter[R] {

// Creating the path template - no-trailing-slash inputs are treated as wildcard inputs, as they are usually
// accompanied by endpoints which handle wildcard path inputs, when the `/` is present (to serve files). They
// need to end up in the same group (see below), so that they are disambiguated by tapir's logic.
val pathTemplate = inputs.foldLeft("") { case (p, component) =>
// need to end up in the same group (see below), so that they are disambiguated by Tapir's logic.
val pathTemplate = inputs.foldLeft(Vector.empty[String]) { case (p, component) =>
component match {
case _: EndpointInput.PathCapture[_] => p + "/?"
case _: EndpointInput.PathsCapture[_] => p + "/..."
case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => p + "/..."
case i: EndpointInput.FixedPath[_] => p + "/" + i.s
case _: EndpointInput.PathCapture[_] => p :+ "?"
case _: EndpointInput.PathsCapture[_] => p :+ "..."
case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => p :+ "..."
case i: EndpointInput.FixedPath[_] => p :+ s"{${i.s}}"
case _ => p
}
}
Expand All @@ -94,7 +95,7 @@ trait ZioHttpInterpreter[R] {
case _ => false
}

val routePattern = if (hasPath) {
val routePattern: RoutePattern[Any] = if (hasPath) {
val initialPattern = RoutePattern(Method.ANY, PathCodec.empty).asInstanceOf[RoutePattern[Any]]
// The second tuple parameter specifies if PathCodec.trailing should be added to the route's pattern. It can
// be added either because of a PathsCapture, or because of an noTrailingSlash input.
Expand All @@ -109,7 +110,7 @@ trait ZioHttpInterpreter[R] {
}
}

if (addTrailing) p / PathCodec.trailing else p
if (addTrailing) (p / PathCodec.trailing).asInstanceOf[RoutePattern[Any]] else p
} else {
// if there are no path inputs, we return a catch-all
RoutePattern(Method.ANY, PathCodec.trailing).asInstanceOf[RoutePattern[Any]]
Expand All @@ -118,21 +119,59 @@ trait ZioHttpInterpreter[R] {
ServerEndpointWithPattern(index, pathTemplate, routePattern, se)
}

// Grouping the endpoints by path template. This way, if there are multiple endpoints with/without trailing slash or
// with path wildcards, they will end up in the same group, and they will be disambiguated by the tapir logic.
// That's because there's not way currently to create a zio-http route pattern which would match on
// no-trailing-slashes. A group also includes multiple endpoints with different methods, but same path.
val widenedSesGroupedByPathPrefixTemplate = widenedSes.zipWithIndex
.map { case (se, index) => toPattern(se, index) }
.groupBy(_.pathTemplate)
.toList
.map(_._2)
// we try to maintain the order of endpoints as passed by the user; this order might be changed if there are
// endpoints with/without trailing slashes, or with different methods, which are not passed as subsequent
// values in the original `ses` list
.sortBy(_.map(_.index).min)

val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathPrefixTemplate.map { sesWithPattern =>
/** `t1` and `t2` are both path templates as created by `toPattern` above. Each path template is a vector of: ? | ... | {string}. This
* method checks if `t1` is at least as general as `t2`, that is if each request that matches `t2` also matches `t1`
*/
def isAtLeastAsGeneralAs(t1: Vector[String], t2: Vector[String]): Boolean = (t1, t2) match {
case ("..." +: _, _) => true
case (_, "..." +: _) => false
case ("?" +: tail1, "?" +: tail2) => isAtLeastAsGeneralAs(tail1, tail2)
case ("?" +: tail1, _ +: tail2) => isAtLeastAsGeneralAs(tail1, tail2)
case (_ +: _, "?" +: _) => false
case (p1 +: tail1, p2 +: tail2) => (p1 == p2) && isAtLeastAsGeneralAs(tail1, tail2)
case (Vector(), Vector()) => true
case _ => false
}

/** For each server endpoint, find the most general template among all the templates in the list, and use it for the endpoint, along
* with the `RoutePattern` corresponding to that template.
*/
def generaliseTemplates(endpoints: List[ServerEndpointWithPattern]): List[ServerEndpointWithPattern] = {
// de-duplicating the path templates
val allTemplates: List[(Vector[String], RoutePattern[Any])] = endpoints.map(se => (se.pathTemplate, se.routePattern)).toMap.toList
endpoints.map { se =>
val mostGeneral: (Vector[String], RoutePattern[Any]) =
allTemplates.foldLeft((se.pathTemplate, se.routePattern)) {
case ((mostGeneralTemplate, mostGeneralPattern), (template, pattern)) =>
if (template != mostGeneralTemplate && isAtLeastAsGeneralAs(template, mostGeneralTemplate)) {
(template, pattern)
} else {
(mostGeneralTemplate, mostGeneralPattern)
}
}
se.copy(pathTemplate = mostGeneral._1, routePattern = mostGeneral._2)
}
}

// Generating a path tempalte for each endpoint, and then finding the most general template among all of the
// endpoints. Once this is done, grouping the endpoints by path template. This way, if there are multiple endpoints
// with/without trailing slash or with path wildcards, they will end up in the same group, and they will be
// disambiguated by the Tapir logic. That's because there's no way currently to create a zio-http route pattern
// which would match on no-trailing-slashes. A group also includes multiple endpoints with different methods, but
// same path.
val widenedSesGroupedByPathTemplate =
widenedSes.zipWithIndex
.map { case (se, index) => toPattern(se, index) }
.pipe(generaliseTemplates)
.groupBy(_.pathTemplate)
.toList
.map(_._2)
// we try to maintain the order of endpoints as passed by the user; this order might be changed if there are
// endpoints with/without trailing slashes, or with different methods, which are not passed as subsequent
// values in the original `ses` list
.sortBy(_.map(_.index).min)

val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathTemplate.map { sesWithPattern =>
val pattern = sesWithPattern.head.routePattern
val endpoints = sesWithPattern.sortBy(_.index).map(_.endpoint)
// The pattern that we generate should be the same for all endpoints in a group
Expand Down
Loading