Skip to content

Commit

Permalink
feature(nyz): add new pipeline impala demo
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jul 19, 2023
1 parent 4f4e98a commit e33c958
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 31 deletions.
2 changes: 2 additions & 0 deletions ding/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ def compile_config(
if len(world_model_config) > 0:
default_config['world_model'] = world_model_config
cfg = deep_merge_dicts(default_config, cfg)
if 'unroll_len' in cfg.policy:
cfg.policy.collect.unroll_len = cfg.policy.unroll_len
cfg.seed = seed
# check important key in config
if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation
Expand Down
44 changes: 44 additions & 0 deletions ding/example/impala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import gym
from ditk import logging
from ding.model import VAC
from ding.policy import IMPALAPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
CkptSaver
from ding.utils import set_pkg_seed
from dizoo.classic_control.cartpole.config.cartpole_impala_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = IMPALAPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(data_pusher(cfg, buffer_, group_by_env=True))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.run()


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions ding/policy/edac.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class EDACPolicy(SACPolicy):
| ``embedding_size`` | network. |
5 | ``model.soft_q_`` int 256 | Linear layer size for soft q |
| ``embedding_size`` | network. |
6 | ``model.emsemble_`` int 10 | Number of Q-ensemble network |
| ``num`` | |
6 | ``model.ensemble`` int 10 | Number of Q-ensemble network |
| ``_num`` | |
| | | is False.
7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when
| ``_rate_q`` | network. | model.value_network
Expand Down
40 changes: 15 additions & 25 deletions ding/policy/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from typing import List, Dict, Any, Tuple

import torch
import treetensor.torch as ttorch

from ding.model import model_wrap
from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action, get_train_sample
from ding.torch_utils import Adam, RMSprop, to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from ding.utils.data import default_collate, default_decollate, ttorch_collate
from ding.policy.base_policy import Policy


Expand Down Expand Up @@ -55,7 +56,6 @@ class IMPALAPolicy(Policy):
# (bool) Whether to need policy data in process transition
transition_with_policy_data=True,
learn=dict(

# (int) collect n_sample data, train model update_per_collect times
# here we follow ppo serial pipeline
update_per_collect=4,
Expand Down Expand Up @@ -158,7 +158,13 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]):
- done (:obj:`torch.FloatTensor`): :math:`(T, B)`
- weight (:obj:`torch.FloatTensor`): :math:`(T, B)`
"""
data = default_collate(data)
elem = data[0]
if isinstance(elem, dict): # old pipeline
data = default_collate(data)
elif isinstance(elem, list): # new task pipeline
data = default_collate(default_collate(data))
else:
raise TypeError("not support element type ({}) in IMPALA".format(type(elem)))
if self._cuda:
data = to_device(data, self._device)
if self._priority_IS_weight:
Expand All @@ -167,27 +173,11 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]):
data['weight'] = data['IS']
else:
data['weight'] = data.get('weight', None)
data['obs_plus_1'] = torch.cat((data['obs'] + data['next_obs'][-1:]), dim=0) # shape (T+1)*B,env_obs_shape
if self._action_space == 'continuous':
data['logit']['mu'] = torch.cat(
data['logit']['mu'], dim=0
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
data['logit']['sigma'] = torch.cat(
data['logit']['sigma'], dim=0
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
data['action'] = torch.cat(
data['action'], dim=0
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
elif self._action_space == 'discrete':
data['logit'] = torch.cat(
data['logit'], dim=0
).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
data['action'] = torch.cat(data['action'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
data['done'] = torch.cat(data['done'], dim=0).reshape(self._unroll_len, -1).float() # shape T,B,
data['reward'] = torch.cat(data['reward'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
data['weight'] = torch.cat(
data['weight'], dim=0
).reshape(self._unroll_len, -1) if data['weight'] else None # shape T,B
if isinstance(elem, dict): # old pipeline
for k in data:
if isinstance(data[k], list):
data[k] = default_collate(data[k])
data['obs_plus_1'] = torch.cat([data['obs'], data['next_obs'][-1:]], dim=0) # shape (T+1)*B,env_obs_shape
return data

def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
Expand Down Expand Up @@ -276,7 +266,7 @@ def _reshape_data(self, output: Dict[str, Any], data: Dict[str, Any]) -> Tuple[A
actions = data['action'] # shape T,B for discrete # shape T,B,env_action_shape for continuous
values = output['value'].reshape(self._unroll_len + 1, -1) # shape T+1,B,env_action_shape
rewards = data['reward'] # shape T,B
weights_ = 1 - data['done'] # shape T,B
weights_ = 1 - data['done'].float() # shape T,B
weights = torch.ones_like(rewards) # shape T,B
values[1:] = values[1:] * weights_
weights[1:] = weights_[:-1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
action_shape=2,
encoder_hidden_size_list=[64, 64],
),
# (int) the trajectory length to calculate v-trace target
unroll_len=8,
learn=dict(
# (int) collect n_sample data, train model update_per_collect times
# here we follow ppo serial pipeline
Expand All @@ -30,8 +32,6 @@
discount_factor=0.9,
# (float) additional discounting parameter
lambda_=0.95,
# (int) the trajectory length to calculate v-trace target
unroll_len=32,
# (float) clip ratio of importance weights
rho_clip_ratio=1.0,
# (float) clip ratio of importance weights
Expand All @@ -42,8 +42,6 @@
collect=dict(
# (int) collect n_sample data, train model n_iteration times
n_sample=16,
# (int) the trajectory length to calculate v-trace target
unroll_len=32,
# (float) discount factor for future reward, defaults int [0, 1]
discount_factor=0.9,
gae_lambda=0.95,
Expand Down

0 comments on commit e33c958

Please sign in to comment.