Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/#233 upgrade pydantic to v2 #235

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# advanced-analytics-framework 0.1.1, released ????-??-??
# 0.2.0 - 2024-12-03

Code name:

Expand All @@ -20,3 +20,4 @@ Code name:
* #217: Rename dataflow abstraction files
* #219: Applied PTB checks and fixes
* #221: Fixed mypy warnings
* #233: Upgraded pydantic to version 2
1 change: 1 addition & 0 deletions doc/changes/unreleased.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Unreleased
2 changes: 1 addition & 1 deletion exasol/analytics/udf/communication/broadcast_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _check_sequence_number(self, specific_message_obj: messages.Broadcast):
def _get_and_check_specific_message_obj(
self, message: messages.Message
) -> messages.Broadcast:
specific_message_obj = message.__root__
specific_message_obj = message.root
if not isinstance(specific_message_obj, messages.Broadcast):
raise TypeError(
f"Received the wrong message type. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class DiscoverySocket:
def __init__(self, ip_address: IPAddress, port: Port):
self._port = port
self._ip_address = ip_address
self._logger = LOGGER.bind(ip_address=ip_address.dict(), port=port.dict())
self._logger = LOGGER.bind(ip_address=ip_address.model_dump(), port=port.model_dump())
self._logger.info("create")
self._udp_socket = socket.socket(
socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP
Expand Down
2 changes: 1 addition & 1 deletion exasol/analytics/udf/communication/gather_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _check_sequence_number(self, specific_message_obj: Gather):
)

def _get_and_check_specific_message_obj(self, message: messages.Message) -> Gather:
specific_message_obj = message.__root__
specific_message_obj = message.root
if not isinstance(specific_message_obj, Gather):
raise TypeError(
f"Received the wrong message type. "
Expand Down
8 changes: 4 additions & 4 deletions exasol/analytics/udf/communication/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field, RootModel

from exasol.analytics.udf.communication.connection_info import ConnectionInfo
from exasol.analytics.udf.communication.peer import Peer
Expand Down Expand Up @@ -133,8 +133,8 @@ class Broadcast(BaseMessage, frozen=True):
sequence_number: int


class Message(BaseModel, frozen=True):
__root__: Union[
class Message(RootModel, frozen=True):
root: Union[
Copy link
Collaborator

@tkilias tkilias Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also

https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-str-discriminators

Pydantic 2 uses this to determine against which model to validate if their is a common string attribute

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See latest Push:

Declared Message.root as union discriminated by message_type:

root = Field(discriminator='message_type')

Ping,
RegisterPeer,
AcknowledgeRegisterPeer,
Expand All @@ -156,4 +156,4 @@ class Message(BaseModel, frozen=True):
Timeout,
Gather,
Broadcast,
]
] = Field(discriminator='message_type')
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self._out_control_socket = out_control_socket
self._states = _States.INIT
self._logger = LOGGER.bind(
peer=peer.dict(), my_connection_info=my_connection_info.dict()
peer=peer.model_dump(), my_connection_info=my_connection_info.model_dump()
)

def stop(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(
self._peer = peer
self._send_attempt_count = 0
self._logger = LOGGER.bind(
peer=peer.dict(),
my_connection_info=my_connection_info.dict(),
peer=peer.model_dump(),
my_connection_info=my_connection_info.model_dump(),
needs_to_send_for_peer=self._needs_to_send_for_peer,
)
self._logger.debug("init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_my_connection_info(self) -> ConnectionInfo:
try:
received = self._out_control.socket.receive()
generic = deserialize_message(received, messages.Message)
message = generic.__root__
message = generic.root
if not isinstance(message, messages.MyConnectionInfo):
raise UnexpectedMessageError(
f"Unexpected message of type {type(message)}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _handle_control_message(self, frames: List[Frame]) -> Status:
message_obj: messages.Message = deserialize_message(
frames[0].to_bytes(), messages.Message
)
specific_message_obj = message_obj.__root__
specific_message_obj = message_obj.root
if isinstance(specific_message_obj, messages.Stop):
return BackgroundListenerThread.Status.STOPPED
elif isinstance(specific_message_obj, PrepareToStop):
Expand All @@ -251,13 +251,13 @@ def _handle_control_message(self, frames: List[Frame]) -> Status:
else:
self._logger.error(
"RegisterPeer message not allowed",
message_obj=specific_message_obj.dict(),
message_obj=specific_message_obj.model_dump(),
)
elif isinstance(specific_message_obj, messages.Payload):
self.send_payload(payload=specific_message_obj, frames=frames)
else:
self._logger.error(
"Unknown message type", message_obj=specific_message_obj.dict()
"Unknown message type", message_obj=specific_message_obj.model_dump()
)
except Exception as e:
self._logger.exception("Exception during handling message", message=frames)
Expand Down Expand Up @@ -291,8 +291,8 @@ def _add_peer(
):
self._logger.error(
"Peer belongs to a different group",
my_connection_info=self._my_connection_info.dict(),
peer=peer.dict(),
my_connection_info=self._my_connection_info.model_dump(),
peer=peer.model_dump(),
)
raise ValueError("Peer belongs to a different group")
if peer not in self._peer_state:
Expand Down Expand Up @@ -321,7 +321,7 @@ def _handle_listener_message(self, frames: List[Frame]):
message_obj: messages.Message = deserialize_message(
message_content_bytes, messages.Message
)
specific_message_obj = message_obj.__root__
specific_message_obj = message_obj.root
if isinstance(specific_message_obj, messages.SynchronizeConnection):
self._handle_synchronize_connection(specific_message_obj)
elif isinstance(specific_message_obj, messages.AcknowledgeConnection):
Expand All @@ -336,7 +336,7 @@ def _handle_listener_message(self, frames: List[Frame]):
else:
logger.error(
"RegisterPeer message not allowed",
message_obj=specific_message_obj.dict(),
message_obj=specific_message_obj.model_dump(),
)
elif isinstance(specific_message_obj, messages.AcknowledgeRegisterPeer):
self._handle_acknowledge_register_peer_message(specific_message_obj)
Expand All @@ -348,7 +348,7 @@ def _handle_listener_message(self, frames: List[Frame]):
self._handle_acknowledge_payload_message(specific_message_obj)
else:
logger.error(
"Unknown message type", message_obj=specific_message_obj.dict()
"Unknown message type", message_obj=specific_message_obj.model_dump()
)
except Exception as e:
logger.exception(
Expand Down Expand Up @@ -450,7 +450,7 @@ def _handle_acknowledge_register_peer_message(
if self._register_peer_connection.successor != message.source:
self._logger.error(
"AcknowledgeRegisterPeer message not from successor",
message_obj=message.dict(),
message_obj=message.model_dump(),
)
peer = message.peer
self._peer_state[peer].received_acknowledge_register_peer()
Expand All @@ -463,7 +463,7 @@ def _handle_register_peer_complete_message(
if self._register_peer_connection.predecessor != message.source:
self._logger.error(
"RegisterPeerComplete message not from predecessor",
message_obj=message.dict(),
message_obj=message.model_dump(),
)
peer = message.peer
self._peer_state[peer].received_register_peer_complete()
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(
self._sender = sender
self._prepare_to_stop = False
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)
self._logger.debug("__init__")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
self._send_attempt_count = 0
self._peer = peer
self._logger = LOGGER.bind(
peer=peer.dict(), my_connection_info=my_connection_info.dict()
peer=peer.model_dump(), my_connection_info=my_connection_info.model_dump()
)
self._logger.debug("init")

Expand All @@ -47,7 +47,7 @@ def _send(self):
else:
self._logger.warning("resend", send_attempt_count=self._send_attempt_count)
message = messages.Message(
__root__=messages.CloseConnection(
root=messages.CloseConnection(
source=self._my_connection_info,
destination=self._peer,
attempt=self._send_attempt_count,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def __init__(
self._peer = peer
self._sender = sender
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)

def received_close_connection(self):
self._logger.debug("received_synchronize_connection")
self._sender.send(
Message(
__root__=messages.AcknowledgeCloseConnection(
root=messages.AcknowledgeCloseConnection(
source=self._my_connection_info, destination=self._peer
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
self._out_control_socket = out_control_socket
self._states = _States.INIT
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=my_connection_info.model_dump(),
)
self._logger.debug("init")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(
self._peer = peer
self._sender = sender
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)
self._send_initial_messages()

Expand All @@ -48,7 +48,7 @@ def received_synchronize_connection(self):
self._logger.debug("received_synchronize_connection")
self._sender.send(
Message(
__root__=messages.AcknowledgeConnection(
root=messages.AcknowledgeConnection(
source=self._my_connection_info, destination=self._peer
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
self._out_control_socket = out_control_socket
self._states = _States.INIT
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=my_connection_info.model_dump(),
)
self._logger.debug("init")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self._peer_register_forwarder_is_ready = False
self._sequence_number = 0
self._logger = LOGGER.bind(
peer=peer.dict(), my_connection_info=my_connection_info.dict()
peer=peer.model_dump(), my_connection_info=my_connection_info.model_dump()
)

def _next_sequence_number(self):
Expand Down Expand Up @@ -72,7 +72,7 @@ def send(self, payload: List[Frame]):
destination=self._peer,
sequence_number=self._next_sequence_number(),
)
self._logger.debug("send", message=message.dict())
self._logger.debug("send", message=message.model_dump())
self._background_listener.send_payload(message=message, payload=payload)
return message.sequence_number

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ def __init__(
self._out_control_socket = out_control_socket
self._sender = sender
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)
self._next_received_payload_sequence_number = 0
self._received_payload_dict: Dict[int, List[Frame]] = {}

def received_payload(self, message: messages.Payload, frames: List[Frame]):
self._logger.info("received_payload", message=message.dict())
self._logger.info("received_payload", message=message.model_dump())
self._send_acknowledge_payload_message(message.sequence_number)
if message.sequence_number == self._next_received_payload_sequence_number:
self._forward_new_message_directly(message, frames)
Expand All @@ -43,13 +43,13 @@ def received_payload(self, message: messages.Payload, frames: List[Frame]):
def _add_new_message_to_buffer(
self, message: messages.Payload, frames: List[Frame]
):
self._logger.info("put_to_buffer", message=message.dict())
self._logger.info("put_to_buffer", message=message.model_dump())
self._received_payload_dict[message.sequence_number] = frames

def _forward_new_message_directly(
self, message: messages.Payload, frames: List[Frame]
):
self._logger.info("forward_from_message", message=message.dict())
self._logger.info("forward_from_message", message=message.model_dump())
self._forward_received_payload(frames)

def _forward_messages_from_buffer(self):
Expand All @@ -74,10 +74,10 @@ def _send_acknowledge_payload_message(self, sequence_number: int):
)
self._logger.info(
"_send_acknowledge_payload_message",
message=acknowledge_payload_message.dict(),
message=acknowledge_payload_message.model_dump(),
)
self._sender.send(
message=messages.Message(__root__=acknowledge_payload_message)
message=messages.Message(root=acknowledge_payload_message)
)

def _forward_received_payload(self, frames: List[Frame]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(
self._payload_message_sender_factory = payload_message_sender_factory
self._sender = sender
self._logger = LOGGER.bind(
peer=self._peer.dict(),
my_connection_info=self._my_connection_info.dict(),
peer=self._peer.model_dump(),
my_connection_info=self._my_connection_info.model_dump(),
)
self._next_send_payload_sequence_number = 0
self._payload_message_sender_dict: Dict[int, PayloadMessageSender] = (
Expand All @@ -58,16 +58,16 @@ def try_send(self):
payload_sender.try_send()

def received_acknowledge_payload(self, message: messages.AcknowledgePayload):
self._logger.info("received_acknowledge_payload", message=message.dict())
self._logger.info("received_acknowledge_payload", message=message.model_dump())
if message.sequence_number in self._payload_message_sender_dict:
self._payload_message_sender_dict[message.sequence_number].stop()
del self._payload_message_sender_dict[message.sequence_number]
self._out_control_socket.send(
serialize_message(messages.Message(__root__=message))
serialize_message(messages.Message(root=message))
)

def send_payload(self, message: messages.Payload, frames: List[Frame]):
self._logger.info("send_payload", message=message.dict())
self._logger.info("send_payload", message=message.model_dump())
self._payload_message_sender_dict[message.sequence_number] = (
self._payload_message_sender_factory.create(
message=message,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
)
self._my_connection_info = self._background_listener.my_connection_info
self._logger = self._logger.bind(
my_connection_info=self._my_connection_info.dict()
my_connection_info=self._my_connection_info.model_dump()
)
self._logger.info("my_connection_info")
self._peer_states: Dict[Peer, FrontendPeerState] = {}
Expand All @@ -93,7 +93,7 @@ def _handle_messages(self, timeout_in_milliseconds: Optional[int] = 0):
for message_obj, frames in self._background_listener.receive_messages(
timeout_in_milliseconds
):
specific_message_obj = message_obj.__root__
specific_message_obj = message_obj.root
if isinstance(specific_message_obj, messages.ConnectionIsReady):
peer = specific_message_obj.peer
self._add_peer_state(peer)
Expand Down Expand Up @@ -122,7 +122,7 @@ def _handle_messages(self, timeout_in_milliseconds: Optional[int] = 0):
raise TimeoutError(specific_message_obj.reason)
else:
self._logger.error(
"Unknown message", message_obj=specific_message_obj.dict()
"Unknown message", message_obj=specific_message_obj.model_dump()
)

def _add_peer_state(self, peer: Peer):
Expand Down Expand Up @@ -170,7 +170,7 @@ def peers(self, timeout_in_milliseconds: Optional[int] = None) -> List[Peer]:

def register_peer(self, peer_connection_info: ConnectionInfo):
self._logger.info(
"register_peer", peer_connection_info=peer_connection_info.dict()
"register_peer", peer_connection_info=peer_connection_info.model_dump()
)
self._handle_messages()
if (
Expand Down
Loading
Loading