From 8c224f0bfc6ffdbfbaf99025b631568580d47ea6 Mon Sep 17 00:00:00 2001 From: Johanni Brea Date: Thu, 4 Jan 2024 22:20:03 +0100 Subject: [PATCH] 2 init states --- notebooks/rl.jl | 73 +++++++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/notebooks/rl.jl b/notebooks/rl.jl index bd999a3..e833f14 100644 --- a/notebooks/rl.jl +++ b/notebooks/rl.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.19.27 +# v0.19.32 using Markdown using InteractiveUtils @@ -52,7 +52,7 @@ In Monte Carlo estimation, the Q-values ``Q(s, a)`` of state-action pair ``(s, a # ╔═╡ 7c756af7-08a0-4878-8f53-9eaa79c16c1a mlcode(""" Base.@kwdef struct MCLearner # this defines a type MCLearner - ns = 7 # default initial number of states + ns = 8 # default initial number of states na = 2 # default initial number of actions N = zeros(Int, na, ns) # default initial counts Q = zeros(na, ns) # default initial Q-values @@ -194,7 +194,7 @@ md"""This allows us now to evaluate arbitrary policies for given the transition For example, for the following policy, where in state 1 (="red room") deterministically action 1 (="left") and in state 2 (="blue room without guard") action 2 (="right") etc. is taken, we get the following Q-values:""" # ╔═╡ 2920e1c8-64b0-4c60-8c47-1de38b5160b3 -silly_policy = M.DeterministicPolicy([1, 2, 1, 2, 1, 2, 1]) +silly_policy = M.DeterministicPolicy([1, 1, 2, 1, 2, 1, 2, 1]) # ╔═╡ 11141e67-421a-4e02-a638-f03b47ffb53c md"## Policy Iteration @@ -239,7 +239,7 @@ Q-learning is an alternative method for updating Q-values that relies on the ide mlcode( """ Base.@kwdef struct QLearner # this defines a type QLearner - ns = 7 # default initial number of states + ns = 8 # default initial number of states na = 2 # default initial number of actions Q = zeros(na, ns) # default initial Q-values λ = 0.01 @@ -274,12 +274,13 @@ Discrete state representations, where each state has its own state number (e.g. mlcode( """ function distributed_representation(s) - [s == 1; # 1 if in room 1 - s ∈ (2, 3); # 1 if in room 2 - s ∈ (4, 5); # 1 if in room 3 - s == 6; # 1 if in treasure room - s == 7; # 1 if KO - s ∈ (3, 5)] # 1 if guard present + [s == 1; # 1 if in violet room + s == 2; # 1 if in red room + s ∈ (3, 4); # 1 if in blue room + s ∈ (5, 6); # 1 if in green room + s == 7; # 1 if in treasure room + s == 8; # 1 if KO + s ∈ (4, 6)] # 1 if guard present end distributed_representation(5) # green room with guard """ @@ -598,29 +599,31 @@ md"""Player Cross: $(@bind player1 Select(["human", "machine"])) Player Nought: # ╔═╡ 761d690d-5c73-40dd-b38c-5af67ee837c0 begin function onehot(i) - x = zeros(7) + x = zeros(8) x[i] = 1 x end - act(state, action) = wsample(1:7, T[action, state]) + act(state, action) = wsample(1:8, T[action, state]) _reward(r::Number) = r _reward(r::AbstractArray) = rand(r) reward(state, action) = _reward(R[state, action]) - T = reshape([[0, .7, .3, 0, 0, 0, 0], [0, 0, 0, .3, .7, 0, 0], - onehot(6), onehot(6), - onehot(7), onehot(6), - onehot(6), onehot(6), - onehot(6), onehot(7), - onehot(6), onehot(6), - onehot(7), onehot(7)], 2, :) - R = [-.5 -.3 + T = reshape([onehot(3), [0, 0, 0, 0, .5, .5, 0, 0], + [0, 0, .7, .3, 0, 0, 0, 0], [0, 0, 0, 0, .3, .7, 0, 0], + onehot(7), onehot(7), + onehot(8), onehot(7), + onehot(7), onehot(7), + onehot(7), onehot(8), + onehot(7), onehot(7), + onehot(8), onehot(8)], 2, :) + R = [-.1 -.7 + -.5 -.3 [1:4] 1 -5 1 1 [3:6] 1 -5 0 0 0 0] - function showQ(Q; states = ["red room", "blue without guard", "blue with guard", "green without guard", "green with guard", "treasure room", "K.O."], actions = ["left", "right"]) + function showQ(Q; states = ["violet room", "red room", "blue without guard", "blue with guard", "green without guard", "green with guard", "treasure room", "K.O."], actions = ["left", "right"]) df = DataFrame(Q, states) df.action = actions df @@ -643,7 +646,7 @@ begin reward(env::ChasseAuTresorEnv) = env.reward state(env::ChasseAuTresorEnv) = env.state function reset!(env::ChasseAuTresorEnv) - env.state = 1 + env.state = env.state > 2 ? env.episode_recorder[1][1] % 2 + 1 : env.state env.reward = 0 empty!(env.episode_recorder) end @@ -651,7 +654,7 @@ begin end; # ╔═╡ 712c2a9e-4413-4d7a-b729-cfb219723256 -let mclearner = M.MCLearner(na = 2, ns = 7), +let mclearner = M.MCLearner(na = 2, ns = 8), chasse = ChasseAuTresorEnv() for _ in 1:10^6 reset!(chasse) @@ -678,7 +681,7 @@ optimal_policy = M.policy_iteration(T, R) showQ(M.evaluate_policy(optimal_policy, T, R)) # ╔═╡ bd8557bc-86f2-4ccc-93e9-a6bd843e80be -let qlearner = M.QLearner(na = 2, ns = 7), +let qlearner = M.QLearner(na = 2, ns = 8), chasse = ChasseAuTresorEnv() for _ in 1:10^6 reset!(chasse) @@ -695,7 +698,7 @@ let qlearner = M.QLearner(na = 2, ns = 7), end # ╔═╡ af7015c4-e7ab-4e18-bd37-ccffe4ec2928 -dql = let deepqlearner = M.DeepQLearner(Qnetwork = M.Chain(M.Dense(6, 10, M.relu), +dql = let deepqlearner = M.DeepQLearner(Qnetwork = M.Chain(M.Dense(7, 10, M.relu), M.Dense(10, 2))), chasse = ChasseAuTresorEnv() for episode in 1:10^5 @@ -711,10 +714,10 @@ dql = let deepqlearner = M.DeepQLearner(Qnetwork = M.Chain(M.Dense(6, 10, M.relu end; # ╔═╡ 3a643502-7d78-4d0c-a53f-913f35306258 -deepqpolicy = M.DeterministicPolicy([argmax(dql.Qnetwork(M.distributed_representation(s))) for s in 1:7]) +deepqpolicy = M.DeterministicPolicy([argmax(dql.Qnetwork(M.distributed_representation(s))) for s in 1:8]) # ╔═╡ ce334e27-9b66-4692-becd-cfc24ff58cb1 -showQ(hcat([dql.Qnetwork(M.distributed_representation(s)) for s in 1:7]...)) +showQ(hcat([dql.Qnetwork(M.distributed_representation(s)) for s in 1:8]...)) # ╔═╡ 094bdfef-1c4d-45ef-9bf4-66f60ac3a500 act!(ttt_env, 1); @@ -783,7 +786,7 @@ let states = get!(all_states, user_id, create_initial_state()) chasse = states.chasse chasse.action = chasse_actions[1] == 0 ? 3 : chasse_actions[2] - if chasse.state ≤ 5 || (isa(chasse.action, Int) && chasse.action > 2) + if chasse.state ≤ 6 || (isa(chasse.action, Int) && chasse.action > 2) _learner = if learner == "mclearner" if length(chasse.episode_recorder) > 0 && chasse.action == 3 M.update!(states.mclearner, chasse.episode_recorder) @@ -805,13 +808,13 @@ let end d = M.distributed_representation(chasse.state) room = findfirst(d) - guard = d[6] - guard_door = room > 2 - gold = d[4] - room_colors = [:red, :blue, :green, :orange, :black] + guard = d[7] + guard_door = room > 3 + gold = d[5] + room_colors = [:violet, :red, :lightblue, :lightgreen, :orange, :black] plot(xlim = (0, 1), ylim = (0, 1), size = (600, 400), bg = room_colors[room], framestyle = :none, legend = false) - if room < 4 + if room < 5 plot!([.1, .1, .4, .4], [0, .7, .7, 0], w = 5, c = :black) plot!([.14, .12, .17], [.37, .35, .35], w = 4, c = :black) plot!(.5 .+ [.1, .1, .4, .4], [0, .7, .7, 0], w = 5, c = :black) @@ -833,13 +836,13 @@ let y = y[1:r] scatter!(x, y, markerstrokewidth = 3, c = :yellow, markersize = 28) end - if room == 5 + if room == 6 scatter!([0], [0], c = :black, legend = false, markerstrokewidth = 0) # dummy annotate!([(.5, .5, "K.O.", :red)]) end rs = length(chasse.episode_recorder) == 0 ? [0] : getindex.(chasse.episode_recorder, 3) annotate!([(.5, .9, "reward = $(chasse.reward)", :white), - (.5, .8, "cumulative reward = $(join(rs, " + ")) = $(sum(rs))", :white) + (.5, .8, "cumulative reward = $(join(rs, " + ")) = $(round(sum(rs), sigdigits = 2))", :white) ]) end