Skip to content

Commit

Permalink
Merge branch 'main' into export-additional-config-persistence-methods
Browse files Browse the repository at this point in the history
  • Loading branch information
bpkroth authored Jan 9, 2025
2 parents de330a6 + 6d91add commit a938280
Show file tree
Hide file tree
Showing 16 changed files with 41 additions and 27 deletions.
23 changes: 14 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
# See https://pre-commit.com/hooks.html for more hooks
default_stages: [pre-commit]

# Note "require_serial" actually controls whether that particular hook's files
# are partitioned and the hook executable called in parallel across them, not
# whether hooks themselves are parallelized.
# As such, some hooks (e.g., pylint) which do internal parallelism need it set
# for effeciency and correctness anyways.

repos:
#
# Formatting
#
# NOTE: checks that adjust files are marked with the special "manual" stage
# and "require_serial" so that we can easily call them via `make`
# NOTE: checks that adjust files are marked with the special "manual" stage so
# that we can easily call them via `make`.
#
#
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand All @@ -18,48 +25,42 @@ repos:
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
require_serial: true
stages: [pre-commit, manual]
# TODO:
#- id: pretty-format-json
# args: [--autofix, --no-sort-keys]
# require_serial: true
# stages: [pre-commit, manual]
- id: trailing-whitespace
require_serial: true
stages: [pre-commit, manual]
- repo: https://github.com/johann-petrak/licenseheaders
rev: v0.8.8
hooks:
- id: licenseheaders
files: '\.(sh|cmd|ps1|sql|py)$'
args: [-t, doc/mit-license.tmpl, -E, .py, .sh, .ps1, .sql, .cmd, -x, mlos_bench/setup.py, mlos_core/setup.py, mlos_viz/setup.py, -f]
require_serial: true
stages: [pre-commit, manual]
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py310-plus]
require_serial: true
stages: [pre-commit, manual]
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
require_serial: true
args: ["-j", "-1"]
stages: [pre-commit, manual]
- repo: https://github.com/psf/black
rev: 24.10.0
hooks:
- id: black
require_serial: true
stages: [pre-commit, manual]
- repo: https://github.com/PyCQA/docformatter
rev: 06907d0 # v1.7.5
hooks:
- id: docformatter
require_serial: true
stages: [pre-commit, manual]
#
# Linting
Expand All @@ -69,6 +70,8 @@ repos:
hooks:
- id: pydocstyle
types: [python]
additional_dependencies:
- tomli
# Use pylint and mypy from the local (conda) environment so that vscode can reuse them too.
- repo: local
hooks:
Expand All @@ -82,6 +85,7 @@ repos:
entry: pylint
language: system
types: [python]
require_serial: true
args: [
"-j0",
"--rcfile=pyproject.toml",
Expand All @@ -97,6 +101,7 @@ repos:
entry: mypy
language: system
types: [python]
require_serial: true
exclude: |
(?x)^(
doc/source/conf.py|
Expand Down
6 changes: 5 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,9 @@
"--log-level=DEBUG",
"."
],
"python.testing.unittestEnabled": false
"python.testing.unittestEnabled": false,
"debugpy.debugJustMyCode": false,
"python.analysis.autoImportCompletions": true,
"python.analysis.supportRestructuredText": true,
"python.analysis.typeCheckingMode": "standard"
}
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/environments/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import logging
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager as ContextManager
from datetime import datetime
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal
Expand All @@ -28,7 +29,7 @@
_LOG = logging.getLogger(__name__)


class Environment(metaclass=abc.ABCMeta):
class Environment(ContextManager, metaclass=abc.ABCMeta):
# pylint: disable=too-many-instance-attributes
"""An abstract base of all benchmark environments."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _template_from_to(self, config_key: str) -> list[tuple[Template, Template]]:
def _expand(
from_to: Iterable[tuple[Template, Template]],
params: Mapping[str, TunableValue],
) -> Generator[tuple[str, str], None, None]:
) -> Generator[tuple[str, str]]:
"""
Substitute $var parameters in from/to path templates.
Expand Down
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from abc import ABCMeta, abstractmethod
from collections.abc import Sequence
from contextlib import AbstractContextManager as ContextManager
from types import TracebackType
from typing import Literal

Expand All @@ -25,7 +26,7 @@
_LOG = logging.getLogger(__name__)


