diff --git a/src/RLEnvGym.jl b/src/RLEnvGym.jl index 6da75dc..0029505 100644 --- a/src/RLEnvGym.jl +++ b/src/RLEnvGym.jl @@ -47,7 +47,7 @@ end interact!(action, env::GymEnv) = interactgym!(action, env) interact!(action::Int64, env::GymEnv) = interactgym!(action - 1, env) reset!(env::GymEnv) = env.pyobj[:reset]() -getstate(env::GymEnv) = (env.pyobj[:env][:state], false) # doesn't work for all envs +getstate(env::GymEnv) = (Float64[env.pyobj[:env][:state]...], false) # doesn't work for all envs plotenv(env::GymEnv, s, a, r, d) = env.pyobj[:render]() listallenvs() = gym.envs[:registry][:all]() diff --git a/test/runtests.jl b/test/runtests.jl index 5539bfd..a3d4d6f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,10 @@ else using Test end -# write your own tests here -@test 1 == 2 +import RLEnvGym: reset!, interact!, getstate +for x in ["CartPole-v0"] + env = GymEnv(x) + reset!(env) + @test typeof(interact!(1, env)) == Tuple{Array{Float64, 1}, Float64, Bool} + @test typeof(getstate(env)) == Tuple{Array{Float64, 1}, Bool} +end