Skip to content

Commit

Permalink
evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Sep 9, 2024
1 parent 97aae08 commit b7c3a84
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 0 deletions.
20 changes: 20 additions & 0 deletions benchmarl/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse

from benchmarl.hydra_config import reload_experiment_from_file

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Evaluates the experiment from a checkpoint file."
)
parser.add_argument(
"checkpoint_file", type=str, help="The name of the checkpoint file"
)
args = parser.parse_args()
checkpoint_file = args.checkpoint_file
experiment = reload_experiment_from_file(checkpoint_file)
experiment.evaluate()
4 changes: 4 additions & 0 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,10 @@ def evaluate(self):
"""Run just the evaluation loop once."""
self._evaluation_loop()
self.logger.commit()
print(
f"Evaluation results logged to loggers={self.config.loggers}"
f"{' and to a json file in the experiment folder.' if self.config.create_json else ''}"
)

def _collection_loop(self):
pbar = tqdm(
Expand Down
56 changes: 56 additions & 0 deletions benchmarl/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
import importlib
from dataclasses import is_dataclass
from pathlib import Path

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, task_config_registry
Expand All @@ -16,6 +17,7 @@
_has_hydra = importlib.util.find_spec("hydra") is not None

if _has_hydra:
from hydra import compose, initialize, initialize_config_dir
from omegaconf import DictConfig, OmegaConf


Expand Down Expand Up @@ -121,3 +123,57 @@ def load_model_config_from_hydra(cfg: DictConfig) -> ModelConfig:
OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
)
)


def _find_hydra_folder(restore_file: str) -> str:
"""Given the restore file, look for the .hydra folder max three levels above it."""
current_folder = Path(restore_file).parent.resolve()
for _ in range(3):
hydra_dir = current_folder / ".hydra"
if hydra_dir.exists() and hydra_dir.is_dir():
return str(hydra_dir)
current_folder = current_folder.parent
raise ValueError(
".hydra folder not found (should be max 3 levels above checkpoint file"
)


def reload_experiment_from_file(restore_file: str) -> Experiment:
"""Reloads the experiment from a given restore file.
Requires a ``.hydra`` folder containing ``config.yaml``, ``hydra.yaml``, and ``overrides.yaml``
at max three directory levels higher than the checkpoint file. This should be automatically created by hydra.
Args:
restore_file (str): The checkpoint file of the experiment reload.
"""
hydra_folder = _find_hydra_folder(restore_file)
with initialize(
version_base=None,
config_path="conf",
):
cfg = compose(
config_name="config",
overrides=OmegaConf.load(Path(hydra_folder) / "overrides.yaml"),
return_hydra_config=True,
)
task_name = cfg.hydra.runtime.choices.task
algorithm_name = cfg.hydra.runtime.choices.algorithm
with initialize_config_dir(version_base=None, config_dir=hydra_folder):
cfg_loaded = dict(compose(config_name="config"))

for key in ("experiment", "algorithm", "task", "model", "critic_model"):
cfg[key].update(cfg_loaded[key])
cfg_loaded.pop(key)

cfg.update(cfg_loaded)
del cfg.hydra
cfg.experiment.restore_file = restore_file

print("\nReloaded experiment with:")
print(f"\nAlgorithm: {algorithm_name}, Task: {task_name}")
print("\nLoaded config:\n")
print(OmegaConf.to_yaml(cfg))

return load_experiment_from_hydra(cfg, task_name=task_name)
23 changes: 23 additions & 0 deletions benchmarl/resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2024.
# ProrokLab (https://www.proroklab.org/)
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse

from benchmarl.hydra_config import reload_experiment_from_file

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Resumes the experiment from a checkpoint file."
)
parser.add_argument(
"checkpoint_file", type=str, help="The name of the checkpoint file"
)
args = parser.parse_args()
checkpoint_file = args.checkpoint_file

experiment = reload_experiment_from_file(checkpoint_file)
experiment.run()

0 comments on commit b7c3a84

Please sign in to comment.