Skip to content

Commit

Permalink
Merge pull request #63 from berenslab/feat-add-pytest
Browse files Browse the repository at this point in the history
Add Pytest & tests for sample factory
  • Loading branch information
alex404 authored Nov 26, 2024
2 parents 1dd6176 + e9d62f7 commit e199a6f
Show file tree
Hide file tree
Showing 16 changed files with 194 additions and 48 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/code_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ jobs:
run: |
bash tests/ci/copy_configs.sh ${{ env.sif_file }}
bash tests/ci/scan_configs.sh ${{ env.sif_file }}
- name: Build Scenario
if: always() && steps.cache-singularity.outputs.cache-hit == 'true'
run: bash tests/ci/build_scenario.sh ${{ env.sif_file }}

- name: Pytest
if: always() && steps.cache-singularity.outputs.cache-hit == 'true'
run: apptainer exec ${{ env.sif_file }} pytest tests/modules
3 changes: 2 additions & 1 deletion doom_creator/compile_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def make_parser():
running the first time one should use the --preload flag to download the
necessary resources into the --out_dir ('{Directories().CACHE_DIR}').
""",
epilog="Example: python -m exec.compile_scenario gathering apples",
epilog="Example: python -m doom_creator.compile_scenario gathering apples",
)
# Positional argument for scenario yaml files (required, can be multiple)
parser.add_argument(
Expand Down Expand Up @@ -106,6 +106,7 @@ def main():
args = parser.parse_args(argv)

dirs = Directories(args.out_dir)

cfg = load(args.yamls, dirs.SCENARIO_YAML_DIR)
# Check preload flag
do_load, do_make, do_list = args.preload, len(args.yamls) > 0, args.list_yamls
Expand Down
44 changes: 22 additions & 22 deletions doom_creator/util/directories.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,47 @@
import os.path as osp
from dataclasses import dataclass
from pathlib import Path
from typing import Optional


@dataclass
class Directories:
cache_dir: str = "cache"
resource_dir: str = osp.join("doom_creator", "resources") # noqa: RUF009 as this is not really a dynamical call
scenario_out_dir: Optional[str] = None
build_dir: Optional[str] = None
textures_dir: Optional[str] = None
assets_dir: Optional[str] = None
scenario_yaml_dir: Optional[str] = None
dataset_dir: Optional[str] = None
cache_dir: Path = "cache"
resource_dir: Path = Path("doom_creator", "resources")
scenario_out_dir: Optional[Path] = None
build_dir: Optional[Path] = None
textures_dir: Optional[Path] = None
assets_dir: Optional[Path] = None
scenario_yaml_dir: Optional[Path] = None
dataset_dir: Optional[Path] = None

def __post_init__(self):
self.CACHE_DIR = self.cache_dir
self.SCENARIO_OUT_DIR = (
osp.join(self.CACHE_DIR, "scenarios")
self.CACHE_DIR: Path = self.cache_dir
self.SCENARIO_OUT_DIR: Path = (
Path(self.CACHE_DIR, "scenarios")
if self.scenario_out_dir is None
else self.scenario_out_dir
)
self.BUILD_DIR = (
osp.join(self.SCENARIO_OUT_DIR, "build")
self.BUILD_DIR: Path = (
Path(self.SCENARIO_OUT_DIR, "build")
if self.build_dir is None
else self.build_dir
)
self.TEXTURES_DIR = (
osp.join(self.CACHE_DIR, "textures")
self.TEXTURES_DIR: Path = (
Path(self.CACHE_DIR, "textures")
if self.textures_dir is None
else self.textures_dir
)
self.RESOURCE_DIR = self.resource_dir
self.ASSETS_DIR = (
osp.join(self.RESOURCE_DIR, "assets")
self.RESOURCE_DIR: Path = self.resource_dir
self.ASSETS_DIR: Path = (
Path(self.RESOURCE_DIR, "assets")
if self.assets_dir is None
else self.assets_dir
)
self.SCENARIO_YAML_DIR = (
osp.join(self.RESOURCE_DIR, "config")
self.SCENARIO_YAML_DIR: Path = (
Path(self.RESOURCE_DIR, "config")
if self.scenario_yaml_dir is None
else self.scenario_yaml_dir
)
self.DATASET_DIR = (
self.DATASET_DIR: Path = (
self.TEXTURES_DIR if self.dataset_dir is None else self.dataset_dir
)
1 change: 1 addition & 0 deletions doom_creator/util/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def make_scenario(
scenario_name: Optional[str] = None,
):
# Create Zip for output
directories.SCENARIO_OUT_DIR.mkdir(parents=True, exist_ok=True)
out_file = osp.join(directories.SCENARIO_OUT_DIR, scenario_name) + ".zip"
if osp.exists(out_file):
os.remove(out_file)
Expand Down
6 changes: 3 additions & 3 deletions doom_creator/util/preload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import shutil
import struct
from glob import glob
from typing import Optional
from typing import Optional, Set, Tuple

from PIL import Image
from PIL.PngImagePlugin import PngInfo
Expand Down Expand Up @@ -105,8 +105,8 @@ def preload_dataset(
dataset_wrapper.clean(source_dir)


def check_preload(cfg: Config, test: bool):
needed_types = set()
def check_preload(cfg: Config, test: bool) -> Tuple[Config, Set[TextureType]]:
needed_types: Set[TextureType] = set()
for type_cfg in cfg.objects.values():
for actor in type_cfg.actors.values():
for i in range(len(actor.textures)):
Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _program(cfg: DictConfig):
objective = instantiate(cfg.optimizer.objective, brain=brain)
# TODO: RL framework currently can't use objective
else:
objective = None
warnings.warn("No objective specified, is that wanted?")

if cfg.command == "scan":
Expand Down
3 changes: 2 additions & 1 deletion resources/retinal-rl.def
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ From: ubuntu:22.04
opencv-python==4.10.0.84 \
pygame==2.6.1 \
pycairo==1.26.1 \
pytest==8.3.3 \
git+https://github.com/pytorch/captum.git@fd758e025673100cb6a525d59a78893c558b825b \
torchinfo==1.8.0 \
num2words==0.5.13 \
omgifol==0.5.1 \
git+https://github.com/alex404/sample-factory.git@05465ef425530c3e2ef13dd1f9aa2d7313eb3d4a \
git+https://github.com/alex-petrenko/sample-factory.git@e9589e4218e838d8a1377380c27ad45feb44f4ba \
dpcpp-cpp-rt==2024.2.1 \
seaborn==0.13.2 \
hydra-core==1.3.2 \
Expand Down
9 changes: 6 additions & 3 deletions retinal_rl/rl/sample_factory/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,12 @@ def make_retinal_env_from_spec(


def register_retinal_env(scene_name: str, cache_dir: str, input_satiety: bool):
cfg_path = os.path.join(
cache_dir, "scenarios", scene_name + ".cfg"
) # TODO: Check if this stays
print(cache_dir)
if not os.path.isabs(cache_dir):
# make path absolute by making it relative to the path of this file
# TODO: Discuss whether this is desired behaviour...
cache_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", cache_dir)
cfg_path = os.path.join(cache_dir, "scenarios", scene_name + ".cfg")

env_spec = retinal_doomspec(scene_name, cfg_path, input_satiety)
make_env_func = functools.partial(make_retinal_env_from_spec, env_spec)
Expand Down
6 changes: 3 additions & 3 deletions retinal_rl/rl/sample_factory/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from enum import Enum
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -116,7 +116,7 @@ def forward_core(self, head_output, rnn_states):
return out, rnn_states

def forward_tail(
self, core_output, values_only: bool, sample_actions: bool
self, core_output, values_only: bool, sample_actions: bool, action_mask: Optional[Tensor] = None
) -> TensorDict:
out = self.brain.circuits[self.decoder_name](core_output)
out = torch.flatten(out, 1)
Expand All @@ -138,7 +138,7 @@ def forward_tail(
return result

def forward(
self, normalized_obs_dict, rnn_states, values_only: bool = False
self, normalized_obs_dict, rnn_states, values_only: bool = False, action_mask: Optional[Tensor] = None
) -> TensorDict:
head_out = self.forward_head(normalized_obs_dict)
core_out, new_rnn_states = self.forward_core(head_out, rnn_states)
Expand Down
36 changes: 24 additions & 12 deletions runner/frameworks/rl/sf_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def train(
objective: Optional[Objective[ContextT]] = None,
):
warnings.warn(
"device, brain, optimizer are initialized differently in sample_factory and thus there current state will be ignored"
"device, brain, optimizer are initialized differently in sample_factory and thus their current state will be ignored"
)
warnings.warn(
"objective is currently not supported for sample factory simulations"
Expand Down Expand Up @@ -113,25 +113,37 @@ def load_brain_and_config(
brain.to(device)
return brain

def to_sf_cfg(self, cfg: DictConfig) -> Config:
sf_cfg = self._get_default_cfg(cfg.dataset.env_name) # Load Defaults
@staticmethod
def to_sf_cfg(cfg: DictConfig) -> Config:
sf_cfg = SFFramework._get_default_cfg(cfg.dataset.env_name) # Load Defaults

# overwrite default values with those set in cfg
# TODO: which other parameters need to be set_
self._set_cfg_cli_argument(sf_cfg, "learning_rate", cfg.optimizer.optimizer.lr)
SFFramework._set_cfg_cli_argument(
sf_cfg, "learning_rate", cfg.optimizer.optimizer.lr
)
# Using this function is necessary to make sure that the parameters are not overwritten when sample_factory loads a checkpoint

self._set_cfg_cli_argument(sf_cfg, "res_h", cfg.dataset.vision_width)
self._set_cfg_cli_argument(sf_cfg, "res_w", cfg.dataset.vision_height)
self._set_cfg_cli_argument(sf_cfg, "env", cfg.dataset.env_name)
self._set_cfg_cli_argument(sf_cfg, "input_satiety", cfg.dataset.input_satiety)
self._set_cfg_cli_argument(sf_cfg, "device", cfg.system.device)
SFFramework._set_cfg_cli_argument(sf_cfg, "res_h", cfg.dataset.vision_width)
SFFramework._set_cfg_cli_argument(sf_cfg, "res_w", cfg.dataset.vision_height)
SFFramework._set_cfg_cli_argument(sf_cfg, "env", cfg.dataset.env_name)
SFFramework._set_cfg_cli_argument(
sf_cfg, "input_satiety", cfg.dataset.input_satiety
)
SFFramework._set_cfg_cli_argument(sf_cfg, "device", cfg.system.device)
optimizer_name = str.lower(
str.split(cfg.optimizer.optimizer._target_, sep=".")[-1]
)
self._set_cfg_cli_argument(sf_cfg, "optimizer", optimizer_name)
SFFramework._set_cfg_cli_argument(sf_cfg, "optimizer", optimizer_name)

self._set_cfg_cli_argument(sf_cfg, "brain", OmegaConf.to_object(cfg.brain))
SFFramework._set_cfg_cli_argument(
sf_cfg, "brain", OmegaConf.to_object(cfg.brain)
)
SFFramework._set_cfg_cli_argument(
sf_cfg, "train_dir", os.path.join(cfg.path.run_dir, "train_dir")
)
SFFramework._set_cfg_cli_argument(sf_cfg, "with_wandb", cfg.logging.use_wandb)
SFFramework._set_cfg_cli_argument(sf_cfg, "wandb_dir", cfg.path.wandb_dir)
return sf_cfg

def analyze(
Expand All @@ -141,7 +153,7 @@ def analyze(
objective: Optional[Objective[ContextT]] = None,
):
warnings.warn(
"device, brain, optimizer are initialized differently in sample_factory and thus there current state will be ignored"
"device, brain, optimizer are initialized differently in sample_factory and thus their current state will be ignored"
)
enjoy(self.sf_cfg)
# TODO: Implement analyze function for sf framework
Expand Down
27 changes: 27 additions & 0 deletions runner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,30 @@ def _resolve_output_shape(
raise ValueError(
f"Invalid format for output_shape: {output_shape}. Must be of the form 'circuit_name.property_name'"
)


def search_conf(config: DictConfig | dict, search_str: str) -> List:
"""
Recursively search for strings in a DictConfig.
Args:
config (omegaconf.DictConfig): The configuration to search.
Returns:
list: A list of all values containing the string.
"""
found_values = []

def traverse_config(cfg):
for key, value in cfg.items():
if isinstance(value, (dict, DictConfig)):
traverse_config(value)
elif isinstance(value, str) and search_str in value:
found_values.append(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, str) and search_str in item:
found_values.append(item)

traverse_config(config)
return found_values
13 changes: 13 additions & 0 deletions tests/ci/build_scenario.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
#===============================================================================
# Description: Builds the gathering apples scenario used for tests
#
# Arguments:
# $1 - Path to Singularity (.sif) container
#
# Usage:
# tests/ci/build_scenario.sh container.sif
# (run from top level directory!)
#===============================================================================

singularity exec "$1" python -m doom_creator.compile_scenario gathering apples
4 changes: 2 additions & 2 deletions tests/ci/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ fi

if [ -n "$changed_files" ]; then
# Format
apptainer exec "$CONTAINER" ruff format $changed_files $check
singularity exec "$CONTAINER" ruff format $changed_files $check
# Run ruff on changed files with any remaining arguments
apptainer exec "$CONTAINER" ruff check $changed_files "$@"
singularity exec "$CONTAINER" ruff check $changed_files "$@"
else
echo "No .py files changed"
fi
2 changes: 1 addition & 1 deletion tests/ci/scan_configs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@

for file in config/user/experiment/*.yaml; do
experiment=$(basename "$file" .yaml)
apptainer exec "$1" \
singularity exec "$1" \
python main.py +experiment="$experiment" command=scan system.device=cpu
done
47 changes: 47 additions & 0 deletions tests/modules/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import shutil
import sys
import time

import hydra
import pytest
from omegaconf import DictConfig, OmegaConf

sys.path.append(".")
from runner.util import search_conf

OmegaConf.register_new_resolver("eval", eval)


@pytest.fixture
def config() -> DictConfig:
with hydra.initialize(config_path="../../config/base", version_base=None):
experiment = "gathering-apples"
config = hydra.compose(
"config", overrides=[f"+experiment={experiment}", "system.device=cpu"]
)

# replace the paths that are normally set via HydraConfig
config.path.run_dir = f"tmp{hash(time.time())}"
config.sweep.command[-2] = experiment

# check whether there's still values to be interpolated through hydra
hydra_values = search_conf(
OmegaConf.to_container(config, resolve=False), "hydra:"
)

assert (
len(hydra_values) == 0
), "hydra: values can not be resolved here. Set them manually in this fixture for tests!"

OmegaConf.resolve(config)
yield config

# Cleanup: remove temporary dir
if os.path.exists(config.path.run_dir):
shutil.rmtree(config.path.run_dir)


@pytest.fixture
def data_root() -> str:
return "cache"
Loading

0 comments on commit e199a6f

Please sign in to comment.