Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore analysis_config.output, raise if SHMDIR isn't set #42

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
Changelog
=========

Version 0.11.0
--------------

Improvements
~~~~~~~~~~~~

- Allow to access the cache config and the output path of individual analyses with ``analysis_config.cache`` and ``analysis_config.output``, as a shortcut to ``analysis_config.cache.path``.
- Raise an error if the env variable ``SHMDIR`` isn't set, instead of logging a warning.


Version 0.10.1
--------------

Expand Down
8 changes: 1 addition & 7 deletions src/blueetl/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from blueetl.cache import CacheManager
from blueetl.campaign.config import SimulationCampaign
from blueetl.config.analysis import init_multi_analysis_configuration
from blueetl.config.analysis_model import CacheConfig, MultiAnalysisConfig, SingleAnalysisConfig
from blueetl.config.analysis_model import MultiAnalysisConfig, SingleAnalysisConfig
from blueetl.features import FeaturesCollection
from blueetl.repository import Repository
from blueetl.resolver import AttrResolver, Resolver
Expand Down Expand Up @@ -46,19 +46,16 @@ def from_config(
cls,
analysis_config: SingleAnalysisConfig,
simulations_config: SimulationCampaign,
cache_config: CacheConfig,
resolver: Resolver,
) -> "Analyzer":
"""Initialize the Analyzer from the given configuration.

Args:
analysis_config: analysis configuration.
simulations_config: simulation campaign configuration.
cache_config: cache configuration.
resolver: resolver instance.
"""
cache_manager = CacheManager(
cache_config=cache_config,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
Expand Down Expand Up @@ -214,9 +211,6 @@ def _init_analyzers(self) -> dict[str, Analyzer]:
name: Analyzer.from_config(
analysis_config=analysis_config,
simulations_config=simulations_config,
cache_config=self.global_config.cache.model_copy(
update={"path": self.global_config.cache.path / name}
),
resolver=resolver,
)
for name, analysis_config in self.global_config.analysis.items()
Expand Down
8 changes: 4 additions & 4 deletions src/blueetl/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from blueetl_core.utils import is_subfilter

from blueetl.campaign.config import SimulationCampaign
from blueetl.config.analysis_model import CacheConfig, FeaturesConfig, SingleAnalysisConfig
from blueetl.config.analysis_model import FeaturesConfig, SingleAnalysisConfig
from blueetl.store.base import BaseStore
from blueetl.store.feather import FeatherStore
from blueetl.store.parquet import ParquetStore
Expand Down Expand Up @@ -143,17 +143,17 @@ class CacheManager:

def __init__(
self,
cache_config: CacheConfig,
analysis_config: SingleAnalysisConfig,
simulations_config: SimulationCampaign,
) -> None:
"""Initialize the object.

Args:
cache_config: cache configuration dict.
analysis_config: analysis configuration dict.
analysis_config: analysis configuration.
simulations_config: simulations campaign configuration.
"""
cache_config = analysis_config.cache
assert cache_config is not None
self._output_dir = cache_config.path
if cache_config.clear:
self._clear_cache()
Expand Down
4 changes: 3 additions & 1 deletion src/blueetl/config/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def expand_zip(params: dict, params_zip: dict) -> Iterator[dict]:


def _resolve_analysis_configs(global_config: MultiAnalysisConfig) -> None:
for config in global_config.analysis.values():
global_cache_path = global_config.cache.path
for name, config in global_config.analysis.items():
config.cache = global_config.cache.model_copy(update={"path": global_cache_path / name})
config.simulations_filter = global_config.simulations_filter
config.simulations_filter_in_memory = global_config.simulations_filter_in_memory
config.features = _resolve_features(config.features)
Expand Down
6 changes: 6 additions & 0 deletions src/blueetl/config/analysis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class FeaturesConfig(BaseModel):
class SingleAnalysisConfig(BaseModel):
"""SingleAnalysisConfig Model."""

cache: Optional[CacheConfig] = None
simulations_filter: dict[str, Any] = {}
simulations_filter_in_memory: dict[str, Any] = {}
extraction: ExtractionConfig
Expand All @@ -204,6 +205,11 @@ def handle_deprecated_fields(cls, data: Any) -> Any:
data.pop("output", None)
return data

@property
def output(self) -> Optional[Path]:
"""Shortcut to the base output path of the analysis."""
return self.cache.path if self.cache else None


class MultiAnalysisConfig(BaseModel):
"""MultiAnalysisConfig Model."""
Expand Down
18 changes: 10 additions & 8 deletions src/blueetl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import time
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from enum import Enum
from functools import cache, cached_property
from pathlib import Path
from typing import Any, Callable, Optional, Union

import pandas as pd
import yaml
from pydantic import BaseModel

from blueetl.constants import DTYPES
from blueetl.types import StrOrPath
Expand Down Expand Up @@ -190,12 +192,9 @@ def checksum_json(obj: Any) -> str:
@cache
def _get_internal_yaml_dumper() -> type[yaml.SafeDumper]:
"""Return the custom internal yaml dumper class."""
# pylint: disable=import-outside-toplevel
# imported here because optional
from pydantic import BaseModel

_representers = {
Path: str,
Enum: lambda data: data.value,
BaseModel: lambda data: data.dict(),
}

Expand Down Expand Up @@ -336,12 +335,15 @@ def copy_config(src: StrOrPath, dst: StrOrPath) -> None:
dump_yaml(dst, config, default_flow_style=None)


def get_shmdir() -> Optional[Path]:
"""Return the shared memory directory, or None if not set."""
def get_shmdir() -> Path:
"""Return the shared memory directory, or raise an error if not set."""
shmdir = os.getenv("SHMDIR")
if not shmdir:
L.warning("SHMDIR should be set to the shared memory directory")
return None
raise RuntimeError(
"SHMDIR must be set to the shared memory directory. "
"The variable should be automatically set when running on an allocated node, "
"but it's not set when connecting via SSH to a pre-allocated node."
)
shmdir = Path(shmdir)
if not shmdir.is_dir():
raise RuntimeError("SHMDIR must be set to an existing directory")
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/config/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from blueetl.config import analysis as test_module
from blueetl.config.analysis_model import FeaturesConfig, MultiAnalysisConfig
from blueetl.config.analysis_model import FeaturesConfig, MultiAnalysisConfig, SingleAnalysisConfig
from blueetl.utils import load_yaml
from tests.functional.utils import TEST_DATA_PATH as TEST_DATA_PATH_FUNCTIONAL
from tests.unit.utils import TEST_DATA_PATH as TEST_DATA_PATH_UNIT
Expand Down Expand Up @@ -191,3 +191,9 @@ def test_init_multi_analysis_configuration(config_file):
config_dict, base_path=base_path, extra_params={}
)
assert isinstance(result, MultiAnalysisConfig)
assert result.cache.path == base_path / config_dict["cache"]["path"]
assert len(result.analysis) > 0
for name, analysis_config in result.analysis.items():
assert isinstance(analysis_config, SingleAnalysisConfig)
assert analysis_config.cache is not None
assert analysis_config.output == result.cache.path / name
29 changes: 8 additions & 21 deletions tests/unit/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
def cache_manager(global_config):
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
analysis_config = global_config.analysis["spikes"]
cache_config = global_config.cache
instance = test_module.CacheManager(
cache_config=cache_config,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
Expand Down Expand Up @@ -62,10 +60,8 @@ def test_lock_manager_shared(tmp_path):
def test_cache_manager_init_and_close(global_config):
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
analysis_config = global_config.analysis["spikes"]
cache_config = global_config.cache

instance = test_module.CacheManager(
cache_config=cache_config,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
Expand All @@ -79,10 +75,8 @@ def test_cache_manager_init_and_close(global_config):
def test_cache_manager_to_readonly(global_config):
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
analysis_config = global_config.analysis["spikes"]
cache_config = global_config.cache

instance = test_module.CacheManager(
cache_config=cache_config,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
Expand All @@ -108,24 +102,20 @@ def test_cache_manager_to_readonly(global_config):
def test_cache_manager_concurrency_is_not_allowed_when_locked(global_config):
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
analysis_config = global_config.analysis["spikes"]
cache_config = global_config.cache

instance = test_module.CacheManager(
cache_config=cache_config,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
# verify that a new instance cannot be created when the old instance is keeping the lock
with pytest.raises(test_module.CacheError, match="Another process is locking"):
test_module.CacheManager(
cache_config=cache_config,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
# verify that a new instance can be created after closing the old instance
instance.close()
instance = test_module.CacheManager(
cache_config=cache_config,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
Expand All @@ -135,12 +125,10 @@ def test_cache_manager_concurrency_is_not_allowed_when_locked(global_config):
def test_cache_manager_concurrency_is_allowed_when_readonly(global_config):
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
analysis_config = global_config.analysis["spikes"]
cache_config = global_config.cache.model_copy(update={"readonly": False})
cache_config_readonly = global_config.cache.model_copy(update={"readonly": True})
cache_config = analysis_config.cache

# init the cache that will be used later
instance = test_module.CacheManager(
cache_config=cache_config,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
Expand All @@ -149,8 +137,9 @@ def test_cache_manager_concurrency_is_allowed_when_readonly(global_config):
# use the same cache in multiple cache managers
instances = [
test_module.CacheManager(
cache_config=cache_config_readonly,
analysis_config=analysis_config,
analysis_config=analysis_config.model_copy(
update={"cache": cache_config.model_copy(update={"readonly": True})}
),
simulations_config=simulations_config,
)
for _ in range(3)
Expand All @@ -162,15 +151,13 @@ def test_cache_manager_concurrency_is_allowed_when_readonly(global_config):
def test_cache_manager_clear_cache(global_config, tmp_path):
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
analysis_config = global_config.analysis["spikes"]
cache_config = global_config.cache.model_copy(update={"clear": False})
cache_config_clear = global_config.cache.model_copy(update={"clear": True})
cache_config = analysis_config.cache

output = cache_config.path
sentinel = output / "sentinel"

assert output.exists() is False
instance = test_module.CacheManager(
cache_config=cache_config_clear,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
Expand All @@ -181,7 +168,6 @@ def test_cache_manager_clear_cache(global_config, tmp_path):

# reuse the cache
instance = test_module.CacheManager(
cache_config=cache_config,
analysis_config=analysis_config,
simulations_config=simulations_config,
)
Expand All @@ -191,8 +177,9 @@ def test_cache_manager_clear_cache(global_config, tmp_path):

# delete the cache
instance = test_module.CacheManager(
cache_config=cache_config_clear,
analysis_config=analysis_config,
analysis_config=analysis_config.model_copy(
update={"cache": cache_config.model_copy(update={"clear": True})}
),
simulations_config=simulations_config,
)
instance.close()
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from enum import Enum
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -95,11 +96,16 @@ def test_resolve_path(tmp_path):


def test_dump_yaml(tmp_path):
class TestEnum(str, Enum):
v0 = "v0"
v1 = "v1"

data = {
"dict": {"str": "mystr", "int": 123},
"list_of_int": [1, 2, 3],
"list_of_str": ["1", "2", "3"],
"path": Path("/custom/path"),
"enum": TestEnum.v0,
}
expected = """
dict:
Expand All @@ -114,6 +120,7 @@ def test_dump_yaml(tmp_path):
- '2'
- '3'
path: /custom/path
enum: v0
"""
filepath = tmp_path / "test.yaml"

Expand Down Expand Up @@ -150,7 +157,7 @@ def test_load_yaml(tmp_path):
assert loaded_data == expected


def test_dump_jaon_load_json_roundtrip(tmp_path):
def test_dump_json_load_json_roundtrip(tmp_path):
data = {
"dict": {"str": "mystr", "int": 123},
"list_of_int": [1, 2, 3],
Expand Down Expand Up @@ -286,8 +293,8 @@ def test_get_shmdir(monkeypatch, tmp_path):
assert shmdir == tmp_path

monkeypatch.delenv("SHMDIR")
shmdir = test_module.get_shmdir()
assert shmdir is None
with pytest.raises(RuntimeError, match="SHMDIR must be set to the shared memory directory"):
test_module.get_shmdir()

monkeypatch.setenv("SHMDIR", str(tmp_path / "non-existent"))
with pytest.raises(RuntimeError, match="SHMDIR must be set to an existing directory"):
Expand Down
5 changes: 2 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ minversion = 4

[testenv]
setenv =
TMPDIR={env:TMPDIR:/tmp}
SHMDIR={env:SHMDIR:{env:TMPDIR}}
# Run serially
BLUEETL_JOBLIB_JOBS=1
passenv =
SHMDIR
TMPDIR
extras =
all
deps =
Expand Down