Skip to content

Commit

Permalink
doc(zjow): polish rl_utils doc (#724)
Browse files Browse the repository at this point in the history
* polish doc

* polish doc

* polish overview

* polish overview

* polish note
  • Loading branch information
zjowowen authored Oct 7, 2023
1 parent 08a6c52 commit 92ac919
Show file tree
Hide file tree
Showing 15 changed files with 609 additions and 14 deletions.
2 changes: 1 addition & 1 deletion ding/rl_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .exploration import get_epsilon_greedy_fn, create_noise_generator
from .ppo import ppo_data, ppo_loss, ppo_info, ppo_policy_data, ppo_policy_error, ppo_value_data, ppo_value_error,\
ppo_error, ppo_error_continuous, ppo_policy_error_continuous
ppo_error, ppo_error_continuous, ppo_policy_error_continuous, ppo_data_continuous, ppo_policy_data_continuous
from .ppg import ppg_data, ppg_joint_loss, ppg_joint_error
from .gae import gae_data, gae
from .a2c import a2c_data, a2c_error, a2c_error_continuous
Expand Down
22 changes: 21 additions & 1 deletion ding/rl_utils/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ def a2c_error(data: namedtuple) -> namedtuple:
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
- value_loss (:obj:`torch.FloatTensor`): :math:`()`
- entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
Examples:
>>> data = a2c_data(
>>> logit=torch.randn(2, 3),
>>> action=torch.randint(0, 3, (2, )),
>>> value=torch.randn(2, ),
>>> adv=torch.randn(2, ),
>>> return_=torch.randn(2, ),
>>> weight=torch.ones(2, ),
>>> )
>>> loss = a2c_error(data)
"""
logit, action, value, adv, return_, weight = data
if weight is None:
Expand All @@ -47,14 +57,24 @@ def a2c_error_continuous(data: namedtuple) -> namedtuple:
- a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor
Shapes:
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- action (:obj:`torch.LongTensor`): :math:`(B, N)`
- value (:obj:`torch.FloatTensor`): :math:`(B, )`
- adv (:obj:`torch.FloatTensor`): :math:`(B, )`
- return (:obj:`torch.FloatTensor`): :math:`(B, )`
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
- value_loss (:obj:`torch.FloatTensor`): :math:`()`
- entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
Examples:
>>> data = a2c_data(
>>> logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)},
>>> action=torch.randn(2, 3),
>>> value=torch.randn(2, ),
>>> adv=torch.randn(2, ),
>>> return_=torch.randn(2, ),
>>> weight=torch.ones(2, ),
>>> )
>>> loss = a2c_error_continuous(data)
"""
logit, action, value, adv, return_, weight = data
if weight is None:
Expand Down
20 changes: 20 additions & 0 deletions ding/rl_utils/acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def acer_policy_error(
- ratio (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
- actor_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
- bc_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
Examples:
>>> q_values=torch.randn(2, 3, 4),
>>> q_retraces=torch.randn(2, 3, 1),
>>> v_pred=torch.randn(2, 3, 1),
>>> target_pi=torch.randn(2, 3, 4),
>>> actions=torch.randint(0, 4, (2, 3)),
>>> ratio=torch.randn(2, 3, 4),
>>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio)
"""
actions = actions.unsqueeze(-1)
with torch.no_grad():
Expand Down Expand Up @@ -69,6 +77,12 @@ def acer_value_error(q_values, q_retraces, actions):
- q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
- actions (:obj:`torch.LongTensor`): :math:`(T, B)`
- critic_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
Examples:
>>> q_values=torch.randn(2, 3, 4)
>>> q_retraces=torch.randn(2, 3, 1)
>>> actions=torch.randint(0, 4, (2, 3))
>>> loss = acer_value_error(q_values, q_retraces, actions)
"""
actions = actions.unsqueeze(-1)
critic_loss = 0.5 * (q_retraces - q_values.gather(-1, actions)).pow(2)
Expand All @@ -92,6 +106,12 @@ def acer_trust_region_update(
Shapes:
- target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
- avg_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
- update_gradients (:obj:`list(torch.FloatTensor)`): :math:`(T, B, N)`
Examples:
>>> actor_gradients=[torch.randn(2, 3, 4)]
>>> target_pi=torch.randn(2, 3, 4)
>>> avg_pi=torch.randn(2, 3, 4)
>>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1)
"""
with torch.no_grad():
KL_gradients = [torch.exp(avg_logit)]
Expand Down
36 changes: 31 additions & 5 deletions ding/rl_utils/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from ding.utils import list_split, lists_to_dicts
from .gae import gae, gae_data
from ding.rl_utils.gae import gae, gae_data


class Adder(object):
Expand All @@ -25,11 +25,20 @@ def get_gae(cls, data: List[Dict[str, Any]], last_value: torch.Tensor, gamma: fl
Arguments:
- data (:obj:`list`): Transitions list, each element is a transition dict with at least ['value', 'reward']
- last_value (:obj:`torch.Tensor`): The last value(i.e.: the T+1 timestep)
- gamma (:obj:`float`): The future discount factor
- gae_lambda (:obj:`float`): GAE lambda parameter
- gamma (:obj:`float`): The future discount factor, should be in [0, 1], defaults to 0.99.
- gae_lambda (:obj:`float`): GAE lambda parameter, should be in [0, 1], defaults to 0.97, \
when lambda -> 0, it induces bias, but when lambda -> 1, it has high variance due to the sum of terms.
- cuda (:obj:`bool`): Whether use cuda in GAE computation
Returns:
- data (:obj:`list`): transitions list like input one, but each element owns extra advantage key 'adv'
Examples:
>>> B, T = 2, 3 # batch_size, timestep
>>> data = [dict(value=torch.randn(B), reward=torch.randn(B)) for _ in range(T)]
>>> last_value = torch.randn(B)
>>> gamma = 0.99
>>> gae_lambda = 0.95
>>> cuda = False
>>> data = Adder.get_gae(data, last_value, gamma, gae_lambda, cuda)
"""
value = torch.stack([d['value'] for d in data])
next_value = torch.stack([d['value'] for d in data][1:] + [last_value])
Expand Down Expand Up @@ -60,12 +69,21 @@ def get_gae_with_default_last_value(cls, data: deque, done: bool, gamma: float,
- data (:obj:`deque`): Transitions list, each element is a transition dict with \
at least['value', 'reward']
- done (:obj:`bool`): Whether the transition reaches the end of an episode(i.e. whether the env is done)
- gamma (:obj:`float`): The future discount factor
- gae_lambda (:obj:`float`): GAE lambda parameter
- gamma (:obj:`float`): The future discount factor, should be in [0, 1], defaults to 0.99.
- gae_lambda (:obj:`float`): GAE lambda parameter, should be in [0, 1], defaults to 0.97, \
when lambda -> 0, it induces bias, but when lambda -> 1, it has high variance due to the sum of terms.
- cuda (:obj:`bool`): Whether use cuda in GAE computation
Returns:
- data (:obj:`List[Dict[str, Any]]`): transitions list like input one, but each element owns \
extra advantage key 'adv'
Examples:
>>> B, T = 2, 3 # batch_size, timestep
>>> data = [dict(value=torch.randn(B), reward=torch.randn(B)) for _ in range(T)]
>>> done = False
>>> gamma = 0.99
>>> gae_lambda = 0.95
>>> cuda = False
>>> data = Adder.get_gae_with_default_last_value(data, done, gamma, gae_lambda, cuda)
"""
if done:
last_value = torch.zeros_like(data[-1]['value'])
Expand All @@ -92,6 +110,14 @@ def get_nstep_return_data(
Otherwise update with nstep value.
Returns:
- data (:obj:`deque`): Transitions list like input one, but each element updated with nstep value.
Examples:
>>> data = [dict(
>>> obs=torch.randn(B),
>>> reward=torch.randn(1),
>>> next_obs=torch.randn(B),
>>> done=False) for _ in range(T)]
>>> nstep = 2
>>> data = Adder.get_nstep_return_data(data, nstep)
"""
if nstep == 1:
return data
Expand Down
14 changes: 13 additions & 1 deletion ding/rl_utils/coma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import namedtuple
import torch
import torch.nn.functional as F
from .td import generalized_lambda_returns
from ding.rl_utils.td import generalized_lambda_returns

coma_data = namedtuple('coma_data', ['logit', 'action', 'q_value', 'target_q_value', 'reward', 'weight'])
coma_loss = namedtuple('coma_loss', ['policy_loss', 'q_value_loss', 'entropy_loss'])
Expand All @@ -26,6 +26,18 @@ def coma_error(data: namedtuple, gamma: float, lambda_: float) -> namedtuple:
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
- value_loss (:obj:`torch.FloatTensor`): :math:`()`
- entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
Examples:
>>> action_dim = 4
>>> agent_num = 3
>>> data = coma_data(
>>> logit=torch.randn(2, 3, agent_num, action_dim),
>>> action=torch.randint(0, action_dim, (2, 3, agent_num)),
>>> q_value=torch.randn(2, 3, agent_num, action_dim),
>>> target_q_value=torch.randn(2, 3, agent_num, action_dim),
>>> reward=torch.randn(2, 3),
>>> weight=torch.ones(2, 3, agent_num),
>>> )
>>> loss = coma_error(data, 0.99, 0.99)
"""
logit, action, q_value, target_q_value, reward, weight = data
if weight is None:
Expand Down
4 changes: 4 additions & 0 deletions ding/rl_utils/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def __call__(self, shape: tuple, device: str, mu: Optional[float] = None) -> tor

@property
def x0(self) -> Union[float, torch.Tensor]:
"""
Overview:
Get ``self._x0``
"""
return self._x0

@x0.setter
Expand Down
6 changes: 6 additions & 0 deletions ding/rl_utils/gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F
- next_value (:obj:`torch.FloatTensor`): :math:`(T, B)`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
- adv (:obj:`torch.FloatTensor`): :math:`(T, B)`
Examples:
>>> value = torch.randn(2, 3)
>>> next_value = torch.randn(2, 3)
>>> reward = torch.randn(2, 3)
>>> data = gae_data(value, next_value, reward, None, None)
>>> adv = gae(data)
"""
value, next_value, reward, done, traj_flag = data
if done is None:
Expand Down
5 changes: 5 additions & 0 deletions ding/rl_utils/isw.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def compute_importance_weights(
- behaviour_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`
- action (:obj:`torch.LongTensor`): :math:`(T, B)`
- rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`
Examples:
>>> target_output = torch.randn(2, 3, 4)
>>> behaviour_output = torch.randn(2, 3, 4)
>>> action = torch.randint(0, 4, (2, 3))
>>> rhos = compute_importance_weights(target_output, behaviour_output, action)
"""
grad_context = torch.enable_grad() if requires_grad else torch.no_grad()
assert isinstance(action, torch.Tensor)
Expand Down
32 changes: 32 additions & 0 deletions ding/rl_utils/ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,38 @@ def ppg_joint_error(
clip_ratio: float = 0.2,
use_value_clip: bool = True,
) -> Tuple[namedtuple, namedtuple]:
'''
Overview:
Get PPG joint loss
Arguments:
- data (:obj:`namedtuple`): ppg input data with fieids shown in ``ppg_data``
- clip_ratio (:obj:`float`): clip value for ratio
- use_value_clip (:obj:`bool`): whether use value clip
Returns:
- ppg_joint_loss (:obj:`namedtuple`): the ppg loss item, all of them are the differentiable 0-dim tensor
Shapes:
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
- logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B,)`
- value_new (:obj:`torch.FloatTensor`): :math:`(B, 1)`
- value_old (:obj:`torch.FloatTensor`): :math:`(B, 1)`
- return_ (:obj:`torch.FloatTensor`): :math:`(B, 1)`
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B,)`
- auxiliary_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
- behavioral_cloning_loss (:obj:`torch.FloatTensor`): :math:`()`
Examples:
>>> action_dim = 4
>>> data = ppg_data(
>>> logit_new=torch.randn(3, action_dim),
>>> logit_old=torch.randn(3, action_dim),
>>> action=torch.randint(0, action_dim, (3,)),
>>> value_new=torch.randn(3, 1),
>>> value_old=torch.randn(3, 1),
>>> return_=torch.randn(3, 1),
>>> weight=torch.ones(3),
>>> )
>>> loss = ppg_joint_error(data, 0.99, 0.99)
'''
logit_new, logit_old, action, value_new, value_old, return_, weight = data

if weight is None:
Expand Down
Loading

0 comments on commit 92ac919

Please sign in to comment.