Skip to content

Commit

Permalink
Fix warning when loading a RecurrentPPO model (#255)
Browse files Browse the repository at this point in the history
* Reformat configs

* Fix warning when loading RecurrentPPO agent
  • Loading branch information
araffin authored Aug 13, 2024
1 parent 5c81398 commit 42595a5
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 50 deletions.
76 changes: 38 additions & 38 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: CI

on:
push:
branches: [ master ]
branches: [master]
pull_request:
branches: [ master ]
branches: [master]

jobs:
build:
Expand All @@ -22,42 +22,42 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install master version
# and dependencies for docs and tests
pip install "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
pip install .
# Use headless version
pip install opencv-python-headless
# Install master version
# and dependencies for docs and tests
pip install "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
pip install .
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Check codestyle
run: |
make check-codestyle
- name: Build the doc
run: |
make doc
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
- name: Lint with ruff
run: |
make lint
- name: Check codestyle
run: |
make check-codestyle
- name: Build the doc
run: |
make doc
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.4.0a4 (WIP)
Release 2.4.0a8 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -18,6 +18,7 @@ Bug Fixes:
^^^^^^^^^^
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
- Updated QR-DQN paper link in docs (@corentinlger)
- Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False)

Deprecations:
^^^^^^^^^^^^^
Expand Down
16 changes: 8 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ignore = ["B028", "RUF013"]

[tool.ruff.lint.per-file-ignores]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py"= ["RUF012", "RUF013"]
"./tests/*.py" = ["RUF012", "RUF013"]

[tool.ruff.lint.mccabe]
# Unlike Flake8, ruff default to a complexity level of 10.
Expand All @@ -35,22 +35,22 @@ exclude = """(?x)(

[tool.pytest.ini_options]
# Deterministic ordering for tests; useful for pytest-xdist.
env = [
"PYTHONHASHSEED=0"
]
env = ["PYTHONHASHSEED=0"]

filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')"
]
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]

[tool.coverage.run]
disable_warnings = ["couldnt-parse"]
branch = false
omit = ["tests/*", "setup.py"]

[tool.coverage.report]
exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
exclude_lines = [
"pragma: no cover",
"raise NotImplementedError()",
"if typing.TYPE_CHECKING:",
]
2 changes: 1 addition & 1 deletion sb3_contrib/common/maskable/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def predict(
with th.no_grad():
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
# Convert to numpy
actions = actions.cpu().numpy()
actions = actions.cpu().numpy() # type: ignore[assignment]

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
Expand Down
5 changes: 4 additions & 1 deletion sb3_contrib/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union

import numpy as np
import torch as th
Expand Down Expand Up @@ -455,3 +455,6 @@ def learn(
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

def _excluded_save_params(self) -> List[str]:
return super()._excluded_save_params() + ["_last_lstm_states"] # noqa: RUF005
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a4
2.4.0a8

0 comments on commit 42595a5

Please sign in to comment.