class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
class Optimizer(ContextManager, metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes
"""An abstract interface between the benchmarking framework and mlos_core
optimizers.
"""
Expand Down
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/schedulers/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
from abc import ABCMeta, abstractmethod
from contextlib import AbstractContextManager as ContextManager
from datetime import datetime
from types import TracebackType
from typing import Any, Literal
Expand All @@ -23,7 +24,7 @@
_LOG = logging.getLogger(__name__)


class Scheduler(metaclass=ABCMeta):
class Scheduler(ContextManager, metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
"""Base class for the optimization loop scheduling policies."""

Expand Down
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/services/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import logging
from collections.abc import Callable
from contextlib import AbstractContextManager as ContextManager
from types import TracebackType
from typing import Any, Literal

Expand All @@ -19,7 +20,7 @@
_LOG = logging.getLogger(__name__)


class Service:
class Service(ContextManager):
"""An abstract base of all Environment Services and used to build up mix-ins."""

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions mlos_bench/mlos_bench/services/config_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
Free-format dictionary of global parameters.
parent : Service
An optional parent service that can provide mixin functions.
methods : Union[dict[str, Callable], list[Callable], None]
methods : dict[str, Callable] | list[Callable] | None
New methods to register with the service.
"""
super().__init__(
Expand Down Expand Up @@ -166,7 +166,7 @@ def load_config(
Returns
-------
config : Union[dict, list[dict]]
config : dict | list[dict]
Free-format dictionary that contains the configuration.
"""
assert isinstance(json, str)
Expand Down
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/storage/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import logging
from abc import ABCMeta, abstractmethod
from collections.abc import Iterator
from contextlib import AbstractContextManager as ContextManager
from datetime import datetime
from types import TracebackType
from typing import Any, Literal
Expand Down Expand Up @@ -132,7 +133,7 @@ def experiment( # pylint: disable=too-many-arguments
the results of the experiment and related data.
"""

class Experiment(metaclass=ABCMeta):
class Experiment(ContextManager, metaclass=ABCMeta):
# pylint: disable=too-many-instance-attributes
"""
Base interface for storing the results of the experiment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def ssh_test_server(
ssh_test_server_hostname: str,
docker_compose_project_name: str,
locked_docker_services: DockerServices,
) -> Generator[SshTestServerInfo, None, None]:
) -> Generator[SshTestServerInfo]:
"""
Fixture for getting the ssh test server services setup via docker-compose using
pytest-docker.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@contextmanager
def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, None]:
def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper]:
"""
Provides a context manager for a temporary file that can be closed and still
unlinked.
Expand Down
6 changes: 3 additions & 3 deletions mlos_bench/mlos_bench/tests/storage/sql/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def storage() -> SqlStorage:
def exp_storage(
storage: SqlStorage,
tunable_groups: TunableGroups,
) -> Generator[SqlStorage.Experiment, None, None]:
) -> Generator[SqlStorage.Experiment]:
"""
Test fixture for Experiment using in-memory SQLite3 storage.
Expand All @@ -60,7 +60,7 @@ def exp_storage(
@pytest.fixture
def exp_no_tunables_storage(
storage: SqlStorage,
) -> Generator[SqlStorage.Experiment, None, None]:
) -> Generator[SqlStorage.Experiment]:
"""
Test fixture for Experiment using in-memory SQLite3 storage.
Expand All @@ -84,7 +84,7 @@ def exp_no_tunables_storage(
def mixed_numerics_exp_storage(
storage: SqlStorage,
mixed_numerics_tunable_groups: TunableGroups,
) -> Generator[SqlStorage.Experiment, None, None]:
) -> Generator[SqlStorage.Experiment]:
"""
Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3
storage.
Expand Down
4 changes: 2 additions & 2 deletions mlos_bench/mlos_bench/tunables/tunable_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@ def __setitem__(
self._index[name][name] = value
return self._index[name][name]

def __iter__(self) -> Generator[tuple[Tunable, CovariantTunableGroup], None, None]:
def __iter__(self) -> Generator[tuple[Tunable, CovariantTunableGroup]]:
"""
An iterator over all tunables in the group.
Returns
-------
[(tunable, group), ...] : Generator[tuple[Tunable, CovariantTunableGroup], None, None]
[(tunable, group), ...] : Generator[tuple[Tunable, CovariantTunableGroup]]
An iterator over all tunables in all groups. Each element is a 2-tuple
of an instance of the Tunable parameter and covariant group it belongs to.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_configspace_quant_repatch() -> None:
new_meta[QUANTIZATION_BINS_META_KEY] = 21
hp.meta = new_meta
monkey_patch_hp_quantization(hp)
samples_set = set(hp.sample_value(100, seed=RandomState(SEED)))
samples_set: set[int] = set(hp.sample_value(100, seed=RandomState(SEED)))
quantized_values_new = set(range(5, 96, 10))
assert samples_set.issubset(set(range(0, 101, 5)))
assert len(samples_set - quantized_values_new) < len(samples_set)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ disable = [
"consider-using-assignment-expr",
"docstring-first-line-empty",
"missing-raises-doc",
"unnecessary-default-type-args", # affects Generator type hints, but we still support python 3.8
]

[tool.pylint.string]
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ exclude_also =
#

[mypy]
cache_fine_grained = True
#ignore_missing_imports = True
warn_unused_configs = True
warn_unused_ignores = True
Expand Down

0 comments on commit a938280

Please sign in to comment.