From 05bfcd996b20b771df7138b159ad45de6639483c Mon Sep 17 00:00:00 2001
From: jimhorng <jimhorng@qnap.com>
Date: Wed, 15 Oct 2014 18:59:02 +0800
Subject: [PATCH] get task result from AsyncResult.get(); add callbacks for
 sent task ack and sent task

Conflicts:
	tcelery/producer.py
---
 README.rst            | 41 +++++++++++++++++++++------
 tcelery/__init__.py   | 10 ++++++-
 tcelery/connection.py | 65 +++++++++++++++++++++++++++++++++++++++++--
 tcelery/handlers.py   | 18 +++++++-----
 tcelery/producer.py   | 34 +++++++++++-----------
 tcelery/result.py     | 26 +++++++++++++++--
 6 files changed, 156 insertions(+), 38 deletions(-)

diff --git a/README.rst b/README.rst
index a4354b8..f98e8d8 100644
--- a/README.rst
+++ b/README.rst
@@ -11,7 +11,7 @@ tornado-celery is a non-blocking Celery client for Tornado web framework
 Usage
 -----
 
-Calling Celery tasks from Tornado RequestHandler: ::
+Calling Celery tasks(has return value) from Tornado RequestHandler: ::
 
     from tornado import gen, web
     import tcelery, tasks
@@ -19,24 +19,47 @@ Calling Celery tasks from Tornado RequestHandler: ::
     tcelery.setup_nonblocking_producer()
 
     class AsyncHandler(web.RequestHandler):
-        @asynchronous
+        @web.asynchronous
         def get(self):
-            tasks.echo.apply_async(args=['Hello world!'], callback=self.on_result)
+            tasks.echo.apply_async(args=['Hello world!'], callback=self.on_async_result)
 
-        def on_result(self, response):
-            self.write(str(response.result))
+        def on_async_result(self, async_result):
+            async_result.get(callback=self.on_actual_result)
+        
+        def on_actual_result(self, result):
+            self.write(str(result))
             self.finish()
 
-Calling tasks with generator-based interface: ::
+with generator-based interface: ::
 
     class GenAsyncHandler(web.RequestHandler):
-        @asynchronous
+        @web.asynchronous
         @gen.coroutine
         def get(self):
-            response = yield gen.Task(tasks.sleep.apply_async, args=[3])
-            self.write(str(response.result))
+            async_result = yield gen.Task(tasks.sleep.apply_async, args=[3])
+            result = yield gen.Task(async_result.get)
+            self.write(str(result))
             self.finish()
 
+Calling Celery tasks(no return value) from Tornado RequestHandler: ::
+
+    @web.asynchronous
+    def get(self):
+        tasks.echo.apply_async(args=['Hello world!'], callback=self.on_async_result)
+
+    def on_async_result(self, async_result):
+        self.write("task sent") # ack-ed if BROKER_TRANSPORT_OPTIONS: {'confirm_publish': True}
+        self.finish()
+
+with generator-based interface: ::
+
+    @web.asynchronous
+    @gen.coroutine
+    def get(self):
+        yield gen.Task(tasks.sleep.apply_async, args=[3])
+        self.write("task sent") # ack-ed if BROKER_TRANSPORT_OPTIONS: {'confirm_publish': True}
+        self.finish()
+
 **NOTE:** Currently callbacks only work with AMQP and Redis backends.
 To use the Redis backend, you must install `tornado-redis
 <https://github.com/leporo/tornado-redis>`_.
diff --git a/tcelery/__init__.py b/tcelery/__init__.py
index 099e2bf..f07ba04 100644
--- a/tcelery/__init__.py
+++ b/tcelery/__init__.py
@@ -29,6 +29,14 @@ def connect():
         options = celery_app.conf.get('CELERYT_PIKA_OPTIONS', {})
         producer_cls.conn_pool.connect(broker_url,
                                        options=options,
-                                       callback=on_ready)
+                                       callback=on_ready,
+                                       confirm_delivery=_get_confirm_publish_conf(celery_app.conf))
 
     io_loop.add_callback(connect)
+
+def _get_confirm_publish_conf(conf):
+    broker_transport_options = conf.get('BROKER_TRANSPORT_OPTIONS', {})
+    if (broker_transport_options and
+        broker_transport_options.get('confirm_publish') is True):
+        return True
+    return False
\ No newline at end of file
diff --git a/tcelery/connection.py b/tcelery/connection.py
index e2aacb5..ec16ebc 100644
--- a/tcelery/connection.py
+++ b/tcelery/connection.py
@@ -16,16 +16,20 @@
 
 from tornado import ioloop
 
+LOGGER = logging.getLogger(__name__)
 
 class Connection(object):
 
     content_type = 'application/x-python-serialize'
 
