Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added stop_polling & changed way polling is run for (graceful) shutdown #2392

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 43 additions & 64 deletions telebot/async_telebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[
util.validate_token(self.token)

self.bot_id: Union[int, None] = util.extract_bot_id(self.token) # subject to change, unspecified

self.__polling: Optional[asyncio.Event] = None
self._stop_event = asyncio.Event()

self._update_tasks_set = set()


@property
Expand Down Expand Up @@ -317,72 +322,34 @@ async def polling(self, non_stop: bool=True, skip_pending=False, interval: int=0
await self.skip_updates()

if restart_on_change:
self._setup_change_detector(path_to_watch)
self._setup_change_detector(path_to_watch)

tasks = [] # only polling & event task
# we will stop polling when either of these two fail/complete:
# 1. polling task: due to exception, etc
# 2. stop_event: due to stop_polling call which causes _stop_event.set()
tasks.append(asyncio.create_task(self._process_polling(non_stop, interval, timeout, request_timeout, allowed_updates)))
tasks.append(asyncio.create_task(self._stop_event.wait()))
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

pending = pending.union(self._update_tasks_set)
for task in pending:
try:
task.cancel()
except asyncio.CancelledError: # handled just in case, not necessary
await task # cleanup

await self._process_polling(non_stop, interval, timeout, request_timeout, allowed_updates)
await asyncio.gather(*done, return_exceptions=True)

async def infinity_polling(self, timeout: Optional[int]=20, skip_pending: Optional[bool]=False, request_timeout: Optional[int]=None,
logger_level: Optional[int]=logging.ERROR, allowed_updates: Optional[List[str]]=None,
restart_on_change: Optional[bool]=False, path_to_watch: Optional[str]=None, *args, **kwargs):
"""
Wrap polling with infinite loop and exception handling to avoid bot stops polling.

.. note::
Install watchdog and psutil before using restart_on_change option.

:param timeout: Timeout in seconds for get_updates(Defaults to None)
:type timeout: :obj:`int`

:param skip_pending: skip old updates
:type skip_pending: :obj:`bool`

:param request_timeout: Aiohttp's request timeout. Defaults to 5 minutes(aiohttp.ClientTimeout).
:type request_timeout: :obj:`int`

:param logger_level: Custom logging level for infinity_polling logging.
Use logger levels from logging as a value. None/NOTSET = no error logging
:type logger_level: :obj:`int`

:param allowed_updates: A list of the update types you want your bot to receive.
For example, specify [“message”, “edited_channel_post”, “callback_query”] to only receive updates of these types.
See util.update_types for a complete list of available update types.
Specify an empty list to receive all update types except chat_member (default).
If not specified, the previous setting will be used.

Please note that this parameter doesn't affect updates created before the call to the get_updates,
so unwanted updates may be received for a short period of time.
:type allowed_updates: :obj:`list` of :obj:`str`

:param restart_on_change: Restart a file on file(s) change. Defaults to False
:type restart_on_change: :obj:`bool`

:param path_to_watch: Path to watch for changes. Defaults to current directory
:type path_to_watch: :obj:`str`

:return: None
Deprecated. Use polling instead.
"""
if skip_pending:
await self.skip_updates()
self._polling = True

if restart_on_change:
self._setup_change_detector(path_to_watch)

while self._polling:
try:
await self._process_polling(non_stop=True, timeout=timeout, request_timeout=request_timeout,
allowed_updates=allowed_updates, *args, **kwargs)
except Exception as e:
if logger_level and logger_level >= logging.ERROR:
logger.error("Infinity polling exception: %s", self.__hide_token(str(e)))
if logger_level and logger_level >= logging.DEBUG:
logger.error("Exception traceback:\n%s", self.__hide_token(traceback.format_exc()))
await asyncio.sleep(3)
continue
if logger_level and logger_level >= logging.INFO:
logger.error("Infinity polling: polling exited")
if logger_level and logger_level >= logging.INFO:
logger.error("Break infinity polling")
# logger_level is useless & not used;
await self.polling(non_stop=True, skip_pending=skip_pending, interval=0, timeout=timeout, request_timeout=request_timeout,
allowed_updates=allowed_updates, restart_on_change=restart_on_change, path_to_watch=path_to_watch, *args, **kwargs)

async def _handle_exception(self, exception: Exception) -> bool:
if self.exception_handler is None:
Expand Down Expand Up @@ -410,6 +377,13 @@ async def _handle_error_interval(self, error_interval: float):
error_interval = 60
return error_interval

async def stop_polling(self):
"""
Stop polling.
"""
self._stop_event.set()
self.__polling.clear()

async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout: int=20,
request_timeout: int=None, allowed_updates: Optional[List[str]]=None):
"""
Expand Down Expand Up @@ -440,12 +414,13 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout:

logger.info('Starting your bot with username: [@%s]', self.user.username)

self._polling = True

error_interval = 0.25

self.__polling = asyncio.Event()
self.__polling.set()

try:
while self._polling:
while self.__polling.is_set():
try:
updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout)
if updates:
Expand Down Expand Up @@ -499,7 +474,7 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout:
else:
break
finally:
self._polling = False
self.__polling.clear() # clear polling event
await self.close_session()
logger.warning('Polling is stopped.')

Expand All @@ -518,7 +493,11 @@ async def _process_updates(self, handlers, messages, update_type):
tasks = []
middlewares = await self._get_middlewares(update_type)
for message in messages:
tasks.append(self._run_middlewares_and_handlers(message, handlers, middlewares, update_type))
task = asyncio.create_task(self._run_middlewares_and_handlers(message, handlers, middlewares, update_type))
tasks.append(task)
task.add_done_callback(self._update_tasks_set.discard)
self._update_tasks_set.add(task)

await asyncio.gather(*tasks)

async def _run_middlewares_and_handlers(self, message, handlers, middlewares, update_type):
Expand Down
Loading