From 1ab45837fcbfda6a79b0710f922784c6e91e643a Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Mon, 17 Jun 2024 15:25:29 +1000 Subject: [PATCH 1/3] Minor adapter optimizations and code readability improvements --- .../scala/caliban/QuickRequestHandler.scala | 30 ++++++++------ .../main/scala/caliban/GraphQLResponse.scala | 39 ++++++++++--------- core/src/main/scala/caliban/HttpUtils.scala | 9 ++--- .../caliban/interop/tapir/TapirAdapter.scala | 20 +++++----- 4 files changed, 53 insertions(+), 45 deletions(-) diff --git a/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala b/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala index 8426b1a01..554148102 100644 --- a/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala +++ b/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala @@ -44,8 +44,10 @@ final private class QuickRequestHandler[R]( def handleHttpRequest(request: Request)(implicit trace: Trace): URIO[R, Response] = transformHttpRequest(request) .flatMap(executeRequest(request.method, _)) - .map(transformResponse(request, _)) - .merge + .foldZIO( + Exit.succeed, + resp => Exit.succeed(transformResponse(request, resp)) + ) def handleUploadRequest(request: Request)(implicit trace: Trace): URIO[R, Response] = transformUploadRequest(request).flatMap { case (req, fileHandle) => @@ -93,10 +95,10 @@ final private class QuickRequestHandler[R]( def decodeJson(): ZIO[Any, Response, GraphQLRequest] = body.asArray.foldZIO( - _ => ZIO.fail(BodyDecodeErrorResponse), + _ => Exit.fail(BodyDecodeErrorResponse), arr => try checkNonEmptyRequest(readFromArray[GraphQLRequest](arr)) - catch { case NonFatal(_) => ZIO.fail(BodyDecodeErrorResponse) } + catch { case NonFatal(_) => Exit.fail(BodyDecodeErrorResponse) } ) val isApplicationGql = @@ -111,7 +113,7 @@ final private class QuickRequestHandler[R]( val queryParams = httpReq.url.queryParams if ((httpReq.method eq Method.GET) || queryParams.hasQueryParam("query")) { - decodeQueryParams(queryParams).fold(ZIO.fail(_), checkNonEmptyRequest) + decodeQueryParams(queryParams).fold(Exit.fail, checkNonEmptyRequest) } else { val req = decodeBody(httpReq.body) if (isFtv1Request(httpReq)) req.map(_.withFederatedTracing) @@ -168,11 +170,14 @@ final private class QuickRequestHandler[R]( } private def responseHeaders(headers: Headers, cacheDirective: Option[String]): Headers = - cacheDirective.fold(headers)(headers.addHeader(Header.CacheControl.name, _)) + cacheDirective match { + case None => headers + case Some(h) => headers.addHeader(Header.CacheControl.name, h) + } private def transformResponse(httpReq: Request, resp: GraphQLResponse[Any])(implicit trace: Trace): Response = { val accepts = new HttpUtils.AcceptsGqlEncodings(httpReq.headers.get(Header.Accept.name)) - val cacheDirective = HttpUtils.computeCacheDirective(resp.extensions) + val cacheDirective = resp.extensions.flatMap(HttpUtils.computeCacheDirective) resp match { case resp @ GraphQLResponse(StreamValue(stream), _, _, _) => @@ -184,9 +189,10 @@ final private class QuickRequestHandler[R]( case resp if accepts.serverSentEvents => Response.fromServerSentEvents(encodeTextEventStream(resp)) case resp if accepts.graphQLJson => - val isBadRequest = resp.errors.collectFirst { + val isBadRequest = resp.errors.exists { case _: CalibanError.ParsingError | _: CalibanError.ValidationError => true - }.getOrElse(false) + case _ => false + } Response( status = if (isBadRequest) Status.BadRequest else Status.Ok, headers = responseHeaders(ContentTypeGql, cacheDirective), @@ -194,9 +200,9 @@ final private class QuickRequestHandler[R]( encodeSingleResponse(resp, keepDataOnErrors = !isBadRequest, hasCacheDirective = cacheDirective.isDefined) ) case resp => + val isBadRequest = resp.errors.contains(HttpRequestMethod.MutationOverGetError) Response( - status = resp.errors.collectFirst { case HttpRequestMethod.MutationOverGetError => Status.BadRequest } - .getOrElse(Status.Ok), + status = if (isBadRequest) Status.BadRequest else Status.Ok, headers = responseHeaders(ContentTypeJson, cacheDirective), body = encodeSingleResponse(resp, keepDataOnErrors = true, hasCacheDirective = cacheDirective.isDefined) ) @@ -298,5 +304,7 @@ object QuickRequestHandler { null.asInstanceOf[InputValue.ObjectValue] } + private implicit val responseCodec: JsonValueCodec[ResponseValue] = ValueJsoniter.responseValueCodec + private implicit val stringListCodec: JsonValueCodec[Map[String, List[String]]] = JsonCodecMaker.make } diff --git a/core/src/main/scala/caliban/GraphQLResponse.scala b/core/src/main/scala/caliban/GraphQLResponse.scala index f342c4545..9a7844f5f 100644 --- a/core/src/main/scala/caliban/GraphQLResponse.scala +++ b/core/src/main/scala/caliban/GraphQLResponse.scala @@ -19,24 +19,27 @@ case class GraphQLResponse[+E]( ) { def toResponseValue: ResponseValue = toResponseValue(keepDataOnErrors = true) - def toResponseValue(keepDataOnErrors: Boolean, excludeExtensions: Option[Set[String]] = None): ResponseValue = { - val hasErrors = errors.nonEmpty - ObjectValue( - List( - "data" -> (if (!hasErrors || keepDataOnErrors) Some(data) else None), - "errors" -> (if (hasErrors) - Some(ListValue(errors.map { - case e: CalibanError => e.toResponseValue - case e => ObjectValue(List("message" -> StringValue(e.toString))) - })) - else None), - "extensions" -> excludeExtensions.fold(extensions)(excl => - extensions.map(obj => ObjectValue(obj.fields.filterNot(f => excl.contains(f._1)))) - ), - "hasNext" -> hasNext.map(BooleanValue.apply) - ).collect { case (name, Some(v)) => name -> v } - ) - } + def toResponseValue(keepDataOnErrors: Boolean, excludeExtensions: Option[Set[String]] = None): ResponseValue = + if (errors.isEmpty && extensions.isEmpty && hasNext.isEmpty) { + ObjectValue(("data", data) :: Nil) + } else { + val hasErrors = errors.nonEmpty + ObjectValue( + List( + "data" -> (if (!hasErrors || keepDataOnErrors) Some(data) else None), + "errors" -> (if (hasErrors) + Some(ListValue(errors.map { + case e: CalibanError => e.toResponseValue + case e => ObjectValue(List("message" -> StringValue(e.toString))) + })) + else None), + "extensions" -> excludeExtensions.fold(extensions)(excl => + extensions.map(obj => ObjectValue(obj.fields.filterNot(f => excl.contains(f._1)))) + ), + "hasNext" -> hasNext.map(BooleanValue.apply) + ).collect { case (name, Some(v)) => name -> v } + ) + } def withExtension(key: String, value: ResponseValue): GraphQLResponse[E] = copy(extensions = Some(ObjectValue(extensions.foldLeft(List(key -> value)) { case (value, ObjectValue(fields)) => diff --git a/core/src/main/scala/caliban/HttpUtils.scala b/core/src/main/scala/caliban/HttpUtils.scala index 5fead8d84..9c19116c0 100644 --- a/core/src/main/scala/caliban/HttpUtils.scala +++ b/core/src/main/scala/caliban/HttpUtils.scala @@ -60,11 +60,10 @@ private[caliban] object HttpUtils { }).map(v => toSse(v.toResponseValue)) ++ ZStream.succeed(done) } - def computeCacheDirective(extensions: Option[ResponseValue.ObjectValue]): Option[String] = - extensions - .flatMap(_.fields.collectFirst { case (Caching.DirectiveName, ResponseValue.ObjectValue(fields)) => - fields.collectFirst { case ("httpHeader", Value.StringValue(cacheHeader)) => cacheHeader } - }.flatten) + def computeCacheDirective(extensions: ResponseValue.ObjectValue): Option[String] = + extensions.fields.collectFirst { case (Caching.DirectiveName, ResponseValue.ObjectValue(fields)) => + fields.collectFirst { case ("httpHeader", Value.StringValue(cacheHeader)) => cacheHeader } + }.flatten final class AcceptsGqlEncodings(header0: Option[String]) { private val isEmpty = header0.isEmpty diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala index 4afc9fb6a..7a4d48783 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala @@ -105,7 +105,8 @@ object TapirAdapter { streamConstructor: StreamConstructor[BS], responseCodec: JsonCodec[ResponseValue] ): (MediaType, StatusCode, Option[String], CalibanBody[BS]) = { - val accepts = new HttpUtils.AcceptsGqlEncodings(request.header(HeaderNames.Accept)) + val accepts = new HttpUtils.AcceptsGqlEncodings(request.header(HeaderNames.Accept)) + val cacheDirective = response.extensions.flatMap(HttpUtils.computeCacheDirective) response match { case resp @ GraphQLResponse(StreamValue(stream), _, _, _) => @@ -116,15 +117,14 @@ object TapirAdapter { encodeMultipartMixedResponse(resp, stream) ) case resp if accepts.graphQLJson => - val isBadRequest = response.errors.collectFirst { + val isBadRequest = response.errors.exists { case _: CalibanError.ParsingError | _: CalibanError.ValidationError => true - }.getOrElse(false) - val code = if (isBadRequest) StatusCode.BadRequest else StatusCode.Ok - val cacheDirective = HttpUtils.computeCacheDirective(response.extensions) + case _ => false + } ( GraphqlResponseJson.mediaType, - code, - HttpUtils.computeCacheDirective(response.extensions), + if (isBadRequest) StatusCode.BadRequest else StatusCode.Ok, + cacheDirective, encodeSingleResponse( resp, keepDataOnErrors = !isBadRequest, @@ -139,12 +139,10 @@ object TapirAdapter { encodeTextEventStreamResponse(resp) ) case resp => - val code = response.errors.collectFirst { case HttpRequestMethod.MutationOverGetError => StatusCode.BadRequest } - .getOrElse(StatusCode.Ok) - val cacheDirective = HttpUtils.computeCacheDirective(response.extensions) + val isBadRequest = response.errors.contains(HttpRequestMethod.MutationOverGetError) ( MediaType.ApplicationJson, - code, + if (isBadRequest) StatusCode.BadRequest else StatusCode.Ok, cacheDirective, encodeSingleResponse( resp, From 8040c6f807b47907c362bd014a108e72aaae20d2 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Mon, 17 Jun 2024 15:36:42 +1000 Subject: [PATCH 2/3] Fix scala 2 warning --- .../src/main/scala/caliban/interop/tapir/TapirAdapter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala index 7a4d48783..cb942053d 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala @@ -139,7 +139,7 @@ object TapirAdapter { encodeTextEventStreamResponse(resp) ) case resp => - val isBadRequest = response.errors.contains(HttpRequestMethod.MutationOverGetError) + val isBadRequest = response.errors.contains(HttpRequestMethod.MutationOverGetError: Any) ( MediaType.ApplicationJson, if (isBadRequest) StatusCode.BadRequest else StatusCode.Ok, From ae0cad9d40c16335fdd6e7265f0e0e409bf4b877 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Mon, 17 Jun 2024 16:23:15 +1000 Subject: [PATCH 3/3] Use a ListBuffer in `toResponseValue` --- .../main/scala/caliban/GraphQLResponse.scala | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/caliban/GraphQLResponse.scala b/core/src/main/scala/caliban/GraphQLResponse.scala index 9a7844f5f..99809c139 100644 --- a/core/src/main/scala/caliban/GraphQLResponse.scala +++ b/core/src/main/scala/caliban/GraphQLResponse.scala @@ -8,6 +8,8 @@ import caliban.interop.play.{ IsPlayJsonReads, IsPlayJsonWrites } import caliban.interop.tapir.IsTapirSchema import caliban.interop.zio.{ IsZIOJsonCodec, IsZIOJsonDecoder, IsZIOJsonEncoder } +import scala.collection.mutable.ListBuffer + /** * Represents the result of a GraphQL query, containing a data object and a list of errors. */ @@ -19,28 +21,33 @@ case class GraphQLResponse[+E]( ) { def toResponseValue: ResponseValue = toResponseValue(keepDataOnErrors = true) - def toResponseValue(keepDataOnErrors: Boolean, excludeExtensions: Option[Set[String]] = None): ResponseValue = - if (errors.isEmpty && extensions.isEmpty && hasNext.isEmpty) { - ObjectValue(("data", data) :: Nil) - } else { - val hasErrors = errors.nonEmpty - ObjectValue( - List( - "data" -> (if (!hasErrors || keepDataOnErrors) Some(data) else None), - "errors" -> (if (hasErrors) - Some(ListValue(errors.map { - case e: CalibanError => e.toResponseValue - case e => ObjectValue(List("message" -> StringValue(e.toString))) - })) - else None), - "extensions" -> excludeExtensions.fold(extensions)(excl => - extensions.map(obj => ObjectValue(obj.fields.filterNot(f => excl.contains(f._1)))) - ), - "hasNext" -> hasNext.map(BooleanValue.apply) - ).collect { case (name, Some(v)) => name -> v } - ) + def toResponseValue(keepDataOnErrors: Boolean, excludeExtensions: Option[Set[String]] = None): ResponseValue = { + val builder = new ListBuffer[(String, ResponseValue)] + val hasErrors = errors.nonEmpty + val extensions0 = excludeExtensions match { + case None => extensions + case Some(excl) => + extensions.flatMap { obj => + val newFields = obj.fields.filterNot(f => excl.contains(f._1)) + if (newFields.nonEmpty) Some(ObjectValue(newFields)) else None + } } + if (!hasErrors || keepDataOnErrors) + builder += "data" -> data + if (hasErrors) + builder += "errors" -> ListValue(errors.map { + case e: CalibanError => e.toResponseValue + case e => ObjectValue(List("message" -> StringValue(e.toString))) + }) + if (extensions0.nonEmpty) + builder += "extensions" -> extensions0.get + if (hasNext.nonEmpty) + builder += "hasNext" -> BooleanValue(hasNext.get) + + ObjectValue(builder.result()) + } + def withExtension(key: String, value: ResponseValue): GraphQLResponse[E] = copy(extensions = Some(ObjectValue(extensions.foldLeft(List(key -> value)) { case (value, ObjectValue(fields)) => value ::: fields