diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala index a5bc1938ba..255f4ce6f4 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala @@ -1,15 +1,7 @@ package zio.http.endpoint.cli -import scala.util.Try - -import zio.cli._ - -import zio.schema._ - import zio.http._ -import zio.http.codec.HttpCodec.Metadata import zio.http.codec._ -import zio.http.codec.internal._ import zio.http.endpoint._ /** @@ -133,10 +125,10 @@ private[cli] object CliEndpoint { case HttpCodec.Path(pathCodec, _) => CliEndpoint(url = HttpOptions.Path(pathCodec) :: List()) - case HttpCodec.Query(name, textCodec, _) => - textCodec.asInstanceOf[TextCodec[_]] match { - case TextCodec.Constant(value) => CliEndpoint(url = HttpOptions.QueryConstant(name, value) :: List()) - case _ => CliEndpoint(url = HttpOptions.Query(name, textCodec) :: List()) + case query: HttpCodec.Query[Input, ?] => + query.textCodec match { + case TextCodec.Constant(value) => CliEndpoint(url = HttpOptions.QueryConstant(query.name, value) :: List()) + case _ => CliEndpoint(url = HttpOptions.Query(query.name, query.textCodec) :: List()) } case HttpCodec.Status(_, _) => CliEndpoint.empty diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala index 5c7ba70f94..ff50eca5d9 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala @@ -91,9 +91,9 @@ object EndpointGen { } lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] = - Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { case (name, codec) => + Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).zip(Gen.boolean).map { case (name, codec, isMono) => CliRepr( - HttpCodec.Query(name, codec), + if (isMono) HttpCodec.MonoQuery(name, codec) else HttpCodec.MultiQuery(name, codec), codec match { case TextCodec.Constant(value) => CliEndpoint(url = HttpOptions.QueryConstant(name, value) :: Nil) case _ => CliEndpoint(url = HttpOptions.Query(name, codec) :: Nil) diff --git a/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala index f80665ac30..4dd554ec08 100644 --- a/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala @@ -574,14 +574,49 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): ContentStream[A] = copy(index = index) } - private[http] final case class Query[A](name: String, textCodec: TextCodec[A], index: Int = 0) - extends Atom[HttpCodecType.Query, A] { - self => - def erase: Query[Any] = self.asInstanceOf[Query[Any]] + + private[http] sealed trait Query[A, I] extends Atom[HttpCodecType.Query, A] { + def erase: Query[Any, I] = asInstanceOf[Query[Any, I]] + + def name: String + + def textCodec: TextCodec[I] + + def index: Int def tag: AtomTag = AtomTag.Query - def index(index: Int): Query[A] = copy(index = index) + def index(index: Int): Query[A, I] + + def encode(value: A): Chunk[String] + + def decode(values: Chunk[String]): A + + @inline final private[HttpCodec] def decodeItem(value: String): I = + if (textCodec.isDefinedAt(value)) textCodec(value) + else throw HttpCodecError.MalformedQueryParam(name, textCodec) + } + + private[http] final case class MonoQuery[A](name: String, textCodec: TextCodec[A], index: Int = 0) + extends Query[A, A] { + def index(index: Int): Query[A, A] = copy(index = index) + + def encode(value: A): Chunk[String] = Chunk(textCodec.encode(value)) + + def decode(values: Chunk[String]): A = values match { + case Chunk(value) => decodeItem(value) + case empty if empty.isEmpty => throw HttpCodecError.MissingQueryParam(name) + case _ => throw HttpCodecError.SingleQueryParamValueExpected(name) + } + } + + private[http] final case class MultiQuery[I](name: String, textCodec: TextCodec[I], index: Int = 0) + extends Query[Chunk[I], I] { + def index(index: Int): Query[Chunk[I], I] = copy(index = index) + + def encode(value: Chunk[I]): Chunk[String] = value map textCodec.encode + + def decode(values: Chunk[String]): Chunk[I] = values map decodeItem } private[http] final case class Method[A](codec: SimpleCodec[zio.http.Method, A], index: Int = 0) diff --git a/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala b/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala index 8c486cc668..822b363821 100644 --- a/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala +++ b/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala @@ -51,6 +51,9 @@ object HttpCodecError { final case class MissingQueryParam(queryParamName: String) extends HttpCodecError { def message = s"Missing query parameter $queryParamName" } + final case class SingleQueryParamValueExpected(queryParamName: String) extends HttpCodecError { + def message = s"Single query parameter $queryParamName value expected, but multiple values are found" + } final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed query parameter $queryParamName failed to decode using $textCodec" } diff --git a/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala b/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala index 5bf72e57e3..4d73bb0026 100644 --- a/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala +++ b/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala @@ -15,30 +15,36 @@ */ package zio.http.codec +import zio.Chunk import zio.stacktracer.TracingImplicits.disableAutoTrace private[codec] trait QueryCodecs { def query(name: String): QueryCodec[String] = - HttpCodec.Query(name, TextCodec.string) + HttpCodec.MonoQuery(name, TextCodec.string) def queryBool(name: String): QueryCodec[Boolean] = - HttpCodec.Query(name, TextCodec.boolean) + HttpCodec.MonoQuery(name, TextCodec.boolean) def queryInt(name: String): QueryCodec[Int] = - HttpCodec.Query(name, TextCodec.int) + HttpCodec.MonoQuery(name, TextCodec.int) def queryAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] = - HttpCodec.Query(name, codec) + HttpCodec.MonoQuery(name, codec) + + def queries[I](name: String)(implicit codec: TextCodec[I]): QueryCodec[Chunk[I]] = + HttpCodec.MultiQuery(name, codec) def paramStr(name: String): QueryCodec[String] = - HttpCodec.Query(name, TextCodec.string) + HttpCodec.MonoQuery(name, TextCodec.string) def paramBool(name: String): QueryCodec[Boolean] = - HttpCodec.Query(name, TextCodec.boolean) + HttpCodec.MonoQuery(name, TextCodec.boolean) def paramInt(name: String): QueryCodec[Int] = - HttpCodec.Query(name, TextCodec.int) + HttpCodec.MonoQuery(name, TextCodec.int) def paramAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] = - HttpCodec.Query(name, codec) + HttpCodec.MonoQuery(name, codec) + def params[I](name: String)(implicit codec: TextCodec[I]): QueryCodec[Chunk[I]] = + HttpCodec.MultiQuery(name, codec) } diff --git a/zio-http/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala b/zio-http/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala index 4c18c8e466..b84fbbd9b8 100644 --- a/zio-http/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala +++ b/zio-http/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala @@ -25,7 +25,7 @@ import zio.http.codec._ private[http] final case class AtomizedCodecs( method: Chunk[SimpleCodec[zio.http.Method, _]], path: Chunk[PathCodec[_]], - query: Chunk[Query[_]], + query: Chunk[Query[_, _]], header: Chunk[Header[_]], content: Chunk[BodyCodec[_]], status: Chunk[SimpleCodec[zio.http.Status, _]], @@ -33,7 +33,7 @@ private[http] final case class AtomizedCodecs( def append(atom: Atom[_, _]): AtomizedCodecs = atom match { case path0: Path[_] => self.copy(path = path :+ path0.pathCodec) case method0: Method[_] => self.copy(method = method :+ method0.codec) - case query0: Query[_] => self.copy(query = query :+ query0) + case query0: Query[_, _] => self.copy(query = query :+ query0) case header0: Header[_] => self.copy(header = header :+ header0) case content0: Content[_] => self.copy(content = content :+ BodyCodec.Single(content0.schema, content0.mediaType, content0.name)) diff --git a/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index 720ec51e4a..99c3970110 100644 --- a/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -279,20 +279,8 @@ private[codec] object EncoderDecoder { var i = 0 val queries = flattened.query while (i < queries.length) { - val query = queries(i).erase - - val queryParamValue = - queryParams - .getAllOrElse(query.name, Nil) - .collectFirst(query.textCodec) - - queryParamValue match { - case Some(value) => - inputs(i) = value - case None => - throw HttpCodecError.MissingQueryParam(query.name) - } - + val query = queries(i) + inputs(i) = query.decode(queryParams.getAllOrElse(query.name, Nil)) i = i + 1 } } @@ -478,9 +466,7 @@ private[codec] object EncoderDecoder { val query = flattened.query(i).erase val input = inputs(i) - val value = query.textCodec.encode(input) - - queryParams = queryParams.add(query.name, value) + queryParams = queryParams.addAll(query.name, query.encode(input)) i = i + 1 } diff --git a/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala b/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala index 2f4c5aabb3..6d6bbbc7f6 100644 --- a/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala +++ b/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala @@ -37,9 +37,21 @@ object HttpCodecSpec extends ZIOHttpSpec { val emptyJson = Body.fromString("{}") - val isAge = "isAge" - val codecBool = QueryCodec.paramBool(isAge) - def makeRequest(paramValue: String) = Request.get(googleUrl.queryParams(QueryParams(isAge -> paramValue))) + val strParam = "name" + val codecStr = QueryCodec.paramStr(strParam) + val boolParam = "isAge" + val codecBool = QueryCodec.paramBool(boolParam) + val intParam = "age" + val codecInt = QueryCodec.paramInt(intParam) + val longParam = "count" + val codecLong = QueryCodec.paramAs[Long](longParam) + val seqIntParam = "integers" + val codecSeqInt = QueryCodec.params[Int](seqIntParam) + + def makeRequest(name: String, value: Any) = + Request.get(googleUrl.queryParams(QueryParams(name -> value.toString))) + def makeChunkRequest(name: String, values: Chunk[Any]) = + Request.get(googleUrl.queryParams(QueryParams(name -> values.map(_.toString)))) def spec = suite("HttpCodecSpec")( suite("fallback") { @@ -120,25 +132,61 @@ object HttpCodecSpec extends ZIOHttpSpec { } } + suite("QueryCodec")( - test("paramBool decoding with case-insensitive") { - assertZIO(codecBool.decodeRequest(makeRequest("true")))(Assertion.isTrue) && - assertZIO(codecBool.decodeRequest(makeRequest("TRUE")))(Assertion.isTrue) && - assertZIO(codecBool.decodeRequest(makeRequest("yes")))(Assertion.isTrue) && - assertZIO(codecBool.decodeRequest(makeRequest("YES")))(Assertion.isTrue) && - assertZIO(codecBool.decodeRequest(makeRequest("on")))(Assertion.isTrue) && - assertZIO(codecBool.decodeRequest(makeRequest("ON")))(Assertion.isTrue) + test("paramStr decoding and encoding") { + check(Gen.alphaNumericString) { value => + assertZIO(codecStr.decodeRequest(makeRequest(strParam, value)))(Assertion.equalTo(value)) && + assert(codecStr.encodeRequest(value).url.queryParams.get(strParam))( + Assertion.isSome(Assertion.equalTo(value)), + ) + } }, - test("paramBool decoding with different values") { - assertZIO(codecBool.decodeRequest(makeRequest("true")))(Assertion.isTrue) && - assertZIO(codecBool.decodeRequest(makeRequest("1")))(Assertion.isTrue) && - assertZIO(codecBool.decodeRequest(makeRequest("yes")))(Assertion.isTrue) && - assertZIO(codecBool.decodeRequest(makeRequest("on")))(Assertion.isTrue) + test("paramBool decoding true") { + Chunk("true", "TRUE", "yes", "YES", "on", "ON", "1") map { value => + assertZIO(codecBool.decodeRequest(makeRequest(boolParam, value)))(Assertion.isTrue) + } reduce (_ && _) + }, + test("paramBool decoding false") { + Chunk("false", "FALSE", "no", "NO", "off", "OFF", "0") map { value => + assertZIO(codecBool.decodeRequest(makeRequest(boolParam, value)))(Assertion.isFalse) + } reduce (_ && _) }, test("paramBool encoding") { val requestTrue = codecBool.encodeRequest(true) val requestFalse = codecBool.encodeRequest(false) - assert(requestTrue.url.queryParams.get(isAge).get)(Assertion.equalTo("true")) && - assert(requestFalse.url.queryParams.get(isAge).get)(Assertion.equalTo("false")) + assert(requestTrue.url.queryParams.get(boolParam).get)(Assertion.equalTo("true")) && + assert(requestFalse.url.queryParams.get(boolParam).get)(Assertion.equalTo("false")) + }, + test("paramInt decoding and encoding") { + check(Gen.int) { value => + assertZIO(codecInt.decodeRequest(makeRequest(intParam, value)))(Assertion.equalTo(value)) && + assert(codecInt.encodeRequest(value).url.queryParams.get(intParam))( + Assertion.isSome(Assertion.equalTo(value.toString)), + ) + } + }, + test("paramLong decoding and encoding") { + check(Gen.long) { value => + assertZIO(codecLong.decodeRequest(makeRequest(longParam, value)))(Assertion.equalTo(value)) && + assert(codecLong.encodeRequest(value).url.queryParams.get(longParam))( + Assertion.isSome(Assertion.equalTo(value.toString)), + ) + } + }, + test("paramSeq decoding with empty chunk") { + assertZIO(codecSeqInt.decodeRequest(makeChunkRequest(seqIntParam, Chunk.empty)))(Assertion.isEmpty) + }, + test("paramSeq decoding with non-empty chunk") { + assertZIO(codecSeqInt.decodeRequest(makeChunkRequest(seqIntParam, Chunk("2023", "10", "7"))))( + Assertion.equalTo(Chunk(2023, 10, 7)), + ) + }, + test("paramSeq encoding with empty chunk") { + assert(codecSeqInt.encodeRequest(Chunk.empty).url.queryParams.get(seqIntParam))(Assertion.isNone) + }, + test("paramSeq encoding with non-empty chunk") { + assert(codecSeqInt.encodeRequest(Chunk(1974, 5, 3)).url.queryParams.getAll(seqIntParam).get)( + Assertion.equalTo(Chunk("1974", "5", "3")), + ) }, ) + suite("Codec with examples") { diff --git a/zio-http/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala b/zio-http/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala index 165a1adbec..f7cba19f65 100644 --- a/zio-http/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala +++ b/zio-http/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala @@ -30,7 +30,7 @@ import zio.schema.{DeriveSchema, Schema} import zio.http.Header.ContentType import zio.http.Method._ import zio.http._ -import zio.http.codec.HttpCodec.{query, queryInt} +import zio.http.codec.HttpCodec.{queries, query, queryAs, queryInt} import zio.http.codec._ import zio.http.endpoint.EndpointSpec.testEndpoint import zio.http.forms.Fixtures.formField @@ -105,5 +105,25 @@ object QueryParameterSpec extends ZIOHttpSpec { testRoutes(s"/users/$userId?key=$key&value=$value", s"path(users, $userId, Some($key), Some($value))") } }, + test("query parameter with multiple values") { + check(Gen.boolean, Gen.alphaNumericString, Gen.alphaNumericString) { (isSomething, name1, name2) => + val testRoutes = testEndpoint( + Routes( + Endpoint(GET / "data") + .query(queryAs[Boolean]("isSomething")) + .query(queries[String]("name")) + .out[String] + .implement { + Handler.fromFunction { case (isSomething, names) => + s"query($isSomething, ${names mkString ", "})" + } + }, + ), + ) _ + testRoutes(s"/data?isSomething=$isSomething", s"query($isSomething, )") && + testRoutes(s"/data?isSomething=$isSomething&name=$name1", s"query($isSomething, $name1)") && + testRoutes(s"/data?isSomething=$isSomething&name=$name1&name=$name2", s"query($isSomething, $name1, $name2)") + } + }, ) } diff --git a/zio-http/src/test/scala/zio/http/endpoint/RequestSpec.scala b/zio-http/src/test/scala/zio/http/endpoint/RequestSpec.scala index c73d2048dc..fa495e665e 100644 --- a/zio-http/src/test/scala/zio/http/endpoint/RequestSpec.scala +++ b/zio-http/src/test/scala/zio/http/endpoint/RequestSpec.scala @@ -134,6 +134,25 @@ object RequestSpec extends ZIOHttpSpec { assertTrue(contentType.isEmpty) } }, + test("multiple parameters for MonoQuery") { + check(Gen.int, Gen.int, Gen.int) { (id, id1, id2) => + val endpoint = + Endpoint(GET / "posts") + .query(queryInt("id")) + .out[Int] + val routes = + endpoint.implement { + Handler.succeed(id) + } + for { + response <- routes.toHttpApp.runZIO( + Request.get(URL.decode(s"/posts?id=$id1&id=$id2").toOption.get), + ) + contentType = response.header(Header.ContentType) + } yield assertTrue(extractStatus(response).code == 400) && + assertTrue(contentType.isEmpty) + } + }, test("header codec") { check(Gen.int, Gen.alphaNumericString) { (id, notACorrelationId) => val endpoint =