diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..258c853 --- /dev/null +++ b/.gitignore @@ -0,0 +1,126 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don’t work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Misc +models/ diff --git a/event_logger.py b/event_logger.py new file mode 100644 index 0000000..cc7df3e --- /dev/null +++ b/event_logger.py @@ -0,0 +1,33 @@ +import logging + +from tensorboardX import SummaryWriter + + +class EventLogger: + def __init__(self, root_dir): + self.root_dir = root_dir + if root_dir is None: + self.tensorboard_logger = None + else: + root_dir.mkdir(parents=True, exist_ok=False) + self.tensorboard_logger = SummaryWriter(str(root_dir)) + self.console = logging.getLogger(__name__) + + def log_scalar(self, tag, value, iteration): + if self.tensorboard_logger is not None: + self.tensorboard_logger.add_scalar(tag, value, iteration) + + def debug(self, msg): + self.console.debug(msg) + + def info(self, msg): + self.console.info(msg) + + def warning(self, msg): + self.console.warning(msg) + + def error(self, msg): + self.console.error(msg) + + def critical(self, msg): + self.console.critical(msg) diff --git a/model.py b/model.py new file mode 100644 index 0000000..c1e6b96 --- /dev/null +++ b/model.py @@ -0,0 +1,88 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.distributions import Categorical + + +EPS = np.finfo(np.float32).eps.item() + + +class Policy(nn.Module): + def __init__(self, num_features, num_actions): + super().__init__() + + self.num_features = num_features + self.num_actions = num_actions + + layer_sizes = [126, 64] + dropout_probs = [0.5, 0.75] + self.network = nn.Sequential( + nn.Linear(num_features, layer_sizes[0]), + nn.ReLU(), + nn.Dropout(dropout_probs[0]), + nn.Linear(layer_sizes[0], layer_sizes[1]), + nn.ReLU(), + nn.Dropout(dropout_probs[1]), + nn.Linear(layer_sizes[1], num_actions), + nn.Softmax(dim=-1) + ) + + def _expand_mask(self, mask): + expanded_mask = [0 for x in range(self.num_actions)] + for i in mask: + expanded_mask[i] = 1 + return expanded_mask + + def predict(self, state, mask): + action_probs = self.network(torch.FloatTensor(state)) + mask = torch.FloatTensor(self._expand_mask(mask)) + masked_probs = action_probs * mask + # Guard against all-zero probabilities + guard_probs = torch.full((self.num_actions,), EPS) * mask + return masked_probs + guard_probs + + def predict_masked_normalized(self, state, mask): + action_probs = self.network(torch.FloatTensor(state)) + mask = torch.ByteTensor(self._expand_mask(mask)) + masked_probs = torch.masked_select(action_probs, mask) + # Guard against all-zero probabilities + masked_probs += torch.full((len(masked_probs),), EPS) + normalized_probs = masked_probs / masked_probs.sum() + return normalized_probs + + def sample_action(self, state, mask): + probs = self.predict(state, mask) + distribution = Categorical(probs) + action = distribution.sample() + return action.item() + + def sample_action_with_log_probability(self, state, mask): + probs = self.predict(state, mask) + distribution = Categorical(probs) + action = distribution.sample() + log_prob = distribution.log_prob(action) + return action, log_prob + + @staticmethod + def save(model, path): + model_descriptor = { + 'num_features': model.num_features, + 'num_actions': model.num_actions, + 'network': model.state_dict() + } + torch.save(model_descriptor, path) + + @staticmethod + def load(path): + model_descriptor = torch.load(path) + num_features = model_descriptor['num_features'] + num_actions = model_descriptor['num_actions'] + model = Policy(num_features, num_actions) + model.load_state_dict(model_descriptor['network']) + return model + + @staticmethod + def load_for_eval(path): + model = Policy.load(path) + model.eval() + return model diff --git a/policy_gradient.py b/policy_gradient.py new file mode 100644 index 0000000..178565d --- /dev/null +++ b/policy_gradient.py @@ -0,0 +1,232 @@ +from collections import deque, namedtuple +import json +from statistics import mean +import torch + +from event_logger import EventLogger +from model import Policy, EPS + +Parameters = namedtuple( + 'Parameters', + [ + 'seed', + 'num_training', + 'num_episodes', + 'batch_size', + 'restart_count', + 'discount_factor', + 'learning_rate', + 'tracking_window', + 'save_interval', + ] +) + +RolloutTrace = namedtuple( + 'RolloutTrace', + [ + 'actions', + 'success', + 'log_probs', + 'rewards', + ] +) + + +RolloutStats = namedtuple( + 'RolloutStats', + [ + 'success', + 'length', + 'action_count', + ] +) + + +def random_int(low, high): + return torch.randint(low, high, (1,))[0].item() + + +class StatsTracker: + def __init__(self, window_size): + self.episode_history = deque(maxlen=window_size) + + def _current_episode(self): + return self.episode_history[-1] + + def _rollout_history(self): + return [y for x in self.episode_history for y in x] + + def track(self, stats): + self._current_episode().append(stats) + + def new_episode(self): + self.episode_history.append([]) + + def success_rate(self): + rollout_history = self._rollout_history() + if len(rollout_history) == 0: + return 0.0 + success_count = sum(1 for stats in rollout_history if stats.success) + return float(success_count) / float(len(rollout_history)) + + def average_length(self): + rollout_history = self._rollout_history() + if len(rollout_history) == 0: + return 0.0 + return mean(map(lambda x: x.length, rollout_history)) + + def average_action_diversity(self): + rollout_history = self._rollout_history() + if len(rollout_history) == 0: + return 0.0 + return mean(map(lambda x: x.action_count, rollout_history)) + + +class Reinforce: + + def __init__(self, env, policy, params, root_dir=None): + self.env = env + self.policy = policy + self.params = params + self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=params.learning_rate) + self.root_dir = root_dir + self.event_logger = EventLogger(root_dir) + self.stats_tracker = StatsTracker(params.tracking_window) + torch.manual_seed(params.seed) + + def save_policy(self, tag): + if self.root_dir is not None: + model_path = self.root_dir / f'policy_{tag}.pt' + Policy.save(self.policy, model_path) + self.event_logger.info(f'Current model saved to {model_path}') + + def rollout(self, bench_id): + actions, log_probs, rewards = [], [], [] + + if bench_id is None: + resp = self.env.restart_rollout() + else: + resp = self.env.start_rollout(bench_id) + + while 'reward' not in resp: + features = resp['features'] + self.event_logger.debug('Current Features: {}'.format(features)) + available_actions = resp['available_actions'] + self.event_logger.debug('Available Actions: {}'.format(available_actions)) + + next_action, log_prob = self.policy.sample_action_with_log_probability( + features, available_actions) + actions.append(next_action) + log_probs.append(log_prob) + # We don't get any reward until the end + rewards.append(0) + + self.event_logger.debug(f'Taking action {next_action}') + resp = self.env.take_action(next_action) + + assert len(actions) > 0 + if resp['reward'] == 1: + success = True + self.event_logger.info('Rollout succeeded') + # Slightly favors proofs with shorter length + # Slightly favors proofs with diverse actions + reward = 1 + 1 / (len(actions) ** 0.1) + 0.01 * (len(set(actions)) ** 0.5) + else: + success = False + self.event_logger.info('Rollout failed') + reward = -0.01 + self.event_logger.info(f'Final reward = {reward}') + rewards[-1] = reward + return RolloutTrace(actions, success, log_probs, rewards) + + def optimize_loss(self, rollout_traces): + batch_rewards = [] + batch_log_probs = [] + for rollout_trace in rollout_traces: + rewards = [] + cumulative_reward = 0 + for reward in reversed(rollout_trace.rewards): + cumulative_reward = reward + self.params.discount_factor * cumulative_reward + rewards.append(cumulative_reward) + batch_rewards.extend(reversed(rewards)) + batch_log_probs.extend(rollout_trace.log_probs) + + reward_tensor = torch.FloatTensor(batch_rewards) + reward_tensor = (reward_tensor - reward_tensor.mean()) / (reward_tensor.std() + EPS) + losses = [] + for log_prob, reward in zip(batch_log_probs, reward_tensor): + losses.append(-log_prob.reshape(1) * reward) + total_loss = torch.cat(losses).sum() + + self.optimizer.zero_grad() + total_loss.backward() + self.optimizer.step() + + return total_loss.item() + + def batch_rollout(self, episode_id): + def single_rollout(rollout_id, restart_id): + if restart_id > 0: + self.event_logger.info(f'Restart #{restart_id} on previous benchmark') + rollout_trace = self.rollout(None) + else: + bench_id = random_int(0, self.params.num_training) + self.event_logger.info(f'Start rollout on benchmark {bench_id}') + rollout_trace = self.rollout(bench_id) + return rollout_trace + + rollout_traces = [] + for rollout_id in range(self.params.batch_size): + self.event_logger.info(f'Batching rollout {rollout_id}...') + for restart_id in range(self.params.restart_count): + rollout_trace = single_rollout(rollout_id, restart_id) + rollout_traces.append(rollout_trace) + rollout_stats = RolloutStats( + length=len(rollout_trace.actions), + success=(rollout_trace.success), + action_count=len(set(rollout_trace.actions)), + ) + self.stats_tracker.track(rollout_stats) + if rollout_trace.success: + break + return rollout_traces + + def train_episode(self, episode_id): + self.event_logger.warning(f'Starting episode {episode_id}...') + + rollout_traces = self.batch_rollout(episode_id) + loss = self.optimize_loss(rollout_traces) + + self.event_logger.log_scalar('Training_Success_Rate', + self.stats_tracker.success_rate(), episode_id) + self.event_logger.log_scalar('Average_Rollout_Length', + self.stats_tracker.average_length(), episode_id) + self.event_logger.log_scalar('Average_Action_Disversity', + self.stats_tracker.average_action_diversity(), episode_id) + self.event_logger.log_scalar('Training_Loss', loss, episode_id) + + if self.params.save_interval is not None and\ + episode_id > 0 and \ + episode_id % self.params.save_interval == 0: + self.save_policy(str(episode_id)) + + self.event_logger.warning(f'Finished episode {episode_id}') + + def save_params(self): + if self.root_dir is not None: + out_path = self.root_dir / 'parameters.json' + with open(out_path, 'w') as f: + json.dump(self.params._asdict(), f) + + def train(self): + # Save the training parameters first for post-training inspection + self.save_params() + + try: + for episode_id in range(self.params.num_episodes): + self.stats_tracker.new_episode() + self.train_episode(episode_id) + except KeyboardInterrupt: + self.event_logger.warning('Training terminated by user') + finally: + self.save_policy('final') diff --git a/search_server.py b/search_server.py new file mode 100644 index 0000000..13d0935 --- /dev/null +++ b/search_server.py @@ -0,0 +1,92 @@ +import argparse +import functools +import logging +from pathlib import Path +import re +import sys + +from model import Policy +from remote import establish_simple_server + + +LOG = logging.getLogger(__name__) + + +def handler(policy, state): + features = state['features'] + available_actions = state['available_actions'] + prob_tensor = policy.predict_masked_normalized(features, available_actions) + return prob_tensor.tolist() + + +def find_policy(args): + if args.iteration is not None: + return args.model / f'policy_{str(args.iteration)}.pt' + + final_policy_path = args.model / 'policy_final.pt' + if final_policy_path.is_file(): + return final_policy_path + + # We need to enumerate the model files and find the last one + max_iteration = 0 + for file_path in args.model.iterdir(): + if file_path.suffix != '.pt': + continue + match = re.match(r'policy_(\d+)', file_path.name) + if match is not None: + iteration = int(match.group(1)) + if iteration > max_iteration: + max_iteration = iteration + final_policy_path = file_path + return final_policy_path + + +def main(args): + policy_path = find_policy(args) + if not policy_path.is_file(): + LOG.error(f'Cannot find model at: {policy_path}') + sys.exit(1) + + policy = Policy.load_for_eval(policy_path) + LOG.info(f'Policy loaded from {policy_path}') + + LOG.info(f'Server starting at {args.addr}:{args.port}') + establish_simple_server(args.addr, args.port, functools.partial(handler, policy)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Demo server that can talk to Coeus search client') + parser.add_argument( + 'model', + type=Path, + help='Directory that holds the model file') + parser.add_argument( + '-i', + '--iteration', + metavar='I', + type=int, + help='Specify the which version of the model to use (indexed by iteration number). ' + 'By default, use the one that gets trained the longest') + parser.add_argument( + '-a', + '--addr', + metavar='HOST', + type=str, + default='localhost', + help='Host name of the server') + parser.add_argument( + '-p', + '--port', + metavar='PORT', + type=int, + default=12345, + help='Remote port of the server') + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + main(args) diff --git a/training.py b/training.py new file mode 100644 index 0000000..be6b154 --- /dev/null +++ b/training.py @@ -0,0 +1,188 @@ +import argparse +import logging +from pathlib import Path +import torch +import sys + +from remote import open_connection +from model import Policy +from policy_gradient import Reinforce, Parameters + +LOG = logging.getLogger(__name__) + + +def train(verifier, args): + num_actions = verifier.get_num_actions() + num_features = verifier.get_num_features() + num_training = verifier.get_num_training() + if num_actions <= 0: + LOG.error(f'Illegal action count: {num_actions}') + sys.exit(1) + if num_features <= 0: + LOG.error(f'Illegal feature count: {num_features}') + sys.exit(1) + if num_training <= 0: + LOG.error(f'Illegal training example count: {num_training}') + sys.exit(1) + + LOG.warning( + 'Verifier connected ' + f'(actions count = {num_actions}, ' + f'features count = {num_features}, ' + f'training example count = {num_training})' + ) + + policy = Policy(num_features, num_actions) + params = Parameters( + seed=args.seed if args.seed is not None else torch.initial_seed() & ((1 << 63) - 1), + num_training=num_training, + num_episodes=args.num_episodes, + batch_size=args.batch_size, + restart_count=args.restart_count, + discount_factor=args.gamma, + learning_rate=args.learning_rate, + tracking_window=args.tracking_window, + save_interval=args.save_interval, + ) + trainer = Reinforce(verifier, policy, params, args.output) + trainer.train() + + +def main(args): + if args.output is not None: + LOG.info(f'Setting output path to {args.output}') + if args.num_episodes <= 0: + LOG.error(f'Episode count must be positive: {args.num_episodes}') + sys.exit(1) + if args.save_interval is not None and args.save_interval <= 0: + LOG.error(f'Save interval must be positive: {args.save_interval}') + sys.exit(1) + if args.batch_size <= 0: + LOG.error(f'Batch size must be positive: {args.batch_size}') + sys.exit(1) + if args.restart_count <= 0: + LOG.error(f'Restart count must be positive: {args.restart_count}') + sys.exit(1) + if args.tracking_window <= 0: + LOG.error(f'Tracking window must be positive: {args.tracking_window}') + sys.exit(1) + if args.gamma > 1.0 or args.gamma <= 0.0: + LOG.error(f'Discount factor must be in (0, 1]: {args.batch_size}') + sys.exit(1) + + LOG.info(f'Connecting to verifier at remote host {args.addr}:{args.port}...') + try: + with open_connection((args.addr, args.port)) as verifier: + train(verifier, args) + except FileExistsError: + LOG.error('Model output directory already exists.') + LOG.error( + 'To prevent accidental overwrites, please remove or rename the existing file first.' + ) + sys.exit(1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Coeus learning engine') + parser.add_argument( + '-a', + '--addr', + metavar='HOST', + type=str, + default='localhost', + help='Host name of the server') + parser.add_argument( + '-p', + '--port', + metavar='PORT', + type=int, + default=12345, + help='Remote port of the server') + parser.add_argument( + '-n', + '--num-episodes', + type=int, + default=1000, + metavar='N', + help='Max number of training episodes (default: 1000)') + parser.add_argument( + '-o', + '--output', + type=Path, + metavar='PATH', + help='Directory where output files (trained models, event logs) are stored. ' + 'Note that nothing will be saved if this argument is absent') + parser.add_argument( + '-g', + '--gamma', + type=float, + default=0.99, + metavar='G', + help='Discount factor (default: 0.99)') + parser.add_argument( + '-l', + '--learning-rate', + type=float, + default=1e-3, + metavar='L', + help='Learning rate (default: 0.001)') + parser.add_argument( + '-b', + '--batch-size', + type=int, + default=32, + metavar='B', + help='Batch size (default: 1)' + ) + parser.add_argument( + '-r', + '--restart-count', + type=int, + default=1, + metavar='R', + help='Number of rollouts on the same benchmark. ' + 'Setting r>1 may benefit from conflict analysis if it is enabled ' + 'on the server side. (default: 1)' + ) + parser.add_argument( + '-s', + '--seed', + type=int, + metavar='SEED', + help='Random seed (default: auto chosen)') + parser.add_argument( + '-w', + '--tracking-window', + type=int, + default=250, + metavar='W', + help='How many episodes are considered when tracking training statistics (default: 250)') + parser.add_argument( + '-i', + '--save-interval', + type=int, + metavar='I', + help='Interval between saving trained model. ' + 'By default models are only saved after the training is done') + parser.add_argument( + '-v', + '--verbose', + dest='verbose_count', + action="count", + default=0, + help="increases log verbosity for each occurence.") + args = parser.parse_args() + + if args.verbose_count == 0: + log_level = logging.WARNING + elif args.verbose_count == 1: + log_level = logging.INFO + else: + log_level = logging.DEBUG + logging.basicConfig( + level=log_level, + format="[%(asctime)s] [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + + main(args)