diff --git a/graphene_federation/tests/test_key.py b/graphene_federation/tests/test_key.py index 2b9b83f..f410845 100644 --- a/graphene_federation/tests/test_key.py +++ b/graphene_federation/tests/test_key.py @@ -2,7 +2,7 @@ from graphql import graphql_sync -from graphene import ObjectType, ID, String, Field +from graphene import ObjectType, ID, String, Field, Enum from ..entity import key from ..main import build_schema @@ -295,3 +295,106 @@ class Query(ObjectType): build_schema(query=Query, enable_federation_2=True) assert 'Invalid compound key definition for type "User"' == str(err.value) + + +def test_compound_primary_key_with_enum(): + class OrgTypeEnum(Enum): + LOCAL = 0 + GLOBAL = 1 + + class UserTypeEnum(Enum): + ADMIN = "ADMIN" + USER = "USER" + + class Organization(ObjectType): + registration_number = ID() + organization_type = OrgTypeEnum() + + @key("id organization { registration_number organization_type } user_type ") + class User(ObjectType): + id = ID() + user_type = UserTypeEnum() + organization = Field(Organization) + + class Query(ObjectType): + user = Field(User) + + schema = build_schema(query=Query, enable_federation_2=True) + assert ( + str(schema).strip() + == """ +type Query { + user: User + _entities(representations: [_Any!]!): [_Entity]! + _service: _Service! +} + +type User { + id: ID + userType: UserTypeEnum + organization: Organization +} + +enum UserTypeEnum { + ADMIN + USER +} + +type Organization { + registrationNumber: ID + organizationType: OrgTypeEnum +} + +enum OrgTypeEnum { + LOCAL + GLOBAL +} + +union _Entity = User + +scalar _Any + +type _Service { + sdl: String +}""".strip() + ) + # Check the federation service schema definition language + query = """ + query { + _service { + sdl + } + } + """ + result = graphql_sync(schema.graphql_schema, query) + assert not result.errors + assert ( + result.data["_service"]["sdl"].strip() + == """ +extend schema @link(url: "https://specs.apollo.dev/federation/v2.0", import: ["@key"]) +type Query { + user: User +} + +type User @key(fields: "id organization { registrationNumber organizationType } userType ") { + id: ID + userType: UserTypeEnum + organization: Organization +} + +enum UserTypeEnum { + ADMIN + USER +} + +type Organization { + registrationNumber: ID + organizationType: OrgTypeEnum +} + +enum OrgTypeEnum { + LOCAL + GLOBAL +} +""".strip() + ) diff --git a/graphene_federation/utils.py b/graphene_federation/utils.py index 917ce52..38a973d 100644 --- a/graphene_federation/utils.py +++ b/graphene_federation/utils.py @@ -2,7 +2,7 @@ import graphene from graphene import Schema, ObjectType -from graphene.types.definitions import GrapheneObjectType +from graphene.types.definitions import GrapheneObjectType, GrapheneEnumType from graphene.types.enum import EnumOptions from graphene.types.scalars import ScalarOptions from graphene.types.union import UnionOptions @@ -63,23 +63,27 @@ def is_valid_compound_key(type_name: str, key: str, schema: Schema): return False field_type = parent_type_fields[field_name].type + + is_scalar_field_type = isinstance(field_type, GraphQLScalarType) or ( + isinstance(field_type, GraphQLNonNull) + and isinstance(field_type.of_type, GraphQLScalarType) + ) + + is_enum_field_type = isinstance(field_type, GrapheneEnumType) or ( + isinstance(field_type, GraphQLNonNull) + and isinstance(field_type.of_type, GrapheneEnumType) + ) + if field.selection_set: # If the field has sub-selections, add it to node mappings to check for valid subfields - - if isinstance(field_type, GraphQLScalarType) or ( - isinstance(field_type, GraphQLNonNull) - and isinstance(field_type.of_type, GraphQLScalarType) - ): + if is_scalar_field_type or is_enum_field_type: # sub-selections are added to a scalar type, key is not valid return False key_nodes.append((field, field_type)) else: - # If there are no sub-selections for a field, it should be a scalar - if not isinstance(field_type, GraphQLScalarType) and not ( - isinstance(field_type, GraphQLNonNull) - and isinstance(field_type.of_type, GraphQLScalarType) - ): + # If there are no sub-selections for a field, it should be a scalar or Enum + if not (is_scalar_field_type or is_enum_field_type): return False key_nodes.pop(0) # Remove the current node as it is fully processed