Skip to content

Commit

Permalink
New Agent Abstraction [Part 1 of Multi-Agent RL] (facebookresearch#1169)
Browse files Browse the repository at this point in the history
* Added TrainerAgent

* VER changes

* Consolidated into Agent

* Integrated VER trainer

* Fixed CI

* Added pop play trainer

* Refactored agent access interface

* Functioning

* Pop play running fine

* Pre-commit fixes

* Updated naming

* removed multi-agent

* Removed unnecessary MA file

* Added docstring

* PR comments

* Syntax in VER

* Addressed PR comments

* CI tests

* Swap to val for tests

* Fixed test
  • Loading branch information
ASzot authored Mar 14, 2023
1 parent 7cc06be commit 12be06f
Show file tree
Hide file tree
Showing 17 changed files with 790 additions and 327 deletions.
3 changes: 2 additions & 1 deletion habitat-baselines/habitat_baselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

from habitat_baselines.common.base_il_trainer import BaseILTrainer
from habitat_baselines.common.base_trainer import BaseRLTrainer, BaseTrainer
from habitat_baselines.common.rollout_storage import RolloutStorage
from habitat_baselines.il.trainers.eqa_cnn_pretrain_trainer import (
EQACNNPretrainTrainer,
)
from habitat_baselines.il.trainers.pacman_trainer import PACMANTrainer
from habitat_baselines.il.trainers.vqa_trainer import VQATrainer
from habitat_baselines.rl.ppo.ppo_trainer import PPOTrainer, RolloutStorage
from habitat_baselines.rl.ppo.ppo_trainer import PPOTrainer
from habitat_baselines.rl.ver.ver_trainer import VERTrainer
from habitat_baselines.version import VERSION as __version__ # noqa: F401

Expand Down
28 changes: 28 additions & 0 deletions habitat-baselines/habitat_baselines/common/baseline_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,34 @@ def register_storage(cls, to_register=None, *, name: Optional[str] = None):
def get_storage(cls, name: str):
return cls._get_impl("storage", name)

@classmethod
def register_agent_access_mgr(
cls, to_register=None, *, name: Optional[str] = None
):
"""
Registers an agent access manager for the trainer to interface with. Usage:
```
@baseline_registry.register_agent_access_mgr
class ExampleAgentAccessMgr:
pass
```
or override the name with `name`.
```
@baseline_registry.register_agent_access_mgr(name="MyAgentAccessMgr")
class ExampleAgentAccessMgr:
pass
```
"""
from habitat_baselines.rl.ppo.agent_access_mgr import AgentAccessMgr

return cls._register_impl(
"agent", to_register, name, assert_type=AgentAccessMgr
)

@classmethod
def get_agent_access_mgr(cls, name: str):
return cls._get_impl("agent", name)

@classmethod
def register_updater(cls, to_register=None, *, name: Optional[str] = None):
"""
Expand Down
17 changes: 17 additions & 0 deletions habitat-baselines/habitat_baselines/common/env_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import attr
import gym.spaces as spaces


@attr.s(auto_attribs=True, slots=True)
class EnvironmentSpec:
"""
Stores information about the spaces of an environment.
:property obs_space: Observation space of the environment.
:property action_space: The potentially flattened version of the environment action space.
:property orig_action_space: The non-flattened version of the environment action space.
"""

observation_space: spaces.Space
action_space: spaces.Space
orig_action_space: spaces.Space
15 changes: 14 additions & 1 deletion habitat-baselines/habitat_baselines/common/rollout_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.storage import Storage
from habitat_baselines.common.tensor_dict import DictTree, TensorDict
from habitat_baselines.rl.models.rnn_state_encoder import (
build_pack_info_from_dones,
Expand All @@ -20,7 +21,7 @@


@baseline_registry.register_storage
class RolloutStorage:
class RolloutStorage(Storage):
r"""Class for storing rollout information for RL trainers."""

def __init__(
Expand Down Expand Up @@ -258,3 +259,15 @@ def __getstate__(self) -> Dict[str, Any]:

def __setstate__(self, state: Dict[str, Any]):
self.__dict__.update(state)

def insert_first_observations(self, batch):
self.buffers["observations"][0] = batch # type: ignore

def get_current_step(self, env_slice, buffer_index):
return self.buffers[
self.current_rollout_step_idxs[buffer_index],
env_slice,
]

def get_last_step(self):
return self.buffers[self.current_rollout_step_idx]
52 changes: 52 additions & 0 deletions habitat-baselines/habitat_baselines/common/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import abc

import torch


class Storage(abc.ABC):
"""
Storage interface.
"""

@abc.abstractmethod
def insert(
self,
next_observations=None,
next_recurrent_hidden_states=None,
actions=None,
action_log_probs=None,
value_preds=None,
rewards=None,
next_masks=None,
buffer_index: int = 0,
**kwargs,
):
pass

@abc.abstractmethod
def to(self, device) -> None:
pass

@abc.abstractmethod
def insert_first_observations(self, batch) -> None:
pass

@abc.abstractmethod
def advance_rollout(self, buffer_index: int = 0) -> None:
pass

@abc.abstractmethod
def compute_returns(
self, next_value: torch.Tensor, use_gae: bool, gamma: float, tau: float
) -> None:
pass

@abc.abstractmethod
def after_update(self) -> None:
pass

def get_last_step(self):
pass

def get_current_step(self, env_slice, buffer_index):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,16 @@ class DDPPOConfig(HabitatBaselinesBaseConfig):
force_distributed: bool = False


@dataclass
class AgentAccessMgrConfig(HabitatBaselinesBaseConfig):
type: str = "SingleAgentAccessMgr"


@dataclass
class RLConfig(HabitatBaselinesBaseConfig):
"""Reinforcement learning config"""

agent: AgentAccessMgrConfig = AgentAccessMgrConfig()
preemption: PreemptionConfig = PreemptionConfig()
policy: PolicyConfig = PolicyConfig()
ppo: PPOConfig = PPOConfig()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def act(
# If any skills want to terminate invoke the high-level policy to get
# the next skill.
hl_terminate = torch.zeros(self._num_envs, dtype=torch.bool)
hl_info: Dict[str, Any] = {}
hl_info: Dict[str, Any] = self._high_level_policy.create_hl_info()
if call_high_level.sum() > 0:
(
new_skills,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def get_next_skill(
"""
raise NotImplementedError()

def create_hl_info(self) -> Dict[str, Any]:
return {}

def apply_mask(self, mask: torch.Tensor) -> None:
"""
Called before every step with the mask information at the current step.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def __init__(self, *args, **kwargs):
self._policy = CategoricalNet(self._hidden_size, self._n_actions)
self._critic = CriticHead(self._hidden_size)

def create_hl_info(self):
return {"actions": None}

def _setup_actions(self) -> List[PddlAction]:
all_actions = self._pddl_prob.get_possible_actions()
all_actions = [
Expand Down
139 changes: 139 additions & 0 deletions habitat-baselines/habitat_baselines/rl/ppo/agent_access_mgr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

import gym.spaces as spaces

from habitat_baselines.common.env_spec import EnvironmentSpec
from habitat_baselines.common.storage import Storage
from habitat_baselines.rl.ppo.policy import Policy
from habitat_baselines.rl.ppo.updater import Updater

if TYPE_CHECKING:
from omegaconf import DictConfig


class AgentAccessMgr(ABC):
"""
Consists of:
- Policy: How actions are selected from observations.
- Data Storage: How data collected from the environment is stored.
- Updater: How the Policy is updated.
"""

@abstractmethod
def __init__(
self,
config: "DictConfig",
env_spec: EnvironmentSpec,
is_distrib: bool,
device,
resume_state: Optional[Dict[str, Any]],
num_envs: int,
percent_done_fn: Callable[[], float],
lr_schedule_fn: Optional[Callable[[float], float]] = None,
):
pass

@property
@abstractmethod
def nbuffers(self) -> int:
"""
Number of storage buffers.
"""
raise NotImplementedError()

@abstractmethod
def post_init(self, create_rollouts_fn: Optional[Callable] = None) -> None:
"""
Called after the constructor. Sets up the rollout storage.
:param create_rollouts_fn: Override behavior for creating the
rollout storage. Default behavior for this and the call signature is
`default_create_rollouts`.
"""
raise NotImplementedError()

@property
@abstractmethod
def policy_action_space(self) -> spaces.Space:
"""
The action space the policy acts in. This can be different from the
environment action space for hierarchical policies.
"""
raise NotImplementedError()

@property
@abstractmethod
def rollouts(self) -> Storage:
"""
Gets the current rollout storage.
"""
raise NotImplementedError()

@property
@abstractmethod
def actor_critic(self) -> Policy:
"""
Gets the current policy
"""
raise NotImplementedError()

@property
@abstractmethod
def updater(self) -> Updater:
"""
Gets the current policy updater.
"""
raise NotImplementedError()

@abstractmethod
def get_resume_state(self) -> Dict[str, Any]:
raise NotImplementedError()

@abstractmethod
def get_save_state(self) -> Dict[str, Any]:
raise NotImplementedError()

@abstractmethod
def eval(self) -> None:
raise NotImplementedError()

@abstractmethod
def train(self) -> None:
raise NotImplementedError()

@abstractmethod
def load_ckpt_state_dict(self, ckpt: Dict) -> None:
"""
Loads a state dict for evaluation. The difference from
`load_state_dict` is that this will not load the policy state if the
policy does not request it.
"""
raise NotImplementedError()

@abstractmethod
def load_state_dict(self, state: Dict) -> None:
raise NotImplementedError()

@property
@abstractmethod
def hidden_state_shape(self) -> Tuple[int]:
"""
The shape of the tensor to track the hidden state, such as the RNN hidden state.
"""
raise NotImplementedError()

@abstractmethod
def after_update(self) -> None:
"""
Must be called by the trainer after the updater has called `update` and
the rollout `after_update` is called.
"""
raise NotImplementedError()

@abstractmethod
def pre_rollout(self) -> None:
"""
Called before a rollout is collected.
"""
raise NotImplementedError()
9 changes: 8 additions & 1 deletion habitat-baselines/habitat_baselines/rl/ppo/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,14 @@ def all_policy_tensors(self) -> Iterable[torch.Tensor]:
for c in self._get_policy_components():
yield from c.buffers()

def extract_policy_info(
def get_value(
self, observations, rnn_hidden_states, prev_actions, masks
) -> torch.Tensor:
raise NotImplementedError(
"Get value is supported in non-neural network policies."
)

def get_extra(
self, action_data: PolicyActionData, infos, dones
) -> List[Dict[str, float]]:
"""
Expand Down
3 changes: 2 additions & 1 deletion habitat-baselines/habitat_baselines/rl/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.rollout_storage import RolloutStorage
from habitat_baselines.rl.ppo.policy import NetPolicy
from habitat_baselines.rl.ppo.updater import Updater
from habitat_baselines.rl.ver.ver_rollout_storage import VERRolloutStorage
from habitat_baselines.utils.common import (
LagrangeInequalityCoefficient,
Expand All @@ -28,7 +29,7 @@


@baseline_registry.register_updater
class PPO(nn.Module):
class PPO(nn.Module, Updater):
entropy_coef: Union[float, LagrangeInequalityCoefficient]

@classmethod
Expand Down
Loading

0 comments on commit 12be06f

Please sign in to comment.