Skip to content

Commit

Permalink
Finalize RL swingup setting.
Browse files Browse the repository at this point in the history
  • Loading branch information
obrusvit committed Apr 24, 2022
1 parent 21acd17 commit 5fe86e6
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 21 deletions.
1 change: 1 addition & 0 deletions assets/rl_ctrl_2022-04-23_16:30:10.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[0.4717133,0.7440337,-1.7736977,-2.0201738,-0.31323344,0.9730072,-0.5329789,0.92873204,-0.6848246,-1.9004071,2.143752,-1.1378529,0.3115067,1.1259874,3.0911608,-6.0154552,0.05651751,1.0153601,-0.77643025,0.269369,0.18687288,-0.16074464,1.2439964,-0.15892066,-1.1393435,0.9317976,1.3223119,-0.6772952,-0.3536504,0.94268227,1.3424321,-0.046352644,1.1503859,-0.675419,-0.26867515,-0.6363417,0.15652771,-0.8910839,-0.333146,-0.29503053,0.78153825,0.17307802,-0.6166021,-0.7012435,-0.32341886]
12 changes: 6 additions & 6 deletions main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ function main_LQR(cp_params::CartPoleParams, init_state::CartPoleState; make_plo
plot_sol_force(sol, saved_values; dest_dir="output", dest_name="main_LQR.png")
end
if make_gif
make_gif(sol, CartPole(cp_params, init_state); dest_dir="output", dest_name="main_lqr.gif")
gif_sol(sol, CartPole(cp_params, init_state); dest_dir="output", dest_name="main_lqr.gif")
end
end

Expand All @@ -216,7 +216,7 @@ function main_swingup_optim(cp_params::CartPoleParams, init_state::CartPoleState
plot_sol_force(sol, saved_values; dest_dir="output", dest_name="main_swingup_optim.png")
end
if make_gif
make_gif(sol, CartPole(cp_params, init_state); dest_dir="output", dest_name="main_swingup_optim.gif")
gif_sol(sol, CartPole(cp_params, init_state); dest_dir="output", dest_name="main_swingup_optim.gif")
end
end

Expand All @@ -226,15 +226,15 @@ function main_swingup_rl(cp_params::CartPoleParams, init_state::CartPoleState; m
sys = CartPole(cp_params, init_state)
sys_LTI_upper = cartpole_LTI_sys(sys, final)

T_N = 1.0 # final time of the maneuver
T_N = 4.0 # final time of the maneuver
N = 401
if isempty(nn_params)
nn_params = train_cartpole_rl_controller(T_N, N, cp_params, init_state, final; saveToJson = true)
println("training finished")
end

f(x, t) =
if t >= T_N || abs(x[3] - pi) <= 0.01
if t >= T_N || abs(x[3] - pi) <= 0.1
force_LQR(x, sys_LTI_upper)
else
get_control_input(x, nn_params)
Expand All @@ -246,7 +246,7 @@ function main_swingup_rl(cp_params::CartPoleParams, init_state::CartPoleState; m
plot_sol_force(sol, saved_values; dest_dir="output", dest_name="main_swingup_rl.png")
end
if make_gif
make_gif(sol, CartPole(cp_params, init_state); dest_dir="output", dest_name="main_swingup_rl.gif")
gif_sol(sol, CartPole(cp_params, init_state); dest_dir="output", dest_name="main_swingup_rl.gif")
end
return nn_params
end
Expand All @@ -267,5 +267,5 @@ function main()
# Plot
plot_sol_force(sol, saved_values; dest_name = "main.png")
sys = CartPole(cp_params, init)
make_gif(sol, sys; dest_name = "main.gif")
gif_sol(sol, sys; dest_name = "main.gif")
end
Binary file added output/swingup_rl_2022-04-23_16:30:10.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output/swingup_rl_2022-04-23_16:30:10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions src/postprocess_visualize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function vis_cartpole(cartpole::CartPole, time::Float64)
end


function make_gif(t, states, cartpole::CartPole)
function gif_sol(t, states, cartpole::CartPole)
# was used for swing up control vis
p = plot(reuse = false)
anim = @animate for idx = 1:1:length(t)
Expand All @@ -145,7 +145,7 @@ function make_gif(t, states, cartpole::CartPole)
end


function make_gif(sol, cartpole::CartPole; dest_dir::String="output", dest_name::String="cart_pole_sim.gif")
function gif_sol(sol, cartpole::CartPole; dest_dir::String="output", dest_name::String="cart_pole_sim.gif")
# p = plot(reuse = false)
anim = @animate for t = sol.t[begin]:0.1:sol.t[end]
v = sol(t)
Expand Down
20 changes: 7 additions & 13 deletions src/swingup_rl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@ using Dates
include("diffeq_simulation.jl")


# controller = FastChain((x, p) -> x, FastDense(3, 3, relu), FastDense(3,3,relu), FastDense(3,1)) # good
controller = FastChain((x, p) -> x, FastDense(4, 4, relu), FastDense(4,4,relu), FastDense(4,1))
# controller = FastChain((x, p) -> x, FastDense(5, 5, relu), FastDense(5,5,relu), FastDense(5,1))

# get_control_input(u, nn_params) = controller([cos(u[3]), sin(u[3]), u[4]], nn_params)[1]
get_control_input(u, nn_params) = controller([u[1], cos(u[3]), sin(u[3]), u[4]], nn_params)[1]
# get_control_input(u, nn_params) = controller([u[1], u[2], cos(u[3]), sin(u[3]), u[4]], nn_params)[1]

# map angle to [-pi, pi)
modpi(theta) = mod2pi(theta + pi) - pi
Expand Down Expand Up @@ -41,7 +37,7 @@ function train_cartpole_rl_controller(T_N::Float64, N::Int64, params::CartPolePa
function loss_neuralode(p)
f(x,t) = get_control_input(x, p)
ode_params = [params.mₜ, params.mₚ, params.L, params.bₜ, params.bₚ, f]
sol = solve(remake(prob, p=ode_params), Tsit5(), saveat = tsteps)
sol = DifferentialEquations.solve(remake(prob, p=ode_params), Tsit5(), saveat = tsteps)
x = sol[1, :]
dx = sol[2, :]
theta = modpi.(sol[3, :])
Expand All @@ -51,15 +47,13 @@ function train_cartpole_rl_controller(T_N::Float64, N::Int64, params::CartPolePa
state_vec = [[u[1], u[2], u[3], u[4]] for u in sol.u]
force = [get_control_input(u,p) for u in state_vec]

# 2022-04-23_16:30:10, default CartPoleParams but with L=2.0; T_N=4.0, N = 401
# controller with 4 inputs, 1 hl, 1 output
# x0 = (0, 0, 0, 0); xN = (0, 0, pi, 0)
loss = 100*(theta[end]-xN.ϕ)^2 + 10*(dtheta[end]-xN.ϕ̇)^2 + 50*(x[end]-xN.x)^2 + (dx[end]-xN.ẋ)^2 + 0.01 * sum(abs2, force) / N

# good objective functions
# loss = 100*(theta[end]-pi)^2 + dtheta[end]^2 + dx[end]^2 + 0.01 * sum(abs2, force) / N # best so far with tspan=(0,1), length of pole=1,

loss = 100*(theta[end]-pi)^2 + dtheta[end]^2 + 50*x[end]^2 + dx[end]^2 + 0.01 * sum(abs2, force) / N

# loss = 1000*(theta[end]-pi)^2 + 10*sum(abs2, x) / N + dtheta[end]^2 + dx[end]^2 + 0.1 * sum(abs2, force) / N # big dtheta at the end

# loss = 10*(theta[end]-pi)^2 + dtheta[end]^2 + 0.01 * sum(abs2, force) / N

return loss, sol
end

Expand All @@ -84,7 +78,7 @@ function train_cartpole_rl_controller(T_N::Float64, N::Int64, params::CartPolePa
nn_pinit,
# ADAM(0.05),
cb=callback_1,
maxiters=2400,
maxiters=3000,
)

if saveToJson
Expand Down

0 comments on commit 5fe86e6

Please sign in to comment.