From 05bfcd996b20b771df7138b159ad45de6639483c Mon Sep 17 00:00:00 2001 From: jimhorng <jimhorng@qnap.com> Date: Wed, 15 Oct 2014 18:59:02 +0800 Subject: [PATCH] get task result from AsyncResult.get(); add callbacks for sent task ack and sent task Conflicts: tcelery/producer.py --- README.rst | 41 +++++++++++++++++++++------ tcelery/__init__.py | 10 ++++++- tcelery/connection.py | 65 +++++++++++++++++++++++++++++++++++++++++-- tcelery/handlers.py | 18 +++++++----- tcelery/producer.py | 34 +++++++++++----------- tcelery/result.py | 26 +++++++++++++++-- 6 files changed, 156 insertions(+), 38 deletions(-) diff --git a/README.rst b/README.rst index a4354b8..f98e8d8 100644 --- a/README.rst +++ b/README.rst @@ -11,7 +11,7 @@ tornado-celery is a non-blocking Celery client for Tornado web framework Usage ----- -Calling Celery tasks from Tornado RequestHandler: :: +Calling Celery tasks(has return value) from Tornado RequestHandler: :: from tornado import gen, web import tcelery, tasks @@ -19,24 +19,47 @@ Calling Celery tasks from Tornado RequestHandler: :: tcelery.setup_nonblocking_producer() class AsyncHandler(web.RequestHandler): - @asynchronous + @web.asynchronous def get(self): - tasks.echo.apply_async(args=['Hello world!'], callback=self.on_result) + tasks.echo.apply_async(args=['Hello world!'], callback=self.on_async_result) - def on_result(self, response): - self.write(str(response.result)) + def on_async_result(self, async_result): + async_result.get(callback=self.on_actual_result) + + def on_actual_result(self, result): + self.write(str(result)) self.finish() -Calling tasks with generator-based interface: :: +with generator-based interface: :: class GenAsyncHandler(web.RequestHandler): - @asynchronous + @web.asynchronous @gen.coroutine def get(self): - response = yield gen.Task(tasks.sleep.apply_async, args=[3]) - self.write(str(response.result)) + async_result = yield gen.Task(tasks.sleep.apply_async, args=[3]) + result = yield gen.Task(async_result.get) + self.write(str(result)) self.finish() +Calling Celery tasks(no return value) from Tornado RequestHandler: :: + + @web.asynchronous + def get(self): + tasks.echo.apply_async(args=['Hello world!'], callback=self.on_async_result) + + def on_async_result(self, async_result): + self.write("task sent") # ack-ed if BROKER_TRANSPORT_OPTIONS: {'confirm_publish': True} + self.finish() + +with generator-based interface: :: + + @web.asynchronous + @gen.coroutine + def get(self): + yield gen.Task(tasks.sleep.apply_async, args=[3]) + self.write("task sent") # ack-ed if BROKER_TRANSPORT_OPTIONS: {'confirm_publish': True} + self.finish() + **NOTE:** Currently callbacks only work with AMQP and Redis backends. To use the Redis backend, you must install `tornado-redis <https://github.com/leporo/tornado-redis>`_. diff --git a/tcelery/__init__.py b/tcelery/__init__.py index 099e2bf..f07ba04 100644 --- a/tcelery/__init__.py +++ b/tcelery/__init__.py @@ -29,6 +29,14 @@ def connect(): options = celery_app.conf.get('CELERYT_PIKA_OPTIONS', {}) producer_cls.conn_pool.connect(broker_url, options=options, - callback=on_ready) + callback=on_ready, + confirm_delivery=_get_confirm_publish_conf(celery_app.conf)) io_loop.add_callback(connect) + +def _get_confirm_publish_conf(conf): + broker_transport_options = conf.get('BROKER_TRANSPORT_OPTIONS', {}) + if (broker_transport_options and + broker_transport_options.get('confirm_publish') is True): + return True + return False \ No newline at end of file diff --git a/tcelery/connection.py b/tcelery/connection.py index e2aacb5..ec16ebc 100644 --- a/tcelery/connection.py +++ b/tcelery/connection.py @@ -16,16 +16,20 @@ from tornado import ioloop +LOGGER = logging.getLogger(__name__) class Connection(object): content_type = 'application/x-python-serialize' - def __init__(self, io_loop=None): + def __init__(self, io_loop=None, confirm_delivery=False): self.channel = None self.connection = None self.url = None self.io_loop = io_loop or ioloop.IOLoop.instance() + self.confirm_delivery = confirm_delivery + if self.confirm_delivery: + self.confirm_delivery_handler = ConfirmDeliveryHandler() def connect(self, url=None, options=None, callback=None): if url is not None: @@ -61,9 +65,17 @@ def on_connect(self, callback, connection): def on_channel_open(self, callback, channel): self.channel = channel + if self.confirm_delivery: + self.init_confirm_delivery() if callback: callback() + def init_confirm_delivery(self): + self.channel.confirm_delivery(callback=self.confirm_delivery_handler.on_delivery_confirmation, + nowait=True) + self.confirm_delivery_handler.reset_message_seq() + self.confirm_delivery_handler.reset_coroutine_callbacks() + def on_exchange_declare(self, frame): pass @@ -118,10 +130,10 @@ def __init__(self, limit, io_loop=None): self._connection = None self.io_loop = io_loop - def connect(self, broker_url, options=None, callback=None): + def connect(self, broker_url, options=None, callback=None, confirm_delivery=False): self._on_ready = callback for _ in range(self._limit): - conn = Connection(io_loop=self.io_loop) + conn = Connection(io_loop=self.io_loop, confirm_delivery=confirm_delivery) conn.connect(broker_url, options=options, callback=partial(self._on_connect, conn)) @@ -135,3 +147,50 @@ def _on_connect(self, connection): def connection(self): assert self._connection is not None return next(self._connection) + +class ConfirmDeliveryHandler(object): + + def __init__(self): + self._message_seq = 0 + self._acked = 0 + self._nacked = 0 + self._unknown_ack = 0 + self.coroutine_callbacks = {} + + def on_delivery_confirmation(self, method_frame): + """Invoked by pika when RabbitMQ responds to a Basic.Publish RPC + command, passing in either a Basic.Ack or Basic.Nack frame with + the delivery tag of the message that was published. The delivery tag + is an integer counter indicating the message number that was sent + on the channel via Basic.Publish. After Basic.Ack is received, it + will call corresponding callback based on delivery tag number. + + :param pika.frame.Method method_frame: Basic.Ack or Basic.Nack frame + + """ + confirmation_type = method_frame.method.NAME.split('.')[1].lower() + delivery_tag = method_frame.method.delivery_tag + message = ('Received %s for delivery tag: %i' % + (confirmation_type, + delivery_tag)) + LOGGER.debug(message) + + if confirmation_type == 'ack': + self._acked += 1 + elif confirmation_type == 'nack': + self._nacked += 1 + else: + self._unknown_ack += 1 + coroutine_callback = self.coroutine_callbacks.pop(delivery_tag) + if coroutine_callback: + coroutine_callback(None) + + def reset_message_seq(self): + self._message_seq = 0 + + def reset_coroutine_callbacks(self): + self.coroutine_callbacks.clear() + + def add_callback(self, callback): + self._message_seq += 1 + self.coroutine_callbacks[self._message_seq] = callback diff --git a/tcelery/handlers.py b/tcelery/handlers.py index 1458309..aba43f3 100644 --- a/tcelery/handlers.py +++ b/tcelery/handlers.py @@ -230,20 +230,24 @@ def post(self, taskname): partial(self.on_time, task_id)) task.apply_async(args=args, kwargs=kwargs, task_id=task_id, - callback=partial(self.on_complete, htimeout), + callback=partial(self.on_async_result, htimeout), **options) - def on_complete(self, htimeout, result): + def on_async_result(self, htimeout, async_result): + self._result = async_result + async_result.get(callback=partial(self.on_actual_result, htimeout)) + + def on_actual_result(self, htimeout, result): if self._finished: return if htimeout: ioloop.IOLoop.instance().remove_timeout(htimeout) - response = {'task-id': result.task_id, 'state': result.state} - if result.successful(): - response['result'] = result.result + response = {'task-id': self._result.task_id, 'state': self._result.state} + if self._result.successful(): + response['result'] = result else: - response['traceback'] = result.traceback - response['error'] = repr(result.result) + response['traceback'] = self._result.traceback + response['error'] = repr( self._result.result) self.write(response) self.finish() diff --git a/tcelery/producer.py b/tcelery/producer.py index af3b409..bfefcdb 100644 --- a/tcelery/producer.py +++ b/tcelery/producer.py @@ -1,13 +1,14 @@ from __future__ import absolute_import import sys -from functools import partial + from datetime import timedelta from kombu import serialization from kombu.utils import cached_property from celery.app.amqp import TaskProducer from celery.backends.amqp import AMQPBackend from celery.backends.redis import RedisBackend +from celery.backends.base import DisabledBackend from celery.utils import timeutils from .result import AsyncResult @@ -20,7 +21,6 @@ is_py3k = sys.version_info >= (3, 0) - class AMQPConsumer(object): def __init__(self, producer): self.producer = producer @@ -68,10 +68,6 @@ def publish(self, body, routing_key=None, delivery_mode=None, if callback and not callable(callback): raise ValueError('callback should be callable') - if callback and not isinstance(self.app.backend, - (AMQPBackend, RedisBackend)): - raise NotImplementedError( - 'callback can be used only with AMQP or Redis backends') body, content_type, content_encoding = self._prepare( body, serializer, content_type, content_encoding, @@ -94,10 +90,16 @@ def publish(self, body, routing_key=None, delivery_mode=None, exchange=exchange, declare=declare) if callback: - self.consumer.wait_for(task_id, - partial(self.on_result, task_id, callback), - expires=self.prepare_expires(type=int), - persistent=self.app.conf.CELERY_RESULT_PERSISTENT) + async_result = self.result_cls(task_id=task_id, + result=result, + producer=self) + if conn.confirm_delivery: + conn.confirm_delivery_handler.add_callback(lambda result: + callback(async_result)) + + else: + callback(async_result) + return result @cached_property @@ -117,12 +119,6 @@ def decode(self, payload): content_type=self.content_type, content_encoding=self.content_encoding) - def on_result(self, task_id, callback, reply): - reply = self.decode(reply) - reply['task_id'] = task_id - result = self.result_cls(**reply) - callback(result) - def prepare_expires(self, value=None, type=None): if value is None: value = self.app.conf.CELERY_TASK_RESULT_EXPIRES @@ -132,5 +128,11 @@ def prepare_expires(self, value=None, type=None): return type(value * 1000) return value + def fail_if_backend_not_supported(self): + if not isinstance(self.app.backend, + (AMQPBackend, RedisBackend, DisabledBackend)): + raise NotImplementedError( + 'result retrieval can be used only with AMQP or Redis backends') + def __repr__(self): return '<NonBlockingTaskProducer: {0.channel}>'.format(self) diff --git a/tcelery/result.py b/tcelery/result.py index d5626dd..6a9d7ab 100644 --- a/tcelery/result.py +++ b/tcelery/result.py @@ -1,16 +1,19 @@ from __future__ import absolute_import from __future__ import with_statement +from functools import partial + import celery class AsyncResult(celery.result.AsyncResult): - def __init__(self, task_id, status=None, traceback=None, result=None, - **kwargs): + def __init__(self, task_id, status=None, traceback=None, + result=None, producer=None, **kwargs): super(AsyncResult, self).__init__(task_id) self._status = status self._traceback = traceback self._result = result + self._producer = producer @property def status(self): @@ -27,3 +30,22 @@ def traceback(self): @property def result(self): return self._result or super(AsyncResult, self).result + + def get(self, callback=None): + self._producer.fail_if_backend_not_supported() + self._producer.consumer.wait_for(self.task_id, + partial(self.on_result, callback), + expires=self._producer.prepare_expires(type=int), + persistent=self._producer.app.conf.CELERY_RESULT_PERSISTENT) + + def on_result(self, callback, reply): + reply = self._producer.decode(reply) + self._status = reply.get('status') + self._traceback = reply.get('traceback') + self._result = reply.get('result') + if callback: + callback(self._result) + + def _get_task_meta(self): + self._producer.fail_if_backend_not_supported() + return super(AsyncResult, self)._get_task_meta() \ No newline at end of file