-    def __init__(self, io_loop=None):
+    def __init__(self, io_loop=None, confirm_delivery=False):
         self.channel = None
         self.connection = None
         self.url = None
         self.io_loop = io_loop or ioloop.IOLoop.instance()
+        self.confirm_delivery = confirm_delivery
+        if self.confirm_delivery:
+            self.confirm_delivery_handler = ConfirmDeliveryHandler() 
 
     def connect(self, url=None, options=None, callback=None):
         if url is not None:
@@ -61,9 +65,17 @@ def on_connect(self, callback, connection):
 
     def on_channel_open(self, callback, channel):
         self.channel = channel
+        if self.confirm_delivery:
+            self.init_confirm_delivery()
         if callback:
             callback()
 
+    def init_confirm_delivery(self):
+        self.channel.confirm_delivery(callback=self.confirm_delivery_handler.on_delivery_confirmation,
+                                      nowait=True)
+        self.confirm_delivery_handler.reset_message_seq()
+        self.confirm_delivery_handler.reset_coroutine_callbacks()
+
     def on_exchange_declare(self, frame):
         pass
 
@@ -118,10 +130,10 @@ def __init__(self, limit, io_loop=None):
         self._connection = None
         self.io_loop = io_loop
 
-    def connect(self, broker_url, options=None, callback=None):
+    def connect(self, broker_url, options=None, callback=None, confirm_delivery=False):
         self._on_ready = callback
         for _ in range(self._limit):
-            conn = Connection(io_loop=self.io_loop)
+            conn = Connection(io_loop=self.io_loop, confirm_delivery=confirm_delivery)
             conn.connect(broker_url, options=options,
                          callback=partial(self._on_connect, conn))
 
@@ -135,3 +147,50 @@ def _on_connect(self, connection):
     def connection(self):
         assert self._connection is not None
         return next(self._connection)
+
+class ConfirmDeliveryHandler(object):
+    
+    def __init__(self):
+        self._message_seq = 0
+        self._acked = 0
+        self._nacked = 0
+        self._unknown_ack = 0
+        self.coroutine_callbacks = {}
+    
+    def on_delivery_confirmation(self, method_frame):
+        """Invoked by pika when RabbitMQ responds to a Basic.Publish RPC
+        command, passing in either a Basic.Ack or Basic.Nack frame with
+        the delivery tag of the message that was published. The delivery tag
+        is an integer counter indicating the message number that was sent
+        on the channel via Basic.Publish. After Basic.Ack is received, it
+        will call corresponding callback based on delivery tag number.
+
+        :param pika.frame.Method method_frame: Basic.Ack or Basic.Nack frame
+
+        """
+        confirmation_type = method_frame.method.NAME.split('.')[1].lower()
+        delivery_tag = method_frame.method.delivery_tag
+        message = ('Received %s for delivery tag: %i' %
+                   (confirmation_type,
+                    delivery_tag))
+        LOGGER.debug(message)
+        
+        if confirmation_type == 'ack':
+            self._acked += 1
+        elif confirmation_type == 'nack':
+            self._nacked += 1
+        else:
+            self._unknown_ack += 1
+        coroutine_callback = self.coroutine_callbacks.pop(delivery_tag)
+        if coroutine_callback:
+            coroutine_callback(None)
+    
+    def reset_message_seq(self):
+        self._message_seq = 0
+        
+    def reset_coroutine_callbacks(self):
+        self.coroutine_callbacks.clear()
+    
+    def add_callback(self, callback):
+        self._message_seq += 1
+        self.coroutine_callbacks[self._message_seq] = callback
diff --git a/tcelery/handlers.py b/tcelery/handlers.py
index 1458309..aba43f3 100644
--- a/tcelery/handlers.py
+++ b/tcelery/handlers.py
@@ -230,20 +230,24 @@ def post(self, taskname):
                 partial(self.on_time, task_id))
 
         task.apply_async(args=args, kwargs=kwargs, task_id=task_id,
-                         callback=partial(self.on_complete, htimeout),
+                         callback=partial(self.on_async_result, htimeout),
                          **options)
 
-    def on_complete(self, htimeout, result):
+    def on_async_result(self, htimeout, async_result):
+        self._result = async_result
+        async_result.get(callback=partial(self.on_actual_result, htimeout))
+
+    def on_actual_result(self, htimeout, result):
         if self._finished:
             return
         if htimeout:
             ioloop.IOLoop.instance().remove_timeout(htimeout)
-        response = {'task-id': result.task_id, 'state': result.state}
-        if result.successful():
-            response['result'] = result.result
+        response = {'task-id': self._result.task_id, 'state': self._result.state}
+        if self._result.successful():
+            response['result'] = result
         else:
