From c2f00ea81c42daa0af0e5d131eb51b16552d5d8b Mon Sep 17 00:00:00 2001 From: "M. Ernestus" Date: Sat, 30 Jul 2022 14:11:50 +0200 Subject: [PATCH] Support environments with slash ('/') in their name (#257) * Fix bug in get_last_run_id which would ignore runs when the env name contains a slash. * Remove any slashes from the HF repo name. * Fix formatting in utils.py * Update CHANGELOG.md * Construct correct repo id when loading from huggingface hub. * Use environment name instead of environment id to ensure slashes are replaced by dashes. * Fix get_trained_models util by loading the env_id from metadata instead of parsing it from the model path. * Fix get_hf_trained_models util by loading the env_id and algo from the model card instead of parsing it from the repo id. * Remove unused lines from migrate_to_hub.py * Fix formatting in utils.py * Change help text of --env parameter back to `environment ID` * Add comments to explain naming scheme in. * Make `get_trained_models()` use the `args.yml` file instead of the monitor log file to determine the used environment. * Introduce usage of EnvironmentName to record_training.py and record_video.py * Restrict huggingface_sb3 version to avoid breaking changes. * Fix formatting in utils.py * Add missing seaborn requirement. * Pass gym_id instead of env_name to is_atari and crete_test_env. * Use EnvironmentName in the ExperimentManger to properly construct folder names. * Fix formatting in exp_manager.py * Disable slow check and fix recurrent ppo alias Co-authored-by: Antonin RAFFIN --- CHANGELOG.md | 1 + enjoy.py | 19 +++++---- requirements.txt | 4 +- rl-trained-agents | 2 +- scripts/migrate_to_hub.py | 4 -- utils/exp_manager.py | 31 ++++++++------ utils/load_from_hub.py | 29 +++++++------ utils/push_to_hub.py | 44 ++++++++++--------- utils/record_training.py | 14 +++--- utils/record_video.py | 15 ++++--- utils/utils.py | 90 +++++++++++++++++++++++++-------------- 11 files changed, 143 insertions(+), 110 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 227b6b322..1d58b1a48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ - Fix `Reacher-v3` name in PPO hyperparameter file - Pinned ale-py==0.7.4 until new SB3 version is released - Fix enjoy / record videos with LSTM policy +- Fix bug with environments that have a slash in their name (@ernestum) - Changed `optimize_memory_usage` to `False` for DQN/QR-DQN on Atari games, if you want to save RAM, you need to deactivate `handle_timeout_termination` in the `replay_buffer_kwargs` diff --git a/enjoy.py b/enjoy.py index b2379b3df..8e6c33298 100644 --- a/enjoy.py +++ b/enjoy.py @@ -6,6 +6,7 @@ import numpy as np import torch as th import yaml +from huggingface_sb3 import EnvironmentName from stable_baselines3.common.utils import set_random_seed import utils.import_envs # noqa: F401 pylint: disable=unused-import @@ -17,7 +18,7 @@ def main(): # noqa: C901 parser = argparse.ArgumentParser() - parser.add_argument("--env", help="environment ID", type=str, default="CartPole-v1") + parser.add_argument("--env", help="environment ID", type=EnvironmentName, default="CartPole-v1") parser.add_argument("-f", "--folder", help="Log folder", type=str, default="rl-trained-agents") parser.add_argument("--algo", help="RL Algorithm", default="ppo", type=str, required=False, choices=list(ALGOS.keys())) parser.add_argument("-n", "--n-timesteps", help="number of timesteps", default=1000, type=int) @@ -67,7 +68,7 @@ def main(): # noqa: C901 for env_module in args.gym_packages: importlib.import_module(env_module) - env_id = args.env + env_name: EnvironmentName = args.env algo = args.algo folder = args.folder @@ -76,7 +77,7 @@ def main(): # noqa: C901 args.exp_id, folder, algo, - env_id, + env_name, args.load_best, args.load_checkpoint, args.load_last_checkpoint, @@ -91,7 +92,7 @@ def main(): # noqa: C901 # Auto-download download_from_hub( algo=algo, - env_id=env_id, + env_name=env_name, exp_id=args.exp_id, folder=folder, organization="sb3", @@ -103,7 +104,7 @@ def main(): # noqa: C901 args.exp_id, folder, algo, - env_id, + env_name, args.load_best, args.load_checkpoint, args.load_last_checkpoint, @@ -124,14 +125,14 @@ def main(): # noqa: C901 print(f"Setting torch.num_threads to {args.num_threads}") th.set_num_threads(args.num_threads) - is_atari = ExperimentManager.is_atari(env_id) + is_atari = ExperimentManager.is_atari(env_name.gym_id) - stats_path = os.path.join(log_path, env_id) + stats_path = os.path.join(log_path, env_name) hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=args.norm_reward, test_mode=True) # load env_kwargs if existing env_kwargs = {} - args_path = os.path.join(log_path, env_id, "args.yml") + args_path = os.path.join(log_path, env_name, "args.yml") if os.path.isfile(args_path): with open(args_path) as f: loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr @@ -144,7 +145,7 @@ def main(): # noqa: C901 log_dir = args.reward_log if args.reward_log != "" else None env = create_test_env( - env_id, + env_name.gym_id, n_envs=args.n_envs, stats_path=stats_path, seed=args.seed, diff --git a/requirements.txt b/requirements.txt index 32c74be21..5645637d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,5 +15,5 @@ panda-gym==1.1.1 # tmp fix: until compatibility with panda-gym v2 rliable>=1.0.5 wandb ale-py==0.7.4 # tmp fix: until new SB3 version is released -# TODO: replace with release -git+https://github.com/huggingface/huggingface_sb3 +huggingface_sb3>=2.2.1, <3.* +seaborn \ No newline at end of file diff --git a/rl-trained-agents b/rl-trained-agents index 72feeb8c2..f0b0efc31 160000 --- a/rl-trained-agents +++ b/rl-trained-agents @@ -1 +1 @@ -Subproject commit 72feeb8c2e8985e5382ee61f3542ee023ec81922 +Subproject commit f0b0efc31a9b41953085158c0c57183ba1467b28 diff --git a/scripts/migrate_to_hub.py b/scripts/migrate_to_hub.py index d2622a43a..81bef4d6d 100644 --- a/scripts/migrate_to_hub.py +++ b/scripts/migrate_to_hub.py @@ -16,8 +16,4 @@ if algo == "her": continue - # if model doesn't exist already - repo_name = f"{algo}-{env_id}" - repo_id = f"{orga}/{repo_name}" - return_code = subprocess.call(["python", "-m", "utils.push_to_hub"] + args) diff --git a/utils/exp_manager.py b/utils/exp_manager.py index 620d325f7..da1512304 100644 --- a/utils/exp_manager.py +++ b/utils/exp_manager.py @@ -12,6 +12,7 @@ import optuna import torch as th import yaml +from huggingface_sb3 import EnvironmentName from optuna.integration.skopt import SkoptSampler from optuna.pruners import BasePruner, MedianPruner, NopPruner, SuccessiveHalvingPruner from optuna.samplers import BaseSampler, RandomSampler, TPESampler @@ -95,7 +96,7 @@ def __init__( ): super().__init__() self.algo = algo - self.env_id = env_id + self.env_name = EnvironmentName(env_id) # Custom params self.custom_hyperparams = hyperparams self.env_kwargs = {} if env_kwargs is None else env_kwargs @@ -144,12 +145,12 @@ def __init__( self.pruner = pruner self.n_startup_trials = n_startup_trials self.n_evaluations = n_evaluations - self.deterministic_eval = not self.is_atari(self.env_id) + self.deterministic_eval = not self.is_atari(env_id) self.device = device # Logging self.log_folder = log_folder - self.tensorboard_log = None if tensorboard_log == "" else os.path.join(tensorboard_log, env_id) + self.tensorboard_log = None if tensorboard_log == "" else os.path.join(tensorboard_log, self.env_name) self.verbose = verbose self.args = args self.log_interval = log_interval @@ -157,9 +158,9 @@ def __init__( self.log_path = f"{log_folder}/{self.algo}/" self.save_path = os.path.join( - self.log_path, f"{self.env_id}_{get_latest_run_id(self.log_path, self.env_id) + 1}{uuid_str}" + self.log_path, f"{self.env_name}_{get_latest_run_id(self.log_path, self.env_name) + 1}{uuid_str}" ) - self.params_path = f"{self.save_path}/{self.env_id}" + self.params_path = f"{self.save_path}/{self.env_name}" def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]: """ @@ -235,7 +236,7 @@ def save_trained_model(self, model: BaseAlgorithm) -> None: :param model: """ print(f"Saving to {self.save_path}") - model.save(f"{self.save_path}/{self.env_id}") + model.save(f"{self.save_path}/{self.env_name}") if hasattr(model, "save_replay_buffer") and self.save_replay_buffer: print("Saving replay buffer") @@ -267,12 +268,12 @@ def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: # Load hyperparameters from yaml file with open(f"hyperparams/{self.algo}.yml") as f: hyperparams_dict = yaml.safe_load(f) - if self.env_id in list(hyperparams_dict.keys()): - hyperparams = hyperparams_dict[self.env_id] + if self.env_name.gym_id in list(hyperparams_dict.keys()): + hyperparams = hyperparams_dict[self.env_name.gym_id] elif self._is_atari: hyperparams = hyperparams_dict["atari"] else: - raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_id}") + raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}") if self.custom_hyperparams is not None: # Overwrite hyperparams if needed @@ -486,7 +487,7 @@ def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv: :return: """ # Pretrained model, load normalization - path_ = os.path.join(os.path.dirname(self.trained_agent), self.env_id) + path_ = os.path.join(os.path.dirname(self.trained_agent), self.env_name) path_ = os.path.join(path_, "vecnormalize.pkl") if os.path.exists(path_): @@ -530,13 +531,17 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) monitor_kwargs = {} # Special case for GoalEnvs: log success rate too - if "Neck" in self.env_id or self.is_robotics_env(self.env_id) or "parking-v0" in self.env_id: + if ( + "Neck" in self.env_name.gym_id + or self.is_robotics_env(self.env_name.gym_id) + or "parking-v0" in self.env_name.gym_id + ): monitor_kwargs = dict(info_keywords=("is_success",)) # On most env, SubprocVecEnv does not help and is quite memory hungry # therefore we use DummyVecEnv by default env = make_vec_env( - env_id=self.env_id, + env_id=self.env_name.gym_id, n_envs=n_envs, seed=self.seed, env_kwargs=self.env_kwargs, @@ -797,7 +802,7 @@ def hyperparameters_optimization(self) -> None: print(f" {key}: {value}") report_name = ( - f"report_{self.env_id}_{self.n_trials}-trials-{self.n_timesteps}" + f"report_{self.env_name}_{self.n_trials}-trials-{self.n_timesteps}" f"-{self.sampler}-{self.pruner}_{int(time.time())}" ) diff --git a/utils/load_from_hub.py b/utils/load_from_hub.py index 490ef1f8f..a1cc39591 100644 --- a/utils/load_from_hub.py +++ b/utils/load_from_hub.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional -from huggingface_sb3 import load_from_hub +from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId, load_from_hub from requests.exceptions import HTTPError from utils import ALGOS, get_latest_run_id @@ -13,7 +13,7 @@ def download_from_hub( algo: str, - env_id: str, + env_name: EnvironmentName, exp_id: int, folder: str, organization: str, @@ -27,7 +27,7 @@ def download_from_hub( where repo_name = {algo}-{env_id} :param algo: Algorithm - :param env_id: Environment id + :param env_name: Environment name :param exp_id: Experiment id :param folder: Log folder :param organization: Huggingface organization @@ -36,15 +36,16 @@ def download_from_hub( if it already exists. """ + model_name = ModelName(algo, env_name) + if repo_name is None: - repo_name = f"{algo}-{env_id}" + repo_name = model_name # Note: model name is {algo}-{env_name} - repo_id = f"{organization}/{repo_name}" + # Note: repo id is {organization}/{repo_name} + repo_id = ModelRepoId(organization, repo_name) print(f"Downloading from https://huggingface.co/{repo_id}") - model_name = f"{algo}-{env_id}" - - checkpoint = load_from_hub(repo_id, f"{model_name}.zip") + checkpoint = load_from_hub(repo_id, model_name.filename) config_path = load_from_hub(repo_id, "config.yml") # If VecNormalize, download @@ -59,10 +60,10 @@ def download_from_hub( train_eval_metrics = load_from_hub(repo_id, "train_eval_metrics.zip") if exp_id == 0: - exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) + 1 + exp_id = get_latest_run_id(os.path.join(folder, algo), env_name) + 1 # Sanity checks if exp_id > 0: - log_path = os.path.join(folder, algo, f"{env_id}_{exp_id}") + log_path = os.path.join(folder, algo, f"{env_name}_{exp_id}") else: log_path = os.path.join(folder, algo) @@ -82,11 +83,11 @@ def download_from_hub( print(f"Saving to {log_path}") # Create folder structure os.makedirs(log_path, exist_ok=True) - config_folder = os.path.join(log_path, env_id) + config_folder = os.path.join(log_path, env_name) os.makedirs(config_folder, exist_ok=True) # Copy config files and saved stats - shutil.copy(checkpoint, os.path.join(log_path, f"{env_id}.zip")) + shutil.copy(checkpoint, os.path.join(log_path, f"{env_name}.zip")) shutil.copy(saved_args, os.path.join(config_folder, "args.yml")) shutil.copy(config_path, os.path.join(config_folder, "config.yml")) shutil.copy(env_kwargs, os.path.join(config_folder, "env_kwargs.yml")) @@ -100,7 +101,7 @@ def download_from_hub( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--env", help="environment ID", type=str, required=True) + parser.add_argument("--env", help="environment ID", type=EnvironmentName, required=True) parser.add_argument("-f", "--folder", help="Log folder", type=str, required=True) parser.add_argument("-orga", "--organization", help="Huggingface hub organization", default="sb3") parser.add_argument("-name", "--repo-name", help="Huggingface hub repository name, by default 'algo-env_id'", type=str) @@ -114,7 +115,7 @@ def download_from_hub( download_from_hub( algo=args.algo, - env_id=args.env, + env_name=args.env, exp_id=args.exp_id, folder=args.folder, organization=args.organization, diff --git a/utils/push_to_hub.py b/utils/push_to_hub.py index 9e518427c..79fb6f902 100644 --- a/utils/push_to_hub.py +++ b/utils/push_to_hub.py @@ -12,6 +12,7 @@ import yaml from huggingface_hub import HfApi, Repository from huggingface_hub.repocard import metadata_save +from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId from huggingface_sb3.push_to_hub import _evaluate_agent, _generate_replay, generate_metadata from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.utils import set_random_seed @@ -56,7 +57,7 @@ def generate_model_card( Generate the model card for the Hub :param algo_class_name: name of the algorithm class - :param env_id: name of the environment + :param env_id: gym id of the environment :param mean_reward: mean reward of the agent :param std_reward: standard deviation of the mean reward of the agent :return: Model card (readme) and metadata (performance, algo/env id, tags) @@ -114,15 +115,15 @@ def generate_model_card( def package_to_hub( model: BaseAlgorithm, - model_name: str, + model_name: ModelName, algo_name: str, algo_class_name: str, log_path: Path, hyperparams: Dict[str, Any], env_kwargs: Dict[str, Any], - env_id: str, + env_name: EnvironmentName, eval_env: VecEnv, - repo_id: str, + repo_id: ModelRepoId, commit_message: str, is_deterministic: bool = True, n_eval_episodes=10, @@ -143,7 +144,7 @@ def package_to_hub( use `push_to_hub` method. :param model: trained model - :param model_name: name of the model zip file + :param model_name: name of the model :param algo_name: alias used in the zoo for the algorithm, usually lower case of the class (a2c, ars, ppo, ppo_lstm) :param algo_class_name: name of the architecture of your model @@ -154,7 +155,7 @@ def package_to_hub( includes wrappers. :param env_kwargs: Additional keyword arguments that were passed to the environment. - :param env_id: name of the environment + :param env_name: name of the environment :param eval_env: environment used to evaluate the agent :param repo_id: id of the model repository from the Hugging Face Hub :param commit_message: commit message @@ -192,6 +193,7 @@ def package_to_hub( repo.lfs_track(["*.mp4"]) # Step 1: Save the model + print("Saving model to:", repo_local_path / model_name) model.save(repo_local_path / model_name) # Retrieve VecNormalize wrapper if it exists @@ -207,12 +209,12 @@ def package_to_hub( maybe_vec_normalize.norm_reward = False # Unzip the model - with zipfile.ZipFile(repo_local_path / f"{model_name}.zip", "r") as zip_ref: + with zipfile.ZipFile(repo_local_path / model_name.filename, "r") as zip_ref: zip_ref.extractall(repo_local_path / model_name) # Step 2: Copy config files - args_path = log_path / env_id / "args.yml" - config_path = log_path / env_id / "config.yml" + args_path = log_path / env_name / "args.yml" + config_path = log_path / env_name / "config.yml" shutil.copy(args_path, repo_local_path / "args.yml") shutil.copy(config_path, repo_local_path / "config.yml") @@ -246,7 +248,7 @@ def package_to_hub( algo_name, algo_class_name, organization, - env_id, + env_name.gym_id, mean_reward, std_reward, hyperparams, @@ -264,7 +266,7 @@ def package_to_hub( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--env", help="environment ID", type=str, required=True) + parser.add_argument("--env", help="environment ID", type=EnvironmentName, required=True) parser.add_argument("-f", "--folder", help="Log folder", type=str, required=True) parser.add_argument("--algo", help="RL Algorithm", type=str, required=True, choices=list(ALGOS.keys())) parser.add_argument("-n", "--n-timesteps", help="number of timesteps", default=1000, type=int) @@ -298,11 +300,11 @@ def package_to_hub( "--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor" ) parser.add_argument("-orga", "--organization", help="Huggingface hub organization", type=str, required=True) - parser.add_argument("-name", "--repo-name", help="Huggingface hub repository name, by default 'algo-env_id'", type=str) + parser.add_argument("-name", "--repo-name", help="Huggingface hub repository name, by default 'algo-env'", type=str) parser.add_argument("-m", "--commit-message", help="Commit message", default="Initial commit", type=str) args = parser.parse_args() - env_id = args.env + env_name: EnvironmentName = args.env algo = args.algo _, model_path, log_path = get_model_path( @@ -330,14 +332,14 @@ def package_to_hub( print(f"Setting torch.num_threads to {args.num_threads}") th.set_num_threads(args.num_threads) - is_atari = ExperimentManager.is_atari(env_id) + is_atari = ExperimentManager.is_atari(env_name.gym_id) - stats_path = os.path.join(log_path, env_id) + stats_path = os.path.join(log_path, env_name) hyperparams, stats_path = get_saved_hyperparams(stats_path, test_mode=True) # load env_kwargs if existing env_kwargs = {} - args_path = os.path.join(log_path, env_id, "args.yml") + args_path = os.path.join(log_path, env_name, "args.yml") if os.path.isfile(args_path): with open(args_path) as f: loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr @@ -348,7 +350,7 @@ def package_to_hub( env_kwargs.update(args.env_kwargs) eval_env = create_test_env( - env_id, + env_name.gym_id, n_envs=args.n_envs, stats_path=stats_path, seed=args.seed, @@ -372,13 +374,13 @@ def package_to_hub( stochastic = args.stochastic or is_atari and not args.deterministic deterministic = not stochastic - # Default model name, the model will be saved under "{algo}-{env_id}.zip" - model_name = f"{algo}-{env_id}" + # Default model name, the model will be saved under "{algo}-{env_name}.zip" + model_name = ModelName(algo, env_name) if args.repo_name is None: args.repo_name = model_name - repo_id = f"{args.organization}/{args.repo_name}" + repo_id = ModelRepoId(args.organization, args.repo_name) print(f"Uploading to {repo_id}, make sure to have the rights") package_to_hub( @@ -389,7 +391,7 @@ def package_to_hub( Path(log_path), hyperparams, env_kwargs, - env_id, + env_name, eval_env, repo_id=repo_id, commit_message=args.commit_message, diff --git a/utils/record_training.py b/utils/record_training.py index 9e9a68fd4..f6a3a043f 100644 --- a/utils/record_training.py +++ b/utils/record_training.py @@ -5,11 +5,13 @@ import subprocess from copy import deepcopy +from huggingface_sb3 import EnvironmentName + from utils.utils import ALGOS, get_latest_run_id if __name__ == "__main__": # noqa: C901 parser = argparse.ArgumentParser() - parser.add_argument("--env", help="environment ID", type=str, default="CartPole-v1") + parser.add_argument("--env", help="environment ID", type=EnvironmentName, default="CartPole-v1") parser.add_argument("-f", "--folder", help="Log folder", type=str, default="rl-trained-agents") parser.add_argument("-o", "--output-folder", help="Output folder", type=str) parser.add_argument("--algo", help="RL Algorithm", default="ppo", type=str, required=False, choices=list(ALGOS.keys())) @@ -21,7 +23,7 @@ parser.add_argument("--exp-id", help="Experiment ID (default: 0: latest, -1: no exp folder)", default=0, type=int) args = parser.parse_args() - env_id = args.env + env_name: EnvironmentName = args.env algo = args.algo folder = args.folder n_timesteps = args.n_timesteps @@ -32,11 +34,11 @@ convert_to_gif = args.gif if args.exp_id == 0: - args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) + args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_name) print(f"Loading latest experiment, id={args.exp_id}") # Sanity checks if args.exp_id > 0: - log_path = os.path.join(folder, algo, f"{env_id}_{args.exp_id}") + log_path = os.path.join(folder, algo, f"{env_name}_{args.exp_id}") else: log_path = os.path.join(folder, algo) @@ -55,7 +57,7 @@ args_final_model = [ "--env", - env_id, + env_name.gym_id, "--algo", algo, "--exp-id", @@ -76,7 +78,7 @@ if deterministic is not None: args_final_model.append("--deterministic") - if os.path.exists(os.path.join(log_path, f"{env_id}.zip")): + if os.path.exists(os.path.join(log_path, f"{env_name}.zip")): return_code = subprocess.call(["python", "-m", "utils.record_video"] + args_final_model) assert return_code == 0, "Failed to record the final model" diff --git a/utils/record_video.py b/utils/record_video.py index 0bb66b114..a73942d72 100644 --- a/utils/record_video.py +++ b/utils/record_video.py @@ -4,6 +4,7 @@ import numpy as np import yaml +from huggingface_sb3 import EnvironmentName from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.vec_env import VecVideoRecorder @@ -12,7 +13,7 @@ if __name__ == "__main__": # noqa: C901 parser = argparse.ArgumentParser() - parser.add_argument("--env", help="environment ID", type=str, default="CartPole-v1") + parser.add_argument("--env", help="environment ID", type=EnvironmentName, default="CartPole-v1") parser.add_argument("-f", "--folder", help="Log folder", type=str, default="rl-trained-agents") parser.add_argument("-o", "--output-folder", help="Output folder", type=str) parser.add_argument("--algo", help="RL Algorithm", default="ppo", type=str, required=False, choices=list(ALGOS.keys())) @@ -44,7 +45,7 @@ ) args = parser.parse_args() - env_id = args.env + env_name: EnvironmentName = args.env algo = args.algo folder = args.folder video_folder = args.output_folder @@ -56,7 +57,7 @@ args.exp_id, folder, algo, - env_id, + env_name, args.load_best, args.load_checkpoint, args.load_last_checkpoint, @@ -67,14 +68,14 @@ set_random_seed(args.seed) - is_atari = ExperimentManager.is_atari(env_id) + is_atari = ExperimentManager.is_atari(env_name.gym_id) - stats_path = os.path.join(log_path, env_id) + stats_path = os.path.join(log_path, env_name) hyperparams, stats_path = get_saved_hyperparams(stats_path) # load env_kwargs if existing env_kwargs = {} - args_path = os.path.join(log_path, env_id, "args.yml") + args_path = os.path.join(log_path, env_name, "args.yml") if os.path.isfile(args_path): with open(args_path) as f: loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr @@ -85,7 +86,7 @@ env_kwargs.update(args.env_kwargs) env = create_test_env( - env_id, + env_name.gym_id, n_envs=n_envs, stats_path=stats_path, seed=seed, diff --git a/utils/utils.py b/utils/utils.py index f92fc8ae0..89ad5e332 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -2,7 +2,6 @@ import glob import importlib import os -import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union import gym @@ -10,6 +9,7 @@ import torch as th # noqa: F401 import yaml from huggingface_hub import HfApi +from huggingface_sb3 import EnvironmentName, ModelName from sb3_contrib import ARS, QRDQN, TQC, TRPO, RecurrentPPO from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.callbacks import BaseCallback @@ -286,51 +286,73 @@ def get_trained_models(log_folder: str) -> Dict[str, Tuple[str, str]]: for algo in os.listdir(log_folder): if not os.path.isdir(os.path.join(log_folder, algo)): continue - for env_id in os.listdir(os.path.join(log_folder, algo)): - # Retrieve env name - env_id = env_id.split("_")[0] - trained_models[f"{algo}-{env_id}"] = (algo, env_id) + for model_folder in os.listdir(os.path.join(log_folder, algo)): + args_files = glob.glob(os.path.join(log_folder, algo, model_folder, "*/args.yml")) + if len(args_files) != 1: + continue # we expect only one sub-folder with an args.yml file + with open(args_files[0], "r") as fh: + env_id = yaml.load(fh, Loader=yaml.UnsafeLoader)["env"] + + model_name = ModelName(algo, EnvironmentName(env_id)) + trained_models[model_name] = (algo, env_id) return trained_models -def get_hf_trained_models(organization: str = "sb3") -> Dict[str, Tuple[str, str]]: +def get_hf_trained_models(organization: str = "sb3", check_filename: bool = False) -> Dict[str, Tuple[str, str]]: """ Get pretrained models, available on the Hugginface hub for a given organization. - :param organization: + :param organization: Huggingface organization + Stable-Baselines (SB3) one is the default. + :param check_filename: Perform additional check per model + to be sure they match the RL Zoo convention. + (this will slow down things as it requires one API call per model) :return: Dict representing the trained agents """ api = HfApi() - models = api.list_models(author=organization) - regex = re.compile(r"^(?P[a-z_0-9]+)-(?P[a-zA-Z0-9]+-v[0-9]+)$") + models = api.list_models(author=organization, cardData=True) + trained_models = {} for model in models: - # Remove organization - repo_id = model.modelId.split(f"{organization}/")[1] - result = regex.match(repo_id) - # Skip demo repo that does not fit the pattern - if result is not None: - algo, env_id = result.group("algo"), result.group("env_id") - trained_models[f"{algo}-{env_id}"] = (algo, env_id) + # Try to extract algorithm and environment id from model card + try: + env_id = model.cardData["model-index"][0]["results"][0]["dataset"]["name"] + algo = model.cardData["model-index"][0]["name"].lower() + # RecurrentPPO alias is "ppo_lstm" in the rl zoo + if algo == "recurrentppo": + algo = "ppo_lstm" + except (KeyError, IndexError): + print(f"Skipping {model.modelId}") + continue # skip model if name env id or algo name could not be found + + env_name = EnvironmentName(env_id) + model_name = ModelName(algo, env_name) + + # check if there is a model file in the repo + if check_filename and not any(f.rfilename == model_name.filename for f in api.model_info(model.modelId).siblings): + continue # skip model if the repo contains no properly named model file + + trained_models[model_name] = (algo, env_id) + return trained_models -def get_latest_run_id(log_path: str, env_id: str) -> int: +def get_latest_run_id(log_path: str, env_name: EnvironmentName) -> int: """ Returns the latest run number for the given log name and log path, by finding the greatest number in the directories. :param log_path: path to log folder - :param env_id: + :param env_name: :return: latest run number """ max_run_id = 0 - for path in glob.glob(os.path.join(log_path, env_id + "_[0-9]*")): - file_name = os.path.basename(path) - ext = file_name.split("_")[-1] - if env_id == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: - max_run_id = int(ext) + for path in glob.glob(os.path.join(log_path, env_name + "_[0-9]*")): + run_id = path.split("_")[-1] + path_without_run_id = path[: -len(run_id) - 1] + if path_without_run_id.endswith(env_name) and run_id.isdigit() and int(run_id) > max_run_id: + max_run_id = int(run_id) return max_run_id @@ -397,33 +419,35 @@ def get_model_path( exp_id: int, folder: str, algo: str, - env_id: str, + env_name: EnvironmentName, load_best: bool = False, load_checkpoint: Optional[str] = None, load_last_checkpoint: bool = False, ) -> Tuple[str, str, str]: if exp_id == 0: - exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) + exp_id = get_latest_run_id(os.path.join(folder, algo), env_name) print(f"Loading latest experiment, id={exp_id}") # Sanity checks if exp_id > 0: - log_path = os.path.join(folder, algo, f"{env_id}_{exp_id}") + log_path = os.path.join(folder, algo, f"{env_name}_{exp_id}") else: log_path = os.path.join(folder, algo) assert os.path.isdir(log_path), f"The {log_path} folder was not found" + model_name = ModelName(algo, env_name) + if load_best: model_path = os.path.join(log_path, "best_model.zip") - name_prefix = f"best-model-{algo}-{env_id}" + name_prefix = f"best-model-{model_name}" elif load_checkpoint is not None: model_path = os.path.join(log_path, f"rl_model_{load_checkpoint}_steps.zip") - name_prefix = f"checkpoint-{load_checkpoint}-{algo}-{env_id}" + name_prefix = f"checkpoint-{load_checkpoint}-{model_name}" elif load_last_checkpoint: checkpoints = glob.glob(os.path.join(log_path, "rl_model_*_steps.zip")) if len(checkpoints) == 0: - raise ValueError(f"No checkpoint found for {algo} on {env_id}, path: {log_path}") + raise ValueError(f"No checkpoint found for {algo} on {env_name}, path: {log_path}") def step_count(checkpoint_path: str) -> int: # path follow the pattern "rl_model_*_steps.zip", we count from the back to ignore any other _ in the path @@ -431,14 +455,14 @@ def step_count(checkpoint_path: str) -> int: checkpoints = sorted(checkpoints, key=step_count) model_path = checkpoints[-1] - name_prefix = f"checkpoint-{step_count(model_path)}-{algo}-{env_id}" + name_prefix = f"checkpoint-{step_count(model_path)}-{model_name}" else: # Default: load latest model - model_path = os.path.join(log_path, f"{env_id}.zip") - name_prefix = f"final-model-{algo}-{env_id}" + model_path = os.path.join(log_path, f"{env_name}.zip") + name_prefix = f"final-model-{model_name}" found = os.path.isfile(model_path) if not found: - raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}") + raise ValueError(f"No model found for {algo} on {env_name}, path: {model_path}") return name_prefix, model_path, log_path