Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #1446: monitor dies when exceptions raised before monitor created. #1447

Merged
merged 2 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 61 additions & 41 deletions slack_sdk/socket_mode/aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,48 +342,68 @@
return self.build_session_id(self.current_session)

async def connect(self):
old_session: Optional[ClientWebSocketResponse] = None if self.current_session is None else self.current_session
if self.wss_uri is None:
# If the underlying WSS URL does not exist,
# acquiring a new active WSS URL from the server-side first
self.wss_uri = await self.issue_new_wss_url()

self.current_session = await self.aiohttp_client_session.ws_connect(
self.wss_uri,
autoping=False,
heartbeat=self.ping_interval,
proxy=self.proxy,
ssl=self.web_client.ssl,
)
session_id: str = await self.session_id()
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
self.stale = False
self.logger.info(f"A new session ({session_id}) has been established")

# The first ping from the new connection
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"Sending a ping message with the newly established connection ({session_id})...")
t = time.time()
await self.current_session.ping(f"sdk-ping-pong:{t}")

if self.current_session_monitor is not None:
self.current_session_monitor.cancel()

self.current_session_monitor = asyncio.ensure_future(self.monitor_current_session())
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"A new monitor_current_session() executor has been recreated for {session_id}")

if self.message_receiver is not None:
self.message_receiver.cancel()

self.message_receiver = asyncio.ensure_future(self.receive_messages())
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"A new receive_messages() executor has been recreated for {session_id}")
# This loop is used to ensure when a new session is created,
# a new monitor and a new message receiver are also created.
# If a new session is created but we failed to create the new
# monitor or the new message, we should try it.
while True:
try:
old_session: Optional[ClientWebSocketResponse] = (
None if self.current_session is None else self.current_session
)

if old_session is not None:
await old_session.close()
old_session_id = self.build_session_id(old_session)
self.logger.info(f"The old session ({old_session_id}) has been abandoned")
# If the old session is broken (e.g. reset by peer), it might fail to close it.
# We don't want to retry when this kind of cases happen.
try:
# We should close old session before create a new one. Because when disconnect
# reason is `too_many_websockets`, we need to close the old one first to
# to decrease the number of connections.
self.auto_reconnect_enabled = False
if old_session is not None:
await old_session.close()
old_session_id = self.build_session_id(old_session)
self.logger.info(f"The old session ({old_session_id}) has been abandoned")
except Exception as e:
self.logger.exception(f"Failed to close the old session : {e}")

Check warning on line 367 in slack_sdk/socket_mode/aiohttp/__init__.py

View check run for this annotation

Codecov / codecov/patch

slack_sdk/socket_mode/aiohttp/__init__.py#L366-L367

Added lines #L366 - L367 were not covered by tests

if self.wss_uri is None:
# If the underlying WSS URL does not exist,
# acquiring a new active WSS URL from the server-side first
self.wss_uri = await self.issue_new_wss_url()

Check warning on line 372 in slack_sdk/socket_mode/aiohttp/__init__.py

View check run for this annotation

Codecov / codecov/patch

slack_sdk/socket_mode/aiohttp/__init__.py#L372

Added line #L372 was not covered by tests

self.current_session = await self.aiohttp_client_session.ws_connect(
self.wss_uri,
autoping=False,
heartbeat=self.ping_interval,
proxy=self.proxy,
ssl=self.web_client.ssl,
)
session_id: str = await self.session_id()
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
self.stale = False
self.logger.info(f"A new session ({session_id}) has been established")

# The first ping from the new connection
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"Sending a ping message with the newly established connection ({session_id})...")
t = time.time()
await self.current_session.ping(f"sdk-ping-pong:{t}")

if self.current_session_monitor is not None:
self.current_session_monitor.cancel()
self.current_session_monitor = asyncio.ensure_future(self.monitor_current_session())
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"A new monitor_current_session() executor has been recreated for {session_id}")

if self.message_receiver is not None:
self.message_receiver.cancel()
self.message_receiver = asyncio.ensure_future(self.receive_messages())
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"A new receive_messages() executor has been recreated for {session_id}")
break
except Exception as e:
self.logger.exception(f"Failed to connect : {e}. Retrying...")
seratch marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.sleep(self.ping_interval)

Check warning on line 406 in slack_sdk/socket_mode/aiohttp/__init__.py

View check run for this annotation

Codecov / codecov/patch

slack_sdk/socket_mode/aiohttp/__init__.py#L404-L406

Added lines #L404 - L406 were not covered by tests

async def disconnect(self):
if self.current_session is not None:
Expand Down
77 changes: 77 additions & 0 deletions tests/slack_sdk/socket_mode/mock_socket_mode_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import os
import time

from aiohttp import WSMsgType, web

Expand All @@ -24,6 +25,8 @@

socket_mode_hello_message = """{"type":"hello","num_connections":2,"debug_info":{"host":"applink-111-xxx","build_number":10,"approximate_connection_time":18060},"connection_info":{"app_id":"A111"}}"""

socket_mode_disconnect_message = """{"type":"disconnect","reason":"too_many_websockets","num_connections":2,"debug_info":{"host":"applink-111-xxx"},"connection_info":{"app_id":"A111"}}"""


def start_socket_mode_server(self, port: int):
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,3 +85,77 @@ def run_server():
loop.close()

return run_server


def start_socket_mode_server_with_disconnection(self, port: int):
logger = logging.getLogger(__name__)
state = {}

def reset_server_state():
state.update(
hello_sent=False,
disconnect_sent=False,
envelopes_to_consume=list(socket_mode_envelopes),
)

self.reset_server_state = reset_server_state

async def link(request):
disconnected = False
ws = web.WebSocketResponse()
await ws.prepare(request)

async for msg in ws:
# To ensure disconnect message is received and handled,
# need to keep this ws alive to bypass client ping-pong check.
if msg.type == WSMsgType.PING:
t = time.time()
await ws.pong(f"sdk-ping-pong:{t}")
continue
if msg.type != WSMsgType.TEXT:
continue

message = msg.data
logger.debug(f"Server received a message: {message}")

if not state["hello_sent"]:
state["hello_sent"] = True
await ws.send_str(socket_mode_hello_message)

if not state["disconnect_sent"]:
state["hello_sent"] = False
state["disconnect_sent"] = True
disconnected = True
await ws.send_str(socket_mode_disconnect_message)
logger.debug(f"Disconnect message sent")

if state["envelopes_to_consume"] and not disconnected:
e = state["envelopes_to_consume"].pop(0)
logger.debug(f"Send an envelope: {e}")
await ws.send_str(e)

await ws.send_str(message)

return ws

app = web.Application()
app.add_routes([web.get("/link", link)])
runner = web.AppRunner(app)

def run_server():
reset_server_state()

self.loop = loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, "127.0.0.1", port, reuse_port=True)
loop.run_until_complete(site.start())

# run until it's stopped from the main thread
loop.run_forever()

loop.run_until_complete(runner.cleanup())
loop.run_until_complete(asyncio.sleep(1))
loop.close()

return run_server
2 changes: 1 addition & 1 deletion tests/slack_sdk/socket_mode/mock_web_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _handle(self):
if self.path == "/apps.connections.open":
body = {
"ok": True,
"url": "wss://test-server/link/?ticket=xxx&app_id=yyy",
"url": "ws://0.0.0.0:3001/link",
}
if self.path == "/api.test" and request_body:
body = {"ok": True, "args": request_body}
Expand Down
2 changes: 1 addition & 1 deletion tests/slack_sdk/socket_mode/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_issue_new_wss_url(self):
web_client=self.web_client,
)
url = client.issue_new_wss_url()
self.assertTrue(url.startswith("wss://"))
self.assertTrue(url.startswith("ws://"))

legacy_client = LegacyWebClient(token="xoxb-api_test", base_url="http://localhost:8888")
response = legacy_client.apps_connections_open(app_token="xapp-A111-222-xyz")
Expand Down
2 changes: 1 addition & 1 deletion tests/slack_sdk_async/socket_mode/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def test_issue_new_wss_url(self):
)
try:
url = await client.issue_new_wss_url()
self.assertTrue(url.startswith("wss://"))
self.assertTrue(url.startswith("ws://"))
finally:
await client.close()

Expand Down
90 changes: 90 additions & 0 deletions tests/slack_sdk_async/socket_mode/test_interactions_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from tests.helpers import is_ci_unstable_test_skip_enabled
from tests.slack_sdk.socket_mode.mock_socket_mode_server import (
start_socket_mode_server,
start_socket_mode_server_with_disconnection,
socket_mode_envelopes,
socket_mode_hello_message,
socket_mode_disconnect_message,
)
from tests.slack_sdk.socket_mode.mock_web_api_server import (
setup_mock_web_api_server,
Expand Down Expand Up @@ -104,6 +106,94 @@ async def socket_mode_listener(
self.loop.stop()
t.join(timeout=5)

@async_test
async def test_interactions_with_disconnection(self):
if is_ci_unstable_test_skip_enabled():
return
t = Thread(target=start_socket_mode_server_with_disconnection(self, 3001))
t.daemon = True
t.start()

self.disconnected = False
received_messages = []
received_socket_mode_requests = []

async def message_handler(message: WSMessage):
session_id = client.build_session_id(client.current_session)
if "wait_for_disconnect" in message.data:
return
self.logger.info(f"Raw Message: {message}")
await asyncio.sleep(randint(50, 200) / 1000)
self.disconnected = "disconnect" in message.data
received_messages.append(message.data + "_" + session_id)

async def socket_mode_listener(
self: AsyncBaseSocketModeClient,
request: SocketModeRequest,
):
self.logger.info(f"Socket Mode Request: {request.payload}")
await asyncio.sleep(randint(50, 200) / 1000)
received_socket_mode_requests.append(request.payload)

client = SocketModeClient(
app_token="xapp-A111-222-xyz",
web_client=self.web_client,
on_message_listeners=[message_handler],
auto_reconnect_enabled=True,
trace_enabled=True,
)
client.socket_mode_request_listeners.append(socket_mode_listener)

try:
time.sleep(1) # wait for the server
client.wss_uri = "ws://0.0.0.0:3001/link"
await client.connect()
await asyncio.sleep(1) # wait for the message receiver

# Because we want to check the expected messages of new session,
# we need to ensure we send messaged after disconnected.
count = 0
while not self.disconnected and count < 10:
try:
await client.send_message("wait_for_disconnect")
except Exception as e:
self.logger.exception(e)
finally:
await asyncio.sleep(1)
count += 1
await asyncio.sleep(10)
expected_session_id = client.build_session_id(client.current_session)

for _ in range(10):
await client.send_message("foo")
await client.send_message("bar")
await client.send_message("baz")

expected = socket_mode_envelopes + [socket_mode_hello_message] + ["foo", "bar", "baz"] * 10
expected.sort()

count = 0
while count < 10 and (
len([msg for msg in received_messages if expected_session_id in msg]) < len(expected)
or len(received_socket_mode_requests) < len(socket_mode_envelopes)
):
await asyncio.sleep(0.2)
count += 0.2

received_messages.sort()

# Only check messages of current alive session. Ignore the disconnected session.
received_messages = [msg for msg in received_messages if expected_session_id in msg]
expected = [msg + "_" + expected_session_id for msg in expected]

self.assertEqual(received_messages, expected)

self.assertEqual(len(socket_mode_envelopes), len(received_socket_mode_requests))
finally:
await client.close()
self.loop.stop()
t.join(timeout=5)

@async_test
async def test_send_message_while_disconnection(self):
if is_ci_unstable_test_skip_enabled():
Expand Down
2 changes: 1 addition & 1 deletion tests/slack_sdk_async/socket_mode/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def test_issue_new_wss_url(self):
)
try:
url = await client.issue_new_wss_url()
self.assertTrue(url.startswith("wss://"))
self.assertTrue(url.startswith("ws://"))
finally:
await client.close()

Expand Down
Loading