Skip to content

Commit

Permalink
Fix ArgBuilder derivation for case classes containing only optional f…
Browse files Browse the repository at this point in the history
…ields (#2408)
  • Loading branch information
kyri-petrou authored Sep 20, 2024
1 parent f45e0d3 commit e6b89ae
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 33 deletions.
10 changes: 6 additions & 4 deletions core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package caliban.schema
import caliban.CalibanError.ExecutionError
import caliban.InputValue
import caliban.Value._
import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput }
import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput, GQLValueType }
import magnolia1._

import scala.collection.compat._
Expand All @@ -27,7 +27,6 @@ trait CommonArgBuilderDerivation {
}

def join[T](ctx: CaseClass[ArgBuilder, T]): ArgBuilder[T] = new ArgBuilder[T] {

private val params = {
val arr = Array.ofDim[(String, EitherExecutionError[Any])](ctx.parameters.length)
ctx.parameters.zipWithIndex.foreach { case (p, i) =>
Expand All @@ -40,14 +39,17 @@ trait CommonArgBuilderDerivation {

private val required = params.collect { case (label, default) if default.isLeft => label }

private val isValueType = DerivationUtils.isValueType(ctx)

override private[schema] val partial: PartialFunction[InputValue, Either[ExecutionError, T]] = {
case InputValue.ObjectValue(fields) if required.forall(fields.contains) => fromFields(fields)
}

def build(input: InputValue): Either[ExecutionError, T] =
input match {
case InputValue.ObjectValue(fields) => fromFields(fields)
case value => ctx.constructMonadic(p => p.typeclass.build(value))
case InputValue.ObjectValue(fields) if !isValueType => fromFields(fields)
case value if isValueType => ctx.constructMonadic(p => p.typeclass.build(value))
case _ => Left(ExecutionError("Expected an input object"))
}

private[this] def fromFields(fields: Map[String, InputValue]): Either[ExecutionError, T] =
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala-2/caliban/schema/DerivationUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package caliban.schema

import caliban.schema.Annotations.GQLValueType
import magnolia1.ReadOnlyCaseClass

private object DerivationUtils {

def isValueType[F[_]](ctx: ReadOnlyCaseClass[F, ?]): Boolean =
(ctx.isValueClass || ctx.annotations.exists(_.isInstanceOf[GQLValueType])) && ctx.parameters.nonEmpty

}
3 changes: 1 addition & 2 deletions core/src/main/scala-2/caliban/schema/SchemaDerivation.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package caliban.schema

import caliban.CalibanError.ValidationError
import caliban.Value._
import caliban.introspection.adt._
import caliban.parsing.adt.{ Directive, Directives }
Expand Down Expand Up @@ -59,7 +58,7 @@ trait CommonSchemaDerivation[R] {
}
)

private lazy val _isValueType = (ctx.isValueClass || isValueType(ctx)) && ctx.parameters.nonEmpty
private lazy val _isValueType = DerivationUtils.isValueType(ctx)

override def toType(isInput: Boolean, isSubscription: Boolean): __Type = {
val _ = objectResolver // Initializes lazy val
Expand Down
36 changes: 15 additions & 21 deletions core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package caliban.schema

import caliban.CalibanError.ExecutionError
import caliban.InputValue.{ ListValue, VariableValue }
import caliban.Value.*
import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput }
import caliban.schema.Annotations.{ GQLDefault, GQLName, GQLOneOfInput, GQLValueType }
import caliban.schema.macros.Macros
import caliban.{ CalibanError, InputValue }
import magnolia1.Macro as MagnoliaMacro
Expand Down Expand Up @@ -70,7 +69,8 @@ trait CommonArgBuilderDerivation {
case m: Mirror.ProductOf[A] =>
makeProductArgBuilder(
recurseProduct[A, m.MirroredElemLabels, m.MirroredElemTypes](),
MagnoliaMacro.paramAnns[A].toMap
MagnoliaMacro.paramAnns[A].toMap,
DerivationUtils.isValueType[A, m.MirroredElemLabels]
)(m.fromProduct)
}

Expand Down Expand Up @@ -124,7 +124,8 @@ trait CommonArgBuilderDerivation {

private def makeProductArgBuilder[A](
_fields: => List[(String, ArgBuilder[Any])],
annotations: Map[String, List[Any]]
annotations: Map[String, List[Any]],
isValueType: Boolean
)(fromProduct: Product => A): ArgBuilder[A] = new ArgBuilder[A] {

private val params = Array.from(_fields.map { (label, builder) =>
Expand All @@ -134,6 +135,8 @@ trait CommonArgBuilderDerivation {
(finalLabel, default, builder)
})

assert(!isValueType || params.length == 1, "value classes must have exactly one field")

private val required = params.collect { case (label, default, _) if default.isLeft => label }

override private[schema] val partial: PartialFunction[InputValue, Either[ExecutionError, A]] = {
Expand All @@ -142,8 +145,9 @@ trait CommonArgBuilderDerivation {

def build(input: InputValue): Either[ExecutionError, A] =
input match {
case InputValue.ObjectValue(fields) => fromFields(fields)
case value => fromValue(value)
case InputValue.ObjectValue(fields) if !isValueType => fromFields(fields)
case value if isValueType => fromValue(value)
case _ => Left(ExecutionError("Expected an input object"))
}

private def fromFields(fields: Map[String, InputValue]): Either[ExecutionError, A] = {
Expand All @@ -163,22 +167,12 @@ trait CommonArgBuilderDerivation {
Right(fromProduct(Tuple.fromArray(arr)))
}

private def fromValue(input: InputValue): Either[ExecutionError, A] = {
val l = params.length
val arr = Array.ofDim[Any](l)
var i = 0
while (i < l) {
val (_, _, builder) = params(i)
builder.build(input) match {
case Right(v) => arr(i) = v
case Left(e) => return Left(e)
}
i += 1
}
Right(fromProduct(Tuple.fromArray(arr)))
}

private def fromValue(input: InputValue): Either[ExecutionError, A] =
params(0)._3
.build(input)
.map(v => fromProduct(Tuple1(v)))
}

}

trait ArgBuilderDerivation extends CommonArgBuilderDerivation {
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/scala-3/caliban/schema/DerivationUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import caliban.introspection.adt.*
import caliban.parsing.adt.{ Directive, Directives }
import caliban.schema.Annotations.*
import caliban.schema.Types.*
import magnolia1.TypeInfo
import caliban.schema.macros.Macros
import magnolia1.{ Macro as MagnoliaMacro, TypeInfo }

import scala.compiletime.erasedValue

private object DerivationUtils {

Expand Down Expand Up @@ -44,6 +47,12 @@ private object DerivationUtils {
def getDeprecatedReason(annotations: Seq[Any]): Option[String] =
annotations.collectFirst { case GQLDeprecated(reason) => reason }

transparent inline def isValueType[A, Labels]: Boolean =
inline erasedValue[Labels] match {
case _: EmptyTuple => false
case _ => MagnoliaMacro.isValueClass[A] || Macros.hasAnnotation[A, GQLValueType]
}

def mkEnum(annotations: List[Any], info: TypeInfo, subTypes: List[(String, __Type, List[Any])]): __Type =
makeEnum(
Some(getName(annotations, info)),
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala-3/caliban/schema/SchemaDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,20 @@ trait CommonSchemaDerivation {

case m: Mirror.ProductOf[A] =>
inline erasedValue[m.MirroredElemLabels] match {
case _: EmptyTuple if !Macros.hasFieldsFromMethods[A] =>
case _: EmptyTuple if !Macros.hasFieldsFromMethods[A] =>
new EnumValueSchema[R, A](
MagnoliaMacro.typeInfo[A],
// Workaround until we figure out why the macro uses the parent's annotations when the leaf is a Scala 3 enum
inline if (!MagnoliaMacro.isEnum[A]) MagnoliaMacro.anns[A] else Nil,
config.enableSemanticNonNull
)
case _ if Macros.hasAnnotation[A, GQLValueType] =>
case _ if DerivationUtils.isValueType[A, m.MirroredElemLabels] =>
new ValueTypeSchema[R, A](
valueTypeSchema[R, m.MirroredElemLabels, m.MirroredElemTypes],
MagnoliaMacro.typeInfo[A],
MagnoliaMacro.anns[A]
)
case _ =>
case _ =>
new ObjectSchema[R, A](
recurseProduct[R, A, m.MirroredElemLabels, m.MirroredElemTypes]()(),
Macros.fieldsFromMethods[R, A],
Expand Down
91 changes: 90 additions & 1 deletion core/src/test/scala/caliban/execution/ExecutionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1505,8 +1505,97 @@ object ExecutionSpec extends ZIOSpecDefault {
ZIO.foldLeft(cases)(assertCompletes) { case (acc, (query, expected)) =>
api.interpreter
.flatMap(_.execute(query, variables = Map("args" -> ObjectValue(Map("intValue" -> IntValue(42))))))
.map(response => assertTrue(response.data.toString == expected))
.map(response => acc && assertTrue(response.data.toString == expected))
}
},
test("oneOf input with input object with all optional fields") {

case class AddPet(pet: Pet.Wrapper)
case class Queries(addPet: AddPet => Pet)

val api: GraphQL[Any] = graphQL(
RootResolver(
Queries(_.pet.pet)
)
)

val cases = List(
gqldoc("""{
addPet(pet: { cat: { name: "a" } }) {
__typename
... on Cat { name }
... on Dog { name }
}
}""") -> """{"addPet":{"__typename":"Cat","name":"a"}}""",
gqldoc("""{
addPet(pet: { dog: {} }) {
__typename
... on Cat { name }
... on Dog { name }
}
}""") -> """{"addPet":{"__typename":"Dog","name":null}}""",
gqldoc("""{
addPet(pet: { dog: { name: "b" } }) {
__typename
... on Cat { name }
... on Dog { name }
}
}""") -> """{"addPet":{"__typename":"Dog","name":"b"}}"""
)

ZIO.foldLeft(cases)(assertCompletes) { case (acc, (query, expected)) =>
api.interpreter
.flatMap(_.execute(query))
.map(response => acc && assertTrue(response.data.toString == expected))
}
}
)
}

// needs to be outside for Scala 2
sealed trait Pet
object Pet { parent =>

@GQLOneOfInput
@GQLName("Pet")
sealed trait Wrapper {
def pet: Pet
}
object Wrapper {
implicit val argBuilder: ArgBuilder[Wrapper] = ArgBuilder.gen
implicit val schema: Schema[Any, Wrapper] = Schema.gen
}

case class Cat(name: Option[String], numberOfLives: Option[Int]) extends Pet
object Cat {
@GQLName("Cat")
case class Wrapper(cat: Cat) extends parent.Wrapper {
override val pet = cat
}
object Wrapper {
implicit val argBuilder: ArgBuilder[Wrapper] = ArgBuilder.gen
implicit val schema: Schema[Any, Wrapper] = Schema.gen
}

implicit val argBuilder: ArgBuilder[Cat] = ArgBuilder.gen
implicit val schema: Schema[Any, Cat] = Schema.gen
}

case class Dog(name: Option[String], wagsTail: Option[Boolean]) extends Pet
object Dog {
@GQLName("Dog")
case class Wrapper(dog: Dog) extends parent.Wrapper {
override val pet = dog
}
object Wrapper {
implicit val argBuilder: ArgBuilder[Wrapper] = ArgBuilder.gen
implicit val schema: Schema[Any, Wrapper] = Schema.gen
}

implicit val argBuilder: ArgBuilder[Dog] = ArgBuilder.gen
implicit val schema: Schema[Any, Dog] = Schema.gen
}

implicit val argBuilder: ArgBuilder[Pet] = ArgBuilder.gen
implicit val schema: Schema[Any, Pet] = Schema.gen
}
27 changes: 26 additions & 1 deletion core/src/test/scala/caliban/schema/ArgBuilderSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import caliban.InputValue
import caliban.InputValue.ObjectValue
import caliban.schema.ArgBuilder.auto._
import caliban.Value.{ IntValue, NullValue, StringValue }
import caliban.schema.Annotations.GQLOneOfInput
import caliban.schema.Annotations.{ GQLOneOfInput, GQLValueType }
import zio.test.Assertion._
import zio.test._

Expand Down Expand Up @@ -71,6 +71,31 @@ object ArgBuilderSpec extends ZIOSpecDefault {
)
)
),
suite("derived build")(
test("should fail when null is provided for case class with optional fields") {
case class Foo(value: Option[String])
val ab = ArgBuilder.gen[Foo]
assertTrue(
ab.build(NullValue).isLeft,
// Sanity checks
ab.build(ObjectValue(Map())) == Right(Foo(None)),
ab.build(ObjectValue(Map("value" -> StringValue("foo")))) == Right(Foo(Some("foo"))),
ab.build(ObjectValue(Map("bar" -> StringValue("foo")))) == Right(Foo(None))
)
},
test("should fail when an empty object is provided for GQLValueType case classes") {
@GQLValueType
case class Foo(value: Option[String])
val ab = ArgBuilder.gen[Foo]
assertTrue(
ab.build(ObjectValue(Map())).isLeft,
// Sanity checks
ab.build(NullValue) == Right(Foo(None)),
ab.build(StringValue("foo")) == Right(Foo(Some("foo"))),
ab.build(IntValue(42)).isLeft
)
}
),
suite("buildMissing")(
test("works with derived case class ArgBuilders") {
sealed abstract class Nullable[+T]
Expand Down

0 comments on commit e6b89ae

Please sign in to comment.