Skip to content

Commit

Permalink
Pyright Improvements (#932)
Browse files Browse the repository at this point in the history
# Pull Request

## Title

Pyright improvements

______________________________________________________________________

## Description

Pyright is a type checker that ships with VSCode's Pylance by default.

It is billed as a faster, though less complete, version of mypy.

As such it gets a few things a little differently that mypy and alerts
in VSCode.

This PR fixes those ("standard") alerts and removes the mypy extension
from VSCode's default extensions for MLOS in favor of just using pyright
(there's no sense in running both interactively). We do not enable
pyright's "strict" mode.

Additionally, it enables pyright in pre-commit rules to ensure those
fixes remain.

We leave the rest of the mypy checks as well since they are still
useful.

A list of some of the types of fixes:

- TypeDict initialization checks for Tunables
- Check that json.loads() returns a dict and not a list (e.g.)
- Replace ConcreteOptimizer TypeVar with a TypeAlias
- Add BoundMethod protocol for checking __self__ attribute
- Ensure correct type inference in a number of places
- Add `...` to Protocol methods to make pyright aware of the lack of
method body.
- Fix a few type annotations

______________________________________________________________________

## Type of Change

- 🛠️ Bug fix
- 🔄 Refactor
______________________________________________________________________

## Testing

- Additional CI checks as described above.

______________________________________________________________________

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
bpkroth and pre-commit-ci[bot] authored Jan 17, 2025
1 parent c66f793 commit e91546e
Show file tree
Hide file tree
Showing 50 changed files with 323 additions and 106 deletions.
1 change: 0 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
"huntertran.auto-markdown-toc",
"ibm.output-colorizer",
"lextudio.restructuredtext",
"matangover.mypy",
"ms-azuretools.vscode-docker",
"ms-python.black-formatter",
"ms-python.pylint",
Expand Down
17 changes: 16 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ci:
# Let pre-commit.ci automatically update PRs with formatting fixes.
autofix_prs: true
# skip local hooks - they should be managed manually via conda-envs/*.yml
skip: [mypy, pylint, pycodestyle]
skip: [mypy, pylint, pycodestyle, pyright]
autoupdate_schedule: monthly
autoupdate_commit_msg: |
[pre-commit.ci] pre-commit autoupdate
Expand All @@ -15,6 +15,7 @@ ci:
See Also:
- https://github.com/microsoft/MLOS/blob/main/conda-envs/mlos.yml
- https://pypi.org/project/mypy/
- https://pypi.org/project/pyright/
- https://pypi.org/project/pylint/
- https://pypi.org/project/pycodestyle/
Expand Down Expand Up @@ -140,6 +141,20 @@ repos:
(?x)^(
doc/source/conf.py
)$
- id: pyright
name: pyright
entry: pyright
language: system
types: [python]
require_serial: true
exclude: |
(?x)^(
doc/source/conf.py|
mlos_core/setup.py|
mlos_bench/setup.py|
mlos_viz/setup.py|
conftest.py
)$
- id: mypy
name: mypy
entry: mypy
Expand Down
1 change: 0 additions & 1 deletion .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"huntertran.auto-markdown-toc",
"ibm.output-colorizer",
"lextudio.restructuredtext",
"matangover.mypy",
"ms-azuretools.vscode-docker",
"ms-python.black-formatter",
"ms-python.pylint",
Expand Down
3 changes: 1 addition & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,5 @@
"python.testing.unittestEnabled": false,
"debugpy.debugJustMyCode": false,
"python.analysis.autoImportCompletions": true,
"python.analysis.supportRestructuredText": true,
"python.analysis.typeCheckingMode": "standard"
"python.analysis.supportRestructuredText": true
}
1 change: 1 addition & 0 deletions conda-envs/mlos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies:
- pylint==3.3.3
- tomlkit
- mypy==1.14.1
- pyright==1.1.392.post0
- pandas-stubs
- types-beautifulsoup4
- types-colorama
Expand Down
4 changes: 4 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from sphinx.application import Sphinx as SphinxApp
from sphinx.environment import BuildEnvironment

# Note: doc requirements aren't installed by default.
# To install them, run `pip install -r doc/requirements.txt`


sys.path.insert(0, os.path.abspath("../../mlos_core/mlos_core"))
sys.path.insert(1, os.path.abspath("../../mlos_bench/mlos_bench"))
sys.path.insert(1, os.path.abspath("../../mlos_viz/mlos_viz"))
Expand Down
6 changes: 3 additions & 3 deletions mlos_bench/mlos_bench/environments/mock_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__( # pylint: disable=too-many-arguments
seed = int(self.config.get("mock_env_seed", -1))
self._run_random = random.Random(seed or None) if seed >= 0 else None
self._status_random = random.Random(seed or None) if seed >= 0 else None
self._range = self.config.get("mock_env_range")
self._metrics = self.config.get("mock_env_metrics", ["score"])
self._range: tuple[int, int] | None = self.config.get("mock_env_range")
self._metrics: list[str] | None = self.config.get("mock_env_metrics", ["score"])
self._is_ready = True

def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue]:
Expand All @@ -80,7 +80,7 @@ def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue
if self._range:
score = self._range[0] + score * (self._range[1] - self._range[0])

return {metric: score for metric in self._metrics}
return {metric: float(score) for metric in self._metrics or []}

def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
"""
Expand Down
2 changes: 1 addition & 1 deletion mlos_bench/mlos_bench/optimizers/mock_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self._random: dict[str, Callable[[Tunable], TunableValue]] = {
"categorical": lambda tunable: rnd.choice(tunable.categories),
"float": lambda tunable: rnd.uniform(*tunable.range),
"int": lambda tunable: rnd.randint(*tunable.range),
"int": lambda tunable: rnd.randint(*(int(x) for x in tunable.range)),
}

def bulk_register(
Expand Down
3 changes: 2 additions & 1 deletion mlos_bench/mlos_bench/services/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Literal

from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.services.types.bound_method import BoundMethod
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
from mlos_bench.util import instantiate_from_config

Expand Down Expand Up @@ -278,7 +279,7 @@ def register(self, services: dict[str, Callable] | list[Callable]) -> None:
for _, svc_method in self._service_methods.items()
# Note: some methods are actually stand alone functions, so we need
# to filter them out.
if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service)
if isinstance(svc_method, BoundMethod) and isinstance(svc_method.__self__, Service)
}

def export(self) -> dict[str, Callable]:
Expand Down
2 changes: 2 additions & 0 deletions mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class CopyMode(Enum):
class SshFileShareService(FileShareService, SshService):
"""A collection of functions for interacting with SSH servers as file shares."""

# pylint: disable=too-many-ancestors

async def _start_file_copy(
self,
params: dict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
"""Helper methods to manage machines via SSH."""

# pylint: disable=too-many-ancestors
# pylint: disable=too-many-instance-attributes

def __init__(
Expand Down
6 changes: 6 additions & 0 deletions mlos_bench/mlos_bench/services/types/authenticator_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
class SupportsAuth(Protocol[T_co]):
"""Protocol interface for authentication for the cloud services."""

# Needed by pyright
# pylint: disable=unnecessary-ellipsis,redundant-returns-doc

def get_access_token(self) -> str:
"""
Get the access token for cloud services.
Expand All @@ -23,6 +26,7 @@ def get_access_token(self) -> str:
access_token : str
Access token.
"""
...

def get_auth_headers(self) -> dict:
"""
Expand All @@ -33,6 +37,7 @@ def get_auth_headers(self) -> dict:
access_header : dict
HTTP header containing the access token.
"""
...

def get_credential(self) -> T_co:
"""
Expand All @@ -43,3 +48,4 @@ def get_credential(self) -> T_co:
credential : T_co
Cloud-specific credential object.
"""
...
24 changes: 24 additions & 0 deletions mlos_bench/mlos_bench/services/types/bound_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""Protocol representing a bound method."""

from typing import Any, Protocol, runtime_checkable


@runtime_checkable
class BoundMethod(Protocol):
"""A callable method bound to an object."""

# pylint: disable=too-few-public-methods
# pylint: disable=unnecessary-ellipsis

@property
def __self__(self) -> Any:
"""The self object of the bound method."""
...

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Call the bound method."""
...
10 changes: 10 additions & 0 deletions mlos_bench/mlos_bench/services/types/config_loader_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
class SupportsConfigLoading(Protocol):
"""Protocol interface for helper functions to lookup and load configs."""

# Needed by pyright
# pylint: disable=unnecessary-ellipsis,redundant-returns-doc

def get_config_paths(self) -> list[str]:
"""
Gets the list of config paths this service will search for config files.
Expand All @@ -31,6 +34,7 @@ def get_config_paths(self) -> list[str]:
-------
list[str]
"""
...

def resolve_path(self, file_path: str, extra_paths: Iterable[str] | None = None) -> str:
"""
Expand All @@ -49,6 +53,7 @@ def resolve_path(self, file_path: str, extra_paths: Iterable[str] | None = None)
path : str
An actual path to the config or script.
"""
...

def load_config(
self,
Expand All @@ -71,6 +76,7 @@ def load_config(
config : Union[dict, list[dict]]
Free-format dictionary that contains the configuration.
"""
...

def build_environment( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -108,6 +114,7 @@ def build_environment( # pylint: disable=too-many-arguments
env : Environment
An instance of the `Environment` class initialized with `config`.
"""
...

def load_environment(
self,
Expand Down Expand Up @@ -140,6 +147,7 @@ def load_environment(
env : Environment
A new benchmarking environment.
"""
...

def load_environment_list(
self,
Expand Down Expand Up @@ -173,6 +181,7 @@ def load_environment_list(
env : list[Environment]
A list of new benchmarking environments.
"""
...

def load_services(
self,
Expand All @@ -198,3 +207,4 @@ def load_services(
service : Service
A collection of service methods.
"""
...
6 changes: 6 additions & 0 deletions mlos_bench/mlos_bench/services/types/host_ops_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
class SupportsHostOps(Protocol):
"""Protocol interface for Host/VM boot operations."""

# pylint: disable=unnecessary-ellipsis

def start_host(self, params: dict) -> tuple["Status", dict]:
"""
Start a Host/VM.
Expand All @@ -29,6 +31,7 @@ def start_host(self, params: dict) -> tuple["Status", dict]:
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
...

def stop_host(self, params: dict, force: bool = False) -> tuple["Status", dict]:
"""
Expand All @@ -47,6 +50,7 @@ def stop_host(self, params: dict, force: bool = False) -> tuple["Status", dict]:
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
...

def restart_host(self, params: dict, force: bool = False) -> tuple["Status", dict]:
"""
Expand All @@ -65,6 +69,7 @@ def restart_host(self, params: dict, force: bool = False) -> tuple["Status", dic
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
...

def wait_host_operation(self, params: dict) -> tuple["Status", dict]:
"""
Expand All @@ -85,3 +90,4 @@ def wait_host_operation(self, params: dict) -> tuple["Status", dict]:
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""
...
6 changes: 6 additions & 0 deletions mlos_bench/mlos_bench/services/types/host_provisioner_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
class SupportsHostProvisioning(Protocol):
"""Protocol interface for Host/VM provisioning operations."""

# pylint: disable=unnecessary-ellipsis

def provision_host(self, params: dict) -> tuple["Status", dict]:
"""
Check if Host/VM is ready. Deploy a new Host/VM, if necessary.
Expand All @@ -31,6 +33,7 @@ def provision_host(self, params: dict) -> tuple["Status", dict]:
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
...

def wait_host_deployment(self, params: dict, *, is_setup: bool) -> tuple["Status", dict]:
"""
Expand All @@ -52,6 +55,7 @@ def wait_host_deployment(self, params: dict, *, is_setup: bool) -> tuple["Status
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""
...

def deprovision_host(self, params: dict) -> tuple["Status", dict]:
"""
Expand All @@ -68,6 +72,7 @@ def deprovision_host(self, params: dict) -> tuple["Status", dict]:
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
...

def deallocate_host(self, params: dict) -> tuple["Status", dict]:
"""
Expand All @@ -88,3 +93,4 @@ def deallocate_host(self, params: dict) -> tuple["Status", dict]:
A pair of Status and result. The result is always {}.
Status is one of {PENDING, SUCCEEDED, FAILED}
"""
...
7 changes: 6 additions & 1 deletion mlos_bench/mlos_bench/services/types/local_exec_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class SupportsLocalExec(Protocol):
vs the target environment. Used in LocalEnv and provided by LocalExecService.
"""

# Needed by pyright
# pylint: disable=unnecessary-ellipsis,redundant-returns-doc

def local_exec(
self,
script_lines: Iterable[str],
Expand All @@ -49,6 +52,7 @@ def local_exec(
(return_code, stdout, stderr) : (int, str, str)
A 3-tuple of return code, stdout, and stderr of the script process.
"""
...

def temp_dir_context(
self,
Expand All @@ -59,11 +63,12 @@ def temp_dir_context(
Parameters
----------
path : str
path : str | None
A path to the temporary directory. Create a new one if None.
Returns
-------
temp_dir_context : tempfile.TemporaryDirectory
Temporary directory context to use in the `with` clause.
"""
...
Loading

0 comments on commit e91546e

Please sign in to comment.