From b372e9a429882300fd52c139edac3a9d1eef3a10 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 3 Oct 2022 17:36:08 +0200 Subject: [PATCH] Rename to RL-Zoo3 and better packaging (#291) * Rename and better packaging * Move plot scripts inside package --- .coveragerc | 2 +- CHANGELOG.md | 2 +- Makefile | 6 +- README.md | 20 +- docker/Dockerfile | 2 +- enjoy.py | 2 +- hyperparams/her.yml | 6 +- hyperparams/ppo.yml | 2 +- hyperparams/ppo_lstm.yml | 4 +- hyperparams/sac.yml | 16 +- hyperparams/tqc.yml | 8 +- requirements.txt | 1 - rl_zoo/cli.py | 15 - rl_zoo/version.txt | 1 - {rl_zoo => rl_zoo3}/__init__.py | 2 +- {rl_zoo => rl_zoo3}/benchmark.py | 2 +- {rl_zoo => rl_zoo3}/callbacks.py | 0 rl_zoo3/cli.py | 22 + {rl_zoo => rl_zoo3}/enjoy.py | 12 +- {rl_zoo => rl_zoo3}/exp_manager.py | 19 +- {rl_zoo => rl_zoo3}/hyperparams_opt.py | 2 +- {rl_zoo => rl_zoo3}/import_envs.py | 2 +- {rl_zoo => rl_zoo3}/load_from_hub.py | 2 +- rl_zoo3/plots/__init__.py | 3 + rl_zoo3/plots/all_plots.py | 250 ++++++++++ rl_zoo3/plots/plot_from_file.py | 428 ++++++++++++++++++ rl_zoo3/plots/plot_train.py | 98 ++++ .../plots}/score_normalization.py | 0 {rl_zoo => rl_zoo3}/push_to_hub.py | 8 +- {rl_zoo => rl_zoo3}/py.typed | 0 {rl_zoo => rl_zoo3}/record_training.py | 8 +- {rl_zoo => rl_zoo3}/record_video.py | 4 +- {rl_zoo => rl_zoo3}/train.py | 6 +- {rl_zoo => rl_zoo3}/utils.py | 2 +- rl_zoo3/version.txt | 1 + {rl_zoo => rl_zoo3}/wrappers.py | 0 scripts/__init__.py | 0 scripts/all_plots.py | 246 +--------- scripts/migrate_to_hub.py | 4 +- scripts/plot_from_file.py | 425 +---------------- scripts/plot_train.py | 95 +--- scripts/run_docker_cpu.sh | 4 +- scripts/run_docker_gpu.sh | 4 +- setup.cfg | 12 +- setup.py | 30 +- tests/test_callbacks.py | 4 +- tests/test_enjoy.py | 8 +- tests/test_train.py | 2 +- tests/test_wrappers.py | 8 +- train.py | 2 +- 50 files changed, 926 insertions(+), 876 deletions(-) delete mode 100644 rl_zoo/cli.py delete mode 100644 rl_zoo/version.txt rename {rl_zoo => rl_zoo3}/__init__.py (92%) rename {rl_zoo => rl_zoo3}/benchmark.py (98%) rename {rl_zoo => rl_zoo3}/callbacks.py (100%) create mode 100644 rl_zoo3/cli.py rename {rl_zoo => rl_zoo3}/enjoy.py (96%) rename {rl_zoo => rl_zoo3}/exp_manager.py (98%) rename {rl_zoo => rl_zoo3}/hyperparams_opt.py (99%) rename {rl_zoo => rl_zoo3}/import_envs.py (96%) rename {rl_zoo => rl_zoo3}/load_from_hub.py (99%) create mode 100644 rl_zoo3/plots/__init__.py create mode 100644 rl_zoo3/plots/all_plots.py create mode 100644 rl_zoo3/plots/plot_from_file.py create mode 100644 rl_zoo3/plots/plot_train.py rename {scripts => rl_zoo3/plots}/score_normalization.py (100%) rename {rl_zoo => rl_zoo3}/push_to_hub.py (98%) rename {rl_zoo => rl_zoo3}/py.typed (100%) rename {rl_zoo => rl_zoo3}/record_training.py (95%) rename {rl_zoo => rl_zoo3}/record_video.py (97%) rename {rl_zoo => rl_zoo3}/train.py (98%) rename {rl_zoo => rl_zoo3}/utils.py (99%) create mode 100644 rl_zoo3/version.txt rename {rl_zoo => rl_zoo3}/wrappers.py (100%) create mode 100644 scripts/__init__.py diff --git a/.coveragerc b/.coveragerc index 785267de2..b455b76fc 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,7 +2,7 @@ branch = False omit = tests/* - rl_zoo/utils/plot.py + rl_zoo3/utils/plot.py [report] exclude_lines = diff --git a/CHANGELOG.md b/CHANGELOG.md index ef76b847c..992e6783d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ - low pass filter was removed ### New Features -- RL Zoo cli: `rl_zoo train` and `rl_zoo enjoy` +- RL Zoo cli: `rl_zoo3 train` and `rl_zoo3 enjoy` ### Bug fixes diff --git a/Makefile b/Makefile index 8f20105df..7139a61e4 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -LINT_PATHS = *.py tests/ scripts/ rl_zoo/ +LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ # Run pytest and coverage report pytest: @@ -10,7 +10,7 @@ check-trained-agents: # Type check type: - pytype -j auto rl_zoo/ tests/ scripts/ -d import-error + pytype -j auto rl_zoo3/ tests/ scripts/ -d import-error lint: # stop the build if there are Python syntax errors or undefined names @@ -44,12 +44,14 @@ docker-gpu: # PyPi package release release: + # rm -r build/* dist/* python setup.py sdist python setup.py bdist_wheel twine upload dist/* # Test PyPi package release test-release: + # rm -r build/* dist/* python setup.py sdist python setup.py bdist_wheel twine upload --repository-url https://test.pypi.org/legacy/ dist/* diff --git a/README.md b/README.md index ecb870168..ef2fa4d86 100644 --- a/README.md +++ b/README.md @@ -154,13 +154,13 @@ python enjoy.py --algo algo_name --env env_id -f logs/ --exp-id 1 --load-last-ch Upload model to hub (same syntax as for `enjoy.py`): ``` -python -m rl_zoo.push_to_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3 -m "Initial commit" +python -m rl_zoo3.push_to_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3 -m "Initial commit" ``` you can choose custom `repo-name` (default: `{algo}-{env_id}`) by passing a `--repo-name` argument. Download model from hub: ``` -python -m rl_zoo.load_from_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3 +python -m rl_zoo3.load_from_hub --algo ppo --env CartPole-v1 -f logs/ -orga sb3 ``` ## Hyperparameter yaml syntax @@ -255,7 +255,7 @@ for multiple, specify a list: ```yaml env_wrapper: - - rl_zoo.wrappers.DoneOnSuccessWrapper: + - rl_zoo3.wrappers.DoneOnSuccessWrapper: reward_offset: 1.0 - sb3_contrib.common.wrappers.TimeFeatureWrapper ``` @@ -279,7 +279,7 @@ Following the same syntax as env wrappers, you can also add custom callbacks to ```yaml callback: - - rl_zoo.callbacks.ParallelTrainCallback: + - rl_zoo3.callbacks.ParallelTrainCallback: gradient_steps: 256 ``` @@ -306,19 +306,19 @@ Note: if you want to pass a string, you need to escape it like that: `my_string: Record 1000 steps with the latest saved model: ``` -python -m rl_zoo.record_video --algo ppo --env BipedalWalkerHardcore-v3 -n 1000 +python -m rl_zoo3.record_video --algo ppo --env BipedalWalkerHardcore-v3 -n 1000 ``` Use the best saved model instead: ``` -python -m rl_zoo.record_video --algo ppo --env BipedalWalkerHardcore-v3 -n 1000 --load-best +python -m rl_zoo3.record_video --algo ppo --env BipedalWalkerHardcore-v3 -n 1000 --load-best ``` Record a video of a checkpoint saved during training (here the checkpoint name is `rl_model_10000_steps.zip`): ``` -python -m rl_zoo.record_video --algo ppo --env BipedalWalkerHardcore-v3 -n 1000 --load-checkpoint 10000 +python -m rl_zoo3.record_video --algo ppo --env BipedalWalkerHardcore-v3 -n 1000 --load-checkpoint 10000 ``` ## Record a Video of a Training Experiment @@ -328,18 +328,18 @@ Apart from recording videos of specific saved models, it is also possible to rec Record 1000 steps for each checkpoint, latest and best saved models: ``` -python -m rl_zoo.record_training --algo ppo --env CartPole-v1 -n 1000 -f logs --deterministic +python -m rl_zoo3.record_training --algo ppo --env CartPole-v1 -n 1000 -f logs --deterministic ``` The previous command will create a `mp4` file. To convert this file to `gif` format as well: ``` -python -m rl_zoo.record_training --algo ppo --env CartPole-v1 -n 1000 -f logs --deterministic --gif +python -m rl_zoo3.record_training --algo ppo --env CartPole-v1 -n 1000 -f logs --deterministic --gif ``` ## Current Collection: 195+ Trained Agents! -Final performance of the trained agents can be found in [`benchmark.md`](./benchmark.md). To compute them, simply run `python -m rl_zoo.benchmark`. +Final performance of the trained agents can be found in [`benchmark.md`](./benchmark.md). To compute them, simply run `python -m rl_zoo3.benchmark`. List and videos of trained agents can be found on our Huggingface page: https://huggingface.co/sb3 diff --git a/docker/Dockerfile b/docker/Dockerfile index 2baa348c1..7ec68ed3c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -18,7 +18,7 @@ COPY requirements.txt /tmp/ RUN \ - mkdir -p ${CODE_DIR}/rl_zoo && \ + mkdir -p ${CODE_DIR}/rl_zoo3 && \ pip uninstall -y stable-baselines3 && \ pip install -r /tmp/requirements.txt && \ pip install pip install highway-env==1.5.0 && \ diff --git a/enjoy.py b/enjoy.py index 926b70141..784abe3cf 100644 --- a/enjoy.py +++ b/enjoy.py @@ -1,4 +1,4 @@ -from rl_zoo.enjoy import enjoy +from rl_zoo3.enjoy import enjoy if __name__ == "__main__": enjoy() diff --git a/hyperparams/her.yml b/hyperparams/her.yml index 45c0710fe..a8249f46d 100644 --- a/hyperparams/her.yml +++ b/hyperparams/her.yml @@ -59,7 +59,7 @@ FetchSlide-v1: FetchPickAndPlace-v1: env_wrapper: - sb3_contrib.common.wrappers.TimeFeatureWrapper - # - rl_zoo.wrappers.DoneOnSuccessWrapper: + # - rl_zoo3.wrappers.DoneOnSuccessWrapper: # reward_offset: 0 # n_successes: 4 # - stable_baselines3.common.monitor.Monitor @@ -96,7 +96,7 @@ FetchReach-v1: NeckGoalEnvRelativeSparse-v2: model_class: 'sac' # env_wrapper: - # - rl_zoo.wrappers.HistoryWrapper: + # - rl_zoo3.wrappers.HistoryWrapper: # horizon: 2 # - sb3_contrib.common.wrappers.TimeFeatureWrapper n_timesteps: !!float 1e6 @@ -122,7 +122,7 @@ NeckGoalEnvRelativeSparse-v2: NeckGoalEnvRelativeDense-v2: model_class: 'sac' env_wrapper: - - rl_zoo.wrappers.HistoryWrapperObsDict: + - rl_zoo3.wrappers.HistoryWrapperObsDict: horizon: 2 # - sb3_contrib.common.wrappers.TimeFeatureWrapper n_timesteps: !!float 1e6 diff --git a/hyperparams/ppo.yml b/hyperparams/ppo.yml index 26d90d5f1..10909fbe8 100644 --- a/hyperparams/ppo.yml +++ b/hyperparams/ppo.yml @@ -319,7 +319,7 @@ MiniGrid-FourRooms-v0: CarRacing-v0: env_wrapper: - - rl_zoo.wrappers.FrameSkip: + - rl_zoo3.wrappers.FrameSkip: skip: 2 - gym.wrappers.resize_observation.ResizeObservation: shape: 64 diff --git a/hyperparams/ppo_lstm.yml b/hyperparams/ppo_lstm.yml index 4043d7d4c..872755f94 100644 --- a/hyperparams/ppo_lstm.yml +++ b/hyperparams/ppo_lstm.yml @@ -132,7 +132,7 @@ BipedalWalker-v3: # TO BE TUNED BipedalWalkerHardcore-v3: # env_wrapper: - # - rl_zoo.wrappers.FrameSkip: + # - rl_zoo3.wrappers.FrameSkip: # skip: 2 normalize: true n_envs: 32 @@ -285,7 +285,7 @@ InvertedPendulumSwingupBulletEnv-v0: CarRacing-v0: env_wrapper: - # - rl_zoo.wrappers.FrameSkip: + # - rl_zoo3.wrappers.FrameSkip: # skip: 2 - gym.wrappers.resize_observation.ResizeObservation: shape: 64 diff --git a/hyperparams/sac.yml b/hyperparams/sac.yml index d4475fdce..9d41262e4 100644 --- a/hyperparams/sac.yml +++ b/hyperparams/sac.yml @@ -16,7 +16,7 @@ MountainCarContinuous-v0: Pendulum-v1: # callback: - # - rl_zoo.callbacks.ParallelTrainCallback + # - rl_zoo3.callbacks.ParallelTrainCallback n_timesteps: 20000 policy: 'MlpPolicy' learning_rate: !!float 1e-3 @@ -74,9 +74,9 @@ BipedalWalkerHardcore-v3: HalfCheetahBulletEnv-v0: &pybullet-defaults # env_wrapper: # - sb3_contrib.common.wrappers.TimeFeatureWrapper - # - rl_zoo.wrappers.DelayedRewardWrapper: + # - rl_zoo3.wrappers.DelayedRewardWrapper: # delay: 10 - # - rl_zoo.wrappers.HistoryWrapper: + # - rl_zoo3.wrappers.HistoryWrapper: # horizon: 10 n_timesteps: !!float 1e6 policy: 'MlpPolicy' @@ -163,12 +163,12 @@ MinitaurBulletDuckEnv-v0: # To be tuned CarRacing-v0: env_wrapper: - - rl_zoo.wrappers.FrameSkip: + - rl_zoo3.wrappers.FrameSkip: skip: 2 # wrapper from https://github.com/araffin/aae-train-donkeycar - ae.wrapper.AutoencoderWrapper: ae_path: "logs/car_racing_rgb_160.pkl" - - rl_zoo.wrappers.HistoryWrapper: + - rl_zoo3.wrappers.HistoryWrapper: horizon: 2 # frame_stack: 4 normalize: True @@ -238,7 +238,7 @@ donkey-generated-track-v0: env_wrapper: - gym.wrappers.time_limit.TimeLimit: max_episode_steps: 500 - - rl_zoo.wrappers.HistoryWrapper: + - rl_zoo3.wrappers.HistoryWrapper: horizon: 5 n_timesteps: !!float 1e6 policy: 'MlpPolicy' @@ -262,9 +262,9 @@ donkey-generated-track-v0: NeckEnvRelative-v2: <<: *pybullet-defaults env_wrapper: - - rl_zoo.wrappers.HistoryWrapper: + - rl_zoo3.wrappers.HistoryWrapper: horizon: 2 - # - rl_zoo.wrappers.LowPassFilterWrapper: + # - rl_zoo3.wrappers.LowPassFilterWrapper: # freq: 2.0 # df: 25.0 n_timesteps: !!float 1e6 diff --git a/hyperparams/tqc.yml b/hyperparams/tqc.yml index 720eb91bf..d26289bef 100644 --- a/hyperparams/tqc.yml +++ b/hyperparams/tqc.yml @@ -258,12 +258,12 @@ parking-v0: # Tuned CarRacing-v0: env_wrapper: - - rl_zoo.wrappers.FrameSkip: + - rl_zoo3.wrappers.FrameSkip: skip: 2 # wrapper from https://github.com/araffin/aae-train-donkeycar - ae.wrapper.AutoencoderWrapper: ae_path: "logs/car_racing_rgb_160.pkl" - - rl_zoo.wrappers.HistoryWrapper: + - rl_zoo3.wrappers.HistoryWrapper: horizon: 2 # frame_stack: 4 normalize: True @@ -280,7 +280,7 @@ RocketLander-v0: n_timesteps: !!float 3e6 policy: 'MlpPolicy' env_wrapper: - - rl_zoo.wrappers.FrameSkip: + - rl_zoo3.wrappers.FrameSkip: skip: 4 - - rl_zoo.wrappers.HistoryWrapper: + - rl_zoo3.wrappers.HistoryWrapper: horizon: 2 diff --git a/requirements.txt b/requirements.txt index 2739c4914..a74b21d97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ gym-minigrid scikit-optimize optuna pytablewriter~=0.64 -seaborn pyyaml>=5.1 cloudpickle>=1.5.0 plotly diff --git a/rl_zoo/cli.py b/rl_zoo/cli.py deleted file mode 100644 index 5902224f3..000000000 --- a/rl_zoo/cli.py +++ /dev/null @@ -1,15 +0,0 @@ -import sys - -from rl_zoo.enjoy import enjoy -from rl_zoo.train import train - - -def main(): - script_name = sys.argv[1] - # Remove script name - del sys.argv[1] - # Execute known script - { - "train": train, - "enjoy": enjoy, - }[script_name]() diff --git a/rl_zoo/version.txt b/rl_zoo/version.txt deleted file mode 100644 index fdd3be6df..000000000 --- a/rl_zoo/version.txt +++ /dev/null @@ -1 +0,0 @@ -1.6.2 diff --git a/rl_zoo/__init__.py b/rl_zoo3/__init__.py similarity index 92% rename from rl_zoo/__init__.py rename to rl_zoo3/__init__.py index 7df789c63..aa6d64d95 100644 --- a/rl_zoo/__init__.py +++ b/rl_zoo3/__init__.py @@ -1,6 +1,6 @@ import os -from rl_zoo.utils import ( +from rl_zoo3.utils import ( ALGOS, create_test_env, get_latest_run_id, diff --git a/rl_zoo/benchmark.py b/rl_zoo3/benchmark.py similarity index 98% rename from rl_zoo/benchmark.py rename to rl_zoo3/benchmark.py index 01626fcec..274bab3a9 100644 --- a/rl_zoo/benchmark.py +++ b/rl_zoo3/benchmark.py @@ -9,7 +9,7 @@ import pytablewriter from stable_baselines3.common.results_plotter import load_results, ts2xy -from rl_zoo.utils import get_hf_trained_models, get_latest_run_id, get_saved_hyperparams, get_trained_models +from rl_zoo3.utils import get_hf_trained_models, get_latest_run_id, get_saved_hyperparams, get_trained_models parser = argparse.ArgumentParser() parser.add_argument("--log-dir", help="Root log folder", default="rl-trained-agents/", type=str) diff --git a/rl_zoo/callbacks.py b/rl_zoo3/callbacks.py similarity index 100% rename from rl_zoo/callbacks.py rename to rl_zoo3/callbacks.py diff --git a/rl_zoo3/cli.py b/rl_zoo3/cli.py new file mode 100644 index 000000000..dea074756 --- /dev/null +++ b/rl_zoo3/cli.py @@ -0,0 +1,22 @@ +import sys + +from rl_zoo3.enjoy import enjoy +from rl_zoo3.plots import all_plots, plot_from_file, plot_train +from rl_zoo3.train import train + + +def main(): + script_name = sys.argv[1] + # Remove script name + del sys.argv[1] + # Execute known script + known_scripts = { + "train": train, + "enjoy": enjoy, + "plot_train": plot_train, + "plot_from_file": plot_from_file, + "all_plots": all_plots, + } + if script_name not in known_scripts.keys(): + raise ValueError(f"The script {script_name} is unknown, please use one of {known_scripts.keys()}") + known_scripts[script_name]() diff --git a/rl_zoo/enjoy.py b/rl_zoo3/enjoy.py similarity index 96% rename from rl_zoo/enjoy.py rename to rl_zoo3/enjoy.py index 9fe0b5218..881b53c7b 100644 --- a/rl_zoo/enjoy.py +++ b/rl_zoo3/enjoy.py @@ -9,12 +9,12 @@ from huggingface_sb3 import EnvironmentName from stable_baselines3.common.utils import set_random_seed -import rl_zoo.import_envs # noqa: F401 pylint: disable=unused-import -from rl_zoo import ALGOS, create_test_env, get_saved_hyperparams -from rl_zoo.callbacks import tqdm -from rl_zoo.exp_manager import ExperimentManager -from rl_zoo.load_from_hub import download_from_hub -from rl_zoo.utils import StoreDict, get_model_path +import rl_zoo3.import_envs # noqa: F401 pylint: disable=unused-import +from rl_zoo3 import ALGOS, create_test_env, get_saved_hyperparams +from rl_zoo3.callbacks import tqdm +from rl_zoo3.exp_manager import ExperimentManager +from rl_zoo3.load_from_hub import download_from_hub +from rl_zoo3.utils import StoreDict, get_model_path def enjoy(): # noqa: C901 diff --git a/rl_zoo/exp_manager.py b/rl_zoo3/exp_manager.py similarity index 98% rename from rl_zoo/exp_manager.py rename to rl_zoo3/exp_manager.py index 8a661aee4..22b459b97 100644 --- a/rl_zoo/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -14,7 +14,6 @@ import torch as th import yaml from huggingface_sb3 import EnvironmentName -from optuna.integration.skopt import SkoptSampler from optuna.pruners import BasePruner, MedianPruner, NopPruner, SuccessiveHalvingPruner from optuna.samplers import BaseSampler, RandomSampler, TPESampler from optuna.study import MaxTrialsCallback @@ -45,10 +44,10 @@ from torch import nn as nn # noqa: F401 # Register custom envs -import rl_zoo.import_envs # noqa: F401 pytype: disable=import-error -from rl_zoo.callbacks import SaveVecNormalizeCallback, TQDMCallback, TrialEvalCallback -from rl_zoo.hyperparams_opt import HYPERPARAMS_SAMPLER -from rl_zoo.utils import ALGOS, get_callback_list, get_latest_run_id, get_wrapper_class, linear_schedule +import rl_zoo3.import_envs # noqa: F401 pytype: disable=import-error +from rl_zoo3.callbacks import SaveVecNormalizeCallback, TQDMCallback, TrialEvalCallback +from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER +from rl_zoo3.utils import ALGOS, get_callback_list, get_latest_run_id, get_wrapper_class, linear_schedule class ExperimentManager: @@ -102,7 +101,13 @@ def __init__( self.env_name = EnvironmentName(env_id) # Custom params self.custom_hyperparams = hyperparams - default_path = Path(__file__).parent.parent + if (Path(__file__).parent / "hyperparams").is_dir(): + # Package version + default_path = Path(__file__).parent + else: + # Take the root folder + default_path = Path(__file__).parent.parent + self.yaml_file = yaml_file or str(default_path / f"hyperparams/{self.algo}.yml") self.env_kwargs = {} if env_kwargs is None else env_kwargs self.n_timesteps = n_timesteps @@ -631,6 +636,8 @@ def _create_sampler(self, sampler_method: str) -> BaseSampler: elif sampler_method == "tpe": sampler = TPESampler(n_startup_trials=self.n_startup_trials, seed=self.seed, multivariate=True) elif sampler_method == "skopt": + from optuna.integration.skopt import SkoptSampler + # cf https://scikit-optimize.github.io/#skopt.Optimizer # GP: gaussian process # Gradient boosted regression: GBRT diff --git a/rl_zoo/hyperparams_opt.py b/rl_zoo3/hyperparams_opt.py similarity index 99% rename from rl_zoo/hyperparams_opt.py rename to rl_zoo3/hyperparams_opt.py index 1c7f86796..6eadec292 100644 --- a/rl_zoo/hyperparams_opt.py +++ b/rl_zoo3/hyperparams_opt.py @@ -5,7 +5,7 @@ from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise from torch import nn as nn -from rl_zoo import linear_schedule +from rl_zoo3 import linear_schedule def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: diff --git a/rl_zoo/import_envs.py b/rl_zoo3/import_envs.py similarity index 96% rename from rl_zoo/import_envs.py rename to rl_zoo3/import_envs.py index 8eb1861ae..dc3a96fe3 100644 --- a/rl_zoo/import_envs.py +++ b/rl_zoo3/import_envs.py @@ -1,7 +1,7 @@ import gym from gym.envs.registration import register -from rl_zoo.wrappers import MaskVelocityWrapper +from rl_zoo3.wrappers import MaskVelocityWrapper try: import pybullet_envs # pytype: disable=import-error diff --git a/rl_zoo/load_from_hub.py b/rl_zoo3/load_from_hub.py similarity index 99% rename from rl_zoo/load_from_hub.py rename to rl_zoo3/load_from_hub.py index cb1aceba2..78e1310f8 100644 --- a/rl_zoo/load_from_hub.py +++ b/rl_zoo3/load_from_hub.py @@ -8,7 +8,7 @@ from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId, load_from_hub from requests.exceptions import HTTPError -from rl_zoo import ALGOS, get_latest_run_id +from rl_zoo3 import ALGOS, get_latest_run_id def download_from_hub( diff --git a/rl_zoo3/plots/__init__.py b/rl_zoo3/plots/__init__.py new file mode 100644 index 000000000..753425e1b --- /dev/null +++ b/rl_zoo3/plots/__init__.py @@ -0,0 +1,3 @@ +from rl_zoo3.plots.all_plots import all_plots +from rl_zoo3.plots.plot_from_file import plot_from_file +from rl_zoo3.plots.plot_train import plot_train diff --git a/rl_zoo3/plots/all_plots.py b/rl_zoo3/plots/all_plots.py new file mode 100644 index 000000000..d1f931f89 --- /dev/null +++ b/rl_zoo3/plots/all_plots.py @@ -0,0 +1,250 @@ +import argparse +import os +import pickle +from copy import deepcopy + +import numpy as np +import pytablewriter +import seaborn +from matplotlib import pyplot as plt +from scipy.spatial import distance_matrix + + +def all_plots(): # noqa: C901 + parser = argparse.ArgumentParser("Gather results, plot them and create table") + parser.add_argument("-a", "--algos", help="Algorithms to include", nargs="+", type=str) + parser.add_argument("-e", "--env", help="Environments to include", nargs="+", type=str) + parser.add_argument("-f", "--exp-folders", help="Folders to include", nargs="+", type=str) + parser.add_argument("-l", "--labels", help="Label for each folder", nargs="+", type=str) + parser.add_argument( + "-k", + "--key", + help="Key from the `evaluations.npz` file to use to aggregate results " + "(e.g. reward, success rate, ...), it is 'results' by default (i.e., the episode reward)", + default="results", + type=str, + ) + parser.add_argument("-max", "--max-timesteps", help="Max number of timesteps to display", type=int, default=int(2e6)) + parser.add_argument("-min", "--min-timesteps", help="Min number of timesteps to keep a trial", type=int, default=-1) + parser.add_argument( + "-o", "--output", help="Output filename (pickle file), where to save the post-processed data", type=str + ) + parser.add_argument( + "-median", "--median", action="store_true", default=False, help="Display median instead of mean in the table" + ) + parser.add_argument("--no-million", action="store_true", default=False, help="Do not convert x-axis to million") + parser.add_argument("--no-display", action="store_true", default=False, help="Do not show the plots") + parser.add_argument( + "-print", "--print-n-trials", action="store_true", default=False, help="Print the number of trial for each result" + ) + args = parser.parse_args() + + # Activate seaborn + seaborn.set() + results = {} + post_processed_results = {} + + args.algos = [algo.upper() for algo in args.algos] + + if args.labels is None: + args.labels = args.exp_folders + + for env in args.env: # noqa: C901 + plt.figure(f"Results {env}") + plt.title(f"{env}", fontsize=14) + + x_label_suffix = "" if args.no_million else "(in Million)" + plt.xlabel(f"Timesteps {x_label_suffix}", fontsize=14) + plt.ylabel("Score", fontsize=14) + results[env] = {} + post_processed_results[env] = {} + + for algo in args.algos: + for folder_idx, exp_folder in enumerate(args.exp_folders): + + log_path = os.path.join(exp_folder, algo.lower()) + + if not os.path.isdir(log_path): + continue + + results[env][f"{args.labels[folder_idx]}-{algo}"] = 0.0 + + dirs = [ + os.path.join(log_path, d) + for d in os.listdir(log_path) + if (env in d and os.path.isdir(os.path.join(log_path, d))) + ] + + max_len = 0 + merged_timesteps, merged_results = [], [] + last_eval = [] + timesteps = np.empty(0) + for _, dir_ in enumerate(dirs): + try: + log = np.load(os.path.join(dir_, "evaluations.npz")) + except FileNotFoundError: + print("Eval not found for", dir_) + continue + + mean_ = np.squeeze(log["results"].mean(axis=1)) + + if mean_.shape == (): + continue + + max_len = max(max_len, len(mean_)) + if len(log["timesteps"]) >= max_len: + timesteps = log["timesteps"] + + # For post-processing + merged_timesteps.append(log["timesteps"]) + merged_results.append(log[args.key]) + + # Truncate the plots + while timesteps[max_len - 1] > args.max_timesteps: + max_len -= 1 + timesteps = timesteps[:max_len] + + if len(log[args.key]) >= max_len: + last_eval.append(log[args.key][max_len - 1]) + else: + last_eval.append(log[args.key][-1]) + + # Merge runs with different eval freq: + # ex: (100,) eval vs (10,) + # in that case, downsample (100,) to match the (10,) samples + # Discard all jobs that are < min_timesteps + if args.min_timesteps > 0: + min_ = np.inf + for n_timesteps in merged_timesteps: + if n_timesteps[-1] >= args.min_timesteps: + min_ = min(min_, len(n_timesteps)) + if len(n_timesteps) == min_: + max_len = len(n_timesteps) + # Truncate the plots + while n_timesteps[max_len - 1] > args.max_timesteps: + max_len -= 1 + timesteps = n_timesteps[:max_len] + # Avoid modifying original aggregated results + merged_results_ = deepcopy(merged_results) + # Downsample if needed + for trial_idx, n_timesteps in enumerate(merged_timesteps): + # We assume they are the same, or they will be discarded in the next step + if len(n_timesteps) == min_ or n_timesteps[-1] < args.min_timesteps: + pass + else: + new_merged_results = [] + # Nearest neighbour + distance_mat = distance_matrix(n_timesteps.reshape(-1, 1), timesteps.reshape(-1, 1)) + closest_indices = distance_mat.argmin(axis=0) + for closest_idx in closest_indices: + new_merged_results.append(merged_results_[trial_idx][closest_idx]) + merged_results[trial_idx] = new_merged_results + last_eval[trial_idx] = merged_results_[trial_idx][closest_indices[-1]] + + # Remove incomplete runs + merged_results_tmp, last_eval_tmp = [], [] + for idx in range(len(merged_results)): + if len(merged_results[idx]) >= max_len: + merged_results_tmp.append(merged_results[idx][:max_len]) + last_eval_tmp.append(last_eval[idx]) + merged_results = merged_results_tmp + last_eval = last_eval_tmp + + # Post-process + if len(merged_results) > 0: + # shape: (n_trials, n_eval * n_eval_episodes) + merged_results = np.array(merged_results) + n_trials = len(merged_results) + n_eval = len(timesteps) + + if args.print_n_trials: + print(f"{env}-{algo}-{args.labels[folder_idx]}: {n_trials}") + + # reshape to (n_trials, n_eval, n_eval_episodes) + evaluations = merged_results.reshape((n_trials, n_eval, -1)) + # re-arrange to (n_eval, n_trials, n_eval_episodes) + evaluations = np.swapaxes(evaluations, 0, 1) + # (n_eval,) + mean_ = np.mean(evaluations, axis=(1, 2)) + # (n_eval, n_trials) + mean_per_eval = np.mean(evaluations, axis=-1) + # (n_eval,) + std_ = np.std(mean_per_eval, axis=-1) + # std: error: + std_error = std_ / np.sqrt(n_trials) + # Take last evaluation + # shape: (n_trials, n_eval_episodes) to (n_trials,) + last_evals = np.array(last_eval).squeeze().mean(axis=-1) + # Standard deviation of the mean performance for the last eval + std_last_eval = np.std(last_evals) + # Compute standard error + std_error_last_eval = std_last_eval / np.sqrt(n_trials) + + if args.median: + results[env][f"{algo}-{args.labels[folder_idx]}"] = f"{np.median(last_evals):.0f}" + else: + results[env][ + f"{algo}-{args.labels[folder_idx]}" + ] = f"{np.mean(last_evals):.0f} +/- {std_error_last_eval:.0f}" + + # x axis in Millions of timesteps + divider = 1e6 + if args.no_million: + divider = 1.0 + + post_processed_results[env][f"{algo}-{args.labels[folder_idx]}"] = { + "timesteps": timesteps, + "mean": mean_, + "std_error": std_error, + "last_evals": last_evals, + "std_error_last_eval": std_error_last_eval, + "mean_per_eval": mean_per_eval, + } + + plt.plot(timesteps / divider, mean_, label=f"{algo}-{args.labels[folder_idx]}", linewidth=3) + plt.fill_between(timesteps / divider, mean_ + std_error, mean_ - std_error, alpha=0.5) + + plt.legend() + + # Markdown Table + writer = pytablewriter.MarkdownTableWriter(max_precision=3) + writer.table_name = "results_table" + + headers = ["Environments"] + + # One additional row for the subheader + value_matrix = [[] for i in range(len(args.env) + 1)] + + headers = ["Environments"] + # Header and sub-header + value_matrix[0].append("") + for algo in args.algos: + for label in args.labels: + value_matrix[0].append(label) + headers.append(algo) + + writer.headers = headers + + for i, env in enumerate(args.env, start=1): + value_matrix[i].append(env) + for algo in args.algos: + for label in args.labels: + key = f"{algo}-{label}" + value_matrix[i].append(f'{results[env].get(key, "0.0 +/- 0.0")}') + + writer.value_matrix = value_matrix + writer.write_table() + + post_processed_results["results_table"] = {"headers": headers, "value_matrix": value_matrix} + + if args.output is not None: + print(f"Saving to {args.output}.pkl") + with open(f"{args.output}.pkl", "wb") as file_handler: + pickle.dump(post_processed_results, file_handler) + + if not args.no_display: + plt.show() + + +if __name__ == "__main__": + all_plots() diff --git a/rl_zoo3/plots/plot_from_file.py b/rl_zoo3/plots/plot_from_file.py new file mode 100644 index 000000000..163b21a12 --- /dev/null +++ b/rl_zoo3/plots/plot_from_file.py @@ -0,0 +1,428 @@ +import argparse +import itertools +import pickle +import warnings + +import numpy as np +import pandas as pd +import pytablewriter +import seaborn +from matplotlib import pyplot as plt + +try: + from rliable import library as rly # pytype: disable=import-error + from rliable import metrics, plot_utils # pytype: disable=import-error +except ImportError: + rly = None + +from rl_zoo3.plots.score_normalization import normalize_score + + +# From https://github.com/mwaskom/seaborn/blob/master/seaborn/categorical.py +def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5): + """Take a drawn matplotlib boxplot and make it look nice.""" + for box in artist_dict["boxes"]: + box.update(dict(facecolor=color, zorder=0.9, edgecolor=gray, linewidth=linewidth)) + + for whisk in artist_dict["whiskers"]: + whisk.update(dict(color=gray, linewidth=linewidth, linestyle="-")) + + for cap in artist_dict["caps"]: + cap.update(dict(color=gray, linewidth=linewidth)) + + for med in artist_dict["medians"]: + med.update(dict(color=gray, linewidth=linewidth)) + + for fly in artist_dict["fliers"]: + fly.update(dict(markerfacecolor=gray, marker="d", markeredgecolor=gray, markersize=fliersize)) + + +def plot_from_file(): # noqa: C901 + parser = argparse.ArgumentParser("Gather results, plot them and create table") + parser.add_argument("-i", "--input", help="Input filename (numpy archive)", type=str) + parser.add_argument("-skip", "--skip-envs", help="Environments to skip", nargs="+", default=[], type=str) + parser.add_argument("--keep-envs", help="Envs to keep", nargs="+", default=[], type=str) + parser.add_argument("--skip-keys", help="Keys to skip", nargs="+", default=[], type=str) + parser.add_argument("--keep-keys", help="Keys to keep", nargs="+", default=[], type=str) + parser.add_argument("--no-million", action="store_true", default=False, help="Do not convert x-axis to million") + parser.add_argument("--skip-timesteps", action="store_true", default=False, help="Do not display learning curves") + parser.add_argument("-o", "--output", help="Output filename (image)", type=str) + parser.add_argument("--format", help="Output format", type=str, default="svg") + parser.add_argument("-loc", "--legend-loc", help="The location of the legend.", type=str, default="best") + parser.add_argument("--figsize", help="Figure size, width, height in inches.", nargs=2, type=int, default=[6.4, 4.8]) + parser.add_argument("--fontsize", help="Font size", type=int, default=14) + parser.add_argument("-l", "--labels", help="Custom labels", type=str, nargs="+") + parser.add_argument("-b", "--boxplot", help="Enable boxplot", action="store_true", default=False) + parser.add_argument("-r", "--rliable", help="Enable rliable plots", action="store_true", default=False) + parser.add_argument("-vs", "--versus", help="Enable probability of improvement plot", action="store_true", default=False) + parser.add_argument("-iqm", "--iqm", help="Enable IQM sample efficiency plot", action="store_true", default=False) + parser.add_argument("-ci", "--ci-size", help="Confidence interval size (for rliable)", type=float, default=0.95) + parser.add_argument("-latex", "--latex", help="Enable latex support", action="store_true", default=False) + parser.add_argument("--merge", help="Merge with other results files", nargs="+", default=[], type=str) + + args = parser.parse_args() + + # Activate seaborn + seaborn.set() + # Seaborn style + seaborn.set(style="whitegrid") + + # Enable LaTeX support + if args.latex: + plt.rc("text", usetex=True) + + filename = args.input + + if not filename.endswith(".pkl"): + filename += ".pkl" + + with open(filename, "rb") as file_handler: + results = pickle.load(file_handler) + + # Plot table + writer = pytablewriter.MarkdownTableWriter(max_precision=3) + writer.table_name = "results_table" + writer.headers = results["results_table"]["headers"] + writer.value_matrix = results["results_table"]["value_matrix"] + writer.write_table() + + del results["results_table"] + + for filename in args.merge: + # Merge other files + with open(filename, "rb") as file_handler: + results_2 = pickle.load(file_handler) + del results_2["results_table"] + for key in results.keys(): + if key in results_2: + for new_key in results_2[key].keys(): + results[key][new_key] = results_2[key][new_key] + + keys = [key for key in results[list(results.keys())[0]].keys() if key not in args.skip_keys] + print(f"keys: {keys}") + if len(args.keep_keys) > 0: + keys = [key for key in keys if key in args.keep_keys] + envs = [env for env in results.keys() if env not in args.skip_envs] + + if len(args.keep_envs) > 0: + envs = [env for env in envs if env in args.keep_envs] + + labels = {key: key for key in keys} + if args.labels is not None: + for key, label in zip(keys, args.labels): + labels[key] = label + + if not args.skip_timesteps: + # Plot learning curves per env + for env in envs: + + plt.figure(f"Results {env}") + title = f"{env}" # BulletEnv-v0 + if "Mountain" in env: + title = "MountainCarContinuous-v0" + + plt.title(title, fontsize=args.fontsize) + + x_label_suffix = "" if args.no_million else "(1e6)" + plt.xlabel(f"Timesteps {x_label_suffix}", fontsize=args.fontsize) + plt.ylabel("Score", fontsize=args.fontsize) + + for key in keys: + # x axis in Millions of timesteps + divider = 1e6 + if args.no_million: + divider = 1.0 + + timesteps = results[env][key]["timesteps"] + mean_ = results[env][key]["mean"] + std_error = results[env][key]["std_error"] + + plt.xticks(fontsize=13) + plt.plot(timesteps / divider, mean_, label=labels[key], linewidth=3) + plt.fill_between(timesteps / divider, mean_ + std_error, mean_ - std_error, alpha=0.5) + + plt.legend(fontsize=args.fontsize) + plt.tight_layout() + + # Convert to pandas dataframe, in order to use seaborn + labels_df, envs_df, scores = [], [], [] + # Post-process to use it with rliable + # algo: (n_runs, n_envs) + normalized_score_dict = {} + # algo: (n_runs, n_envs, n_eval) + all_eval_normalized_scores_dict = {} + # Convert env key to env id for normalization + env_key_to_env_id = { + "Half": "HalfCheetahBulletEnv-v0", + "Ant": "AntBulletEnv-v0", + "Hopper": "HopperBulletEnv-v0", + "Walker": "Walker2DBulletEnv-v0", + } + # Backward compat + skip_all_algos_dict = False + + for key in keys: + algo_scores, all_algo_scores = [], [] + for env in envs: + if isinstance(results[env][key]["last_evals"], (np.float32, np.float64)): + # No enough timesteps + print(f"Skipping {env}-{key}") + continue + + for score in results[env][key]["last_evals"]: + labels_df.append(labels[key]) + # convert to int if needed + # labels_df.append(int(labels[key])) + envs_df.append(env) + scores.append(score) + + algo_scores.append(results[env][key]["last_evals"]) + + # Backward compat: mean_per_eval key may not be present + if "mean_per_eval" in results[env][key]: + all_algo_scores.append(results[env][key]["mean_per_eval"]) + else: + skip_all_algos_dict = True + + # Normalize score, env key must match env_id + if env in env_key_to_env_id: + algo_scores[-1] = normalize_score(algo_scores[-1], env_key_to_env_id[env]) + if not skip_all_algos_dict: + all_algo_scores[-1] = normalize_score(all_algo_scores[-1], env_key_to_env_id[env]) + elif env not in env_key_to_env_id and args.rliable: + warnings.warn(f"{env} not found for normalizing scores, you should update `env_key_to_env_id`") + + # Truncate to convert to matrix + min_runs = min(len(algo_score) for algo_score in algo_scores) + if min_runs > 0: + algo_scores = [algo_score[:min_runs] for algo_score in algo_scores] + # shape: (n_envs, n_runs) -> (n_runs, n_envs) + normalized_score_dict[labels[key]] = np.array(algo_scores).T + if not skip_all_algos_dict: + all_algo_scores = [all_algo_score[:, :min_runs] for all_algo_score in all_algo_scores] + # (n_envs, n_eval, n_runs) -> (n_runs, n_envs, n_eval) + all_eval_normalized_scores_dict[labels[key]] = np.array(all_algo_scores).transpose((2, 0, 1)) + + data_frame = pd.DataFrame(data=dict(Method=labels_df, Environment=envs_df, Score=scores)) + + # Rliable plots, see https://github.com/google-research/rliable + if args.rliable: + + if rly is None: + raise ImportError( + "You must install rliable package to use this feature. Note: Python 3.7+ is required in that case." + ) + + print("Computing bootstrap CI ...") + algorithms = list(labels.values()) + # Scores as a dictionary mapping algorithms to their normalized + # score matrices, each of which is of size `(num_runs x num_envs)`. + + aggregate_func = lambda x: np.array( # noqa: E731 + [ + metrics.aggregate_median(x), + metrics.aggregate_iqm(x), + metrics.aggregate_mean(x), + metrics.aggregate_optimality_gap(x), + ] + ) + aggregate_scores, aggregate_interval_estimates = rly.get_interval_estimates( + normalized_score_dict, + aggregate_func, + # Default was 50000 + reps=2000, # Number of bootstrap replications. + confidence_interval_size=args.ci_size, # Coverage of confidence interval. Defaults to 95%. + ) + + fig, axes = plot_utils.plot_interval_estimates( + aggregate_scores, + aggregate_interval_estimates, + metric_names=["Median", "IQM", "Mean", "Optimality Gap"], + algorithms=algorithms, + xlabel="Normalized Score", + xlabel_y_coordinate=0.02, + subfigure_width=5, + row_height=1, + max_ticks=4, + interval_height=0.6, + ) + fig.canvas.manager.set_window_title("Rliable metrics") + # Adjust margin to see the x label + plt.tight_layout() + plt.subplots_adjust(bottom=0.2) + + # Performance profiles + # Normalized score thresholds + normalized_score_thresholds = np.linspace(0.0, 1.5, 50) + score_distributions, score_distributions_cis = rly.create_performance_profile( + normalized_score_dict, + normalized_score_thresholds, + reps=2000, + confidence_interval_size=args.ci_size, + ) + # Plot score distributions + fig, ax = plt.subplots(ncols=1, figsize=(7, 5)) + plot_utils.plot_performance_profiles( + score_distributions, + normalized_score_thresholds, + performance_profile_cis=score_distributions_cis, + colors=dict(zip(algorithms, seaborn.color_palette("colorblind"))), + xlabel=r"Normalized Score $(\tau)$", + ax=ax, + ) + fig.canvas.manager.set_window_title("Performance profiles") + plt.legend() + + # Probability of improvement + # Scores as a dictionary containing pairs of normalized score + # matrices for pairs of algorithms we want to compare + algorithm_pairs_keys = itertools.combinations(algorithms, 2) + # algorithm_pairs = {.. , 'x,y': (score_x, score_y), ..} + algorithm_pairs = {} + for algo1, algo2 in algorithm_pairs_keys: + algorithm_pairs[f"{algo1}, {algo2}"] = (normalized_score_dict[algo1], normalized_score_dict[algo2]) + + if args.versus: + average_probabilities, average_prob_cis = rly.get_interval_estimates( + algorithm_pairs, + metrics.probability_of_improvement, + reps=1000, # Default was 50000 + confidence_interval_size=args.ci_size, + ) + plot_utils.plot_probability_of_improvement( + average_probabilities, + average_prob_cis, + figsize=(10, 8), + interval_height=0.6, + ) + plt.gcf().canvas.manager.set_window_title("Probability of Improvement") + plt.tight_layout() + + if args.iqm: + # Load scores as a dictionary mapping algorithms to their normalized + # score matrices across all evaluations, each of which is of size + # `(n_runs, n_envs, n_eval)` where scores are recorded every n steps. + # Only compute CI for 1/4 of the evaluations and keep the first and last eval + downsample_factor = 4 + n_evals = all_eval_normalized_scores_dict[algorithms[0]].shape[-1] + eval_indices = np.arange(n_evals - 1)[::downsample_factor] + eval_indices = np.concatenate((eval_indices, [n_evals - 1])) + eval_indices_scores_dict = { + algorithm: score[:, :, eval_indices] for algorithm, score in all_eval_normalized_scores_dict.items() + } + iqm = lambda scores: np.array( # noqa: E731 + [metrics.aggregate_iqm(scores[..., eval_idx]) for eval_idx in range(scores.shape[-1])] + ) + iqm_scores, iqm_cis = rly.get_interval_estimates( + eval_indices_scores_dict, + iqm, + reps=2000, + confidence_interval_size=args.ci_size, + ) + plot_utils.plot_sample_efficiency_curve( + eval_indices + 1, + iqm_scores, + iqm_cis, + algorithms=algorithms, + # TODO: convert to timesteps using the timesteps + xlabel=r"Number of Evaluations", + ylabel="IQM Normalized Score", + ) + plt.gcf().canvas.manager.set_window_title("IQM Normalized Score - Sample Efficiency Curve") + plt.legend() + plt.tight_layout() + + plt.show() + + # Plot final results with env as x axis + plt.figure("Sensitivity plot", figsize=args.figsize) + plt.title("Sensitivity plot", fontsize=args.fontsize) + # plt.title('Influence of the time feature', fontsize=args.fontsize) + # plt.title('Influence of the network architecture', fontsize=args.fontsize) + # plt.title('Influence of the exploration variance $log \sigma$', fontsize=args.fontsize) + # plt.title("Influence of the sampling frequency", fontsize=args.fontsize) + # plt.title('Parallel vs No Parallel Sampling', fontsize=args.fontsize) + # plt.title('Influence of the exploration function input', fontsize=args.fontsize) + plt.title("PyBullet envs", fontsize=args.fontsize) + plt.xticks(fontsize=13) + plt.xlabel("Environment", fontsize=args.fontsize) + plt.ylabel("Score", fontsize=args.fontsize) + + ax = seaborn.barplot(x="Environment", y="Score", hue="Method", data=data_frame) + # Custom legend title + handles, labels_legend = ax.get_legend_handles_labels() + # ax.legend(handles=handles, labels=labels_legend, title=r"$log \sigma$", loc=args.legend_loc) + # ax.legend(handles=handles, labels=labels_legend, title="Network Architecture", loc=args.legend_loc) + # ax.legend(handles=handles, labels=labels_legend, title="Interval", loc=args.legend_loc) + # Old error plot + # for key in keys: + # values = [np.mean(results[env][key]["last_evals"]) for env in envs] + # # Overwrite the labels + # # labels = {key:i for i, key in enumerate(keys, start=-6)} + # plt.errorbar( + # envs, + # values, + # yerr=results[env][key]["std_error"][-1], + # linewidth=3, + # fmt="-o", + # label=labels[key], + # capsize=5, + # capthick=2, + # elinewidth=2, + # ) + # plt.legend(fontsize=13, loc=args.legend_loc) + plt.tight_layout() + if args.output is not None: + plt.savefig(args.output, format=args.format) + + # Plot final results with env as labels and method as x axis + # plt.figure('Sensitivity plot inverted', figsize=args.figsize) + # plt.title('Sensitivity plot', fontsize=args.fontsize) + # plt.xticks(fontsize=13) + # # plt.xlabel('Method', fontsize=args.fontsize) + # plt.ylabel('Score', fontsize=args.fontsize) + # + # for env in envs: + # values = [np.mean(results[env][key]['last_evals']) for key in keys] + # # Overwrite the labels + # # labels = {key:i for i, key in enumerate(keys, start=-6)} + # plt.errorbar(labels.values(), values, yerr=results[env][key]['std_error'][-1], + # linewidth=3, fmt='-o', label=env, capsize=5, capthick=2, elinewidth=2) + # + # plt.legend(fontsize=13, loc=args.legend_loc) + # plt.tight_layout() + + if args.boxplot: + # Box plot + plt.figure("Sensitivity box plot", figsize=args.figsize) + plt.title("Sensitivity box plot", fontsize=args.fontsize) + # plt.title('Influence of the exploration variance $log \sigma$ on Hopper', fontsize=args.fontsize) + # plt.title('Influence of the sampling frequency on Walker2D', fontsize=args.fontsize) + # plt.title('Influence of the exploration function input on Hopper', fontsize=args.fontsize) + plt.xticks(fontsize=13) + # plt.xlabel('Exploration variance $log \sigma$', fontsize=args.fontsize) + # plt.xlabel("Sampling frequency", fontsize=args.fontsize) + # plt.xlabel('Method', fontsize=args.fontsize) + plt.ylabel("Score", fontsize=args.fontsize) + + data, labels_ = [], [] + for env in envs: + for key in keys: + data.append(results[env][key]["last_evals"]) + text = f"{env}-{labels[key]}" if len(envs) > 1 else labels[key] + labels_.append(text) + artist_dict = plt.boxplot(data, patch_artist=True) + # Make the boxplot looks nice + # see https://github.com/mwaskom/seaborn/blob/master/seaborn/categorical.py + color_palette = seaborn.color_palette() + # orange + boxplot_color = color_palette[1] + restyle_boxplot(artist_dict, color=boxplot_color) + plt.xticks(np.arange(1, len(data) + 1), labels_, rotation=0) + plt.tight_layout() + + plt.show() + + +if __name__ == "__main__": + plot_from_file() diff --git a/rl_zoo3/plots/plot_train.py b/rl_zoo3/plots/plot_train.py new file mode 100644 index 000000000..b46331e20 --- /dev/null +++ b/rl_zoo3/plots/plot_train.py @@ -0,0 +1,98 @@ +""" +Plot training reward/success rate +""" +import argparse +import os + +import numpy as np +import seaborn +from matplotlib import pyplot as plt +from stable_baselines3.common.monitor import LoadMonitorResultsError, load_results +from stable_baselines3.common.results_plotter import X_EPISODES, X_TIMESTEPS, X_WALLTIME, ts2xy, window_func + +# Activate seaborn +seaborn.set() + + +def plot_train(): + parser = argparse.ArgumentParser("Gather results, plot training reward/success") + parser.add_argument("-a", "--algo", help="Algorithm to include", type=str, required=True) + parser.add_argument("-e", "--env", help="Environment(s) to include", nargs="+", type=str, required=True) + parser.add_argument("-f", "--exp-folder", help="Folders to include", type=str, required=True) + parser.add_argument("--figsize", help="Figure size, width, height in inches.", nargs=2, type=int, default=[6.4, 4.8]) + parser.add_argument("--fontsize", help="Font size", type=int, default=14) + parser.add_argument("-max", "--max-timesteps", help="Max number of timesteps to display", type=int) + parser.add_argument("-x", "--x-axis", help="X-axis", choices=["steps", "episodes", "time"], type=str, default="steps") + parser.add_argument("-y", "--y-axis", help="Y-axis", choices=["success", "reward", "length"], type=str, default="reward") + parser.add_argument("-w", "--episode-window", help="Rolling window size", type=int, default=100) + + args = parser.parse_args() + + algo = args.algo + envs = args.env + log_path = os.path.join(args.exp_folder, algo) + + x_axis = { + "steps": X_TIMESTEPS, + "episodes": X_EPISODES, + "time": X_WALLTIME, + }[args.x_axis] + x_label = { + "steps": "Timesteps", + "episodes": "Episodes", + "time": "Walltime (in hours)", + }[args.x_axis] + + y_axis = { + "success": "is_success", + "reward": "r", + "length": "l", + }[args.y_axis] + y_label = { + "success": "Training Success Rate", + "reward": "Training Episodic Reward", + "length": "Training Episode Length", + }[args.y_axis] + + dirs = [] + + for env in envs: + dirs.extend( + [ + os.path.join(log_path, folder) + for folder in os.listdir(log_path) + if (env in folder and os.path.isdir(os.path.join(log_path, folder))) + ] + ) + + plt.figure(y_label, figsize=args.figsize) + plt.title(y_label, fontsize=args.fontsize) + plt.xlabel(f"{x_label}", fontsize=args.fontsize) + plt.ylabel(y_label, fontsize=args.fontsize) + for folder in dirs: + try: + data_frame = load_results(folder) + except LoadMonitorResultsError: + continue + if args.max_timesteps is not None: + data_frame = data_frame[data_frame.l.cumsum() <= args.max_timesteps] + try: + y = np.array(data_frame[y_axis]) + except KeyError: + print(f"No data available for {folder}") + continue + x, _ = ts2xy(data_frame, x_axis) + + # Do not plot the smoothed curve at all if the timeseries is shorter than window size. + if x.shape[0] >= args.episode_window: + # Compute and plot rolling mean with window of size args.episode_window + x, y_mean = window_func(x, y, args.episode_window, np.mean) + plt.plot(x, y_mean, linewidth=2, label=folder.split("/")[-1]) + + plt.legend() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + plot_train() diff --git a/scripts/score_normalization.py b/rl_zoo3/plots/score_normalization.py similarity index 100% rename from scripts/score_normalization.py rename to rl_zoo3/plots/score_normalization.py diff --git a/rl_zoo/push_to_hub.py b/rl_zoo3/push_to_hub.py similarity index 98% rename from rl_zoo/push_to_hub.py rename to rl_zoo3/push_to_hub.py index 9c2ed556a..50323cd32 100644 --- a/rl_zoo/push_to_hub.py +++ b/rl_zoo3/push_to_hub.py @@ -19,10 +19,10 @@ from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize from wasabi import Printer -import rl_zoo.import_envs # noqa: F401 pylint: disable=unused-import -from rl_zoo import ALGOS, create_test_env, get_saved_hyperparams -from rl_zoo.exp_manager import ExperimentManager -from rl_zoo.utils import StoreDict, get_model_path +import rl_zoo3.import_envs # noqa: F401 pylint: disable=unused-import +from rl_zoo3 import ALGOS, create_test_env, get_saved_hyperparams +from rl_zoo3.exp_manager import ExperimentManager +from rl_zoo3.utils import StoreDict, get_model_path msg = Printer() diff --git a/rl_zoo/py.typed b/rl_zoo3/py.typed similarity index 100% rename from rl_zoo/py.typed rename to rl_zoo3/py.typed diff --git a/rl_zoo/record_training.py b/rl_zoo3/record_training.py similarity index 95% rename from rl_zoo/record_training.py rename to rl_zoo3/record_training.py index 043efb832..bc84772fc 100644 --- a/rl_zoo/record_training.py +++ b/rl_zoo3/record_training.py @@ -7,7 +7,7 @@ from huggingface_sb3 import EnvironmentName -from rl_zoo.utils import ALGOS, get_latest_run_id +from rl_zoo3.utils import ALGOS, get_latest_run_id if __name__ == "__main__": # noqa: C901 parser = argparse.ArgumentParser() @@ -79,19 +79,19 @@ args_final_model.append("--deterministic") if os.path.exists(os.path.join(log_path, f"{env_name}.zip")): - return_code = subprocess.call(["python", "-m", "rl_zoo.record_video"] + args_final_model) + return_code = subprocess.call(["python", "-m", "rl_zoo3.record_video"] + args_final_model) assert return_code == 0, "Failed to record the final model" if os.path.exists(os.path.join(log_path, "best_model.zip")): args_best_model = args_final_model + ["--load-best"] - return_code = subprocess.call(["python", "-m", "rl_zoo.record_video"] + args_best_model) + return_code = subprocess.call(["python", "-m", "rl_zoo3.record_video"] + args_best_model) assert return_code == 0, "Failed to record the best model" args_checkpoint = args_final_model + ["--load-checkpoint"] args_checkpoint.append("0") for checkpoint in checkpoints: args_checkpoint[-1] = str(checkpoint) - return_code = subprocess.call(["python", "-m", "rl_zoo.record_video"] + args_checkpoint) + return_code = subprocess.call(["python", "-m", "rl_zoo3.record_video"] + args_checkpoint) assert return_code == 0, f"Failed to record the {checkpoint} checkpoint model" # add text to each video diff --git a/rl_zoo/record_video.py b/rl_zoo3/record_video.py similarity index 97% rename from rl_zoo/record_video.py rename to rl_zoo3/record_video.py index ea15b1cfb..b116068ce 100644 --- a/rl_zoo/record_video.py +++ b/rl_zoo3/record_video.py @@ -8,8 +8,8 @@ from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.vec_env import VecVideoRecorder -from rl_zoo.exp_manager import ExperimentManager -from rl_zoo.utils import ALGOS, StoreDict, create_test_env, get_model_path, get_saved_hyperparams +from rl_zoo3.exp_manager import ExperimentManager +from rl_zoo3.utils import ALGOS, StoreDict, create_test_env, get_model_path, get_saved_hyperparams if __name__ == "__main__": # noqa: C901 parser = argparse.ArgumentParser() diff --git a/rl_zoo/train.py b/rl_zoo3/train.py similarity index 98% rename from rl_zoo/train.py rename to rl_zoo3/train.py index 79ce938d4..19552cbd0 100644 --- a/rl_zoo/train.py +++ b/rl_zoo3/train.py @@ -11,9 +11,9 @@ from stable_baselines3.common.utils import set_random_seed # Register custom envs -import rl_zoo.import_envs # noqa: F401 pytype: disable=import-error -from rl_zoo.exp_manager import ExperimentManager -from rl_zoo.utils import ALGOS, StoreDict +import rl_zoo3.import_envs # noqa: F401 pytype: disable=import-error +from rl_zoo3.exp_manager import ExperimentManager +from rl_zoo3.utils import ALGOS, StoreDict def train(): diff --git a/rl_zoo/utils.py b/rl_zoo3/utils.py similarity index 99% rename from rl_zoo/utils.py rename to rl_zoo3/utils.py index 3da88eabf..26ed69158 100644 --- a/rl_zoo/utils.py +++ b/rl_zoo3/utils.py @@ -199,7 +199,7 @@ def create_test_env( :return: """ # Avoid circular import - from rl_zoo.exp_manager import ExperimentManager + from rl_zoo3.exp_manager import ExperimentManager # Create the environment and wrap it if necessary env_wrapper = get_wrapper_class(hyperparams) diff --git a/rl_zoo3/version.txt b/rl_zoo3/version.txt new file mode 100644 index 000000000..607310c88 --- /dev/null +++ b/rl_zoo3/version.txt @@ -0,0 +1 @@ +1.6.2.post1 diff --git a/rl_zoo/wrappers.py b/rl_zoo3/wrappers.py similarity index 100% rename from rl_zoo/wrappers.py rename to rl_zoo3/wrappers.py diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/scripts/all_plots.py b/scripts/all_plots.py index 335b78829..32f249da6 100644 --- a/scripts/all_plots.py +++ b/scripts/all_plots.py @@ -1,244 +1,4 @@ -import argparse -import os -import pickle -from copy import deepcopy +from rl_zoo3.plots.all_plots import all_plots -import numpy as np -import pytablewriter -import seaborn -from matplotlib import pyplot as plt -from scipy.spatial import distance_matrix - -parser = argparse.ArgumentParser("Gather results, plot them and create table") -parser.add_argument("-a", "--algos", help="Algorithms to include", nargs="+", type=str) -parser.add_argument("-e", "--env", help="Environments to include", nargs="+", type=str) -parser.add_argument("-f", "--exp-folders", help="Folders to include", nargs="+", type=str) -parser.add_argument("-l", "--labels", help="Label for each folder", nargs="+", type=str) -parser.add_argument( - "-k", - "--key", - help="Key from the `evaluations.npz` file to use to aggregate results " - "(e.g. reward, success rate, ...), it is 'results' by default (i.e., the episode reward)", - default="results", - type=str, -) -parser.add_argument("-max", "--max-timesteps", help="Max number of timesteps to display", type=int, default=int(2e6)) -parser.add_argument("-min", "--min-timesteps", help="Min number of timesteps to keep a trial", type=int, default=-1) -parser.add_argument("-o", "--output", help="Output filename (pickle file), where to save the post-processed data", type=str) -parser.add_argument( - "-median", "--median", action="store_true", default=False, help="Display median instead of mean in the table" -) -parser.add_argument("--no-million", action="store_true", default=False, help="Do not convert x-axis to million") -parser.add_argument("--no-display", action="store_true", default=False, help="Do not show the plots") -parser.add_argument( - "-print", "--print-n-trials", action="store_true", default=False, help="Print the number of trial for each result" -) -args = parser.parse_args() - -# Activate seaborn -seaborn.set() -results = {} -post_processed_results = {} - -args.algos = [algo.upper() for algo in args.algos] - -if args.labels is None: - args.labels = args.exp_folders - -for env in args.env: # noqa: C901 - plt.figure(f"Results {env}") - plt.title(f"{env}", fontsize=14) - - x_label_suffix = "" if args.no_million else "(in Million)" - plt.xlabel(f"Timesteps {x_label_suffix}", fontsize=14) - plt.ylabel("Score", fontsize=14) - results[env] = {} - post_processed_results[env] = {} - - for algo in args.algos: - for folder_idx, exp_folder in enumerate(args.exp_folders): - - log_path = os.path.join(exp_folder, algo.lower()) - - if not os.path.isdir(log_path): - continue - - results[env][f"{args.labels[folder_idx]}-{algo}"] = 0.0 - - dirs = [ - os.path.join(log_path, d) - for d in os.listdir(log_path) - if (env in d and os.path.isdir(os.path.join(log_path, d))) - ] - - max_len = 0 - merged_timesteps, merged_results = [], [] - last_eval = [] - timesteps = np.empty(0) - for _, dir_ in enumerate(dirs): - try: - log = np.load(os.path.join(dir_, "evaluations.npz")) - except FileNotFoundError: - print("Eval not found for", dir_) - continue - - mean_ = np.squeeze(log["results"].mean(axis=1)) - - if mean_.shape == (): - continue - - max_len = max(max_len, len(mean_)) - if len(log["timesteps"]) >= max_len: - timesteps = log["timesteps"] - - # For post-processing - merged_timesteps.append(log["timesteps"]) - merged_results.append(log[args.key]) - - # Truncate the plots - while timesteps[max_len - 1] > args.max_timesteps: - max_len -= 1 - timesteps = timesteps[:max_len] - - if len(log[args.key]) >= max_len: - last_eval.append(log[args.key][max_len - 1]) - else: - last_eval.append(log[args.key][-1]) - - # Merge runs with different eval freq: - # ex: (100,) eval vs (10,) - # in that case, downsample (100,) to match the (10,) samples - # Discard all jobs that are < min_timesteps - min_trials = [] - if args.min_timesteps > 0: - min_ = np.inf - for n_timesteps in merged_timesteps: - if n_timesteps[-1] >= args.min_timesteps: - min_ = min(min_, len(n_timesteps)) - if len(n_timesteps) == min_: - max_len = len(n_timesteps) - # Truncate the plots - while n_timesteps[max_len - 1] > args.max_timesteps: - max_len -= 1 - timesteps = n_timesteps[:max_len] - # Avoid modifying original aggregated results - merged_results_ = deepcopy(merged_results) - # Downsample if needed - for trial_idx, n_timesteps in enumerate(merged_timesteps): - # We assume they are the same, or they will be discarded in the next step - if len(n_timesteps) == min_ or n_timesteps[-1] < args.min_timesteps: - pass - else: - new_merged_results = [] - # Nearest neighbour - distance_mat = distance_matrix(n_timesteps.reshape(-1, 1), timesteps.reshape(-1, 1)) - closest_indices = distance_mat.argmin(axis=0) - for closest_idx in closest_indices: - new_merged_results.append(merged_results_[trial_idx][closest_idx]) - merged_results[trial_idx] = new_merged_results - last_eval[trial_idx] = merged_results_[trial_idx][closest_indices[-1]] - - # Remove incomplete runs - merged_results_tmp, last_eval_tmp = [], [] - for idx in range(len(merged_results)): - if len(merged_results[idx]) >= max_len: - merged_results_tmp.append(merged_results[idx][:max_len]) - last_eval_tmp.append(last_eval[idx]) - merged_results = merged_results_tmp - last_eval = last_eval_tmp - - # Post-process - if len(merged_results) > 0: - # shape: (n_trials, n_eval * n_eval_episodes) - merged_results = np.array(merged_results) - n_trials = len(merged_results) - n_eval = len(timesteps) - - if args.print_n_trials: - print(f"{env}-{algo}-{args.labels[folder_idx]}: {n_trials}") - - # reshape to (n_trials, n_eval, n_eval_episodes) - evaluations = merged_results.reshape((n_trials, n_eval, -1)) - # re-arrange to (n_eval, n_trials, n_eval_episodes) - evaluations = np.swapaxes(evaluations, 0, 1) - # (n_eval,) - mean_ = np.mean(evaluations, axis=(1, 2)) - # (n_eval, n_trials) - mean_per_eval = np.mean(evaluations, axis=-1) - # (n_eval,) - std_ = np.std(mean_per_eval, axis=-1) - # std: error: - std_error = std_ / np.sqrt(n_trials) - # Take last evaluation - # shape: (n_trials, n_eval_episodes) to (n_trials,) - last_evals = np.array(last_eval).squeeze().mean(axis=-1) - # Standard deviation of the mean performance for the last eval - std_last_eval = np.std(last_evals) - # Compute standard error - std_error_last_eval = std_last_eval / np.sqrt(n_trials) - - if args.median: - results[env][f"{algo}-{args.labels[folder_idx]}"] = f"{np.median(last_evals):.0f}" - else: - results[env][ - f"{algo}-{args.labels[folder_idx]}" - ] = f"{np.mean(last_evals):.0f} +/- {std_error_last_eval:.0f}" - - # x axis in Millions of timesteps - divider = 1e6 - if args.no_million: - divider = 1.0 - - post_processed_results[env][f"{algo}-{args.labels[folder_idx]}"] = { - "timesteps": timesteps, - "mean": mean_, - "std_error": std_error, - "last_evals": last_evals, - "std_error_last_eval": std_error_last_eval, - "mean_per_eval": mean_per_eval, - } - - plt.plot(timesteps / divider, mean_, label=f"{algo}-{args.labels[folder_idx]}", linewidth=3) - plt.fill_between(timesteps / divider, mean_ + std_error, mean_ - std_error, alpha=0.5) - - plt.legend() - - -# Markdown Table -writer = pytablewriter.MarkdownTableWriter(max_precision=3) -writer.table_name = "results_table" - -headers = ["Environments"] - -# One additional row for the subheader -value_matrix = [[] for i in range(len(args.env) + 1)] - -headers = ["Environments"] -# Header and sub-header -value_matrix[0].append("") -for algo in args.algos: - for label in args.labels: - value_matrix[0].append(label) - headers.append(algo) - -writer.headers = headers - -for i, env in enumerate(args.env, start=1): - value_matrix[i].append(env) - for algo in args.algos: - for label in args.labels: - key = f"{algo}-{label}" - value_matrix[i].append(f'{results[env].get(key, "0.0 +/- 0.0")}') - -writer.value_matrix = value_matrix -writer.write_table() - -post_processed_results["results_table"] = {"headers": headers, "value_matrix": value_matrix} - -if args.output is not None: - print(f"Saving to {args.output}.pkl") - with open(f"{args.output}.pkl", "wb") as file_handler: - pickle.dump(post_processed_results, file_handler) - -if not args.no_display: - plt.show() +if __name__ == "__main__": + all_plots() diff --git a/scripts/migrate_to_hub.py b/scripts/migrate_to_hub.py index 921952c1b..b905ee2e1 100644 --- a/scripts/migrate_to_hub.py +++ b/scripts/migrate_to_hub.py @@ -1,6 +1,6 @@ import subprocess -from rl_zoo.utils import get_hf_trained_models, get_trained_models +from rl_zoo3.utils import get_hf_trained_models, get_trained_models folder = "rl-trained-agents" orga = "sb3" @@ -16,4 +16,4 @@ if algo == "her": continue - return_code = subprocess.call(["python", "-m", "utils.push_to_hub"] + args) + return_code = subprocess.call(["python", "-m", "rl_zoo3.push_to_hub"] + args) diff --git a/scripts/plot_from_file.py b/scripts/plot_from_file.py index e0f1c95e3..8ff358415 100644 --- a/scripts/plot_from_file.py +++ b/scripts/plot_from_file.py @@ -1,423 +1,4 @@ -import argparse -import itertools -import pickle -import warnings +from rl_zoo3.plots.plot_from_file import plot_from_file -import numpy as np -import pandas as pd -import pytablewriter -import seaborn -from matplotlib import pyplot as plt - -try: - from rliable import library as rly # pytype: disable=import-error - from rliable import metrics, plot_utils # pytype: disable=import-error -except ImportError: - rly = None - -from score_normalization import normalize_score - - -# From https://github.com/mwaskom/seaborn/blob/master/seaborn/categorical.py -def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5): - """Take a drawn matplotlib boxplot and make it look nice.""" - for box in artist_dict["boxes"]: - box.update(dict(facecolor=color, zorder=0.9, edgecolor=gray, linewidth=linewidth)) - - for whisk in artist_dict["whiskers"]: - whisk.update(dict(color=gray, linewidth=linewidth, linestyle="-")) - - for cap in artist_dict["caps"]: - cap.update(dict(color=gray, linewidth=linewidth)) - - for med in artist_dict["medians"]: - med.update(dict(color=gray, linewidth=linewidth)) - - for fly in artist_dict["fliers"]: - fly.update(dict(markerfacecolor=gray, marker="d", markeredgecolor=gray, markersize=fliersize)) - - -parser = argparse.ArgumentParser("Gather results, plot them and create table") -parser.add_argument("-i", "--input", help="Input filename (numpy archive)", type=str) -parser.add_argument("-skip", "--skip-envs", help="Environments to skip", nargs="+", default=[], type=str) -parser.add_argument("--keep-envs", help="Envs to keep", nargs="+", default=[], type=str) -parser.add_argument("--skip-keys", help="Keys to skip", nargs="+", default=[], type=str) -parser.add_argument("--keep-keys", help="Keys to keep", nargs="+", default=[], type=str) -parser.add_argument("--no-million", action="store_true", default=False, help="Do not convert x-axis to million") -parser.add_argument("--skip-timesteps", action="store_true", default=False, help="Do not display learning curves") -parser.add_argument("-o", "--output", help="Output filename (image)", type=str) -parser.add_argument("--format", help="Output format", type=str, default="svg") -parser.add_argument("-loc", "--legend-loc", help="The location of the legend.", type=str, default="best") -parser.add_argument("--figsize", help="Figure size, width, height in inches.", nargs=2, type=int, default=[6.4, 4.8]) -parser.add_argument("--fontsize", help="Font size", type=int, default=14) -parser.add_argument("-l", "--labels", help="Custom labels", type=str, nargs="+") -parser.add_argument("-b", "--boxplot", help="Enable boxplot", action="store_true", default=False) -parser.add_argument("-r", "--rliable", help="Enable rliable plots", action="store_true", default=False) -parser.add_argument("-vs", "--versus", help="Enable probability of improvement plot", action="store_true", default=False) -parser.add_argument("-iqm", "--iqm", help="Enable IQM sample efficiency plot", action="store_true", default=False) -parser.add_argument("-ci", "--ci-size", help="Confidence interval size (for rliable)", type=float, default=0.95) -parser.add_argument("-latex", "--latex", help="Enable latex support", action="store_true", default=False) -parser.add_argument("--merge", help="Merge with other results files", nargs="+", default=[], type=str) - -args = parser.parse_args() - -# Activate seaborn -seaborn.set() -# Seaborn style -seaborn.set(style="whitegrid") - -# Enable LaTeX support -if args.latex: - plt.rc("text", usetex=True) - -filename = args.input - -if not filename.endswith(".pkl"): - filename += ".pkl" - -with open(filename, "rb") as file_handler: - results = pickle.load(file_handler) - -# Plot table -writer = pytablewriter.MarkdownTableWriter(max_precision=3) -writer.table_name = "results_table" -writer.headers = results["results_table"]["headers"] -writer.value_matrix = results["results_table"]["value_matrix"] -writer.write_table() - -del results["results_table"] - -for filename in args.merge: - # Merge other files - with open(filename, "rb") as file_handler: - results_2 = pickle.load(file_handler) - del results_2["results_table"] - for key in results.keys(): - if key in results_2: - for new_key in results_2[key].keys(): - results[key][new_key] = results_2[key][new_key] - - -keys = [key for key in results[list(results.keys())[0]].keys() if key not in args.skip_keys] -print(f"keys: {keys}") -if len(args.keep_keys) > 0: - keys = [key for key in keys if key in args.keep_keys] -envs = [env for env in results.keys() if env not in args.skip_envs] - -if len(args.keep_envs) > 0: - envs = [env for env in envs if env in args.keep_envs] - -labels = {key: key for key in keys} -if args.labels is not None: - for key, label in zip(keys, args.labels): - labels[key] = label - -if not args.skip_timesteps: - # Plot learning curves per env - for env in envs: - - plt.figure(f"Results {env}") - title = f"{env}" # BulletEnv-v0 - if "Mountain" in env: - title = "MountainCarContinuous-v0" - - plt.title(title, fontsize=args.fontsize) - - x_label_suffix = "" if args.no_million else "(1e6)" - plt.xlabel(f"Timesteps {x_label_suffix}", fontsize=args.fontsize) - plt.ylabel("Score", fontsize=args.fontsize) - - for key in keys: - # x axis in Millions of timesteps - divider = 1e6 - if args.no_million: - divider = 1.0 - - timesteps = results[env][key]["timesteps"] - mean_ = results[env][key]["mean"] - std_error = results[env][key]["std_error"] - - plt.xticks(fontsize=13) - plt.plot(timesteps / divider, mean_, label=labels[key], linewidth=3) - plt.fill_between(timesteps / divider, mean_ + std_error, mean_ - std_error, alpha=0.5) - - plt.legend(fontsize=args.fontsize) - plt.tight_layout() - -# Convert to pandas dataframe, in order to use seaborn -labels_df, envs_df, scores = [], [], [] -# Post-process to use it with rliable -# algo: (n_runs, n_envs) -normalized_score_dict = {} -# algo: (n_runs, n_envs, n_eval) -all_eval_normalized_scores_dict = {} -# Convert env key to env id for normalization -env_key_to_env_id = { - "Half": "HalfCheetahBulletEnv-v0", - "Ant": "AntBulletEnv-v0", - "Hopper": "HopperBulletEnv-v0", - "Walker": "Walker2DBulletEnv-v0", -} -# Backward compat -skip_all_algos_dict = False - -for key in keys: - algo_scores, all_algo_scores = [], [] - for env in envs: - if isinstance(results[env][key]["last_evals"], (np.float32, np.float64)): - # No enough timesteps - print(f"Skipping {env}-{key}") - continue - - for score in results[env][key]["last_evals"]: - labels_df.append(labels[key]) - # convert to int if needed - # labels_df.append(int(labels[key])) - envs_df.append(env) - scores.append(score) - - algo_scores.append(results[env][key]["last_evals"]) - - # Backward compat: mean_per_eval key may not be present - if "mean_per_eval" in results[env][key]: - all_algo_scores.append(results[env][key]["mean_per_eval"]) - else: - skip_all_algos_dict = True - - # Normalize score, env key must match env_id - if env in env_key_to_env_id: - algo_scores[-1] = normalize_score(algo_scores[-1], env_key_to_env_id[env]) - if not skip_all_algos_dict: - all_algo_scores[-1] = normalize_score(all_algo_scores[-1], env_key_to_env_id[env]) - elif env not in env_key_to_env_id and args.rliable: - warnings.warn(f"{env} not found for normalizing scores, you should update `env_key_to_env_id`") - - # Truncate to convert to matrix - min_runs = min(len(algo_score) for algo_score in algo_scores) - if min_runs > 0: - algo_scores = [algo_score[:min_runs] for algo_score in algo_scores] - # shape: (n_envs, n_runs) -> (n_runs, n_envs) - normalized_score_dict[labels[key]] = np.array(algo_scores).T - if not skip_all_algos_dict: - all_algo_scores = [all_algo_score[:, :min_runs] for all_algo_score in all_algo_scores] - # (n_envs, n_eval, n_runs) -> (n_runs, n_envs, n_eval) - all_eval_normalized_scores_dict[labels[key]] = np.array(all_algo_scores).transpose((2, 0, 1)) - -data_frame = pd.DataFrame(data=dict(Method=labels_df, Environment=envs_df, Score=scores)) - -# Rliable plots, see https://github.com/google-research/rliable -if args.rliable: - - if rly is None: - raise ImportError("You must install rliable package to use this feature. Note: Python 3.7+ is required in that case.") - - print("Computing bootstrap CI ...") - algorithms = list(labels.values()) - # Scores as a dictionary mapping algorithms to their normalized - # score matrices, each of which is of size `(num_runs x num_envs)`. - - aggregate_func = lambda x: np.array( # noqa: E731 - [ - metrics.aggregate_median(x), - metrics.aggregate_iqm(x), - metrics.aggregate_mean(x), - metrics.aggregate_optimality_gap(x), - ] - ) - aggregate_scores, aggregate_interval_estimates = rly.get_interval_estimates( - normalized_score_dict, - aggregate_func, - # Default was 50000 - reps=2000, # Number of bootstrap replications. - confidence_interval_size=args.ci_size, # Coverage of confidence interval. Defaults to 95%. - ) - - fig, axes = plot_utils.plot_interval_estimates( - aggregate_scores, - aggregate_interval_estimates, - metric_names=["Median", "IQM", "Mean", "Optimality Gap"], - algorithms=algorithms, - xlabel="Normalized Score", - xlabel_y_coordinate=0.02, - subfigure_width=5, - row_height=1, - max_ticks=4, - interval_height=0.6, - ) - fig.canvas.manager.set_window_title("Rliable metrics") - # Adjust margin to see the x label - plt.tight_layout() - plt.subplots_adjust(bottom=0.2) - - # Performance profiles - # Normalized score thresholds - normalized_score_thresholds = np.linspace(0.0, 1.5, 50) - score_distributions, score_distributions_cis = rly.create_performance_profile( - normalized_score_dict, - normalized_score_thresholds, - reps=2000, - confidence_interval_size=args.ci_size, - ) - # Plot score distributions - fig, ax = plt.subplots(ncols=1, figsize=(7, 5)) - plot_utils.plot_performance_profiles( - score_distributions, - normalized_score_thresholds, - performance_profile_cis=score_distributions_cis, - colors=dict(zip(algorithms, seaborn.color_palette("colorblind"))), - xlabel=r"Normalized Score $(\tau)$", - ax=ax, - ) - fig.canvas.manager.set_window_title("Performance profiles") - plt.legend() - - # Probability of improvement - # Scores as a dictionary containing pairs of normalized score - # matrices for pairs of algorithms we want to compare - algorithm_pairs_keys = itertools.combinations(algorithms, 2) - # algorithm_pairs = {.. , 'x,y': (score_x, score_y), ..} - algorithm_pairs = {} - for algo1, algo2 in algorithm_pairs_keys: - algorithm_pairs[f"{algo1}, {algo2}"] = (normalized_score_dict[algo1], normalized_score_dict[algo2]) - - if args.versus: - average_probabilities, average_prob_cis = rly.get_interval_estimates( - algorithm_pairs, - metrics.probability_of_improvement, - reps=1000, # Default was 50000 - confidence_interval_size=args.ci_size, - ) - plot_utils.plot_probability_of_improvement( - average_probabilities, - average_prob_cis, - figsize=(10, 8), - interval_height=0.6, - ) - plt.gcf().canvas.manager.set_window_title("Probability of Improvement") - plt.tight_layout() - - if args.iqm: - # Load scores as a dictionary mapping algorithms to their normalized - # score matrices across all evaluations, each of which is of size - # `(n_runs, n_envs, n_eval)` where scores are recorded every n steps. - # Only compute CI for 1/4 of the evaluations and keep the first and last eval - downsample_factor = 4 - n_evals = all_eval_normalized_scores_dict[algorithms[0]].shape[-1] - eval_indices = np.arange(n_evals - 1)[::downsample_factor] - eval_indices = np.concatenate((eval_indices, [n_evals - 1])) - eval_indices_scores_dict = { - algorithm: score[:, :, eval_indices] for algorithm, score in all_eval_normalized_scores_dict.items() - } - iqm = lambda scores: np.array( # noqa: E731 - [metrics.aggregate_iqm(scores[..., eval_idx]) for eval_idx in range(scores.shape[-1])] - ) - iqm_scores, iqm_cis = rly.get_interval_estimates( - eval_indices_scores_dict, - iqm, - reps=2000, - confidence_interval_size=args.ci_size, - ) - plot_utils.plot_sample_efficiency_curve( - eval_indices + 1, - iqm_scores, - iqm_cis, - algorithms=algorithms, - # TODO: convert to timesteps using the timesteps - xlabel=r"Number of Evaluations", - ylabel="IQM Normalized Score", - ) - plt.gcf().canvas.manager.set_window_title("IQM Normalized Score - Sample Efficiency Curve") - plt.legend() - plt.tight_layout() - - plt.show() - -# Plot final results with env as x axis -plt.figure("Sensitivity plot", figsize=args.figsize) -plt.title("Sensitivity plot", fontsize=args.fontsize) -# plt.title('Influence of the time feature', fontsize=args.fontsize) -# plt.title('Influence of the network architecture', fontsize=args.fontsize) -# plt.title('Influence of the exploration variance $log \sigma$', fontsize=args.fontsize) -# plt.title("Influence of the sampling frequency", fontsize=args.fontsize) -# plt.title('Parallel vs No Parallel Sampling', fontsize=args.fontsize) -# plt.title('Influence of the exploration function input', fontsize=args.fontsize) -plt.title("PyBullet envs", fontsize=args.fontsize) -plt.xticks(fontsize=13) -plt.xlabel("Environment", fontsize=args.fontsize) -plt.ylabel("Score", fontsize=args.fontsize) - - -ax = seaborn.barplot(x="Environment", y="Score", hue="Method", data=data_frame) -# Custom legend title -handles, labels_legend = ax.get_legend_handles_labels() -# ax.legend(handles=handles, labels=labels_legend, title=r"$log \sigma$", loc=args.legend_loc) -# ax.legend(handles=handles, labels=labels_legend, title="Network Architecture", loc=args.legend_loc) -# ax.legend(handles=handles, labels=labels_legend, title="Interval", loc=args.legend_loc) -# Old error plot -# for key in keys: -# values = [np.mean(results[env][key]["last_evals"]) for env in envs] -# # Overwrite the labels -# # labels = {key:i for i, key in enumerate(keys, start=-6)} -# plt.errorbar( -# envs, -# values, -# yerr=results[env][key]["std_error"][-1], -# linewidth=3, -# fmt="-o", -# label=labels[key], -# capsize=5, -# capthick=2, -# elinewidth=2, -# ) -# plt.legend(fontsize=13, loc=args.legend_loc) -plt.tight_layout() -if args.output is not None: - plt.savefig(args.output, format=args.format) - -# Plot final results with env as labels and method as x axis -# plt.figure('Sensitivity plot inverted', figsize=args.figsize) -# plt.title('Sensitivity plot', fontsize=args.fontsize) -# plt.xticks(fontsize=13) -# # plt.xlabel('Method', fontsize=args.fontsize) -# plt.ylabel('Score', fontsize=args.fontsize) -# -# for env in envs: -# values = [np.mean(results[env][key]['last_evals']) for key in keys] -# # Overwrite the labels -# # labels = {key:i for i, key in enumerate(keys, start=-6)} -# plt.errorbar(labels.values(), values, yerr=results[env][key]['std_error'][-1], -# linewidth=3, fmt='-o', label=env, capsize=5, capthick=2, elinewidth=2) -# -# plt.legend(fontsize=13, loc=args.legend_loc) -# plt.tight_layout() - -if args.boxplot: - # Box plot - plt.figure("Sensitivity box plot", figsize=args.figsize) - plt.title("Sensitivity box plot", fontsize=args.fontsize) - # plt.title('Influence of the exploration variance $log \sigma$ on Hopper', fontsize=args.fontsize) - # plt.title('Influence of the sampling frequency on Walker2D', fontsize=args.fontsize) - # plt.title('Influence of the exploration function input on Hopper', fontsize=args.fontsize) - plt.xticks(fontsize=13) - # plt.xlabel('Exploration variance $log \sigma$', fontsize=args.fontsize) - # plt.xlabel("Sampling frequency", fontsize=args.fontsize) - # plt.xlabel('Method', fontsize=args.fontsize) - plt.ylabel("Score", fontsize=args.fontsize) - - data, labels_ = [], [] - for env in envs: - for key in keys: - data.append(results[env][key]["last_evals"]) - text = f"{env}-{labels[key]}" if len(envs) > 1 else labels[key] - labels_.append(text) - artist_dict = plt.boxplot(data, patch_artist=True) - # Make the boxplot looks nice - # see https://github.com/mwaskom/seaborn/blob/master/seaborn/categorical.py - color_palette = seaborn.color_palette() - # orange - boxplot_color = color_palette[1] - restyle_boxplot(artist_dict, color=boxplot_color) - plt.xticks(np.arange(1, len(data) + 1), labels_, rotation=0) - plt.tight_layout() - -plt.show() +if __name__ == "__main__": + plot_from_file() diff --git a/scripts/plot_train.py b/scripts/plot_train.py index a3a04d68e..a3c98d1fd 100644 --- a/scripts/plot_train.py +++ b/scripts/plot_train.py @@ -1,93 +1,4 @@ -""" -Plot training reward/success rate -""" -import argparse -import os +from rl_zoo3.plots.plot_train import plot_train -import numpy as np -import seaborn -from matplotlib import pyplot as plt -from stable_baselines3.common.monitor import LoadMonitorResultsError, load_results -from stable_baselines3.common.results_plotter import X_EPISODES, X_TIMESTEPS, X_WALLTIME, ts2xy, window_func - -# Activate seaborn -seaborn.set() - -parser = argparse.ArgumentParser("Gather results, plot training reward/success") -parser.add_argument("-a", "--algo", help="Algorithm to include", type=str, required=True) -parser.add_argument("-e", "--env", help="Environment(s) to include", nargs="+", type=str, required=True) -parser.add_argument("-f", "--exp-folder", help="Folders to include", type=str, required=True) -parser.add_argument("--figsize", help="Figure size, width, height in inches.", nargs=2, type=int, default=[6.4, 4.8]) -parser.add_argument("--fontsize", help="Font size", type=int, default=14) -parser.add_argument("-max", "--max-timesteps", help="Max number of timesteps to display", type=int) -parser.add_argument("-x", "--x-axis", help="X-axis", choices=["steps", "episodes", "time"], type=str, default="steps") -parser.add_argument("-y", "--y-axis", help="Y-axis", choices=["success", "reward", "length"], type=str, default="reward") -parser.add_argument("-w", "--episode-window", help="Rolling window size", type=int, default=100) - -args = parser.parse_args() - - -algo = args.algo -envs = args.env -log_path = os.path.join(args.exp_folder, algo) - -x_axis = { - "steps": X_TIMESTEPS, - "episodes": X_EPISODES, - "time": X_WALLTIME, -}[args.x_axis] -x_label = { - "steps": "Timesteps", - "episodes": "Episodes", - "time": "Walltime (in hours)", -}[args.x_axis] - -y_axis = { - "success": "is_success", - "reward": "r", - "length": "l", -}[args.y_axis] -y_label = { - "success": "Training Success Rate", - "reward": "Training Episodic Reward", - "length": "Training Episode Length", -}[args.y_axis] - -dirs = [] - -for env in envs: - dirs.extend( - [ - os.path.join(log_path, folder) - for folder in os.listdir(log_path) - if (env in folder and os.path.isdir(os.path.join(log_path, folder))) - ] - ) - -plt.figure(y_label, figsize=args.figsize) -plt.title(y_label, fontsize=args.fontsize) -plt.xlabel(f"{x_label}", fontsize=args.fontsize) -plt.ylabel(y_label, fontsize=args.fontsize) -for folder in dirs: - try: - data_frame = load_results(folder) - except LoadMonitorResultsError: - continue - if args.max_timesteps is not None: - data_frame = data_frame[data_frame.l.cumsum() <= args.max_timesteps] - try: - y = np.array(data_frame[y_axis]) - except KeyError: - print(f"No data available for {folder}") - continue - x, _ = ts2xy(data_frame, x_axis) - - # Do not plot the smoothed curve at all if the timeseries is shorter than window size. - if x.shape[0] >= args.episode_window: - # Compute and plot rolling mean with window of size args.episode_window - x, y_mean = window_func(x, y, args.episode_window, np.mean) - plt.plot(x, y_mean, linewidth=2, label=folder.split("/")[-1]) - -plt.legend() -plt.tight_layout() -plt.show() +if __name__ == "__main__": + plot_train() diff --git a/scripts/run_docker_cpu.sh b/scripts/run_docker_cpu.sh index f7be51c6c..854fe15d1 100755 --- a/scripts/run_docker_cpu.sh +++ b/scripts/run_docker_cpu.sh @@ -7,5 +7,5 @@ echo "Executing in the docker (cpu image):" echo $cmd_line docker run -it --rm --network host --ipc=host \ - --mount src=$(pwd),target=/root/code/rl_zoo,type=bind stablebaselines/rl-baselines3-zoo-cpu:latest\ - bash -c "cd /root/code/rl_zoo/ && $cmd_line" + --mount src=$(pwd),target=/root/code/rl_zoo3,type=bind stablebaselines/rl-baselines3-zoo-cpu:latest\ + bash -c "cd /root/code/rl_zoo3/ && $cmd_line" diff --git a/scripts/run_docker_gpu.sh b/scripts/run_docker_gpu.sh index 4fd530d5a..9bcb31f42 100755 --- a/scripts/run_docker_gpu.sh +++ b/scripts/run_docker_gpu.sh @@ -7,5 +7,5 @@ echo "Executing in the docker (gpu image):" echo $cmd_line docker run -it --runtime=nvidia --rm --network host --ipc=host \ - --mount src=$(pwd),target=/root/code/rl_zoo,type=bind stablebaselines/rl-baselines3-zoo:latest\ - bash -c "cd /root/code/rl_zoo/ && $cmd_line" + --mount src=$(pwd),target=/root/code/rl_zoo3,type=bind stablebaselines/rl-baselines3-zoo:latest\ + bash -c "cd /root/code/rl_zoo3/ && $cmd_line" diff --git a/setup.cfg b/setup.cfg index b2796ea5e..c8e3f8490 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,11 +15,11 @@ inputs = . ignore = W503,W504,E203,E231 # line breaks before and after binary operators # Ignore import not used when aliases are defined per-file-ignores = - ./rl_zoo/__init__.py:F401 - ./rl_zoo/import_envs.py:F401 - ./scripts/all_plots.py:E501 - ./scripts/plot_train.py:E501 - ./scripts/plot_training_success.py:E501 + ./rl_zoo3/__init__.py:F401 + ./rl_zoo3/plots/__init__.py:F401 + ./rl_zoo3/import_envs.py:F401 + ./rl_zoo3/plots/all_plots.py:E501 + ./rl_zoo3/plots/plot_train.py:E501 exclude = # No need to traverse our git directory @@ -33,4 +33,4 @@ max-line-length = 127 [isort] profile = black line_length = 127 -src_paths = stable_baselines3,rl_zoo +src_paths = stable_baselines3,rl_zoo3 diff --git a/setup.py b/setup.py index 3787af35a..a60e5eeb5 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,13 @@ import os +import shutil -from setuptools import find_packages, setup +from setuptools import setup -with open(os.path.join("rl_zoo", "version.txt")) as file_handler: +with open(os.path.join("rl_zoo3", "version.txt")) as file_handler: __version__ = file_handler.read().strip() +# Copy hyperparams files for packaging +shutil.copytree("hyperparams", os.path.join("rl_zoo3", "hyperparams")) long_description = """ # RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents @@ -13,22 +16,16 @@ """ setup( - name="rl_zoo", - packages=[package for package in find_packages() if package.startswith("rl_zoo")], + name="rl_zoo3", + packages=["rl_zoo3", "rl_zoo3.plots"], package_data={ - "rl_zoo": [ + "rl_zoo3": [ "py.typed", "version.txt", - "../scripts/*.py", - "../hyperparams/*.yml", + "hyperparams/*.yml", ] }, - scripts=[ - "./scripts/all_plots.py", - "./scripts/plot_train.py", - "./scripts/plot_from_file.py", - ], - entry_points={"console_scripts": ["rl_zoo_train=rl_zoo.train:train", "rl_zoo=rl_zoo.cli:main"]}, + entry_points={"console_scripts": ["rl_zoo3=rl_zoo3.cli:main"]}, install_requires=[ "sb3-contrib>=1.6.1", "huggingface_sb3>=2.2.1, <3.*", @@ -39,6 +36,9 @@ "pytablewriter~=0.64", # TODO: add test dependencies ], + extras_require={ + "plots": ["seaborn", "rliable>=1.0.5", "scipy~=1.7.3"], + }, description="A Training Framework for Stable Baselines3 Reinforcement Learning Agents", author="Antonin Raffin", url="https://github.com/DLR-RM/rl-baselines3-zoo", @@ -59,6 +59,10 @@ ], ) +# Remove copied files after packaging +shutil.rmtree(os.path.join("rl_zoo3", "hyperparams")) + + # python setup.py sdist # python setup.py bdist_wheel # twine upload --repository-url https://test.pypi.org/legacy/ dist/* diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 08f8f316b..74feebf36 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -14,7 +14,7 @@ def test_raw_stat_callback(tmp_path): "--env", "CartPole-v1", "-params", - "callback:'rl_zoo.callbacks.RawStatisticsCallback'", + "callback:'rl_zoo3.callbacks.RawStatisticsCallback'", "--tensorboard-log", f"{tmp_path}", ] @@ -32,7 +32,7 @@ def test_tqdm_callback(tmp_path): "--env", "CartPole-v1", "-params", - "callback:'rl_zoo.callbacks.TQDMCallback'", + "callback:'rl_zoo3.callbacks.TQDMCallback'", "--tensorboard-log", f"{tmp_path}", ] diff --git a/tests/test_enjoy.py b/tests/test_enjoy.py index ab0aec2c5..0e66be78f 100644 --- a/tests/test_enjoy.py +++ b/tests/test_enjoy.py @@ -3,7 +3,7 @@ import pytest -from rl_zoo.utils import get_hf_trained_models, get_trained_models +from rl_zoo3.utils import get_hf_trained_models, get_trained_models def _assert_eq(left, right): @@ -46,7 +46,7 @@ def test_trained_agents(trained_model): def test_benchmark(tmp_path): args = ["-n", str(N_STEPS), "--benchmark-dir", tmp_path, "--test-mode", "--no-hub"] - return_code = subprocess.call(["python", "-m", "rl_zoo.benchmark"] + args) + return_code = subprocess.call(["python", "-m", "rl_zoo3.benchmark"] + args) _assert_eq(return_code, 0) @@ -94,7 +94,7 @@ def test_record_video(tmp_path): # Skip if no X-Server pytest.importorskip("pyglet.gl") - return_code = subprocess.call(["python", "-m", "rl_zoo.record_video"] + args) + return_code = subprocess.call(["python", "-m", "rl_zoo3.record_video"] + args) _assert_eq(return_code, 0) video_path = str(tmp_path / "final-model-sac-Pendulum-v1-step-0-to-step-100.mp4") # File is not empty @@ -135,7 +135,7 @@ def test_record_training(tmp_path): return_code = subprocess.call(["python", "train.py"] + args_training) _assert_eq(return_code, 0) - return_code = subprocess.call(["python", "-m", "rl_zoo.record_training"] + args_recording) + return_code = subprocess.call(["python", "-m", "rl_zoo3.record_training"] + args_recording) _assert_eq(return_code, 0) mp4_path = str(videos_tmp_path / "training.mp4") gif_path = str(videos_tmp_path / "training.gif") diff --git a/tests/test_train.py b/tests/test_train.py index ed32d144c..ec245497c 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -96,7 +96,7 @@ def test_parallel_train(tmp_path): "--log-folder", tmp_path, "-params", - "callback:'rl_zoo.callbacks.ParallelTrainCallback'", + "callback:'rl_zoo3.callbacks.ParallelTrainCallback'", ] return_code = subprocess.call(["python", "train.py"] + args) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 929c75592..c494a0076 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -5,8 +5,8 @@ from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import DummyVecEnv -from rl_zoo.utils import get_wrapper_class -from rl_zoo.wrappers import ActionNoiseWrapper, DelayedRewardWrapper, HistoryWrapper, TimeFeatureWrapper +from rl_zoo3.utils import get_wrapper_class +from rl_zoo3.wrappers import ActionNoiseWrapper, DelayedRewardWrapper, HistoryWrapper, TimeFeatureWrapper def test_wrappers(): @@ -22,8 +22,8 @@ def test_wrappers(): "env_wrapper", [ None, - {"rl_zoo.wrappers.HistoryWrapper": dict(horizon=2)}, - [{"rl_zoo.wrappers.HistoryWrapper": dict(horizon=3)}, "rl_zoo.wrappers.TimeFeatureWrapper"], + {"rl_zoo3.wrappers.HistoryWrapper": dict(horizon=2)}, + [{"rl_zoo3.wrappers.HistoryWrapper": dict(horizon=3)}, "rl_zoo3.wrappers.TimeFeatureWrapper"], ], ) def test_get_wrapper(env_wrapper): diff --git a/train.py b/train.py index 7316e13e1..9deb66e3a 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -from rl_zoo.train import train +from rl_zoo3.train import train if __name__ == "__main__": # noqa: C901 train()