-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
110 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ os: | |
- osx | ||
julia: | ||
- 0.6 | ||
- nightly | ||
notifications: | ||
email: false | ||
git: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
# RLEnvGym | ||
|
||
[![Build Status](https://travis-ci.com/Johanni Brea/RLEnvGym.jl.svg?branch=master)](https://travis-ci.com/Johanni Brea/RLEnvGym.jl) | ||
[![Build Status](https://travis-ci.com/JuliaReinforcementLearning/RLEnvGym.jl.svg?branch=master)](https://travis-ci.com/JuliaReinforcementLearning/RLEnvGym.jl) | ||
|
||
[![Coverage Status](https://coveralls.io/repos/Johanni Brea/RLEnvGym.jl/badge.svg?branch=master&service=github)](https://coveralls.io/github/Johanni Brea/RLEnvGym.jl?branch=master) | ||
[![Coverage Status](https://coveralls.io/repos/JuliaReinforcementLearning/RLEnvGym.jl/badge.svg?branch=master&service=github)](https://coveralls.io/github/JuliaReinforcementLearning/RLEnvGym.jl?branch=master) | ||
|
||
[![codecov.io](http://codecov.io/github/Johanni Brea/RLEnvGym.jl/coverage.svg?branch=master)](http://codecov.io/github/Johanni Brea/RLEnvGym.jl?branch=master) | ||
[![codecov.io](http://codecov.io/github/JuliaReinforcementLearning/RLEnvGym.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaReinforcementLearning/RLEnvGym.jl?branch=master) | ||
|
||
Making the [OpenAI gym](https://github.com/openai/gym) environments available to the [Julia Reinforcement Learning](https://github.com/jbrea/ReinforcementLearning.jl) package. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
julia 0.6 | ||
PyCall | ||
Reexport | ||
ReinforcementLearning |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
using PyCall | ||
|
||
# Change that to whatever packages you need. | ||
const PACKAGES = ["gym"] | ||
|
||
# Use eventual proxy info | ||
proxy_arg=String[] | ||
if haskey(ENV, "http_proxy") | ||
push!(proxy_arg, "--proxy") | ||
push!(proxy_arg, ENV["http_proxy"]) | ||
end | ||
|
||
# Import pip | ||
try | ||
@pyimport pip | ||
catch | ||
# If it is not found, install it | ||
println("Pip not found on your sytstem. Downloading it.") | ||
get_pip = joinpath(dirname(@__FILE__), "get-pip.py") | ||
download("https://bootstrap.pypa.io/get-pip.py", get_pip) | ||
run(`$(PyCall.python) $(proxy_arg) $get_pip --user`) | ||
end | ||
|
||
println("Installing required python packages using pip") | ||
run(`$(PyCall.python) $(proxy_arg) -m pip install --user --upgrade pip setuptools`) | ||
run(`$(PyCall.python) $(proxy_arg) -m pip install --user $(PACKAGES)`) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
using RLEnvGym, Flux | ||
# List all envs | ||
|
||
listallenvs() | ||
|
||
# CartPole example | ||
|
||
env = GymEnv("CartPole-v0") | ||
learner = DQN(Chain(Dense(4, 48, relu), Dense(48, 24, relu), Dense(24, 2)), | ||
updateevery = 1, updatetargetevery = 100, | ||
startlearningat = 50, minibatchsize = 32, | ||
doubledqn = false, replaysize = 10^3, | ||
opttype = x -> ADAM(x, .0005)) | ||
x = RLSetup(learner, env, ConstantNumberEpisodes(10), | ||
callbacks = [Progress(), EvaluationPerEpisode(TimeSteps()), | ||
Visualize(wait = 0)]) | ||
info("Before learning.") | ||
run!(x) | ||
pop!(x.callbacks) | ||
x.stoppingcriterion = ConstantNumberEpisodes(400) | ||
@time learn!(x) | ||
x.stoppingcriterion = ConstantNumberEpisodes(10) | ||
push!(x.callbacks, Visualize(wait = 0)) | ||
info("After learning.") | ||
run!(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,57 @@ | ||
module RLEnvGym | ||
using Reexport | ||
@reexport using ReinforcementLearning | ||
import ReinforcementLearning:interact!, reset!, getstate | ||
import ReinforcementLearning:interact!, reset!, getstate, plotenv | ||
using PyCall | ||
@pyimport gym | ||
# @pyimport roboschool | ||
|
||
function getspace(space) | ||
if pyisinstance(space, gym.spaces[:box][:Box]) | ||
ReinforcementLearning.Box(space[:low], space[:high]) | ||
elseif pyisinstance(space, gym.spaces[:discrete][:Discrete]) | ||
1:space[:n] | ||
else | ||
error("Don't know how to convert $(pytypeof(space)).") | ||
end | ||
end | ||
mutable struct GymEnvState | ||
done::Bool | ||
end | ||
struct GymEnv{TObject, TObsSpace, TActionSpace} | ||
pyobj::TObject | ||
observation_space::TObsSpace | ||
action_space::TActionSpace | ||
state::GymEnvState | ||
end | ||
function GymEnv(name::String) | ||
pyenv = gym.make(name) | ||
obsspace = getspace(pyenv[:observation_space]) | ||
actspace = getspace(pyenv[:action_space]) | ||
env = GymEnv(pyenv, obsspace, actspace, GymEnvState(false)) | ||
reset!(env) | ||
env | ||
end | ||
|
||
function interactgym!(action, env) | ||
if env.state.done | ||
s = reset!(env) | ||
r = 0 | ||
d = false | ||
else | ||
s, r, d = env.pyobj[:step](action) | ||
end | ||
env.state.done = d | ||
s, r, d | ||
end | ||
interact!(action, env::GymEnv) = interactgym!(action, env) | ||
interact!(action::Int64, env::GymEnv) = interactgym!(action - 1, env) | ||
reset!(env::GymEnv) = env.pyobj[:reset]() | ||
getstate(env::GymEnv) = (env.pyobj[:env][:state], false) # doesn't work for all envs | ||
|
||
plotenv(env::GymEnv, s, a, r, d) = env.pyobj[:render]() | ||
listallenvs() = gym.envs[:registry][:all]() | ||
|
||
export GymEnv, listallenvs | ||
|
||
end # module |