Skip to content

Commit

Permalink
Initial Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
grievejia committed Oct 6, 2019
1 parent 5f05315 commit 6b591fe
Show file tree
Hide file tree
Showing 6 changed files with 759 additions and 0 deletions.
126 changes: 126 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don’t work, or not
# install all needed dependencies.
#Pipfile.lock

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Misc
models/
33 changes: 33 additions & 0 deletions event_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging

from tensorboardX import SummaryWriter


class EventLogger:
def __init__(self, root_dir):
self.root_dir = root_dir
if root_dir is None:
self.tensorboard_logger = None
else:
root_dir.mkdir(parents=True, exist_ok=False)
self.tensorboard_logger = SummaryWriter(str(root_dir))
self.console = logging.getLogger(__name__)

def log_scalar(self, tag, value, iteration):
if self.tensorboard_logger is not None:
self.tensorboard_logger.add_scalar(tag, value, iteration)

def debug(self, msg):
self.console.debug(msg)

def info(self, msg):
self.console.info(msg)

def warning(self, msg):
self.console.warning(msg)

def error(self, msg):
self.console.error(msg)

def critical(self, msg):
self.console.critical(msg)
88 changes: 88 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Categorical


EPS = np.finfo(np.float32).eps.item()


class Policy(nn.Module):
def __init__(self, num_features, num_actions):
super().__init__()

self.num_features = num_features
self.num_actions = num_actions

layer_sizes = [126, 64]
dropout_probs = [0.5, 0.75]
self.network = nn.Sequential(
nn.Linear(num_features, layer_sizes[0]),
nn.ReLU(),
nn.Dropout(dropout_probs[0]),
nn.Linear(layer_sizes[0], layer_sizes[1]),
nn.ReLU(),
nn.Dropout(dropout_probs[1]),
nn.Linear(layer_sizes[1], num_actions),
nn.Softmax(dim=-1)
)

def _expand_mask(self, mask):
expanded_mask = [0 for x in range(self.num_actions)]
for i in mask:
expanded_mask[i] = 1
return expanded_mask

def predict(self, state, mask):
action_probs = self.network(torch.FloatTensor(state))
mask = torch.FloatTensor(self._expand_mask(mask))
masked_probs = action_probs * mask
# Guard against all-zero probabilities
guard_probs = torch.full((self.num_actions,), EPS) * mask
return masked_probs + guard_probs

def predict_masked_normalized(self, state, mask):
action_probs = self.network(torch.FloatTensor(state))
mask = torch.ByteTensor(self._expand_mask(mask))
masked_probs = torch.masked_select(action_probs, mask)
# Guard against all-zero probabilities
masked_probs += torch.full((len(masked_probs),), EPS)
normalized_probs = masked_probs / masked_probs.sum()
return normalized_probs

def sample_action(self, state, mask):
probs = self.predict(state, mask)
distribution = Categorical(probs)
action = distribution.sample()
return action.item()

def sample_action_with_log_probability(self, state, mask):
probs = self.predict(state, mask)
distribution = Categorical(probs)
action = distribution.sample()
log_prob = distribution.log_prob(action)
return action, log_prob

@staticmethod
def save(model, path):
model_descriptor = {
'num_features': model.num_features,
'num_actions': model.num_actions,
'network': model.state_dict()
}
torch.save(model_descriptor, path)

@staticmethod
def load(path):
model_descriptor = torch.load(path)
num_features = model_descriptor['num_features']
num_actions = model_descriptor['num_actions']
model = Policy(num_features, num_actions)
model.load_state_dict(model_descriptor['network'])
return model

@staticmethod
def load_for_eval(path):
model = Policy.load(path)
model.eval()
return model
Loading

0 comments on commit 6b591fe

Please sign in to comment.