Skip to content

Commit

Permalink
Support environments with slash ('/') in their name (#257)
Browse files Browse the repository at this point in the history
* Fix bug in get_last_run_id which would ignore runs when the env name contains a slash.

* Remove any slashes from the HF repo name.

* Fix formatting in utils.py

* Update CHANGELOG.md

* Construct correct repo id when loading from huggingface hub.

* Use environment name instead of environment id to ensure slashes are replaced by dashes.

* Fix get_trained_models util by loading the env_id from metadata instead of parsing it from the model path.

* Fix get_hf_trained_models util by loading the env_id and algo from the model card instead of parsing it from the repo id.

* Remove unused lines from migrate_to_hub.py

* Fix formatting in utils.py

* Change help text of --env parameter back to `environment ID`

* Add comments to explain naming scheme in.

* Make `get_trained_models()` use the `args.yml` file instead of the monitor log file to determine the used environment.

* Introduce usage of EnvironmentName to record_training.py and record_video.py

* Restrict huggingface_sb3 version to avoid breaking changes.

* Fix formatting in utils.py

* Add missing seaborn requirement.

* Pass gym_id instead of env_name to is_atari and crete_test_env.

* Use EnvironmentName in the ExperimentManger to properly construct folder names.

* Fix formatting in exp_manager.py

* Disable slow check and fix recurrent ppo alias

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
ernestum and araffin committed Jul 30, 2022
1 parent f1064a7 commit c2f00ea
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 110 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- Fix `Reacher-v3` name in PPO hyperparameter file
- Pinned ale-py==0.7.4 until new SB3 version is released
- Fix enjoy / record videos with LSTM policy
- Fix bug with environments that have a slash in their name (@ernestum)
- Changed `optimize_memory_usage` to `False` for DQN/QR-DQN on Atari games,
if you want to save RAM, you need to deactivate `handle_timeout_termination`
in the `replay_buffer_kwargs`
Expand Down
19 changes: 10 additions & 9 deletions enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch as th
import yaml
from huggingface_sb3 import EnvironmentName
from stable_baselines3.common.utils import set_random_seed

import utils.import_envs # noqa: F401 pylint: disable=unused-import
Expand All @@ -17,7 +18,7 @@

def main(): # noqa: C901
parser = argparse.ArgumentParser()
parser.add_argument("--env", help="environment ID", type=str, default="CartPole-v1")
parser.add_argument("--env", help="environment ID", type=EnvironmentName, default="CartPole-v1")
parser.add_argument("-f", "--folder", help="Log folder", type=str, default="rl-trained-agents")
parser.add_argument("--algo", help="RL Algorithm", default="ppo", type=str, required=False, choices=list(ALGOS.keys()))
parser.add_argument("-n", "--n-timesteps", help="number of timesteps", default=1000, type=int)
Expand Down Expand Up @@ -67,7 +68,7 @@ def main(): # noqa: C901
for env_module in args.gym_packages:
importlib.import_module(env_module)

env_id = args.env
env_name: EnvironmentName = args.env
algo = args.algo
folder = args.folder

Expand All @@ -76,7 +77,7 @@ def main(): # noqa: C901
args.exp_id,
folder,
algo,
env_id,
env_name,
args.load_best,
args.load_checkpoint,
args.load_last_checkpoint,
Expand All @@ -91,7 +92,7 @@ def main(): # noqa: C901
# Auto-download
download_from_hub(
algo=algo,
env_id=env_id,
env_name=env_name,
exp_id=args.exp_id,
folder=folder,
organization="sb3",
Expand All @@ -103,7 +104,7 @@ def main(): # noqa: C901
args.exp_id,
folder,
algo,
env_id,
env_name,
args.load_best,
args.load_checkpoint,
args.load_last_checkpoint,
Expand All @@ -124,14 +125,14 @@ def main(): # noqa: C901
print(f"Setting torch.num_threads to {args.num_threads}")
th.set_num_threads(args.num_threads)

is_atari = ExperimentManager.is_atari(env_id)
is_atari = ExperimentManager.is_atari(env_name.gym_id)

stats_path = os.path.join(log_path, env_id)
stats_path = os.path.join(log_path, env_name)
hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=args.norm_reward, test_mode=True)

# load env_kwargs if existing
env_kwargs = {}
args_path = os.path.join(log_path, env_id, "args.yml")
args_path = os.path.join(log_path, env_name, "args.yml")
if os.path.isfile(args_path):
with open(args_path) as f:
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
Expand All @@ -144,7 +145,7 @@ def main(): # noqa: C901
log_dir = args.reward_log if args.reward_log != "" else None

env = create_test_env(
env_id,
env_name.gym_id,
n_envs=args.n_envs,
stats_path=stats_path,
seed=args.seed,
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ panda-gym==1.1.1 # tmp fix: until compatibility with panda-gym v2
rliable>=1.0.5
wandb
ale-py==0.7.4 # tmp fix: until new SB3 version is released
# TODO: replace with release
git+https://github.com/huggingface/huggingface_sb3
huggingface_sb3>=2.2.1, <3.*
seaborn
4 changes: 0 additions & 4 deletions scripts/migrate_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,4 @@
if algo == "her":
continue

# if model doesn't exist already
repo_name = f"{algo}-{env_id}"
repo_id = f"{orga}/{repo_name}"

return_code = subprocess.call(["python", "-m", "utils.push_to_hub"] + args)
31 changes: 18 additions & 13 deletions utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import optuna
import torch as th
import yaml
from huggingface_sb3 import EnvironmentName
from optuna.integration.skopt import SkoptSampler
from optuna.pruners import BasePruner, MedianPruner, NopPruner, SuccessiveHalvingPruner
from optuna.samplers import BaseSampler, RandomSampler, TPESampler
Expand Down Expand Up @@ -95,7 +96,7 @@ def __init__(
):
super().__init__()
self.algo = algo
self.env_id = env_id
self.env_name = EnvironmentName(env_id)
# Custom params
self.custom_hyperparams = hyperparams
self.env_kwargs = {} if env_kwargs is None else env_kwargs
Expand Down Expand Up @@ -144,22 +145,22 @@ def __init__(
self.pruner = pruner
self.n_startup_trials = n_startup_trials
self.n_evaluations = n_evaluations
self.deterministic_eval = not self.is_atari(self.env_id)
self.deterministic_eval = not self.is_atari(env_id)
self.device = device

# Logging
self.log_folder = log_folder
self.tensorboard_log = None if tensorboard_log == "" else os.path.join(tensorboard_log, env_id)
self.tensorboard_log = None if tensorboard_log == "" else os.path.join(tensorboard_log, self.env_name)
self.verbose = verbose
self.args = args
self.log_interval = log_interval
self.save_replay_buffer = save_replay_buffer

self.log_path = f"{log_folder}/{self.algo}/"
self.save_path = os.path.join(
self.log_path, f"{self.env_id}_{get_latest_run_id(self.log_path, self.env_id) + 1}{uuid_str}"
self.log_path, f"{self.env_name}_{get_latest_run_id(self.log_path, self.env_name) + 1}{uuid_str}"
)
self.params_path = f"{self.save_path}/{self.env_id}"
self.params_path = f"{self.save_path}/{self.env_name}"

def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
"""
Expand Down Expand Up @@ -235,7 +236,7 @@ def save_trained_model(self, model: BaseAlgorithm) -> None:
:param model:
"""
print(f"Saving to {self.save_path}")
model.save(f"{self.save_path}/{self.env_id}")
model.save(f"{self.save_path}/{self.env_name}")

if hasattr(model, "save_replay_buffer") and self.save_replay_buffer:
print("Saving replay buffer")
Expand Down Expand Up @@ -267,12 +268,12 @@ def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Load hyperparameters from yaml file
with open(f"hyperparams/{self.algo}.yml") as f:
hyperparams_dict = yaml.safe_load(f)
if self.env_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_id]
if self.env_name.gym_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_name.gym_id]
elif self._is_atari:
hyperparams = hyperparams_dict["atari"]
else:
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_id}")
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}")

