Skip to content

Commit

Permalink
2 init states
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrea committed Jan 4, 2024
1 parent 8dbf7ad commit 8c224f0
Showing 1 changed file with 38 additions and 35 deletions.
73 changes: 38 additions & 35 deletions notebooks/rl.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### A Pluto.jl notebook ###
# v0.19.27
# v0.19.32

using Markdown
using InteractiveUtils
Expand Down Expand Up @@ -52,7 +52,7 @@ In Monte Carlo estimation, the Q-values ``Q(s, a)`` of state-action pair ``(s, a
# ╔═╡ 7c756af7-08a0-4878-8f53-9eaa79c16c1a
mlcode("""
Base.@kwdef struct MCLearner # this defines a type MCLearner
ns = 7 # default initial number of states
ns = 8 # default initial number of states
na = 2 # default initial number of actions
N = zeros(Int, na, ns) # default initial counts
Q = zeros(na, ns) # default initial Q-values
Expand Down Expand Up @@ -194,7 +194,7 @@ md"""This allows us now to evaluate arbitrary policies for given the transition
For example, for the following policy, where in state 1 (="red room") deterministically action 1 (="left") and in state 2 (="blue room without guard") action 2 (="right") etc. is taken, we get the following Q-values:"""

# ╔═╡ 2920e1c8-64b0-4c60-8c47-1de38b5160b3
silly_policy = M.DeterministicPolicy([1, 2, 1, 2, 1, 2, 1])
silly_policy = M.DeterministicPolicy([1, 1, 2, 1, 2, 1, 2, 1])

# ╔═╡ 11141e67-421a-4e02-a638-f03b47ffb53c
md"## Policy Iteration
Expand Down Expand Up @@ -239,7 +239,7 @@ Q-learning is an alternative method for updating Q-values that relies on the ide
mlcode(
"""
Base.@kwdef struct QLearner # this defines a type QLearner
ns = 7 # default initial number of states
ns = 8 # default initial number of states
na = 2 # default initial number of actions
Q = zeros(na, ns) # default initial Q-values
λ = 0.01
Expand Down Expand Up @@ -274,12 +274,13 @@ Discrete state representations, where each state has its own state number (e.g.
mlcode(
"""
function distributed_representation(s)
[s == 1; # 1 if in room 1
s ∈ (2, 3); # 1 if in room 2
s ∈ (4, 5); # 1 if in room 3
s == 6; # 1 if in treasure room
s == 7; # 1 if KO
s ∈ (3, 5)] # 1 if guard present
[s == 1; # 1 if in violet room
s == 2; # 1 if in red room
s ∈ (3, 4); # 1 if in blue room
s ∈ (5, 6); # 1 if in green room
s == 7; # 1 if in treasure room
s == 8; # 1 if KO
s ∈ (4, 6)] # 1 if guard present
end
distributed_representation(5) # green room with guard
"""
Expand Down Expand Up @@ -598,29 +599,31 @@ md"""Player Cross: $(@bind player1 Select(["human", "machine"])) Player Nought:
# ╔═╡ 761d690d-5c73-40dd-b38c-5af67ee837c0
begin
function onehot(i)
x = zeros(7)
x = zeros(8)
x[i] = 1
x
end
act(state, action) = wsample(1:7, T[action, state])
act(state, action) = wsample(1:8, T[action, state])
_reward(r::Number) = r
_reward(r::AbstractArray) = rand(r)
reward(state, action) = _reward(R[state, action])
T = reshape([[0, .7, .3, 0, 0, 0, 0], [0, 0, 0, .3, .7, 0, 0],
onehot(6), onehot(6),
onehot(7), onehot(6),
onehot(6), onehot(6),
onehot(6), onehot(7),
onehot(6), onehot(6),
onehot(7), onehot(7)], 2, :)
R = [-.5 -.3
T = reshape([onehot(3), [0, 0, 0, 0, .5, .5, 0, 0],
[0, 0, .7, .3, 0, 0, 0, 0], [0, 0, 0, 0, .3, .7, 0, 0],
onehot(7), onehot(7),
onehot(8), onehot(7),
onehot(7), onehot(7),
onehot(7), onehot(8),
onehot(7), onehot(7),
onehot(8), onehot(8)], 2, :)
R = [-.1 -.7
-.5 -.3
[1:4] 1
-5 1
1 [3:6]
1 -5
0 0
0 0]
function showQ(Q; states = ["red room", "blue without guard", "blue with guard", "green without guard", "green with guard", "treasure room", "K.O."], actions = ["left", "right"])
function showQ(Q; states = ["violet room", "red room", "blue without guard", "blue with guard", "green without guard", "green with guard", "treasure room", "K.O."], actions = ["left", "right"])
df = DataFrame(Q, states)
df.action = actions
df
Expand All @@ -643,15 +646,15 @@ begin
reward(env::ChasseAuTresorEnv) = env.reward
state(env::ChasseAuTresorEnv) = env.state
function reset!(env::ChasseAuTresorEnv)
env.state = 1
env.state = env.state > 2 ? env.episode_recorder[1][1] % 2 + 1 : env.state
env.reward = 0
empty!(env.episode_recorder)
end
# chasse = ChasseAuTresorEnv()
end;

# ╔═╡ 712c2a9e-4413-4d7a-b729-cfb219723256
let mclearner = M.MCLearner(na = 2, ns = 7),
let mclearner = M.MCLearner(na = 2, ns = 8),
chasse = ChasseAuTresorEnv()
for _ in 1:10^6
reset!(chasse)
Expand All @@ -678,7 +681,7 @@ optimal_policy = M.policy_iteration(T, R)
showQ(M.evaluate_policy(optimal_policy, T, R))

# ╔═╡ bd8557bc-86f2-4ccc-93e9-a6bd843e80be
let qlearner = M.QLearner(na = 2, ns = 7),
let qlearner = M.QLearner(na = 2, ns = 8),
chasse = ChasseAuTresorEnv()
for _ in 1:10^6
reset!(chasse)
Expand All @@ -695,7 +698,7 @@ let qlearner = M.QLearner(na = 2, ns = 7),
end

# ╔═╡ af7015c4-e7ab-4e18-bd37-ccffe4ec2928
dql = let deepqlearner = M.DeepQLearner(Qnetwork = M.Chain(M.Dense(6, 10, M.relu),
dql = let deepqlearner = M.DeepQLearner(Qnetwork = M.Chain(M.Dense(7, 10, M.relu),
M.Dense(10, 2))),
chasse = ChasseAuTresorEnv()
for episode in 1:10^5
Expand All @@ -711,10 +714,10 @@ dql = let deepqlearner = M.DeepQLearner(Qnetwork = M.Chain(M.Dense(6, 10, M.relu
end;

# ╔═╡ 3a643502-7d78-4d0c-a53f-913f35306258
deepqpolicy = M.DeterministicPolicy([argmax(dql.Qnetwork(M.distributed_representation(s))) for s in 1:7])
deepqpolicy = M.DeterministicPolicy([argmax(dql.Qnetwork(M.distributed_representation(s))) for s in 1:8])

# ╔═╡ ce334e27-9b66-4692-becd-cfc24ff58cb1
showQ(hcat([dql.Qnetwork(M.distributed_representation(s)) for s in 1:7]...))
showQ(hcat([dql.Qnetwork(M.distributed_representation(s)) for s in 1:8]...))

# ╔═╡ 094bdfef-1c4d-45ef-9bf4-66f60ac3a500
act!(ttt_env, 1);
Expand Down Expand Up @@ -783,7 +786,7 @@ let
states = get!(all_states, user_id, create_initial_state())
chasse = states.chasse
chasse.action = chasse_actions[1] == 0 ? 3 : chasse_actions[2]
if chasse.state 5 || (isa(chasse.action, Int) && chasse.action > 2)
if chasse.state 6 || (isa(chasse.action, Int) && chasse.action > 2)
_learner = if learner == "mclearner"
if length(chasse.episode_recorder) > 0 && chasse.action == 3
M.update!(states.mclearner, chasse.episode_recorder)
Expand All @@ -805,13 +808,13 @@ let
end
d = M.distributed_representation(chasse.state)
room = findfirst(d)
guard = d[6]
guard_door = room > 2
gold = d[4]
room_colors = [:red, :blue, :green, :orange, :black]
guard = d[7]
guard_door = room > 3
gold = d[5]
room_colors = [:violet, :red, :lightblue, :lightgreen, :orange, :black]
plot(xlim = (0, 1), ylim = (0, 1), size = (600, 400),
bg = room_colors[room], framestyle = :none, legend = false)
if room < 4
if room < 5
plot!([.1, .1, .4, .4], [0, .7, .7, 0], w = 5, c = :black)
plot!([.14, .12, .17], [.37, .35, .35], w = 4, c = :black)
plot!(.5 .+ [.1, .1, .4, .4], [0, .7, .7, 0], w = 5, c = :black)
Expand All @@ -833,13 +836,13 @@ let
y = y[1:r]
scatter!(x, y, markerstrokewidth = 3, c = :yellow, markersize = 28)
end
if room == 5
if room == 6
scatter!([0], [0], c = :black, legend = false, markerstrokewidth = 0) # dummy
annotate!([(.5, .5, "K.O.", :red)])
end
rs = length(chasse.episode_recorder) == 0 ? [0] : getindex.(chasse.episode_recorder, 3)
annotate!([(.5, .9, "reward = $(chasse.reward)", :white),
(.5, .8, "cumulative reward = $(join(rs, " + ")) = $(sum(rs))", :white)
(.5, .8, "cumulative reward = $(join(rs, " + ")) = $(round(sum(rs), sigdigits = 2))", :white)
])
end

Expand Down

0 comments on commit 8c224f0

Please sign in to comment.