Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filters #164

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion graphene_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .types import SQLAlchemyObjectType
from .fields import SQLAlchemyConnectionField
from .fields import SQLAlchemyConnectionField, FilterableConnectionField
from .utils import get_query, get_session

__version__ = "2.1.0"
Expand All @@ -8,6 +8,7 @@
"__version__",
"SQLAlchemyObjectType",
"SQLAlchemyConnectionField",
"FilterableConnectionField",
"get_query",
"get_session",
]
19 changes: 19 additions & 0 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice

from .filters import filter_class_for_module, Filter
from .utils import get_query, sort_argument_for_model


Expand Down Expand Up @@ -94,6 +95,24 @@ def __init__(self, type, *args, **kwargs):
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)


class FilterableConnectionField(SQLAlchemyConnectionField):
def __init__(self, type, *args, **kwargs):
if 'filter' not in kwargs and issubclass(type, Connection):
model = type.Edge.node._type._meta.model
kwargs.setdefault('filter', filter_class_for_module(model))
elif "filter" in kwargs and kwargs["filter"] is None:
del kwargs["filter"]
super(FilterableConnectionField, self).__init__(type, *args, **kwargs)

@classmethod
def get_query(cls, model, info, filter=None, **kwargs):
query = super(FilterableConnectionField, cls).get_query(model, info, **kwargs)
if filter:
for k, v in filter.items():
query = Filter.add_filter_to_query(query, model, k, v)
return query


__connectionFactory = UnsortedSQLAlchemyConnectionField


Expand Down
71 changes: 71 additions & 0 deletions graphene_sqlalchemy/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import graphene

from collections import OrderedDict
from graphene import Argument, Field
from sqlalchemy import inspect

# Cache for the generated classes, to avoid name clash
_INPUT_CACHE = {}
_INPUT_FIELDS_CACHE = {}


class Filter:
@staticmethod
def add_filter_to_query(query, model, field, value):
[(operator, value)] = value.items()
if operator == 'eq':
query = query.filter(getattr(model, field) == value)
elif operator == 'ne':
query = query.filter(getattr(model, field) == value)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should == actually be !=?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, nice catch :)

elif operator == 'lt':
query = query.filter(getattr(model, field) < value)
elif operator == 'gt':
query = query.filter(getattr(model, field) > value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about the le and ge operators?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't have time for it.

elif operator == 'like':
query = query.filter(getattr(model, field).like(value))
return query


def filter_class_for_module(cls):
name = cls.__name__ + "InputFilter"
if name in _INPUT_CACHE:
return Argument(_INPUT_CACHE[name])

class InputFilterBase:
pass

fields = OrderedDict()
for column in inspect(cls).columns.values():
maybe_field = create_input_filter_field(column)
if maybe_field:
fields[column.name] = maybe_field
input_class = type(name, (InputFilterBase, graphene.InputObjectType), {})
input_class._meta.fields.update(fields)
_INPUT_CACHE[name] = input_class
return Argument(input_class)


def create_input_filter_field(column):
from .converter import convert_sqlalchemy_type
graphene_type = convert_sqlalchemy_type(column.type, column)
if graphene_type.__class__ == Field: # TODO enum not supported
return None
name = str(graphene_type.__class__) + 'Filter'

if name in _INPUT_FIELDS_CACHE:
return Field(_INPUT_FIELDS_CACHE[name])

field_class = Filter
fields = OrderedDict()
fields['eq'] = Field(graphene_type.__class__, description='Field should be equal to given value')
fields['ne'] = Field(graphene_type.__class__, description='Field should not be equal to given value')
fields['lt'] = Field(graphene_type.__class__, description='Field should be less then given value')
fields['gt'] = Field(graphene_type.__class__, description='Field should be great then given value')
fields['like'] = Field(graphene_type.__class__, description='Field should have a pattern of given value')
# TODO construct operators based on __class__
# TODO complex filter support: OR

field_class = type(name, (field_class, graphene.InputObjectType), {})
field_class._meta.fields.update(fields)
_INPUT_FIELDS_CACHE[name] = field_class
return Field(field_class)
35 changes: 34 additions & 1 deletion graphene_sqlalchemy/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from graphene.relay import Connection, Node

from ..registry import reset_global_registry
from ..fields import SQLAlchemyConnectionField
from ..fields import SQLAlchemyConnectionField, FilterableConnectionField
from ..types import SQLAlchemyObjectType
from ..utils import sort_argument_for_model, sort_enum_for_model
from .models import Article, Base, Editor, Pet, Reporter
Expand Down Expand Up @@ -484,3 +484,36 @@ def makeNodes(nodeList):
node["node"]["name"] for node in expectedNoSort[key]["edges"]
)


def test_filter(session):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this won't get accepted until all of the possible ways to filter are tested.

sort_setup(session)

class PetNode(SQLAlchemyObjectType):
class Meta:
model = Pet
interfaces = (Node,)

class PetConnection(Connection):
class Meta:
node = PetNode

class Query(graphene.ObjectType):
pets = FilterableConnectionField(PetConnection)

only_lassie_query = """
query {
pets(filter: {name: {eq: "Lassie"}}) {
edges {
node {
name
}
}
}
}
"""
schema = graphene.Schema(query=Query)
result = schema.execute(only_lassie_query, context_value={"session": session})
assert len(result.data['pets']['edges']) == 1
assert result.data['pets']['edges'][0]['node']['name'] == 'Lassie'