From 6b96f124ae316838fc3fb032040d03232ce5d61f Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 26 Apr 2020 15:11:20 +0800 Subject: [PATCH] fix pdqn --- .github/ISSUE_TEMPLATE.md | 8 ++--- .github/PULL_REQUEST_TEMPLATE.md | 8 ++--- README.md | 2 +- docs/index.rst | 2 +- test/base/test_buffer.py | 28 +++++++++++++++- test/base/test_prioritized_replay_buffer.py | 37 --------------------- tianshou/__init__.py | 2 +- tianshou/policy/modelfree/dqn.py | 11 +++--- 8 files changed, 45 insertions(+), 53 deletions(-) delete mode 100644 test/base/test_prioritized_replay_buffer.py diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 0d6b9e44a..fdc290884 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -4,13 +4,13 @@ + [ ] documentation request (i.e. "X is missing from the documentation.") + [ ] new feature request - [ ] I have visited the [source website], and in particular read the [known issues] -- [ ] I have searched through the [issue tracker] for duplicates +- [ ] I have searched through the [issue categories] for duplicates - [ ] I have mentioned version numbers, operating system and environment, where applicable: ```python - import tianshou, sys - print(tianshou.__version__, sys.version, sys.platform) + import tianshou, torch, sys + print(tianshou.__version__, torch.__version__, sys.version, sys.platform) ``` [source website]: https://github.com/thu-ml/tianshou/ [known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues - [issue tracker]: https://github.com/thu-ml/tianshou/projects/2 + [issue categories]: https://github.com/thu-ml/tianshou/projects/2 diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 5e9db0c18..3fef4b14a 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,13 +8,13 @@ Less important but also useful: - [ ] I have visited the [source website], and in particular read the [known issues] -- [ ] I have searched through the [issue tracker] for duplicates +- [ ] I have searched through the [issue categories] for duplicates - [ ] I have mentioned version numbers, operating system and environment, where applicable: ```python - import tianshou, sys - print(tianshou.__version__, sys.version, sys.platform) + import tianshou, torch, sys + print(tianshou.__version__, torch.__version__, sys.version, sys.platform) ``` [source website]: https://github.com/thu-ml/tianshou [known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues - [issue tracker]: https://github.com/thu-ml/tianshou/projects/2 + [issue categories]: https://github.com/thu-ml/tianshou/projects/2 diff --git a/README.md b/README.md index fc10ed9a1..3d81e8e74 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) - [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns -- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)) +- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) diff --git a/docs/index.rst b/docs/index.rst index c355258c8..312ec79db 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,7 +11,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.PGPolicy` `Policy Gradient `_ * :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network `_ * :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ with n-step returns -* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN `_ * :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ * :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient `_ * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 5c34352a6..af6fbdc75 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,5 +1,5 @@ import numpy as np -from tianshou.data import ReplayBuffer +from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -47,6 +47,32 @@ def test_stack(size=5, bufsize=9, stack_num=4): print(buf) +def test_priortized_replaybuffer(size=32, bufsize=15): + env = MyTestEnv(size) + buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) + obs = env.reset() + action_list = [1] * 5 + [0] * 10 + [1] * 10 + for i, a in enumerate(action_list): + obs_next, rew, done, info = env.step(a) + buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5) + obs = obs_next + assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]), + 1, rtol=1e-12) + data, indice = buf.sample(len(buf) // 2) + if len(buf) // 2 == 0: + assert len(data) == len(buf) + else: + assert len(data) == len(buf) // 2 + assert len(buf) == min(bufsize, i + 1), print(len(buf), i) + assert np.isclose(buf._weight_sum, (buf.weight).sum()) + data, indice = buf.sample(len(buf) // 2) + buf.update_weight(indice, -data.weight / 2) + assert np.isclose(buf.weight[indice], np.power( + np.abs(-data.weight / 2), buf._alpha)).all() + assert np.isclose(buf._weight_sum, (buf.weight).sum()) + + if __name__ == '__main__': test_replaybuffer() test_stack() + test_priortized_replaybuffer(233333, 200000) diff --git a/test/base/test_prioritized_replay_buffer.py b/test/base/test_prioritized_replay_buffer.py deleted file mode 100644 index cca4ace40..000000000 --- a/test/base/test_prioritized_replay_buffer.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np -from tianshou.data import PrioritizedReplayBuffer - -if __name__ == '__main__': - from env import MyTestEnv -else: # pytest - from test.base.env import MyTestEnv - - -def test_replaybuffer(size=32, bufsize=15): - env = MyTestEnv(size) - buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) - obs = env.reset() - action_list = [1] * 5 + [0] * 10 + [1] * 10 - for i, a in enumerate(action_list): - obs_next, rew, done, info = env.step(a) - buf.add(obs, a, rew, done, obs_next, info, np.random.randn()-0.5) - obs = obs_next - assert np.isclose(np.sum((buf.weight/buf._weight_sum)[:buf._size]), 1, - rtol=1e-12) - data, indice = buf.sample(len(buf) // 2) - if len(buf)//2 == 0: - assert len(data) == len(buf) - else: - assert len(data) == len(buf)//2 - assert len(buf) == min(bufsize, i + 1), print(len(buf), i) - assert np.isclose(buf._weight_sum, (buf.weight).sum()) - data, indice = buf.sample(len(buf) // 2) - buf.update_weight(indice, -data.weight/2) - assert np.isclose(buf.weight[indice], np.power( - np.abs(-data.weight/2), buf._alpha)).all() - assert np.isclose(buf._weight_sum, (buf.weight).sum()) - - -if __name__ == "__main__": - test_replaybuffer(233333, 200000) - print("pass") diff --git a/tianshou/__init__.py b/tianshou/__init__.py index f62a62df5..bef80b5cd 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, \ exploration -__version__ = '0.2.1' +__version__ = '0.2.2' __all__ = [ 'env', 'data', diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index d17b57b53..8150bf035 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -104,12 +104,15 @@ def process_fn(self, batch, buffer, indice): r = batch.returns if isinstance(r, np.ndarray): r = torch.tensor(r, device=q.device, dtype=q.dtype) - td = r-q - buffer.update_weight(indice, td.detach().numpy()) + td = r - q + buffer.update_weight(indice, td.detach().cpu().numpy()) impt_weight = torch.tensor(batch.impt_weight, device=q.device, dtype=torch.float) - loss = (td.pow(2)*impt_weight).mean() - batch.loss = loss + loss = (td.pow(2) * impt_weight).mean() + if not hasattr(batch, 'loss'): + batch.loss = loss + else: + batch.loss += loss return batch def forward(self, batch, state=None,