-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add filters #164
Add filters #164
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import graphene | ||
|
||
from collections import OrderedDict | ||
from graphene import Argument, Field | ||
from sqlalchemy import inspect | ||
|
||
# Cache for the generated classes, to avoid name clash | ||
_INPUT_CACHE = {} | ||
_INPUT_FIELDS_CACHE = {} | ||
|
||
|
||
class Filter: | ||
@staticmethod | ||
def add_filter_to_query(query, model, field, value): | ||
[(operator, value)] = value.items() | ||
if operator == 'eq': | ||
query = query.filter(getattr(model, field) == value) | ||
elif operator == 'ne': | ||
query = query.filter(getattr(model, field) == value) | ||
elif operator == 'lt': | ||
query = query.filter(getattr(model, field) < value) | ||
elif operator == 'gt': | ||
query = query.filter(getattr(model, field) > value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I didn't have time for it. |
||
elif operator == 'like': | ||
query = query.filter(getattr(model, field).like(value)) | ||
return query | ||
|
||
|
||
def filter_class_for_module(cls): | ||
name = cls.__name__ + "InputFilter" | ||
if name in _INPUT_CACHE: | ||
return Argument(_INPUT_CACHE[name]) | ||
|
||
class InputFilterBase: | ||
pass | ||
|
||
fields = OrderedDict() | ||
for column in inspect(cls).columns.values(): | ||
maybe_field = create_input_filter_field(column) | ||
if maybe_field: | ||
fields[column.name] = maybe_field | ||
input_class = type(name, (InputFilterBase, graphene.InputObjectType), {}) | ||
input_class._meta.fields.update(fields) | ||
_INPUT_CACHE[name] = input_class | ||
return Argument(input_class) | ||
|
||
|
||
def create_input_filter_field(column): | ||
from .converter import convert_sqlalchemy_type | ||
graphene_type = convert_sqlalchemy_type(column.type, column) | ||
if graphene_type.__class__ == Field: # TODO enum not supported | ||
return None | ||
name = str(graphene_type.__class__) + 'Filter' | ||
|
||
if name in _INPUT_FIELDS_CACHE: | ||
return Field(_INPUT_FIELDS_CACHE[name]) | ||
|
||
field_class = Filter | ||
fields = OrderedDict() | ||
fields['eq'] = Field(graphene_type.__class__, description='Field should be equal to given value') | ||
fields['ne'] = Field(graphene_type.__class__, description='Field should not be equal to given value') | ||
fields['lt'] = Field(graphene_type.__class__, description='Field should be less then given value') | ||
fields['gt'] = Field(graphene_type.__class__, description='Field should be great then given value') | ||
fields['like'] = Field(graphene_type.__class__, description='Field should have a pattern of given value') | ||
# TODO construct operators based on __class__ | ||
# TODO complex filter support: OR | ||
|
||
field_class = type(name, (field_class, graphene.InputObjectType), {}) | ||
field_class._meta.fields.update(fields) | ||
_INPUT_FIELDS_CACHE[name] = field_class | ||
return Field(field_class) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
from graphene.relay import Connection, Node | ||
|
||
from ..registry import reset_global_registry | ||
from ..fields import SQLAlchemyConnectionField | ||
from ..fields import SQLAlchemyConnectionField, FilterableConnectionField | ||
from ..types import SQLAlchemyObjectType | ||
from ..utils import sort_argument_for_model, sort_enum_for_model | ||
from .models import Article, Base, Editor, Pet, Reporter | ||
|
@@ -484,3 +484,36 @@ def makeNodes(nodeList): | |
node["node"]["name"] for node in expectedNoSort[key]["edges"] | ||
) | ||
|
||
|
||
def test_filter(session): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect this won't get accepted until all of the possible ways to filter are tested. |
||
sort_setup(session) | ||
|
||
class PetNode(SQLAlchemyObjectType): | ||
class Meta: | ||
model = Pet | ||
interfaces = (Node,) | ||
|
||
class PetConnection(Connection): | ||
class Meta: | ||
node = PetNode | ||
|
||
class Query(graphene.ObjectType): | ||
pets = FilterableConnectionField(PetConnection) | ||
|
||
only_lassie_query = """ | ||
query { | ||
pets(filter: {name: {eq: "Lassie"}}) { | ||
edges { | ||
node { | ||
name | ||
} | ||
} | ||
} | ||
} | ||
""" | ||
schema = graphene.Schema(query=Query) | ||
result = schema.execute(only_lassie_query, context_value={"session": session}) | ||
assert len(result.data['pets']['edges']) == 1 | ||
assert result.data['pets']['edges'][0]['node']['name'] == 'Lassie' | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should
==
actually be!=
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, nice catch :)