Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 31, 2024
1 parent fa6a858 commit 716350e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
1 change: 1 addition & 0 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def process_env_fun(
) -> Callable[[], EnvBase]:
"""
This function can be used to wrap env_fun
Args:
env_fun (callable): a function that takes no args and creates an enviornment
Expand Down
24 changes: 24 additions & 0 deletions benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class Model(TensorDictModuleBase, ABC):
This is independent of the other options as it is possible to have different parameters
for centralized critics with global input.
action_spec (CompositeSpec): The action spec of the environment
model_index (int): the index of the model in a sequence
is_critic (bool): Whether the model is a critic
"""

def __init__(
Expand Down Expand Up @@ -281,6 +283,7 @@ def get_model(
This is independent of the other options as it is possible to have different parameters
for centralized critics with global input.
action_spec (CompositeSpec): The action spec of the environment
model_index (int): the index of the model in a sequence. Defaults to 0.
Returns: the Model
Expand Down Expand Up @@ -310,19 +313,40 @@ def associated_class():

@property
def is_rnn(self) -> bool:
"""
Whether the model is an RNN
"""
return False

@property
def is_critic(self):
"""
Whether the model is a critic
"""
if not hasattr(self, "_is_critic"):
self._is_critic = False
return self._is_critic

@is_critic.setter
def is_critic(self, value):
"""
Set whether the model is a critic
"""
self._is_critic = value

def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec:
"""Get additional specs needed by the model as input.
This method is useful for adding recurrent states.
The returned value should be key: spec with the desired ending shape.
The batch and agent dimensions will automatically be added to the spec.
Args:
model_index (int, optional): the index of the model. Defaults to 0.:
"""
return CompositeSpec()

@staticmethod
Expand Down
8 changes: 8 additions & 0 deletions examples/extending/model/models/custommodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:

@dataclass
class CustomModelConfig(ModelConfig):

# The config parameters for this class, these will be loaded from yaml
custom_param: int = MISSING
activation_class: Type[nn.Module] = MISSING
Expand All @@ -184,3 +185,10 @@ class CustomModelConfig(ModelConfig):
def associated_class():
# The associated algorithm class
return CustomModel

@property
def is_rnn(self) -> bool:
"""
Whether the model is an RNN
"""
return False

0 comments on commit 716350e

Please sign in to comment.