Skip to content

Commit

Permalink
remove unnecessary sockets (#306)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #306

Removing ZMQ socket because it is no longer necesssary (all clients can communicate over pysocket), and the pyzmq dependency is a burden.

Differential Revision: D47733703

fbshipit-source-id: 379a3f686626b415b087d19701f5b8e1ad5971ac
  • Loading branch information
crasanders authored and facebook-github-bot committed Jul 25, 2023
1 parent f0ac680 commit 7caeff8
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 174 deletions.
2 changes: 1 addition & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The canonical way of using AEPsych is to launch it in server mode (you can run `
aepsych_server --port 5555 --ip 0.0.0.0 --db mydatabase.db
```

The server accepts messages over either a unix socket or [ZMQ](https://zeromq.org/), and
The server accepts messages over a unix socket, and
all messages are formatted using [JSON](https://www.json.org/json-en.html). All messages
have the following format:

Expand Down
4 changes: 1 addition & 3 deletions aepsych/server/message_handlers/handle_exit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ def handle_exit(server, request):
termination_type = "Normal termination"
logger.info("Got termination message!")
server.write_strats(termination_type)
if not server.is_using_thrift:
server.exit_server_loop = True
server.exit_server_loop = True

# If using thrift, it will add 'Terminate' to the queue and pass it to thrift server level
return "Terminate"
39 changes: 14 additions & 25 deletions aepsych/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
replay,
)

from aepsych.server.sockets import BAD_REQUEST, createSocket, DummySocket
from aepsych.server.sockets import BAD_REQUEST, DummySocket, PySocket

logger = utils_logging.getLogger(logging.INFO)
DEFAULT_DESC = "default description"
Expand All @@ -48,7 +48,7 @@ def get_next_filename(folder, fname, ext):


class AEPsychServer(object):
def __init__(self, socket=None, database_path=None, thrift=False):
def __init__(self, socket=None, database_path=None):
"""Server for doing black box optimization using gaussian processes.
Keyword Arguments:
socket -- socket object that implements `send` and `receive` for json
Expand Down Expand Up @@ -79,7 +79,6 @@ def __init__(self, socket=None, database_path=None, thrift=False):
self.enable_pregen = False

self.debug = False
self.is_using_thrift = thrift
self.receive_thread = threading.Thread(
target=self._receive_send, args=(self.exit_server_loop,), daemon=True
)
Expand Down Expand Up @@ -153,20 +152,15 @@ def serve(self) -> None:
# yeah we're not sanitizing input at all

# Start the method to accept a client connection

if self.is_using_thrift is True:
self.queue.append(self.socket.receive())
self.socket.accept_client()
self.receive_thread.start()
while True:
self._handle_queue()
else:
self.socket.accept_client()
self.receive_thread.start()
while True:
self._handle_queue()
if self.exit_server_loop:
break
# Close the socket and terminate with code 0
self.cleanup()
sys.exit(0)
if self.exit_server_loop:
break
# Close the socket and terminate with code 0
self.cleanup()
sys.exit(0)

def _unpack_strat_buffer(self, strat_buffer):
if isinstance(strat_buffer, io.BytesIO):
Expand Down Expand Up @@ -394,12 +388,7 @@ def parse_argument():
parser.add_argument(
"--port", metavar="N", type=int, default=5555, help="port to serve on"
)
parser.add_argument(
"--socket_type",
choices=["zmq", "pysocket"],
default="pysocket",
help="method to serve over",
)

parser.add_argument(
"--ip",
metavar="M",
Expand Down Expand Up @@ -449,7 +438,7 @@ def start_server(server_class, args):
if "replay" in args and args.replay is not None:
logger.info(f"Attempting to replay {args.replay}")
if args.resume is True:
sock = createSocket(socket_type=args.socket_type, port=args.port)
sock = PySocket(socket_type=args.socket_type, port=args.port)
logger.info(f"Will resume {args.replay}")
else:
sock = None
Expand All @@ -462,15 +451,15 @@ def start_server(server_class, args):
)
else:
logger.info(f"Setting the database path {database_path}")
sock = createSocket(socket_type=args.socket_type, port=args.port)
sock = PySocket(socket_type=args.socket_type, port=args.port)
startServerAndRun(
server_class,
database_path=database_path,
socket=sock,
config_path=args.stratconfig,
)
else:
sock = createSocket(socket_type=args.socket_type, port=args.port)
sock = PySocket(socket_type=args.socket_type, port=args.port)
startServerAndRun(server_class, socket=sock, config_path=args.stratconfig)

except (KeyboardInterrupt, SystemExit):
Expand Down
81 changes: 1 addition & 80 deletions aepsych/server/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import aepsych.utils_logging as utils_logging
import numpy as np
import zmq

logger = utils_logging.getLogger(logging.INFO)
BAD_REQUEST = "bad request"
Expand All @@ -30,60 +29,11 @@ def SimplifyArrays(message):
}


def createSocket(socket_type="pysocket", port=5555, msg_queue=None):
logger.info(f"socket_type = {socket_type} port = {port}")

if socket_type == "pysocket":
sock = PySocket(port=port)
elif socket_type == "zmq":
sock = ZMQSocket(port=port)
elif socket_type == "thrift":
sock = ThriftSocketWrapper(msg_queue)

return sock


class DummySocket(object):
def close(self):
pass


class ZMQSocket(object):
def __init__(self, port, ip="*"):
"""sends/receives json-formated messages over ZMQ
Arguments:
port {int} -- port to listen over
"""

self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(f"tcp://{ip}:{port}")

def close(self):
self.socket.close()

def receive(self):
while True:
try:
msg = self.socket.recv_json()
break
except Exception as e:
logger.info(
"Exception caught while trying to receive a message from the client. "
f"Ignoring message and trying again. The caught exception was: {e}."
)
return msg

def send(self, message):
if type(message) == str:
self.socket.send_string(message)
elif type(message) == int:
self.socket.send_string(str(message))
else:
self.socket.send_json(SimplifyArrays(message))


class PySocket(object):
def __init__(self, port, ip=""):

Expand Down Expand Up @@ -138,7 +88,7 @@ def receive(self, server_exiting):
logger.debug(f"receive : result = {recv_result}")
logger.info(f"Got: {msg}")
return msg
except Exception as e:
except Exception:
return BAD_REQUEST

def send(self, message):
Expand All @@ -155,32 +105,3 @@ def send(self, message):
logger.info(f"Sending: {message}")
sys.stdout.flush()
self.conn.sendall(bytes(message, "utf-8"))


class ThriftSocketWrapper(object):
def __init__(self, msg_queue=None):
self.msg_queue = msg_queue

def close(self):
# it's not a real socket so no close function need?
pass

def receive(self):
# Remove and return an item from the queue. If queue is empty, wait until an item is available.
message = self.msg_queue.get()
logger.info(f"thrift socket got msg: {message}")
return message

def send(self, message):
# add responds to msg_queue
if self.msg_queue is None:
logger.exception("There is no msg_queue!")
raise RuntimeError("There is no message to send from server!")
if type(message) == str:
pass # keep it as-is
elif type(message) == int:
message = str(message)
else:
message = json.dumps(SimplifyArrays(message))
logger.info(f"Sending: {message}")
self.msg_queue.put(message, block=True)
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
REQUIRES = [
"matplotlib",
"torch",
"pyzmq==19.0.2",
"scipy",
"SQLAlchemy==1.4.46",
"dill",
Expand Down
64 changes: 0 additions & 64 deletions tests/test_ThriftSocketWrapper.py

This file was deleted.

0 comments on commit 7caeff8

Please sign in to comment.