Skip to content

Commit

Permalink
Merge pull request #12 from hballard/tests
Browse files Browse the repository at this point in the history
Added tests for subscription_manager module
  • Loading branch information
hballard authored Apr 16, 2017
2 parents 4326eef + 39990e7 commit f7c095e
Show file tree
Hide file tree
Showing 7 changed files with 732 additions and 177 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This is a implementation of apollographql [subscriptions-transport-ws](https://

Meant to be used in conjunction with [graphql-python](https://github.com/graphql-python) / [graphene](http://graphene-python.org/) server and [apollo-client](http://dev.apollodata.com/) for graphql. The api is below, but if you want more information, consult the apollo graphql libraries referenced above.

Initial implementation. Currently only works with Python 2. No tests yet.
Initial implementation. Currently only works with Python 2.

## Installation
```
Expand Down Expand Up @@ -39,7 +39,7 @@ $ pip install graphql-subscriptions
args = kwargs.get('args')
return {
'new_user_channel': {
'filter': lambda user, context: user.active == args.active
'filter': lambda root, context: root.active == args.active
}
}

Expand Down
143 changes: 58 additions & 85 deletions graphql_subscriptions/subscription_manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import redis
import gevent
import cPickle
from types import FunctionType
from promise import Promise

from graphql import parse, validate, specified_rules, value_from_ast, execute
from graphql.language.ast import OperationDefinition
from promise import Promise
import gevent
import redis

from .utils import to_snake_case
from .validation import SubscriptionHasSingleRootField

class RedisPubsub(object):

class RedisPubsub(object):
def __init__(self, host='localhost', port=6379, *args, **kwargs):
redis.connection.socket = gevent.socket
self.redis = redis.StrictRedis(host, port, *args, **kwargs)
Expand All @@ -29,13 +32,10 @@ def subscribe(self, trigger_name, on_message_handler, options):
except IndexError:
self.pubsub.subscribe(trigger_name)
self.subscriptions[self.sub_id_counter] = [
trigger_name,
on_message_handler
trigger_name, on_message_handler
]
if not self.greenlet:
self.greenlet = gevent.spawn(
self.wait_and_get_message
)
self.greenlet = gevent.spawn(self.wait_and_get_message)
return Promise.resolve(self.sub_id_counter)

def unsubscribe(self, sub_id):
Expand Down Expand Up @@ -63,14 +63,12 @@ def handle_message(self, message):


class ValidationError(Exception):

def __init__(self, errors):
self.errors = errors
self.message = 'Subscription query has validation errors'


class SubscriptionManager(object):

def __init__(self, schema, pubsub, setup_funcs={}):
self.schema = schema
self.pubsub = pubsub
Expand All @@ -84,16 +82,11 @@ def publish(self, trigger_name, payload):
def subscribe(self, query, operation_name, callback, variables, context,
format_error, format_response):
parsed_query = parse(query)
errors = validate(
self.schema,
parsed_query,
# TODO: Need to create/add subscriptionHasSingleRootField
# rule from apollo subscription manager package
rules=specified_rules
)
rules = specified_rules + [SubscriptionHasSingleRootField]
errors = validate(self.schema, parsed_query, rules=rules)

if errors:
return Promise.reject(ValidationError(errors))
return Promise.rejected(ValidationError(errors))

args = {}

Expand All @@ -110,29 +103,25 @@ def subscribe(self, query, operation_name, callback, variables, context,
for arg in root_field.arguments:

arg_definition = [
arg_def for _, arg_def in
fields.get(subscription_name).args.iteritems() if
arg_def.out_name == arg.name.value
arg_def
for _, arg_def in fields.get(subscription_name)
.args.iteritems() if arg_def.out_name == arg.name.value
][0]

args[arg_definition.out_name] = value_from_ast(
arg.value,
arg_definition.type,
variables=variables
)

if self.setup_funcs.get(subscription_name):
trigger_map = self.setup_funcs[subscription_name](
query,
operation_name,
callback,
variables,
context,
format_error,
format_response,
args,
subscription_name
)
arg.value, arg_definition.type, variables=variables)

if self.setup_funcs.get(to_snake_case(subscription_name)):
trigger_map = self.setup_funcs[to_snake_case(subscription_name)](
query=query,
operation_name=operation_name,
callback=callback,
variables=variables,
context=context,
format_error=format_error,
format_response=format_response,
args=args,
subscription_name=subscription_name)
else:
trigger_map = {}
trigger_map[subscription_name] = {}
Expand All @@ -143,71 +132,55 @@ def subscribe(self, query, operation_name, callback, variables, context,
subscription_promises = []

for trigger_name in trigger_map.viewkeys():
channel_options = trigger_map[trigger_name].get(
'channel_options',
{}
)
filter = trigger_map[trigger_name].get(
'filter',
lambda arg1, arg2: True
)
try:
channel_options = trigger_map[trigger_name].get(
'channel_options', {})
filter = trigger_map[trigger_name].get('filter',
lambda arg1, arg2: True)
# TODO: Think about this some more...the Apollo library
# let's all messages through by default, even if
# the users incorrectly uses the setup_funcs (does not
# use 'filter' or 'channel_options' keys); I think it
# would be better to raise an exception here
except AttributeError:
channel_options = {}

def filter(arg1, arg2):
return True

def on_message(root_value):

def context_promise_handler(result):
if isinstance(context, FunctionType):
return context()
else:
return context

def filter_func_promise_handler(context):
return Promise.all([
context,
filter(root_value, context)
])
return Promise.all([context, filter(root_value, context)])

def context_do_execute_handler(result):
context, do_execute = result
if not do_execute:
return
else:
return execute(
self.schema,
parsed_query,
root_value,
context,
variables,
operation_name
)

return Promise.resolve(
True
).then(
context_promise_handler
).then(
filter_func_promise_handler
).then(
context_do_execute_handler
).then(
lambda result: callback(None, result)
).catch(
lambda error: callback(error, None)
)
return execute(self.schema, parsed_query, root_value,
context, variables, operation_name)

return Promise.resolve(True).then(
context_promise_handler).then(
filter_func_promise_handler).then(
context_do_execute_handler).then(
lambda result: callback(None, result)).catch(
lambda error: callback(error, None))

subscription_promises.append(
self.pubsub.subscribe(
trigger_name,
on_message,
channel_options
).then(
lambda id: self.subscriptions[
external_subscription_id].append(id)
)
)
self.pubsub.
subscribe(trigger_name, on_message, channel_options).then(
lambda id: self.subscriptions[external_subscription_id].append(id)
))

return Promise.all(subscription_promises).then(
lambda result: external_subscription_id
)
lambda result: external_subscription_id)

def unsubscribe(self, sub_id):
for internal_id in self.subscriptions.get(sub_id):
Expand Down
Loading

0 comments on commit f7c095e

Please sign in to comment.