Skip to content

Commit

Permalink
polish(nyz): polish api doc details
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jul 6, 2024
1 parent 96ccaed commit d88ebe2
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 29 deletions.
2 changes: 1 addition & 1 deletion ding/bonus/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
- model (:obj:`torch.nn.Module`): The model of A2C algorithm, which should be an instance of class \
:class:`ding.model.VAC`. \
If not specified, a default model will be generated according to the configuration.
- cfg (:obj:Union[EasyDict, dict]): The configuration of A2C algorithm, which is a dict. \
- cfg (:obj:`Union[EasyDict, dict]`): The configuration of A2C algorithm, which is a dict. \
Default to None. If not specified, the default configuration will be used. \
The default configuration can be found in ``ding/config/example/A2C/gym_lunarlander_v2.py``.
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
Expand Down
5 changes: 2 additions & 3 deletions ding/bonus/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def __init__(
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
- model (:obj:`torch.nn.Module`): The model of C51 algorithm, which should be an instance of class \
:class:`ding.model.C51DQN`. \
If not specified, a default model will be generated according to the configuration.
- cfg (:obj:Union[EasyDict, dict]): The configuration of C51 algorithm, which is a dict. \
:class:`ding.model.C51DQN`. If not specified, a default model will be generated according to the config.
- cfg (:obj:`Union[EasyDict, dict]`): The configuration of C51 algorithm, which is a dict. \
Default to None. If not specified, the default configuration will be used. \
The default configuration can be found in ``ding/config/example/C51/gym_lunarlander_v2.py``.
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
Expand Down
2 changes: 1 addition & 1 deletion ding/bonus/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
- model (:obj:`torch.nn.Module`): The model of DDPG algorithm, which should be an instance of class \
:class:`ding.model.ContinuousQAC`. \
If not specified, a default model will be generated according to the configuration.
- cfg (:obj:Union[EasyDict, dict]): The configuration of DDPG algorithm, which is a dict. \
- cfg (:obj:`Union[EasyDict, dict]`): The configuration of DDPG algorithm, which is a dict. \
Default to None. If not specified, the default configuration will be used. \
The default configuration can be found in ``ding/config/example/DDPG/gym_lunarlander_v2.py``.
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
Expand Down
2 changes: 1 addition & 1 deletion ding/bonus/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
- model (:obj:`torch.nn.Module`): The model of DQN algorithm, which should be an instance of class \
:class:`ding.model.DQN`. \
If not specified, a default model will be generated according to the configuration.
- cfg (:obj:Union[EasyDict, dict]): The configuration of DQN algorithm, which is a dict. \
- cfg (:obj:`Union[EasyDict, dict]`): The configuration of DQN algorithm, which is a dict. \
Default to None. If not specified, the default configuration will be used. \
The default configuration can be found in ``ding/config/example/DQN/gym_lunarlander_v2.py``.
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
Expand Down
2 changes: 1 addition & 1 deletion ding/bonus/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
- model (:obj:`torch.nn.Module`): The model of PG algorithm, which should be an instance of class \
:class:`ding.model.PG`. \
If not specified, a default model will be generated according to the configuration.
- cfg (:obj:Union[EasyDict, dict]): The configuration of PG algorithm, which is a dict. \
- cfg (:obj:`Union[EasyDict, dict]`): The configuration of PG algorithm, which is a dict. \
Default to None. If not specified, the default configuration will be used. \
The default configuration can be found in ``ding/config/example/PG/gym_lunarlander_v2.py``.
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
Expand Down
2 changes: 1 addition & 1 deletion ding/bonus/ppo_offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
- model (:obj:`torch.nn.Module`): The model of PPO (offpolicy) algorithm, \
which should be an instance of class :class:`ding.model.VAC`. \
If not specified, a default model will be generated according to the configuration.
- cfg (:obj:Union[EasyDict, dict]): The configuration of PPO (offpolicy) algorithm, which is a dict. \
- cfg (:obj:`Union[EasyDict, dict]`): The configuration of PPO (offpolicy) algorithm, which is a dict. \
Default to None. If not specified, the default configuration will be used. \
The default configuration can be found in ``ding/config/example/PPO (offpolicy)/gym_lunarlander_v2.py``.
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
Expand Down
2 changes: 1 addition & 1 deletion ding/bonus/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
- model (:obj:`torch.nn.Module`): The model of SAC algorithm, which should be an instance of class \
:class:`ding.model.ContinuousQAC`. \
If not specified, a default model will be generated according to the configuration.
- cfg (:obj:Union[EasyDict, dict]): The configuration of SAC algorithm, which is a dict. \
- cfg (:obj:`Union[EasyDict, dict]`): The configuration of SAC algorithm, which is a dict. \
Default to None. If not specified, the default configuration will be used. \
The default configuration can be found in ``ding/config/example/SAC/gym_lunarlander_v2.py``.
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
Expand Down
6 changes: 3 additions & 3 deletions ding/envs/env_manager/base_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
.. note::
For more details about how to merge config, please refer to the system document of DI-engine \
(`en link <../03_system/config.html>`_).
(`en link1 <../03_system/config.html>`_).
"""
self._cfg = cfg
self._env_fn = env_fn
Expand Down Expand Up @@ -484,7 +484,7 @@ def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: bool =
.. note::
For more details about ``dynamic_seed``, please refer to the best practice document of DI-engine \
(`en link <../04_best_practice/random_seed.html>`_).
(`en link2 <../04_best_practice/random_seed.html>`_).
"""
if isinstance(seed, numbers.Integral):
seed = [seed + i for i in range(self.env_num)]
Expand Down Expand Up @@ -580,7 +580,7 @@ class BaseEnvManagerV2(BaseEnvManager):
.. note::
For more details about new task pipeline, please refer to the system document of DI-engine \
(`system en link <../03_system/index.html>`_).
(`system en link3 <../03_system/index.html>`_).
Interfaces:
reset, step, seed, close, enable_save_replay, launch, default_config, reward_shaping, enable_save_figure
Expand Down
4 changes: 2 additions & 2 deletions ding/model/template/qgpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ def q_loss_fn(self, a, s, r, s_, d, fake_a_, discount=0.99):
- a (:obj:`torch.Tensor`): The input action.
- s (:obj:`torch.Tensor`): The input state.
- r (:obj:`torch.Tensor`): The input reward.
- s_ (:obj:`torch.Tensor`): The input next state.
- s\_ (:obj:`torch.Tensor`): The input next state.
- d (:obj:`torch.Tensor`): The input done.
- fake_a_ (:obj:`torch.Tensor`): The input fake action.
- fake_a (:obj:`torch.Tensor`): The input fake action.
- discount (:obj:`float`): The discount factor.
"""

Expand Down
11 changes: 5 additions & 6 deletions ding/policy/qgpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
@POLICY_REGISTRY.register('qgpo')
class QGPOPolicy(Policy):
"""
Overview:
Policy class of QGPO algorithm
Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning
https://arxiv.org/abs/2304.12824
Interfaces:
``__init__``, ``forward``, ``learn``, ``eval``, ``state_dict``, ``load_state_dict``
Overview:
Policy class of QGPO algorithm (https://arxiv.org/abs/2304.12824).
Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning
Interfaces:
``__init__``, ``forward``, ``learn``, ``eval``, ``state_dict``, ``load_state_dict``
"""

config = dict(
Expand Down
8 changes: 4 additions & 4 deletions ding/rl_utils/value_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ def value_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
"""
Overview:
A function to reduce the scale of the action-value function.
:math: `h(x) = sign(x)(\sqrt{(abs(x)+1)} - 1) + \eps * x` .
:math: `h(x) = sign(x)(\sqrt{(abs(x)+1)} - 1) + \epsilon * x` .
Arguments:
- x: (:obj:`torch.Tensor`) The input tensor to be normalized.
- eps: (:obj:`float`) The coefficient of the additive regularization term \
to ensure h^{-1} is Lipschitz continuous
to ensure inverse function is Lipschitz continuous
Returns:
- (:obj:`torch.Tensor`) Normalized tensor.
Expand All @@ -23,11 +23,11 @@ def value_inv_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
"""
Overview:
The inverse form of value rescale.
:math: `h^{-1}(x) = sign(x)({(\frac{\sqrt{1+4\eps(|x|+1+\eps)}-1}{2\eps})}^2-1)` .
:math: `h^{-1}(x) = sign(x)({(\frac{\sqrt{1+4\epsilon(|x|+1+\epsilon)}-1}{2\epsilon})}^2-1)` .
Arguments:
- x: (:obj:`torch.Tensor`) The input tensor to be unnormalized.
- eps: (:obj:`float`) The coefficient of the additive regularization term \
to ensure h^{-1} is Lipschitz continuous
to ensure inverse function is Lipschitz continuous
Returns:
- (:obj:`torch.Tensor`) Unnormalized tensor.
"""
Expand Down
8 changes: 3 additions & 5 deletions ding/torch_utils/network/gtrxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ def update(self, hidden_state: List[torch.Tensor]):
"""
Overview:
Update the memory given a sequence of hidden states.
Example for single layer:
memory_len=3, hidden_size_len=2, bs=3
Example for single layer: (memory_len=3, hidden_size_len=2, bs=3)
m00 m01 m02 h00 h01 h02 m20 m21 m22
m = m10 m11 m12 h = h10 h11 h12 => new_m = h00 h01 h02
Expand Down Expand Up @@ -264,9 +263,8 @@ def _rel_shift(self, x: torch.Tensor, zero_upper: bool = False) -> torch.Tensor:
4) Mask out the upper triangle (optional)
.. note::
See the following material for better understanding:
https://github.com/kimiyoung/transformer-xl/issues/8
https://arxiv.org/pdf/1901.02860.pdf (Appendix B)
See the following material for better understanding: https://github.com/kimiyoung/transformer-xl/issues/8 \
https://arxiv.org/pdf/1901.02860.pdf (Appendix B)
Arguments:
- x (:obj:`torch.Tensor`): The input tensor with shape (cur_seq, full_seq, bs, head_num).
- zero_upper (:obj:`bool`): If True, the upper-right triangle of the matrix is set to zero.
Expand Down

0 comments on commit d88ebe2

Please sign in to comment.