From e6b89aec50adf78e6cbc96968d28126e01aa5049 Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:28:30 +0300 Subject: [PATCH] Fix ArgBuilder derivation for case classes containing only optional fields (#2408) --- .../caliban/schema/ArgBuilderDerivation.scala | 10 +- .../caliban/schema/DerivationUtils.scala | 11 +++ .../caliban/schema/SchemaDerivation.scala | 3 +- .../caliban/schema/ArgBuilderDerivation.scala | 36 +++----- .../caliban/schema/DerivationUtils.scala | 11 ++- .../caliban/schema/SchemaDerivation.scala | 6 +- .../caliban/execution/ExecutionSpec.scala | 91 ++++++++++++++++++- .../scala/caliban/schema/ArgBuilderSpec.scala | 27 +++++- 8 files changed, 162 insertions(+), 33 deletions(-) create mode 100644 core/src/main/scala-2/caliban/schema/DerivationUtils.scala diff --git a/core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala b/core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala index 34324a550..29d29f8ce 100644 --- a/core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala +++ b/core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala @@ -3,7 +3,7 @@ package caliban.schema import caliban.CalibanError.ExecutionError import caliban.InputValue import caliban.Value._ -import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput } +import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput, GQLValueType } import magnolia1._ import scala.collection.compat._ @@ -27,7 +27,6 @@ trait CommonArgBuilderDerivation { } def join[T](ctx: CaseClass[ArgBuilder, T]): ArgBuilder[T] = new ArgBuilder[T] { - private val params = { val arr = Array.ofDim[(String, EitherExecutionError[Any])](ctx.parameters.length) ctx.parameters.zipWithIndex.foreach { case (p, i) => @@ -40,14 +39,17 @@ trait CommonArgBuilderDerivation { private val required = params.collect { case (label, default) if default.isLeft => label } + private val isValueType = DerivationUtils.isValueType(ctx) + override private[schema] val partial: PartialFunction[InputValue, Either[ExecutionError, T]] = { case InputValue.ObjectValue(fields) if required.forall(fields.contains) => fromFields(fields) } def build(input: InputValue): Either[ExecutionError, T] = input match { - case InputValue.ObjectValue(fields) => fromFields(fields) - case value => ctx.constructMonadic(p => p.typeclass.build(value)) + case InputValue.ObjectValue(fields) if !isValueType => fromFields(fields) + case value if isValueType => ctx.constructMonadic(p => p.typeclass.build(value)) + case _ => Left(ExecutionError("Expected an input object")) } private[this] def fromFields(fields: Map[String, InputValue]): Either[ExecutionError, T] = diff --git a/core/src/main/scala-2/caliban/schema/DerivationUtils.scala b/core/src/main/scala-2/caliban/schema/DerivationUtils.scala new file mode 100644 index 000000000..941f656c3 --- /dev/null +++ b/core/src/main/scala-2/caliban/schema/DerivationUtils.scala @@ -0,0 +1,11 @@ +package caliban.schema + +import caliban.schema.Annotations.GQLValueType +import magnolia1.ReadOnlyCaseClass + +private object DerivationUtils { + + def isValueType[F[_]](ctx: ReadOnlyCaseClass[F, ?]): Boolean = + (ctx.isValueClass || ctx.annotations.exists(_.isInstanceOf[GQLValueType])) && ctx.parameters.nonEmpty + +} diff --git a/core/src/main/scala-2/caliban/schema/SchemaDerivation.scala b/core/src/main/scala-2/caliban/schema/SchemaDerivation.scala index defc2f192..9c27de0d4 100644 --- a/core/src/main/scala-2/caliban/schema/SchemaDerivation.scala +++ b/core/src/main/scala-2/caliban/schema/SchemaDerivation.scala @@ -1,6 +1,5 @@ package caliban.schema -import caliban.CalibanError.ValidationError import caliban.Value._ import caliban.introspection.adt._ import caliban.parsing.adt.{ Directive, Directives } @@ -59,7 +58,7 @@ trait CommonSchemaDerivation[R] { } ) - private lazy val _isValueType = (ctx.isValueClass || isValueType(ctx)) && ctx.parameters.nonEmpty + private lazy val _isValueType = DerivationUtils.isValueType(ctx) override def toType(isInput: Boolean, isSubscription: Boolean): __Type = { val _ = objectResolver // Initializes lazy val diff --git a/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala b/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala index 1160acfdc..2f445cf44 100644 --- a/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala +++ b/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala @@ -1,9 +1,8 @@ package caliban.schema import caliban.CalibanError.ExecutionError -import caliban.InputValue.{ ListValue, VariableValue } import caliban.Value.* -import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput } +import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput, GQLValueType } import caliban.schema.macros.Macros import caliban.{ CalibanError, InputValue } import magnolia1.Macro as MagnoliaMacro @@ -70,7 +69,8 @@ trait CommonArgBuilderDerivation { case m: Mirror.ProductOf[A] => makeProductArgBuilder( recurseProduct[A, m.MirroredElemLabels, m.MirroredElemTypes](), - MagnoliaMacro.paramAnns[A].toMap + MagnoliaMacro.paramAnns[A].toMap, + DerivationUtils.isValueType[A, m.MirroredElemLabels] )(m.fromProduct) } @@ -124,7 +124,8 @@ trait CommonArgBuilderDerivation { private def makeProductArgBuilder[A]( _fields: => List[(String, ArgBuilder[Any])], - annotations: Map[String, List[Any]] + annotations: Map[String, List[Any]], + isValueType: Boolean )(fromProduct: Product => A): ArgBuilder[A] = new ArgBuilder[A] { private val params = Array.from(_fields.map { (label, builder) => @@ -134,6 +135,8 @@ trait CommonArgBuilderDerivation { (finalLabel, default, builder) }) + assert(!isValueType || params.length == 1, "value classes must have exactly one field") + private val required = params.collect { case (label, default, _) if default.isLeft => label } override private[schema] val partial: PartialFunction[InputValue, Either[ExecutionError, A]] = { @@ -142,8 +145,9 @@ trait CommonArgBuilderDerivation { def build(input: InputValue): Either[ExecutionError, A] = input match { - case InputValue.ObjectValue(fields) => fromFields(fields) - case value => fromValue(value) + case InputValue.ObjectValue(fields) if !isValueType => fromFields(fields) + case value if isValueType => fromValue(value) + case _ => Left(ExecutionError("Expected an input object")) } private def fromFields(fields: Map[String, InputValue]): Either[ExecutionError, A] = { @@ -163,22 +167,12 @@ trait CommonArgBuilderDerivation { Right(fromProduct(Tuple.fromArray(arr))) } - private def fromValue(input: InputValue): Either[ExecutionError, A] = { - val l = params.length - val arr = Array.ofDim[Any](l) - var i = 0 - while (i < l) { - val (_, _, builder) = params(i) - builder.build(input) match { - case Right(v) => arr(i) = v - case Left(e) => return Left(e) - } - i += 1 - } - Right(fromProduct(Tuple.fromArray(arr))) - } - + private def fromValue(input: InputValue): Either[ExecutionError, A] = + params(0)._3 + .build(input) + .map(v => fromProduct(Tuple1(v))) } + } trait ArgBuilderDerivation extends CommonArgBuilderDerivation { diff --git a/core/src/main/scala-3/caliban/schema/DerivationUtils.scala b/core/src/main/scala-3/caliban/schema/DerivationUtils.scala index 12acd8043..f1cfd2f0f 100644 --- a/core/src/main/scala-3/caliban/schema/DerivationUtils.scala +++ b/core/src/main/scala-3/caliban/schema/DerivationUtils.scala @@ -4,7 +4,10 @@ import caliban.introspection.adt.* import caliban.parsing.adt.{ Directive, Directives } import caliban.schema.Annotations.* import caliban.schema.Types.* -import magnolia1.TypeInfo +import caliban.schema.macros.Macros +import magnolia1.{ Macro as MagnoliaMacro, TypeInfo } + +import scala.compiletime.erasedValue private object DerivationUtils { @@ -44,6 +47,12 @@ private object DerivationUtils { def getDeprecatedReason(annotations: Seq[Any]): Option[String] = annotations.collectFirst { case GQLDeprecated(reason) => reason } + transparent inline def isValueType[A, Labels]: Boolean = + inline erasedValue[Labels] match { + case _: EmptyTuple => false + case _ => MagnoliaMacro.isValueClass[A] || Macros.hasAnnotation[A, GQLValueType] + } + def mkEnum(annotations: List[Any], info: TypeInfo, subTypes: List[(String, __Type, List[Any])]): __Type = makeEnum( Some(getName(annotations, info)), diff --git a/core/src/main/scala-3/caliban/schema/SchemaDerivation.scala b/core/src/main/scala-3/caliban/schema/SchemaDerivation.scala index 325953767..9ffe54e3c 100644 --- a/core/src/main/scala-3/caliban/schema/SchemaDerivation.scala +++ b/core/src/main/scala-3/caliban/schema/SchemaDerivation.scala @@ -106,20 +106,20 @@ trait CommonSchemaDerivation { case m: Mirror.ProductOf[A] => inline erasedValue[m.MirroredElemLabels] match { - case _: EmptyTuple if !Macros.hasFieldsFromMethods[A] => + case _: EmptyTuple if !Macros.hasFieldsFromMethods[A] => new EnumValueSchema[R, A]( MagnoliaMacro.typeInfo[A], // Workaround until we figure out why the macro uses the parent's annotations when the leaf is a Scala 3 enum inline if (!MagnoliaMacro.isEnum[A]) MagnoliaMacro.anns[A] else Nil, config.enableSemanticNonNull ) - case _ if Macros.hasAnnotation[A, GQLValueType] => + case _ if DerivationUtils.isValueType[A, m.MirroredElemLabels] => new ValueTypeSchema[R, A]( valueTypeSchema[R, m.MirroredElemLabels, m.MirroredElemTypes], MagnoliaMacro.typeInfo[A], MagnoliaMacro.anns[A] ) - case _ => + case _ => new ObjectSchema[R, A]( recurseProduct[R, A, m.MirroredElemLabels, m.MirroredElemTypes]()(), Macros.fieldsFromMethods[R, A], diff --git a/core/src/test/scala/caliban/execution/ExecutionSpec.scala b/core/src/test/scala/caliban/execution/ExecutionSpec.scala index c8b328fb6..950f7be08 100644 --- a/core/src/test/scala/caliban/execution/ExecutionSpec.scala +++ b/core/src/test/scala/caliban/execution/ExecutionSpec.scala @@ -1505,8 +1505,97 @@ object ExecutionSpec extends ZIOSpecDefault { ZIO.foldLeft(cases)(assertCompletes) { case (acc, (query, expected)) => api.interpreter .flatMap(_.execute(query, variables = Map("args" -> ObjectValue(Map("intValue" -> IntValue(42)))))) - .map(response => assertTrue(response.data.toString == expected)) + .map(response => acc && assertTrue(response.data.toString == expected)) + } + }, + test("oneOf input with input object with all optional fields") { + + case class AddPet(pet: Pet.Wrapper) + case class Queries(addPet: AddPet => Pet) + + val api: GraphQL[Any] = graphQL( + RootResolver( + Queries(_.pet.pet) + ) + ) + + val cases = List( + gqldoc("""{ + addPet(pet: { cat: { name: "a" } }) { + __typename + ... on Cat { name } + ... on Dog { name } + } + }""") -> """{"addPet":{"__typename":"Cat","name":"a"}}""", + gqldoc("""{ + addPet(pet: { dog: {} }) { + __typename + ... on Cat { name } + ... on Dog { name } + } + }""") -> """{"addPet":{"__typename":"Dog","name":null}}""", + gqldoc("""{ + addPet(pet: { dog: { name: "b" } }) { + __typename + ... on Cat { name } + ... on Dog { name } + } + }""") -> """{"addPet":{"__typename":"Dog","name":"b"}}""" + ) + + ZIO.foldLeft(cases)(assertCompletes) { case (acc, (query, expected)) => + api.interpreter + .flatMap(_.execute(query)) + .map(response => acc && assertTrue(response.data.toString == expected)) } } ) } + +// needs to be outside for Scala 2 +sealed trait Pet +object Pet { parent => + + @GQLOneOfInput + @GQLName("Pet") + sealed trait Wrapper { + def pet: Pet + } + object Wrapper { + implicit val argBuilder: ArgBuilder[Wrapper] = ArgBuilder.gen + implicit val schema: Schema[Any, Wrapper] = Schema.gen + } + + case class Cat(name: Option[String], numberOfLives: Option[Int]) extends Pet + object Cat { + @GQLName("Cat") + case class Wrapper(cat: Cat) extends parent.Wrapper { + override val pet = cat + } + object Wrapper { + implicit val argBuilder: ArgBuilder[Wrapper] = ArgBuilder.gen + implicit val schema: Schema[Any, Wrapper] = Schema.gen + } + + implicit val argBuilder: ArgBuilder[Cat] = ArgBuilder.gen + implicit val schema: Schema[Any, Cat] = Schema.gen + } + + case class Dog(name: Option[String], wagsTail: Option[Boolean]) extends Pet + object Dog { + @GQLName("Dog") + case class Wrapper(dog: Dog) extends parent.Wrapper { + override val pet = dog + } + object Wrapper { + implicit val argBuilder: ArgBuilder[Wrapper] = ArgBuilder.gen + implicit val schema: Schema[Any, Wrapper] = Schema.gen + } + + implicit val argBuilder: ArgBuilder[Dog] = ArgBuilder.gen + implicit val schema: Schema[Any, Dog] = Schema.gen + } + + implicit val argBuilder: ArgBuilder[Pet] = ArgBuilder.gen + implicit val schema: Schema[Any, Pet] = Schema.gen +} diff --git a/core/src/test/scala/caliban/schema/ArgBuilderSpec.scala b/core/src/test/scala/caliban/schema/ArgBuilderSpec.scala index 6045149fa..ab6c97fc3 100644 --- a/core/src/test/scala/caliban/schema/ArgBuilderSpec.scala +++ b/core/src/test/scala/caliban/schema/ArgBuilderSpec.scala @@ -5,7 +5,7 @@ import caliban.InputValue import caliban.InputValue.ObjectValue import caliban.schema.ArgBuilder.auto._ import caliban.Value.{ IntValue, NullValue, StringValue } -import caliban.schema.Annotations.GQLOneOfInput +import caliban.schema.Annotations.{ GQLOneOfInput, GQLValueType } import zio.test.Assertion._ import zio.test._ @@ -71,6 +71,31 @@ object ArgBuilderSpec extends ZIOSpecDefault { ) ) ), + suite("derived build")( + test("should fail when null is provided for case class with optional fields") { + case class Foo(value: Option[String]) + val ab = ArgBuilder.gen[Foo] + assertTrue( + ab.build(NullValue).isLeft, + // Sanity checks + ab.build(ObjectValue(Map())) == Right(Foo(None)), + ab.build(ObjectValue(Map("value" -> StringValue("foo")))) == Right(Foo(Some("foo"))), + ab.build(ObjectValue(Map("bar" -> StringValue("foo")))) == Right(Foo(None)) + ) + }, + test("should fail when an empty object is provided for GQLValueType case classes") { + @GQLValueType + case class Foo(value: Option[String]) + val ab = ArgBuilder.gen[Foo] + assertTrue( + ab.build(ObjectValue(Map())).isLeft, + // Sanity checks + ab.build(NullValue) == Right(Foo(None)), + ab.build(StringValue("foo")) == Right(Foo(Some("foo"))), + ab.build(IntValue(42)).isLeft + ) + } + ), suite("buildMissing")( test("works with derived case class ArgBuilders") { sealed abstract class Nullable[+T]