Skip to content

Commit

Permalink
Allow for Endpoint.close to be async
Browse files Browse the repository at this point in the history
  • Loading branch information
egbertbouman committed Feb 12, 2024
1 parent 44c12fc commit 96034c3
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 13 deletions.
8 changes: 5 additions & 3 deletions ipv8/messaging/interfaces/dispatcher/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import typing

from ipv8.util import maybe_coroutine

from ..endpoint import Endpoint, EndpointListener
from ..udp.endpoint import UDPEndpoint, UDPv4Address, UDPv6Address, UDPv6Endpoint

Expand Down Expand Up @@ -167,7 +169,7 @@ def send(self, socket_address: Address, packet: bytes, interface: str | None = N
if ep is not None:
ep.send(socket_address, packet)

async def open(self) -> bool: # noqa: A003
async def open(self) -> bool:
"""
Open all interfaces.
"""
Expand All @@ -176,12 +178,12 @@ async def open(self) -> bool: # noqa: A003
any_success |= await interface.open()
return any_success

def close(self) -> None:
async def close(self) -> None:
"""
Close all interfaces.
"""
for interface in self.interfaces.values():
interface.close()
await maybe_coroutine(interface.close)

def reset_byte_counters(self) -> None:
"""
Expand Down
8 changes: 4 additions & 4 deletions ipv8/messaging/interfaces/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import socket
import struct
import threading
from typing import TYPE_CHECKING, Iterable
from typing import TYPE_CHECKING, Awaitable, Iterable

from .lan_addresses.interfaces import get_lan_addresses

Expand Down Expand Up @@ -115,13 +115,13 @@ def send(self, socket_address: Address, packet: bytes) -> None:
"""

@abc.abstractmethod
async def open(self) -> bool: # noqa: A003
async def open(self) -> bool:
"""
Attempt to open this endpoint and return if this was succesful.
Attempt to open this endpoint and return if this was successful.
"""

@abc.abstractmethod
def close(self) -> None:
def close(self) -> None | Awaitable:
"""
Close this endpoint as quick as possible.
"""
Expand Down
4 changes: 2 additions & 2 deletions ipv8/test/messaging/anonymization/test_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
PEER_FLAG_SPEED_TEST,
)
from ....messaging.interfaces.udp.endpoint import DomainAddress, UDPEndpoint
from ....util import succeed
from ....util import maybe_coroutine, succeed
from ...base import TestBase
from ...mocking.endpoint import MockEndpointListener
from ...mocking.exit_socket import MockTunnelExitSocket
Expand Down Expand Up @@ -256,7 +256,7 @@ async def test_create_circuit_multiple_calls(self) -> None:
await self.introduce_nodes()

# Don't allow the exit node to answer, this keeps peer 0's circuit in EXTENDING state
self.endpoint(1).close()
await maybe_coroutine(self.endpoint(1).close)
self.overlay(0).build_tunnels(1)

# Node 0 should have 1 circuit in the CIRCUIT_STATE_EXTENDING state
Expand Down
4 changes: 2 additions & 2 deletions ipv8/test/messaging/interfaces/dispatcher/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def send(self, socket_address: Address, packet: bytes) -> None:
self.bytes_up += len(packet)
self.sent.append((socket_address, packet))

async def open(self) -> bool: # noqa: A003
async def open(self) -> bool:
"""
Do a fake open.
"""
Expand Down Expand Up @@ -189,7 +189,7 @@ async def test_is_open(self) -> None:
self.assertTrue(endpoint.is_open())

# Close the Dispatcher Endpoint.
endpoint.close()
await endpoint.close()

# The Child Endpoint is closed and the Dispatcher Endpoint propagates the child's status.
self.assertFalse(child_endpoint.is_open())
Expand Down
2 changes: 1 addition & 1 deletion ipv8/test/mocking/ipv8.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def stop(self) -> None:
"""
Stop all registered IPv8 strategies, unload all registered overlays and close the endpoint.
"""
self.endpoint.close()
await maybe_coroutine(self.endpoint.close)
await self.overlay.unload()
if self.dht:
await self.dht.unload()
2 changes: 1 addition & 1 deletion ipv8_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ async def stop(self) -> None:
with self.overlay_lock:
unload_list = [self.unload_overlay(overlay) for overlay in self.overlays[:]]
await gather(*unload_list)
self.endpoint.close()
await maybe_coroutine(self.endpoint.close)


if __name__ == '__main__':
Expand Down

0 comments on commit 96034c3

Please sign in to comment.