Skip to content

Commit

Permalink
fix(nyz): fix dt dataset compatibility bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 20, 2023
1 parent 08c42fa commit 0be9555
Show file tree
Hide file tree
Showing 12 changed files with 23 additions and 397 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs
- [ACE](https://github.com/opendilab/ACE): [AAAI 2023] ACE: Cooperative Multi-agent Q-learning with Bidirectional Action-Dependency
- [GoBigger](https://github.com/opendilab/GoBigger): [ICLR 2023] Multi-Agent Decision Intelligence Environment
- [DOS](https://github.com/opendilab/DOS): [CVPR 2023] ReasonNet: End-to-End Driving with Temporal and Global Reasoning
- [LightZero](https://github.com/opendilab/LightZero): LightZero: A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkit
- [LightZero](https://github.com/opendilab/LightZero): A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkit
- Docs and Tutorials
- [DI-engine-docs](https://github.com/opendilab/DI-engine-docs): Tutorials, best practice and the API reference.
- [awesome-model-based-RL](https://github.com/opendilab/awesome-model-based-RL): A curated list of awesome Model-Based RL resources
Expand Down
Binary file added assets/wechat.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
88 changes: 0 additions & 88 deletions ding/entry/serial_entry_decision_transformer.py

This file was deleted.

2 changes: 1 addition & 1 deletion ding/example/dt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gym
from ditk import logging
from ding.model.template.decision_transformer import DecisionTransformer
from ding.model.template.dt import DecisionTransformer
from ding.policy import DTPolicy
from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
Expand Down
22 changes: 12 additions & 10 deletions ding/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,18 @@ def __len__(self) -> int:
return len(self._data['obs']) - self.context_len

def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
# return {k: self._data[k][idx] for k in self._data.keys()}
block_size = self.context_len
done_idx = idx + block_size
idx = done_idx - block_size
states = torch.as_tensor(np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32).view(block_size, -1)
actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long)
rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32)
timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64)
traj_mask = torch.ones(self.context_len, dtype=torch.long)
return timesteps, states, actions, rtgs, traj_mask
if self.context_len == 0: # for other offline RL algorithms
return {k: self._data[k][idx] for k in self._data.keys()}
else: # for decision transformer
block_size = self.context_len
done_idx = idx + block_size
idx = done_idx - block_size
states = torch.as_tensor(np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32).view(block_size, -1)
actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long)
rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32)
timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64)
traj_mask = torch.ones(self.context_len, dtype=torch.long)
return timesteps, states, actions, rtgs, traj_mask

def _load_data(self, dataset: Dict[str, np.ndarray]) -> None:
self._data = {}
Expand Down
10 changes: 8 additions & 2 deletions ding/utils/pytorch_ddp_dist_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ def simple_group_split(world_size: int, rank: int, num_groups: int) -> List:

def to_ddp_config(cfg: EasyDict) -> EasyDict:
w = get_world_size()
cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w))
# cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample) / w)
if 'batch_size' in cfg.policy:
cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w))
if 'batch_size' in cfg.policy.learn:
cfg.policy.learn.batch_size = int(np.ceil(cfg.policy.learn.batch_size / w))
if 'n_sample' in cfg.policy.collect:
cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample / w))
if 'n_episode' in cfg.policy.collect:
cfg.policy.collect.n_episode = int(np.ceil(cfg.policy.collect.n_episode / w))
return cfg
5 changes: 0 additions & 5 deletions dizoo/atari/config/serial/pong/pong_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,3 @@
)
Pong_dt_create_config = EasyDict(Pong_dt_create_config)
create_config = Pong_dt_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline_dt
config = deepcopy([main_config, create_config])
serial_pipeline_dt(config, seed=0, max_train_iter=1000)
7 changes: 1 addition & 6 deletions dizoo/box2d/lunarlander/config/lunarlander_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
learn=dict(
dataset_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO
learning_rate=3e-4,
batch_size=64, # training batch size
batch_size=64, # training batch size
target_update_freq=100,
),
collect=dict(
Expand All @@ -62,8 +62,3 @@
)
lunarlander_dt_create_config = EasyDict(lunarlander_dt_create_config)
create_config = lunarlander_dt_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline_dt, collect_demo_data, eval, serial_pipeline
config = deepcopy([main_config, create_config])
serial_pipeline_dt(config, seed=0, max_train_iter=1000)
65 changes: 0 additions & 65 deletions dizoo/classic_control/cartpole/config/cartpole_dt_config.py

This file was deleted.

This file was deleted.

50 changes: 0 additions & 50 deletions dizoo/classic_control/cartpole/offline_data/cartpole_show_data.py

This file was deleted.

Loading

0 comments on commit 0be9555

Please sign in to comment.