Skip to content

Commit

Permalink
Refactor/#233 upgrade pydantic to v2 (#235)
Browse files Browse the repository at this point in the history
* #233: Upgraded pydantic to version 2
* Prepared release 0.2.0
* Declared Message.root as union discriminated by message_type
  • Loading branch information
ckunki authored Dec 3, 2024
1 parent 2214bc4 commit 09f1196
Show file tree
Hide file tree
Showing 36 changed files with 228 additions and 142 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# advanced-analytics-framework 0.1.1, released ????-??-??
# 0.2.0 - 2024-12-03

Code name:

Expand All @@ -20,3 +20,4 @@ Code name:
* #217: Rename dataflow abstraction files
* #219: Applied PTB checks and fixes
* #221: Fixed mypy warnings
* #233: Upgraded pydantic to version 2
1 change: 1 addition & 0 deletions doc/changes/unreleased.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Unreleased
2 changes: 1 addition & 1 deletion exasol/analytics/udf/communication/broadcast_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _check_sequence_number(self, specific_message_obj: messages.Broadcast):
def _get_and_check_specific_message_obj(
self, message: messages.Message
) -> messages.Broadcast:
specific_message_obj = message.__root__
specific_message_obj = message.root
if not isinstance(specific_message_obj, messages.Broadcast):
raise TypeError(
f"Received the wrong message type. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class DiscoverySocket:
def __init__(self, ip_address: IPAddress, port: Port):
self._port = port
self._ip_address = ip_address
self._logger = LOGGER.bind(ip_address=ip_address.dict(), port=port.dict())
self._logger = LOGGER.bind(ip_address=ip_address.model_dump(), port=port.model_dump())
self._logger.info("create")
self._udp_socket = socket.socket(
socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP
Expand Down
2 changes: 1 addition & 1 deletion exasol/analytics/udf/communication/gather_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _check_sequence_number(self, specific_message_obj: Gather):
)

def _get_and_check_specific_message_obj(self, message: messages.Message) -> Gather:
specific_message_obj = message.__root__
specific_message_obj = message.root
if not isinstance(specific_message_obj, Gather):
raise TypeError(
f"Received the wrong message type. "
Expand Down
8 changes: 4 additions & 4 deletions exasol/analytics/udf/communication/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field, RootModel

from exasol.analytics.udf.communication.connection_info import ConnectionInfo
from exasol.analytics.udf.communication.peer import Peer
Expand Down Expand Up @@ -133,8 +133,8 @@ class Broadcast(BaseMessage, frozen=True):
sequence_number: int


class Message(BaseModel, frozen=True):
__root__: Union[
class Message(RootModel, frozen=True):
root: Union[
Ping,
RegisterPeer,
AcknowledgeRegisterPeer,
Expand All @@ -156,4 +156,4 @@ class Message(BaseModel, frozen=True):
Timeout,
Gather,
Broadcast,
]
] = Field(discriminator='message_type')
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self._out_control_socket = out_control_socket
self._states = _States.INIT
self._logger = LOGGER.bind(
peer=peer.dict(), my_connection_info=my_connection_info.dict()
peer=peer.model_dump(), my_connection_info=my_connection_info.model_dump()
)

def stop(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(
self._peer = peer
self._send_attempt_count = 0
self._logger = LOGGER.bind(
peer=peer.dict(),
my_connection_info=my_connection_info.dict(),
peer=peer.model_dump(),
my_connection_info=my_connection_info.model_dump(),
needs_to_send_for_peer=self._needs_to_send_for_peer,
)
self._logger.debug("init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_my_connection_info(self) -> ConnectionInfo:
try:
received = self._out_control.socket.receive()
generic = deserialize_message(received, messages.Message)
message = generic.__root__
message = generic.root
if not isinstance(message, messages.MyConnectionInfo):
raise UnexpectedMessageError(
f"Unexpected message of type {type(message)}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _handle_control_message(self, frames: List[Frame]) -> Status:
message_obj: messages.Message = deserialize_message(
frames[0].to_bytes(), messages.Message
)
specific_message_obj = message_obj.__root__
specific_message_obj = message_obj.root
if isinstance(specific_message_obj, messages.Stop):
return BackgroundListenerThread.Status.STOPPED
elif isinstance(specific_message_obj, PrepareToStop):
Expand All @@ -251,13 +251,13 @@ def _handle_control_message(self, frames: List[Frame]) -> Status:
else:
self._logger.error(
"RegisterPeer message not allowed",
message_obj=specific_message_obj.dict(),
message_obj=specific_message_obj.model_dump(),
)
elif isinstance(specific_message_obj, messages.Payload):
self.send_payload(payload=specific_message_obj, frames=frames)
else:
self._logger.error(
"Unknown message type", message_obj=specific_message_obj.dict()
"Unknown message type", message_obj=specific_message_obj.model_dump()
)
except Exception as e:
self._logger.exception("Exception during handling message", message=frames)
Expand Down Expand Up @@ -291,8 +291,8 @@ def _add_peer(
):
self._logger.error(
"Peer belongs to a different group",
my_connection_info=self._my_connection_info.dict(),
peer=peer.dict(),
my_connection_info=self._my_connection_info.model_dump(),
peer=peer.model_dump(),
)
raise ValueError("Peer belongs to a different group")
if peer not in self._peer_state:
Expand Down Expand Up @@ -321,7 +321,7 @@ def _handle_listener_message(self, frames: List[Frame]):
message_obj: messages.Message = deserialize_message(
message_content_bytes, messages.Message
)
specific_message_obj = message_obj.__root__
specific_message_obj = message_obj.root
if isinstance(specific_message_obj, messages.SynchronizeConnection):
self._handle_synchronize_connection(specific_message_obj)
elif isinstance(specific_message_obj, messages.AcknowledgeConnection):
Expand All @@ -336,7 +336,7 @@ def _handle_listener_message(self, frames: List[Frame]):
else:
logger.error(
"RegisterPeer message not allowed",
message_obj=specific_message_obj.dict(),
message_obj=specific_message_obj.model_dump(),
)
elif isinstance(specific_message_obj, messages.AcknowledgeRegisterPeer):
self._handle_acknowledge_register_peer_message(specific_message_obj)
Expand All @@ -348,7 +348,7 @@ def _handle_listener_message(self, frames: List[Frame]):
self._handle_acknowledge_payload_message(specific_message_obj)
else:
logger.error(
"Unknown message type", message_obj=specific_message_obj.dict()
"Unknown message type", message_obj=specific_message_obj.model_dump()
)
except Exception as e:
logger.exception(
Expand Down Expand Up @@ -450,7 +450,7 @@ def _handle_acknowledge_register_peer_message(
if self._register_peer_connection.successor != message.source:
self._logger.error(
"AcknowledgeRegisterPeer message not from successor",
message_obj=message.dict(),
message_obj=message.model_dump(),
)
peer = message.peer
self._peer_state[peer].received_acknowledge_register_peer()
Expand All @@ -463,7 +463,7 @@ def _handle_register_peer_complete_message(
if self._register_peer_connection.predecessor != message.source:
self._logger.error(
"RegisterPeerComplete message not from predecessor",
message_obj=message.dict(),
message_obj=message.model_dump(),
)
peer = message.peer
self._peer_state[peer].received_register_peer_complete()
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(
self._sender = sender
self._prepare_to_stop = False
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)
self._logger.debug("__init__")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
self._send_attempt_count = 0
self._peer = peer
self._logger = LOGGER.bind(
peer=peer.dict(), my_connection_info=my_connection_info.dict()
peer=peer.model_dump(), my_connection_info=my_connection_info.model_dump()
)
self._logger.debug("init")

Expand All @@ -47,7 +47,7 @@ def _send(self):
else:
self._logger.warning("resend", send_attempt_count=self._send_attempt_count)
message = messages.Message(
__root__=messages.CloseConnection(
root=messages.CloseConnection(
source=self._my_connection_info,
destination=self._peer,
attempt=self._send_attempt_count,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def __init__(
self._peer = peer
self._sender = sender
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)

def received_close_connection(self):
self._logger.debug("received_synchronize_connection")
self._sender.send(
Message(
__root__=messages.AcknowledgeCloseConnection(
root=messages.AcknowledgeCloseConnection(
source=self._my_connection_info, destination=self._peer
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
self._out_control_socket = out_control_socket
self._states = _States.INIT
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=my_connection_info.model_dump(),
)
self._logger.debug("init")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(
self._peer = peer
self._sender = sender
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)
self._send_initial_messages()

Expand All @@ -48,7 +48,7 @@ def received_synchronize_connection(self):
self._logger.debug("received_synchronize_connection")
self._sender.send(
Message(
__root__=messages.AcknowledgeConnection(
root=messages.AcknowledgeConnection(
source=self._my_connection_info, destination=self._peer
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
self._out_control_socket = out_control_socket
self._states = _States.INIT
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=my_connection_info.model_dump(),
)
self._logger.debug("init")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self._peer_register_forwarder_is_ready = False
self._sequence_number = 0
self._logger = LOGGER.bind(
peer=peer.dict(), my_connection_info=my_connection_info.dict()
peer=peer.model_dump(), my_connection_info=my_connection_info.model_dump()
)

def _next_sequence_number(self):
Expand Down Expand Up @@ -72,7 +72,7 @@ def send(self, payload: List[Frame]):
destination=self._peer,
sequence_number=self._next_sequence_number(),
)
self._logger.debug("send", message=message.dict())
self._logger.debug("send", message=message.model_dump())
self._background_listener.send_payload(message=message, payload=payload)
return message.sequence_number

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ def __init__(
self._out_control_socket = out_control_socket
self._sender = sender
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)
self._next_received_payload_sequence_number = 0
self._received_payload_dict: Dict[int, List[Frame]] = {}

def received_payload(self, message: messages.Payload, frames: List[Frame]):
self._logger.info("received_payload", message=message.dict())
self._logger.info("received_payload", message=message.model_dump())
self._send_acknowledge_payload_message(message.sequence_number)
if message.sequence_number == self._next_received_payload_sequence_number:
self._forward_new_message_directly(message, frames)
Expand All @@ -43,13 +43,13 @@ def received_payload(self, message: messages.Payload, frames: List[Frame]):
def _add_new_message_to_buffer(
self, message: messages.Payload, frames: List[Frame]
):
self._logger.info("put_to_buffer", message=message.dict())
self._logger.info("put_to_buffer", message=message.model_dump())
self._received_payload_dict[message.sequence_number] = frames

def _forward_new_message_directly(
self, message: messages.Payload, frames: List[Frame]
):
self._logger.info("forward_from_message", message=message.dict())
self._logger.info("forward_from_message", message=message.model_dump())
self._forward_received_payload(frames)

def _forward_messages_from_buffer(self):
Expand All @@ -74,10 +74,10 @@ def _send_acknowledge_payload_message(self, sequence_number: int):
)
self._logger.info(
"_send_acknowledge_payload_message",
message=acknowledge_payload_message.dict(),
message=acknowledge_payload_message.model_dump(),
)
self._sender.send(
message=messages.Message(__root__=acknowledge_payload_message)
message=messages.Message(root=acknowledge_payload_message)
)

def _forward_received_payload(self, frames: List[Frame]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(
self._payload_message_sender_factory = payload_message_sender_factory
self._sender = sender
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)
self._next_send_payload_sequence_number = 0
self._payload_message_sender_dict: Dict[int, PayloadMessageSender] = (
Expand All @@ -58,16 +58,16 @@ def try_send(self):
payload_sender.try_send()

def received_acknowledge_payload(self, message: messages.AcknowledgePayload):
self._logger.info("received_acknowledge_payload", message=message.dict())
self._logger.info("received_acknowledge_payload", message=message.model_dump())
if message.sequence_number in self._payload_message_sender_dict:
self._payload_message_sender_dict[message.sequence_number].stop()
del self._payload_message_sender_dict[message.sequence_number]
self._out_control_socket.send(
serialize_message(messages.Message(__root__=message))
serialize_message(messages.Message(root=message))
)

def send_payload(self, message: messages.Payload, frames: List[Frame]):
self._logger.info("send_payload", message=message.dict())
self._logger.info("send_payload", message=message.model_dump())
self._payload_message_sender_dict[message.sequence_number] = (
self._payload_message_sender_factory.create(
message=message,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
)
self._my_connection_info = self._background_listener.my_connection_info
self._logger = self._logger.bind(
my_connection_info=self._my_connection_info.dict()
my_connection_info=self._my_connection_info.model_dump()
)
self._logger.info("my_connection_info")
self._peer_states: Dict[Peer, FrontendPeerState] = {}
Expand All @@ -93,7 +93,7 @@ def _handle_messages(self, timeout_in_milliseconds: Optional[int] = 0):
for message_obj, frames in self._background_listener.receive_messages(
timeout_in_milliseconds
):
specific_message_obj = message_obj.__root__
specific_message_obj = message_obj.root
if isinstance(specific_message_obj, messages.ConnectionIsReady):
peer = specific_message_obj.peer
self._add_peer_state(peer)
Expand Down Expand Up @@ -122,7 +122,7 @@ def _handle_messages(self, timeout_in_milliseconds: Optional[int] = 0):
raise TimeoutError(specific_message_obj.reason)
else:
self._logger.error(
"Unknown message", message_obj=specific_message_obj.dict()
"Unknown message", message_obj=specific_message_obj.model_dump()
)

def _add_peer_state(self, peer: Peer):
Expand Down Expand Up @@ -170,7 +170,7 @@ def peers(self, timeout_in_milliseconds: Optional[int] = None) -> List[Peer]:

def register_peer(self, peer_connection_info: ConnectionInfo):
self._logger.info(
"register_peer", peer_connection_info=peer_connection_info.dict()
"register_peer", peer_connection_info=peer_connection_info.model_dump()
)
self._handle_messages()
if (
Expand Down
Loading

0 comments on commit 09f1196

Please sign in to comment.