Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix expected atoms monitor for multistep options #1560

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,17 @@ def __init__(self,
self._last_plan: List[_Option] = [] # used if plan WITH sim
self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim
self._last_atoms_seq: List[Set[GroundAtom]] = [] # plan WITHOUT sim
self._last_executed_option: Optional[_Option] = None
self._last_executed_option_terminated = False

def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
self._num_calls += 1
# ensure random over successive calls
seed = self._seed + self._num_calls
nsrts = self._get_current_nsrts()
preds = self._get_current_predicates()
self._last_executed_option = None
self._last_executed_option_terminated = False

# Run task planning only and then greedily sample and execute in the
# policy.
Expand All @@ -63,6 +67,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
task, nsrts, preds, timeout, seed)
self._last_nsrt_plan = nsrt_plan
self._last_atoms_seq = atoms_seq
# Always pop the first element because it's already achieved.
# self._last_atoms_seq.pop(0)
policy = utils.nsrt_plan_to_greedy_policy(nsrt_plan, task.goal,
self._rng)
logging.debug("Current Task Plan:")
Expand All @@ -80,8 +86,15 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
self._save_metrics(metrics, nsrts, preds)

def _policy(s: State) -> Action:
self._last_executed_option_terminated = False
try:
return policy(s)
# Record for execution monitoring.
act = policy(s)
option = act.get_option()
if option is not self._last_executed_option:
self._last_executed_option_terminated = True
self._last_executed_option = option
return act
except utils.OptionExecutionFailure as e:
raise ApproachFailure(e.args[0], e.info)

Expand Down Expand Up @@ -203,6 +216,7 @@ def get_execution_monitoring_info(self) -> List[Set[GroundAtom]]:
if self._plan_without_sim:
remaining_atoms_seq = list(self._last_atoms_seq)
if remaining_atoms_seq:
self._last_atoms_seq.pop(0)
if self._last_executed_option_terminated:
self._last_atoms_seq.pop(0)
return remaining_atoms_seq
return []
38 changes: 38 additions & 0 deletions tests/execution_monitoring/test_execution_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@

import pytest

from predicators import utils
from predicators.approaches import create_approach
from predicators.cogman import CogMan
from predicators.envs import get_or_create_env
from predicators.execution_monitoring import create_execution_monitor
from predicators.execution_monitoring.expected_atoms_monitor import \
ExpectedAtomsExecutionMonitor
from predicators.execution_monitoring.mpc_execution_monitor import \
MpcExecutionMonitor
from predicators.execution_monitoring.trivial_execution_monitor import \
TrivialExecutionMonitor
from predicators.ground_truth_models import get_gt_options
from predicators.perception import create_perceiver


def test_create_execution_monitor():
Expand All @@ -25,3 +31,35 @@ def test_create_execution_monitor():
with pytest.raises(NotImplementedError) as e:
create_execution_monitor("not a real monitor")
assert "Unrecognized execution monitor" in str(e)


def test_expected_atoms_execution_monitor():
"""Tests for ExpectedAtomsExecutionMonitor."""
# Test that the monitor works in an environment where options take
# multiple steps.
env_name = "cover_multistep_options"
utils.reset_config({
"env": env_name,
"approach": "oracle",
"bilevel_plan_without_sim": True,
})
env = get_or_create_env(env_name)
options = get_gt_options(env.get_name())
train_tasks = [t.task for t in env.get_train_tasks()]
approach = create_approach("oracle", env.predicates, options, env.types,
env.action_space, train_tasks)
perceiver = create_perceiver("trivial")
exec_monitor = create_execution_monitor("expected_atoms")
cogman = CogMan(approach, perceiver, exec_monitor)
env_task = env.get_test_tasks()[0]
cogman.reset(env_task)
obs = env.reset("test", 0)
# Check that the actions are not ever repeated, since re-planning should
# cause re-sampling.
prev_act = None
for _ in range(10):
act = cogman.step(obs)
obs = env.step(act)
if prev_act is not None:
assert prev_act != act
prev_act = act