-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(lyd): refactor dt_policy in new pipeline and add img input …
…support (#693) * Revise old version dt pipline * Add new dt pipline * Add DT in new pipeline * Add img input to support atari * Fix according to comment * Fix dt config files * Fix abs path * Accelerate DT train iter by replacing dataloader * Simplify dt model and policy and config * reformat * Reformat * Change data fatcher func to class * Add threading shift data to gpu * Change action sample func * Add configure optimizers * Add multi gpu support * Add dt policy test serial * Fix multi gpu support and data fetcher * Reformat
- Loading branch information
Showing
109 changed files
with
2,217 additions
and
1,445 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import gym | ||
from ditk import logging | ||
from ding.model.template.decision_transformer import DecisionTransformer | ||
from ding.policy import DTPolicy | ||
from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2 | ||
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper | ||
from ding.data import create_dataset | ||
from ding.config import compile_config | ||
from ding.framework import task, ding_init | ||
from ding.framework.context import OfflineRLContext | ||
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \ | ||
offline_data_fetcher, offline_logger, termination_checker, final_ctx_saver | ||
from ding.utils import set_pkg_seed | ||
from dizoo.box2d.lunarlander.envs.lunarlander_env import LunarLanderEnv | ||
from dizoo.box2d.lunarlander.config.lunarlander_dt_config import main_config, create_config | ||
|
||
|
||
def main(): | ||
# If you don't have offline data, you need to prepare if first and set the data_path in config | ||
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data | ||
logging.getLogger().setLevel(logging.INFO) | ||
cfg = compile_config(main_config, create_cfg=create_config, auto=True) | ||
ding_init(cfg) | ||
with task.start(async_mode=False, ctx=OfflineRLContext()): | ||
evaluator_env = BaseEnvManagerV2( | ||
env_fn=[lambda: AllinObsWrapper(LunarLanderEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], | ||
cfg=cfg.env.manager | ||
) | ||
|
||
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) | ||
|
||
dataset = create_dataset(cfg) | ||
cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats() | ||
model = DecisionTransformer(**cfg.policy.model) | ||
policy = DTPolicy(cfg.policy, model=model) | ||
|
||
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) | ||
task.use(offline_data_fetcher(cfg, dataset)) | ||
task.use(trainer(cfg, policy.learn_mode)) | ||
task.use(termination_checker(max_train_iter=1e5)) | ||
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) | ||
task.use(offline_logger()) | ||
task.run() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from typing import TYPE_CHECKING | ||
from threading import Thread, Event | ||
from queue import Queue | ||
import time | ||
import torch | ||
import torch.distributed as dist | ||
from easydict import EasyDict | ||
from ding.framework import task | ||
from ding.data import Dataset, DataLoader | ||
from ding.utils import get_rank | ||
import numpy as np | ||
|
||
if TYPE_CHECKING: | ||
from ding.framework import OfflineRLContext | ||
|
||
|
||
class offline_data_fetcher_from_mem_c: | ||
|
||
def __new__(cls, *args, **kwargs): | ||
if task.router.is_active and not task.has_role(task.role.FETCHER): | ||
return task.void() | ||
return super(offline_data_fetcher_from_mem_c, cls).__new__(cls) | ||
|
||
def __init__(self, cfg: EasyDict, dataset: Dataset): | ||
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' | ||
if device != 'cpu': | ||
stream = torch.cuda.Stream() | ||
|
||
def producer(queue, dataset, batch_size, device, event): | ||
torch.set_num_threads(4) | ||
if device != 'cpu': | ||
nonlocal stream | ||
sbatch_size = batch_size * dist.get_world_size() | ||
rank = get_rank() | ||
idx_list = np.random.permutation(len(dataset)) | ||
temp_idx_list = [] | ||
for i in range(len(dataset) // sbatch_size): | ||
temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) | ||
idx_iter = iter(temp_idx_list) | ||
|
||
if device != 'cpu': | ||
with torch.cuda.stream(stream): | ||
while True: | ||
if queue.full(): | ||
time.sleep(0.1) | ||
else: | ||
data = [] | ||
for _ in range(batch_size): | ||
try: | ||
data.append(dataset.__getitem__(next(idx_iter))) | ||
except StopIteration: | ||
del idx_iter | ||
idx_list = np.random.permutation(len(dataset)) | ||
idx_iter = iter(idx_list) | ||
data.append(dataset.__getitem__(next(idx_iter))) | ||
data = [[i[j] for i in data] for j in range(len(data[0]))] | ||
data = [torch.stack(x).to(device) for x in data] | ||
queue.put(data) | ||
if event.is_set(): | ||
break | ||
else: | ||
while True: | ||
if queue.full(): | ||
time.sleep(0.1) | ||
else: | ||
data = [] | ||
for _ in range(batch_size): | ||
try: | ||
data.append(dataset.__getitem__(next(idx_iter))) | ||
except StopIteration: | ||
del idx_iter | ||
idx_list = np.random.permutation(len(dataset)) | ||
idx_iter = iter(idx_list) | ||
data.append(dataset.__getitem__(next(idx_iter))) | ||
data = [[i[j] for i in data] for j in range(len(data[0]))] | ||
data = [torch.stack(x) for x in data] | ||
queue.put(data) | ||
if event.is_set(): | ||
break | ||
|
||
self.queue = Queue(maxsize=50) | ||
self.event = Event() | ||
self.producer_thread = Thread( | ||
target=producer, | ||
args=(self.queue, dataset, cfg.policy.batch_size, device, self.event), | ||
name='cuda_fetcher_producer' | ||
) | ||
|
||
def __call__(self, ctx: "OfflineRLContext"): | ||
if not self.producer_thread.is_alive(): | ||
time.sleep(5) | ||
self.producer_thread.start() | ||
while self.queue.empty(): | ||
time.sleep(0.001) | ||
ctx.train_data = self.queue.get() | ||
|
||
def __del__(self): | ||
if self.producer_thread.is_alive(): | ||
self.event.set() | ||
del self.queue |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.