Skip to content

Commit

Permalink
Optimize the log of the entropy coeff instead of the entropy coeff (#56)
Browse files Browse the repository at this point in the history
* optimize the log of the entropy coeff instead of the entropy coeff

* Update log ent coef for SAC and derivates

* Reformat yaml

* Use uv for faster downloads

* Remove TODO

* Remove redundant call

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
jamesheald and araffin authored Nov 1, 2024
1 parent 19c85a1 commit 1c79684
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 40 deletions.
67 changes: 35 additions & 32 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: CI

on:
push:
branches: [ master ]
branches: [master]
pull_request:
branches: [ master ]
branches: [master]

jobs:
build:
Expand All @@ -23,34 +23,37 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
pip install .[tests]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
# - name: Build the doc
# run: |
# make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
# skip mypy, jax doesn't have its latest version for python 3.8
if: "!(matrix.python-version == '3.8')"
- name: Test with pytest
run: |
make pytest
uv pip install --system .[tests]
# Use headless version
uv pip install --system opencv-python-headless
- name: Lint with ruff
run: |
make lint
# - name: Build the doc
# run: |
# make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
# skip mypy, jax doesn't have its latest version for python 3.8
if: "!(matrix.python-version == '3.8')"
- name: Test with pytest
run: |
make pytest
4 changes: 3 additions & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,10 @@ def actor_loss(
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr]
return ent_coef_loss

ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)
Expand Down
1 change: 0 additions & 1 deletion sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
),
)

# TODO: jit qf.apply_fn too?
self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign]

return key
Expand Down
6 changes: 3 additions & 3 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ def _setup_model(self) -> None:
ent_coef_init = float(self.ent_coef_init.split("_")[1])
assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0"

# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
self.ent_coef = EntropyCoef(ent_coef_init)
else:
# This will throw an error if a malformed string (different from 'auto') is passed
Expand Down Expand Up @@ -325,8 +323,10 @@ def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState:
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr]
return ent_coef_loss

ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)
Expand Down
5 changes: 3 additions & 2 deletions sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,10 @@ def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState) ->
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
# ent_coef_loss = (jnp.log(ent_coef_value) * (entropy - target_entropy)).mean()
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
ent_coef_loss = jnp.log(ent_coef_value) * (entropy - target_entropy).mean() # type: ignore[union-attr]
return ent_coef_loss

ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)
Expand Down
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.17.0
0.18.0

0 comments on commit 1c79684

Please sign in to comment.