forked from utopia-group/ReCoeus
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add remote library and top-level readme
- Loading branch information
Showing
5 changed files
with
917 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Coeus Learner | ||
|
||
This repository contains the reinforcement learning source code for our paper Relational Verification using Reinforcement Learning published at OOPSLA'19. Source code for the relational verifier part is released as [a separated repository](https://github.com/utopia-group/Coeus). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .client import open_connection | ||
from .server import establish_simple_server |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import socket | ||
from contextlib import contextmanager | ||
from .sexpdata import loads, dumps | ||
|
||
|
||
class RemoteClient: | ||
'''A thin wrapper around socket client APIs''' | ||
|
||
def __init__(self, sock=None): | ||
if sock is None: | ||
self.sock = socket.socket( | ||
socket.AF_INET, socket.SOCK_STREAM) | ||
else: | ||
self.sock = sock | ||
|
||
def connect(self, host, port): | ||
self.sock.connect((host, port)) | ||
|
||
def disconnect(self): | ||
self.__raw_send('Stop') | ||
self.sock.close() | ||
|
||
def __raw_send(self, msg): | ||
# The server side won't terminate parsing until it sees a '\n' | ||
self.sock.sendall(str.encode(msg + '\n')) | ||
|
||
def __raw_recv(self): | ||
msg = '' | ||
while True: | ||
# Server side will use '\n' as message terminator | ||
raw_chunk = self.sock.recv(4096) | ||
if not raw_chunk: | ||
raise IOError('Remote server disconnected') | ||
chunk = raw_chunk.decode('utf-8') | ||
msg += chunk | ||
if msg.endswith('\n'): | ||
break | ||
return msg.strip() | ||
|
||
def __communicate(self, msg): | ||
self.__raw_send(msg) | ||
reply_msg = self.__raw_recv() | ||
try: | ||
reply_sexp = loads(reply_msg) | ||
if len(reply_sexp) == 0: | ||
raise IOError( | ||
'Unexpected parsing result of messsage "{}"'.format(reply_msg)) | ||
reply_sexp[0] = reply_sexp[0].value() | ||
if reply_sexp[0] == 'Error': | ||
raise IOError(reply_sexp[1]) | ||
return reply_sexp | ||
except AssertionError: | ||
raise IOError( | ||
'Sexp parsing error for message "{}"'.format(reply_msg)) | ||
except IndexError: | ||
raise IOError( | ||
'Sexp index out of bound for message "{}"'.format(reply_msg)) | ||
|
||
def __expect_ack(self, resp): | ||
try: | ||
if resp[0] != 'Ack': | ||
raise IOError('Protocol error: {}'.format(dumps(resp))) | ||
return int(resp[1]) | ||
except IndexError: | ||
raise IOError('Protocol error: {}'.format(dumps(resp))) | ||
|
||
def __expect_ackstring(self, resp): | ||
try: | ||
if resp[0] != 'AckString': | ||
raise IOError('Protocol error: {}'.format(dumps(resp))) | ||
return resp[1].value() | ||
except IndexError: | ||
raise IOError('Protocol error: {}'.format(dumps(resp))) | ||
|
||
def __expect_state(self, resp): | ||
try: | ||
ret = dict() | ||
if resp[0] == 'NextState': | ||
ret['features'] = resp[1] | ||
ret['available_actions'] = resp[2] | ||
ret['is_final'] = False | ||
elif resp[0] == 'Reward': | ||
ret['reward'] = float(resp[1]) | ||
ret['is_final'] = True | ||
else: | ||
raise IOError('Protocol error: {}'.format(dumps(resp))) | ||
return ret | ||
except IndexError: | ||
raise IOError('Protocol error: {}'.format(dumps(resp))) | ||
|
||
def get_num_actions(self): | ||
resp = self.__communicate('ActionCount') | ||
return self.__expect_ack(resp) | ||
|
||
def get_num_features(self): | ||
resp = self.__communicate('FeatureCount') | ||
return self.__expect_ack(resp) | ||
|
||
def get_num_training(self): | ||
resp = self.__communicate('TrainingBenchCount') | ||
return self.__expect_ack(resp) | ||
|
||
def get_num_testing(self): | ||
resp = self.__communicate('TestingBenchCount') | ||
return self.__expect_ack(resp) | ||
|
||
def get_action_name(self, idx): | ||
resp = self.__communicate('(ActionName {})'.format(idx)) | ||
return self.__expect_ackstring(resp) | ||
|
||
def get_bench_name(self, idx): | ||
resp = self.__communicate('(BenchName {})'.format(idx)) | ||
return self.__expect_ackstring(resp) | ||
|
||
def start_rollout(self, idx): | ||
resp = self.__communicate('(PickBench {})'.format(idx)) | ||
return self.__expect_state(resp) | ||
|
||
def restart_rollout(self): | ||
resp = self.__communicate('RestartBench') | ||
return self.__expect_state(resp) | ||
|
||
def take_action(self, idx): | ||
resp = self.__communicate('(TakeAction {})'.format(idx)) | ||
return self.__expect_state(resp) | ||
|
||
|
||
@contextmanager | ||
def open_connection(*args, **kwargs): | ||
sock = socket.create_connection(*args, **kwargs) | ||
client = RemoteClient(sock) | ||
yield client | ||
client.disconnect() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import logging | ||
import socket | ||
from .sexpdata import loads, dumps | ||
|
||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
|
||
def raw_recv(sock): | ||
msg = '' | ||
while True: | ||
# Server side will use '\n' as message terminator | ||
raw_chunk = sock.recv(4096) | ||
if not raw_chunk: | ||
raise IOError('Remote client is disconnected') | ||
chunk = raw_chunk.decode('utf-8') | ||
msg += chunk | ||
if msg.endswith('\n'): | ||
break | ||
return msg.strip() | ||
|
||
|
||
def raw_send(sock, msg): | ||
# The server side won't terminate parsing until it sees a '\n' | ||
sock.sendall(str.encode(msg + '\n')) | ||
|
||
|
||
def get_response(handler, request_sexp): | ||
if request_sexp[0] == 'Error': | ||
raise IOError(request_sexp[1]) | ||
elif request_sexp[0] == 'NextState': | ||
state_dict = dict() | ||
state_dict['features'] = request_sexp[1] | ||
state_dict['available_actions'] = request_sexp[2] | ||
|
||
num_actions = len(request_sexp[2]) | ||
prio_distr = handler(state_dict) | ||
if not isinstance(prio_distr, list): | ||
raise IOError('Protocol error: ' | ||
'server handler must return a list of numbers: {}' | ||
.format(prio_distr)) | ||
if len(prio_distr) != num_actions: | ||
raise IOError( | ||
'Protocol error: distribution contains {} items, but there are {} actions'.format( | ||
num_actions, len(prio_distr))) | ||
resp_msg = '(Probability ({}))'.format( | ||
' '.join([str(x) for x in prio_distr])) | ||
return resp_msg | ||
else: | ||
raise IOError('Protocol error: {}'.format(dumps(resp_msg))) | ||
|
||
|
||
def communicate(sock, handler): | ||
request_msg = raw_recv(sock) | ||
request_sexp = loads(request_msg) | ||
if len(request_sexp) == 0: | ||
raise IOError( | ||
'Unexpected parsing result of messsage "{}"'.format(request_msg)) | ||
request_sexp[0] = request_sexp[0].value() | ||
if request_sexp[0] == 'Error': | ||
raise IOError(request_sexp[1]) | ||
elif request_sexp[0] == 'Stop': | ||
return True | ||
response = get_response(handler, request_sexp) | ||
if response is not None: | ||
raw_send(sock, response) | ||
return False | ||
|
||
|
||
def establish_simple_server(addr, port, handler): | ||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||
server_address = (addr, port) | ||
sock.bind(server_address) | ||
sock.listen(0) | ||
|
||
try: | ||
while True: | ||
# Wait for a connection | ||
connection, client_address = sock.accept() | ||
try: | ||
LOG.info('connection from {}'.format(client_address)) | ||
# Receive the data in small chunks and retransmit it | ||
while True: | ||
should_stop = communicate(connection, handler) | ||
if should_stop: | ||
break | ||
except AssertionError as e: | ||
LOG.warning('Sexp parsing error: {}'.format(e)) | ||
except IndexError as e: | ||
LOG.warning('Sexp index out of bound error: {}'.format(e)) | ||
except IOError as e: | ||
LOG.warning('I/O error: {}'.format(e)) | ||
finally: | ||
# Clean up the connection | ||
LOG.info('connection closed') | ||
connection.close() | ||
except KeyboardInterrupt: | ||
LOG.warning('Received stop request from user.') | ||
finally: | ||
sock.close() |
Oops, something went wrong.