Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the log of the entropy coeff instead of the entropy coeff #56

Merged
merged 6 commits into from
Nov 1, 2024

Conversation

jamesheald
Copy link
Contributor

SB3 SAC optimizes the log of the entropy coeff instead of the entropy coeff (link), as discussed here. In contrast, SBX SAC optimizes the entropy coeff (link).

In my experience, this small change has a huge impact on the stability and performance of the algorithm (optimizing the log of the entropy coeff produces much better results than optimizing the entropy coeff).

@jan1854
Copy link
Collaborator

jan1854 commented Oct 7, 2024

Hi, the SAC version in SBX already optimizes the log of the entropy coefficient. The EntropyCoefficient module contains a single parameter log_ent_coef, which is exponentiated when the module is called (see here). Therefore, ent_coef_value here is the exponential of the log parameter (which is optimized).

@jamesheald
Copy link
Contributor Author

jamesheald commented Oct 7, 2024

Hi @jan1854. I agree. Maybe what I should have said is that the losses are calculated differently in SB3 and SBX. In SB3, the loss is
loss = log(ent_coeff) * (entropy-target_entropy) ,
whereas in SBX, the loss is
loss = ent_coeff * (entropy-target_entropy) .
The minimum of the loss is the same in both cases (so in that sense you can use either), but the gradients are different. Using the former loss can lead to better numerical stability, which is what I have found myself.

@jan1854
Copy link
Collaborator

jan1854 commented Oct 7, 2024

I see, the loss is indeed different from SB3. I think it could be a bit cleaner to change EntropyCoef to directly return the log entropy (basically removing the jnp.exp()) and exponentiating explicitly when updating the actor / critic. This would avoid accessing the parameters directly in temperature_loss (and it would be closer to the SB3 implementation). What do you think @araffin?

@araffin
Copy link
Owner

araffin commented Oct 7, 2024

@jamesheald thanks for finding that out. It is a difference with SB3 indeed.

I have to say, I don't remember why I didn't directly store/returned the ent coef (probably the intention was to use the log in the loss but I forgot 🙈 ).
I think a backward compatible solution would be to store and return the ent coef directly (instead of using exp() all the time) and convert old models log_ent_coef param automatically. What do you think @jan1854 ?

EDIT: the easiest is just to take the log in the loss

cc @danielpalen this should not but might affect CrossQ result slightly (I don't think too much because of Stable-Baselines-Team/stable-baselines3-contrib#243 results).

@jamesheald do you have access to some compute to compare to current SBX master version? (I can grant you access to https://wandb.ai/openrlbenchmark/sbx if you want to push the runs too)

@jamesheald
Copy link
Contributor Author

@araffin yes I have access to compute. Do you just want me to perform a few of runs of the modified SBX SAC on Walker2d-v4 using default parameters?

@araffin
Copy link
Owner

araffin commented Oct 7, 2024

Do you just want me to perform a few of runs of the modified SBX SAC on Walker2d-v4 using default parameters?

I would do on the classic 4 mujoco envs, see for instance https://wandb.ai/openrlbenchmark/sbx/reports/CrossQ-SBX-Perf-Report--Vmlldzo3MzQxOTAw

I'll also some runs on my machine (I only have access to one machine currently) and prepare a report.

@araffin
Copy link
Owner

araffin commented Oct 7, 2024

Update: I started the report here: https://wandb.ai/openrlbenchmark/sbx/reports/SAC-SB3-vs-SBX-Ent-coef-vs-log-ent-coef--Vmlldzo5NjI3NTQ5 and will start the runs soon (I'm using the RL Zoo, see SBX readme)

My config looks like that:

default_hyperparams = dict(
    n_envs=1,
    n_timesteps=int(1e6),
    policy="MlpPolicy",
    # qf_learning_rate=1e-3,
    policy_kwargs={},
    learning_starts=10_000,
)

hyperparams = {}

for env_id in [
    "HalfCheetah-v4",
    "Humanoid-v4",
    "HalfCheetahBulletEnv-v0",
    "Ant-v4",
    "Hopper-v4",
    "Walker2d-v4",
    "Swimmer-v4",
]:
    hyperparams[env_id] = default_hyperparams

and the command line (I'm modifying the train freq/gradient steps for faster execution):

#!/bin/bash

# seeds=(3831217417 1756129809 4075310593 2904435568 4115729820 2726071262 1865059290 1408779145 3716099507 411100252)
# envs=(HalfCheetah-v4 Ant-v4 Hopper-v4 Walker2d-v4)
# note: for swimmer, we need to set gamma=0.999

# Note: we use n_envs=gradient_steps=30 to makes things 10x faster thanks to JIT compilation
# the results are equivalent to n_envs=train_freq=gradient_steps=1 as we have a 1:1 ratio
# between gradient steps and collected data
for env_id in ${envs[*]}; do
  for seed in ${seeds[*]}; do
   OMP_NUM_THREADS=1 python train.py --algo sac --env $env_id --seed $seed \
   --eval-freq 25000 --verbose 0 --n-eval-envs 5 --eval-episodes 20 \
    -c hyperparams/sac.py -n 500000 \
   --log-interval 100 -param n_envs:30 gradient_steps:30 --vec-env subproc -P \
   --track --wandb-entity openrlbenchmark --wandb-project-name sbx -tags 0.18.0 ent-coef
  done
done
 seeds="3831217417 1756129809 4075310593" envs="HalfCheetah-v4 Ant-v4 Hopper-v4 Walker2d-v4" ./run_sac.sh 

Note: for swimmer, you need to set gamma=0.999, see DLR-RM/rl-baselines3-zoo#447

@jan1854
Copy link
Collaborator

jan1854 commented Oct 7, 2024

I think a backward compatible solution would be to store and return the ent coef directly (instead of using exp() all the time) and convert old models log_ent_coef param automatically. What do you think @jan1854 ?

EDIT: the easiest is just to take the log in the loss

For SAC, I think we would just need to call exp() here and here, and this solution should be backwards compatible (as the name of the parameter does not change). We can of course also just call log() in the loss, but that would mean that we essentially do log(exp(log_ent_coef)) in the entropy cofficient loss, which does not feel super clean.

@araffin
Copy link
Owner

araffin commented Oct 7, 2024

For SAC, I think we would just need to call exp() here and here, and this solution should be backwards compatible (as the name of the parameter does not change).

I considered that but the issue was with both naming and with constant entropy coef (in that case, taking the exp is wrong).

, but that would mean that we essentially do log(exp(log_ent_coef)) in the entropy cofficient loss, which does not f

I also thought the same but hope that jax jit may simplify this expression to identity?
see https://jax.readthedocs.io/en/latest/faq.html#jit-changes-the-exact-numerics-of-outputs
That also means changing things at one place only.

@jan1854
Copy link
Collaborator

jan1854 commented Oct 7, 2024

True, we would probably need to rename ent_coef_state to log_ent_coef_state or something, which will cause problems when loading an old model. So your solution is probably the easiest.

@araffin
Copy link
Owner

araffin commented Oct 16, 2024

I did many more runs, and looking at https://wandb.ai/openrlbenchmark/sbx/reports/SAC-SB3-vs-SBX-Ent-coef-vs-log-ent-coef--Vmlldzo5NjI3NTQ5 I cannot see any significant different between SBX master (0.17.0) and this PR (0.18.0).

@jamesheald can you confirm that this PR solves your issue?
Can you provide a colab notebook or a minimal example (so minimal number of lines but everything needed to reproduce) to reproduce the original issue?

Copy link
Owner

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks =)

@araffin araffin merged commit 1c79684 into araffin:master Nov 1, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants