Skip to content

Commit

Permalink
Add new semaphore-adjacent object that does logging as well
Browse files Browse the repository at this point in the history
  • Loading branch information
4Kaylum committed Jan 17, 2024
1 parent 9661219 commit e72b936
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 62 deletions.
131 changes: 71 additions & 60 deletions novus/api/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,48 @@
dump = json.dumps


class _LoggingSemapohreContext:

def __init__(
self,
semaphore: asyncio.Semaphore,
shard_id: int,
fmt: str,
sleep_time: float = 5.0) -> None:
self.semapohre: asyncio.Semaphore = semaphore
self.shard_id: int = shard_id
self.fmt: str = fmt
self.sleep_time: float = sleep_time
self.logging_task: asyncio.Task | None = None

async def _loop(self, offset: int = 0) -> None:
try:
await asyncio.sleep(self.sleep_time)
except asyncio.CancelledError:
return
log.debug(self.fmt.format(shard=self.shard_id, time=5 * (offset + 1)))
await self._loop(offset + 1)

async def __aenter__(self) -> None:
await self.semapohre.__aenter__()
self.logging_task = asyncio.create_task(self._loop())

async def __aexit__(self, *args: Any) -> None:
await self.semapohre.__aexit__(*args)
if self.logging_task is not None:
self.logging_task.cancel()


class LoggingSemaphore(asyncio.Semaphore):

def log(
self,
shard_id: int,
fmt: str = "[{shard}] Waiting at semapohre for {time} seconds",
**kwargs: Any) -> _LoggingSemapohreContext:
return _LoggingSemapohreContext(self, shard_id, fmt=fmt, **kwargs)


class GatewayConnection:

def __init__(self, parent: HTTPConnection) -> None:
Expand Down Expand Up @@ -98,8 +140,8 @@ async def connect(

# Make some semaphores so we can control which shards connect
# simultaneously
identify_semaphore = asyncio.Semaphore(max_concurrency)
connect_semaphore = asyncio.Semaphore(10_000)
identify_semaphore = LoggingSemaphore(max_concurrency)
connect_semaphore = LoggingSemaphore(10_000)

# Create shard objects
shard_ids = shard_ids or list(range(shard_count))
Expand Down Expand Up @@ -164,8 +206,8 @@ def __init__(
shard_count: int,
presence: None = None,
intents: Intents = Intents.none(),
connect_semaphore: asyncio.Semaphore,
identify_semaphore: asyncio.Semaphore) -> None:
connect_semaphore: LoggingSemaphore,
identify_semaphore: LoggingSemaphore) -> None:
self.parent = parent
self.ws_url = Route.WS_BASE + "?" + urlencode({
"v": 10,
Expand All @@ -175,8 +217,9 @@ def __init__(
self.dispatch = GatewayDispatch(self)
self.connect_semaphore = connect_semaphore
self.identify_semaphore = identify_semaphore

self.connecting = asyncio.Event()
self.ready = asyncio.Event()
self.ready_received = asyncio.Event()
self.heartbeat_received = asyncio.Event()

# Initial identify data (for entire reconnects)
Expand Down Expand Up @@ -209,7 +252,7 @@ def __repr__(self) -> str:
state: str
if self.socket is None or self.socket.closed:
state = "DISCONNECTED"
elif self.ready.is_set():
elif self.ready_received.is_set():
state = "READY"
elif self.connecting.is_set():
state = "CONNECTING"
Expand Down Expand Up @@ -495,19 +538,12 @@ async def _connect(
# Open socket
if reconnect is False:
self.sequence = None
log.info(f"[{self.shard_id}] Getting session")
session = await self.parent.get_session()
ws_url = ws_url or self.ws_url
log.info("[%s] Creating websocket connection to %s", self.shard_id, ws_url)
try:
ws = await session.ws_connect(ws_url)
except Exception as e:
if attempt >= 5:
log.info(
"[%s] Failed to connect to open ws connection, closing (%s)",
self.shard_id, attempt,
)
return await self.close()
log.info(
"[%s] Failed to connect to open ws connection (%s), reattempting (%s)",
self.shard_id, e, attempt,
Expand All @@ -522,16 +558,10 @@ async def _connect(
# Get hello
log.info("[%s] Waiting for a HELLO", self.shard_id)
try:
got = await asyncio.wait_for(self.receive(), timeout=10.0)
got = await asyncio.wait_for(self.receive(), timeout=60.0)
except Exception as e:
if attempt >= 5:
log.info(
"[%s] Failed to get a HELLO (%s), closing (%s)",
self.shard_id, e, attempt,
)
return await self.close()
log.info(
"[%s] Failed to get a HELLO (%s), reattempting (%s)",
"[%s] Failed to get a HELLO after 60 seconds (%s), reattempting (%s)",
self.shard_id, e, attempt,
)
return await self._connect(
Expand All @@ -551,45 +581,25 @@ async def _connect(
)

# Send identify or resume
self.ready.clear()
timeout = 10
self.ready_received.clear()
self.message_task = asyncio.create_task(self.message_handler())
if reconnect:
fmt = "[{shard}] Waited {time} seconds at semapohre to RESUME"
else:
fmt = "[{shard}] Waited {time} seconds at semapohre to IDENTIFY"
async with self.identify_semaphore.log(self.shard_id, fmt):
if reconnect:
await self.resume()
else:
await self.identify()

# Wait for a ready or resume
log.info("[%s] Waiting for a READY/RESUMED", self.shard_id)
try:
async def connection_timer(offset: int = 0) -> None:
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
return
log.debug(
"[%s] Waited %s seconds at semaphore to %s",
self.shard_id, format(5 * (offset + 1), ","),
"RESUME" if reconnect else "IDENTIFY",
)
await connection_timer(offset + 1)
log_timer = asyncio.create_task(connection_timer())
async with self.identify_semaphore:
if reconnect:
await asyncio.wait_for(self.resume(), timeout)
else:
await asyncio.wait_for(self.identify(), timeout)
log_timer.cancel()
except asyncio.CancelledError:
log.info(
"[%s] Failed to get a response from resume/idenfity after %s seconds, reattempting connect (%s)",
self.shard_id, timeout, attempt,
)
self.message_task.cancel()
return await self._connect(
ws_url=ws_url,
reconnect=reconnect,
attempt=attempt + 1,
)
log.info("[%s] Waiting for a READY", self.shard_id)
try:
await asyncio.wait_for(self.ready.wait(), timeout=30)
await asyncio.wait_for(self.ready_received.wait(), timeout=60)
except asyncio.TimeoutError:
log.info(
"[%s] Failed to get a READY from the gateway after 30 seconds; reattempting connect (%s)",
"[%s] Failed to get a READY from the gateway after 60 seconds; reattempting connect (%s)",
self.shard_id, attempt,
)
self.message_task.cancel()
Expand All @@ -599,6 +609,7 @@ async def connection_timer(offset: int = 0) -> None:
attempt=attempt + 1,
)

# We are no longer connecting
self.connecting.clear()

async def close(self, code: int = 1_000) -> None:
Expand Down Expand Up @@ -794,7 +805,7 @@ async def message_handler(self) -> None:
async for opcode, event_name, sequence, message in self.messages():
match opcode:

# Ignore heartbeat acks
# Ignore heartbeats
case GatewayOpcode.heartbeat_ack:
self.heartbeat_received.set()

Expand All @@ -804,7 +815,7 @@ async def message_handler(self) -> None:
sequence = cast(int, sequence)
if event_name == "READY" or event_name == "RESUMED":
log.info("[%s] Received %s", self.shard_id, event_name)
self.ready.set()
self.ready_received.set()
self.sequence = sequence
t = asyncio.create_task(
self.handle_dispatch(event_name, message),
Expand All @@ -818,7 +829,7 @@ async def message_handler(self) -> None:

# Since we disconnect and reconnect here,
# we need the ready flag to be set so the event is triggered
self.ready.set()
self.ready_received.set()

t = asyncio.create_task(self.reconnect())
self.running_tasks.add(t)
Expand All @@ -830,7 +841,7 @@ async def message_handler(self) -> None:

# Since we disconnect and reconnect here,
# we need the ready flag to be set so the event is triggered
self.ready.set()
self.ready_received.set()

if message is True:
t = asyncio.create_task(self.reconnect())
Expand Down
4 changes: 2 additions & 2 deletions novus/ext/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def commands(self) -> set[Command]:
@property
def is_ready(self) -> bool:
for i in self.state.gateway.shards.values():
if not i.ready.is_set():
if not i.ready_received.is_set():
return False
return True

Expand Down Expand Up @@ -218,7 +218,7 @@ async def wait_until_ready(self) -> None:
"""

for i in self.state.gateway.shards.values():
await i.ready.wait()
await i.ready_received.wait()

def add_command(self, command: Command) -> None:
"""
Expand Down

0 comments on commit e72b936

Please sign in to comment.