Skip to content

Commit

Permalink
fix(nyz): fix offline mem data fetcher unittest bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 21, 2023
1 parent cacab2e commit 6e139b6
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def test_discrete_dt():
from ding.model.template.dt import DecisionTransformer
from ding.policy import DTPolicy
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
offline_data_fetcher_from_mem_c, offline_logger, termination_checker
OfflineMemoryDataFetcher, offline_logger, termination_checker
config = compile_config(config[0], create_cfg=config[1], auto=True)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = BaseEnvManagerV2(
Expand All @@ -675,7 +675,7 @@ def test_discrete_dt():

task.use(termination_checker(max_train_iter=1))
task.use(interaction_evaluator(config, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher_from_mem_c(config, dataset))
task.use(OfflineMemoryDataFetcher(config, dataset))
task.use(trainer(config, policy.learn_mode))
task.use(CkptSaver(policy, config.exp_name, train_freq=100))
task.use(offline_logger(config.exp_name))
Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .ckpt_handler import CkptSaver
from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
from .barrier import Barrier, BarrierRuntime
from .data_fetcher import offline_data_fetcher_from_mem_c
from .data_fetcher import OfflineMemoryDataFetcher
11 changes: 5 additions & 6 deletions ding/framework/middleware/data_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,23 @@
from threading import Thread, Event
from queue import Queue
import time
import numpy as np
import torch
import torch.distributed as dist
from easydict import EasyDict
from ding.framework import task
from ding.data import Dataset, DataLoader
from ding.utils import get_rank
import numpy as np
from ding.utils import get_rank, get_world_size

if TYPE_CHECKING:
from ding.framework import OfflineRLContext


class offline_data_fetcher_from_mem_c:
class OfflineMemoryDataFetcher:

def __new__(cls, *args, **kwargs):
if task.router.is_active and not task.has_role(task.role.FETCHER):
return task.void()
return super(offline_data_fetcher_from_mem_c, cls).__new__(cls)
return super(OfflineMemoryDataFetcher, cls).__new__(cls)

def __init__(self, cfg: EasyDict, dataset: Dataset):
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
Expand All @@ -30,7 +29,7 @@ def producer(queue, dataset, batch_size, device, event):
torch.set_num_threads(4)
if device != 'cpu':
nonlocal stream
sbatch_size = batch_size * dist.get_world_size()
sbatch_size = batch_size * get_world_size()
rank = get_rank()
idx_list = np.random.permutation(len(dataset))
temp_idx_list = []
Expand Down
5 changes: 3 additions & 2 deletions dizoo/atari/entry/atari_dt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_logger, termination_checker, offline_data_fetcher_from_mem_c, offline_data_fetcher
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_logger, termination_checker, \
OfflineMemoryDataFetcher
from ding.utils import set_pkg_seed, DDPContext, to_ddp_config
from dizoo.atari.envs import AtariEnv
from dizoo.atari.config.serial.pong.pong_dt_config import main_config, create_config
Expand Down Expand Up @@ -43,7 +44,7 @@ def main():
policy = DTPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher_from_mem_c(cfg, dataset))
task.use(OfflineMemoryDataFetcher(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(termination_checker(max_train_iter=3e4))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
Expand Down

0 comments on commit 6e139b6

Please sign in to comment.