Skip to content

Commit

Permalink
chore: fix type errors in IPC code (#1885)
Browse files Browse the repository at this point in the history
  • Loading branch information
mackrorysd authored Jan 28, 2021
1 parent 058ba7d commit fa85d33
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
20 changes: 10 additions & 10 deletions harness/determined/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
) -> None:
self._num_connections = num_connections

context = zmq.Context()
context = zmq.Context() # type: ignore

self._pub_socket = context.socket(zmq.PUB)
self._pull_socket = context.socket(zmq.PULL)
Expand Down Expand Up @@ -179,7 +179,7 @@ def _recv_one(self) -> Tuple[Any, type]:

class ZMQBroadcastClient:
def __init__(self, srv_pub_url: str, srv_pull_url: str) -> None:
context = zmq.Context()
context = zmq.Context() # type: ignore

self._sub_socket = context.socket(zmq.SUB)
# Subscriber always listens to ALL messages.
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(
ports: Optional[List[int]] = None,
port_range: Optional[Tuple[int, int]] = None,
) -> None:
self.context = zmq.Context()
self.context = zmq.Context() # type: ignore
self.sockets = [] # type: List[zmq.Socket]
self.ports = [] # type: List[int]

Expand Down Expand Up @@ -297,11 +297,11 @@ def get_ports(self) -> List[int]:

def send(self, py_obj: Any) -> None:
for socket in self.sockets:
socket.send_pyobj(py_obj)
socket.send_pyobj(py_obj) # type: ignore

def receive_blocking(self, send_rank: int) -> Any:
check.lt(send_rank, len(self.sockets))
message = self.sockets[send_rank].recv_pyobj()
message = self.sockets[send_rank].recv_pyobj() # type: ignore
return message

def receive_non_blocking(
Expand All @@ -311,9 +311,9 @@ def receive_non_blocking(
timeout = 1000 if not deadline else int(deadline - time.time()) * 1000
timeout = max(timeout, 100)

if self.sockets[send_rank].poll(timeout) == 0:
if self.sockets[send_rank].poll(timeout) == 0: # type: ignore
return False, None
message = self.sockets[send_rank].recv_pyobj()
message = self.sockets[send_rank].recv_pyobj() # type: ignore
return True, message

def barrier(
Expand Down Expand Up @@ -341,13 +341,13 @@ def barrier(

check.is_instance(barrier_message, _OneSidedBarrier)
messages.append(barrier_message.message)
self.sockets[0].send_pyobj(_OneSidedBarrier(message=message))
self.sockets[0].send_pyobj(_OneSidedBarrier(message=message)) # type: ignore

return messages

def close(self) -> None:
for socket in self.sockets:
socket.close()
socket.close() # type: ignore


class ZMQClient:
Expand All @@ -358,7 +358,7 @@ class ZMQClient:
"""

def __init__(self, ip_address: str, port: int) -> None:
self.context = zmq.Context()
self.context = zmq.Context() # type: ignore
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(f"tcp://{ip_address}:{port}")

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ isort==4.3.21
# pytest 6.0 is based on mypy 0.780
mypy==0.780
bump2version>=1.0.0
# Earlier versions had different type annotations
pyzmq>=22

0 comments on commit fa85d33

Please sign in to comment.