Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cogman to online interaction #1488

Merged
merged 15 commits into from
Jun 28, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _learn_nsrt_sampler(self, nsrt_data: _OptionSamplerDataset,
ensemble_size=CFG.active_sampler_learning_num_ensemble_members,
member_cls=MLPBinaryClassifier,
balance_data=CFG.mlp_classifier_balance_data,
max_train_iters=CFG.predicate_mlp_classifier_max_itr,
max_train_iters=CFG.sampler_mlp_classifier_max_itr,
learning_rate=CFG.learning_rate,
n_iter_no_change=CFG.mlp_classifier_n_iter_no_change,
hid_sizes=CFG.mlp_classifier_hid_sizes,
Expand Down
71 changes: 65 additions & 6 deletions predicators/cogman.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@

The name "CogMan" is due to Leslie Kaelbling.
"""
import logging
from typing import Callable, List, Optional, Sequence, Set

from predicators.approaches import BaseApproach
from predicators.execution_monitoring import BaseExecutionMonitor
from predicators.perception import BasePerceiver
from predicators.settings import CFG
from predicators.structs import Action, Dataset, EnvironmentTask, GroundAtom, \
InteractionRequest, InteractionResult, Metrics, Observation, State, Task
InteractionRequest, InteractionResult, LowLevelTrajectory, Metrics, \
Observation, State, Task


class CogMan:
Expand All @@ -28,34 +30,61 @@ def __init__(self, approach: BaseApproach, perceiver: BasePerceiver,
self._exec_monitor = execution_monitor
self._current_policy: Optional[Callable[[State], Action]] = None
self._current_goal: Optional[Set[GroundAtom]] = None
self._override_policy: Optional[Callable[[State], Action]] = None
self._termination_fn: Optional[Callable[[State], bool]] = None
self._episode_state_history: List[State] = []
self._episode_action_history: List[Action] = []

def reset(self, env_task: EnvironmentTask) -> None:
"""Start a new episode of environment interaction."""
logging.info("[CogMan] Reset called.")
task = self._perceiver.reset(env_task)
self._current_goal = task.goal
self._current_policy = self._approach.solve(task, timeout=CFG.timeout)
self._reset_policy(task)
self._exec_monitor.reset(task)
self._exec_monitor.update_approach_info(
self._approach.get_execution_monitoring_info())
self._episode_state_history = [task.init]
self._episode_action_history = []

def step(self, observation: Observation) -> Action:
"""Receive an observation and produce an action."""
def step(self, observation: Observation) -> Optional[Action]:
"""Receive an observation and produce an action, or None for done."""
logging.info("[CogMan] Step called.")
state = self._perceiver.step(observation)
# Replace the first step because the state was already added in reset().
if not self._episode_action_history:
self._episode_state_history[0] = state
else:
self._episode_state_history.append(state)
if self._termination_fn is not None and self._termination_fn(state):
logging.info("[CogMan] Termination triggered.")
return None
# Check if we should replan.
if self._exec_monitor.step(state):
logging.info("[CogMan] Replanning triggered.")
assert self._current_goal is not None
task = Task(state, self._current_goal)
new_policy = self._approach.solve(task, timeout=CFG.timeout)
self._current_policy = new_policy
self._reset_policy(task)
self._exec_monitor.reset(task)
self._exec_monitor.update_approach_info(
self._approach.get_execution_monitoring_info())
assert not self._exec_monitor.step(state)
assert self._current_policy is not None
act = self._current_policy(state)
self._exec_monitor.update_approach_info(
self._approach.get_execution_monitoring_info())
self._episode_action_history.append(act)
logging.info("[CogMan] Returning action.")
return act

def finish_episode(self, observation: Observation) -> None:
"""Called at the end of an episode."""
logging.info("[CogMan] Finishing episode.")
if len(self._episode_state_history) == len(
self._episode_action_history):
state = self._perceiver.step(observation)
self._episode_state_history.append(state)

# The methods below provide an interface to the approach. In the future,
# we may want to move some of these methods into cogman properly, e.g.,
# if we want the perceiver or execution monitor to learn from data.
Expand Down Expand Up @@ -90,3 +119,33 @@ def metrics(self) -> Metrics:
def reset_metrics(self) -> None:
"""See BaseApproach docstring."""
return self._approach.reset_metrics()

def set_override_policy(self, policy: Callable[[State], Action]) -> None:
"""Used during online interaction."""
self._override_policy = policy

def unset_override_policy(self) -> None:
"""Give control back to the approach."""
self._override_policy = None

def set_termination_function(
self, termination_fn: Callable[[State], bool]) -> None:
"""Used during online interaction."""
self._termination_fn = termination_fn

def unset_termination_function(self) -> None:
"""Reset to never willfully terminating."""
self._termination_fn = None

def get_current_history(self) -> LowLevelTrajectory:
"""Expose the most recent state, action history for learning."""
return LowLevelTrajectory(self._episode_state_history,
self._episode_action_history)

def _reset_policy(self, task: Task) -> None:
"""Call the approach or use the override policy."""
if self._override_policy is not None:
self._current_policy = self._override_policy
else:
self._current_policy = self._approach.solve(task,
timeout=CFG.timeout)
12 changes: 7 additions & 5 deletions predicators/execution_monitoring/mpc_execution_monitor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""A model-predictive control monitor that always suggests replanning."""

from predicators.execution_monitoring.base_execution_monitor import \
BaseExecutionMonitor
from predicators.structs import State, Task
from predicators.structs import State


class MpcExecutionMonitor(BaseExecutionMonitor):
Expand All @@ -12,8 +11,11 @@ class MpcExecutionMonitor(BaseExecutionMonitor):
def get_name(cls) -> str:
return "mpc"

def reset(self, task: Task) -> None:
pass

def step(self, state: State) -> bool:
# Don't trigger replanning on the 0th
# timestep.
if self._curr_plan_timestep == 0:
self._curr_plan_timestep += 1
return False
# Otherwise, trigger replanning.
return True
56 changes: 40 additions & 16 deletions predicators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
from predicators.ground_truth_models import get_gt_options, \
parse_config_included_options
from predicators.perception import create_perceiver
from predicators.settings import CFG
from predicators.settings import CFG, get_allowed_query_type_names
from predicators.structs import Action, Dataset, InteractionRequest, \
InteractionResult, Metrics, Observation, Task
InteractionResult, Metrics, Observation, Response, Task, Video
from predicators.teacher import Teacher, TeacherInteractionMonitorWithVideo

assert os.environ.get("PYTHONHASHSEED") == "0", \
Expand Down Expand Up @@ -171,7 +171,11 @@ def _run_pipeline(env: BaseEnv,
results["learning_time"] = learning_time
results.update(offline_learning_metrics)
_save_test_results(results, online_learning_cycle=None)
teacher = Teacher(train_tasks)
# Only create a teacher if there are possibly queries coming.
if get_allowed_query_type_names():
teacher = Teacher(train_tasks)
else:
teacher = None
# The online learning loop.
for i in range(CFG.num_online_learning_cycles):
if i < CFG.skip_until_cycle:
Expand All @@ -188,7 +192,7 @@ def _run_pipeline(env: BaseEnv,
"terminating")
break # agent doesn't want to learn anything more; terminate
interaction_results, query_cost = _generate_interaction_results(
env, teacher, interaction_requests, i)
cogman, env, teacher, interaction_requests, i)
num_online_transitions += sum(
len(result.actions) for result in interaction_results)
total_query_cost += query_cost
Expand Down Expand Up @@ -219,8 +223,9 @@ def _run_pipeline(env: BaseEnv,


def _generate_interaction_results(
cogman: CogMan,
env: BaseEnv,
teacher: Teacher,
teacher: Optional[Teacher],
requests: Sequence[InteractionRequest],
cycle_num: Optional[int] = None
) -> Tuple[List[InteractionResult], float]:
Expand All @@ -230,32 +235,48 @@ def _generate_interaction_results(
results = []
query_cost = 0.0
if CFG.make_interaction_videos:
video = []
video: Video = []
for request in requests:
if request.train_task_idx < CFG.max_initial_demos and \
not CFG.allow_interaction_in_demo_tasks:
raise RuntimeError("Interaction requests cannot be on demo tasks "
"if allow_interaction_in_demo_tasks is False.")
monitor = TeacherInteractionMonitorWithVideo(env.render, request,
teacher)
traj, _ = utils.run_policy(
request.act_policy,
monitor: Optional[TeacherInteractionMonitorWithVideo] = None
if teacher is not None:
monitor = TeacherInteractionMonitorWithVideo(
env.render, request, teacher)
cogman.set_override_policy(request.act_policy)
cogman.set_termination_function(request.termination_function)
env_task = env.get_train_tasks()[request.train_task_idx]
cogman.reset(env_task)
observed_traj, _, _ = _run_episode(
cogman,
env,
"train",
request.train_task_idx,
request.termination_function,
max_num_steps=CFG.max_num_steps_interaction_request,
exceptions_to_break_on={
utils.EnvironmentFailure, utils.OptionExecutionFailure,
utils.RequestActPolicyFailure
utils.EnvironmentFailure,
utils.OptionExecutionFailure,
utils.RequestActPolicyFailure,
},
monitor=monitor)
request_responses = monitor.get_responses()
query_cost += monitor.get_query_cost()
cogman.unset_override_policy()
cogman.unset_termination_function()
traj = cogman.get_current_history()
request_responses: List[Optional[Response]] = [
None for _ in traj.states
]
if monitor is not None:
request_responses = monitor.get_responses()
query_cost += monitor.get_query_cost()
assert len(traj.states) == len(observed_traj[0])
assert len(traj.actions) == len(observed_traj[1])
result = InteractionResult(traj.states, traj.actions,
request_responses)
results.append(result)
if CFG.make_interaction_videos:
assert monitor is not None
video.extend(monitor.get_video())
if CFG.make_interaction_videos:
save_prefix = utils.get_config_path_str()
Expand Down Expand Up @@ -421,7 +442,7 @@ def _run_episode(
Note that the environment and cogman internal states are updated.

Terminates when any of these conditions hold:
(1) the termination_function returns True
(1) cogman.step returns None, indicating termination
(2) max_num_steps is reached
(3) cogman or env raise an exception of type in exceptions_to_break_on

Expand Down Expand Up @@ -450,6 +471,8 @@ def _run_episode(
start_time = time.perf_counter()
act = cogman.step(obs)
metrics["policy_call_time"] += time.perf_counter() - start_time
if act is None:
break
# Note: it's important to call monitor.observe() before
# env.step(), because the monitor may, for example, call
# env.render(), which outputs images of the current env
Expand All @@ -475,6 +498,7 @@ def _run_episode(
break
if monitor is not None and not exception_raised_in_step:
monitor.observe(obs, None)
cogman.finish_episode(obs)
traj = (observations, actions)
solved = env.goal_reached()
return traj, solved, metrics
Expand Down
3 changes: 2 additions & 1 deletion predicators/teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ class TeacherInteractionMonitorWithVideo(TeacherInteractionMonitor,
"""

def observe(self, obs: Observation, action: Optional[Action]) -> None:
assert obs.allclose(self._teacher_env.get_observation())
if isinstance(obs, State):
assert obs.allclose(self._teacher_env.get_observation())
if action is not None:
self._teacher_env.step(action)
state = obs
Expand Down
8 changes: 7 additions & 1 deletion tests/approaches/test_active_sampler_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from predicators import utils
from predicators.approaches.active_sampler_learning_approach import \
ActiveSamplerLearningApproach
from predicators.cogman import CogMan
from predicators.datasets import create_dataset
from predicators.envs.cover import BumpyCoverEnv
from predicators.execution_monitoring import create_execution_monitor
from predicators.ground_truth_models import get_gt_options
from predicators.main import _generate_interaction_results
from predicators.perception import create_perceiver
from predicators.settings import CFG
from predicators.structs import Dataset
from predicators.teacher import Teacher
Expand Down Expand Up @@ -60,8 +63,11 @@ def test_active_sampler_learning_approach(model_name, right_targets, num_demo):
approach.load(online_learning_cycle=None)
interaction_requests = approach.get_interaction_requests()
teacher = Teacher(train_tasks)
perceiver = create_perceiver("trivial")
exec_monitor = create_execution_monitor("trivial")
cogman = CogMan(approach, perceiver, exec_monitor)
interaction_results, _ = _generate_interaction_results(
env, teacher, interaction_requests)
cogman, env, teacher, interaction_requests)
approach.learn_from_interaction_results(interaction_results)
approach.load(online_learning_cycle=0)
with pytest.raises(FileNotFoundError):
Expand Down
8 changes: 7 additions & 1 deletion tests/approaches/test_bridge_policy_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from predicators.approaches import ApproachFailure, ApproachTimeout
from predicators.approaches.bridge_policy_approach import BridgePolicyApproach
from predicators.bridge_policies import BridgePolicyDone
from predicators.cogman import CogMan
from predicators.envs import get_or_create_env
from predicators.execution_monitoring import create_execution_monitor
from predicators.ground_truth_models import get_gt_options
from predicators.main import _generate_interaction_results
from predicators.perception import create_perceiver
from predicators.settings import CFG
from predicators.structs import Action, DemonstrationResponse, DummyOption, \
InteractionResult, LowLevelTrajectory, STRIPSOperator
Expand Down Expand Up @@ -192,8 +195,11 @@ def _mock_human_demonstratory_policy(*args, **kwargs):
m.side_effect = _mock_human_demonstratory_policy
interaction_requests = approach.get_interaction_requests()
teacher = Teacher(train_tasks)
perceiver = create_perceiver("trivial")
exec_monitor = create_execution_monitor("trivial")
cogman = CogMan(approach, perceiver, exec_monitor)
interaction_results, _ = _generate_interaction_results(
env, teacher, interaction_requests)
cogman, env, teacher, interaction_requests)
real_result = interaction_results[0]
# Add additional interaction result with no queries.
interaction_results.append(
Expand Down
Loading