if self.custom_hyperparams is not None:
# Overwrite hyperparams if needed
Expand Down Expand Up @@ -486,7 +487,7 @@ def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv:
:return:
"""
# Pretrained model, load normalization
path_ = os.path.join(os.path.dirname(self.trained_agent), self.env_id)
path_ = os.path.join(os.path.dirname(self.trained_agent), self.env_name)
path_ = os.path.join(path_, "vecnormalize.pkl")

if os.path.exists(path_):
Expand Down Expand Up @@ -530,13 +531,17 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)

monitor_kwargs = {}
# Special case for GoalEnvs: log success rate too
if "Neck" in self.env_id or self.is_robotics_env(self.env_id) or "parking-v0" in self.env_id:
if (
"Neck" in self.env_name.gym_id
or self.is_robotics_env(self.env_name.gym_id)
or "parking-v0" in self.env_name.gym_id
):
monitor_kwargs = dict(info_keywords=("is_success",))

# On most env, SubprocVecEnv does not help and is quite memory hungry
# therefore we use DummyVecEnv by default
env = make_vec_env(
env_id=self.env_id,
env_id=self.env_name.gym_id,
n_envs=n_envs,
seed=self.seed,
env_kwargs=self.env_kwargs,
Expand Down Expand Up @@ -797,7 +802,7 @@ def hyperparameters_optimization(self) -> None:
print(f" {key}: {value}")

report_name = (
f"report_{self.env_id}_{self.n_trials}-trials-{self.n_timesteps}"
f"report_{self.env_name}_{self.n_trials}-trials-{self.n_timesteps}"
f"-{self.sampler}-{self.pruner}_{int(time.time())}"
)

Expand Down
29 changes: 15 additions & 14 deletions utils/load_from_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from pathlib import Path
from typing import Optional

from huggingface_sb3 import load_from_hub
from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId, load_from_hub
from requests.exceptions import HTTPError

from utils import ALGOS, get_latest_run_id


def download_from_hub(
algo: str,
env_id: str,
env_name: EnvironmentName,
exp_id: int,
folder: str,
organization: str,
Expand All @@ -27,7 +27,7 @@ def download_from_hub(
where repo_name = {algo}-{env_id}
:param algo: Algorithm
:param env_id: Environment id
:param env_name: Environment name
:param exp_id: Experiment id
:param folder: Log folder
:param organization: Huggingface organization
Expand All @@ -36,15 +36,16 @@ def download_from_hub(
if it already exists.
"""

