Skip to content

Commit

Permalink
polish(pu): delete unused enable_fast_timestep argument (#855)
Browse files Browse the repository at this point in the history
* polish(pu): delete unused enable_fast_timestep argument

* polish(pu): delete unused empty lines

* polish(pu): delete unused empty lines

* style(pu): polish comment's format

* style(pu): polish comment's format
  • Loading branch information
puyuan1996 authored Jan 27, 2025
1 parent 3292384 commit 64efcb3
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 87 deletions.
23 changes: 8 additions & 15 deletions ding/model/template/collaq.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,27 +411,20 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_alone_state = agent_alone_state.reshape(T, -1, *agent_alone_state.shape[3:])
agent_alone_padding_state = agent_alone_padding_state.reshape(T, -1, *agent_alone_padding_state.shape[3:])

colla_output = self._q_network(
{
'obs': agent_state,
'prev_state': colla_prev_state,
'enable_fast_timestep': True
}
)
colla_output = self._q_network({
'obs': agent_state,
'prev_state': colla_prev_state,
})
colla_alone_output = self._q_network(
{
'obs': agent_alone_padding_state,
'prev_state': colla_alone_prev_state,
'enable_fast_timestep': True
}
)
alone_output = self._q_alone_network(
{
'obs': agent_alone_state,
'prev_state': alone_prev_state,
'enable_fast_timestep': True
}
)
alone_output = self._q_alone_network({
'obs': agent_alone_state,
'prev_state': alone_prev_state,
})

agent_alone_q, alone_next_state = alone_output['logit'], alone_output['next_state']
agent_colla_alone_q, colla_alone_next_state = colla_alone_output['logit'], colla_alone_output['next_state']
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/coma.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(self, inputs: Dict) -> Dict:
T, B, A = agent_state.shape[:3]
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
prev_state = reduce(lambda x, y: x + y, prev_state)
output = self.main({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
output = self.main({'obs': agent_state, 'prev_state': prev_state})
logit, next_state = output['logit'], output['next_state']
next_state, _ = list_split(next_state, step=A)
logit = logit.reshape(T, B, A, -1)
Expand Down
119 changes: 65 additions & 54 deletions ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,18 +855,22 @@ def reshape(d):
class DRQN(nn.Module):
"""
Overview:
The neural network structure and computation graph of DRQN (DQN + RNN = DRQN) algorithm, which is the most \
common DQN variant for sequential data and paratially observable environment. The DRQN is composed of three \
parts: ``encoder``, ``head`` and ``rnn``. The ``encoder`` is used to extract the feature from various \
observation, the ``rnn`` is used to process the sequential observation and other data, and the ``head`` is \
used to compute the Q value of each action dimension.
The DRQN (Deep Recurrent Q-Network) is a neural network model combining DQN with RNN to handle sequential
data and partially observable environments. It consists of three main components: ``encoder``, ``rnn``,
and ``head``.
- **Encoder**: Extracts features from various observation inputs.
- **RNN**: Processes sequential observations and other data.
- **Head**: Computes Q-values for each action dimension.
Interfaces:
``__init__``, ``forward``.
.. note::
Current ``DRQN`` supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``, two types of head: \
``DiscreteHead`` and ``DuelingHead``, three types of rnn: ``normal (LSTM with LayerNorm)``, ``pytorch`` and \
``gru``. You can customize your own encoder, rnn or head by inheriting this class.
The current implementation supports:
- Two encoder types: ``FCEncoder`` and ``ConvEncoder``.
- Two head types: ``DiscreteHead`` and ``DuelingHead``.
- Three RNN types: ``normal (LSTM with LayerNorm)``, ``pytorch`` (PyTorch's native LSTM), and ``gru``.
You can extend the model by customizing your own encoder, RNN, or head by inheriting this class.
"""

def __init__(
Expand All @@ -884,43 +888,48 @@ def __init__(
) -> None:
"""
Overview:
Initialize the DRQN Model according to the corresponding input arguments.
Initialize the DRQN model with specified parameters.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
the last element must match ``head_hidden_size``.
- dueling (:obj:`Optional[bool]`): Whether choose ``DuelingHead`` or ``DiscreteHead (default)``.
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \
then it will be set to the last element of ``encoder_hidden_size_list``.
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
- lstm_type (:obj:`Optional[str]`): The type of RNN module, now support ['normal', 'pytorch', 'gru'].
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
- res_link (:obj:`bool`): Whether to enable the residual link, which is the skip connnection between \
single frame data and the sequential data, defaults to False.
- obs_shape (:obj:`Union[int, SequenceType]`): Shape of the observation space, e.g., 8 or [4, 84, 84].
- action_shape (:obj:`Union[int, SequenceType]`): Shape of the action space, e.g., 6 or [2, 3, 3].
- encoder_hidden_size_list (:obj:`SequenceType`): List of hidden sizes for the encoder. The last element \
must match ``head_hidden_size``.
- dueling (:obj:`Optional[bool]`): Use ``DuelingHead`` if True, otherwise use ``DiscreteHead``.
- head_hidden_size (:obj:`Optional[int]`): Hidden size for the head network. Defaults to the last \
element of ``encoder_hidden_size_list`` if None.
- head_layer_num (:obj:`int`): Number of layers in the head network to compute Q-value outputs.
- lstm_type (:obj:`Optional[str]`): Type of RNN module. Supported types are ``normal``, ``pytorch``, \
and ``gru``.
- activation (:obj:`Optional[nn.Module]`): Activation function used in the network. Defaults to \
``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): Normalization type for the networks. Supported types are: \
['BN', 'IN', 'SyncBN', 'LN']. See ``ding.torch_utils.fc_block`` for more details.
- res_link (:obj:`bool`): Enables residual connections between single-frame data and sequential data. \
Defaults to False.
"""
super(DRQN, self).__init__()
# For compatibility: 1, (1, ), [4, 32, 32]
# Compatibility for obs_shape/action_shape: Handles scalar, tuple, or multi-dimensional inputs.
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
if head_hidden_size is None:
head_hidden_size = encoder_hidden_size_list[-1]
# FC Encoder

# Encoder: Determines the encoder type based on the observation shape.
if isinstance(obs_shape, int) or len(obs_shape) == 1:
# FC Encoder
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
# Conv Encoder
elif len(obs_shape) == 3:
# Conv Encoder
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
else:
raise RuntimeError(
"not support obs_shape for pre-defined encoder: {}, please customize your own DRQN".format(obs_shape)
f"Unsupported obs_shape for pre-defined encoder: {obs_shape}. Please customize your own DRQN."
)
# LSTM Type

# RNN: Initializes the RNN module based on the specified lstm_type.
self.rnn = get_lstm(lstm_type, input_size=head_hidden_size, hidden_size=head_hidden_size)
self.res_link = res_link
# Head Type

# Head: Determines the head type (Dueling or Discrete) and its configuration.
if dueling:
head_cls = DuelingHead
else:
Expand All @@ -943,31 +952,32 @@ def __init__(
def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict:
"""
Overview:
DRQN forward computation graph, input observation tensor to predict q_value.
Defines the forward pass of the DRQN model. Takes observation and previous RNN states as inputs \
and predicts Q-values.
Arguments:
- inputs (:obj:`torch.Tensor`): The dict of input data, including observation and previous rnn state.
- inference: (:obj:'bool'): Whether to enable inference forward mode, if True, we unroll the one timestep \
transition, otherwise, we unroll the eentire sequence transitions.
- saved_state_timesteps: (:obj:'Optional[list]'): When inference is False, we unroll the sequence \
transitions, then we would use this list to indicate how to save and return hidden state.
- inputs (:obj:`Dict`): Input data dictionary containing observation and previous RNN state.
- inference (:obj:`bool`): If True, unrolls one timestep (used during evaluation). If False, unrolls \
the entire sequence (used during training).
- saved_state_timesteps (:obj:`Optional[list]`): When inference is False, specifies the timesteps \
whose hidden states are saved and returned.
ArgumentsKeys:
- obs (:obj:`torch.Tensor`): The raw observation tensor.
- prev_state (:obj:`list`): The previous rnn state tensor, whose structure depends on ``lstm_type``.
- obs (:obj:`torch.Tensor`): Raw observation tensor.
- prev_state (:obj:`list`): Previous RNN state tensor, structure depends on ``lstm_type``.
Returns:
- outputs (:obj:`Dict`): The output of DRQN's forward, including logit (q_value) and next state.
ReturnsKeys:
- logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension.
- next_state (:obj:`list`): The next rnn state tensor, whose structure depends on ``lstm_type``.
- logit (:obj:`torch.Tensor`): Discrete Q-value output for each action dimension.
- next_state (:obj:`list`): Next RNN state tensor.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
- logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
- obs (:obj:`torch.Tensor`): :math:`(B, N)` where B is batch size and N is ``obs_shape``.
- logit (:obj:`torch.Tensor`): :math:`(B, M)` where B is batch size and M is ``action_shape``.
Examples:
>>> # Init input's Keys:
>>> # Initialize input keys
>>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4
>>> obs = torch.randn(4,64)
>>> model = DRQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
>>> outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True)
>>> # Check outputs's Keys
>>> # Validate output keys and shapes
>>> assert isinstance(outputs, dict)
>>> assert outputs['logit'].shape == (4, 64)
>>> assert len(outputs['next_state']) == 4
Expand All @@ -976,9 +986,9 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
"""

x, prev_state = inputs['obs'], inputs['prev_state']
# for both inference and other cases, the network structure is encoder -> rnn network -> head
# the difference is inference take the data with seq_len=1 (or T = 1)
# NOTE(rjy): in most situations, set inference=True when evaluate and inference=False when training
# Forward pass: Encoder -> RNN -> Head
# in most situations, set inference=True when evaluate and inference=False when training
# Inference mode: Processes one timestep (seq_len=1).
if inference:
x = self.encoder(x)
if self.res_link:
Expand All @@ -992,27 +1002,28 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
x = self.head(x)
x['next_state'] = next_state
return x
# Training mode: Processes the entire sequence.
else:
# In order to better explain why rnn needs saved_state and which states need to be stored,
# let's take r2d2 as an example
# in r2d2,
# 1) data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
# 2) data['main_obs'] = data['obs'][bs:-self._nstep]
# 3) data['target_obs'] = data['obs'][bs + self._nstep:]
# NOTE(rjy): (T, B, N) or (T, B, C, H, W)
assert len(x.shape) in [3, 5], x.shape
assert len(x.shape) in [3, 5], f"Expected shape (T, B, N) or (T, B, C, H, W), got {x.shape}"
x = parallel_wrapper(self.encoder)(x) # (T, B, N)
if self.res_link:
a = x
# NOTE(rjy) lstm_embedding stores all hidden_state
# lstm_embedding stores all hidden_state
lstm_embedding = []
# TODO(nyz) how to deal with hidden_size key-value
hidden_state_list = []

if saved_state_timesteps is not None:
saved_state = []
for t in range(x.shape[0]): # T timesteps
# NOTE(rjy) use x[t:t+1] but not x[t] can keep original dimension
output, prev_state = self.rnn(x[t:t + 1], prev_state) # output: (1,B, head_hidden_size)
for t in range(x.shape[0]): # Iterate over timesteps (T).
# use x[t:t+1] but not x[t] can keep the original dimension
output, prev_state = self.rnn(x[t:t + 1], prev_state) # RNN step output: (1, B, hidden_size)
if saved_state_timesteps is not None and t + 1 in saved_state_timesteps:
saved_state.append(prev_state)
lstm_embedding.append(output)
Expand All @@ -1023,7 +1034,7 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
if self.res_link:
x = x + a
x = parallel_wrapper(self.head)(x) # (T, B, action_shape)
# NOTE(rjy): x['next_state'] is the hidden state of the last timestep inputted to lstm
# x['next_state'] is the hidden state of the last timestep inputted to lstm
# the last timestep state including the hidden state (h) and the cell state (c)
# shape: {list: B{dict: 2{Tensor:(1, 1, head_hidden_size}}}
x['next_state'] = prev_state
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
prev_state = reduce(lambda x, y: x + y, prev_state)
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
output = self._q_network({'obs': agent_state, 'prev_state': prev_state})
agent_q, next_state = output['logit'], output['next_state']
next_state, _ = list_split(next_state, step=A)
agent_q = agent_q.reshape(T, B, A, -1)
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/qtran.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
prev_state = reduce(lambda x, y: x + y, prev_state)
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
output = self._q_network({'obs': agent_state, 'prev_state': prev_state})
agent_q, next_state = output['logit'], output['next_state']
next_state, _ = list_split(next_state, step=A)
agent_q = agent_q.reshape(T, B, A, -1)
Expand Down
2 changes: 0 additions & 2 deletions ding/model/template/wqmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) ->
{
'obs': agent_state,
'prev_state': prev_state,
'enable_fast_timestep': True
}
) # here is the forward pass of the agent networks of Q_star
agent_q, next_state = output['logit'], output['next_state']
Expand Down Expand Up @@ -223,7 +222,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) ->
{
'obs': agent_state,
'prev_state': prev_state,
'enable_fast_timestep': True
}
) # here is the forward pass of the agent networks of Q
agent_q, next_state = output['logit'], output['next_state']
Expand Down
3 changes: 0 additions & 3 deletions ding/policy/ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
'action': data['burnin_nstep_action'],
'reward': data['burnin_nstep_reward'],
'beta': data['burnin_nstep_beta'],
'enable_fast_timestep': True
}
tmp = self._learn_model.forward(
inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
Expand All @@ -304,7 +303,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
'action': data['main_action'],
'reward': data['main_reward'],
'beta': data['main_beta'],
'enable_fast_timestep': True
}
self._learn_model.reset(data_id=None, state=tmp['saved_state'][0])
q_value = self._learn_model.forward(inputs)['logit']
Expand All @@ -317,7 +315,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
'action': data['target_action'],
'reward': data['target_reward'],
'beta': data['target_beta'],
'enable_fast_timestep': True
}
with torch.no_grad():
target_q_value = self._target_model.forward(next_inputs)['logit']
Expand Down
Loading

0 comments on commit 64efcb3

Please sign in to comment.