Skip to content

Commit

Permalink
fix #1446: monitor dies when exceptions raised before monitor created. (
Browse files Browse the repository at this point in the history
#1447)

---------
Co-authored-by: Kazuhiro Sera <[email protected]>
  • Loading branch information
woolen-sheep authored Jan 9, 2024
1 parent df47c32 commit b6b9527
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 45 deletions.
102 changes: 61 additions & 41 deletions slack_sdk/socket_mode/aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,48 +342,68 @@ async def session_id(self) -> str:
return self.build_session_id(self.current_session)

async def connect(self):
old_session: Optional[ClientWebSocketResponse] = None if self.current_session is None else self.current_session
if self.wss_uri is None:
# If the underlying WSS URL does not exist,
# acquiring a new active WSS URL from the server-side first
self.wss_uri = await self.issue_new_wss_url()

self.current_session = await self.aiohttp_client_session.ws_connect(
self.wss_uri,
autoping=False,
heartbeat=self.ping_interval,
proxy=self.proxy,
ssl=self.web_client.ssl,
)
session_id: str = await self.session_id()
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
self.stale = False
self.logger.info(f"A new session ({session_id}) has been established")

# The first ping from the new connection
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"Sending a ping message with the newly established connection ({session_id})...")
t = time.time()
await self.current_session.ping(f"sdk-ping-pong:{t}")

if self.current_session_monitor is not None:
self.current_session_monitor.cancel()

self.current_session_monitor = asyncio.ensure_future(self.monitor_current_session())
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"A new monitor_current_session() executor has been recreated for {session_id}")

if self.message_receiver is not None:
self.message_receiver.cancel()

self.message_receiver = asyncio.ensure_future(self.receive_messages())
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"A new receive_messages() executor has been recreated for {session_id}")
# This loop is used to ensure when a new session is created,
# a new monitor and a new message receiver are also created.
# If a new session is created but we failed to create the new
# monitor or the new message, we should try it.
while True:
try:
old_session: Optional[ClientWebSocketResponse] = (
None if self.current_session is None else self.current_session
)

if old_session is not None:
await old_session.close()
old_session_id = self.build_session_id(old_session)
self.logger.info(f"The old session ({old_session_id}) has been abandoned")
# If the old session is broken (e.g. reset by peer), it might fail to close it.
# We don't want to retry when this kind of cases happen.
try:
# We should close old session before create a new one. Because when disconnect
# reason is `too_many_websockets`, we need to close the old one first to
# to decrease the number of connections.
self.auto_reconnect_enabled = False
if old_session is not None:
await old_session.close()
old_session_id = self.build_session_id(old_session)
self.logger.info(f"The old session ({old_session_id}) has been abandoned")
except Exception as e:
self.logger.exception(f"Failed to close the old session : {e}")

if self.wss_uri is None:
# If the underlying WSS URL does not exist,
# acquiring a new active WSS URL from the server-side first
self.wss_uri = await self.issue_new_wss_url()

self.current_session = await self.aiohttp_client_session.ws_connect(
self.wss_uri,
autoping=False,
heartbeat=self.ping_interval,
proxy=self.proxy,
ssl=self.web_client.ssl,
)
session_id: str = await self.session_id()
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
self.stale = False
self.logger.info(f"A new session ({session_id}) has been established")

# The first ping from the new connection
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"Sending a ping message with the newly established connection ({session_id})...")
t = time.time()
await self.current_session.ping(f"sdk-ping-pong:{t}")

if self.current_session_monitor is not None:
self.current_session_monitor.cancel()
self.current_session_monitor = asyncio.ensure_future(self.monitor_current_session())
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"A new monitor_current_session() executor has been recreated for {session_id}")

if self.message_receiver is not None:
self.message_receiver.cancel()
self.message_receiver = asyncio.ensure_future(self.receive_messages())
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"A new receive_messages() executor has been recreated for {session_id}")
break
except Exception as e:
self.logger.exception(f"Failed to connect (error: {e}); Retrying...")
await asyncio.sleep(self.ping_interval)

async def disconnect(self):
if self.current_session is not None:
Expand Down
77 changes: 77 additions & 0 deletions tests/slack_sdk/socket_mode/mock_socket_mode_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import os
import time

from aiohttp import WSMsgType, web

Expand All @@ -24,6 +25,8 @@

socket_mode_hello_message = """{"type":"hello","num_connections":2,"debug_info":{"host":"applink-111-xxx","build_number":10,"approximate_connection_time":18060},"connection_info":{"app_id":"A111"}}"""

socket_mode_disconnect_message = """{"type":"disconnect","reason":"too_many_websockets","num_connections":2,"debug_info":{"host":"applink-111-xxx"},"connection_info":{"app_id":"A111"}}"""


def start_socket_mode_server(self, port: int):
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,3 +85,77 @@ def run_server():
loop.close()

return run_server


def start_socket_mode_server_with_disconnection(self, port: int):
logger = logging.getLogger(__name__)
state = {}

def reset_server_state():
state.update(
hello_sent=False,
disconnect_sent=False,
envelopes_to_consume=list(socket_mode_envelopes),
)

self.reset_server_state = reset_server_state

async def link(request):
disconnected = False
ws = web.WebSocketResponse()
await ws.prepare(request)

async for msg in ws:
# To ensure disconnect message is received and handled,
# need to keep this ws alive to bypass client ping-pong check.
if msg.type == WSMsgType.PING:
t = time.time()
await ws.pong(f"sdk-ping-pong:{t}")
continue
if msg.type != WSMsgType.TEXT:
continue

message = msg.data
logger.debug(f"Server received a message: {message}")

if not state["hello_sent"]:
state["hello_sent"] = True
await ws.send_str(socket_mode_hello_message)

if not state["disconnect_sent"]:
state["hello_sent"] = False
state["disconnect_sent"] = True
disconnected = True
await ws.send_str(socket_mode_disconnect_message)
logger.debug(f"Disconnect message sent")

if state["envelopes_to_consume"] and not disconnected:
e = state["envelopes_to_consume"].pop(0)
logger.debug(f"Send an envelope: {e}")
await ws.send_str(e)

await ws.send_str(message)

return ws

app = web.Application()
app.add_routes([web.get("/link", link)])
runner = web.AppRunner(app)

def run_server():
reset_server_state()

self.loop = loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, "127.0.0.1", port, reuse_port=True)
loop.run_until_complete(site.start())

# run until it's stopped from the main thread
loop.run_forever()

loop.run_until_complete(runner.cleanup())
loop.run_until_complete(asyncio.sleep(1))
loop.close()

return run_server
2 changes: 1 addition & 1 deletion tests/slack_sdk/socket_mode/mock_web_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _handle(self):
if self.path == "/apps.connections.open":
body = {
"ok": True,
"url": "wss://test-server/link/?ticket=xxx&app_id=yyy",
"url": "ws://0.0.0.0:3001/link",
}
if self.path == "/api.test" and request_body:
body = {"ok": True, "args": request_body}
Expand Down
2 changes: 1 addition & 1 deletion tests/slack_sdk/socket_mode/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_issue_new_wss_url(self):
web_client=self.web_client,
)
url = client.issue_new_wss_url()
self.assertTrue(url.startswith("wss://"))
self.assertTrue(url.startswith("ws://"))

legacy_client = LegacyWebClient(token="xoxb-api_test", base_url="http://localhost:8888")
response = legacy_client.apps_connections_open(app_token="xapp-A111-222-xyz")
Expand Down
2 changes: 1 addition & 1 deletion tests/slack_sdk_async/socket_mode/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def test_issue_new_wss_url(self):
)
try:
url = await client.issue_new_wss_url()
self.assertTrue(url.startswith("wss://"))
self.assertTrue(url.startswith("ws://"))
finally:
await client.close()

