Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: DroQ and TD3+TQC jax implementation #272

Draft
wants to merge 31 commits into
base: master
Choose a base branch
from

Conversation

araffin
Copy link

@araffin araffin commented Sep 16, 2022

Description

FYI: unpolished jax implementation of TD3+DroQ and TD3+TQC implementations.
Related to #262 #258
My plan is to try to have sac in jax, but currently jax rely on tensorflow for probability distributions :/
So I adapted TD3 instead.
I also want to make it even faster but would need to tweak a bit the way the replay buffer is used.

EDIT: apparently tfd doesn't depends on tf anymore for latest version: https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX

Reference:

EDIT: SBX = SB3 + JAX: https://github.com/araffin/sbx

Known difference with original implementation: qf are updated at the same time of the actor instead of after each gradient step.

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format with width=500 and height=300).
    • I have added links to the tracked experiments.
    • I have updated the overview sections at the docs and the repo
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Sep 16, 2022

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Sep 24, 2022 at 4:50PM (UTC)

This reverts commit d5704b3.
Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 how does Adan perform?

@araffin
Copy link
Author

araffin commented Sep 18, 2022

eyes how does Adan perform?

Results are very preliminary, ADAN performs on par or slightly better than ADAM, but nothing significant yet.
The noticeable difference is the FPS though (adan slower, for instance 100 FPS vs 130 FPS).
Btw, I managed to JIT the for loop =) it goes 2x faster now but results are different than without jit 👀 (not worse/better, just different)

@vwxyzjn
Copy link
Owner

vwxyzjn commented Sep 19, 2022

FYI https://github.com/deepmind/distrax might be a better replacement for tensorflow probability

@joaogui1
Copy link
Collaborator

fwiw you can also use tensorflow_probability with a jax backend and then you don't need to use tensorflow at all (in one of their tutorials they even explicitly unninstall tf)

@araffin
Copy link
Author

araffin commented Sep 23, 2022

@vwxyzjn Good news, I've got a TQC + SAC version working =) (currently doing some runs)

@joaogui1 thanks, I gave distrax a try but it was giving me weird errors, and at the end it still depends on tf proba (which doesn't require tensorflow as I learned =)), so I switched to tf proba ;)

@araffin
Copy link
Author

araffin commented Sep 29, 2022

Fyi, I converted that single file to a proof of concept of SB3 + Jax (SBX): https://github.com/araffin/sbx
the nice thing is that I'm re-using SB3 base class, which means it has access to saving/loading/scikit interface/callbacks and soon the RL zoo =)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants