Skip to content

Commit

Permalink
Schema based HeaderOps (#3310)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Feb 18, 2025
1 parent 08a2531 commit e75382e
Show file tree
Hide file tree
Showing 14 changed files with 237 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ jobs:
matrix:
os: [ubuntu-latest]
scala: [2.13.16]
java: [temurin@8]
java: [zulu@8]
runs-on: ${{ matrix.os }}
steps:
- uses: coursier/setup-action@v1
Expand Down
2 changes: 2 additions & 0 deletions project/MimaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ object MimaSettings {
exclude[Problem]("zio.http.endpoint.openapi.OpenAPIGen#AtomizedMetaCodecs.apply"),
exclude[Problem]("zio.http.endpoint.openapi.OpenAPIGen#AtomizedMetaCodecs.this"),
exclude[Problem]("zio.http.endpoint.openapi.OpenAPIGen#AtomizedMetaCodecs.copy"),
exclude[IncompatibleMethTypeProblem]("zio.http.Middleware.addHeader"),
exclude[IncompatibleMethTypeProblem]("zio.http.HandlerAspect.addHeader")
),
mimaFailOnProblem := failOnProblem,
)
Expand Down
127 changes: 124 additions & 3 deletions zio-http/jvm/src/test/scala/zio/http/HeaderSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,22 @@

package zio.http

import zio.NonEmptyChunk
import java.time.Instant
import java.util.UUID

import zio._
import zio.test.Assertion._
import zio.test.assert
import zio.test._

import zio.schema._

object HeaderSpec extends ZIOHttpSpec {

case class SimpleWrapper(a: String)
implicit val simpleWrapperSchema: Schema[SimpleWrapper] = DeriveSchema.gen[SimpleWrapper]
case class Foo(a: Int, b: SimpleWrapper, c: NonEmptyChunk[String], chunk: Chunk[String])
implicit val fooSchema: Schema[Foo] = DeriveSchema.gen[Foo]

def spec = suite("Header")(
suite("getHeader")(
test("should not return header that doesn't exist in list") {
Expand Down Expand Up @@ -83,6 +93,117 @@ object HeaderSpec extends ZIOHttpSpec {
assert(actual)(isFalse)
},
),
suite("add typed")(
test("primitives") {
val uuid = "123e4567-e89b-12d3-a456-426614174000"
assertTrue(
Headers.empty.addHeader("a", 1).rawHeader("a").get == "1",
Headers.empty.addHeader("a", 1.0d).rawHeader("a").get == "1.0",
Headers.empty.addHeader("a", 1.0f).rawHeader("a").get == "1.0",
Headers.empty.addHeader("a", 1L).rawHeader("a").get == "1",
Headers.empty.addHeader("a", 1.toShort).rawHeader("a").get == "1",
Headers.empty.addHeader("a", true).rawHeader("a").get == "true",
Headers.empty.addHeader("a", 'a').rawHeader("a").get == "a",
Headers.empty.addHeader("a", Instant.EPOCH).rawHeader("a").get == "1970-01-01T00:00:00Z",
Headers.empty
.addHeader("a", UUID.fromString(uuid))
.rawHeader("a")
.get == uuid,
)

},
test("collections") {
assertTrue(
// Chunk
Headers.empty.addHeader("a", Chunk.empty[Int]).rawHeader("a").isEmpty,
Headers.empty.addHeader("a", Chunk(1)).rawHeaders("a") == Chunk("1"),
Headers.empty.addHeader("a", Chunk(1, 2)).rawHeaders("a") == Chunk("1", "2"),
Headers.empty.addHeader("a", Chunk(1.0, 2.0)).rawHeaders("a") == Chunk("1.0", "2.0"),
// List
Headers.empty.addHeader("a", List.empty[Int]).rawHeader("a").isEmpty,
Headers.empty.addHeader("a", List(1)).rawHeaders("a") == Chunk("1"),
// NonEmptyChunk
Headers.empty.addHeader("a", NonEmptyChunk(1)).rawHeaders("a") == Chunk("1"),
Headers.empty.addHeader("a", NonEmptyChunk(1, 2)).rawHeaders("a") == Chunk("1", "2"),
)
},
test("case class") {
val foo = Foo(1, SimpleWrapper("foo"), NonEmptyChunk("1", "2"), Chunk("foo", "bar"))
val fooEmpty = Foo(0, SimpleWrapper(""), NonEmptyChunk("1"), Chunk.empty)
assertTrue(
Headers.empty.addHeader(foo).rawHeader("a").get == "1",
Headers.empty.addHeader(foo).rawHeader("b").get == "foo",
Headers.empty.addHeader(foo).rawHeaders("c") == Chunk("1", "2"),
Headers.empty.addHeader(foo).rawHeaders("chunk") == Chunk("foo", "bar"),
Headers.empty.addHeader(fooEmpty).rawHeader("a").get == "0",
Headers.empty.addHeader(fooEmpty).rawHeader("b").get == "",
Headers.empty.addHeader(fooEmpty).rawHeaders("c") == Chunk("1"),
Headers.empty.addHeader(fooEmpty).rawHeaders("chunk").isEmpty,
)
},
),
suite("schema based getters")(
test("pure") {
val typed = "typed"
val default = 3
val invalidTyped = "invalidTyped"
val unknown = "non-existent"
val headers = Headers(typed -> "1", typed -> "2", "invalid-typed" -> "str")
val single = Headers(typed -> "1")
val headersFoo = Headers("a" -> "1", "b" -> "foo", "c" -> "2", "chunk" -> "foo", "chunk" -> "bar")
assertTrue(
single.header[Int](typed) == Right(1),
headers.header[Int](invalidTyped).isLeft,
headers.header[Int](unknown).isLeft,
single.headerOrElse[Int](typed, default) == 1,
headers.headerOrElse[Int](invalidTyped, default) == default,
headers.headerOrElse[Int](unknown, default) == default,
headers.header[Chunk[Int]](typed) == Right(Chunk(1, 2)),
headers.header[Chunk[Int]](invalidTyped).isLeft,
headers.header[Chunk[Int]](unknown) == Right(Chunk.empty),
headers.header[NonEmptyChunk[Int]](unknown).isLeft,
headers.headerOrElse[Chunk[Int]](typed, Chunk(default)) == Chunk(1, 2),
headers.headerOrElse[Chunk[Int]](invalidTyped, Chunk(default)) == Chunk(default),
headers.headerOrElse[Chunk[Int]](unknown, Chunk(default)) == Chunk.empty,
headers.headerOrElse[NonEmptyChunk[Int]](unknown, NonEmptyChunk(default)) == NonEmptyChunk(default),
// case class
headersFoo.header[Foo] == Right(Foo(1, SimpleWrapper("foo"), NonEmptyChunk("2"), Chunk("foo", "bar"))),
headersFoo.header[SimpleWrapper] == Right(SimpleWrapper("1")),
headersFoo.header[SimpleWrapper]("b") == Right(SimpleWrapper("foo")),
headers.header[Foo].isLeft,
headersFoo.headerOrElse[Foo](Foo(0, SimpleWrapper(""), NonEmptyChunk("1"), Chunk.empty)) == Foo(
1,
SimpleWrapper("foo"),
NonEmptyChunk("2"),
Chunk("foo", "bar"),
),
headers.headerOrElse[Foo](Foo(0, SimpleWrapper(""), NonEmptyChunk("1"), Chunk.empty)) == Foo(
0,
SimpleWrapper(""),
NonEmptyChunk("1"),
Chunk.empty,
),
)
},
test("as ZIO") {
val typed = "typed"
val invalidTyped = "invalidTyped"
val unknown = "non-existent"
val headers = Headers(typed -> "1", typed -> "2", "invalid-typed" -> "str")
val single = Headers(typed -> "1")
assertZIO(single.headerZIO[Int](typed))(equalTo(1)) &&
assertZIO(single.headerZIO[Int](unknown).exit)(fails(anything)) &&
assertZIO(single.headerZIO[Chunk[Int]](typed))(hasSize(equalTo(1))) &&
assertZIO(single.headerZIO[Chunk[Int]](unknown).exit)(succeeds(equalTo(Chunk.empty[Int]))) &&
assertZIO(single.headerZIO[NonEmptyChunk[Int]](unknown).exit)(fails(anything)) &&
assertZIO(headers.headerZIO[Int](invalidTyped).exit)(fails(anything)) &&
assertZIO(headers.headerZIO[Int](unknown).exit)(fails(anything)) &&
assertZIO(headers.headerZIO[Chunk[Int]](typed))(hasSize(equalTo(2))) &&
assertZIO(headers.headerZIO[Chunk[Int]](invalidTyped).exit)(fails(anything)) &&
assertZIO(headers.headerZIO[Chunk[Int]](unknown).exit)(succeeds(equalTo(Chunk.empty[Int]))) &&
assertZIO(headers.headerZIO[NonEmptyChunk[Int]](unknown).exit)(fails(anything))
},
),
suite("cookie")(
test("should be able to extract more than one header with the same name") {
val firstCookie = Cookie.Response("first", "value")
Expand All @@ -97,7 +218,7 @@ object HeaderSpec extends ZIOHttpSpec {
)
},
test("should return an empty sequence if no headers in the response") {
val headers = Headers()
val headers = Headers.empty
assert(headers.getAll(Header.SetCookie))(hasSameElements(Seq.empty))
},
),
Expand Down
4 changes: 2 additions & 2 deletions zio-http/shared/src/main/scala/zio/http/Handler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import zio.stream.ZStream

import zio.http.Handler.ApplyContextAspect
import zio.http.Header.HeaderType
import zio.http.internal.HeaderModifier
import zio.http.internal.{HeaderGetters, HeaderModifier}
import zio.http.template._

sealed trait Handler[-R, +Err, -In, +Out] { self =>
Expand Down Expand Up @@ -1139,7 +1139,7 @@ object Handler extends HandlerPlatformSpecific with HandlerVersionSpecific {
* Updates the current Headers with new one, using the provided update
* function passed.
*/
override def updateHeaders(update: Headers => Headers)(implicit trace: Trace): RequestHandler[R, Err] =
def updateHeaders(update: Headers => Headers)(implicit trace: Trace): RequestHandler[R, Err] =
self.map(_.updateHeaders(update))
}

Expand Down
3 changes: 3 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/HandlerAspect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ final case class HandlerAspect[-Env, +CtxOut](
}
object HandlerAspect extends HandlerAspects {

final protected override def addHeader(name: CharSequence, value: CharSequence): HandlerAspect[Any, Unit] =
HandlerAspect.addHeader[String](name.toString, value.toString)

final class InterceptPatch[State](val fromRequest: Request => State) extends AnyVal {
def apply(result: (Response, State) => Response.Patch): HandlerAspect[Any, Unit] =
HandlerAspect.interceptHandlerStateful(
Expand Down
11 changes: 7 additions & 4 deletions zio-http/shared/src/main/scala/zio/http/Header.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ object Header {
override def headerType: HeaderType.Typed[Custom] = new Header.HeaderType {
override type HeaderValue = Custom

override def name: String = self.customName.toString
override def name: String = self.customName.toString.toLowerCase

override def parse(value: String): Either[String, HeaderValue] = Right(Custom(self.customName, value))

Expand Down Expand Up @@ -228,14 +228,17 @@ object Header {
override def equals(that: Any): Boolean = {
that match {
case Custom(k, v) =>
def eqs(l: CharSequence, r: CharSequence): Boolean = {
def eqs(l: CharSequence, r: CharSequence, caseSensitive: Boolean): Boolean = {
if (l.length() != r.length()) false
else {
var i = 0
var equal = true

while (i < l.length()) {
if (l.charAt(i) != r.charAt(i)) {
if (
(caseSensitive && l.charAt(i) != r
.charAt(i)) || (!caseSensitive && l.charAt(i).toLower != r.charAt(i).toLower)
) {
equal = false
i = l.length()
}
Expand All @@ -245,7 +248,7 @@ object Header {
}
}

eqs(self.customName, k) && eqs(self.value, v)
eqs(self.customName, k, caseSensitive = false) && eqs(self.value, v, caseSensitive = true)

case _ => false
}
Expand Down
3 changes: 3 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/Headers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ object Headers {

def apply(tuple2: (CharSequence, CharSequence)): Headers = apply(tuple2._1, tuple2._2)

def apply(value: (CharSequence, CharSequence), values: (CharSequence, CharSequence)*): Headers =
Headers.FromIterable((value +: values).map { case (k, v) => Header.Custom(k, v) })

def apply(headers: Header*): Headers = FromIterable(headers)

def apply(iter: Iterable[Header]): Headers = FromIterable(iter)
Expand Down
5 changes: 4 additions & 1 deletion zio-http/shared/src/main/scala/zio/http/Middleware.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ trait Middleware[-UpperEnv] { self =>
@nowarn("msg=shadows type")
object Middleware extends HandlerAspects {

final protected override def addHeader(name: CharSequence, value: CharSequence): HandlerAspect[Any, Unit] =
HandlerAspect.addHeader[String](name.toString, value.toString)

/**
* Configuration for the CORS aspect.
*/
Expand Down Expand Up @@ -191,7 +194,7 @@ object Middleware extends HandlerAspects {
routes.transform[Env1] { h =>
handler { (req: Request) =>
if (req.headers.contains(headerName)) h(req)
else h(req.addHeader(headerName, make))
else h(req.addHeader[String](headerName, make))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion zio-http/shared/src/main/scala/zio/http/Response.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ final case class Response(
}

def contentType(mediaType: MediaType): Response =
self.addHeader("content-type", mediaType.fullType)
self.addHeader[String]("content-type", mediaType.fullType)

/**
* Consumes the streaming body fully and then discards it while also ignoring
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package zio.http.internal

import zio.Chunk
import zio._

import zio.schema.Schema

import zio.http.Header.HeaderType
import zio.http.Headers
import zio.http.codec.HttpCodecError

/**
* Maintains a list of operators that parse and extract data from the headers.
Expand All @@ -39,6 +42,65 @@ trait HeaderGetters { self =>
parsed.toOption
}

/**
* Retrieves the header with the specified name as a value of the specified
* type. The type must have a schema and can be a primitive type (e.g. Int,
* String, UUID, Instant etc.), a case class with a single field or a
* collection of either of these.
*/
final def header[T](name: String)(implicit schema: Schema[T]): Either[HttpCodecError.HeaderError, T] =
try
Right(
StringSchemaCodec
.headerFromSchema(schema, ErrorConstructor.header, name)
.decode(headers),
)
catch {
case e: HttpCodecError.HeaderError => Left(e)
}

/**
* Retrieves headers as a value of the specified type. The type must have a
* schema and be a case class and all fields must be headers. So fields must
* be of primitive types (e.g. Int, String, UUID, Instant etc.), a case class
* with a single field or a collection of either of these. Headers are
* selected by field names.
*/
final def header[T](implicit schema: Schema[T]): Either[HttpCodecError.HeaderError, T] =
try
Right(
StringSchemaCodec
.headerFromSchema(schema, ErrorConstructor.header, null)
.decode(headers),
)
catch {
case e: HttpCodecError.HeaderError => Left(e)
}

/**
* Retrieves the header with the specified name as a value of the specified
* type T, or returns a default value if the header is not present or could
* not be parsed. The type T must have a schema and can be a primitive type
* (e.g. Int, String, UUID, Instant etc.), a case class with a single field or
* a collection of either of these.
*/
final def headerOrElse[T](name: String, default: => T)(implicit schema: Schema[T]): T =
header[T](name).getOrElse(default)

/**
* Retrieves headers as a value of the specified type T, or returns a default
* value if the headers are not present or could not be parsed. The type T
* must have a schema and be a case class and all fields must be headers. So
* fields must be of primitive types (e.g. Int, String, UUID, Instant etc.), a
* case class with a single field or a collection of either of these. Headers
* are selected by field names.
*/
final def headerOrElse[T](default: => T)(implicit schema: Schema[T]): T =
header[T].getOrElse(default)

final def headerZIO[T](name: String)(implicit schema: Schema[T]): IO[HttpCodecError.HeaderError, T] =
ZIO.fromEither(header[T](name))

final def headers(headerType: HeaderType): Chunk[headerType.HeaderValue] =
Chunk.fromIterator(
headers.iterator
Expand Down
Loading

0 comments on commit e75382e

Please sign in to comment.