From daf30f03d2952345f310bfe6e5fbfc952764cd37 Mon Sep 17 00:00:00 2001 From: Tom Silver Date: Tue, 26 Sep 2023 12:51:42 -0400 Subject: [PATCH 1/3] fix expected atoms monitor to work with multistep options --- .../approaches/bilevel_planning_approach.py | 5 +-- predicators/cogman.py | 3 +- .../base_execution_monitor.py | 10 ++++- .../expected_atoms_monitor.py | 29 +++++++++++---- .../test_execution_monitoring.py | 37 +++++++++++++++++++ 5 files changed, 69 insertions(+), 15 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 6466e5bde4..bce45adb09 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -201,8 +201,5 @@ def get_last_nsrt_plan(self) -> List[_GroundNSRT]: def get_execution_monitoring_info(self) -> List[Set[GroundAtom]]: if self._plan_without_sim: - remaining_atoms_seq = list(self._last_atoms_seq) - if remaining_atoms_seq: - self._last_atoms_seq.pop(0) - return remaining_atoms_seq + return list(self._last_atoms_seq) return [] diff --git a/predicators/cogman.py b/predicators/cogman.py index 121f681291..aaae98aeb4 100644 --- a/predicators/cogman.py +++ b/predicators/cogman.py @@ -70,8 +70,7 @@ def step(self, observation: Observation) -> Optional[Action]: assert not self._exec_monitor.step(state) assert self._current_policy is not None act = self._current_policy(state) - self._exec_monitor.update_approach_info( - self._approach.get_execution_monitoring_info()) + self._exec_monitor.update_with_action(act) self._episode_action_history.append(act) return act diff --git a/predicators/execution_monitoring/base_execution_monitor.py b/predicators/execution_monitoring/base_execution_monitor.py index 6e1a75c287..8b29c08e96 100644 --- a/predicators/execution_monitoring/base_execution_monitor.py +++ b/predicators/execution_monitoring/base_execution_monitor.py @@ -1,9 +1,9 @@ """Base class for execution monitors.""" import abc -from typing import Any, List +from typing import Any, List, Optional -from predicators.structs import State, Task +from predicators.structs import State, Task, Action class BaseExecutionMonitor(abc.ABC): @@ -12,6 +12,7 @@ class BaseExecutionMonitor(abc.ABC): def __init__(self) -> None: self._approach_info: List[Any] = [] self._curr_plan_timestep = 0 + self._prev_action: Optional[Action] = None @classmethod @abc.abstractmethod @@ -22,11 +23,16 @@ def reset(self, task: Task) -> None: """Reset after replanning.""" del task # unused self._curr_plan_timestep = 0 + self._prev_action = None @abc.abstractmethod def step(self, state: State) -> bool: """Return true if the agent should replan.""" + def update_with_action(self, action: Action) -> None: + """Called after each action is executed.""" + self._prev_action = action + def update_approach_info(self, info: List[Any]) -> None: """Update internal info received from approach.""" self._approach_info = info diff --git a/predicators/execution_monitoring/expected_atoms_monitor.py b/predicators/execution_monitoring/expected_atoms_monitor.py index 884483b15e..c9beb1a9b1 100644 --- a/predicators/execution_monitoring/expected_atoms_monitor.py +++ b/predicators/execution_monitoring/expected_atoms_monitor.py @@ -3,10 +3,12 @@ import logging +from typing import Set + from predicators.execution_monitoring.base_execution_monitor import \ BaseExecutionMonitor from predicators.settings import CFG -from predicators.structs import State +from predicators.structs import Action, State, GroundAtom class ExpectedAtomsExecutionMonitor(BaseExecutionMonitor): @@ -25,15 +27,28 @@ def step(self, state: State) -> bool: # If the approach info is empty, don't replan. if not self._approach_info: # pragma: no cover return False - next_expected_atoms = self._approach_info[0] - assert isinstance(next_expected_atoms, set) + # If the previous action just terminated, advance the expected atoms. + if self._prev_action is not None and self._prev_action.has_option(): + option = self._prev_action.get_option() + if option.terminal(state): + self._advance_expected_atoms() + next_expected_atoms = self._get_expected_atoms() self._curr_plan_timestep += 1 # If the expected atoms are a subset of the current atoms, then # we don't have to replan. unsat_atoms = {a for a in next_expected_atoms if not a.holds(state)} if not unsat_atoms: return False - logging.info( - "Expected atoms execution monitor triggered replanning " - f"because of these atoms: {unsat_atoms}") # pragma: no cover - return True # pragma: no cover + logging.info("Expected atoms execution monitor triggered replanning " + f"because of these atoms: {unsat_atoms}") + return True + + def _get_expected_atoms(self) -> Set[GroundAtom]: + expected_atoms = self._approach_info[0] + assert isinstance(expected_atoms, set) + return expected_atoms + + def _advance_expected_atoms(self) -> None: + expected_atom_seq = self._approach_info + assert isinstance(expected_atom_seq, list) + expected_atom_seq.pop(0) diff --git a/tests/execution_monitoring/test_execution_monitoring.py b/tests/execution_monitoring/test_execution_monitoring.py index 6c7c5acfe2..982364bf9a 100644 --- a/tests/execution_monitoring/test_execution_monitoring.py +++ b/tests/execution_monitoring/test_execution_monitoring.py @@ -9,6 +9,12 @@ MpcExecutionMonitor from predicators.execution_monitoring.trivial_execution_monitor import \ TrivialExecutionMonitor +from predicators.cogman import CogMan +from predicators.perception import create_perceiver +from predicators.envs import get_or_create_env +from predicators import utils +from predicators.approaches import create_approach +from predicators.ground_truth_models import get_gt_options def test_create_execution_monitor(): @@ -25,3 +31,34 @@ def test_create_execution_monitor(): with pytest.raises(NotImplementedError) as e: create_execution_monitor("not a real monitor") assert "Unrecognized execution monitor" in str(e) + + +def test_expected_atoms_execution_monitor(): + """Tests for ExpectedAtomsExecutionMonitor.""" + # Test that the monitor works in an environment where options take + # multiple steps. + env_name = "cover_multistep_options" + utils.reset_config({ + "env": env_name, + "approach": "oracle", + "bilevel_plan_without_sim": True, + }) + env = get_or_create_env(env_name) + options = get_gt_options(env.get_name()) + train_tasks = [t.task for t in env.get_train_tasks()] + approach = create_approach("oracle", env.predicates, options, env.types, env.action_space, train_tasks) + perceiver = create_perceiver("trivial") + exec_monitor = create_execution_monitor("expected_atoms") + cogman = CogMan(approach, perceiver, exec_monitor) + env_task = env.get_test_tasks()[0] + cogman.reset(env_task) + obs = env.reset("test", 0) + # Check that the actions are not ever repeated, since re-planning should + # cause re-sampling. + prev_act = None + for _ in range(10): + act = cogman.step(obs) + obs = env.step(act) + if prev_act is not None: + assert prev_act != act + prev_act = act From 8b2989440417b1a03fab9cf7faa8bc28c3d4b809 Mon Sep 17 00:00:00 2001 From: Tom Silver Date: Tue, 26 Sep 2023 12:52:26 -0400 Subject: [PATCH 2/3] clean --- .../execution_monitoring/base_execution_monitor.py | 2 +- .../execution_monitoring/expected_atoms_monitor.py | 3 +-- .../test_execution_monitoring.py | 13 +++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/predicators/execution_monitoring/base_execution_monitor.py b/predicators/execution_monitoring/base_execution_monitor.py index 8b29c08e96..93047d4346 100644 --- a/predicators/execution_monitoring/base_execution_monitor.py +++ b/predicators/execution_monitoring/base_execution_monitor.py @@ -3,7 +3,7 @@ import abc from typing import Any, List, Optional -from predicators.structs import State, Task, Action +from predicators.structs import Action, State, Task class BaseExecutionMonitor(abc.ABC): diff --git a/predicators/execution_monitoring/expected_atoms_monitor.py b/predicators/execution_monitoring/expected_atoms_monitor.py index c9beb1a9b1..82eda78f06 100644 --- a/predicators/execution_monitoring/expected_atoms_monitor.py +++ b/predicators/execution_monitoring/expected_atoms_monitor.py @@ -2,13 +2,12 @@ suggest replanning when the expected atoms check is not met.""" import logging - from typing import Set from predicators.execution_monitoring.base_execution_monitor import \ BaseExecutionMonitor from predicators.settings import CFG -from predicators.structs import Action, State, GroundAtom +from predicators.structs import GroundAtom, State class ExpectedAtomsExecutionMonitor(BaseExecutionMonitor): diff --git a/tests/execution_monitoring/test_execution_monitoring.py b/tests/execution_monitoring/test_execution_monitoring.py index 982364bf9a..b059fa65b8 100644 --- a/tests/execution_monitoring/test_execution_monitoring.py +++ b/tests/execution_monitoring/test_execution_monitoring.py @@ -2,6 +2,10 @@ import pytest +from predicators import utils +from predicators.approaches import create_approach +from predicators.cogman import CogMan +from predicators.envs import get_or_create_env from predicators.execution_monitoring import create_execution_monitor from predicators.execution_monitoring.expected_atoms_monitor import \ ExpectedAtomsExecutionMonitor @@ -9,12 +13,8 @@ MpcExecutionMonitor from predicators.execution_monitoring.trivial_execution_monitor import \ TrivialExecutionMonitor -from predicators.cogman import CogMan -from predicators.perception import create_perceiver -from predicators.envs import get_or_create_env -from predicators import utils -from predicators.approaches import create_approach from predicators.ground_truth_models import get_gt_options +from predicators.perception import create_perceiver def test_create_execution_monitor(): @@ -46,7 +46,8 @@ def test_expected_atoms_execution_monitor(): env = get_or_create_env(env_name) options = get_gt_options(env.get_name()) train_tasks = [t.task for t in env.get_train_tasks()] - approach = create_approach("oracle", env.predicates, options, env.types, env.action_space, train_tasks) + approach = create_approach("oracle", env.predicates, options, env.types, + env.action_space, train_tasks) perceiver = create_perceiver("trivial") exec_monitor = create_execution_monitor("expected_atoms") cogman = CogMan(approach, perceiver, exec_monitor) From 7b7eb3d3621c9598a4fa55d161a9d84aab62f1c7 Mon Sep 17 00:00:00 2001 From: Tom Silver Date: Tue, 26 Sep 2023 15:15:54 -0400 Subject: [PATCH 3/3] try again... --- .../approaches/bilevel_planning_approach.py | 21 ++++++++++++-- predicators/cogman.py | 3 +- .../base_execution_monitor.py | 10 ++----- .../expected_atoms_monitor.py | 28 +++++-------------- 4 files changed, 30 insertions(+), 32 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index bce45adb09..ac80e301ca 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -48,6 +48,8 @@ def __init__(self, self._last_plan: List[_Option] = [] # used if plan WITH sim self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim self._last_atoms_seq: List[Set[GroundAtom]] = [] # plan WITHOUT sim + self._last_executed_option: Optional[_Option] = None + self._last_executed_option_terminated = False def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: self._num_calls += 1 @@ -55,6 +57,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: seed = self._seed + self._num_calls nsrts = self._get_current_nsrts() preds = self._get_current_predicates() + self._last_executed_option = None + self._last_executed_option_terminated = False # Run task planning only and then greedily sample and execute in the # policy. @@ -63,6 +67,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: task, nsrts, preds, timeout, seed) self._last_nsrt_plan = nsrt_plan self._last_atoms_seq = atoms_seq + # Always pop the first element because it's already achieved. + # self._last_atoms_seq.pop(0) policy = utils.nsrt_plan_to_greedy_policy(nsrt_plan, task.goal, self._rng) logging.debug("Current Task Plan:") @@ -80,8 +86,15 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: self._save_metrics(metrics, nsrts, preds) def _policy(s: State) -> Action: + self._last_executed_option_terminated = False try: - return policy(s) + # Record for execution monitoring. + act = policy(s) + option = act.get_option() + if option is not self._last_executed_option: + self._last_executed_option_terminated = True + self._last_executed_option = option + return act except utils.OptionExecutionFailure as e: raise ApproachFailure(e.args[0], e.info) @@ -201,5 +214,9 @@ def get_last_nsrt_plan(self) -> List[_GroundNSRT]: def get_execution_monitoring_info(self) -> List[Set[GroundAtom]]: if self._plan_without_sim: - return list(self._last_atoms_seq) + remaining_atoms_seq = list(self._last_atoms_seq) + if remaining_atoms_seq: + if self._last_executed_option_terminated: + self._last_atoms_seq.pop(0) + return remaining_atoms_seq return [] diff --git a/predicators/cogman.py b/predicators/cogman.py index aaae98aeb4..121f681291 100644 --- a/predicators/cogman.py +++ b/predicators/cogman.py @@ -70,7 +70,8 @@ def step(self, observation: Observation) -> Optional[Action]: assert not self._exec_monitor.step(state) assert self._current_policy is not None act = self._current_policy(state) - self._exec_monitor.update_with_action(act) + self._exec_monitor.update_approach_info( + self._approach.get_execution_monitoring_info()) self._episode_action_history.append(act) return act diff --git a/predicators/execution_monitoring/base_execution_monitor.py b/predicators/execution_monitoring/base_execution_monitor.py index 93047d4346..6e1a75c287 100644 --- a/predicators/execution_monitoring/base_execution_monitor.py +++ b/predicators/execution_monitoring/base_execution_monitor.py @@ -1,9 +1,9 @@ """Base class for execution monitors.""" import abc -from typing import Any, List, Optional +from typing import Any, List -from predicators.structs import Action, State, Task +from predicators.structs import State, Task class BaseExecutionMonitor(abc.ABC): @@ -12,7 +12,6 @@ class BaseExecutionMonitor(abc.ABC): def __init__(self) -> None: self._approach_info: List[Any] = [] self._curr_plan_timestep = 0 - self._prev_action: Optional[Action] = None @classmethod @abc.abstractmethod @@ -23,16 +22,11 @@ def reset(self, task: Task) -> None: """Reset after replanning.""" del task # unused self._curr_plan_timestep = 0 - self._prev_action = None @abc.abstractmethod def step(self, state: State) -> bool: """Return true if the agent should replan.""" - def update_with_action(self, action: Action) -> None: - """Called after each action is executed.""" - self._prev_action = action - def update_approach_info(self, info: List[Any]) -> None: """Update internal info received from approach.""" self._approach_info = info diff --git a/predicators/execution_monitoring/expected_atoms_monitor.py b/predicators/execution_monitoring/expected_atoms_monitor.py index 82eda78f06..884483b15e 100644 --- a/predicators/execution_monitoring/expected_atoms_monitor.py +++ b/predicators/execution_monitoring/expected_atoms_monitor.py @@ -2,12 +2,11 @@ suggest replanning when the expected atoms check is not met.""" import logging -from typing import Set from predicators.execution_monitoring.base_execution_monitor import \ BaseExecutionMonitor from predicators.settings import CFG -from predicators.structs import GroundAtom, State +from predicators.structs import State class ExpectedAtomsExecutionMonitor(BaseExecutionMonitor): @@ -26,28 +25,15 @@ def step(self, state: State) -> bool: # If the approach info is empty, don't replan. if not self._approach_info: # pragma: no cover return False - # If the previous action just terminated, advance the expected atoms. - if self._prev_action is not None and self._prev_action.has_option(): - option = self._prev_action.get_option() - if option.terminal(state): - self._advance_expected_atoms() - next_expected_atoms = self._get_expected_atoms() + next_expected_atoms = self._approach_info[0] + assert isinstance(next_expected_atoms, set) self._curr_plan_timestep += 1 # If the expected atoms are a subset of the current atoms, then # we don't have to replan. unsat_atoms = {a for a in next_expected_atoms if not a.holds(state)} if not unsat_atoms: return False - logging.info("Expected atoms execution monitor triggered replanning " - f"because of these atoms: {unsat_atoms}") - return True - - def _get_expected_atoms(self) -> Set[GroundAtom]: - expected_atoms = self._approach_info[0] - assert isinstance(expected_atoms, set) - return expected_atoms - - def _advance_expected_atoms(self) -> None: - expected_atom_seq = self._approach_info - assert isinstance(expected_atom_seq, list) - expected_atom_seq.pop(0) + logging.info( + "Expected atoms execution monitor triggered replanning " + f"because of these atoms: {unsat_atoms}") # pragma: no cover + return True # pragma: no cover