Skip to content

Commit

Permalink
Ensure unavailable actions have zero probability.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Dec 22, 2023
1 parent af9cc7f commit 6d2214d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
6 changes: 4 additions & 2 deletions src/solutions/epsilon_greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ function get_action_probs(sol::EpsilonGreedyPolicy, state::State)
end

function get_action_prob(sol::EpsilonGreedyPolicy, state::State, action::Term)
n_actions = length(lazy_collect(available(sol.domain, state)))
actions = lazy_collect(available(sol.domain, state))
n_actions = length(actions)
if n_actions == 0 return 0.0 end
prob = sol.epsilon / n_actions
best_act = best_action(sol.policy, state)
if action == best_act
return prob + (1 - sol.epsilon)
else
elseif action in actions
return prob
end
return 0.0
end
17 changes: 12 additions & 5 deletions test/solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ no_op = convert(Term, PDDL.no_op)
@test get_action_probs(sol, trajectory[2]) == Dict(plan[2] => 1.0)
@test get_action_probs(sol, trajectory[end]) == Dict(no_op => 1.0)
@test get_action_prob(sol, trajectory[1], plan[1]) == 1.0
@test get_action_prob(sol, trajectory[1], pddl"(pick-up b)") == 0.0
@test get_action_prob(sol, trajectory[1], pddl"(pick-up z)") == 0.0

@test copy(sol) == sol

Expand All @@ -62,6 +64,7 @@ n_actions = length(bw_init_actions)
probs = Dict(a => 1.0 / n_actions for a in bw_init_actions)
@test get_action_probs(sol, bw_state) == probs
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0 / n_actions
@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0

@test copy(sol) == sol

Expand All @@ -83,8 +86,9 @@ sol.Q[hash(bw_state)] = bw_init_q

probs = Dict(a => a == pddl"(pick-up a)" ? 1.0 : 0.0 for a in bw_init_actions)
@test get_action_probs(sol, bw_state) == probs
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0
@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0
@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0
@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0

@test copy(sol) == sol

Expand All @@ -109,8 +113,9 @@ end

probs = Dict(a => a == pddl"(pick-up a)" ? 1.0 : 0.0 for a in bw_init_actions)
@test get_action_probs(sol, bw_state) == probs
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0
@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0
@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0
@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0

@test copy(sol) == sol

Expand All @@ -132,7 +137,7 @@ sol = FunctionalVPolicy(heuristic, blocksworld, bw_spec)

probs = Dict(a => a == pddl"(pick-up a)" ? 1.0 : 0.0 for a in bw_init_actions)
@test get_action_probs(sol, bw_state) == probs
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") == 1.0
@test get_action_prob(sol, bw_state, pddl"(pick-up b)") == 0.0

@test copy(sol) == sol
Expand All @@ -159,6 +164,7 @@ probs = Dict(zip(keys(bw_init_q), probs))
@test get_action_probs(sol, bw_state) == probs
act_prob = probs[pddl"(pick-up a)"]
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") act_prob
@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0

@test copy(sol) == sol

Expand All @@ -184,6 +190,7 @@ probs[pddl"(pick-up a)"] += 0.9
@test get_action_probs(sol, bw_state) == probs
act_prob = probs[pddl"(pick-up a)"]
@test get_action_prob(sol, bw_state, pddl"(pick-up a)") act_prob
@test get_action_prob(sol, bw_state, pddl"(pick-up z)") == 0.0

@test copy(sol) == sol

Expand Down

0 comments on commit 6d2214d

Please sign in to comment.