Expand Down
90 changes: 90 additions & 0 deletions tests/slack_sdk_async/socket_mode/test_interactions_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from tests.helpers import is_ci_unstable_test_skip_enabled
from tests.slack_sdk.socket_mode.mock_socket_mode_server import (
start_socket_mode_server,
start_socket_mode_server_with_disconnection,
socket_mode_envelopes,
socket_mode_hello_message,
socket_mode_disconnect_message,
)
from tests.slack_sdk.socket_mode.mock_web_api_server import (
setup_mock_web_api_server,
Expand Down Expand Up @@ -104,6 +106,94 @@ async def socket_mode_listener(
self.loop.stop()
t.join(timeout=5)

@async_test
async def test_interactions_with_disconnection(self):
if is_ci_unstable_test_skip_enabled():
return
t = Thread(target=start_socket_mode_server_with_disconnection(self, 3001))
t.daemon = True
t.start()

self.disconnected = False
received_messages = []
received_socket_mode_requests = []

async def message_handler(message: WSMessage):
session_id = client.build_session_id(client.current_session)
if "wait_for_disconnect" in message.data:
return
self.logger.info(f"Raw Message: {message}")
await asyncio.sleep(randint(50, 200) / 1000)
self.disconnected = "disconnect" in message.data
received_messages.append(message.data + "_" + session_id)

async def socket_mode_listener(
self: AsyncBaseSocketModeClient,
request: SocketModeRequest,
):
self.logger.info(f"Socket Mode Request: {request.payload}")
await asyncio.sleep(randint(50, 200) / 1000)
received_socket_mode_requests.append(request.payload)

client = SocketModeClient(
app_token="xapp-A111-222-xyz",
web_client=self.web_client,
on_message_listeners=[message_handler],
auto_reconnect_enabled=True,
trace_enabled=True,
)
client.socket_mode_request_listeners.append(socket_mode_listener)

try:
time.sleep(1) # wait for the server
client.wss_uri = "ws://0.0.0.0:3001/link"
await client.connect()
await asyncio.sleep(1) # wait for the message receiver

# Because we want to check the expected messages of new session,
# we need to ensure we send messaged after disconnected.
count = 0
while not self.disconnected and count < 10:
try:
await client.send_message("wait_for_disconnect")
except Exception as e:
self.logger.exception(e)
finally:
await asyncio.sleep(1)
count += 1
await asyncio.sleep(10)
expected_session_id = client.build_session_id(client.current_session)

for _ in range(10):
await client.send_message("foo")
await client.send_message("bar")
await client.send_message("baz")

expected = socket_mode_envelopes + [socket_mode_hello_message] + ["foo", "bar", "baz"] * 10
expected.sort()

count = 0
while count < 10 and (
len([msg for msg in received_messages if expected_session_id in msg]) < len(expected)
or len(received_socket_mode_requests) < len(socket_mode_envelopes)
):
await asyncio.sleep(0.2)
count += 0.2

received_messages.sort()

# Only check messages of current alive session. Ignore the disconnected session.
received_messages = [msg for msg in received_messages if expected_session_id in msg]
expected = [msg + "_" + expected_session_id for msg in expected]

self.assertEqual(received_messages, expected)

self.assertEqual(len(socket_mode_envelopes), len(received_socket_mode_requests))
finally:
await client.close()
self.loop.stop()
t.join(timeout=5)

@async_test
async def test_send_message_while_disconnection(self):
if is_ci_unstable_test_skip_enabled():
Expand Down
2 changes: 1 addition & 1 deletion tests/slack_sdk_async/socket_mode/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def test_issue_new_wss_url(self):
)
try:
url = await client.issue_new_wss_url()
self.assertTrue(url.startswith("wss://"))
self.assertTrue(url.startswith("ws://"))
finally:
await client.close()

Expand Down

0 comments on commit b6b9527

Please sign in to comment.