Skip to content

Commit

Permalink
feature(zjow): add middleware for ape-x structure pipeline (#696)
Browse files Browse the repository at this point in the history
* Add priority collected in collector; Add Periodical model exchanger middleware

* polish code
  • Loading branch information
zjowowen authored Aug 11, 2023
1 parent 49fc489 commit d905ca8
Show file tree
Hide file tree
Showing 11 changed files with 318 additions and 6 deletions.
5 changes: 4 additions & 1 deletion ding/data/buffer/middleware/priority.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def __init__(

def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData:
if meta is None:
meta = {'priority': self.max_priority}
if 'priority' in data:
meta = {'priority': data.pop('priority')}
else:
meta = {'priority': self.max_priority}
else:
if 'priority' not in meta:
meta['priority'] = self.max_priority
Expand Down
32 changes: 32 additions & 0 deletions ding/data/buffer/tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,38 @@ def test_priority():
assert buffer.count() == 0


@pytest.mark.unittest
def test_priority_from_collector():
N = 5
buffer = DequeBuffer(size=10)
buffer.use(PriorityExperienceReplay(buffer, IS_weight=True))
for _ in range(N):
tmp_data = get_data()
tmp_data['priority'] = 2.0
buffer.push(get_data())
assert buffer.count() == N
for _ in range(N):
tmp_data = get_data()
tmp_data['priority'] = 2.0
buffer.push(get_data())
assert buffer.count() == N + N
data = buffer.sample(size=N + N, replace=False)
assert len(data) == N + N
for item in data:
meta = item.meta
assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS']))
meta['priority'] = 3.0
for item in data:
data, index, meta = item.data, item.index, item.meta
buffer.update(index, data, meta)
data = buffer.sample(size=1)
assert data[0].meta['priority'] == 3.0
buffer.delete(data[0].index)
assert buffer.count() == N + N - 1
buffer.clear()
assert buffer.count() == 0


@pytest.mark.unittest
def test_padding():
buffer = DequeBuffer(size=10)
Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from .collector import StepCollector, EpisodeCollector, PPOFStepCollector
from .learner import OffPolicyLearner, HERLearner
from .ckpt_handler import CkptSaver
from .distributer import ContextExchanger, ModelExchanger
from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
from .barrier import Barrier, BarrierRuntime
126 changes: 126 additions & 0 deletions ding/framework/middleware/distributer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from time import sleep, time
from dataclasses import fields
from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union
Expand Down Expand Up @@ -287,3 +288,128 @@ def _send_callback(self, storage: Storage):
def __del__(self):
if self._model_loader:
self._model_loader.shutdown()


class PeriodicalModelExchanger:

def __init__(
self,
model: "Module",
mode: str,
period: int = 1,
delay_toleration: float = np.inf,
stale_toleration: int = 1,
event_name: str = "model_exchanger",
model_loader: Optional[ModelLoader] = None
) -> None:
"""
Overview:
Exchange model between processes, set the mode to "send" or "receive" to specify the role of the process.
If you are using a shared model on a single host, there is no need to use this middleware.
Arguments:
- model (:obj:`torch.nn.Module`): Pytorch module.
- mode (:obj:`str`): "send" or "receive".
- period (:obj:`int`): The period of model exchange.
- delay_toleration (:obj:`float`): The permitted time interval for receiving model after being sent.
- stale_toleration (:obj:`int`): The permitted number of iterations for receiving model after being sent.
- event_name (:obj:`str`): The event name for model exchange.
- model_loader (:obj:`ModelLoader`): ModelLoader for this PeriodicalModelExchanger to use.
"""
self._model = model
self._model_loader = model_loader
self._event_name = event_name
self._period = period
self._mode = mode
if self._mode == "receive":
self._id_counter = -1
self._model_id = -1
else:
self._id_counter = 0
self._stale_toleration = stale_toleration
self._model_stale = stale_toleration
self._delay_toleration = delay_toleration
self._state_dict_cache: Optional[Union[object, Storage]] = None

if self._mode == "receive":
task.on(self._event_name, self._cache_state_dict)
if model_loader:
task.once("finish", lambda _: model_loader.shutdown())

def _cache_state_dict(self, msg: Dict[str, Any]):
if msg['id'] % self._period == 0:
self._state_dict_cache = msg['model']
self._id_counter = msg['id']
self._time = msg['time']

def __new__(cls, *args, **kwargs):
return super(PeriodicalModelExchanger, cls).__new__(cls)

def __call__(self, ctx: "Context") -> Any:
if self._model_loader:
self._model_loader.start()

if self._mode == "receive":
if ctx.total_step != 0: # Skip first iteration
self._update_model()
elif self._mode == "send":
yield
if self._id_counter % self._period == 0:
self._send_model(id=self._id_counter)
self._id_counter += 1
else:
raise NotImplementedError

def _update_model(self):
start = time()
while True:
if task.finish:
return
if time() - start > 60:
logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id))
self._model_stale += 1
break
if self._state_dict_cache is None:
if self._model_stale < self._stale_toleration and time() - self._time < self._delay_toleration:
self._model_stale += 1
break
else:
sleep(0.01)
else:
if self._id_counter > self._model_id and time() - self._time < self._delay_toleration:
if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None:
try:
self._model.load_state_dict(self._model_loader.load(self._state_dict_cache))
self._state_dict_cache = None
self._model_id = self._id_counter
self._model_stale = 1
break
except FileNotFoundError as e:
logging.warning(
"Model file has been deleted on node {}, maybe you can increase the ttl.".format(
task.router.node_id
)
)
self._state_dict_cache = None
continue
else:
self._model.load_state_dict(self._state_dict_cache)
self._state_dict_cache = None
self._model_id = self._id_counter
self._model_stale = 1
break
else:
self._model_stale += 1

