-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 44dfa90
Showing
370 changed files
with
494,322 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# Distribution / packaging | ||
.Python | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
__pycache__ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# misc | ||
*.urdf | ||
*.mp4 | ||
*.avi | ||
*.mpeg | ||
*.mpg | ||
*.png | ||
|
||
# prevent from storing data files in the git because they are big | ||
*.ptx | ||
*.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
BSD 3-Clause License | ||
|
||
Copyright (c) 2023 Julian Whitman | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
3. Neither the name of the copyright holder nor the names of its | ||
contributors may be used to endorse or promote products derived from | ||
this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
**Overview** | ||
Code used in the publication "Learning modular robot control policies." (see https://arxiv.org/abs/2105.10049) Trains and runs modular robot policies with model based reinforcement learning, from the Biorobotics Laboratory at Carnegie Mellon University. Written and maintained by Julian Whitman. | ||
|
||
**System requirements** | ||
- Training: NVIDIA GPU with minimum 8 Gb VRAM, ideally multiple GPUs with >12 Gb. | ||
- Running policies: most CPUs can run the policy, but we have only verified computers with at least four Intel i7 cores running it in real-time. | ||
|
||
**Dependencies** | ||
- python3 | ||
- pybullet for simulation: pip3 install pybullet | ||
- pytorch for deep neural networks: see https://pytorch.org/, install the version corresponding to your OS and GPU. | ||
- scipy for interpolation utility: pip3 install scipy | ||
- If you want to compile the modular robot urdfs from xacros, this requires a ROS verison of at least kinetic and with at least the xacro command installed. | ||
- If you are using a joystick to control the trained policy, get pygame for joystick reading: pip install pygame | ||
- The physical robot control (run_robot_policy.py) uses the hebi python API: pip install hebi-py, but is not needed for training or simulation, so most users will not need to install this package. | ||
- Some analysis scripts (with file extension .ipynb) use jupyter notebook, but it is not necessary to run the training or simulation tests. | ||
|
||
**Running** | ||
- The first step after installing dependencies is to simulate a pre-trained policy with modular_policy/simulate_policy.py. | ||
- If you would like to train the modular policy from scratch, the main modular policy training script is modular_policy/mbrl.py. See "Learning modular robot control policies" for more information about the training process and compute time. | ||
|
||
**Repository contents** | ||
- modular_policy contains scripts and utilities for training and executing modular policies. | ||
- mpl_policy contains scripts and utilities for training and executing multi-layer perceptron policies, which serve as a basis of comparison. | ||
- urdf contains the robot models used in simulations. | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
''' | ||
Code adapted from: | ||
https://github.com/emadRad/lstm-gru-pytorch/blob/master/lstm_gru.ipynb | ||
''' | ||
import torch | ||
import torch.nn as nn | ||
# import torchvision.transforms as transforms | ||
# import torchvision.datasets as dsets | ||
# from torch.autograd import Variable | ||
# from torch.nn import Parameter | ||
# from torch import Tensor | ||
import torch.nn.functional as F | ||
import math | ||
|
||
class LSTMCell(nn.Module): | ||
|
||
""" | ||
An implementation of Hochreiter & Schmidhuber: | ||
'Long-Short Term Memory' cell. | ||
http://www.bioinf.jku.at/publications/older/2604.pdf | ||
""" | ||
|
||
def __init__(self, input_size, hidden_size, bias=True): | ||
super(LSTMCell, self).__init__() | ||
self.input_size = input_size | ||
self.hidden_size = hidden_size | ||
self.bias = bias | ||
self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) | ||
self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) | ||
self.reset_parameters() | ||
|
||
|
||
|
||
def reset_parameters(self): | ||
std = 1.0 / math.sqrt(self.hidden_size) | ||
for w in self.parameters(): | ||
w.data.uniform_(-std, std) | ||
|
||
def forward(self, x, hidden): | ||
|
||
hx, cx = hidden | ||
|
||
x = x.view(-1, x.size(1)) | ||
|
||
gates = self.x2h(x) + self.h2h(hx) | ||
|
||
gates = gates.squeeze() | ||
|
||
# ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) # this breaks when batch_size=1 | ||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, dim=-1) # appears to work for all batch_size | ||
|
||
ingate = torch.sigmoid(ingate) | ||
forgetgate = torch.sigmoid(forgetgate) | ||
cellgate = torch.tanh(cellgate) | ||
outgate = torch.sigmoid(outgate) | ||
|
||
|
||
cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate) | ||
|
||
hy = torch.mul(outgate, torch.tanh(cy)) | ||
|
||
return (hy, cy) |
Oops, something went wrong.