From 2c5852cd9b47eb8b6ca1302e50135e86ba008467 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 7 Jan 2025 18:45:58 +0100 Subject: [PATCH 1/8] ensemble model --- benchmarl/models/__init__.py | 8 ++++- benchmarl/models/common.py | 53 +++++++++++++++++++++++++++++ benchmarl/utils.py | 6 ++-- examples/ensemble/ensemble_model.py | 40 ++++++++++++++++++++++ 4 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 examples/ensemble/ensemble_model.py diff --git a/benchmarl/models/__init__.py b/benchmarl/models/__init__.py index ada32ec7..f72586bf 100644 --- a/benchmarl/models/__init__.py +++ b/benchmarl/models/__init__.py @@ -5,7 +5,13 @@ # from .cnn import Cnn, CnnConfig -from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig +from .common import ( + EnsembleModelConfig, + Model, + ModelConfig, + SequenceModel, + SequenceModelConfig, +) from .deepsets import Deepsets, DeepsetsConfig from .gnn import Gnn, GnnConfig from .gru import Gru, GruConfig diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index d17d0c1c..4cfc04aa 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -349,6 +349,11 @@ def get_model_state_spec(self, model_index: int = 0) -> Composite: """ return Composite() + def _get_model_state_spec_inner( + self, model_index: int = 0, group: str = None + ) -> Composite: + return self.get_model_state_spec(model_index) + @staticmethod def _load_from_yaml(name: str) -> Dict[str, Any]: yaml_path = ( @@ -522,3 +527,51 @@ def is_rnn(self) -> bool: @classmethod def get_from_yaml(cls, path: Optional[str] = None): raise NotImplementedError + + +@dataclass +class EnsembleModelConfig(ModelConfig): + + model_configs_map: Dict[str, ModelConfig] + + def get_model(self, agent_group: str, **kwargs) -> Model: + return self.model_configs_map[agent_group].get_model( + **kwargs, agent_group=agent_group + ) + + @staticmethod + def associated_class(): + class EnsembleModel(Model): + pass + + return EnsembleModel + + @property + def is_critic(self): + if not hasattr(self, "_is_critic"): + self._is_critic = False + return self._is_critic + + @is_critic.setter + def is_critic(self, value): + self._is_critic = value + for model_config in self.model_configs_map.values(): + model_config.is_critic = value + + def _get_model_state_spec_inner( + self, model_index: int = 0, group: str = None + ) -> Composite: + return self.model_configs_map[group].get_model_state_spec( + model_index=model_index + ) + + @property + def is_rnn(self) -> bool: + is_rnn = False + for model_config in self.model_configs_map.values(): + is_rnn += model_config.is_rnn + return is_rnn + + @classmethod + def get_from_yaml(cls, path: Optional[str] = None): + raise NotImplementedError diff --git a/benchmarl/utils.py b/benchmarl/utils.py index efe36e27..b5686ffb 100644 --- a/benchmarl/utils.py +++ b/benchmarl/utils.py @@ -80,11 +80,13 @@ def _add_rnn_transforms( def model_fun(): env = env_fun() - spec_actor = model_config.get_model_state_spec() spec_actor = Composite( { group: Composite( - spec_actor.expand(len(agents), *spec_actor.shape), + model_config._get_model_state_spec_inner(group=group).expand( + len(agents), + *model_config._get_model_state_spec_inner(group=group).shape + ), shape=(len(agents),), ) for group, agents in group_map.items() diff --git a/examples/ensemble/ensemble_model.py b/examples/ensemble/ensemble_model.py new file mode 100644 index 00000000..bee1f7b5 --- /dev/null +++ b/examples/ensemble/ensemble_model.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +from benchmarl.algorithms import MappoConfig +from benchmarl.environments import VmasTask +from benchmarl.experiment import Experiment, ExperimentConfig +from benchmarl.models import EnsembleModelConfig, GnnConfig, MlpConfig + + +if __name__ == "__main__": + + # Loads from "benchmarl/conf/experiment/base_experiment.yaml" + experiment_config = ExperimentConfig.get_from_yaml() + + # Loads from "benchmarl/conf/task/vmas/simple_tag.yaml" + task = VmasTask.SIMPLE_TAG.get_from_yaml() + + # Loads from "benchmarl/conf/algorithm/mappo.yaml" + algorithm_config = MappoConfig.get_from_yaml() + + # Loads from "benchmarl/conf/model/layers/mlp.yaml" + critic_model_config = MlpConfig.get_from_yaml() + + model_config = EnsembleModelConfig( + {"agent": MlpConfig.get_from_yaml(), "adversary": GnnConfig.get_from_yaml()} + ) + + experiment = Experiment( + task=task, + algorithm_config=algorithm_config, + model_config=model_config, + critic_model_config=critic_model_config, + seed=0, + config=experiment_config, + ) + experiment.run() From 71d46cdea5461e3e07f18e4c96e4c2502f416ecc Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 7 Jan 2025 18:58:54 +0100 Subject: [PATCH 2/8] error check --- benchmarl/models/common.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index 4cfc04aa..eb277d9f 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -426,6 +426,13 @@ class SequenceModelConfig(ModelConfig): model_configs: Sequence[ModelConfig] intermediate_sizes: Sequence[int] + def __post_init__(self): + for model_config in self.model_configs: + if isinstance(model_config, EnsembleModelConfig): + raise TypeError( + "SequenceModelConfig cannot contain EnsembleModelConfig layers, but the opposite can be done." + ) + def get_model( self, input_spec: Composite, From 441d897e47819574e080d21d7c14f4690a9b3b89 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 7 Jan 2025 19:04:34 +0100 Subject: [PATCH 3/8] error --- benchmarl/models/common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index eb277d9f..d7976c2b 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -542,6 +542,10 @@ class EnsembleModelConfig(ModelConfig): model_configs_map: Dict[str, ModelConfig] def get_model(self, agent_group: str, **kwargs) -> Model: + if agent_group not in self.model_configs_map.keys(): + raise ValueError( + f"Environment contains agent group '{agent_group}' not present in the EnsembleModelConfig configuration." + ) return self.model_configs_map[agent_group].get_model( **kwargs, agent_group=agent_group ) From 811396232fa965098b1ea475fca00e128c479d86 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 7 Jan 2025 23:00:30 +0100 Subject: [PATCH 4/8] ensemble algorithm --- benchmarl/algorithms/__init__.py | 1 + benchmarl/algorithms/ensemble.py | 128 ++++++++++++++++++ examples/ensemble/ensemble_algorithm.py | 38 ++++++ .../ensemble/ensemble_algorithm_and_model.py | 44 ++++++ 4 files changed, 211 insertions(+) create mode 100644 benchmarl/algorithms/ensemble.py create mode 100644 examples/ensemble/ensemble_algorithm.py create mode 100644 examples/ensemble/ensemble_algorithm_and_model.py diff --git a/benchmarl/algorithms/__init__.py b/benchmarl/algorithms/__init__.py index f0e2d20a..5b6d79ce 100644 --- a/benchmarl/algorithms/__init__.py +++ b/benchmarl/algorithms/__init__.py @@ -5,6 +5,7 @@ # from .common import Algorithm, AlgorithmConfig +from .ensemble import EnsembleAlgorithm, EnsembleAlgorithmConfig from .iddpg import Iddpg, IddpgConfig from .ippo import Ippo, IppoConfig from .iql import Iql, IqlConfig diff --git a/benchmarl/algorithms/ensemble.py b/benchmarl/algorithms/ensemble.py new file mode 100644 index 00000000..29577451 --- /dev/null +++ b/benchmarl/algorithms/ensemble.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +from dataclasses import dataclass +from typing import Dict, Iterable, Optional, Tuple, Type + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModule + +from torchrl.objectives import LossModule + +from benchmarl.algorithms.common import Algorithm, AlgorithmConfig + +from benchmarl.models.common import ModelConfig + + +class EnsembleAlgorithm(Algorithm): + def __init__(self, algorithms_map, **kwargs): + super().__init__(**kwargs) + self.algorithms_map = algorithms_map + + def _get_loss( + self, group: str, policy_for_loss: TensorDictModule, continuous: bool + ) -> Tuple[LossModule, bool]: + return self.algorithms_map[group]._get_loss(group, policy_for_loss, continuous) + + def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]: + return self.algorithms_map[group]._get_parameters(group, loss) + + def _get_policy_for_loss( + self, group: str, model_config: ModelConfig, continuous: bool + ) -> TensorDictModule: + return self.algorithms_map[group]._get_policy_for_loss( + group, model_config, continuous + ) + + def _get_policy_for_collection( + self, policy_for_loss: TensorDictModule, group: str, continuous: bool + ) -> TensorDictModule: + return self.algorithms_map[group]._get_policy_for_collection( + policy_for_loss, group, continuous + ) + + def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: + return self.algorithms_map[group].process_batch(group, batch) + + def process_loss_vals( + self, group: str, loss_vals: TensorDictBase + ) -> TensorDictBase: + return self.algorithms_map[group].process_loss_vals(group, loss_vals) + + +@dataclass +class EnsembleAlgorithmConfig(AlgorithmConfig): + + algorithm_configs_map: Dict[str, AlgorithmConfig] + + def __post_init__(self): + algorithm_configs = list(self.algorithm_configs_map.values()) + self._on_policy = algorithm_configs[0].on_policy() + + for algorithm_config in algorithm_configs[1:]: + if algorithm_config.on_policy() != self._on_policy: + raise ValueError( + "Algorithms in EnsembleAlgorithmConfig must either be all on_policy or all off_policy" + ) + + if ( + not self.supports_discrete_actions() + and not self.supports_continuous_actions() + ): + raise ValueError( + "Ensemble algorithm does not support discrete actions nor continuous actions." + " Make sure that at least one type of action is supported across all the algorithms used." + ) + + def get_algorithm(self, experiment) -> Algorithm: + return self.associated_class()( + algorithms_map={ + group: algorithm_config.get_algorithm(experiment) + for group, algorithm_config in self.algorithm_configs_map.items() + }, + experiment=experiment, + ) + + @classmethod + def get_from_yaml(cls, path: Optional[str] = None): + raise NotImplementedError + + @staticmethod + def associated_class() -> Type[Algorithm]: + return EnsembleAlgorithm + + def on_policy(self) -> bool: + return self._on_policy + + def supports_continuous_actions(self) -> bool: + supports_continuous_actions = True + for algorithm_config in self.algorithm_configs_map.values(): + supports_continuous_actions *= ( + algorithm_config.supports_continuous_actions() + ) + return supports_continuous_actions + + def supports_discrete_actions(self) -> bool: + supports_discrete_actions = True + for algorithm_config in self.algorithm_configs_map.values(): + supports_discrete_actions *= algorithm_config.supports_discrete_actions() + return supports_discrete_actions + + def has_independent_critic(self) -> bool: + has_independent_critic = False + for algorithm_config in self.algorithm_configs_map.values(): + has_independent_critic += algorithm_config.has_independent_critic() + return has_independent_critic + + def has_centralized_critic(self) -> bool: + has_centralized_critic = False + for algorithm_config in self.algorithm_configs_map.values(): + has_centralized_critic += algorithm_config.has_centralized_critic() + return has_centralized_critic + + def has_critic(self) -> bool: + return self.has_centralized_critic() or self.has_independent_critic() diff --git a/examples/ensemble/ensemble_algorithm.py b/examples/ensemble/ensemble_algorithm.py new file mode 100644 index 00000000..b7affa7c --- /dev/null +++ b/examples/ensemble/ensemble_algorithm.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +from benchmarl.algorithms import EnsembleAlgorithmConfig, IsacConfig, MaddpgConfig +from benchmarl.environments import VmasTask +from benchmarl.experiment import Experiment, ExperimentConfig +from benchmarl.models import MlpConfig + + +if __name__ == "__main__": + + # Loads from "benchmarl/conf/experiment/base_experiment.yaml" + experiment_config = ExperimentConfig.get_from_yaml() + + # Loads from "benchmarl/conf/task/vmas/simple_tag.yaml" + task = VmasTask.SIMPLE_TAG.get_from_yaml() + + # Loads from "benchmarl/conf/model/layers/mlp.yaml" + model_config = MlpConfig.get_from_yaml() + critic_model_config = MlpConfig.get_from_yaml() + + algorithm_config = EnsembleAlgorithmConfig( + {"agent": MaddpgConfig.get_from_yaml(), "adversary": IsacConfig.get_from_yaml()} + ) + + experiment = Experiment( + task=task, + algorithm_config=algorithm_config, + model_config=model_config, + critic_model_config=critic_model_config, + seed=0, + config=experiment_config, + ) + experiment.run() diff --git a/examples/ensemble/ensemble_algorithm_and_model.py b/examples/ensemble/ensemble_algorithm_and_model.py new file mode 100644 index 00000000..f097bb04 --- /dev/null +++ b/examples/ensemble/ensemble_algorithm_and_model.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +from benchmarl.algorithms import EnsembleAlgorithmConfig, IppoConfig, MappoConfig +from benchmarl.environments import VmasTask +from benchmarl.experiment import Experiment, ExperimentConfig +from benchmarl.models import MlpConfig +from models import DeepsetsConfig, EnsembleModelConfig, GnnConfig + +if __name__ == "__main__": + + # Loads from "benchmarl/conf/experiment/base_experiment.yaml" + experiment_config = ExperimentConfig.get_from_yaml() + + # Loads from "benchmarl/conf/task/vmas/simple_tag.yaml" + task = VmasTask.SIMPLE_TAG.get_from_yaml() + + algorithm_config = EnsembleAlgorithmConfig( + {"agent": MappoConfig.get_from_yaml(), "adversary": IppoConfig.get_from_yaml()} + ) + + model_config = EnsembleModelConfig( + {"agent": MlpConfig.get_from_yaml(), "adversary": GnnConfig.get_from_yaml()} + ) + critic_model_config = EnsembleModelConfig( + { + "agent": DeepsetsConfig.get_from_yaml(), + "adversary": MlpConfig.get_from_yaml(), + } + ) + + experiment = Experiment( + task=task, + algorithm_config=algorithm_config, + model_config=model_config, + critic_model_config=critic_model_config, + seed=0, + config=experiment_config, + ) + experiment.run() From 44a735e6dc7c48427a2c6665781d6c967aaa936d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 22 Jan 2025 18:09:06 +0000 Subject: [PATCH 5/8] test --- benchmarl/models/gnn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 855ba81b..c238b17a 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -466,6 +466,7 @@ class GnnConfig(ModelConfig): """Dataclass config for a :class:`~benchmarl.models.Gnn`.""" topology: str = MISSING + self_loops: bool = MISSING gnn_class: Type[torch_geometric.nn.MessagePassing] = MISSING From 985c3b0855a36e1d35f5c1b8968087d078dde5c8 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 22 Jan 2025 18:09:20 +0000 Subject: [PATCH 6/8] test --- benchmarl/models/gnn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index c238b17a..855ba81b 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -466,7 +466,6 @@ class GnnConfig(ModelConfig): """Dataclass config for a :class:`~benchmarl.models.Gnn`.""" topology: str = MISSING - self_loops: bool = MISSING gnn_class: Type[torch_geometric.nn.MessagePassing] = MISSING From e2f83223d0912ea71ada077de552063a45adb09c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 26 Jan 2025 14:06:48 +0000 Subject: [PATCH 7/8] docs --- benchmarl/algorithms/ensemble.py | 5 ++++ docs/source/concepts/features.rst | 46 +++++++++++++++++++++++++++++++ examples/ensemble/README.md | 34 +++++++++++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 examples/ensemble/README.md diff --git a/benchmarl/algorithms/ensemble.py b/benchmarl/algorithms/ensemble.py index 29577451..9e8360f9 100644 --- a/benchmarl/algorithms/ensemble.py +++ b/benchmarl/algorithms/ensemble.py @@ -79,6 +79,11 @@ def __post_init__(self): ) def get_algorithm(self, experiment) -> Algorithm: + if set(self.algorithm_configs_map.keys()) != set(experiment.group_map.keys()): + raise ValueError( + f"EnsembleAlgorithm group names {self.algorithm_configs_map.keys()} do not match " + f"environment group names {experiment.group_map.keys()}" + ) return self.associated_class()( algorithms_map={ group: algorithm_config.get_algorithm(experiment) diff --git a/docs/source/concepts/features.rst b/docs/source/concepts/features.rst index cec3ce8d..873d4524 100644 --- a/docs/source/concepts/features.rst +++ b/docs/source/concepts/features.rst @@ -106,3 +106,49 @@ as: .. python_example_button:: https://github.com/facebookresearch/BenchMARL/blob/main/examples/callback/custom_callback.py + +Ensemble models and algorithms +------------------------------ + +It is possible to use different algorithms and models for different agent groups. + +Ensemble algorithm +^^^^^^^^^^^^^^^^^^ + +Ensemble algorithms take as input a dictionary mapping group names to algorithm configs: + +.. code-block:: python + + from benchmarl.algorithms import EnsembleAlgorithmConfig, IsacConfig, MaddpgConfig + + algorithm_config = EnsembleAlgorithmConfig( + {"agent": MaddpgConfig.get_from_yaml(), "adversary": IsacConfig.get_from_yaml()} + ) + +.. note:: + All algorithms need to be on-policy or off-policy, it is not possible to mix the two paradigms. + + +.. python_example_button:: + https://github.com/facebookresearch/BenchMARL/blob/main/examples/ensemble/ensemble_algorithm.py + + +Ensemble model +^^^^^^^^^^^^^^ + +Ensemble models take as input a dictionary mapping group names to model configs: + +.. code-block:: python + + from benchmarl.models import EnsembleModelConfig, GnnConfig, MlpConfig + + model_config = EnsembleModelConfig( + {"agent": MlpConfig.get_from_yaml(), "adversary": GnnConfig.get_from_yaml()} + ) + + +.. note:: + If you use ensemble models with sequence models, make sure the ensemble is the outer layer (you cannot make a sequence of ensembles, but an ensemble of sequences yes). + +.. python_example_button:: + https://github.com/facebookresearch/BenchMARL/blob/main/examples/ensemble/ensemble_algorithm.py diff --git a/examples/ensemble/README.md b/examples/ensemble/README.md new file mode 100644 index 00000000..55eec026 --- /dev/null +++ b/examples/ensemble/README.md @@ -0,0 +1,34 @@ +# Different components for different groups + +It is possible to use different algorithms and models for different agent groups. + +In this folder, we provide examples on how to do this. + +## Ensemble algorithm + +Ensemble algorithms take as input a dictionary mapping group names to algorithm configs: + +```pyhton +from benchmarl.algorithms import EnsembleAlgorithmConfig, IsacConfig, MaddpgConfig + +algorithm_config = EnsembleAlgorithmConfig( + {"agent": MaddpgConfig.get_from_yaml(), "adversary": IsacConfig.get_from_yaml()} +) +``` + +**Important: All algorithms need to be on-policy or off-policy, it is not possible to mix the two paradigms.** + +## Ensemble model + +Ensemble models take as input a dictionary mapping group names to model configs: + +```pyhton +from benchmarl.models import EnsembleModelConfig, GnnConfig, MlpConfig + +model_config = EnsembleModelConfig( + {"agent": MlpConfig.get_from_yaml(), "adversary": GnnConfig.get_from_yaml()} +) +``` + +**Important: if you use ensemble models with sequence models, make sure the ensemble is the outer layer (you cannot make a +sequence of ensembles, but an ensemble of sequences yes).** From afdfb4cbccd3cf535ce29f4bcb037c08dbb03b54 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 26 Jan 2025 14:11:56 +0000 Subject: [PATCH 8/8] docs --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index cdef823e..873737bd 100644 --- a/README.md +++ b/README.md @@ -442,12 +442,13 @@ a script [![Example](https://img.shields.io/badge/Example-blue.svg)](examples/co BenchMARL has several features: - A test CI with integration and training test routines that are run for all simulators and algorithms - Integration in the official TorchRL ecosystem for dedicated support +- Possibility of using different algorithms and models for different agent groups (see [`examples/ensemble`](examples/ensemble)) ### Logging BenchMARL is compatible with the [TorchRL loggers](https://github.com/pytorch/rl/tree/main/torchrl/record/loggers). -A list of logger names can be provided in the [experiment config](benchmarl/conf/experiment/base_experiment.yaml). +A list of logger names can be provided in the [experiment config])(benchmarl/conf/experiment/base_experiment.yaml. Example of available options are: `wandb`, `csv`, `mflow`, `tensorboard` or any other option available in TorchRL. You can specify the loggers in the yaml config files or in the script arguments like so: ```bash