Skip to content

Commit

Permalink
fix(nyz): fix drex collecting data unittest bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 31, 2023
1 parent 23fac67 commit efa59b2
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 337 deletions.
1 change: 0 additions & 1 deletion ding/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import serial_pipeline_preference_based_irl
from .serial_entry_preference_based_irl_onpolicy \
import serial_pipeline_preference_based_irl_onpolicy
from .application_entry_drex_collect_data import drex_collecting_data
from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream, serial_pipeline_dreamer
from .serial_entry_bco import serial_pipeline_bco
from .serial_entry_pc import serial_pipeline_pc
257 changes: 0 additions & 257 deletions ding/entry/application_entry_drex_collect_data.py

This file was deleted.

74 changes: 0 additions & 74 deletions ding/entry/tests/test_application_entry_drex_collect_data.py

This file was deleted.

5 changes: 3 additions & 2 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def test_discrete_dt():
config = [deepcopy(cartpole_discrete_dt_config), deepcopy(cartpole_discrete_dt_create_config)]
config[0].policy.eval.evaluator.eval_freq = 5
try:
from ding.framework import task
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.envs import SubprocessEnvManagerV2, BaseEnvManagerV2
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
Expand All @@ -659,6 +659,7 @@ def test_discrete_dt():
from ding.policy import DTPolicy
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
OfflineMemoryDataFetcher, offline_logger, termination_checker
ding_init(config[0])
config = compile_config(config[0], create_cfg=config[1], auto=True)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = BaseEnvManagerV2(
Expand All @@ -678,7 +679,7 @@ def test_discrete_dt():
task.use(OfflineMemoryDataFetcher(config, dataset))
task.use(trainer(config, policy.learn_mode))
task.use(CkptSaver(policy, config.exp_name, train_freq=100))
task.use(offline_logger(config.exp_name))
task.use(offline_logger())
task.run()
except Exception:
assert False, "pipeline fail"
Expand Down
2 changes: 1 addition & 1 deletion dizoo/atari/entry/atari_dt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main():
task.use(trainer(cfg, policy.learn_mode))
task.use(termination_checker(max_train_iter=3e4))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(offline_logger(cfg.exp_name))
task.use(offline_logger())
task.run()


Expand Down
4 changes: 2 additions & 2 deletions dizoo/d4rl/entry/d4rl_dt_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def main():
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
# ding_init(cfg)
ding_init(cfg)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: AllinObsWrapper(D4RLEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)],
Expand All @@ -42,7 +42,7 @@ def main():
task.use(trainer(cfg, policy.learn_mode))
task.use(termination_checker(max_train_iter=1e5))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(offline_logger(cfg.exp_name))
task.use(offline_logger())
task.run()


Expand Down

0 comments on commit efa59b2

Please sign in to comment.