Skip to content

Commit

Permalink
fix(nyz): fix mlp dropout if condition bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 28, 2023
1 parent 0968250 commit d6f3020
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DQNPolicy(Policy):
8 | ``model.dueling`` bool True | dueling head architecture
9 | ``model.encoder`` list [32, 64, | Sequence of ``hidden_size`` of | default kernel_size
| ``_hidden`` (int) 64, 128] | subsequent conv layers and the | is [8, 4, 3]
| ``_size_list`` | final dense layer. | default stride is
| ``_size_list`` | final dense layer. | default stride is
| [4, 2 ,1]
10 | ``model.dropout`` float None | Dropout rate for dropout layers. | [0,1]
| If set to ``None``
Expand Down
11 changes: 5 additions & 6 deletions ding/policy/mbpolicy/mbsac.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from typing import Dict, Any, List
from functools import partial
import copy

import torch
from torch import Tensor
from torch import nn
from torch.distributions import Normal, Independent, TransformedDistribution, TanhTransform
from easydict import EasyDict
from torch.distributions import Normal, Independent

from ding.torch_utils import to_device, fold_batch, unfold_batch, unsqueeze_repeat
from ding.utils import POLICY_REGISTRY, deep_merge_dicts
from ding.utils import POLICY_REGISTRY
from ding.policy import SACPolicy
from ding.rl_utils import generalized_lambda_returns
from ding.policy.common_utils import default_preprocess_learn
Expand All @@ -33,11 +31,12 @@ class MBSACPolicy(SACPolicy):
== ==================== ======== ============= ==================================
1 ``learn._lambda`` float 0.8 | Lambda for TD-lambda return.
2 ``learn.grad_clip` float 100.0 | Max norm of gradients.
3 ``learn.sample_`` bool True | Whether to sample states or tra-
``state`` | nsitions from env buffer.
3 | ``learn.sample`` bool True | Whether to sample states or
| ``_state`` | transitions from env buffer.
== ==================== ======== ============= ==================================
.. note::
For other configs, please refer to ding.policy.sac.SACPolicy.
"""

Expand Down
2 changes: 1 addition & 1 deletion ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def MLP(
# The last layer uses the same activation as front layers.
if activation is not None:
block.append(activation)
if use_dropout is not None:
if use_dropout:
block.append(nn.Dropout(dropout_probability))

if last_linear_layer_init_zero:
Expand Down

0 comments on commit d6f3020

Please sign in to comment.