From 3de802156e63da7ae91f04785ff547f2b08f434b Mon Sep 17 00:00:00 2001 From: Benedikt Moneke <67148916+bmoneke@users.noreply.github.com> Date: Thu, 16 Feb 2023 10:50:11 +0100 Subject: [PATCH] Working version of Coordinator and message utils added. Message Layer is still under development. --- pyleco/coordinator.py | 372 ++++++++++++++++++++++++++++++++++++++ pyleco/timers.py | 60 ++++++ pyleco/utils.py | 214 ++++++++++++++++++++++ tests/test_coordinator.py | 177 ++++++++++++++++++ tests/test_utils.py | 91 ++++++++++ 5 files changed, 914 insertions(+) create mode 100644 pyleco/coordinator.py create mode 100644 pyleco/timers.py create mode 100644 pyleco/utils.py create mode 100644 tests/test_coordinator.py create mode 100644 tests/test_utils.py diff --git a/pyleco/coordinator.py b/pyleco/coordinator.py new file mode 100644 index 000000000..1b274e2e4 --- /dev/null +++ b/pyleco/coordinator.py @@ -0,0 +1,372 @@ +# +# This file is part of the PyLECO package. +# +# Copyright (c) 2023-2023 PyLECO Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# + +import logging +from random import random +from socket import gethostname +import sys +from time import perf_counter + +import zmq + +from .utils import (Commands, serialize_data, interpret_header, + create_message, split_name, deserialize_data, + divide_message + ) +from .timers import RepeatingTimer + + +log = logging.getLogger(__name__) +log.addHandler(logging.NullHandler()) + + +class Coordinator: + """A Coordinator program, routing messages among connected peers. + + .. code:: + + c = Coordinator() + c.routing() + + :param str node: Name of the node. Defaults to hostname + :param int port: Port to listen to. + :param timeout: Timeout waiting for messages in ms. + :param cleaning_interval: Interval between two addresses cleaning runs in s. + :param expiration_time: Time, when a stored address expires in s. + :param context: ZMQ context or similar. + """ + + def __init__(self, node=None, host=None, port=12300, timeout=50, cleaning_interval=5, expiration_time=15, + context=zmq.Context.instance(), + **kwargs): + if node is None: + self.node = gethostname().encode() + elif isinstance(node, str): + self.node = node.encode() + elif isinstance(node, bytes): + self.node = node + else: + raise ValueError("`node` must be str or bytes or None.") + self.fname = self.node + b".COORDINATOR" + log.info(f"Start Coordinator of node {self.node} at port {port}.") + self.address = (gethostname() if host is None else host, port) + # Connected Components + self.directory = {b'COORDINATOR': b""} # Component name: identity + self.heartbeats = {} # Component name: timestamp + # Connected Coordinators + self.node_identities = {} # identity: Namespace + self.node_heartbeats = {} # identity: time + self.dealers = {} # Namespace: DEALER socket + self.waiting_dealers = {} # Namespace, socket + self.node_addresses = {self.node: self.address} # Namespace: address + self.global_directory = {} # All Components + self.timeout = timeout + self.cleaner = RepeatingTimer(cleaning_interval, self.clean_addresses, args=(expiration_time,)) + + self.context = context + self.sock = self.context.socket(zmq.ROUTER) + self.cleaner.start() + try: + self.sock.bind(f"tcp://*:{port}") + except Exception: + raise + super().__init__(**kwargs) + + def __del__(self): + self.close() + + def close(self): + self.sock.close(1) + self.cleaner.cancel() + + def send_message(self, receiver, data=None, **kwargs): + """Send a message with any socket, including routing. + + :param identity: Connection identity to send to. + :param receiver: Receiver name + :param sender: Sender name + :param data: Object to send. + :param \\**kwargs: Keyword arguments for the header. + """ + payload = [serialize_data(data)] if data else None + frames = create_message(receiver, self.fname, payload=payload, **kwargs) + self.deliver_message(b"", frames) + + def send_message_raw(self, sender_identity, receiver, data=None, **kwargs): + """Send a message with the ROUTER socket. + + :param identity: Connection identity to send to. + :param receiver: Receiver name + :param sender: Sender name + :param data: Object to send. + :param \\**kwargs: Keyword arguments for the header. + """ + payload = [serialize_data(data)] if data else None + frames = create_message(receiver, self.fname, payload=payload, **kwargs) + self.sock.send_multipart((sender_identity, *frames)) + + def clean_addresses(self, expiration_time): + """Clean all expired addresses. + + :param float expiration_time: Expiration limit in s. + """ + log.debug("Cleaning addresses.") + now = perf_counter() + for name, time in list(self.heartbeats.items()): + if now > time + 2 * expiration_time: + del self.directory[name] + del self.heartbeats[name] + elif now > time + expiration_time: + self.send_message_raw(self.directory[name], receiver=b".".join((self.node, name)), + data=[[Commands.PING]]) + # Clean Coordinators + for identity, time in list(self.node_heartbeats.items()): + if now > time + 2 * expiration_time: + del self.node_heartbeats[identity] + node = self.node_identities.get(identity, None) + if node is not None: + log.debug(f"Node {node} at {identity} is unresponsive, removing.") + try: + self.dealers[node].close(1) + del self.dealers[node] + del self.waiting_dealers[node] + except KeyError: + pass + del self.node_identities[identity] + elif now > time + expiration_time: + node = self.node_identities.get(identity, None) + log.debug(f"Node {node} expired with identity {identity}, pinging.") + if node is None: + del self.node_heartbeats[identity] + continue + self.send_message(receiver=node + b".COORDINATOR", data=[[Commands.PING]]) + + def routing(self, coordinators=[]): + """Route all messages. + + Connect to Coordinators at the beginning. + :param list coordinators: list of coordinator addresses (host, port). + """ + for coordinator in coordinators: + self.add_coordinator(*coordinator) + self.running = True + while self.running: + if self.sock.poll(self.timeout): + self._routing() + for ns, sock in list(self.waiting_dealers.items()): + if sock.poll(0): + self.handle_dealer_message(sock, ns) + # Cleanup + log.info("Coordinator stopped.") + + def _routing(self): + """Do the routing of one message.""" + sender_identity, *msg = self.sock.recv_multipart() + # Handle different communication cases. + self.deliver_message(sender_identity, msg) + + def deliver_message(self, sender_identity, msg): + """Deliver a message to some recipient""" + try: + version, receiver, sender, header, payload = divide_message(msg) + except IndexError: + log.error(f"Less than two frames received! {msg}.") + return + conversation_id, message_id = interpret_header(header) + log.debug(f"From identity {sender_identity}, from {sender}, to {receiver}, {message_id}, {payload}") + r_node, r_name = split_name(receiver, self.node) + s_node, s_name = split_name(sender, self.node) + # Update heartbeat + if sender_identity: + if s_node == self.node: + if sender_identity == self.directory.get(s_name): + self.heartbeats[s_name] = perf_counter() + elif payload == [f'[["{Commands.SIGNIN}"]]'.encode()]: + pass # Signing in, no heartbeat yet + else: + log.error(f"Message {payload} from not signed in Component {sender}.") + self.send_message_raw(sender_identity, sender, conversation_id=message_id, + data=[[Commands.ERROR, "You did not sign in!"]]) + return + else: + # Message from another Coordinator + self.node_heartbeats[sender_identity] = perf_counter() + # Route the message + if r_node != self.node: + # remote connections. + try: + self.dealers[r_node].send_multipart(msg) + except KeyError: + self.send_message(receiver=sender, + data=[[Commands.ERROR, f"Node {r_node} is not known."]]) + elif r_name == b"COORDINATOR" or r_name == b"": + # Coordinator communication + self.handle_commands(sender_identity, sender, s_node, s_name, payload) + elif receiver_addr := self.directory.get(r_name): + # Local Receiver is known + self.sock.send_multipart((receiver_addr, *msg)) + else: + # Receiver is unknown + log.error(f"Receiver '{receiver}' is not in the addresses list.") + self.send_message(receiver=sender, + data=[[Commands.ERROR, f"Receiver '{receiver}' is not in addresses list."]]) + + def handle_commands(self, sender_identity, sender, s_node, s_name, payload): + """Handle commands for the Coordinator itself.""" + if not payload: + return # Empty payload, just heartbeat. + try: + data = deserialize_data(payload[0]) + except ValueError as exc: + log.exception("Payload decoding error.", exc_info=exc) + return # TODO error message + log.debug(f"Coordinator commands: {data}") + reply = [] + try: + for command in data: + if not command: + continue + elif command[0] == Commands.SIGNIN: + if s_name not in self.directory.keys(): + log.info(f"New Component {s_name} at {sender_identity}.") + reply.append([Commands.SIGNIN, self.node.decode()]) + self.directory[s_name] = sender_identity + self.heartbeats[s_name] = perf_counter() + else: + log.info(f"Another Component at {sender_identity} tries to log in as {s_name}.") + self.send_message_raw(sender_identity, receiver=sender, + data=[[Commands.ERROR, Commands.SIGNIN, "The name is already taken."]]) + return + elif command[0] == Commands.OFF: + self.running = False + reply.append([Commands.ACKNOWLEDGE]) + elif command[0] == Commands.CLEAR: + self.clean_addresses(0) + reply.append([Commands.ACKNOWLEDGE]) + elif command[0] == Commands.LIST: + reply.append(self.compose_local_directory()) + elif command[0] == Commands.SIGNOUT and sender_identity == self.directory.get(s_name): + try: + del self.directory[s_name] + del self.heartbeats[s_name] + except KeyError: + pass # no entry + reply.append([Commands.ACKNOWLEDGE]) + elif command[0] == Commands.CO_SIGNIN and s_node not in self.dealers.keys(): + self.node_identities[sender_identity] = s_node + self.send_message_raw(sender_identity, receiver=sender, + data=[[Commands.ACKNOWLEDGE]]) + return + elif command[0] == Commands.SET: + for key, value in command[1].items(): + if key == "directory": + self.global_directory[s_node] = value + elif key == "nodes": + for node, address in value.items(): + node = node.encode() + if node in self.dealers.keys() or node == self.node: + continue + self.add_coordinator(*address, node=node) + except Exception as exc: + log.exception("Handling commands failed.", exc_info=exc) + log.debug(f"Reply {reply} to {sender} at node {s_node}.") + if s_node == self.node: + self.send_message_raw(sender_identity, receiver=sender, data=reply) + else: + self.send_message(receiver=sender, data=reply) + + def add_coordinator(self, host, port=12300, node=None): + """Add another Coordinator to the connections. + + :param str host: Host name of address to connect to. + :param int port: Port number to connect to. + :param node: Namespace of the Node to add or 'None' for a temporary name. + """ + if node is None: + node = str(random()).encode() + log.debug(f"Add DEALER for Coordinator {node} at {host}:{port}.") + self.dealers[node] = d = self.context.socket(zmq.DEALER) + d.connect(f"tcp://{host}:{port}") + d.send_multipart(create_message(receiver=b"COORDINATOR", sender=self.fname, + payload=serialize_data([[Commands.CO_SIGNIN, + {'host': self.address[0], + 'port': self.address[1]}]]))) + self.node_addresses[node] = host, port + self.waiting_dealers[node] = d + + def handle_dealer_message(self, sock, ns): + """Handle a message at a DEALER socket. + + :param sock: DEALER socket. + :param ns: Temporary name of that socket. + """ + msg = sock.recv_multipart() + try: + version, receiver, sender, header, payload = divide_message(msg) + except IndexError: + log.error(f"Less than two frames received! {msg}.") + return + if deserialize_data(payload[0]) == [[Commands.ACKNOWLEDGE]]: + s_node, s_name = split_name(sender) + addr = self.node_addresses[ns] + del self.dealers[ns] + del self.waiting_dealers[ns] + del self.node_addresses[ns] + self.dealers[s_node] = sock + # Rename address + self.node_addresses[s_node] = addr + log.info(f"Renaming DEALER socket from temporary {ns} to {s_node}.") + self.send_message(receiver=sender, data=[self.compose_local_directory()]) + else: + log.warning(f"Unknown message {payload} from {sender} at DEALER socket {ns}.") + + def compose_local_directory(self): + """Send the local directory to the receiver.""" + return [Commands.SET, + {'directory': [key.decode() for key in self.directory.keys()], + 'nodes': {key.decode(): value for key, value in self.node_addresses.items()}}] + + +if __name__ == "__main__": + if "-v" in sys.argv: # Verbose log. + log.setLevel(logging.DEBUG) + if len(log.handlers) == 1: + log.addHandler(logging.StreamHandler()) + kwargs = {} + if "-h" in sys.argv: + try: + kwargs["host"] = sys.argv[sys.argv.index("-h") + 1] + except IndexError: + pass + coordinators = [] + if "-c" in sys.argv: # Coordinator hostname to connect to. + try: + coordinators.append([sys.argv[sys.argv.index("-c") + 1]]) + except IndexError: + pass + try: + r = Coordinator(**kwargs) + r.routing(coordinators) + except KeyboardInterrupt: + print("Stopped due to keyboard interrupt.") diff --git a/pyleco/timers.py b/pyleco/timers.py new file mode 100644 index 000000000..d2d05fe29 --- /dev/null +++ b/pyleco/timers.py @@ -0,0 +1,60 @@ +# +# This file is part of the PyLECO package. +# +# Copyright (c) 2023-2023 PyLECO Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# + + +from threading import Event, Timer + + +class RepeatingTimer(Timer): + """A timer timing out several times instead of just once. + + Note that the next time is called after the function has finished! + + :param float interval: Interval between readouts in s. + """ + + def __init__(self, interval, function, args=None, kwargs=None): + super().__init__(interval, function, args, kwargs) + self.daemon = True + + def run(self): + while not self.finished.wait(self.interval): + self.function(*self.args, **self.kwargs) + + +class SignallingTimer(RepeatingTimer): + """Repeating timer that sets an Event (:attr:`signal`) at timeout and continues counting. + + :param float interval: Interval in s. + """ + + def __init__(self, interval): + self.signal = Event() + super().__init__(interval, self._timeout, args=(self.signal,)) + + @staticmethod + def _timeout(signal): + """Set and clear the signal event.""" + signal.set() + signal.clear() diff --git a/pyleco/utils.py b/pyleco/utils.py new file mode 100644 index 000000000..2fac98cb4 --- /dev/null +++ b/pyleco/utils.py @@ -0,0 +1,214 @@ +# +# This file is part of the PyLECO package. +# +# Copyright (c) 2023-2023 PyLECO Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# + +try: + from enum import StrEnum +except ImportError: + # for versions <3.11 + from enum import Enum + + class StrEnum(str, Enum): + pass # just inherit +import json + + +# Current protocol version +VERSION = 0 +VERSION_B = VERSION.to_bytes(1, "big") + + +def create_header_frame(conversation_id=b"", message_id=b""): + """Create the header frame. + + :param bytes conversation_id: Message ID of the receiver, for example the ID of its request. + :param bytes message_id: Message ID of this message. + :return: header frame. + """ + return b";".join((conversation_id, message_id)) + + +def create_message(receiver, sender=b"", payload=None, **kwargs): + """Create a message. + + :param bytes receiver: To whom the message is going to be sent. + :param bytes sender: Name of the sender of the message. + :param list of bytes payload: Payload frames. + :param \\**kwargs: Keyword arguments for the header creation. + :return: list of byte messages, ready to send as frames. + """ + if payload: + if isinstance(payload, bytes): + payload = [payload] + return [VERSION_B, receiver, sender, create_header_frame(**kwargs)] + payload + else: + return [VERSION_B, receiver, sender, create_header_frame(**kwargs)] + + +def divide_message(msg): + """Return version, receiver, sender, header frame, and payload frames of a message""" + return msg[0], msg[1], msg[2], msg[3], msg[4:] + + +def split_name(name, node=b""): + """Split a sender/receiver name with given default node.""" + s = name.split(b".") + n = s.pop(-1) + return (s.pop() if s else node), n + + +def split_name_str(name, node=""): + """Split a sender/receiver name with given default node.""" + s = name.split(".") + n = s.pop(-1) + return (s.pop() if s else node), n + + +def interpret_header(header): + """Interpret the header frame.""" + try: + conversation_id, message_id = header.split(b";") + except (IndexError, ValueError): + conversation_id = b"" + message_id = b"" + return conversation_id, message_id + + +# Control content protocol +class Commands(StrEnum): + """Valid commands for the control protocol""" + ERROR = "E" # An error occurred. + GET = "G" + SET = "S" + ACKNOWLEDGE = "A" # Message received. + CALL = "C" + OFF = "O" # Turn off program + CLEAR = "X" + SIGNIN = "SI" + SIGNOUT = "D" + LOG = "L" # configure log level + LIST = "?" # List options + SAVE = "V" + CO_SIGNIN = "COS" # Sign in as a Coordinator + PING = "P" # Check, whether the other side is alive. + + +def serialize_data(data): + """Turn data into a bytes object.""" + return json.dumps(data).encode() + + +def deserialize_data(content): + """Turn received message content into python objects.""" + return json.loads(content.decode()) + + +# Convenience methods +def compose_message(receiver, sender="", conversation_id="", message_id="", + data=None, + ): + """Compose a message. + + :param str/bytes receiver: To whom the message is going to be sent. + :param str/bytes sender: Name of the sender of the message. + :param str/bytes conversation_id: Conversation ID of the receiver, for example the ID of its request. + :param str/bytes message_id: Message ID of this message. + :param data: Python object to send or bytes object. + :return: list of byte messages, sent as frames. + """ + if isinstance(receiver, str): + receiver = receiver.encode() + if isinstance(sender, str): + sender = sender.encode() + if isinstance(conversation_id, str): + conversation_id = conversation_id.encode() + if isinstance(message_id, str): + message_id = message_id.encode() + + if data is not None and not isinstance(data, bytes): + data = serialize_data(data) + return create_message(receiver, sender, payload=data, conversation_id=conversation_id, message_id=message_id) + + +def split_message(msg): + """Split the recieved message and return strings and the data object. + + :return: receiver, sender, conversation_id, message_id, data + """ + # Store necessary data like address and maybe conversation ID + version, receiver, sender, header, payload = divide_message(msg) + assert (v := int.from_bytes(version, "big")) <= VERSION, f"Version {v} is above current version {VERSION}." + conversation_id, message_id = interpret_header(header) + data = deserialize_data(payload[0]) if payload else None + return receiver.decode(), sender.decode(), conversation_id.decode(), message_id.decode(), data + + +# For tests +class FakeContext: + """A fake context instance, similar to the result of `zmq.Context.instance().""" + + def socket(self, socket_type): + return FakeSocket(socket_type) + + +class FakeSocket: + """A fake socket useful for unit tests. + + :attr list _s: contains a list of messages sent via this socket. + :attr list _r: List of messages which can be read. + """ + + def __init__(self, socket_type, *args): + self.socket_type = socket_type + self.addr = None + self._s = [] + self._r = [] + + def bind(self, addr, *args): + self.addr = addr + + def bind_to_random_port(self, addr, *args, **kwargs): + self.addr = addr + return 5 + + def unbind(self, linger=0): + self.addr = None + + def connect(self, addr, *args): + self.addr = addr + + def disconnect(self, linger=0): + self.addr = None + + def poll(self, timeout=0): + return len(self._r) + + def recv_multipart(self): + return self._r.pop() + + def send_multipart(self, parts): + print(parts) + self._s.append(list(parts)) + + def close(self, *args): + self.addr = None diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py new file mode 100644 index 000000000..600916115 --- /dev/null +++ b/tests/test_coordinator.py @@ -0,0 +1,177 @@ +# +# This file is part of the PyLECO package. +# +# Copyright (c) 2023-2023 PyLECO Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# + +import pytest + +from pyleco.utils import VERSION_B, FakeSocket, FakeContext +from pyleco.coordinator import Coordinator + + +@pytest.fixture +def coordinator(): + coordinator = Coordinator(node="N1", host="N1host", cleaning_interval=1e5, context=FakeContext()) + coordinator.directory = {b"send": b"321", b"rec": b"123"} + coordinator.dealers[b"N2"] = FakeSocket("zmq.DEALER") + coordinator.node_identities[b"n2"] = b"N2" + coordinator.node_addresses[b"N2"] = "N2host", 12300 + return coordinator + + +def fake_perf_counter(): + return 0 + + +@pytest.fixture() +def fake_counting(monkeypatch): + monkeypatch.setattr("coordinator.perf_counter", fake_perf_counter) + + +# TODO cleaning anpassen an neue Begebenheiten +# @pytest.fixture() +# def cleaning(fake_counting, coordinator): +# coordinator.heartbeats = {-2: -2, -1: -1.1, -0.5: -0.5, -0.1: -0.1} +# coordinator.addresses = {-2: 1, -1: 1, -0.5: 1, -0.1: 1} +# coordinator.clean_addresses(expiration_time=1) +# return coordinator + + +# def test_clean_addresses(cleaning): +# assert cleaning.addresses == {-0.5: 1, -0.1: 1} + + +# def test_clean_heartbeats(cleaning): +# assert cleaning.heartbeats == {-0.5: -0.5, -0.1: -0.1} + + +def test_heartbeat_local(coordinator, fake_counting): + coordinator.sock._r = [[b"321", VERSION_B, b"COORDINATOR", b"send", b";", b""]] + coordinator._routing() + assert coordinator.heartbeats[b"send"] == 0 + + +@pytest.mark.parametrize("i, o", ( + ([b"321", VERSION_B, b"COORDINATOR", b"send", b";", b""], None), # test heartbeat alone + ([b"321", VERSION_B, b"rec", b"send", b";", b"1"], [b"123", VERSION_B, b"rec", b"send", b";", b"1"]), # receiver known, sender given. + ([b"321", VERSION_B, b"x", b"send", b";", b""], [b"321", VERSION_B, b"send", b"N1.COORDINATOR", b";", b'[["E", "Receiver \'b\'x\'\' is not in addresses list."]]']), # receiver unknown, return to sender + ([b"321", VERSION_B, b"N3.CB", b"N1.send", b";"], [b"321", VERSION_B, b"N1.send", b"N1.COORDINATOR", b";", b'[["E", "Node b\'N3\' is not known."]]']), +)) +def test_routing(coordinator, i, o): + """Test whether some incoming message `i` is sent as `o`.""" + coordinator.sock._r = [i] + coordinator._routing() + if o is None: + assert coordinator.sock._s == [] + else: + assert coordinator.sock._s == [o] + + +def test_remote_routing(coordinator): + coordinator.sock._r = [[b"321", VERSION_B, b"N2.CB", b"N1.send", b";"]] + coordinator._routing() + assert coordinator.dealers[b"N2"]._s == [[VERSION_B, b"N2.CB", b"N1.send", b";"]] + + +def test_remote_heartbeat(coordinator, fake_counting): + coordinator.sock._r = [[b"1", VERSION_B, b"N2.CB", b"N3.CA", b";"]] + coordinator._routing() + assert coordinator.node_heartbeats[b"1"] == 0 + + +# Test Coordinator commands handling +# TODO test individual Coordinator commands and their execution. +def test_signin(coordinator): + coordinator.sock._r = [[b'cb', VERSION_B, b"COORDINATOR", b"CB", b";", b'[["SI"]]']] + coordinator._routing() + assert coordinator.sock._s == [[b"cb", VERSION_B, b"CB", b"N1.COORDINATOR", b";", b'[["SI", "N1"]]']] + + +def test_signout_clears_address(coordinator): + coordinator.sock._r = [[b'123', VERSION_B, b"N1.COORDINATOR", b"rec", b";", b'[["D"]]']] + coordinator._routing() + assert b"rec" not in coordinator.directory.keys() + assert coordinator.sock._s == [[b"123", VERSION_B, b"rec", b"N1.COORDINATOR", b";", b'[["A"]]']] + + +def test_co_signin_successful(coordinator): + coordinator.sock._r = [[b'n3', VERSION_B, b"COORDINATOR", b"N3.COORDINATOR", b";", b'[["COS"]]']] + coordinator._routing() + assert b'n3' in coordinator.node_identities.keys() + assert coordinator.sock._s[0] == [b'n3', VERSION_B, b"N3.COORDINATOR", b"N1.COORDINATOR", b";", b'[["A"]]'] + + +def test_set_directory(coordinator): + coordinator.sock._r = [[b"n2", VERSION_B, b"N1.COORDINATOR", + b"N2.COORDINATOR", b";", b'[["S", {"directory": ["send", "rec"], "nodes": {"N1": ["N1host", 12300], "N2": ["wrong_host", -7], "N3": ["N3host", 12300]}}]]']] + coordinator._routing() + assert coordinator.global_directory == {b"N2": ["send", "rec"]} + assert b"N1" not in coordinator.dealers.keys() # not created + assert coordinator.node_addresses[b"N2"] == ("N2host", 12300) # not changed + assert b"N3" in coordinator.dealers.keys() # newly created + + +class Test_add_coordinator: + @pytest.fixture + def coordinator_added(self, coordinator): + coordinator.add_coordinator("host", node=12345) + return coordinator + + def test_socket_created(self, coordinator_added): + assert coordinator_added.dealers[12345].addr == "tcp://host:12300" + + def test_COS_message_sent(self, coordinator_added): + assert coordinator_added.dealers[12345]._s == [ + [VERSION_B, b"COORDINATOR", b"N1.COORDINATOR", b";", b'[["COS", {"host": "N1host", "port": 12300}]]']] + + def test_address_added(self, coordinator_added): + assert coordinator_added.node_addresses[12345] == ("host", 12300) + + def test_waiting_dealer(self, coordinator_added): + assert 12345 in coordinator_added.waiting_dealers.keys() + + +class Test_handle_dealer_message: + @pytest.fixture + def c_message_handled(self, coordinator): + coordinator.add_coordinator("N3host", node=12345) + sock = coordinator.dealers[12345] + sock._s = [] # reset effects of add_coordinator + sock._r = [[VERSION_B, b"N1.COORDINATOR", b"N3.COORDINATOR", b";", b'[["A"]]']] + coordinator.handle_dealer_message(sock, 12345) + return coordinator + + def test_name_changed(self, c_message_handled): + assert b"N3" in c_message_handled.dealers.keys() + assert 12345 not in c_message_handled.dealers.keys() + + def test_socket_not_waiting_anymore(self, c_message_handled): + assert 12345 not in c_message_handled.waiting_dealers.keys() + + def test_address_name_changed(self, c_message_handled): + assert 12345 not in c_message_handled.node_addresses.keys() + assert c_message_handled.node_addresses[b"N3"] == ("N3host", 12300) + + def test_directory_sent(self, c_message_handled): + assert c_message_handled.dealers[b"N3"]._s == [ + [VERSION_B, b"N3.COORDINATOR", b"N1.COORDINATOR", b";", + b'[["S", {"directory": ["send", "rec"], "nodes": {"N1": ["N1host", 12300], "N2": ["N2host", 12300], "N3": ["N3host", 12300]}}]]']] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..33fb95773 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,91 @@ +# +# This file is part of the PyLECO package. +# +# Copyright (c) 2023-2023 PyLECO Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# + +import pytest + +from pyleco import utils +from pyleco.utils import VERSION_B + + +class Test_Publisher: + @pytest.fixture + def pub(self): + return utils.Publisher(host="localhost") + + def test_init(self, pub): + assert pub.host == "localhost" + + def test_setPort(self, pub): + pub.port = 12345 + assert pub._port == 12345 + + +message_tests = ( + ({'receiver': "broker", 'data': [["GET", [1, 2]], ["GET", 3]]}, + [VERSION_B, b"broker", b"", b";", b'[["GET", [1, 2]], ["GET", 3]]']), + ({'receiver': "someone", 'receiver_mid': "123", 'sender': "ego", 'sender_mid': "1"}, + [VERSION_B, b'someone', b'ego', b'123;1']), + ({'receiver': "router", 'sender': "origin"}, + [VERSION_B, b"router", b"origin", b";"]), +) + + +@pytest.mark.parametrize("kwargs, header", ( + ({}, b";"), +)) +def test_create_header_frame(kwargs, header): + assert utils.create_header_frame(**kwargs) == header + + +@pytest.mark.parametrize("kwargs, message", ( + ({'receiver': b"receiver"}, [VERSION_B, b"receiver", b"", b";"]), + ({'receiver': b"receiver", "payload": [b"abc"]}, [VERSION_B, b"receiver", b"", b";", b"abc"]), + ({'receiver': b"receiver", "payload": b"abc"}, [VERSION_B, b"receiver", b"", b";", b"abc"]), + ({'receiver': b"r", 'payload': [b"xyz"], "message_id": b"7"}, [VERSION_B, b"r", b"", b";7", b"xyz"]), +)) +def test_create_message(kwargs, message): + assert utils.create_message(**kwargs) == message + + +@pytest.mark.parametrize("full_name, node, name", ( + (b"local only", b"node", b"local only"), + (b"abc.def", b"abc", b"def"), +)) +def test_split_name(full_name, node, name): + assert utils.split_name(full_name, b"node") == (node, name) + + +@pytest.mark.parametrize("kwargs, message", message_tests) +def test_compose_message(kwargs, message): + assert utils.compose_message(**kwargs) == message + + +@pytest.mark.parametrize("kwargs, message", message_tests) +def test_split_message(kwargs, message): + receiver, sender, receiver_mid, sender_mid, data = utils.split_message(message) + assert receiver == kwargs.get('receiver') + assert receiver_mid == kwargs.get('receiver_mid', "") + assert sender == kwargs.get('sender', "") + assert sender_mid == kwargs.get('sender_mid', "") + assert data == kwargs.get("data")