Skip to content

Commit

Permalink
Add remote library and top-level readme
Browse files Browse the repository at this point in the history
  • Loading branch information
grievejia committed Oct 6, 2019
1 parent 6b591fe commit 8ac95e3
Show file tree
Hide file tree
Showing 5 changed files with 917 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Coeus Learner

This repository contains the reinforcement learning source code for our paper Relational Verification using Reinforcement Learning published at OOPSLA'19. Source code for the relational verifier part is released as [a separated repository](https://github.com/utopia-group/Coeus).

2 changes: 2 additions & 0 deletions remote/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .client import open_connection
from .server import establish_simple_server
133 changes: 133 additions & 0 deletions remote/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import socket
from contextlib import contextmanager
from .sexpdata import loads, dumps


class RemoteClient:
'''A thin wrapper around socket client APIs'''

def __init__(self, sock=None):
if sock is None:
self.sock = socket.socket(
socket.AF_INET, socket.SOCK_STREAM)
else:
self.sock = sock

def connect(self, host, port):
self.sock.connect((host, port))

def disconnect(self):
self.__raw_send('Stop')
self.sock.close()

def __raw_send(self, msg):
# The server side won't terminate parsing until it sees a '\n'
self.sock.sendall(str.encode(msg + '\n'))

def __raw_recv(self):
msg = ''
while True:
# Server side will use '\n' as message terminator
raw_chunk = self.sock.recv(4096)
if not raw_chunk:
raise IOError('Remote server disconnected')
chunk = raw_chunk.decode('utf-8')
msg += chunk
if msg.endswith('\n'):
break
return msg.strip()

def __communicate(self, msg):
self.__raw_send(msg)
reply_msg = self.__raw_recv()
try:
reply_sexp = loads(reply_msg)
if len(reply_sexp) == 0:
raise IOError(
'Unexpected parsing result of messsage "{}"'.format(reply_msg))
reply_sexp[0] = reply_sexp[0].value()
if reply_sexp[0] == 'Error':
raise IOError(reply_sexp[1])
return reply_sexp
except AssertionError:
raise IOError(
'Sexp parsing error for message "{}"'.format(reply_msg))
except IndexError:
raise IOError(
'Sexp index out of bound for message "{}"'.format(reply_msg))

def __expect_ack(self, resp):
try:
if resp[0] != 'Ack':
raise IOError('Protocol error: {}'.format(dumps(resp)))
return int(resp[1])
except IndexError:
raise IOError('Protocol error: {}'.format(dumps(resp)))

def __expect_ackstring(self, resp):
try:
if resp[0] != 'AckString':
raise IOError('Protocol error: {}'.format(dumps(resp)))
return resp[1].value()
except IndexError:
raise IOError('Protocol error: {}'.format(dumps(resp)))

def __expect_state(self, resp):
try:
ret = dict()
if resp[0] == 'NextState':
ret['features'] = resp[1]
ret['available_actions'] = resp[2]
ret['is_final'] = False
elif resp[0] == 'Reward':
ret['reward'] = float(resp[1])
ret['is_final'] = True
else:
raise IOError('Protocol error: {}'.format(dumps(resp)))
return ret
except IndexError:
raise IOError('Protocol error: {}'.format(dumps(resp)))

def get_num_actions(self):
resp = self.__communicate('ActionCount')
return self.__expect_ack(resp)

def get_num_features(self):
resp = self.__communicate('FeatureCount')
return self.__expect_ack(resp)

def get_num_training(self):
resp = self.__communicate('TrainingBenchCount')
return self.__expect_ack(resp)

def get_num_testing(self):
resp = self.__communicate('TestingBenchCount')
return self.__expect_ack(resp)

def get_action_name(self, idx):
resp = self.__communicate('(ActionName {})'.format(idx))
return self.__expect_ackstring(resp)

def get_bench_name(self, idx):
resp = self.__communicate('(BenchName {})'.format(idx))
return self.__expect_ackstring(resp)

def start_rollout(self, idx):
resp = self.__communicate('(PickBench {})'.format(idx))
return self.__expect_state(resp)

def restart_rollout(self):
resp = self.__communicate('RestartBench')
return self.__expect_state(resp)

def take_action(self, idx):
resp = self.__communicate('(TakeAction {})'.format(idx))
return self.__expect_state(resp)


@contextmanager
def open_connection(*args, **kwargs):
sock = socket.create_connection(*args, **kwargs)
client = RemoteClient(sock)
yield client
client.disconnect()
100 changes: 100 additions & 0 deletions remote/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
import socket
from .sexpdata import loads, dumps


LOG = logging.getLogger(__name__)


def raw_recv(sock):
msg = ''
while True:
# Server side will use '\n' as message terminator
raw_chunk = sock.recv(4096)
if not raw_chunk:
raise IOError('Remote client is disconnected')
chunk = raw_chunk.decode('utf-8')
msg += chunk
if msg.endswith('\n'):
break
return msg.strip()


def raw_send(sock, msg):
# The server side won't terminate parsing until it sees a '\n'
sock.sendall(str.encode(msg + '\n'))


def get_response(handler, request_sexp):
if request_sexp[0] == 'Error':
raise IOError(request_sexp[1])
elif request_sexp[0] == 'NextState':
state_dict = dict()
state_dict['features'] = request_sexp[1]
state_dict['available_actions'] = request_sexp[2]

num_actions = len(request_sexp[2])
prio_distr = handler(state_dict)
if not isinstance(prio_distr, list):
raise IOError('Protocol error: '
'server handler must return a list of numbers: {}'
.format(prio_distr))
if len(prio_distr) != num_actions:
raise IOError(
'Protocol error: distribution contains {} items, but there are {} actions'.format(
num_actions, len(prio_distr)))
resp_msg = '(Probability ({}))'.format(
' '.join([str(x) for x in prio_distr]))
return resp_msg
else:
raise IOError('Protocol error: {}'.format(dumps(resp_msg)))


def communicate(sock, handler):
request_msg = raw_recv(sock)
request_sexp = loads(request_msg)
if len(request_sexp) == 0:
raise IOError(
'Unexpected parsing result of messsage "{}"'.format(request_msg))
request_sexp[0] = request_sexp[0].value()
if request_sexp[0] == 'Error':
raise IOError(request_sexp[1])
elif request_sexp[0] == 'Stop':
return True
response = get_response(handler, request_sexp)
if response is not None:
raw_send(sock, response)
return False


def establish_simple_server(addr, port, handler):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_address = (addr, port)
sock.bind(server_address)
sock.listen(0)

try:
while True:
# Wait for a connection
connection, client_address = sock.accept()
try:
LOG.info('connection from {}'.format(client_address))
# Receive the data in small chunks and retransmit it
while True:
should_stop = communicate(connection, handler)
if should_stop:
break
except AssertionError as e:
LOG.warning('Sexp parsing error: {}'.format(e))
except IndexError as e:
LOG.warning('Sexp index out of bound error: {}'.format(e))
except IOError as e:
LOG.warning('I/O error: {}'.format(e))
finally:
# Clean up the connection
LOG.info('connection closed')
connection.close()
except KeyboardInterrupt:
LOG.warning('Received stop request from user.')
finally:
sock.close()
Loading

0 comments on commit 8ac95e3

Please sign in to comment.