diff --git a/ipv8/messaging/interfaces/dispatcher/endpoint.py b/ipv8/messaging/interfaces/dispatcher/endpoint.py index 791112ea6..2960891f4 100644 --- a/ipv8/messaging/interfaces/dispatcher/endpoint.py +++ b/ipv8/messaging/interfaces/dispatcher/endpoint.py @@ -4,6 +4,8 @@ import logging import typing +from ipv8.util import maybe_coroutine + from ..endpoint import Endpoint, EndpointListener from ..udp.endpoint import UDPEndpoint, UDPv4Address, UDPv6Address, UDPv6Endpoint @@ -167,7 +169,7 @@ def send(self, socket_address: Address, packet: bytes, interface: str | None = N if ep is not None: ep.send(socket_address, packet) - async def open(self) -> bool: # noqa: A003 + async def open(self) -> bool: """ Open all interfaces. """ @@ -176,12 +178,12 @@ async def open(self) -> bool: # noqa: A003 any_success |= await interface.open() return any_success - def close(self) -> None: + async def close(self) -> None: """ Close all interfaces. """ for interface in self.interfaces.values(): - interface.close() + await maybe_coroutine(interface.close) def reset_byte_counters(self) -> None: """ diff --git a/ipv8/messaging/interfaces/endpoint.py b/ipv8/messaging/interfaces/endpoint.py index 0dfef63df..ed51113f0 100644 --- a/ipv8/messaging/interfaces/endpoint.py +++ b/ipv8/messaging/interfaces/endpoint.py @@ -6,7 +6,7 @@ import socket import struct import threading -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING, Awaitable, Iterable from .lan_addresses.interfaces import get_lan_addresses @@ -115,13 +115,13 @@ def send(self, socket_address: Address, packet: bytes) -> None: """ @abc.abstractmethod - async def open(self) -> bool: # noqa: A003 + async def open(self) -> bool: """ - Attempt to open this endpoint and return if this was succesful. + Attempt to open this endpoint and return if this was successful. """ @abc.abstractmethod - def close(self) -> None: + def close(self) -> None | Awaitable: """ Close this endpoint as quick as possible. """ diff --git a/ipv8/test/messaging/anonymization/test_community.py b/ipv8/test/messaging/anonymization/test_community.py index 25138a09e..bb307d618 100644 --- a/ipv8/test/messaging/anonymization/test_community.py +++ b/ipv8/test/messaging/anonymization/test_community.py @@ -15,7 +15,7 @@ PEER_FLAG_SPEED_TEST, ) from ....messaging.interfaces.udp.endpoint import DomainAddress, UDPEndpoint -from ....util import succeed +from ....util import maybe_coroutine, succeed from ...base import TestBase from ...mocking.endpoint import MockEndpointListener from ...mocking.exit_socket import MockTunnelExitSocket @@ -256,7 +256,7 @@ async def test_create_circuit_multiple_calls(self) -> None: await self.introduce_nodes() # Don't allow the exit node to answer, this keeps peer 0's circuit in EXTENDING state - self.endpoint(1).close() + await maybe_coroutine(self.endpoint(1).close) self.overlay(0).build_tunnels(1) # Node 0 should have 1 circuit in the CIRCUIT_STATE_EXTENDING state diff --git a/ipv8/test/messaging/interfaces/dispatcher/test_endpoint.py b/ipv8/test/messaging/interfaces/dispatcher/test_endpoint.py index abd4adf59..276b11f0d 100644 --- a/ipv8/test/messaging/interfaces/dispatcher/test_endpoint.py +++ b/ipv8/test/messaging/interfaces/dispatcher/test_endpoint.py @@ -77,7 +77,7 @@ def send(self, socket_address: Address, packet: bytes) -> None: self.bytes_up += len(packet) self.sent.append((socket_address, packet)) - async def open(self) -> bool: # noqa: A003 + async def open(self) -> bool: """ Do a fake open. """ @@ -189,7 +189,7 @@ async def test_is_open(self) -> None: self.assertTrue(endpoint.is_open()) # Close the Dispatcher Endpoint. - endpoint.close() + await endpoint.close() # The Child Endpoint is closed and the Dispatcher Endpoint propagates the child's status. self.assertFalse(child_endpoint.is_open()) diff --git a/ipv8/test/mocking/ipv8.py b/ipv8/test/mocking/ipv8.py index e80998d3a..5a92d8fde 100644 --- a/ipv8/test/mocking/ipv8.py +++ b/ipv8/test/mocking/ipv8.py @@ -105,7 +105,7 @@ async def stop(self) -> None: """ Stop all registered IPv8 strategies, unload all registered overlays and close the endpoint. """ - self.endpoint.close() + await maybe_coroutine(self.endpoint.close) await self.overlay.unload() if self.dht: await self.dht.unload() diff --git a/ipv8_service.py b/ipv8_service.py index b0cfaeadb..bdfdf521b 100644 --- a/ipv8_service.py +++ b/ipv8_service.py @@ -256,7 +256,7 @@ async def stop(self) -> None: with self.overlay_lock: unload_list = [self.unload_overlay(overlay) for overlay in self.overlays[:]] await gather(*unload_list) - self.endpoint.close() + await maybe_coroutine(self.endpoint.close) if __name__ == '__main__':