Skip to content

Commit

Permalink
Increase coverage to > 90%, added Categorical test
Browse files Browse the repository at this point in the history
  • Loading branch information
mossr committed Sep 10, 2020
1 parent 599d052 commit aff3dd5
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 8 deletions.
3 changes: 1 addition & 2 deletions notebooks/Walk1D.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1234,8 +1234,7 @@
}
],
"source": [
"tree = search!(planner; return_tree=true) # re-run the search, this time outputting the tree\n",
"d3tree = visualize(tree)"
"d3tree = visualize(planner) # re-runs the search to output the tree, then visualizes it"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/drl/policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end

function CategoricalPolicy(solver)
policy_net = Chain(Dense(solver.state_size, solver.hidden_layer_size, relu; initW=_random_normal, initb=constant_init),
Dense(solver.hidden_layer_size, solver.ACTION_SIZE; initW=_random_normal, initb=constant_init),
Dense(solver.hidden_layer_size, solver.action_size; initW=_random_normal, initb=constant_init),
x -> softmax(x))

value_net = Chain(Dense(solver.state_size, solver.hidden_layer_size ,relu; initW=_random_normal),
Expand Down
6 changes: 5 additions & 1 deletion src/solvers/drl/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ function train!(planner::Union{TRPOPlanner, PPOPlanner})

# Create or load policy
if solver.resume
policy = load_policy(solver, "weights", DiagonalGaussianPolicy) # TODO: parameterize path
if solver.policy_type == :discrete
policy = load_policy(solver, "weights", CategoricalPolicy) # TODO: parameterize path
elseif solver.policy_type == :continuous
policy = load_policy(solver, "weights", DiagonalGaussianPolicy) # TODO: parameterize path
end
else
policy = get_policy(solver)
end
Expand Down
8 changes: 4 additions & 4 deletions src/visualization/tree_visualization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ full_width_notebook(width=100) = display(HTML("<style>.container { width:$width%

# Display of action nodes.
function MCTS.node_tag(s::AST.ASTState)
state_str::String = "0x"*string(s.hash, base=16)
state_str::String = "0x"*string(s.hash, base=16)
if s.terminal
return "Terminal [$state_str]."
else
Expand All @@ -24,10 +24,10 @@ end
"""
Visualize MCTS tree structure for AST MDP.
"""
function visualize(policy::MCTS.DPWPlanner)
tree = search!(policy; return_tree=true)
function visualize(planner::MCTS.DPWPlanner)
tree = search!(planner; return_tree=true)
d3 = visualize(tree)
return d3::D3Tree
return d3::D3Tree
end

function visualize(tree::MCTS.DPWTree)
Expand Down
142 changes: 142 additions & 0 deletions test/CategoricalWalk1D.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# using Revise
using POMDPStressTesting
using Distributions
using Parameters


@with_kw mutable struct CategoricalWalk1DParams
startx::Float64 = 0 # Starting x-position
threshx::Float64 = 10 # +- boundary threshold
endtime::Int64 = 30 # Simulate end time
end


# Implement abstract GrayBox.Simulation
@with_kw mutable struct CategoricalWalk1DSim <: GrayBox.Simulation
params::CategoricalWalk1DParams = CategoricalWalk1DParams() # Parameters
x::Float64 = 0 # Current x-position
t::Int64 = 0 # Current time
distribution::Distribution = Categorical([0.95,0.025,0.025]) # Transition distribution [1,2,3]
end


# Override from GrayBox
GrayBox.environment(sim::CategoricalWalk1DSim) = GrayBox.Environment(:x => sim.distribution)


# Override from GrayBox (NOTE: used with ASTSeedAction)
function GrayBox.transition!(sim::CategoricalWalk1DSim)
# We sample the environment and apply the transition
environment::GrayBox.Environment = GrayBox.environment(sim) # Get the environment distributions
sample::GrayBox.EnvironmentSample = rand(environment) # Sample from the environment
sim.t += 1 # Keep track of time
sim.x += sample[:x].value # Move agent using sampled value from input
return logpdf(sample)::Real # Summation handled by `logpdf()`
end


# Override from GrayBox (NOTE: used with ASTSampleAction)
function GrayBox.transition!(sim::CategoricalWalk1DSim, sample::GrayBox.EnvironmentSample)
# The environment was sampled for us, and we just apply the transition
sim.t += 1 # Keep track of time
sim.x += sample[:x].value # Move agent using sampled value from input
return logpdf(sample)::Real # Summation handled by `logpdf()`
end


# Override from BlackBox
function BlackBox.initialize!(sim::CategoricalWalk1DSim)
sim.t = 0
sim.x = sim.params.startx
end


# Override from BlackBox
BlackBox.distance(sim::CategoricalWalk1DSim) = max(sim.params.threshx - abs(sim.x), 0)


# Override from BlackBox
BlackBox.isevent(sim::CategoricalWalk1DSim) = abs(sim.x) >= sim.params.threshx


# Override from BlackBox
BlackBox.isterminal(sim::CategoricalWalk1DSim) = BlackBox.isevent(sim) || sim.t >= sim.params.endtime


# Override from BlackBox (NOTE: used with ASTSeedAction)
function BlackBox.evaluate!(sim::CategoricalWalk1DSim)
logprob::Real = GrayBox.transition!(sim) # Step simulation
d::Real = BlackBox.distance(sim) # Calculate miss distance
event::Bool = BlackBox.isevent(sim) # Check event indication
return (logprob::Real, d::Real, event::Bool)
end


# Override from BlackBox (NOTE: used with ASTSampleAction)
function BlackBox.evaluate!(sim::CategoricalWalk1DSim, sample::GrayBox.EnvironmentSample)
logprob::Real = GrayBox.transition!(sim, sample) # Step simulation given input sample
d::Real = BlackBox.distance(sim) # Calculate miss distance
event::Bool = BlackBox.isevent(sim) # Check event indication
return (logprob::Real, d::Real, event::Bool)
end


function setup_ast(seed=AST.DEFAULT_SEED; solver=PPOSolver)
# Create gray-box simulation object
sim::GrayBox.Simulation = CategoricalWalk1DSim()

# AST MDP formulation object
# NOTE: Use either {ASTSeedAction} or {ASTSampleAction} (when using TRPO/PPO/CEM, use ASTSampleAction)
if solver in [TRPOSolver, PPOSolver, CEMSolver]
mdp::ASTMDP = ASTMDP{ASTSampleAction}(sim)
else
mdp = ASTMDP{ASTSeedAction}(sim)
end
mdp.params.debug = true # record metrics
mdp.params.top_k = 10 # record top k best trajectories
mdp.params.seed = seed # set RNG seed for determinism
n_iterations = 1000 # number of algorithm iterations

# Choose a solver (examples of each)
if solver == RandomSearchSolver
solver = RandomSearchSolver(n_iterations=n_iterations,
episode_length=sim.params.endtime)
elseif solver == MCTSPWSolver
solver = MCTSPWSolver(n_iterations=n_iterations,
exploration_constant=1.0, # UCT exploration
k_action=1.0, # action widening
alpha_action=0.5, # action widening
depth=sim.params.endtime) # tree depth (i.e. episode length)
elseif solver == CEMSolver
solver = CEMSolver(n_iterations=n_iterations,
episode_length=sim.params.endtime)
elseif solver == TRPOSolver
solver = TRPOSolver(num_episodes=n_iterations,
episode_length=sim.params.endtime,
policy_type=:discrete)
elseif solver == PPOSolver
solver = PPOSolver(num_episodes=n_iterations,
episode_length=sim.params.endtime,
policy_type=:discrete)
end

# Get online planner (no work done, yet)
planner = solve(solver, mdp)

return planner
end


function run_ast(seed=AST.DEFAULT_SEED; kwargs...)
planner = setup_ast(seed; kwargs...)

action_trace::Vector{ASTAction} = search!(planner) # work done here
final_state::ASTState = playback(planner, action_trace, sim->sim.x)
failure_rate::Float64 = print_metrics(planner)

return planner, action_trace::Vector{ASTAction}, failure_rate::Float64
end

(planner, action_trace, failure_rate) = run_ast()

nothing # Suppress REPL
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ end
@test begin
@info "Extra functions test"
(planner, _, _) = run_ast(solver=MCTSPWSolver)
d3tree = visualize(planner)
action_trace = search!(planner; verbose=true)
actions = online_path(planner.mdp, planner)
x_trace = playback(planner, actions, sim->sim.x; return_trace=true)
Expand All @@ -56,6 +57,10 @@ end
solver.resume = true
search!(planner)

# No top_k action trace
solver.resume = false
planner.mdp.params.top_k = 0
ast_action = search!(planner)
true
end

Expand All @@ -77,3 +82,21 @@ test_solvers(skip_trpo=true) # TRPO slows down the entire test suite, so skip it
include("EpisodicWalk1D.jl")
true
end


@test begin
@info "CategoricalWalk1D"
include("CategoricalWalk1D.jl")

# Policy saving
(planner, _, _) = run_ast(solver=PPOSolver)
solver = PPOSolver(num_episodes=100, episode_length=30, save=true, verbose=true, policy_type=:discrete)
planner = solve(solver, planner.mdp)
search!(planner)

# Policy loading
solver.resume = true
search!(planner)

true
end

0 comments on commit aff3dd5

Please sign in to comment.