diff --git a/harness/determined/ipc.py b/harness/determined/ipc.py index d59f9ccd3f1..b9a7eef20f3 100644 --- a/harness/determined/ipc.py +++ b/harness/determined/ipc.py @@ -95,7 +95,7 @@ def __init__( ) -> None: self._num_connections = num_connections - context = zmq.Context() + context = zmq.Context() # type: ignore self._pub_socket = context.socket(zmq.PUB) self._pull_socket = context.socket(zmq.PULL) @@ -179,7 +179,7 @@ def _recv_one(self) -> Tuple[Any, type]: class ZMQBroadcastClient: def __init__(self, srv_pub_url: str, srv_pull_url: str) -> None: - context = zmq.Context() + context = zmq.Context() # type: ignore self._sub_socket = context.socket(zmq.SUB) # Subscriber always listens to ALL messages. @@ -245,7 +245,7 @@ def __init__( ports: Optional[List[int]] = None, port_range: Optional[Tuple[int, int]] = None, ) -> None: - self.context = zmq.Context() + self.context = zmq.Context() # type: ignore self.sockets = [] # type: List[zmq.Socket] self.ports = [] # type: List[int] @@ -297,11 +297,11 @@ def get_ports(self) -> List[int]: def send(self, py_obj: Any) -> None: for socket in self.sockets: - socket.send_pyobj(py_obj) + socket.send_pyobj(py_obj) # type: ignore def receive_blocking(self, send_rank: int) -> Any: check.lt(send_rank, len(self.sockets)) - message = self.sockets[send_rank].recv_pyobj() + message = self.sockets[send_rank].recv_pyobj() # type: ignore return message def receive_non_blocking( @@ -311,9 +311,9 @@ def receive_non_blocking( timeout = 1000 if not deadline else int(deadline - time.time()) * 1000 timeout = max(timeout, 100) - if self.sockets[send_rank].poll(timeout) == 0: + if self.sockets[send_rank].poll(timeout) == 0: # type: ignore return False, None - message = self.sockets[send_rank].recv_pyobj() + message = self.sockets[send_rank].recv_pyobj() # type: ignore return True, message def barrier( @@ -341,13 +341,13 @@ def barrier( check.is_instance(barrier_message, _OneSidedBarrier) messages.append(barrier_message.message) - self.sockets[0].send_pyobj(_OneSidedBarrier(message=message)) + self.sockets[0].send_pyobj(_OneSidedBarrier(message=message)) # type: ignore return messages def close(self) -> None: for socket in self.sockets: - socket.close() + socket.close() # type: ignore class ZMQClient: @@ -358,7 +358,7 @@ class ZMQClient: """ def __init__(self, ip_address: str, port: int) -> None: - self.context = zmq.Context() + self.context = zmq.Context() # type: ignore self.socket = self.context.socket(zmq.REQ) self.socket.connect(f"tcp://{ip_address}:{port}") diff --git a/requirements.txt b/requirements.txt index 1cc1554088d..8c1ab56fa2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,5 @@ isort==4.3.21 # pytest 6.0 is based on mypy 0.780 mypy==0.780 bump2version>=1.0.0 +# Earlier versions had different type annotations +pyzmq>=22