Skip to content

Commit

Permalink
fix(nyz): fix unittest bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Oct 31, 2023
1 parent 111bf24 commit 439680a
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 17 deletions.
2 changes: 1 addition & 1 deletion ding/framework/middleware/tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_wandb_online_logger_gradient():
test_wandb_online_logger_gradient()


@pytest.mark.unittest
@pytest.mark.tmp
def test_wandb_offline_logger():
record_path = './video_pendulum_cql'
cfg = EasyDict(dict(gradient_logger=True, plot_logger=True, action_logger=True, vis_dataset=True))
Expand Down
4 changes: 2 additions & 2 deletions ding/policy/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
that the policy inputs some training batch data from the offline dataset and then returns the output \
result, including various training information such as loss, action, priority.
Arguments:
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
Expand Down Expand Up @@ -578,7 +578,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
that the policy inputs some training batch data from the offline dataset and then returns the output \
result, including various training information such as loss, action, priority.
Arguments:
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
Expand Down
2 changes: 2 additions & 0 deletions ding/policy/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,8 @@ def _init_collect(self) -> None:
self._nstep = self._cfg.nstep
self._nstep_return = self._cfg.nstep_return
self._value_norm = self._cfg.learn.value_norm
if self._value_norm:
self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)

def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]:
"""
Expand Down
12 changes: 2 additions & 10 deletions ding/policy/tests/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
]

dtype_test = [
"int32",
"int64",
"float32",
"float64",
]

data_type_test = [
Expand All @@ -29,16 +27,10 @@ def get_action(shape, dtype, class_type):
if class_type == "numpy":
return np.random.randn(*shape).astype(dtype)
else:
if dtype == "int32":
dtype = torch.int32
elif dtype == "int64":
if dtype == "int64":
dtype = torch.int64
elif dtype == "float16":
dtype = torch.float16
elif dtype == "float32":
dtype = torch.float32
elif dtype == "float64":
dtype = torch.float64

if class_type == "torch":
return torch.randn(*shape).type(dtype)
Expand Down Expand Up @@ -72,7 +64,7 @@ def test_default_preprocess_learn_action():
data = default_preprocess_learn(data, use_priority_IS_weight, use_priority, use_nstep, ignore_done)

assert data['obs'].shape == torch.Size([10, 4, 84, 84])
if dtype in ["int32", "int64"] and shape[0] == 1:
if dtype in ["int64"] and shape[0] == 1:
assert data['action'].shape == torch.Size([10])
else:
assert data['action'].shape == torch.Size([10, *shape])
Expand Down
10 changes: 6 additions & 4 deletions ding/torch_utils/network/tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ class TestDiffusionNet:
def test_DiffusionNet1d(self):
diffusion = DiffusionUNet1d(transition_dim, dim, dim_mults)
input = torch.rand(batch_size, horizon, transition_dim)
t = torch.randint(0, 10, (batch_size, )).long()
output = diffusion(input, time=t)
t = torch.randint(0, horizon, (batch_size, )).long()
cond = {t: torch.randn(batch_size, 2) for t in range(horizon)}
output = diffusion(input, cond, time=t)
assert output.shape == (batch_size, horizon, transition_dim)

def test_TemporalValue(self):
value = TemporalValue(horizon, transition_dim, dim, dim_mults=dim_mults)
input = torch.rand(batch_size, horizon, transition_dim)
t = torch.randint(0, 10, (batch_size, )).long()
output = value(input, time=t)
t = torch.randint(0, horizon, (batch_size, )).long()
cond = {t: torch.randn(batch_size, 2) for t in range(horizon)}
output = value(input, cond, time=t)
assert output.shape == (batch_size, 1)

0 comments on commit 439680a

Please sign in to comment.