model_name = ModelName(algo, env_name)

if repo_name is None:
repo_name = f"{algo}-{env_id}"
repo_name = model_name # Note: model name is {algo}-{env_name}

repo_id = f"{organization}/{repo_name}"
# Note: repo id is {organization}/{repo_name}
repo_id = ModelRepoId(organization, repo_name)
print(f"Downloading from https://huggingface.co/{repo_id}")

model_name = f"{algo}-{env_id}"

checkpoint = load_from_hub(repo_id, f"{model_name}.zip")
checkpoint = load_from_hub(repo_id, model_name.filename)
config_path = load_from_hub(repo_id, "config.yml")

# If VecNormalize, download
Expand All @@ -59,10 +60,10 @@ def download_from_hub(
train_eval_metrics = load_from_hub(repo_id, "train_eval_metrics.zip")

if exp_id == 0:
exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) + 1
exp_id = get_latest_run_id(os.path.join(folder, algo), env_name) + 1
# Sanity checks
if exp_id > 0:
log_path = os.path.join(folder, algo, f"{env_id}_{exp_id}")
log_path = os.path.join(folder, algo, f"{env_name}_{exp_id}")
else:
log_path = os.path.join(folder, algo)

Expand All @@ -82,11 +83,11 @@ def download_from_hub(
print(f"Saving to {log_path}")
# Create folder structure
os.makedirs(log_path, exist_ok=True)
config_folder = os.path.join(log_path, env_id)
config_folder = os.path.join(log_path, env_name)
os.makedirs(config_folder, exist_ok=True)

# Copy config files and saved stats
shutil.copy(checkpoint, os.path.join(log_path, f"{env_id}.zip"))
shutil.copy(checkpoint, os.path.join(log_path, f"{env_name}.zip"))
shutil.copy(saved_args, os.path.join(config_folder, "args.yml"))
shutil.copy(config_path, os.path.join(config_folder, "config.yml"))
shutil.copy(env_kwargs, os.path.join(config_folder, "env_kwargs.yml"))
Expand All @@ -100,7 +101,7 @@ def download_from_hub(

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--env", help="environment ID", type=str, required=True)
parser.add_argument("--env", help="environment ID", type=EnvironmentName, required=True)
parser.add_argument("-f", "--folder", help="Log folder", type=str, required=True)
parser.add_argument("-orga", "--organization", help="Huggingface hub organization", default="sb3")
parser.add_argument("-name", "--repo-name", help="Huggingface hub repository name, by default 'algo-env_id'", type=str)
Expand All @@ -114,7 +115,7 @@ def download_from_hub(

download_from_hub(
algo=args.algo,
env_id=args.env,
env_name=args.env,
exp_id=args.exp_id,
folder=args.folder,
organization=args.organization,
Expand Down
Loading

0 comments on commit c2f00ea

Please sign in to comment.