Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix step function to not reset every step when using auto-reset #36

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ALjone
Copy link
Contributor

@ALjone ALjone commented Nov 8, 2024

Per now, if using auto-reset, the env is reset every single step

if self.auto_reset:
        obs_re, state_re = self.reset_env(key_reset, params)
        # Use lax.cond to efficiently choose between obs_re and obs_st
        obs = jax.lax.cond(
            done,
            lambda: obs_re,
            lambda: obs_st
        )
        state = jax.lax.cond(
            done,
            lambda: state_re,
            lambda: state_st
        )

This is fairly expensive, and can be avoided by using cond to only call the reset function when needed, which saves around 504 calls to the reset function per game:

if self.auto_reset:
        # Reset the env only if done to avoid generating new state every step
        obs, state = jax.lax.cond(
            done,
            lambda: self.reset_env(key_reset, params),
            lambda: (obs_st, state_st),
        )

I'm not a Jax expert, but as far as I can tell, the above example should work.

Doing this, I observe an increase in steps per second of more than 30%

Copy link

netlify bot commented Nov 8, 2024

Deploy Preview for lux-eye-s3 canceled.

Name Link
🔨 Latest commit 2368823
🔍 Latest deploy log https://app.netlify.com/sites/lux-eye-s3/deploys/672e43c4bbbc480008feee87

@StoneT2000
Copy link
Member

I am surprised jax can't optimize this part out. I'll take a look and benchmark as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants