From 286061c046e517d3c21184f557bb96ec4a7051b6 Mon Sep 17 00:00:00 2001 From: Lorenz Wellhausen Date: Tue, 23 Jan 2024 19:28:46 +0100 Subject: [PATCH 1/4] Generalized reward manager to multiple groups --- .../isaac/orbit/managers/manager_term_cfg.py | 6 + .../isaac/orbit/managers/reward_manager.py | 219 ++++++++++++------ 2 files changed, 158 insertions(+), 67 deletions(-) diff --git a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py index d6602b252f..fd74b5c6e3 100644 --- a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py +++ b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py @@ -199,6 +199,12 @@ class RandomizationTermCfg(ManagerTermBaseCfg): # Reward manager. ## +@configclass +class RewardGroupCfg: + # Reserved for future use. + # No parameters, yet. + pass + @configclass class RewardTermCfg(ManagerTermBaseCfg): diff --git a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py index 15b5b30e12..8d25fcda03 100644 --- a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py +++ b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py @@ -12,12 +12,15 @@ from typing import TYPE_CHECKING, Sequence from .manager_base import ManagerBase, ManagerTermBase -from .manager_term_cfg import RewardTermCfg +from .manager_term_cfg import RewardTermCfg, RewardGroupCfg if TYPE_CHECKING: from omni.isaac.orbit.envs import RLTaskEnv +DEFAULT_GROUP_NAME = "rewards" + + class RewardManager(ManagerBase): """Manager for computing reward signals for a given world. @@ -26,7 +29,11 @@ class RewardManager(ManagerBase): terms configuration. The reward terms are parsed from a config class containing the manager's settings and each term's - parameters. Each reward term should instantiate the :class:`RewardTermCfg` class. + parameters. + + Rewards are organized into groups, for multi-critic or CMDP use-cases. + Each rewards group shouuld inherit from the :class:`RewardGroupCfg` class. + Within each group, each reward term should inherit from the :class:`RewardTermCfg` class. .. note:: @@ -36,6 +43,7 @@ class RewardManager(ManagerBase): """ + _env: RLTaskEnv """The environment instance.""" @@ -43,34 +51,50 @@ def __init__(self, cfg: object, env: RLTaskEnv): """Initialize the reward manager. Args: - cfg: The configuration object or dictionary (``dict[str, RewardTermCfg]``). + cfg: The configuration object or dictionary (``dict[str, RewardGroupCfg]``). env: The environment instance. """ + # Variable to track whether we have reward groups or not. + # Needs to be set before we call super().__init__ because it's needed in prepare_terms. + self.no_group = None super().__init__(cfg, env) - # prepare extra info to store individual reward term information + + # Allocate storage for reward terms. self._episode_sums = dict() - for term_name in self._term_names: - self._episode_sums[term_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) - # create buffer for managing reward per environment - self._reward_buf = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) + self._reward_buf = {} + self._term_names_flat = [] # flat list of all term names + for group_name, group_term_names in self._group_term_names.items(): + for term_name in group_term_names: + sum_term_name = term_name if self.no_group else f"{group_name}/{term_name}" + self._episode_sums[sum_term_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) + + self._term_names_flat.append(sum_term_name) + + # create buffer for managing reward per environment + self._reward_buf[group_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) + + def __str__(self) -> str: """Returns: A string representation for reward manager.""" - msg = f" contains {len(self._term_names)} active terms.\n" + # Get number of reward terms. + msg = f" contains {len(self._term_names_flat)} active terms.\n" # create table for term information - table = PrettyTable() - table.title = "Active Reward Terms" - table.field_names = ["Index", "Name", "Weight"] - # set alignment of table columns - table.align["Name"] = "l" - table.align["Weight"] = "r" - # add info on each term - for index, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)): - table.add_row([index, name, term_cfg.weight]) - # convert table to string - msg += table.get_string() - msg += "\n" + for group_name in self._group_term_names.keys(): + table = PrettyTable() + table.title = "Active Reward Terms In Group: " + group_name + table.field_names = ["Index", "Name", "Weight"] + # set alignment of table columns + table.align["Name"] = "l" + table.align["Weight"] = "r" + # add info on each term + for index, (name, term_cfg) in enumerate(zip(self._group_term_names[group_name], + self._group_term_cfgs[group_name])): + table.add_row([index, name, term_cfg.weight]) + # convert table to string + msg += table.get_string() + msg += "\n" return msg @@ -81,7 +105,7 @@ def __str__(self) -> str: @property def active_terms(self) -> list[str]: """Name of active reward terms.""" - return self._term_names + return self._term_names_flat """ Operations. @@ -110,8 +134,9 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor] # reset episodic sum self._episode_sums[key][env_ids] = 0.0 # reset all the reward terms - for term_cfg in self._class_term_cfgs: - term_cfg.func.reset(env_ids=env_ids) + for group_cfg in self._group_class_term_cfgs.values(): + for term_cfg in group_cfg: + term_cfg.func.reset(env_ids=env_ids) # return logged information return extras @@ -128,20 +153,28 @@ def compute(self, dt: float) -> torch.Tensor: The net reward signal of shape (num_envs,). """ # reset computation - self._reward_buf[:] = 0.0 - # iterate over all the reward terms - for name, term_cfg in zip(self._term_names, self._term_cfgs): - # skip if weight is zero (kind of a micro-optimization) - if term_cfg.weight == 0.0: - continue - # compute term's value - value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt - # update total reward - self._reward_buf += value - # update episodic sum - self._episode_sums[name] += value - - return self._reward_buf + for key in self._reward_buf.keys(): + self._reward_buf[key][:] = 0.0 + # iterate over all reward terms of all groups + for group_name in self._group_term_names.keys(): + # iterate over all the reward terms + for term_name, term_cfg in zip(self._group_term_names[group_name], self._group_term_cfgs[group_name]): + # skip if weight is zero (kind of a micro-optimization) + if term_cfg.weight == 0.0: + continue + # compute term's value + value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt + # update total reward + self._reward_buf[group_name] += value + # update episodic sum + name = term_name if self.no_group else f"{group_name}/{term_name}" + self._episode_sums[name] += value + + # Return only Tensor if config has no groups. + if self.no_group: + return self._reward_buf[DEFAULT_GROUP_NAME] + else: + return self._reward_buf """ Operations - Term settings. @@ -157,10 +190,19 @@ def set_term_cfg(self, term_name: str, cfg: RewardTermCfg): Raises: ValueError: If the term name is not found. """ - if term_name not in self._term_names: + # Split term_name at '/' if it has one. + if "/" in term_name: + group_name, term_name = term_name.split("/") + else: + group_name = DEFAULT_GROUP_NAME + + if group_name not in self._group_term_names: + raise ValueError(f"Reward group '{group_name}' not found.") + if term_name not in self._group_term_names[group_name]: raise ValueError(f"Reward term '{term_name}' not found.") # set the configuration - self._term_cfgs[self._term_names.index(term_name)] = cfg + self._group_term_cfgs[group_name][self._group_term_names[group_name].index(term_name)] = cfg + def get_term_cfg(self, term_name: str) -> RewardTermCfg: """Gets the configuration for the specified term. @@ -174,10 +216,19 @@ def get_term_cfg(self, term_name: str) -> RewardTermCfg: Raises: ValueError: If the term name is not found. """ - if term_name not in self._term_names: + # Split term_name at '/' if it has one. + if "/" in term_name: + group_name, term_name = term_name.split("/") + else: + group_name = DEFAULT_GROUP_NAME + + if group_name not in self._group_term_names: + raise ValueError(f"Reward group '{group_name}' not found.") + if term_name not in self._group_term_names[group_name]: raise ValueError(f"Reward term '{term_name}' not found.") # return the configuration - return self._term_cfgs[self._term_names.index(term_name)] + return self._group_term_cfgs[group_name][self._group_term_names[group_name].index(term_name)] + """ Helper functions. @@ -185,38 +236,72 @@ def get_term_cfg(self, term_name: str) -> RewardTermCfg: def _prepare_terms(self): """Prepares a list of reward functions.""" - # parse remaining reward terms and decimate their information - self._term_names: list[str] = list() - self._term_cfgs: list[RewardTermCfg] = list() - self._class_term_cfgs: list[RewardTermCfg] = list() + + self._group_term_names: dict[str, list[str]] = dict() + self._group_term_cfgs: dict[str, list[RewardTermCfg]] = dict() + self._group_class_term_cfgs: dict[str, list[RewardTermCfg]] = dict() # check if config is dict already if isinstance(self.cfg, dict): cfg_items = self.cfg.items() else: cfg_items = self.cfg.__dict__.items() - # iterate over all the terms - for term_name, term_cfg in cfg_items: + + # Check whether we have a group or not and fail if we have a mix. + for name, cfg in cfg_items: # check for non config - if term_cfg is None: + if cfg is None: continue - # check for valid config type - if not isinstance(term_cfg, RewardTermCfg): + if isinstance(cfg, RewardGroupCfg): + if self.no_group is None: + self.no_group = False + elif self.no_group is True: + raise ValueError("Cannot mix reward groups with reward terms.") + elif isinstance(cfg, RewardTermCfg): + if self.no_group is None: + self.no_group = True + elif self.no_group is False: + raise ValueError("Cannot mix reward groups with reward terms.") + else: raise TypeError( - f"Configuration for the term '{term_name}' is not of type RewardTermCfg." - f" Received: '{type(term_cfg)}'." + f"Configuration for the group or term'{cfg}' is not of type RewardGroupCfg or RewardTermCfg." + f" Received: '{type(cfg)}'." ) - # check for valid weight type - if not isinstance(term_cfg.weight, (float, int)): - raise TypeError( - f"Weight for the term '{term_name}' is not of type float or int." - f" Received: '{type(term_cfg.weight)}'." - ) - # resolve common parameters - self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1) - # add function to list - self._term_names.append(term_name) - self._term_cfgs.append(term_cfg) - # check if the term is a class - if isinstance(term_cfg.func, ManagerTermBase): - self._class_term_cfgs.append(term_cfg) + + # Make a group if we do not have one. + if self.no_group: + cfg_items = {DEFAULT_GROUP_NAME: dict(cfg_items)}.items() + + # iterate over all the groups + for group_name, group_cfg in cfg_items: + self._group_term_names[group_name] = list() + self._group_term_cfgs[group_name] = list() + self._group_class_term_cfgs[group_name] = list() + + # Iterate over all the terms in the group + for term_name, term_cfg in group_cfg.items(): + # check for non config + if term_cfg is None: + continue + # check for valid config type + if not isinstance(term_cfg, RewardTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type RewardTermCfg." + f" Received: '{type(term_cfg)}'." + ) + # check for valid weight type + if not isinstance(term_cfg.weight, (float, int)): + raise TypeError( + f"Weight for the term '{term_name}' is not of type float or int." + f" Received: '{type(term_cfg.weight)}'." + ) + # resolve common terms in the config + self._resolve_common_term_cfg(f"{group_name}/{term_name}", term_cfg, min_argc=1) + # add term config to list + self._group_term_names[group_name].append(term_name) + self._group_term_cfgs[group_name].append(term_cfg) + # add term to separate list if term is a class + if isinstance(term_cfg.func, ManagerTermBase): + self._group_class_term_cfgs[group_name].append(term_cfg) + # call reset on the term + term_cfg.func.reset() From 8d121b25509e3827ae784fd1ff3c2be78342d5ba Mon Sep 17 00:00:00 2001 From: Lorenz Wellhausen Date: Thu, 25 Jan 2024 10:39:48 +0100 Subject: [PATCH 2/4] Changed efault reward group name --- .../omni/isaac/orbit/managers/reward_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py index 8d25fcda03..91f0599512 100644 --- a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py +++ b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py @@ -18,7 +18,7 @@ from omni.isaac.orbit.envs import RLTaskEnv -DEFAULT_GROUP_NAME = "rewards" +DEFAULT_GROUP_NAME = "reward" class RewardManager(ManagerBase): From 51491668945bec8573f89443ed7460a36ea20c1a Mon Sep 17 00:00:00 2001 From: Lorenz Wellhausen Date: Thu, 25 Jan 2024 19:49:09 +0100 Subject: [PATCH 3/4] Small fix in reward manager and formatting --- .../isaac/orbit/managers/reward_manager.py | 74 ++++++++++++------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py index 91f0599512..2aad38e0d2 100644 --- a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py +++ b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py @@ -43,7 +43,6 @@ class RewardManager(ManagerBase): """ - _env: RLTaskEnv """The environment instance.""" @@ -65,15 +64,19 @@ def __init__(self, cfg: object, env: RLTaskEnv): self._term_names_flat = [] # flat list of all term names for group_name, group_term_names in self._group_term_names.items(): for term_name in group_term_names: - sum_term_name = term_name if self.no_group else f"{group_name}/{term_name}" - self._episode_sums[sum_term_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) + sum_term_name = ( + term_name if self.no_group else f"{group_name}/{term_name}" + ) + self._episode_sums[sum_term_name] = torch.zeros( + self.num_envs, dtype=torch.float, device=self.device + ) self._term_names_flat.append(sum_term_name) # create buffer for managing reward per environment - self._reward_buf[group_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) - - + self._reward_buf[group_name] = torch.zeros( + self.num_envs, dtype=torch.float, device=self.device + ) def __str__(self) -> str: """Returns: A string representation for reward manager.""" @@ -89,8 +92,12 @@ def __str__(self) -> str: table.align["Name"] = "l" table.align["Weight"] = "r" # add info on each term - for index, (name, term_cfg) in enumerate(zip(self._group_term_names[group_name], - self._group_term_cfgs[group_name])): + for index, (name, term_cfg) in enumerate( + zip( + self._group_term_names[group_name], + self._group_term_cfgs[group_name], + ) + ): table.add_row([index, name, term_cfg.weight]) # convert table to string msg += table.get_string() @@ -130,7 +137,9 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor] # store information # r_1 + r_2 + ... + r_n episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids]) - extras["Episode Reward/" + key] = episodic_sum_avg / self._env.max_episode_length_s + extras["Episode Reward/" + key] = ( + episodic_sum_avg / self._env.max_episode_length_s + ) # reset episodic sum self._episode_sums[key][env_ids] = 0.0 # reset all the reward terms @@ -158,18 +167,22 @@ def compute(self, dt: float) -> torch.Tensor: # iterate over all reward terms of all groups for group_name in self._group_term_names.keys(): # iterate over all the reward terms - for term_name, term_cfg in zip(self._group_term_names[group_name], self._group_term_cfgs[group_name]): + for term_name, term_cfg in zip( + self._group_term_names[group_name], self._group_term_cfgs[group_name] + ): # skip if weight is zero (kind of a micro-optimization) if term_cfg.weight == 0.0: continue # compute term's value - value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt + value = ( + term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt + ) # update total reward self._reward_buf[group_name] += value # update episodic sum name = term_name if self.no_group else f"{group_name}/{term_name}" self._episode_sums[name] += value - + # Return only Tensor if config has no groups. if self.no_group: return self._reward_buf[DEFAULT_GROUP_NAME] @@ -201,8 +214,9 @@ def set_term_cfg(self, term_name: str, cfg: RewardTermCfg): if term_name not in self._group_term_names[group_name]: raise ValueError(f"Reward term '{term_name}' not found.") # set the configuration - self._group_term_cfgs[group_name][self._group_term_names[group_name].index(term_name)] = cfg - + self._group_term_cfgs[group_name][ + self._group_term_names[group_name].index(term_name) + ] = cfg def get_term_cfg(self, term_name: str) -> RewardTermCfg: """Gets the configuration for the specified term. @@ -227,8 +241,9 @@ def get_term_cfg(self, term_name: str) -> RewardTermCfg: if term_name not in self._group_term_names[group_name]: raise ValueError(f"Reward term '{term_name}' not found.") # return the configuration - return self._group_term_cfgs[group_name][self._group_term_names[group_name].index(term_name)] - + return self._group_term_cfgs[group_name][ + self._group_term_names[group_name].index(term_name) + ] """ Helper functions. @@ -252,21 +267,16 @@ def _prepare_terms(self): # check for non config if cfg is None: continue - if isinstance(cfg, RewardGroupCfg): - if self.no_group is None: - self.no_group = False - elif self.no_group is True: - raise ValueError("Cannot mix reward groups with reward terms.") - elif isinstance(cfg, RewardTermCfg): + if isinstance(cfg, RewardTermCfg): if self.no_group is None: self.no_group = True elif self.no_group is False: raise ValueError("Cannot mix reward groups with reward terms.") else: - raise TypeError( - f"Configuration for the group or term'{cfg}' is not of type RewardGroupCfg or RewardTermCfg." - f" Received: '{type(cfg)}'." - ) + if self.no_group is None: + self.no_group = False + elif self.no_group is True: + raise ValueError("Cannot mix reward groups with reward terms.") # Make a group if we do not have one. if self.no_group: @@ -278,8 +288,14 @@ def _prepare_terms(self): self._group_term_cfgs[group_name] = list() self._group_class_term_cfgs[group_name] = list() + # Make group config a list if it is not. + if isinstance(group_cfg, dict): + group_cfg_items = group_cfg.items() + else: + group_cfg_items = group_cfg.__dict__.items() + # Iterate over all the terms in the group - for term_name, term_cfg in group_cfg.items(): + for term_name, term_cfg in group_cfg_items: # check for non config if term_cfg is None: continue @@ -296,7 +312,9 @@ def _prepare_terms(self): f" Received: '{type(term_cfg.weight)}'." ) # resolve common terms in the config - self._resolve_common_term_cfg(f"{group_name}/{term_name}", term_cfg, min_argc=1) + self._resolve_common_term_cfg( + f"{group_name}/{term_name}", term_cfg, min_argc=1 + ) # add term config to list self._group_term_names[group_name].append(term_name) self._group_term_cfgs[group_name].append(term_cfg) From ea5055d4b387fbf8767639acfb41d73baf885930 Mon Sep 17 00:00:00 2001 From: Lorenz Wellhausen Date: Mon, 29 Jan 2024 17:54:39 +0100 Subject: [PATCH 4/4] Add contributor and format --- CONTRIBUTORS.md | 1 + .../isaac/orbit/managers/manager_term_cfg.py | 1 + .../isaac/orbit/managers/reward_manager.py | 38 +++++-------------- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 113f5863c0..f9d44c90bc 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -37,6 +37,7 @@ Guidelines for modifications: * Chenyu Yang * Jia Lin Yuan * Jingzhou Liu +* Lorenz Wellhausen * Kourosh Darvish * Qinxi Yu * René Zurbrügg diff --git a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py index fd74b5c6e3..fbe5edd9fd 100644 --- a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py +++ b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/manager_term_cfg.py @@ -199,6 +199,7 @@ class RandomizationTermCfg(ManagerTermBaseCfg): # Reward manager. ## + @configclass class RewardGroupCfg: # Reserved for future use. diff --git a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py index 2aad38e0d2..01efce0a9b 100644 --- a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py +++ b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/managers/reward_manager.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Sequence from .manager_base import ManagerBase, ManagerTermBase -from .manager_term_cfg import RewardTermCfg, RewardGroupCfg +from .manager_term_cfg import RewardGroupCfg, RewardTermCfg if TYPE_CHECKING: from omni.isaac.orbit.envs import RLTaskEnv @@ -64,19 +64,13 @@ def __init__(self, cfg: object, env: RLTaskEnv): self._term_names_flat = [] # flat list of all term names for group_name, group_term_names in self._group_term_names.items(): for term_name in group_term_names: - sum_term_name = ( - term_name if self.no_group else f"{group_name}/{term_name}" - ) - self._episode_sums[sum_term_name] = torch.zeros( - self.num_envs, dtype=torch.float, device=self.device - ) + sum_term_name = term_name if self.no_group else f"{group_name}/{term_name}" + self._episode_sums[sum_term_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) self._term_names_flat.append(sum_term_name) # create buffer for managing reward per environment - self._reward_buf[group_name] = torch.zeros( - self.num_envs, dtype=torch.float, device=self.device - ) + self._reward_buf[group_name] = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) def __str__(self) -> str: """Returns: A string representation for reward manager.""" @@ -137,9 +131,7 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor] # store information # r_1 + r_2 + ... + r_n episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids]) - extras["Episode Reward/" + key] = ( - episodic_sum_avg / self._env.max_episode_length_s - ) + extras["Episode Reward/" + key] = episodic_sum_avg / self._env.max_episode_length_s # reset episodic sum self._episode_sums[key][env_ids] = 0.0 # reset all the reward terms @@ -167,16 +159,12 @@ def compute(self, dt: float) -> torch.Tensor: # iterate over all reward terms of all groups for group_name in self._group_term_names.keys(): # iterate over all the reward terms - for term_name, term_cfg in zip( - self._group_term_names[group_name], self._group_term_cfgs[group_name] - ): + for term_name, term_cfg in zip(self._group_term_names[group_name], self._group_term_cfgs[group_name]): # skip if weight is zero (kind of a micro-optimization) if term_cfg.weight == 0.0: continue # compute term's value - value = ( - term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt - ) + value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt # update total reward self._reward_buf[group_name] += value # update episodic sum @@ -214,9 +202,7 @@ def set_term_cfg(self, term_name: str, cfg: RewardTermCfg): if term_name not in self._group_term_names[group_name]: raise ValueError(f"Reward term '{term_name}' not found.") # set the configuration - self._group_term_cfgs[group_name][ - self._group_term_names[group_name].index(term_name) - ] = cfg + self._group_term_cfgs[group_name][self._group_term_names[group_name].index(term_name)] = cfg def get_term_cfg(self, term_name: str) -> RewardTermCfg: """Gets the configuration for the specified term. @@ -241,9 +227,7 @@ def get_term_cfg(self, term_name: str) -> RewardTermCfg: if term_name not in self._group_term_names[group_name]: raise ValueError(f"Reward term '{term_name}' not found.") # return the configuration - return self._group_term_cfgs[group_name][ - self._group_term_names[group_name].index(term_name) - ] + return self._group_term_cfgs[group_name][self._group_term_names[group_name].index(term_name)] """ Helper functions. @@ -312,9 +296,7 @@ def _prepare_terms(self): f" Received: '{type(term_cfg.weight)}'." ) # resolve common terms in the config - self._resolve_common_term_cfg( - f"{group_name}/{term_name}", term_cfg, min_argc=1 - ) + self._resolve_common_term_cfg(f"{group_name}/{term_name}", term_cfg, min_argc=1) # add term config to list self._group_term_names[group_name].append(term_name) self._group_term_cfgs[group_name].append(term_cfg)