def _send_model(self, id: int):
if self._model_loader:
self._model_loader.save(self._send_callback)
else:
task.emit(self._event_name, {'id': id, 'model': self._model.state_dict(), 'time': time()}, only_remote=True)

def _send_callback(self, storage: Storage):
if task.running:
task.emit(self._event_name, storage, only_remote=True)

def __del__(self):
if self._model_loader:
self._model_loader.shutdown()
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
from .explorer import eps_greedy_handler, eps_greedy_masker
from .advantage_estimator import gae_estimator, ppof_adv_estimator
from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer

from .priority import priority_calculator
from .timer import epoch_timer
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _fetch_and_enhance(ctx: "OnlineRLContext"):

def nstep_reward_enhancer(cfg: EasyDict) -> Callable:

if task.router.is_active and not task.has_role(task.role.LEARNER):
if task.router.is_active and (not task.has_role(task.role.LEARNER) and not task.has_role(task.role.COLLECTOR)):
return task.void()

def _enhance(ctx: "OnlineRLContext"):
Expand Down
24 changes: 24 additions & 0 deletions ding/framework/middleware/functional/priority.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import TYPE_CHECKING, Callable
from ding.framework import task
if TYPE_CHECKING:
from ding.framework import OnlineRLContext


def priority_calculator(priority_calculation_fn: Callable) -> Callable:
"""
Overview:
The middleware that calculates the priority of the collected data.
Arguments:
- priority_calculation_fn (:obj:`Callable`): The function that calculates the priority of the collected data.
"""

if task.router.is_active and not task.has_role(task.role.COLLECTOR):
return task.void()

def _priority_calculator(ctx: "OnlineRLContext") -> None:

priority = priority_calculation_fn(ctx.trajectories)
for i in range(len(priority)):
ctx.trajectories[i]['priority'] = priority[i]

