diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index efcf3c6..248ac40 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -7,6 +7,7 @@ from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import ( ColumnProperty, RelationshipProperty, @@ -14,7 +15,6 @@ interfaces, strategies, ) -from sqlalchemy.ext.hybrid import hybrid_property import graphene from graphene.types.json import JSONString @@ -159,7 +159,7 @@ def convert_sqlalchemy_relationship( ): """ :param sqlalchemy.RelationshipProperty relationship_prop: - :param SQLAlchemyObjectType obj_type: + :param SQLAlchemyBase obj_type: :param function|None connection_field_factory: :param bool batching: :param str orm_field_name: @@ -202,7 +202,7 @@ def _convert_o2o_or_m2o_relationship( Convert one-to-one or many-to-one relationshsip. Return an object field. :param sqlalchemy.RelationshipProperty relationship_prop: - :param SQLAlchemyObjectType obj_type: + :param SQLAlchemyBase obj_type: :param bool batching: :param str orm_field_name: :param dict field_kwargs: @@ -230,7 +230,7 @@ def _convert_o2m_or_m2m_relationship( Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. :param sqlalchemy.RelationshipProperty relationship_prop: - :param SQLAlchemyObjectType obj_type: + :param SQLAlchemyBase obj_type: :param bool batching: :param function|None connection_field_factory: :param dict field_kwargs: @@ -362,7 +362,7 @@ def get_type_from_registry(): raise TypeError( "No model found in Registry for type %s. " "Only references to SQLAlchemy Models mapped to " - "SQLAlchemyObjectTypes are allowed." % type_arg + "SQLAlchemyBase types are allowed." % type_arg ) return get_type_from_registry() @@ -680,7 +680,7 @@ def forward_reference_solver(): raise TypeError( "No model found in Registry for forward reference for type %s. " "Only forward references to other SQLAlchemy Models mapped to " - "SQLAlchemyObjectTypes are allowed." % type_arg + "SQLAlchemyBase types are allowed." % type_arg ) # Always fall back to string if no ForwardRef type found. return get_global_registry().get_type_for_model(model) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index 97f8997..e97516f 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -53,10 +53,10 @@ def enum_for_sa_enum(sa_enum, registry): def enum_for_field(obj_type, field_name): """Return the Graphene Enum type for the specified Graphene field.""" - from .types import SQLAlchemyObjectType + from .types import SQLAlchemyBase - if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): - raise TypeError("Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) + if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase): + raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type)) if not field_name or not isinstance(field_name, str): raise TypeError("Expected a field name, but got: {!r}".format(field_name)) registry = obj_type._meta.registry @@ -88,10 +88,10 @@ def _default_sort_enum_symbol_name(column_name, sort_asc=True): def sort_enum_for_object_type( obj_type, name=None, only_fields=None, only_indexed=None, get_symbol_name=None ): - """Return Graphene Enum for sorting the given SQLAlchemyObjectType. + """Return Graphene Enum for sorting the given SQLAlchemyBase. Parameters - - obj_type : SQLAlchemyObjectType + - obj_type : SQLAlchemyBase The object type for which the sort Enum shall be generated. - name : str, optional, default None Name to use for the sort Enum. @@ -160,10 +160,10 @@ def sort_argument_for_object_type( get_symbol_name=None, has_default=True, ): - """ "Returns Graphene Argument for sorting the given SQLAlchemyObjectType. + """ "Returns Graphene Argument for sorting the given SQLAlchemyBase. Parameters - - obj_type : SQLAlchemyObjectType + - obj_type : SQLAlchemyBase The object type for which the sort Argument shall be generated. - enum_name : str, optional, default None Name to use for the sort Enum. diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index ef79885..f059fda 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -26,21 +26,21 @@ class SQLAlchemyConnectionField(ConnectionField): @property def type(self): - from .types import SQLAlchemyObjectType + from .types import SQLAlchemyBase type_ = super(ConnectionField, self).type nullable_type = get_nullable_type(type_) if issubclass(nullable_type, Connection): return type_ - assert issubclass(nullable_type, SQLAlchemyObjectType), ( - "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" + assert issubclass(nullable_type, SQLAlchemyBase), ( + "SQLALchemyConnectionField only accepts SQLAlchemyBase types, not {}" ).format(nullable_type.__name__) assert nullable_type.connection, "The type {} doesn't have a connection".format( nullable_type.__name__ ) assert type_ == nullable_type, ( - "Passing a SQLAlchemyObjectType instance is deprecated. " - "Pass the connection type instead accessible via SQLAlchemyObjectType.connection" + "Passing a SQLAlchemyBase instance is deprecated. " + "Pass the connection type instead accessible via SQLAlchemyBase.connection" ) return nullable_type.connection @@ -266,7 +266,7 @@ def default_connection_field_factory(relationship, registry, **field_kwargs): def createConnectionField(type_, **field_kwargs): warnings.warn( "createConnectionField is deprecated and will be removed in the next " - "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", + "major version. Use SQLAlchemyBase.Meta.connection_field_factory instead.", DeprecationWarning, ) return __connectionFactory(type_, **field_kwargs) @@ -275,7 +275,7 @@ def createConnectionField(type_, **field_kwargs): def registerConnectionFieldFactory(factoryMethod): warnings.warn( "registerConnectionFieldFactory is deprecated and will be removed in the next " - "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", + "major version. Use SQLAlchemyBase.Meta.connection_field_factory instead.", DeprecationWarning, ) global __connectionFactory @@ -285,7 +285,7 @@ def registerConnectionFieldFactory(factoryMethod): def unregisterConnectionFieldFactory(): warnings.warn( "registerConnectionFieldFactory is deprecated and will be removed in the next " - "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", + "major version. Use SQLAlchemyBase.Meta.connection_field_factory instead.", DeprecationWarning, ) global __connectionFactory diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index b959d22..a3d645c 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -95,15 +95,10 @@ def get_graphene_enum_for_sa_enum(self, sa_enum: SQLAlchemyEnumType): return self._registry_enums.get(sa_enum) def register_sort_enum(self, obj_type, sort_enum: Enum): + from .types import SQLAlchemyBase - from .types import SQLAlchemyObjectType - - if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType - ): - raise TypeError( - "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) - ) + if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase): + raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type)) if not isinstance(sort_enum, type(Enum)): raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum)) self._registry_sort_enums[obj_type] = sort_enum diff --git a/graphene_sqlalchemy/resolvers.py b/graphene_sqlalchemy/resolvers.py index e8e6191..d5ce48b 100644 --- a/graphene_sqlalchemy/resolvers.py +++ b/graphene_sqlalchemy/resolvers.py @@ -19,7 +19,7 @@ def get_attr_resolver(obj_type, model_attr): In order to support field renaming via `ORMField.model_attr`, we need to define resolver functions for each field. - :param SQLAlchemyObjectType obj_type: + :param SQLAlchemyBase obj_type: :param str model_attr: the name of the SQLAlchemy attribute :rtype: Callable """ diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index e1ee985..b66c314 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -402,6 +402,46 @@ class Employee(Person): } +class Owner(Base): + id = Column(Integer(), primary_key=True) + name = Column(String()) + + accounts = relationship(lambda: Account, back_populates="owner", lazy="selectin") + + __tablename__ = "owner" + + +class Account(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + + owner_id = Column(Integer(), ForeignKey(Owner.__table__.c.id)) + owner = relationship(Owner, back_populates="accounts") + + balance = Column(Integer()) + + __tablename__ = "account" + __mapper_args__ = { + "polymorphic_on": type, + } + + +class CurrentAccount(Account): + overdraft = Column(Integer()) + + __mapper_args__ = { + "polymorphic_identity": "current", + } + + +class SavingsAccount(Account): + interest_rate = Column(Integer()) + + __mapper_args__ = { + "polymorphic_identity": "savings", + } + + ############################################ # Custom Test Models ############################################ diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index e62e07d..86fd510 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -130,7 +130,7 @@ def hybrid_prop(self) -> "MyTypeNotInRegistry": with pytest.raises( TypeError, match=r"(.*)Only forward references to other SQLAlchemy Models mapped to " - "SQLAlchemyObjectTypes are allowed.(.*)", + "SQLAlchemyBase types are allowed.(.*)", ): get_hybrid_property_type(hybrid_prop).type diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py index 3de6904..580a9d0 100644 --- a/graphene_sqlalchemy/tests/test_enums.py +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -119,6 +119,6 @@ class Meta: with pytest.raises(TypeError, match=re_err): PetType.enum_for_field(None) - re_err = "Expected SQLAlchemyObjectType, but got: None" + re_err = "Expected SQLAlchemyBase, but got: None" with pytest.raises(TypeError, match=re_err): enum_for_field(None, "other_kind") diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 9fed146..3d84d47 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -5,8 +5,10 @@ from graphene.relay import Connection, Node from ..fields import SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField -from ..types import SQLAlchemyObjectType +from ..types import SQLAlchemyInterface, SQLAlchemyObjectType from .models import Editor as EditorModel +from .models import Employee as EmployeeModel +from .models import Person as PersonModel from .models import Pet as PetModel @@ -21,6 +23,18 @@ class Meta: model = EditorModel +class Person(SQLAlchemyInterface): + class Meta: + model = PersonModel + use_connection = True + + +class Employee(SQLAlchemyObjectType): + class Meta: + model = EmployeeModel + interfaces = (Person, Node) + + ## # SQLAlchemyConnectionField ## @@ -51,7 +65,7 @@ def resolver(_obj, _info): def test_type_assert_sqlalchemy_object_type(): - with pytest.raises(AssertionError, match="only accepts SQLAlchemyObjectType"): + with pytest.raises(AssertionError, match="only accepts SQLAlchemyBase types"): SQLAlchemyConnectionField(ObjectType).type @@ -91,3 +105,10 @@ def test_custom_sort(): def test_sort_init_raises(): with pytest.raises(TypeError, match="Cannot create sort"): SQLAlchemyConnectionField(Connection) + + +def test_interface_required_sqlalachemy_connection(): + field = SQLAlchemyConnectionField(Person.connection, required=True) + assert isinstance(field.type, NonNull) + assert issubclass(field.type.of_type, Connection) + assert field.type.of_type._meta.node is Person diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 4acf89a..cb9b49f 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -6,15 +6,19 @@ from ..fields import SQLAlchemyConnectionField from ..filters import FloatFilter -from ..types import ORMField, SQLAlchemyObjectType +from ..types import ORMField, SQLAlchemyInterface, SQLAlchemyObjectType from .models import ( + Account, Article, + CurrentAccount, Editor, HairKind, Image, + Owner, Pet, Reader, Reporter, + SavingsAccount, ShoppingCart, ShoppingCartItem, Tag, @@ -1199,3 +1203,149 @@ async def test_additional_filters(session): schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) assert_and_raise_result(result, expected) + + +# Test relationship filter for interface fields +async def add_relationship_interface_test_data(session): + owner1 = Owner(name="John Doe") + owner2 = Owner(name="Jane Doe") + session.add_all([owner1, owner2]) + + o1_account1 = CurrentAccount(owner=owner1, balance=1000, overdraft=100) + o1_account2 = CurrentAccount(owner=owner1, balance=2000, overdraft=50) + o1_account3 = SavingsAccount(owner=owner1, balance=300, interest_rate=3) + + o2_account1 = CurrentAccount(owner=owner2, balance=1000, overdraft=100) + o2_account2 = SavingsAccount(owner=owner2, balance=300, interest_rate=3) + session.add_all([o1_account1, o1_account2, o1_account3, o2_account1, o2_account2]) + + await eventually_await_session(session, "commit") + + +def create_relationship_interface_schema(session): + class OwnerType(SQLAlchemyObjectType): + class Meta: + model = Owner + interfaces = (relay.Node,) + + class AccountType(SQLAlchemyInterface): + class Meta: + model = Account + use_connection = True + + class CurrentAccountType(SQLAlchemyObjectType): + class Meta: + model = CurrentAccount + interfaces = ( + AccountType, + relay.Node, + ) + + class SavingsAccountType(SQLAlchemyObjectType): + class Meta: + model = SavingsAccount + interfaces = ( + AccountType, + relay.Node, + ) + + class Query(graphene.ObjectType): + node = relay.Node.Field() + owners = SQLAlchemyConnectionField(OwnerType.connection) + accounts = SQLAlchemyConnectionField(AccountType.connection) + + return (Query, [CurrentAccountType, SavingsAccountType]) + + +@pytest.mark.asyncio +async def test_filter_relationship_interface(session): + await add_relationship_interface_test_data(session) + + (Query, types) = create_relationship_interface_schema(session) + + query = """ + query { + owners(filter: { accounts: { contains: [{balance: {gte: 2000}}]}}) { + edges { + node { + name + accounts { + edges { + node { + __typename + balance + } + } + } + } + } + } + } + """ + expected = { + "owners": { + "edges": [ + { + "node": { + "name": "John Doe", + "accounts": { + "edges": [ + { + "node": { + "__typename": "CurrentAccountType", + "balance": 1000, + }, + }, + { + "node": { + "__typename": "CurrentAccountType", + "balance": 2000, + }, + }, + { + "node": { + "__typename": "SavingsAccountType", + "balance": 300, + }, + }, + ], + }, + }, + }, + ], + }, + } + schema = graphene.Schema(query=Query, types=types) + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) + + query = """ + query { + owners(filter: { accounts: { contains: [{balance: {gte: 1000}}]}}) { + edges { + node { + name + } + } + } + } + """ + expected = { + "owners": { + "edges": [ + { + "node": { + "name": "John Doe", + }, + }, + { + "node": { + "name": "Jane Doe", + }, + }, + ] + }, + } + + result = await schema.execute_async(query, context_value={"session": session}) + assert_and_raise_result(result, expected) diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index e54f08b..1869ea7 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -120,7 +120,7 @@ class Meta: [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))], ) - re_err = r"Expected SQLAlchemyObjectType, but got: .*PetSort.*" + re_err = r"Expected SQLAlchemyBase, but got: .*PetSort.*" with pytest.raises(TypeError, match=re_err): reg.register_sort_enum(sort_enum, sort_enum) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 7053988..cccda81 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -67,8 +67,8 @@ def __init__( **field_kwargs, ): """ - Use this to override fields automatically generated by SQLAlchemyObjectType. - Unless specified, options will default to SQLAlchemyObjectType usual behavior + Use this to override fields automatically generated by SQLAlchemyBase. + Unless specified, options will default to SQLAlchemyBase usual behavior for the given SQLAlchemy model property. Usage: @@ -159,8 +159,8 @@ def filter_field_from_field( model_attr: Any, model_attr_name: str, ) -> Optional[graphene.InputField]: - # Field might be a SQLAlchemyObjectType, due to hybrid properties - if issubclass(type_, SQLAlchemyObjectType): + # Field might be a SQLAlchemyBase, due to hybrid properties + if issubclass(type_, SQLAlchemyBase): filter_class = registry.get_filter_for_base_type(type_) # Enum Special Case elif issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty): @@ -285,13 +285,13 @@ def construct_fields_and_filters( connection_field_factory, ): """ - Construct all the fields for a SQLAlchemyObjectType. + Construct all the fields for a SQLAlchemyBase. The main steps are: - Gather all the relevant attributes from the SQLAlchemy model - Gather all the ORM fields defined on the type - Merge in overrides and build up all the fields - :param SQLAlchemyObjectType obj_type: + :param SQLAlchemyBase obj_type: :param model: the SQLAlchemy model :param Registry registry: :param tuple[string] only_fields: