diff --git a/src/solutions/epsilon_greedy.jl b/src/solutions/epsilon_greedy.jl index 0cca5fe..f62df09 100644 --- a/src/solutions/epsilon_greedy.jl +++ b/src/solutions/epsilon_greedy.jl @@ -52,13 +52,15 @@ function get_action_probs(sol::EpsilonGreedyPolicy, state::State) end function get_action_prob(sol::EpsilonGreedyPolicy, state::State, action::Term) - n_actions = length(lazy_collect(available(sol.domain, state))) + actions = lazy_collect(available(sol.domain, state)) + n_actions = length(actions) if n_actions == 0 return 0.0 end prob = sol.epsilon / n_actions best_act = best_action(sol.policy, state) if action == best_act return prob + (1 - sol.epsilon) - else + elseif action in actions return prob end + return 0.0 end diff --git a/test/solutions.jl b/test/solutions.jl index f6ee29d..beace83 100644 --- a/test/solutions.jl +++ b/test/solutions.jl @@ -47,6 +47,8 @@ no_op = convert(Term, PDDL.no_op) @test get_action_probs(sol, trajectory[2]) == Dict(plan[2] => 1.0) @test get_action_probs(sol, trajectory[end]) == Dict(no_op => 1.0) @test get_action_prob(sol, trajectory[1], plan[1]) == 1.0 +@test get_action_prob(sol, trajectory[1], pddl"(pick-up b)") == 0.0 +@test get_action_prob(sol, trajectory[1], pddl"(pick-up z)") == 0.0 @test copy(sol) == sol @@ -62,6 +64,7 @@ n_actions = length(bw_init_actions) probs = Dict(a => 1.0 / n_actions for a in bw_init_actions) @test get_action_probs(sol, bw_state) == probs @test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0 / n_actions +@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0 @test copy(sol) == sol @@ -83,8 +86,9 @@ sol.Q[hash(bw_state)] = bw_init_q probs = Dict(a => a == pddl"(pick-up a)" ? 1.0 : 0.0 for a in bw_init_actions) @test get_action_probs(sol, bw_state) == probs -@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0 -@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0 +@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0 +@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0 +@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0 @test copy(sol) == sol @@ -109,8 +113,9 @@ end probs = Dict(a => a == pddl"(pick-up a)" ? 1.0 : 0.0 for a in bw_init_actions) @test get_action_probs(sol, bw_state) == probs -@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0 -@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0 +@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0 +@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0 +@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0 @test copy(sol) == sol @@ -132,7 +137,7 @@ sol = FunctionalVPolicy(heuristic, blocksworld, bw_spec) probs = Dict(a => a == pddl"(pick-up a)" ? 1.0 : 0.0 for a in bw_init_actions) @test get_action_probs(sol, bw_state) == probs -@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0 +@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0 @test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0 @test copy(sol) == sol @@ -159,6 +164,7 @@ probs = Dict(zip(keys(bw_init_q), probs)) @test get_action_probs(sol, bw_state) == probs act_prob = probs[pddl"(pick-up a)"] @test get_action_prob(sol, bw_state, pddl"(pick-up a)") ≈ act_prob +@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0 @test copy(sol) == sol @@ -184,6 +190,7 @@ probs[pddl"(pick-up a)"] += 0.9 @test get_action_probs(sol, bw_state) == probs act_prob = probs[pddl"(pick-up a)"] @test get_action_prob(sol, bw_state, pddl"(pick-up a)") ≈ act_prob +@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0 @test copy(sol) == sol