return _priority_calculator
2 changes: 1 addition & 1 deletion ding/framework/middleware/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
"""
self.cfg = cfg
self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_))
self._trainer = task.wrap(trainer(cfg, policy))
self._trainer = task.wrap(trainer(cfg, policy, log_freq=log_freq))
if reward_model is not None:
self._reward_estimator = task.wrap(reward_estimator(cfg, reward_model))
else:
Expand Down
48 changes: 47 additions & 1 deletion ding/framework/middleware/tests/test_distributer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ding.data.storage_loader import FileStorageLoader
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger
from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
from ding.framework.parallel import Parallel
from ding.utils.default_helper import set_pkg_seed
from os import path
Expand Down Expand Up @@ -221,3 +221,49 @@ def pred(ctx):
@pytest.mark.tmp
def test_model_exchanger_with_model_loader():
Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main_with_model_loader)


def periodical_model_exchanger_main():
with task.start(ctx=OnlineRLContext()):
set_pkg_seed(0, use_cuda=False)
policy = MockPolicy()
X = torch.rand(10)
y = torch.rand(10)

if task.router.node_id == 0:
task.add_role(task.role.LEARNER)
task.use(PeriodicalModelExchanger(policy._model, mode="send", period=3))
else:
task.add_role(task.role.COLLECTOR)
task.use(PeriodicalModelExchanger(policy._model, mode="receive", period=1, stale_toleration=3))

if task.has_role(task.role.LEARNER):

def train(ctx):
policy.train(X, y)
sleep(0.3)

task.use(train)
else:
y_pred1 = policy.predict(X)
print("y_pred1: ", y_pred1)
stale = 1

def pred(ctx):
nonlocal stale
y_pred2 = policy.predict(X)
print("y_pred2: ", y_pred2)
stale += 1
assert stale <= 3 or all(y_pred1 == y_pred2)
if any(y_pred1 != y_pred2):
stale = 1

sleep(0.3)

task.use(pred)
task.run(8)


@pytest.mark.tmp
def test_periodical_model_exchanger():
Parallel.runner(n_parallel_workers=2, startup_interval=0)(periodical_model_exchanger_main)
33 changes: 33 additions & 0 deletions ding/framework/middleware/tests/test_priority.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#unittest for priority_calculator

import unittest
import pytest
import numpy as np
from unittest.mock import Mock, patch
from ding.framework import OnlineRLContext, OfflineRLContext
from ding.framework import task, Parallel
from ding.framework.middleware.functional import priority_calculator


class MockPolicy(Mock):

def priority_fun(self, data):
return np.random.rand(len(data))


@pytest.mark.unittest
def test_priority_calculator():
policy = MockPolicy()
ctx = OnlineRLContext()
ctx.trajectories = [
{
'obs': np.random.rand(2, 2),
'next_obs': np.random.rand(2, 2),
'reward': np.random.rand(1),
'info': {}
} for _ in range(10)
]
priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun)
priority_calculator_middleware(ctx)
assert len(ctx.trajectories) == 10
assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories])
48 changes: 48 additions & 0 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,54 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}

def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = False) -> Dict[str, Any]:
"""
Overview:
Calculate priority for replay buffer.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training.
Returns:
- priority (:obj:`Dict[str, Any]`): Dict type priority data, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``
ReturnsKeys:
- necessary: ``priority``
"""

if update_target_model:
self._target_model.load_state_dict(self._learn_model.state_dict())

data = default_preprocess_learn(
data,
use_priority=False,
use_priority_IS_weight=False,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.eval()
self._target_model.eval()
with torch.no_grad():
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
target_q_value = self._target_model.forward(data['next_obs'])['logit']
# Max q value action (main model), i.e. Double DQN
target_q_action = self._learn_model.forward(data['next_obs'])['action']
data_n = q_nstep_td_data(
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
)
value_gamma = data.get('value_gamma')
loss, td_error_per_sample = q_nstep_td_error(
data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
)
return {'priority': td_error_per_sample.abs().tolist()}


@POLICY_REGISTRY.register('dqn_stdim')
class DQNSTDIMPolicy(DQNPolicy):
Expand Down

0 comments on commit d905ca8

Please sign in to comment.