diff --git a/slack_sdk/socket_mode/aiohttp/__init__.py b/slack_sdk/socket_mode/aiohttp/__init__.py index 19ae7af0..ed379976 100644 --- a/slack_sdk/socket_mode/aiohttp/__init__.py +++ b/slack_sdk/socket_mode/aiohttp/__init__.py @@ -342,48 +342,68 @@ async def session_id(self) -> str: return self.build_session_id(self.current_session) async def connect(self): - old_session: Optional[ClientWebSocketResponse] = None if self.current_session is None else self.current_session - if self.wss_uri is None: - # If the underlying WSS URL does not exist, - # acquiring a new active WSS URL from the server-side first - self.wss_uri = await self.issue_new_wss_url() - - self.current_session = await self.aiohttp_client_session.ws_connect( - self.wss_uri, - autoping=False, - heartbeat=self.ping_interval, - proxy=self.proxy, - ssl=self.web_client.ssl, - ) - session_id: str = await self.session_id() - self.auto_reconnect_enabled = self.default_auto_reconnect_enabled - self.stale = False - self.logger.info(f"A new session ({session_id}) has been established") - - # The first ping from the new connection - if self.logger.level <= logging.DEBUG: - self.logger.debug(f"Sending a ping message with the newly established connection ({session_id})...") - t = time.time() - await self.current_session.ping(f"sdk-ping-pong:{t}") - - if self.current_session_monitor is not None: - self.current_session_monitor.cancel() - - self.current_session_monitor = asyncio.ensure_future(self.monitor_current_session()) - if self.logger.level <= logging.DEBUG: - self.logger.debug(f"A new monitor_current_session() executor has been recreated for {session_id}") - - if self.message_receiver is not None: - self.message_receiver.cancel() - - self.message_receiver = asyncio.ensure_future(self.receive_messages()) - if self.logger.level <= logging.DEBUG: - self.logger.debug(f"A new receive_messages() executor has been recreated for {session_id}") + # This loop is used to ensure when a new session is created, + # a new monitor and a new message receiver are also created. + # If a new session is created but we failed to create the new + # monitor or the new message, we should try it. + while True: + try: + old_session: Optional[ClientWebSocketResponse] = ( + None if self.current_session is None else self.current_session + ) - if old_session is not None: - await old_session.close() - old_session_id = self.build_session_id(old_session) - self.logger.info(f"The old session ({old_session_id}) has been abandoned") + # If the old session is broken (e.g. reset by peer), it might fail to close it. + # We don't want to retry when this kind of cases happen. + try: + # We should close old session before create a new one. Because when disconnect + # reason is `too_many_websockets`, we need to close the old one first to + # to decrease the number of connections. + self.auto_reconnect_enabled = False + if old_session is not None: + await old_session.close() + old_session_id = self.build_session_id(old_session) + self.logger.info(f"The old session ({old_session_id}) has been abandoned") + except Exception as e: + self.logger.exception(f"Failed to close the old session : {e}") + + if self.wss_uri is None: + # If the underlying WSS URL does not exist, + # acquiring a new active WSS URL from the server-side first + self.wss_uri = await self.issue_new_wss_url() + + self.current_session = await self.aiohttp_client_session.ws_connect( + self.wss_uri, + autoping=False, + heartbeat=self.ping_interval, + proxy=self.proxy, + ssl=self.web_client.ssl, + ) + session_id: str = await self.session_id() + self.auto_reconnect_enabled = self.default_auto_reconnect_enabled + self.stale = False + self.logger.info(f"A new session ({session_id}) has been established") + + # The first ping from the new connection + if self.logger.level <= logging.DEBUG: + self.logger.debug(f"Sending a ping message with the newly established connection ({session_id})...") + t = time.time() + await self.current_session.ping(f"sdk-ping-pong:{t}") + + if self.current_session_monitor is not None: + self.current_session_monitor.cancel() + self.current_session_monitor = asyncio.ensure_future(self.monitor_current_session()) + if self.logger.level <= logging.DEBUG: + self.logger.debug(f"A new monitor_current_session() executor has been recreated for {session_id}") + + if self.message_receiver is not None: + self.message_receiver.cancel() + self.message_receiver = asyncio.ensure_future(self.receive_messages()) + if self.logger.level <= logging.DEBUG: + self.logger.debug(f"A new receive_messages() executor has been recreated for {session_id}") + break + except Exception as e: + self.logger.exception(f"Failed to connect (error: {e}); Retrying...") + await asyncio.sleep(self.ping_interval) async def disconnect(self): if self.current_session is not None: diff --git a/tests/slack_sdk/socket_mode/mock_socket_mode_server.py b/tests/slack_sdk/socket_mode/mock_socket_mode_server.py index 7dba25e5..02f9b84a 100644 --- a/tests/slack_sdk/socket_mode/mock_socket_mode_server.py +++ b/tests/slack_sdk/socket_mode/mock_socket_mode_server.py @@ -1,6 +1,7 @@ import asyncio import logging import os +import time from aiohttp import WSMsgType, web @@ -24,6 +25,8 @@ socket_mode_hello_message = """{"type":"hello","num_connections":2,"debug_info":{"host":"applink-111-xxx","build_number":10,"approximate_connection_time":18060},"connection_info":{"app_id":"A111"}}""" +socket_mode_disconnect_message = """{"type":"disconnect","reason":"too_many_websockets","num_connections":2,"debug_info":{"host":"applink-111-xxx"},"connection_info":{"app_id":"A111"}}""" + def start_socket_mode_server(self, port: int): logger = logging.getLogger(__name__) @@ -82,3 +85,77 @@ def run_server(): loop.close() return run_server + + +def start_socket_mode_server_with_disconnection(self, port: int): + logger = logging.getLogger(__name__) + state = {} + + def reset_server_state(): + state.update( + hello_sent=False, + disconnect_sent=False, + envelopes_to_consume=list(socket_mode_envelopes), + ) + + self.reset_server_state = reset_server_state + + async def link(request): + disconnected = False + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + # To ensure disconnect message is received and handled, + # need to keep this ws alive to bypass client ping-pong check. + if msg.type == WSMsgType.PING: + t = time.time() + await ws.pong(f"sdk-ping-pong:{t}") + continue + if msg.type != WSMsgType.TEXT: + continue + + message = msg.data + logger.debug(f"Server received a message: {message}") + + if not state["hello_sent"]: + state["hello_sent"] = True + await ws.send_str(socket_mode_hello_message) + + if not state["disconnect_sent"]: + state["hello_sent"] = False + state["disconnect_sent"] = True + disconnected = True + await ws.send_str(socket_mode_disconnect_message) + logger.debug(f"Disconnect message sent") + + if state["envelopes_to_consume"] and not disconnected: + e = state["envelopes_to_consume"].pop(0) + logger.debug(f"Send an envelope: {e}") + await ws.send_str(e) + + await ws.send_str(message) + + return ws + + app = web.Application() + app.add_routes([web.get("/link", link)]) + runner = web.AppRunner(app) + + def run_server(): + reset_server_state() + + self.loop = loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(runner.setup()) + site = web.TCPSite(runner, "127.0.0.1", port, reuse_port=True) + loop.run_until_complete(site.start()) + + # run until it's stopped from the main thread + loop.run_forever() + + loop.run_until_complete(runner.cleanup()) + loop.run_until_complete(asyncio.sleep(1)) + loop.close() + + return run_server diff --git a/tests/slack_sdk/socket_mode/mock_web_api_server.py b/tests/slack_sdk/socket_mode/mock_web_api_server.py index b0fffef8..f37529ec 100644 --- a/tests/slack_sdk/socket_mode/mock_web_api_server.py +++ b/tests/slack_sdk/socket_mode/mock_web_api_server.py @@ -89,7 +89,7 @@ def _handle(self): if self.path == "/apps.connections.open": body = { "ok": True, - "url": "wss://test-server/link/?ticket=xxx&app_id=yyy", + "url": "ws://0.0.0.0:3001/link", } if self.path == "/api.test" and request_body: body = {"ok": True, "args": request_body} diff --git a/tests/slack_sdk/socket_mode/test_builtin.py b/tests/slack_sdk/socket_mode/test_builtin.py index 66777208..db4ef9bd 100644 --- a/tests/slack_sdk/socket_mode/test_builtin.py +++ b/tests/slack_sdk/socket_mode/test_builtin.py @@ -56,7 +56,7 @@ def test_issue_new_wss_url(self): web_client=self.web_client, ) url = client.issue_new_wss_url() - self.assertTrue(url.startswith("wss://")) + self.assertTrue(url.startswith("ws://")) legacy_client = LegacyWebClient(token="xoxb-api_test", base_url="http://localhost:8888") response = legacy_client.apps_connections_open(app_token="xapp-A111-222-xyz") diff --git a/tests/slack_sdk_async/socket_mode/test_aiohttp.py b/tests/slack_sdk_async/socket_mode/test_aiohttp.py index fdbda0d4..1d2a7967 100644 --- a/tests/slack_sdk_async/socket_mode/test_aiohttp.py +++ b/tests/slack_sdk_async/socket_mode/test_aiohttp.py @@ -42,7 +42,7 @@ async def test_issue_new_wss_url(self): ) try: url = await client.issue_new_wss_url() - self.assertTrue(url.startswith("wss://")) + self.assertTrue(url.startswith("ws://")) finally: await client.close() diff --git a/tests/slack_sdk_async/socket_mode/test_interactions_aiohttp.py b/tests/slack_sdk_async/socket_mode/test_interactions_aiohttp.py index 74f30d04..8bb0a962 100644 --- a/tests/slack_sdk_async/socket_mode/test_interactions_aiohttp.py +++ b/tests/slack_sdk_async/socket_mode/test_interactions_aiohttp.py @@ -17,8 +17,10 @@ from tests.helpers import is_ci_unstable_test_skip_enabled from tests.slack_sdk.socket_mode.mock_socket_mode_server import ( start_socket_mode_server, + start_socket_mode_server_with_disconnection, socket_mode_envelopes, socket_mode_hello_message, + socket_mode_disconnect_message, ) from tests.slack_sdk.socket_mode.mock_web_api_server import ( setup_mock_web_api_server, @@ -104,6 +106,94 @@ async def socket_mode_listener( self.loop.stop() t.join(timeout=5) + @async_test + async def test_interactions_with_disconnection(self): + if is_ci_unstable_test_skip_enabled(): + return + t = Thread(target=start_socket_mode_server_with_disconnection(self, 3001)) + t.daemon = True + t.start() + + self.disconnected = False + received_messages = [] + received_socket_mode_requests = [] + + async def message_handler(message: WSMessage): + session_id = client.build_session_id(client.current_session) + if "wait_for_disconnect" in message.data: + return + self.logger.info(f"Raw Message: {message}") + await asyncio.sleep(randint(50, 200) / 1000) + self.disconnected = "disconnect" in message.data + received_messages.append(message.data + "_" + session_id) + + async def socket_mode_listener( + self: AsyncBaseSocketModeClient, + request: SocketModeRequest, + ): + self.logger.info(f"Socket Mode Request: {request.payload}") + await asyncio.sleep(randint(50, 200) / 1000) + received_socket_mode_requests.append(request.payload) + + client = SocketModeClient( + app_token="xapp-A111-222-xyz", + web_client=self.web_client, + on_message_listeners=[message_handler], + auto_reconnect_enabled=True, + trace_enabled=True, + ) + client.socket_mode_request_listeners.append(socket_mode_listener) + + try: + time.sleep(1) # wait for the server + client.wss_uri = "ws://0.0.0.0:3001/link" + await client.connect() + await asyncio.sleep(1) # wait for the message receiver + + # Because we want to check the expected messages of new session, + # we need to ensure we send messaged after disconnected. + count = 0 + while not self.disconnected and count < 10: + try: + await client.send_message("wait_for_disconnect") + except Exception as e: + self.logger.exception(e) + finally: + await asyncio.sleep(1) + count += 1 + await asyncio.sleep(10) + expected_session_id = client.build_session_id(client.current_session) + + for _ in range(10): + await client.send_message("foo") + await client.send_message("bar") + await client.send_message("baz") + + expected = socket_mode_envelopes + [socket_mode_hello_message] + ["foo", "bar", "baz"] * 10 + expected.sort() + + count = 0 + while count < 10 and ( + len([msg for msg in received_messages if expected_session_id in msg]) < len(expected) + or len(received_socket_mode_requests) < len(socket_mode_envelopes) + ): + await asyncio.sleep(0.2) + count += 0.2 + + received_messages.sort() + + # Only check messages of current alive session. Ignore the disconnected session. + received_messages = [msg for msg in received_messages if expected_session_id in msg] + expected = [msg + "_" + expected_session_id for msg in expected] + + self.assertEqual(received_messages, expected) + + self.assertEqual(len(socket_mode_envelopes), len(received_socket_mode_requests)) + finally: + await client.close() + self.loop.stop() + t.join(timeout=5) + @async_test async def test_send_message_while_disconnection(self): if is_ci_unstable_test_skip_enabled(): diff --git a/tests/slack_sdk_async/socket_mode/test_websockets.py b/tests/slack_sdk_async/socket_mode/test_websockets.py index 14bd6f19..4fc50b34 100644 --- a/tests/slack_sdk_async/socket_mode/test_websockets.py +++ b/tests/slack_sdk_async/socket_mode/test_websockets.py @@ -36,7 +36,7 @@ async def test_issue_new_wss_url(self): ) try: url = await client.issue_new_wss_url() - self.assertTrue(url.startswith("wss://")) + self.assertTrue(url.startswith("ws://")) finally: await client.close()