From 716350e400d0010be2c0fcf88a575a31bbb7fb46 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 31 Jul 2024 23:14:27 +0200 Subject: [PATCH] amend --- benchmarl/algorithms/common.py | 1 + benchmarl/models/common.py | 24 +++++++++++++++++++ .../extending/model/models/custommodel.py | 8 +++++++ 3 files changed, 33 insertions(+) diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index e40e65b9..f76e57ec 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -250,6 +250,7 @@ def process_env_fun( ) -> Callable[[], EnvBase]: """ This function can be used to wrap env_fun + Args: env_fun (callable): a function that takes no args and creates an enviornment diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index f8af08fe..a4371f04 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -74,6 +74,8 @@ class Model(TensorDictModuleBase, ABC): This is independent of the other options as it is possible to have different parameters for centralized critics with global input. action_spec (CompositeSpec): The action spec of the environment + model_index (int): the index of the model in a sequence + is_critic (bool): Whether the model is a critic """ def __init__( @@ -281,6 +283,7 @@ def get_model( This is independent of the other options as it is possible to have different parameters for centralized critics with global input. action_spec (CompositeSpec): The action spec of the environment + model_index (int): the index of the model in a sequence. Defaults to 0. Returns: the Model @@ -310,19 +313,40 @@ def associated_class(): @property def is_rnn(self) -> bool: + """ + Whether the model is an RNN + """ return False @property def is_critic(self): + """ + Whether the model is a critic + """ if not hasattr(self, "_is_critic"): self._is_critic = False return self._is_critic @is_critic.setter def is_critic(self, value): + """ + Set whether the model is a critic + """ self._is_critic = value def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec: + """Get additional specs needed by the model as input. + + This method is useful for adding recurrent states. + + The returned value should be key: spec with the desired ending shape. + + The batch and agent dimensions will automatically be added to the spec. + + Args: + model_index (int, optional): the index of the model. Defaults to 0.: + + """ return CompositeSpec() @staticmethod diff --git a/examples/extending/model/models/custommodel.py b/examples/extending/model/models/custommodel.py index 7dea32cb..72d54ac6 100644 --- a/examples/extending/model/models/custommodel.py +++ b/examples/extending/model/models/custommodel.py @@ -176,6 +176,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: @dataclass class CustomModelConfig(ModelConfig): + # The config parameters for this class, these will be loaded from yaml custom_param: int = MISSING activation_class: Type[nn.Module] = MISSING @@ -184,3 +185,10 @@ class CustomModelConfig(ModelConfig): def associated_class(): # The associated algorithm class return CustomModel + + @property + def is_rnn(self) -> bool: + """ + Whether the model is an RNN + """ + return False