diff --git a/.copier-answers.yml b/.copier-answers.yml index 9e8e4cf69..a15480550 100644 --- a/.copier-answers.yml +++ b/.copier-answers.yml @@ -16,4 +16,4 @@ github_org: DiamondLightSource package_name: blueapi pypi: true repo_name: blueapi -type_checker: mypy +type_checker: pyright diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index d0fae1bda..8e7d56dca 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -27,6 +27,7 @@ "redhat.vscode-yaml", "ryanluker.vscode-coverage-gutters", "charliermarsh.ruff", + "ms-pyright.pyright", "ms-azuretools.vscode-docker" ] } diff --git a/.gitignore b/.gitignore index 472e30e96..7120af3e1 100644 --- a/.gitignore +++ b/.gitignore @@ -44,7 +44,6 @@ nosetests.xml coverage.xml cov.xml .pytest_cache/ -.mypy_cache/ # Translations *.mo diff --git a/pyproject.toml b/pyproject.toml index 8658d979a..f467c2579 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "event-model==1.22.1", # https://github.com/DiamondLightSource/blueapi/issues/684 "opentelemetry-distro>=0.48b0", "opentelemetry-instrumentation-fastapi>=0.48b0", - "observability-utils>=0.1.4" + "observability-utils>=0.1.4", ] dynamic = ["version"] license.file = "LICENSE" @@ -49,8 +49,8 @@ dev = [ "pipdeptree", "pre-commit>=3.8.0", "pydata-sphinx-theme>=0.15.4", - "mypy", "pytest", + "pyright", "pytest-cov", "pytest-asyncio", "responses", @@ -82,11 +82,9 @@ name = "Callum Forrester" [tool.setuptools_scm] version_file = "src/blueapi/_version.py" -[tool.mypy] -ignore_missing_imports = true # Ignore missing stubs in imported modules - -# necessary for tracing sdk to work with mypy, set false once migraion to pyright complete -namespace_packages = true +[tool.pyright] +typeCheckingMode = "standard" +reportMissingImports = false # Ignore missing stubs in imported modules [tool.pytest.ini_options] # Run pytest with all our checkers, and don't spam us with massive tracebacks on error @@ -122,12 +120,12 @@ passenv = * allowlist_externals = pytest pre-commit - mypy + pyright sphinx-build sphinx-autobuild commands = pre-commit: pre-commit run --all-files --show-diff-on-failure {posargs} - type-checking: mypy src tests {posargs} + type-checking: pyright src tests {posargs} tests: pytest --cov=blueapi --cov-report term --cov-report xml:cov.xml tests/unit_tests {posargs} docs: sphinx-{posargs:build -EW --keep-going} -T docs build/html system-test: pytest tests/system_tests {posargs} diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index c98800c20..05592ba2f 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -163,19 +163,16 @@ def get_devices(obj: dict) -> None: def listen_to_events(obj: dict) -> None: """Listen to events output by blueapi""" config: ApplicationConfig = obj["config"] - if config.stomp is not None: - event_bus_client = EventBusClient( - StompClient.for_broker( - broker=Broker( - host=config.stomp.host, - port=config.stomp.port, - auth=config.stomp.auth, - ) + assert config.stomp is not None, "Message bus needs to be configured" + event_bus_client = EventBusClient( + StompClient.for_broker( + broker=Broker( + host=config.stomp.host, + port=config.stomp.port, + auth=config.stomp.auth, ) ) - else: - raise RuntimeError("Message bus needs to be configured") - + ) fmt = obj["fmt"] def on_event( diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 2e805ea74..4edba0621 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -46,17 +46,16 @@ def __init__( @classmethod def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": rest = BlueapiRestClient(config.api) - if config.stomp is not None: - stomp_client = StompClient.for_broker( - broker=Broker( - host=config.stomp.host, - port=config.stomp.port, - auth=config.stomp.auth, - ) + if config.stomp is None: + return cls(rest) + client = StompClient.for_broker( + broker=Broker( + host=config.stomp.host, + port=config.stomp.port, + auth=config.stomp.auth, ) - events = EventBusClient(stomp_client) - else: - events = None + ) + events = EventBusClient(client) return cls(rest, events) @start_as_current_span(TRACER) diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index c92c38113..6dc4057f1 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -143,10 +143,10 @@ def my_plan(a: int, b: str): if not is_bluesky_plan_generator(plan): raise TypeError(f"{plan} is not a valid plan generator function") - model = create_model( # type: ignore + model = create_model( plan.__name__, __config__=BlueapiPlanModelConfig, - **self._type_spec_for_function(plan), + **self._type_spec_for_function(plan), # type: ignore ) self.plans[plan.__name__] = Plan( name=plan.__name__, model=model, description=plan.__doc__ @@ -284,7 +284,7 @@ def _convert_type(self, typ: type | Any) -> type: root = get_origin(typ) if root == UnionType: root = Union - return root[new_types] if root else typ + return root[new_types] if root else typ # type: ignore return typ diff --git a/src/blueapi/core/device_lookup.py b/src/blueapi/core/device_lookup.py index 1bace1676..72dff32f0 100644 --- a/src/blueapi/core/device_lookup.py +++ b/src/blueapi/core/device_lookup.py @@ -1,12 +1,9 @@ -from typing import Any, TypeVar +from typing import Any from .bluesky_types import Device, is_bluesky_compatible_device -#: Device obeying Bluesky protocols -D = TypeVar("D", bound=Device) - -def find_component(obj: Any, addr: list[str]) -> D | None: +def find_component(obj: Any, addr: list[str]) -> Device | None: """ Best effort function to locate a child device, either in a dictionary of devices or a device with child attributes. diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index f968fd925..f4d819c26 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -6,7 +6,7 @@ from bluesky_stomp.messaging import StompClient from bluesky_stomp.models import Broker, DestinationBase, MessageTopic -from blueapi.config import ApplicationConfig +from blueapi.config import ApplicationConfig, StompConfig from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream from blueapi.service.model import DeviceModel, PlanModel, WorkerTask @@ -50,11 +50,13 @@ def worker() -> TaskWorker: @cache def stomp_client() -> StompClient | None: - stomp_config = config().stomp + stomp_config: StompConfig | None = config().stomp if stomp_config is not None: - stomp_client = StompClient.for_broker( + client = StompClient.for_broker( broker=Broker( - host=stomp_config.host, port=stomp_config.port, auth=stomp_config.auth + host=stomp_config.host, + port=stomp_config.port, + auth=stomp_config.auth, # type: ignore ) ) @@ -68,8 +70,8 @@ def stomp_client() -> StompClient | None: task_worker.data_events: event_topic, } ) - stomp_client.connect() - return stomp_client + client.connect() + return client else: return None diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 11ad7271f..86aa08b59 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -101,7 +101,7 @@ def get_app(): TRACER = get_tracer("interface") -async def on_key_error_404(_: Request, __: KeyError): +async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, content={"detail": "Item not found"}, @@ -181,8 +181,8 @@ def submit_task( runner: WorkerDispatcher = Depends(_runner), ): """Submit a task to the worker.""" + plan_model = runner.run(interface.get_plan, task.name) try: - plan_model = runner.run(interface.get_plan, task.name) task_id: str = runner.run(interface.submit_task, task) response.headers["Location"] = f"{request.url}/{task_id}" return TaskResponse(task_id=task_id) @@ -193,7 +193,7 @@ def submit_task( ) error_detail_response = f""" Input validation failed: {formatted_errors}, - suppplied params {task.params}, + supplied params {task.params}, do not match the expected params: {plan_model.parameter_schema} """ raise HTTPException( diff --git a/src/blueapi/service/runner.py b/src/blueapi/service/runner.py index 658fe8a5e..be4c49ba6 100644 --- a/src/blueapi/service/runner.py +++ b/src/blueapi/service/runner.py @@ -69,7 +69,7 @@ def reload(self): @start_as_current_span(TRACER) def start(self): add_span_attributes( - {"_use_subprocess": self._use_subprocess, "_config": self._config} + {"_use_subprocess": self._use_subprocess, "_config": str(self._config)} ) try: if self._use_subprocess: @@ -176,7 +176,7 @@ def _rpc( ctx = get_global_textmap().extract(carrier) attach(ctx) mod = import_module(module_name) - func: Callable[P, T] = _validate_function( + func: Callable[..., T] = _validate_function( mod.__dict__.get(function_name, None), function_name ) value = func(*args, **kwargs) diff --git a/src/blueapi/startup/example_plans.py b/src/blueapi/startup/example_plans.py index 2915707a5..2f353c7b3 100644 --- a/src/blueapi/startup/example_plans.py +++ b/src/blueapi/startup/example_plans.py @@ -5,11 +5,14 @@ from blueapi.core import MsgGenerator +TEMP: Movable = inject("sample_temperature") +PRESS: Movable = inject("sample_pressure") + def stp_snapshot( detectors: list[Readable], - temperature: Movable = inject("sample_temperature"), - pressure: Movable = inject("sample_pressure"), + temperature: Movable = TEMP, + pressure: Movable = PRESS, ) -> MsgGenerator: """ Moves devices for pressure and temperature (defaults fetched from the context) @@ -26,5 +29,5 @@ def stp_snapshot( Yields: Iterator[MsgGenerator]: Bluesky messages """ - yield from move({temperature: 0, pressure: 10**5}) - yield from count(detectors, 1) + yield from move({temperature: 0, pressure: 10**5}) # type: ignore + yield from count(set(detectors), 1) diff --git a/src/blueapi/startup/simmotor.py b/src/blueapi/startup/simmotor.py index 1d8f9c4a6..5b947ef4b 100644 --- a/src/blueapi/startup/simmotor.py +++ b/src/blueapi/startup/simmotor.py @@ -3,7 +3,7 @@ from collections.abc import Callable from ophyd.sim import SynAxis -from ophyd.status import MoveStatus, Status +from ophyd.status import MoveStatus class SynAxisWithMotionEvents(SynAxis): @@ -27,8 +27,8 @@ def __init__( super().__init__( name=name, readback_func=readback_func, - value=value, - delay=delay, + value=value, # type: ignore + delay=delay, # type: ignore precision=precision, parent=parent, labels=labels, @@ -38,7 +38,7 @@ def __init__( self._events_per_move = events_per_move self.egu = egu - def set(self, value: float) -> None: + def set(self, value: float) -> MoveStatus: old_setpoint = self.sim_state["setpoint"] distance = value - old_setpoint self.sim_state["setpoint"] = value @@ -90,5 +90,5 @@ def __init__(self, *, timeout: float, **kwargs) -> None: super().__init__(**kwargs) self._timeout = timeout - def set(self, value: float) -> Status: - return Status(timeout=self._timeout) + def set(self, value: float) -> MoveStatus: + return MoveStatus(positioner=self, target=value, timeout=self._timeout) diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index b871f842a..c81762542 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -7,7 +7,6 @@ __all__ = [ "handle_all_exceptions", "load_module_all", - "ConfigLoader", "serialize", "BlueapiBaseModel", "BlueapiModelConfig", diff --git a/src/blueapi/worker/__init__.py b/src/blueapi/worker/__init__.py index 7862912cc..85ae49b45 100644 --- a/src/blueapi/worker/__init__.py +++ b/src/blueapi/worker/__init__.py @@ -6,7 +6,6 @@ __all__ = [ "TaskWorker", "Task", - "Worker", "WorkerEvent", "WorkerState", "StatusView", diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 29c815bdc..8bd99d63a 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -8,7 +8,8 @@ from blueapi.utils import BlueapiBaseModel # The RunEngine can return any of these three types as its state -RawRunEngineState = type[PropertyMachine | ProxyString | str] +# RawRunEngineState = type[PropertyMachine | ProxyString | str] +RawRunEngineState = PropertyMachine | ProxyString | str # NOTE this is interim until refactor diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index db760f421..546c5f3b5 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -86,6 +86,14 @@ class TaskWorker: _state: WorkerState _errors: list[str] _warnings: list[str] + + # The queue is actually a channel between 2 threads + # most programming languages have a separate abstraction for this + # but Python reuses Queue + # So it's not used as a standard queue, + # but as a box in which to put the "current" task and nothing else + # So the calling thread can only ever submit one plan at a time. + _task_channel: Queue # type: ignore _current: TrackableTask | None _status_lock: RLock @@ -110,7 +118,9 @@ def __init__( self._tasks = {} - self._state = WorkerState.from_bluesky_state(ctx.run_engine.state) + assert ctx.run_engine.state is not None, "RunEngine state is not set" + state: RawRunEngineState = str(ctx.run_engine.state) + self._state = WorkerState.from_bluesky_state(state) self._errors = [] self._warnings = [] self._task_channel = Queue(maxsize=1) @@ -141,15 +151,17 @@ def cancel_active_task( reason: str | None = None, ) -> str: if self._current is None: - # Persuades mypy that self._current is not None + # Persuades type checker that self._current is not None # We only allow this method to be called if a Plan is active raise TransitionError("Attempted to cancel while no active Task") if failure: - self._ctx.run_engine.abort(reason) - add_span_attributes({"Task aborted": reason}) + default_reason = "Task failed for unknown reason" + self._ctx.run_engine.abort(reason or default_reason) + add_span_attributes({"Task aborted": reason or default_reason}) else: self._ctx.run_engine.stop() - add_span_attributes({"Task stopped": reason}) + default_reason = "Cancellation successful: Task stopped without error" + add_span_attributes({"Task stopped": reason or default_reason}) return self._current.task_id @start_as_current_span(TRACER) @@ -220,6 +232,7 @@ def mark_task_as_started(event: WorkerEvent, _: str | None) -> None: task_started.set() LOGGER.info(f"Submitting: {trackable_task}") + sub = self.worker_events.subscribe(mark_task_as_started) try: self._current_task_otel_context = get_current() sub = self.worker_events.subscribe(mark_task_as_started) @@ -276,10 +289,10 @@ def state(self) -> WorkerState: @start_as_current_span(TRACER) def run(self) -> None: LOGGER.info("Worker starting") - self._ctx.run_engine.state_hook = self._on_state_change + self._ctx.run_engine.state_hook = self._on_state_change # type: ignore self._ctx.run_engine.subscribe(self._on_document) if self._broadcast_statuses: - self._ctx.run_engine.waiting_hook = self._waiting_hook + self._ctx.run_engine.waiting_hook = self._waiting_hook # type: ignore self._stopped.clear() self._started.set() @@ -319,10 +332,7 @@ def _cycle(self) -> None: kind=SpanKind.SERVER, ): LOGGER.info(f"Got new task: {next_task}") - self._current = ( - next_task # Informing mypy that the task is not None - ) - + self._current = next_task self._current_task_otel_context = get_current() add_span_attributes({"next_task.task_id": next_task.task_id}) diff --git a/tests/conftest.py b/tests/conftest.py index b61c09e9d..a564d7980 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,8 @@ # Based on https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option # noqa: E501 import pytest -from bluesky import RunEngine -from bluesky.run_engine import TransitionError +from bluesky._vendor.super_state_machine.errors import TransitionError +from bluesky.run_engine import RunEngine from observability_utils.tracing import JsonObjectSpanExporter, setup_tracing from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor @@ -32,7 +32,7 @@ def clean_event_loop(): @pytest.fixture(scope="session") -def exporter() -> TracerProvider: +def exporter() -> JsonObjectSpanExporter: setup_tracing("test", False) exporter = JsonObjectSpanExporter() provider = cast(TracerProvider, get_tracer_provider()) diff --git a/tests/unit_tests/cli/test_scratch.py b/tests/unit_tests/cli/test_scratch.py index 7e4ca3b9e..29c5f282d 100644 --- a/tests/unit_tests/cli/test_scratch.py +++ b/tests/unit_tests/cli/test_scratch.py @@ -1,6 +1,7 @@ import os import stat import uuid +from collections.abc import Generator from pathlib import Path from tempfile import TemporaryDirectory from unittest.mock import Mock, call, patch @@ -12,14 +13,14 @@ @pytest.fixture -def directory_path() -> Path: # type: ignore +def directory_path() -> Generator[Path]: temporary_directory = TemporaryDirectory() yield Path(temporary_directory.name) temporary_directory.cleanup() @pytest.fixture -def file_path(directory_path: Path) -> Path: # type: ignore +def file_path(directory_path: Path) -> Generator[Path]: file_path = directory_path / str(uuid.uuid4()) with file_path.open("w") as stream: stream.write("foo") diff --git a/tests/unit_tests/core/test_context.py b/tests/unit_tests/core/test_context.py index d01db462e..908fb142b 100644 --- a/tests/unit_tests/core/test_context.py +++ b/tests/unit_tests/core/test_context.py @@ -49,7 +49,10 @@ def has_typeless_params(foo, bar) -> MsgGenerator: # type: ignore ... -def has_default_reference(m: Movable = inject(SIM_MOTOR_NAME)) -> MsgGenerator: +MOTOR: Movable = inject(SIM_MOTOR_NAME) + + +def has_default_reference(m: Movable = MOTOR) -> MsgGenerator: yield from [] diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index 8faf91332..ebfe1f59d 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -3,12 +3,12 @@ from unittest.mock import ANY, MagicMock, Mock, patch import pytest +from bluesky.utils import MsgGenerator from bluesky_stomp.messaging import StompClient from ophyd.sim import SynAxis from stomp.connect import StompConnection11 as Connection from blueapi.config import ApplicationConfig, StompConfig -from blueapi.core import MsgGenerator from blueapi.core.context import BlueskyContext from blueapi.service import interface from blueapi.service.model import DeviceModel, PlanModel, WorkerTask diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 8a1fe5219..578dea13e 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -184,7 +184,7 @@ class MyModel(BaseModel): assert response.json() == { "detail": ( "\n Input validation failed: id: Field required,\n" - " suppplied params {},\n" + " supplied params {},\n" " do not match the expected params: {'properties': {'id': " "{'title': 'Id', 'type': 'string'}}, 'required': ['id'], 'title': " "'MyModel', 'type': 'object'}\n " diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index b6c75eea9..fc9ddeab8 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -136,7 +136,7 @@ def test_submit_plan(runner: CliRunner): def test_invalid_stomp_config_for_listener(runner: CliRunner): result = runner.invoke(main, ["controller", "listen"]) - assert isinstance(result.exception, RuntimeError) + assert isinstance(result.exception, AssertionError) assert str(result.exception) == "Message bus needs to be configured" diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index a5e0b32ad..69351170b 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -126,20 +126,20 @@ def test_error_thrown_if_schema_does_not_match_yaml(nested_config_yaml: Path) -> @mock.patch.dict(os.environ, {"FOO": "bar"}, clear=True) def test_auth_from_env(): - auth = BasicAuthentication(username="${FOO}", password="baz") + auth = BasicAuthentication(username="${FOO}", password="baz") # type: ignore assert auth.username == "bar" @mock.patch.dict(os.environ, {"FOO": "bar", "BAZ": "qux"}, clear=True) def test_auth_from_env_repeated_key(): - auth = BasicAuthentication(username="${FOO}", password="${FOO}") + auth = BasicAuthentication(username="${FOO}", password="${FOO}") # type: ignore assert auth.username == "bar" assert auth.password.get_secret_value() == "bar" @mock.patch.dict(os.environ, {"FOO": "bar"}, clear=True) def test_auth_from_env_ignore_case(): - auth = BasicAuthentication(username="${FOO}", password="${foo}") + auth = BasicAuthentication(username="${FOO}", password="${foo}") # type: ignore assert auth.username == "bar" assert auth.password.get_secret_value() == "bar" @@ -148,9 +148,9 @@ def test_auth_from_env_ignore_case(): def test_auth_from_env_throws_when_not_available(): # Eagerly throws an exception, will fail during initial loading with pytest.raises(KeyError): - BasicAuthentication(username="${BAZ}", password="baz") + BasicAuthentication(username="${BAZ}", password="baz") # type: ignore with pytest.raises(KeyError): - BasicAuthentication(username="${baz}", passcode="baz") + BasicAuthentication(username="${baz}", passcode="baz") # type: ignore def is_subset(subset: Mapping[str, Any], superset: Mapping[str, Any]) -> bool: diff --git a/tests/unit_tests/worker/devices.py b/tests/unit_tests/worker/devices.py index 7d34fe6e5..e6dadc53c 100644 --- a/tests/unit_tests/worker/devices.py +++ b/tests/unit_tests/worker/devices.py @@ -1,7 +1,8 @@ # Devices to use for worker tests -from bluesky.protocols import Movable, Status +from bluesky.protocols import Movable from ophyd import Device, DeviceStatus +from ophyd.status import Status class AdditionalUpdateStatus(DeviceStatus): @@ -39,9 +40,9 @@ def _run_callbacks(self) -> None: class AdditionalStatusDevice(Device, Movable): - def set(self, value: float) -> Status: + def set(self, value: float) -> Status: # type: ignore status = AdditionalUpdateStatus(self) - return status + return status # type: ignore def additional_status_device(name="additional_status_device") -> AdditionalStatusDevice: