Skip to content

Commit

Permalink
fix(nyz): fix unittest bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 21, 2023
1 parent 6e139b6 commit 7451f46
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@

@pytest.mark.unittest
def test_serial_pipeline_trex_onpolicy():
exp_name = 'test_serial_pipeline_trex_onpolicy_expert'
exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert'
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
config[0].exp_name = exp_name
expert_policy = serial_pipeline_onpolicy(config, seed=0)

exp_name = 'test_serial_pipeline_trex_onpolicy_collect'
exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_collect'
config = [deepcopy(cartpole_trex_ppo_onpolicy_config), deepcopy(cartpole_trex_ppo_onpolicy_create_config)]
config[0].exp_name = exp_name
config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_onpolicy_expert'
config[0].reward_model.expert_model_path = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert'
config[0].reward_model.checkpoint_max = 100
config[0].reward_model.checkpoint_step = 100
config[0].reward_model.num_snippets = 100
Expand Down
2 changes: 1 addition & 1 deletion ding/envs/env/tests/test_ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_hybrid(self):
print('random_action', action)
assert isinstance(action, dict)

@pytest.mark.unittest
@pytest.mark.envtest
def test_AllinObsWrapper(self):
env_cfg = EasyDict(env_id='PongNoFrameskip-v4', env_wrapper='reward_in_obs')
ding_env_aio = DingEnvWrapper(cfg=env_cfg)
Expand Down
2 changes: 1 addition & 1 deletion ding/framework/tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_callback(key):
time.sleep(0.7)


@pytest.mark.unittest
@pytest.mark.tmp
def test_parallel_run():
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main)
Parallel.runner(n_parallel_workers=2, protocol="tcp", startup_interval=0.1)(parallel_main)
Expand Down
9 changes: 3 additions & 6 deletions ding/model/template/tests/test_decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn.functional as F

from ding.model.template import DecisionTransformer
from ding.torch_utils import is_differentiable, one_hot
from ding.torch_utils import is_differentiable

args = ['continuous', 'discrete']

Expand All @@ -23,6 +23,7 @@ def test_decision_transformer(action_space):
context_len=T,
n_heads=2,
drop_p=0.1,
continuous=(action_space == 'continuous')
)

is_continuous = True if action_space == 'continuous' else False
Expand All @@ -40,15 +41,11 @@ def test_decision_transformer(action_space):
# all ones since no padding
traj_mask = torch.ones([B, T], dtype=torch.long) # B x T

# if discrete
if not is_continuous:
actions = one_hot(actions.squeeze(-1), num=act_dim)

assert actions.shape == (B, T, act_dim)
if is_continuous:
assert action_target.shape == (B, T, act_dim)
else:
assert action_target.shape == (B, T, 1)
actions = actions.squeeze(-1)

returns_to_go = returns_to_go.float()
state_preds, action_preds, return_preds = DT_model.forward(
Expand Down
37 changes: 18 additions & 19 deletions ding/policy/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def _init_learn(self) -> None:
self.act_dim = self._cfg.model.act_dim

self._learn_model = self._model
self._atari_env = 'state_mean' not in self._cfg
self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg

if 'state_mean' not in self._cfg:
if self._atari_env:
self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr)
else:
self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay)
Expand All @@ -93,22 +95,18 @@ def _forward_learn(self, data: list) -> Dict[str, Any]:
self._learn_model.train()

timesteps, states, actions, returns_to_go, traj_mask = data
if actions.dtype is not torch.long:
actions = actions.to(torch.long)
action_target = torch.clone(actions).detach().to(self._device)

# The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1),
# and we need a 3-dim tensor
if len(returns_to_go.shape) == 2:
returns_to_go = returns_to_go.unsqueeze(-1)

# if discrete
if not self._cfg.model.continuous and 'state_mean' in self._cfg:
# actions = one_hot(actions.squeeze(-1), num=self.act_dim)
if self._basic_discrete_env:
actions = actions.to(torch.long)
actions = actions.squeeze(-1)
action_target = torch.clone(actions).detach().to(self._device)

if 'state_mean' not in self._cfg:
if self._atari_env:
state_preds, action_preds, return_preds = self._learn_model.forward(
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1
)
Expand All @@ -117,7 +115,7 @@ def _forward_learn(self, data: list) -> Dict[str, Any]:
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go
)

if 'state_mean' not in self._cfg:
if self._atari_env:
action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1))
else:
traj_mask = traj_mask.view(-1, )
Expand Down Expand Up @@ -171,7 +169,9 @@ def _init_eval(self) -> None:
self.actions = torch.zeros(
(self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device
)
if 'state_mean' not in self._cfg:
self._atari_env = 'state_mean' not in self._cfg
self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg
if self._atari_env:
self.states = torch.zeros(
(
self.eval_batch_size,
Expand Down Expand Up @@ -201,7 +201,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:

self._eval_model.eval()
with torch.no_grad():
if 'state_mean' not in self._cfg:
if self._atari_env:
states = torch.zeros(
(
self.eval_batch_size,
Expand All @@ -228,15 +228,15 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
(self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device
)
for i in data_id:
if 'state_mean' not in self._cfg:
if self._atari_env:
self.states[i, self.t[i]] = data[i]['obs'].to(self._device)
else:
self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std
self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device)
self.rewards_to_go[i, self.t[i]] = self.running_rtg[i]

if self.t[i] <= self.context_len:
if 'state_mean' not in self._cfg:
if self._atari_env:
timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
(1, 1), dtype=torch.int64
).to(self._device)
Expand All @@ -246,7 +246,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
actions[i] = self.actions[i, :self.context_len]
rewards_to_go[i] = self.rewards_to_go[i, :self.context_len]
else:
if 'state_mean' not in self._cfg:
if self._atari_env:
timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
(1, 1), dtype=torch.int64
).to(self._device)
Expand All @@ -255,15 +255,14 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
if not self._cfg.model.continuous and 'state_mean' in self._cfg:
# actions = one_hot(actions.squeeze(-1), num=self.act_dim)
if self._basic_discrete_env:
actions = actions.squeeze(-1)
_, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go)
del timesteps, states, actions, rewards_to_go

logits = act_preds[:, -1, :]
if not self._cfg.model.continuous:
if 'state_mean' not in self._cfg:
if self._atari_env:
probs = F.softmax(logits, dim=-1)
act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device)
for i in data_id:
Expand Down Expand Up @@ -297,7 +296,7 @@ def _reset_eval(self, data_id: List[int] = None) -> None:
dtype=torch.float32,
device=self._device
)
if 'state_mean' not in self._cfg:
if self._atari_env:
self.states = torch.zeros(
(
self.eval_batch_size,
Expand Down Expand Up @@ -327,7 +326,7 @@ def _reset_eval(self, data_id: List[int] = None) -> None:
self.actions[i] = torch.zeros(
(self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device
)
if 'state_mean' not in self._cfg:
if self._atari_env:
self.states[i] = torch.zeros(
(self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device
)
Expand Down

0 comments on commit 7451f46

Please sign in to comment.