-            response['traceback'] = result.traceback
-            response['error'] = repr(result.result)
+            response['traceback'] = self._result.traceback
+            response['error'] = repr( self._result.result)
         self.write(response)
         self.finish()
 
diff --git a/tcelery/producer.py b/tcelery/producer.py
index af3b409..bfefcdb 100644
--- a/tcelery/producer.py
+++ b/tcelery/producer.py
@@ -1,13 +1,14 @@
 from __future__ import absolute_import
 
 import sys
-from functools import partial
+
 from datetime import timedelta
 from kombu import serialization
 from kombu.utils import cached_property
 from celery.app.amqp import TaskProducer
 from celery.backends.amqp import AMQPBackend
 from celery.backends.redis import RedisBackend
+from celery.backends.base import DisabledBackend
 from celery.utils import timeutils
 
 from .result import AsyncResult
@@ -20,7 +21,6 @@
 
 is_py3k = sys.version_info >= (3, 0)
 
-
 class AMQPConsumer(object):
     def __init__(self, producer):
         self.producer = producer
@@ -68,10 +68,6 @@ def publish(self, body, routing_key=None, delivery_mode=None,
 
         if callback and not callable(callback):
             raise ValueError('callback should be callable')
-        if callback and not isinstance(self.app.backend,
-                                       (AMQPBackend, RedisBackend)):
-            raise NotImplementedError(
-                'callback can be used only with AMQP or Redis backends')
 
         body, content_type, content_encoding = self._prepare(
             body, serializer, content_type, content_encoding,
@@ -94,10 +90,16 @@ def publish(self, body, routing_key=None, delivery_mode=None,
                          exchange=exchange, declare=declare)
 
         if callback:
-            self.consumer.wait_for(task_id,
-                                   partial(self.on_result, task_id, callback),
-                                   expires=self.prepare_expires(type=int),
-                                   persistent=self.app.conf.CELERY_RESULT_PERSISTENT)
+            async_result = self.result_cls(task_id=task_id,
+                                           result=result,
+                                           producer=self)
+            if conn.confirm_delivery:
+                conn.confirm_delivery_handler.add_callback(lambda result:
+                                                           callback(async_result))
+
+            else:
+                callback(async_result)
+
         return result
 
     @cached_property
@@ -117,12 +119,6 @@ def decode(self, payload):
                                     content_type=self.content_type,
                                     content_encoding=self.content_encoding)
 
-    def on_result(self, task_id, callback, reply):
-        reply = self.decode(reply)
-        reply['task_id'] = task_id
-        result = self.result_cls(**reply)
-        callback(result)
-
     def prepare_expires(self, value=None, type=None):
         if value is None:
             value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
@@ -132,5 +128,11 @@ def prepare_expires(self, value=None, type=None):
             return type(value * 1000)
         return value
 
+    def fail_if_backend_not_supported(self):
+        if not isinstance(self.app.backend,
+                          (AMQPBackend, RedisBackend, DisabledBackend)):
+            raise NotImplementedError(
+                'result retrieval can be used only with AMQP or Redis backends')
+
     def __repr__(self):
         return '<NonBlockingTaskProducer: {0.channel}>'.format(self)
diff --git a/tcelery/result.py b/tcelery/result.py
index d5626dd..6a9d7ab 100644
--- a/tcelery/result.py
+++ b/tcelery/result.py
@@ -1,16 +1,19 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
+from functools import partial
+
 import celery
 
 
 class AsyncResult(celery.result.AsyncResult):
-    def __init__(self, task_id, status=None, traceback=None, result=None,
-                 **kwargs):
+    def __init__(self, task_id, status=None, traceback=None,
+                 result=None, producer=None, **kwargs):
         super(AsyncResult, self).__init__(task_id)
         self._status = status
         self._traceback = traceback
         self._result = result
+        self._producer = producer
 
     @property
     def status(self):
@@ -27,3 +30,22 @@ def traceback(self):
     @property
     def result(self):
         return self._result or super(AsyncResult, self).result
+
+    def get(self, callback=None):
+        self._producer.fail_if_backend_not_supported()
+        self._producer.consumer.wait_for(self.task_id,
+                                         partial(self.on_result, callback),
+                                         expires=self._producer.prepare_expires(type=int),
+                                         persistent=self._producer.app.conf.CELERY_RESULT_PERSISTENT)
+
+    def on_result(self, callback, reply):
+        reply = self._producer.decode(reply)
+        self._status = reply.get('status')
+        self._traceback = reply.get('traceback')
+        self._result = reply.get('result')
+        if callback:
+            callback(self._result)
+
+    def _get_task_meta(self):
+        self._producer.fail_if_backend_not_supported()
+        return super(AsyncResult, self)._get_task_meta()
\ No newline at end of file