Skip to content

Commit

Permalink
fix pdqn
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Apr 26, 2020
1 parent b237494 commit 6b96f12
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 53 deletions.
8 changes: 4 additions & 4 deletions .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
+ [ ] documentation request (i.e. "X is missing from the documentation.")
+ [ ] new feature request
- [ ] I have visited the [source website], and in particular read the [known issues]
- [ ] I have searched through the [issue tracker] for duplicates
- [ ] I have searched through the [issue categories] for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, sys
print(tianshou.__version__, sys.version, sys.platform)
import tianshou, torch, sys
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
```

[source website]: https://github.com/thu-ml/tianshou/
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
[issue tracker]: https://github.com/thu-ml/tianshou/projects/2
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
8 changes: 4 additions & 4 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
Less important but also useful:

- [ ] I have visited the [source website], and in particular read the [known issues]
- [ ] I have searched through the [issue tracker] for duplicates
- [ ] I have searched through the [issue categories] for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, sys
print(tianshou.__version__, sys.version, sys.platform)
import tianshou, torch, sys
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
```

[source website]: https://github.com/thu-ml/tianshou
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
[issue tracker]: https://github.com/thu-ml/tianshou/projects/2
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf))
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf`_
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
Expand Down
28 changes: 27 additions & 1 deletion test/base/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from tianshou.data import ReplayBuffer
from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer

if __name__ == '__main__':
from env import MyTestEnv
Expand Down Expand Up @@ -47,6 +47,32 @@ def test_stack(size=5, bufsize=9, stack_num=4):
print(buf)


def test_priortized_replaybuffer(size=32, bufsize=15):
env = MyTestEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
obs = obs_next
assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
1, rtol=1e-12)
data, indice = buf.sample(len(buf) // 2)
if len(buf) // 2 == 0:
assert len(data) == len(buf)
else:
assert len(data) == len(buf) // 2
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
assert np.isclose(buf._weight_sum, (buf.weight).sum())
data, indice = buf.sample(len(buf) // 2)
buf.update_weight(indice, -data.weight / 2)
assert np.isclose(buf.weight[indice], np.power(
np.abs(-data.weight / 2), buf._alpha)).all()
assert np.isclose(buf._weight_sum, (buf.weight).sum())


if __name__ == '__main__':
test_replaybuffer()
test_stack()
test_priortized_replaybuffer(233333, 200000)
37 changes: 0 additions & 37 deletions test/base/test_prioritized_replay_buffer.py

This file was deleted.

2 changes: 1 addition & 1 deletion tianshou/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, \
exploration

__version__ = '0.2.1'
__version__ = '0.2.2'
__all__ = [
'env',
'data',
Expand Down
11 changes: 7 additions & 4 deletions tianshou/policy/modelfree/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,15 @@ def process_fn(self, batch, buffer, indice):
r = batch.returns
if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype)
td = r-q
buffer.update_weight(indice, td.detach().numpy())
td = r - q
buffer.update_weight(indice, td.detach().cpu().numpy())
impt_weight = torch.tensor(batch.impt_weight,
device=q.device, dtype=torch.float)
loss = (td.pow(2)*impt_weight).mean()
batch.loss = loss
loss = (td.pow(2) * impt_weight).mean()
if not hasattr(batch, 'loss'):
batch.loss = loss
else:
batch.loss += loss
return batch

def forward(self, batch, state=None,
Expand Down

0 comments on commit 6b96f12

Please sign in to comment.