diff --git a/.copier-answers.yml b/.copier-answers.yml index 878df0ca3f..3461332e26 100644 --- a/.copier-answers.yml +++ b/.copier-answers.yml @@ -1,9 +1,8 @@ # Changes here will be overwritten by Copier -_commit: 2.1.0 +_commit: 2.3.0 _src_path: gh:DiamondLightSource/python-copier-template author_email: tom.cobb@diamond.ac.uk author_name: Tom Cobb -component_owner: '' description: Asynchronous Bluesky hardware abstraction code, compatible with control systems like EPICS and Tango distribution_name: ophyd-async diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index a8f81fbc27..d2055bcd01 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -52,4 +52,4 @@ "workspaceMount": "source=${localWorkspaceFolder}/..,target=/workspaces,type=bind", // After the container is created, install the python project in editable form "postCreateCommand": "pip install $([ -f dev-requirements.txt ] && echo '-c dev-requirements.txt') -e '.[dev]' && pre-commit install" -} \ No newline at end of file +} diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 43c7642f3c..0473b72142 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -7,4 +7,4 @@ # Use isort to sort imports 881a35b43584103ca572b6f4e472dd8b6fd6ea87 # Replace flake8 and mypy with ruff and pyrite -e2f8317e7584e4de788c2b39e5b5edaa98c1bc9e \ No newline at end of file +e2f8317e7584e4de788c2b39e5b5edaa98c1bc9e diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 1d6f7ce3ba..27f6450d97 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -24,4 +24,4 @@ It is recommended that developers use a [vscode devcontainer](https://code.visua This project was created using the [Diamond Light Source Copier Template](https://github.com/DiamondLightSource/python-copier-template) for Python projects. -For more information on common tasks like setting up a developer environment, running the tests, and setting a pre-commit hook, see the template's [How-to guides](https://diamondlightsource.github.io/python-copier-template/2.1.0/how-to.html). +For more information on common tasks like setting up a developer environment, running the tests, and setting a pre-commit hook, see the template's [How-to guides](https://diamondlightsource.github.io/python-copier-template/2.3.0/how-to.html). diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000000..aa65892f39 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,21 @@ +--- +name: Bug Report +about: The template to use for reporting bugs and usability issues +title: " " +labels: 'bug' +assignees: '' + +--- + +Describe the bug, including a clear and concise description of the expected behavior, the actual behavior and the context in which you encountered it (ideally include details of your environment). + +## Steps To Reproduce +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + + +## Acceptance Criteria +- Specific criteria that will be used to judge if the issue is fixed diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md new file mode 100644 index 0000000000..52c84dd853 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -0,0 +1,13 @@ +--- +name: Issue +about: The standard template to use for feature requests, design discussions and tasks +title: " " +labels: '' +assignees: '' + +--- + +A brief description of the issue, including specific stakeholders and the business case where appropriate + +## Acceptance Criteria +- Specific criteria that will be used to judge if the issue is fixed diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md new file mode 100644 index 0000000000..8200afe5c4 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md @@ -0,0 +1,8 @@ +Fixes #ISSUE + +### Instructions to reviewer on how to test: +1. Do thing x +2. Confirm thing y happens + +### Checks for reviewer +- [ ] Would the PR title make sense to a user on a set of release notes diff --git a/.github/pages/index.html b/.github/pages/index.html index 80f0a00912..c495f39f2f 100644 --- a/.github/pages/index.html +++ b/.github/pages/index.html @@ -8,4 +8,4 @@ - \ No newline at end of file + diff --git a/.github/pages/make_switcher.py b/.github/pages/make_switcher.py index 0babd3c6bb..c06813afad 100755 --- a/.github/pages/make_switcher.py +++ b/.github/pages/make_switcher.py @@ -1,30 +1,32 @@ +"""Make switcher.json to allow docs to switch between different versions.""" + import json import logging from argparse import ArgumentParser from pathlib import Path from subprocess import CalledProcessError, check_output -from typing import List, Optional -def report_output(stdout: bytes, label: str) -> List[str]: +def report_output(stdout: bytes, label: str) -> list[str]: + """Print and return something received frm stdout.""" ret = stdout.decode().strip().split("\n") print(f"{label}: {ret}") return ret -def get_branch_contents(ref: str) -> List[str]: +def get_branch_contents(ref: str) -> list[str]: """Get the list of directories in a branch.""" stdout = check_output(["git", "ls-tree", "-d", "--name-only", ref]) return report_output(stdout, "Branch contents") -def get_sorted_tags_list() -> List[str]: +def get_sorted_tags_list() -> list[str]: """Get a list of sorted tags in descending order from the repository.""" stdout = check_output(["git", "tag", "-l", "--sort=-v:refname"]) return report_output(stdout, "Tags list") -def get_versions(ref: str, add: Optional[str]) -> List[str]: +def get_versions(ref: str, add: str | None) -> list[str]: """Generate the file containing the list of all GitHub Pages builds.""" # Get the directories (i.e. builds) from the GitHub Pages branch try: @@ -41,7 +43,7 @@ def get_versions(ref: str, add: Optional[str]) -> List[str]: tags = get_sorted_tags_list() # Make the sorted versions list from main branches and tags - versions: List[str] = [] + versions: list[str] = [] for version in ["master", "main"] + tags: if version in builds: versions.append(version) @@ -53,14 +55,12 @@ def get_versions(ref: str, add: Optional[str]) -> List[str]: return versions -def write_json(path: Path, repository: str, versions: str): +def write_json(path: Path, repository: str, versions: list[str]): + """Write the JSON switcher to path.""" org, repo_name = repository.split("/") - pages_url = f"https://{org}.github.io" - if repo_name != f"{org}.github.io": - # Only add the repo name if it isn't the source for the org pages site - pages_url += f"/{repo_name}" struct = [ - {"version": version, "url": f"{pages_url}/{version}/"} for version in versions + {"version": version, "url": f"https://{org}.github.io/{repo_name}/{version}/"} + for version in versions ] text = json.dumps(struct, indent=2) print(f"JSON switcher:\n{text}") @@ -68,6 +68,7 @@ def write_json(path: Path, repository: str, versions: str): def main(args=None): + """Parse args and write switcher.""" parser = ArgumentParser( description="Make a versions.json file from gh-pages directories" ) diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index ce33811ff6..a1cafcaedf 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -51,4 +51,4 @@ jobs: with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: .github/pages - keep_files: true \ No newline at end of file + keep_files: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cf5bc80bbc..265fea6424 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,14 +43,14 @@ jobs: needs: check if: needs.check.outputs.branch-pr == '' uses: ./.github/workflows/_dist.yml - + pypi: if: github.ref_type == 'tag' needs: dist uses: ./.github/workflows/_pypi.yml permissions: id-token: write - + release: if: github.ref_type == 'tag' needs: [dist, docs] diff --git a/.gitignore b/.gitignore index 0992cd727a..ef6d127ea9 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,7 @@ cov.xml # Sphinx documentation docs/_build/ +docs/_api # PyBuilder target/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a4cbf7b41..60fc23f9a7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,7 @@ repos: - id: check-added-large-files - id: check-yaml - id: check-merge-conflict + - id: end-of-file-fixer - repo: local hooks: diff --git a/docs/_api.rst b/docs/_api.rst new file mode 100644 index 0000000000..d0c587dcd4 --- /dev/null +++ b/docs/_api.rst @@ -0,0 +1,16 @@ +:orphan: + +.. + This page is not included in the TOC tree, but must exist so that the + autosummary pages are generated for ophyd_async and all its + subpackages + +API +=== + +.. autosummary:: + :toctree: _api + :template: custom-module-template.rst + :recursive: + + ophyd_async diff --git a/docs/_templates/README b/docs/_templates/README deleted file mode 100644 index 1f9b817343..0000000000 --- a/docs/_templates/README +++ /dev/null @@ -1 +0,0 @@ -https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion \ No newline at end of file diff --git a/docs/_templates/custom-class-template.rst b/docs/_templates/custom-class-template.rst deleted file mode 100644 index 236b77063c..0000000000 --- a/docs/_templates/custom-class-template.rst +++ /dev/null @@ -1,40 +0,0 @@ -.. note:: - - Ophyd async is included on a provisional basis until the v1.0 release and - may change API on minor release numbers before then - -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. autoclass:: {{ objname }} - :members: - :undoc-members: - :show-inheritance: - :inherited-members: - :special-members: __call__, __add__, __mul__ - - {% block methods %} - {% if methods %} - .. rubric:: {{ _('Methods') }} - - .. autosummary:: - :nosignatures: - {% for item in methods %} - {%- if not item.startswith('_') %} - ~{{ name }}.{{ item }} - {%- endif -%} - {%- endfor %} - {% endif %} - {% endblock %} - - {% block attributes %} - {% if attributes %} - .. rubric:: {{ _('Attributes') }} - - .. autosummary:: - {% for item in attributes %} - ~{{ name }}.{{ item }} - {%- endfor %} - {% endif %} - {% endblock %} diff --git a/docs/_templates/custom-module-template.rst b/docs/_templates/custom-module-template.rst index 3fbef60a7e..726bf49435 100644 --- a/docs/_templates/custom-module-template.rst +++ b/docs/_templates/custom-module-template.rst @@ -1,72 +1,42 @@ .. note:: - Ophyd async is included on a provisional basis until the v1.0 release and + Ophyd async is considered experimental until the v1.0 release and may change API on minor release numbers before then -{{ fullname | escape | underline}} +{{ ('``' + fullname + '``') | underline }} -.. automodule:: {{ fullname }} - - {% block attributes %} - {% if attributes %} - .. rubric:: Module attributes - - .. autosummary:: - :toctree: - {% for item in attributes %} - {{ item }} - {%- endfor %} - {% endif %} - {% endblock %} - - {% block functions %} - {% if functions %} - .. rubric:: {{ _('Functions') }} - - .. autosummary:: - :toctree: - :nosignatures: - {% for item in functions %} - {{ item }} - {%- endfor %} - {% endif %} - {% endblock %} - - {% block classes %} - {% if classes %} - .. rubric:: {{ _('Classes') }} - - .. autosummary:: - :toctree: - :template: custom-class-template.rst - :nosignatures: - {% for item in classes %} - {{ item }} - {%- endfor %} - {% endif %} - {% endblock %} - - {% block exceptions %} - {% if exceptions %} - .. rubric:: {{ _('Exceptions') }} - - .. autosummary:: - :toctree: - :nosignatures: - {% for item in exceptions %} - {{ item }} - {%- endfor %} - {% endif %} - {% endblock %} - -{% block modules %} -{% if modules %} -.. autosummary:: - :toctree: - :template: custom-module-template.rst - :recursive: -{% for item in modules %} - {{ item }} +{%- set filtered_members = [] %} +{%- for item in members %} + {%- if item in functions + classes + exceptions + attributes %} + {% set _ = filtered_members.append(item) %} + {%- endif %} {%- endfor %} -{% endif %} -{% endblock %} + +.. automodule:: {{ fullname }} + :members: + + {% block modules %} + {% if modules %} + .. rubric:: Submodules + + .. autosummary:: + :toctree: + :template: custom-module-template.rst + :recursive: + {% for item in modules %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block members %} + {% if filtered_members %} + .. rubric:: Members + + .. autosummary:: + :nosignatures: + {% for item in filtered_members %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/conf.py b/docs/conf.py index c96eb76efa..77565b386d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,8 +1,9 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html +"""Configuration file for the Sphinx documentation builder. + +This file only contains a selection of the most common options. For a full +list see the documentation: +https://www.sphinx-doc.org/en/master/usage/configuration.html +""" import os import sys @@ -38,7 +39,10 @@ "sphinxcontrib.autodoc_pydantic", # Use this for generating API docs "sphinx.ext.autodoc", + # Not sure if this is still used? "sphinx.ext.doctest", + # and making summary tables at the top of API docs + "sphinx.ext.autosummary", # This can parse google style docstrings "sphinx.ext.napoleon", # For linking to external sphinx documentation @@ -51,7 +55,6 @@ "sphinx_copybutton", # For the card element "sphinx_design", - "sphinx.ext.autosummary", "sphinx.ext.mathjax", "sphinx.ext.githubpages", "IPython.sphinxext.ipython_directive", @@ -88,16 +91,21 @@ ("py:class", "typing_extensions.Literal"), ] -# Both the class’ and the __init__ method’s docstring are concatenated and -# inserted into the main body of the autoclass directive -autoclass_content = "both" - # Order the members by the order they appear in the source code autodoc_member_order = "bysource" # Don't inherit docstrings from baseclasses autodoc_inherit_docstrings = False +# Add some more modules to the top level autosummary +ophyd_async.__all__ += ["sim", "epics", "tango", "fastcs", "plan_stubs"] + +# Document only what is in __all__ +autosummary_ignore_module_all = False + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + # Output graphviz directive produced images in a scalable format graphviz_output_format = "svg" @@ -241,9 +249,11 @@ # numpydoc config numpydoc_show_class_members = False -# pydantic models -autodoc_pydantic_model_show_json = True +# Don't show config summary as it's not relevant autodoc_pydantic_model_show_config_summary = False +# Show the fields in source order +autodoc_pydantic_model_summary_list_order = "bysource" + # Where to put Ipython savefigs ipython_savefig_dir = "../build/savefig" diff --git a/docs/examples/foo_detector.py b/docs/examples/foo_detector.py index bf3e41e6f6..1c4d824406 100644 --- a/docs/examples/foo_detector.py +++ b/docs/examples/foo_detector.py @@ -1,5 +1,4 @@ import asyncio -from typing import Optional from bluesky.protocols import HasHints, Hints @@ -32,7 +31,7 @@ async def arm( self, num: int, trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = None, + exposure: float | None = None, ) -> AsyncStatus: await asyncio.gather( self._drv.num_images.set(num), diff --git a/docs/explanations/decisions/0003-ophyd-async-migration.rst b/docs/explanations/decisions/0003-ophyd-async-migration.rst index 2c33c132db..22a4fbc941 100644 --- a/docs/explanations/decisions/0003-ophyd-async-migration.rst +++ b/docs/explanations/decisions/0003-ophyd-async-migration.rst @@ -47,4 +47,4 @@ Consequences ------------ This will require changing the repository structure of Ophyd Async; see -the decision on repository structure :doc:`0004-repository-structure` for details. \ No newline at end of file +the decision on repository structure :doc:`0004-repository-structure` for details. diff --git a/docs/explanations/decisions/0005-respect-black-line-length.rst b/docs/explanations/decisions/0005-respect-black-line-length.rst index 2de0ab5603..fa0b0d1961 100644 --- a/docs/explanations/decisions/0005-respect-black-line-length.rst +++ b/docs/explanations/decisions/0005-respect-black-line-length.rst @@ -26,4 +26,4 @@ Consequences ------------ Linting tools for this repository are configured to accept black's line length of 88 characters. -Any additional linting tools should respect this. \ No newline at end of file +Any additional linting tools should respect this. diff --git a/docs/explanations/design-goals.rst b/docs/explanations/design-goals.rst index 8c6b963a5d..63bd74dd17 100644 --- a/docs/explanations/design-goals.rst +++ b/docs/explanations/design-goals.rst @@ -54,4 +54,4 @@ To view and contribute to discussions on outstanding decisions, please see the d .. _malcolm: https://github.com/dls-controls/pymalcolm .. _scanspec: https://github.com/dls-controls/scanspec .. _design: https://github.com/bluesky/ophyd-async/issues?q=is%3Aissue+is%3Aopen+label%3Adesign -.. _pmac: https://github.com/dls-controls/pmac \ No newline at end of file +.. _pmac: https://github.com/dls-controls/pmac diff --git a/docs/explanations/event-loop-choice.rst b/docs/explanations/event-loop-choice.rst index 9ef4abb3db..64b3ff58cb 100644 --- a/docs/explanations/event-loop-choice.rst +++ b/docs/explanations/event-loop-choice.rst @@ -48,5 +48,3 @@ they can either: * Run the :python:`DeviceCollector` first and pass the event-loop into the run-engine. * Initialize the run-engine first and run the :python:`DeviceCollector` using the bluesky event-loop. - - diff --git a/docs/how-to/contribute.md b/docs/how-to/contribute.md index f9c4ca1d75..6e41979708 100644 --- a/docs/how-to/contribute.md +++ b/docs/how-to/contribute.md @@ -1,2 +1,2 @@ ```{include} ../../.github/CONTRIBUTING.md -``` \ No newline at end of file +``` diff --git a/docs/how-to/write-tests-for-devices.rst b/docs/how-to/write-tests-for-devices.rst index b5da899599..0ddd0cd462 100644 --- a/docs/how-to/write-tests-for-devices.rst +++ b/docs/how-to/write-tests-for-devices.rst @@ -1,6 +1,6 @@ .. note:: - Ophyd async is included on a provisional basis until the v1.0 release and + Ophyd async is included on a provisional basis until the v1.0 release and may change API on minor release numbers before then Write Tests for Devices @@ -35,7 +35,7 @@ Mock Utility Functions Mock signals behave as simply as possible, holding a sensible default value when initialized and retaining any value (in memory) to which they are set. This model breaks down in the case of read-only signals, which cannot be set because there is an expectation of some external device setting them in the real world. There is a utility function, ``set_mock_value``, to mock-set values for mock signals, including read-only ones. -In addition this example also utilizes helper functions like ``assert_reading`` and ``assert_value`` to ensure the validity of device readings and values. For more information see: :doc:`API.core<../generated/ophyd_async.core>` +In addition this example also utilizes helper functions like ``assert_reading`` and ``assert_value`` to ensure the validity of device readings and values. For more information see: :doc:`API.core<../_api/ophyd_async.core>` .. literalinclude:: ../../tests/epics/demo/test_demo.py :pyobject: test_sensor_reading_shows_value diff --git a/docs/reference.md b/docs/reference.md index b9e0e3b834..77d9c6c4c5 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -6,7 +6,7 @@ Technical reference material including APIs and release notes. :maxdepth: 1 :glob: -reference/* +API <_api/ophyd_async> genindex Release Notes ``` diff --git a/docs/reference/api.rst b/docs/reference/api.rst deleted file mode 100644 index 74485ee1e8..0000000000 --- a/docs/reference/api.rst +++ /dev/null @@ -1,29 +0,0 @@ -.. note:: - - Ophyd async is included on a provisional basis until the v1.0 release and - may change API on minor release numbers before then - -API -=== - -.. automodule:: ophyd_async - - ``ophyd_async`` - ----------------------------------- - -This is the internal API reference for ophyd_async - -.. data:: ophyd_async.__version__ - :type: str - - Version number as calculated by https://github.com/pypa/setuptools_scm - - -.. autosummary:: - :toctree: ../generated - :template: custom-module-template.rst - :recursive: - - core - epics - fastcs \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6d0e0324bd..23afbf19bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=64", "setuptools_scm[toml]>=6.2", "wheel"] +requires = ["setuptools>=64", "setuptools_scm[toml]>=8"] build-backend = "setuptools.build_meta" [project] @@ -53,6 +53,7 @@ dev = [ "pre-commit", "pydata-sphinx-theme>=0.12", "pyepics>=3.4.2", + "pyright", "pyside6==6.7.0", "pytest", "pytest-asyncio", @@ -83,9 +84,12 @@ GitHub = "https://github.com/bluesky/ophyd-async" email = "tom.cobb@diamond.ac.uk" name = "Tom Cobb" - [tool.setuptools_scm] -write_to = "src/ophyd_async/_version.py" +version_file = "src/ophyd_async/_version.py" + +[tool.pyright] +typeCheckingMode = "standard" +reportMissingImports = false # Ignore missing stubs in imported modules [tool.pytest.ini_options] # Run pytest with all our checkers, and don't spam us with massive tracebacks on error @@ -105,6 +109,7 @@ markers = [ "adsim: require the ADsim IOC to be running", ] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" [tool.coverage.run] data_file = "/tmp/ophyd_async.coverage" @@ -127,13 +132,13 @@ passenv = * allowlist_externals = pytest pre-commit - mypy + pyright sphinx-build sphinx-autobuild commands = - tests: pytest --cov=ophyd_async --cov-report term --cov-report xml:cov.xml {posargs} - type-checking: ruff check src tests {posargs} pre-commit: pre-commit run --all-files --show-diff-on-failure {posargs} + type-checking: pyright src {posargs} + tests: pytest --cov=ophyd_async --cov-report term --cov-report xml:cov.xml {posargs} docs: sphinx-{posargs:build -EW --keep-going} -T docs build/html """ @@ -142,12 +147,14 @@ commands = src = ["src", "tests", "system_tests"] line-length = 88 lint.select = [ - "C4", # flake8-comprehensions - https://beta.ruff.rs/docs/rules/#flake8-comprehensions-c4 - "E", # pycodestyle errors - https://beta.ruff.rs/docs/rules/#error-e - "F", # pyflakes rules - https://beta.ruff.rs/docs/rules/#pyflakes-f - "W", # pycodestyle warnings - https://beta.ruff.rs/docs/rules/#warning-w - "I001", # isort - "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self + "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b + "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 + "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e + "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f + "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w + "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i + "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up + "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self ] [tool.ruff.lint.per-file-ignores] diff --git a/src/ophyd_async/__init__.py b/src/ophyd_async/__init__.py index 26d23badb6..2c23d46c22 100644 --- a/src/ophyd_async/__init__.py +++ b/src/ophyd_async/__init__.py @@ -1,3 +1,12 @@ +"""Top level API. + +.. data:: __version__ + :type: str + + Version number as calculated by https://github.com/pypa/setuptools_scm +""" + +from . import core from ._version import __version__ -__all__ = ["__version__"] +__all__ = ["__version__", "core"] diff --git a/src/ophyd_async/__main__.py b/src/ophyd_async/__main__.py index c0c3f0c66d..b4546e3755 100644 --- a/src/ophyd_async/__main__.py +++ b/src/ophyd_async/__main__.py @@ -1,16 +1,24 @@ +"""Interface for ``python -m ophyd_async``.""" + from argparse import ArgumentParser +from collections.abc import Sequence from . import __version__ __all__ = ["main"] -def main(args=None): +def main(args: Sequence[str] | None = None) -> None: + """Argument parser for the CLI.""" parser = ArgumentParser() - parser.add_argument("-v", "--version", action="version", version=__version__) - args = parser.parse_args(args) + parser.add_argument( + "-v", + "--version", + action="version", + version=__version__, + ) + parser.parse_args(args) -# test with: python -m ophyd_async if __name__ == "__main__": main() diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 1928c7aba4..be38555d10 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -70,9 +70,9 @@ from ._status import AsyncStatus, WatchableAsyncStatus, completed_status from ._table import Table from ._utils import ( + CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, - CalculateTimeout, NotConnected, ReadingValueCallback, T, @@ -80,6 +80,7 @@ get_dtype, get_unique, in_micros, + is_pydantic_model, wait_for_connection, ) @@ -154,7 +155,7 @@ "WatchableAsyncStatus", "DEFAULT_TIMEOUT", "CalculatableTimeout", - "CalculateTimeout", + "CALCULATE_TIMEOUT", "NotConnected", "ReadingValueCallback", "Table", @@ -163,6 +164,7 @@ "get_dtype", "get_unique", "in_micros", + "is_pydantic_model", "wait_for_connection", "completed_status", ] diff --git a/src/ophyd_async/core/_detector.py b/src/ophyd_async/core/_detector.py index 374c38adc6..bacd43a279 100644 --- a/src/ophyd_async/core/_detector.py +++ b/src/ophyd_async/core/_detector.py @@ -3,21 +3,14 @@ import asyncio import time from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence from enum import Enum from typing import ( - AsyncGenerator, - AsyncIterator, - Callable, - Dict, Generic, - List, - Optional, - Sequence, ) from bluesky.protocols import ( Collectable, - DataKey, Flyable, Preparable, Reading, @@ -26,10 +19,12 @@ Triggerable, WritesStreamAssets, ) +from event_model import DataKey from pydantic import BaseModel, Field from ._device import Device from ._protocol import AsyncConfigurable, AsyncReadable +from ._signal import SignalR from ._status import AsyncStatus, WatchableAsyncStatus from ._utils import DEFAULT_TIMEOUT, T, WatcherUpdate, merge_gathered_dicts @@ -123,7 +118,7 @@ class DetectorWriter(ABC): (e.g. an HDF5 file)""" @abstractmethod - async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: + async def open(self, multiplier: int = 1) -> dict[str, DataKey]: """Open writer and wait for it to be ready for data. Args: @@ -174,7 +169,7 @@ def __init__( self, controller: DetectorControl, writer: DetectorWriter, - config_sigs: Sequence[AsyncReadable] = (), + config_sigs: Sequence[SignalR] = (), name: str = "", ) -> None: """ @@ -189,17 +184,17 @@ def __init__( """ self._controller = controller self._writer = writer - self._describe: Dict[str, DataKey] = {} + self._describe: dict[str, DataKey] = {} self._config_sigs = list(config_sigs) # For prepare - self._arm_status: Optional[AsyncStatus] = None - self._trigger_info: Optional[TriggerInfo] = None + self._arm_status: AsyncStatus | None = None + self._trigger_info: TriggerInfo | None = None # For kickoff - self._watchers: List[Callable] = [] - self._fly_status: Optional[WatchableAsyncStatus] = None + self._watchers: list[Callable] = [] + self._fly_status: WatchableAsyncStatus | None = None self._fly_start: float self._iterations_completed: int = 0 - self._intial_frame: int + self._initial_frame: int self._last_frame: int super().__init__(name) @@ -213,7 +208,7 @@ def writer(self) -> DetectorWriter: @AsyncStatus.wrap async def stage(self) -> None: - # Disarm the detector, stop filewriting. + # Disarm the detector, stop file writing. await self._check_config_sigs() await asyncio.gather(self.writer.close(), self.controller.disarm()) self._trigger_info = None @@ -227,28 +222,28 @@ async def _check_config_sigs(self): ) try: await signal.get_value() - except NotImplementedError: + except NotImplementedError as e: raise Exception( f"config signal {signal.name} must be connected before it is " + "passed to the detector" - ) + ) from e @AsyncStatus.wrap async def unstage(self) -> None: # Stop data writing. await asyncio.gather(self.writer.close(), self.controller.disarm()) - async def read_configuration(self) -> Dict[str, Reading]: + async def read_configuration(self) -> dict[str, Reading]: return await merge_gathered_dicts(sig.read() for sig in self._config_sigs) - async def describe_configuration(self) -> Dict[str, DataKey]: + async def describe_configuration(self) -> dict[str, DataKey]: return await merge_gathered_dicts(sig.describe() for sig in self._config_sigs) - async def read(self) -> Dict[str, Reading]: + async def read(self) -> dict[str, Reading]: # All data is in StreamResources, not Events, so nothing to output here return {} - async def describe(self) -> Dict[str, DataKey]: + async def describe(self) -> dict[str, DataKey]: return self._describe @AsyncStatus.wrap @@ -347,11 +342,11 @@ async def complete(self): if self._iterations_completed == self._trigger_info.iteration: await self.controller.wait_for_idle() - async def describe_collect(self) -> Dict[str, DataKey]: + async def describe_collect(self) -> dict[str, DataKey]: return self._describe async def collect_asset_docs( - self, index: Optional[int] = None + self, index: int | None = None ) -> AsyncIterator[StreamAsset]: # Collect stream datum documents for all indices written. # The index is optional, and provided for fly scans, however this needs to be diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index 33f8ddb729..b4305c2bfa 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -2,17 +2,12 @@ import asyncio import sys +from collections.abc import Coroutine, Generator, Iterator from functools import cached_property from logging import LoggerAdapter, getLogger from typing import ( Any, - Coroutine, - Dict, - Generator, - Iterator, Optional, - Set, - Tuple, TypeVar, ) @@ -32,7 +27,7 @@ class Device(HasName): #: The parent Device if it exists parent: Optional["Device"] = None # None if connect hasn't started, a Task if it has - _connect_task: Optional[asyncio.Task] = None + _connect_task: asyncio.Task | None = None # Used to check if the previous connect was mocked, # if the next mock value differs then we fail @@ -52,7 +47,7 @@ def log(self): getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name} ) - def children(self) -> Iterator[Tuple[str, "Device"]]: + def children(self) -> Iterator[tuple[str, "Device"]]: for attr_name, attr in self.__dict__.items(): if attr_name != "parent" and isinstance(attr, Device): yield attr_name, attr @@ -127,7 +122,7 @@ async def connect( VT = TypeVar("VT", bound=Device) -class DeviceVector(Dict[int, VT], Device): +class DeviceVector(dict[int, VT], Device): """ Defines device components with indices. @@ -136,7 +131,7 @@ class DeviceVector(Dict[int, VT], Device): :class:`~ophyd_async.epics.demo.DynamicSensorGroup` """ - def children(self) -> Generator[Tuple[str, Device], None, None]: + def children(self) -> Generator[tuple[str, Device], None, None]: for attr_name, attr in self.items(): if isinstance(attr, Device): yield str(attr_name), attr @@ -182,8 +177,8 @@ def __init__( self._connect = connect self._mock = mock self._timeout = timeout - self._names_on_enter: Set[str] = set() - self._objects_on_exit: Dict[str, Any] = {} + self._names_on_enter: set[str] = set() + self._objects_on_exit: dict[str, Any] = {} def _caller_locals(self): """Walk up until we find a stack frame that doesn't have us as self""" @@ -195,6 +190,9 @@ def _caller_locals(self): caller_frame = tb.tb_frame while caller_frame.f_locals.get("self", None) is self: caller_frame = caller_frame.f_back + assert ( + caller_frame + ), "No previous frame to the one with self in it, this shouldn't happen" return caller_frame.f_locals def __enter__(self) -> "DeviceCollector": @@ -207,7 +205,7 @@ async def __aenter__(self) -> "DeviceCollector": async def _on_exit(self) -> None: # Name and kick off connect for devices - connect_coroutines: Dict[str, Coroutine] = {} + connect_coroutines: dict[str, Coroutine] = {} for name, obj in self._objects_on_exit.items(): if name not in self._names_on_enter and isinstance(obj, Device): if self._set_name and not obj.name: @@ -229,10 +227,10 @@ def __exit__(self, type_, value, traceback): self._objects_on_exit = self._caller_locals() try: fut = call_in_bluesky_event_loop(self._on_exit()) - except RuntimeError: + except RuntimeError as e: raise NotConnected( "Could not connect devices. Is the bluesky event loop running? See " "https://blueskyproject.io/ophyd-async/main/" "user/explanations/event-loop-choice.html for more info." - ) + ) from e return fut diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index d847caff69..95e936752b 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -1,5 +1,7 @@ +from collections.abc import Callable, Generator, Sequence from enum import Enum -from typing import Any, Callable, Dict, Generator, List, Optional, Sequence +from pathlib import Path +from typing import Any import numpy as np import numpy.typing as npt @@ -29,12 +31,12 @@ class OphydDumper(yaml.Dumper): def represent_data(self, data: Any) -> Any: if isinstance(data, Enum): return self.represent_data(data.value) - return super(OphydDumper, self).represent_data(data) + return super().represent_data(data) def get_signal_values( - signals: Dict[str, SignalRW[Any]], ignore: Optional[List[str]] = None -) -> Generator[Msg, Sequence[Location[Any]], Dict[str, Any]]: + signals: dict[str, SignalRW[Any]], ignore: list[str] | None = None +) -> Generator[Msg, Sequence[Location[Any]], dict[str, Any]]: """Get signal values in bulk. Used as part of saving the signals of a device to a yaml file. @@ -66,13 +68,10 @@ def get_signal_values( } selected_values = yield Msg("locate", *selected_signals.values()) - # TODO: investigate wrong type hints - if isinstance(selected_values, dict): - selected_values = [selected_values] # type: ignore - assert selected_values is not None, "No signalRW's were able to be located" named_values = { - key: value["setpoint"] for key, value in zip(selected_signals, selected_values) + key: value["setpoint"] + for key, value in zip(selected_signals, selected_values, strict=False) } # Ignored values place in with value None so we know which ones were ignored named_values.update({key: None for key in ignore}) @@ -80,8 +79,8 @@ def get_signal_values( def walk_rw_signals( - device: Device, path_prefix: Optional[str] = "" -) -> Dict[str, SignalRW[Any]]: + device: Device, path_prefix: str | None = "" +) -> dict[str, SignalRW[Any]]: """Retrieve all SignalRWs from a device. Stores retrieved signals with their dotted attribute paths in a dictionary. Used as @@ -111,7 +110,7 @@ def walk_rw_signals( if not path_prefix: path_prefix = "" - signals: Dict[str, SignalRW[Any]] = {} + signals: dict[str, SignalRW[Any]] = {} for attr_name, attr in device.children(): dot_path = f"{path_prefix}{attr_name}" if type(attr) is SignalRW: @@ -121,7 +120,7 @@ def walk_rw_signals( return signals -def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None: +def save_to_yaml(phases: Sequence[dict[str, Any]], save_path: str | Path) -> None: """Plan which serialises a phase or set of phases of SignalRWs to a yaml file. Parameters @@ -151,7 +150,7 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None: yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False) -def load_from_yaml(save_path: str) -> Sequence[Dict[str, Any]]: +def load_from_yaml(save_path: str) -> Sequence[dict[str, Any]]: """Plan that returns a list of dicts with saved signal values from a yaml file. Parameters @@ -164,12 +163,12 @@ def load_from_yaml(save_path: str) -> Sequence[Dict[str, Any]]: :func:`ophyd_async.core.save_to_yaml` :func:`ophyd_async.core.set_signal_values` """ - with open(save_path, "r") as file: + with open(save_path) as file: return yaml.full_load(file) def set_signal_values( - signals: Dict[str, SignalRW[Any]], values: Sequence[Dict[str, Any]] + signals: dict[str, SignalRW[Any]], values: Sequence[dict[str, Any]] ) -> Generator[Msg, None, None]: """Maps signals from a yaml file into device signals. @@ -229,7 +228,7 @@ def load_device(device: Device, path: str): yield from set_signal_values(signals_to_set, values) -def all_at_once(values: Dict[str, Any]) -> Sequence[Dict[str, Any]]: +def all_at_once(values: dict[str, Any]) -> Sequence[dict[str, Any]]: """Sort all the values into a single phase so they are set all at once""" return [values] @@ -237,8 +236,8 @@ def all_at_once(values: Dict[str, Any]) -> Sequence[Dict[str, Any]]: def save_device( device: Device, path: str, - sorter: Callable[[Dict[str, Any]], Sequence[Dict[str, Any]]] = all_at_once, - ignore: Optional[List[str]] = None, + sorter: Callable[[dict[str, Any]], Sequence[dict[str, Any]]] = all_at_once, + ignore: list[str] | None = None, ): """Plan that saves the state of all PV's on a device using a sorter. diff --git a/src/ophyd_async/core/_flyer.py b/src/ophyd_async/core/_flyer.py index 79fff4029c..4414db36c9 100644 --- a/src/ophyd_async/core/_flyer.py +++ b/src/ophyd_async/core/_flyer.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, Generic, Sequence +from collections.abc import Sequence +from typing import Generic -from bluesky.protocols import DataKey, Flyable, Preparable, Reading, Stageable +from bluesky.protocols import Flyable, Preparable, Reading, Stageable +from event_model import DataKey from ._device import Device from ._signal import SignalR @@ -72,12 +74,12 @@ async def kickoff(self) -> None: async def complete(self) -> None: await self._trigger_logic.complete() - async def describe_configuration(self) -> Dict[str, DataKey]: + async def describe_configuration(self) -> dict[str, DataKey]: return await merge_gathered_dicts( [sig.describe() for sig in self._configuration_signals] ) - async def read_configuration(self) -> Dict[str, Reading]: + async def read_configuration(self) -> dict[str, Reading]: return await merge_gathered_dicts( [sig.read() for sig in self._configuration_signals] ) diff --git a/src/ophyd_async/core/_hdf_dataset.py b/src/ophyd_async/core/_hdf_dataset.py index 84f8814921..79cb9c432a 100644 --- a/src/ophyd_async/core/_hdf_dataset.py +++ b/src/ophyd_async/core/_hdf_dataset.py @@ -1,12 +1,13 @@ +from collections.abc import Iterator, Sequence from dataclasses import dataclass, field from pathlib import Path -from typing import Iterator, List, Sequence from urllib.parse import urlunparse from event_model import ( ComposeStreamResource, ComposeStreamResourceBundle, StreamDatum, + StreamRange, StreamResource, ) @@ -19,6 +20,8 @@ class HDFDataset: dtype_numpy: str = "" multiplier: int = 1 swmr: bool = False + # Represents explicit chunk size written to disk. + chunk_shape: tuple[int, ...] = () SLICE_NAME = "AD_HDF5_SWMR_SLICE" @@ -33,7 +36,7 @@ class HDFFile: def __init__( self, full_file_name: Path, - datasets: List[HDFDataset], + datasets: list[HDFDataset], hostname: str = "localhost", ) -> None: self._last_emitted = 0 @@ -56,7 +59,7 @@ def __init__( ) ) - self._bundles: List[ComposeStreamResourceBundle] = [ + self._bundles: list[ComposeStreamResourceBundle] = [ bundler_composer( mimetype="application/x-hdf5", uri=uri, @@ -65,6 +68,7 @@ def __init__( "dataset": ds.dataset, "swmr": ds.swmr, "multiplier": ds.multiplier, + "chunk_shape": ds.chunk_shape, }, uid=None, validate=True, @@ -79,15 +83,10 @@ def stream_resources(self) -> Iterator[StreamResource]: def stream_data(self, indices_written: int) -> Iterator[StreamDatum]: # Indices are relative to resource if indices_written > self._last_emitted: - indices = { + indices: StreamRange = { "start": self._last_emitted, "stop": indices_written, } self._last_emitted = indices_written for bundle in self._bundles: yield bundle.compose_stream_datum(indices) - return None - - def close(self) -> None: - for bundle in self._bundles: - bundle.close() diff --git a/src/ophyd_async/core/_log.py b/src/ophyd_async/core/_log.py index 2f40f2172d..cf015de95e 100644 --- a/src/ophyd_async/core/_log.py +++ b/src/ophyd_async/core/_log.py @@ -29,7 +29,7 @@ class ColoredFormatterWithDeviceName(colorlog.ColoredFormatter): def format(self, record): message = super().format(record) if hasattr(record, "ophyd_async_device_name"): - message = f"[{record.ophyd_async_device_name}]{message}" + message = f"[{record.ophyd_async_device_name}]{message}" # type: ignore return message @@ -39,6 +39,8 @@ def _validate_level(level) -> int: levelno = level elif isinstance(level, str): levelno = logging.getLevelName(level) + else: + raise TypeError(f"Level {level!r} is not an int or str") if isinstance(levelno, int): return levelno diff --git a/src/ophyd_async/core/_mock_signal_backend.py b/src/ophyd_async/core/_mock_signal_backend.py index 97aae72a39..029881cd96 100644 --- a/src/ophyd_async/core/_mock_signal_backend.py +++ b/src/ophyd_async/core/_mock_signal_backend.py @@ -1,6 +1,6 @@ import asyncio +from collections.abc import Callable from functools import cached_property -from typing import Callable, Optional, Type from unittest.mock import AsyncMock from bluesky.protocols import Descriptor, Reading @@ -11,10 +11,12 @@ class MockSignalBackend(SignalBackend[T]): + """Signal backend for testing, created by ``Device.connect(mock=True)``.""" + def __init__( self, - datatype: Optional[Type[T]] = None, - initial_backend: Optional[SignalBackend[T]] = None, + datatype: type[T] | None = None, + initial_backend: SignalBackend[T] | None = None, ) -> None: if isinstance(initial_backend, MockSignalBackend): raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackends") @@ -55,7 +57,7 @@ def put_proceeds(self) -> asyncio.Event: put_proceeds.set() return put_proceeds - async def put(self, value: Optional[T], wait=True, timeout=None): + async def put(self, value: T | None, wait=True, timeout=None): await self.put_mock(value, wait=wait, timeout=timeout) await self.soft_backend.put(value, wait=wait, timeout=timeout) @@ -78,5 +80,5 @@ async def get_setpoint(self) -> T: async def get_datakey(self, source: str) -> Descriptor: return await self.soft_backend.get_datakey(source) - def set_callback(self, callback: Optional[ReadingValueCallback[T]]) -> None: + def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: self.soft_backend.set_callback(callback) diff --git a/src/ophyd_async/core/_mock_signal_utils.py b/src/ophyd_async/core/_mock_signal_utils.py index 76d1a04c12..33c0f677ba 100644 --- a/src/ophyd_async/core/_mock_signal_utils.py +++ b/src/ophyd_async/core/_mock_signal_utils.py @@ -1,5 +1,6 @@ +from collections.abc import Awaitable, Callable, Iterable from contextlib import asynccontextmanager, contextmanager -from typing import Any, Awaitable, Callable, Iterable +from typing import Any from unittest.mock import AsyncMock from ._mock_signal_backend import MockSignalBackend diff --git a/src/ophyd_async/core/_protocol.py b/src/ophyd_async/core/_protocol.py index 79e11a78ae..3978f39cc8 100644 --- a/src/ophyd_async/core/_protocol.py +++ b/src/ophyd_async/core/_protocol.py @@ -4,14 +4,14 @@ from typing import ( TYPE_CHECKING, Any, - Dict, Generic, Protocol, TypeVar, runtime_checkable, ) -from bluesky.protocols import DataKey, HasName, Reading +from bluesky.protocols import HasName, Reading +from event_model import DataKey if TYPE_CHECKING: from ._status import AsyncStatus @@ -20,7 +20,7 @@ @runtime_checkable class AsyncReadable(HasName, Protocol): @abstractmethod - async def read(self) -> Dict[str, Reading]: + async def read(self) -> dict[str, Reading]: """Return an OrderedDict mapping string field name(s) to dictionaries of values and timestamps and optional per-point metadata. @@ -36,7 +36,7 @@ async def read(self) -> Dict[str, Reading]: ... @abstractmethod - async def describe(self) -> Dict[str, DataKey]: + async def describe(self) -> dict[str, DataKey]: """Return an OrderedDict with exactly the same keys as the ``read`` method, here mapped to per-scan metadata about each field. @@ -57,16 +57,16 @@ async def describe(self) -> Dict[str, DataKey]: @runtime_checkable -class AsyncConfigurable(Protocol): +class AsyncConfigurable(HasName, Protocol): @abstractmethod - async def read_configuration(self) -> Dict[str, Reading]: + async def read_configuration(self) -> dict[str, Reading]: """Same API as ``read`` but for slow-changing fields related to configuration. e.g., exposure time. These will typically be read only once per run. """ ... @abstractmethod - async def describe_configuration(self) -> Dict[str, DataKey]: + async def describe_configuration(self) -> dict[str, DataKey]: """Same API as ``describe``, but corresponding to the keys in ``read_configuration``. """ diff --git a/src/ophyd_async/core/_providers.py b/src/ophyd_async/core/_providers.py index ffb4565c90..bb42cb7ff8 100644 --- a/src/ophyd_async/core/_providers.py +++ b/src/ophyd_async/core/_providers.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from datetime import date from pathlib import Path -from typing import List, Optional, Protocol +from typing import Protocol @dataclass @@ -26,13 +26,13 @@ class PathInfo: class FilenameProvider(Protocol): @abstractmethod - def __call__(self, device_name: Optional[str] = None) -> str: + def __call__(self, device_name: str | None = None) -> str: """Get a filename to use for output data, w/o extension""" class PathProvider(Protocol): @abstractmethod - def __call__(self, device_name: Optional[str] = None) -> PathInfo: + def __call__(self, device_name: str | None = None) -> PathInfo: """Get the current directory to write files into""" @@ -40,7 +40,7 @@ class StaticFilenameProvider(FilenameProvider): def __init__(self, filename: str): self._static_filename = filename - def __call__(self, device_name: Optional[str] = None) -> str: + def __call__(self, device_name: str | None = None) -> str: return self._static_filename @@ -48,12 +48,12 @@ class UUIDFilenameProvider(FilenameProvider): def __init__( self, uuid_call_func: Callable = uuid.uuid4, - uuid_call_args: Optional[List] = None, + uuid_call_args: list | None = None, ): self._uuid_call_func = uuid_call_func self._uuid_call_args = uuid_call_args or [] - def __call__(self, device_name: Optional[str] = None) -> str: + def __call__(self, device_name: str | None = None) -> str: if ( self._uuid_call_func in [uuid.uuid3, uuid.uuid5] and len(self._uuid_call_args) < 2 @@ -82,7 +82,7 @@ def __init__( self._increment = increment self._inc_delimeter = inc_delimeter - def __call__(self, device_name: Optional[str] = None) -> str: + def __call__(self, device_name: str | None = None) -> str: if len(str(self._current_value)) > self._max_digits: raise ValueError( f"Auto incrementing filename counter \ @@ -108,7 +108,7 @@ def __init__( self._directory_path = directory_path self._create_dir_depth = create_dir_depth - def __call__(self, device_name: Optional[str] = None) -> PathInfo: + def __call__(self, device_name: str | None = None) -> PathInfo: filename = self._filename_provider(device_name) return PathInfo( @@ -129,7 +129,7 @@ def __init__( num_calls_per_inc: int = 1, increment: int = 1, inc_delimeter: str = "_", - base_name: str = None, + base_name: str | None = None, ) -> None: self._filename_provider = filename_provider self._base_directory_path = base_directory_path @@ -143,7 +143,7 @@ def __init__( self._increment = increment self._inc_delimeter = inc_delimeter - def __call__(self, device_name: Optional[str] = None) -> PathInfo: + def __call__(self, device_name: str | None = None) -> PathInfo: filename = self._filename_provider(device_name) padded_counter = f"{self._current_value:0{self._max_digits}}" @@ -181,7 +181,7 @@ def __init__( self._create_dir_depth = create_dir_depth self._device_name_as_base_dir = device_name_as_base_dir - def __call__(self, device_name: Optional[str] = None) -> PathInfo: + def __call__(self, device_name: str | None = None) -> PathInfo: sep = os.path.sep current_date = date.today().strftime(f"%Y{sep}%m{sep}%d") if device_name is None: diff --git a/src/ophyd_async/core/_readable.py b/src/ophyd_async/core/_readable.py index c63a0f5dcf..111a26d3b1 100644 --- a/src/ophyd_async/core/_readable.py +++ b/src/ophyd_async/core/_readable.py @@ -1,8 +1,9 @@ import warnings +from collections.abc import Callable, Generator, Sequence from contextlib import contextmanager -from typing import Callable, Dict, Generator, Optional, Sequence, Tuple, Type, Union -from bluesky.protocols import DataKey, HasHints, Hints, Reading +from bluesky.protocols import HasHints, Hints, Reading +from event_model import DataKey from ._device import Device, DeviceVector from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable @@ -10,10 +11,12 @@ from ._status import AsyncStatus from ._utils import merge_gathered_dicts -ReadableChild = Union[AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints] -ReadableChildWrapper = Union[ - Callable[[ReadableChild], ReadableChild], Type["ConfigSignal"], Type["HintedSignal"] -] +ReadableChild = AsyncReadable | AsyncConfigurable | AsyncStageable | HasHints +ReadableChildWrapper = ( + Callable[[ReadableChild], ReadableChild] + | type["ConfigSignal"] + | type["HintedSignal"] +) class StandardReadable( @@ -28,10 +31,10 @@ class StandardReadable( # These must be immutable types to avoid accidental sharing between # different instances of the class - _readables: Tuple[AsyncReadable, ...] = () - _configurables: Tuple[AsyncConfigurable, ...] = () - _stageables: Tuple[AsyncStageable, ...] = () - _has_hints: Tuple[HasHints, ...] = () + _readables: tuple[AsyncReadable, ...] = () + _configurables: tuple[AsyncConfigurable, ...] = () + _stageables: tuple[AsyncStageable, ...] = () + _has_hints: tuple[HasHints, ...] = () def set_readable_signals( self, @@ -53,7 +56,8 @@ def set_readable_signals( DeprecationWarning( "Migrate to `add_children_as_readables` context manager or " "`add_readables` method" - ) + ), + stacklevel=2, ) self.add_readables(read, wrapper=HintedSignal) self.add_readables(config, wrapper=ConfigSignal) @@ -69,20 +73,20 @@ async def unstage(self) -> None: for sig in self._stageables: await sig.unstage().task - async def describe_configuration(self) -> Dict[str, DataKey]: + async def describe_configuration(self) -> dict[str, DataKey]: return await merge_gathered_dicts( [sig.describe_configuration() for sig in self._configurables] ) - async def read_configuration(self) -> Dict[str, Reading]: + async def read_configuration(self) -> dict[str, Reading]: return await merge_gathered_dicts( [sig.read_configuration() for sig in self._configurables] ) - async def describe(self) -> Dict[str, DataKey]: + async def describe(self) -> dict[str, DataKey]: return await merge_gathered_dicts([sig.describe() for sig in self._readables]) - async def read(self) -> Dict[str, Reading]: + async def read(self) -> dict[str, Reading]: return await merge_gathered_dicts([sig.read() for sig in self._readables]) @property @@ -123,7 +127,7 @@ def hints(self) -> Hints: @contextmanager def add_children_as_readables( self, - wrapper: Optional[ReadableChildWrapper] = None, + wrapper: ReadableChildWrapper | None = None, ) -> Generator[None, None, None]: """Context manager to wrap adding Devices @@ -167,8 +171,8 @@ def add_children_as_readables( def add_readables( self, - devices: Sequence[Device], - wrapper: Optional[ReadableChildWrapper] = None, + devices: Sequence[ReadableChild], + wrapper: ReadableChildWrapper | None = None, ) -> None: """Add the given devices to the lists of known Devices @@ -216,12 +220,16 @@ def __init__(self, signal: ReadableChild) -> None: assert isinstance(signal, SignalR), f"Expected signal, got {signal}" self.signal = signal - async def read_configuration(self) -> Dict[str, Reading]: + async def read_configuration(self) -> dict[str, Reading]: return await self.signal.read() - async def describe_configuration(self) -> Dict[str, DataKey]: + async def describe_configuration(self) -> dict[str, DataKey]: return await self.signal.describe() + @property + def name(self) -> str: + return self.signal.name + class HintedSignal(HasHints, AsyncReadable): def __init__(self, signal: ReadableChild, allow_cache: bool = True) -> None: @@ -232,10 +240,10 @@ def __init__(self, signal: ReadableChild, allow_cache: bool = True) -> None: self.stage = signal.stage self.unstage = signal.unstage - async def read(self) -> Dict[str, Reading]: + async def read(self) -> dict[str, Reading]: return await self.signal.read(cached=self.cached) - async def describe(self) -> Dict[str, DataKey]: + async def describe(self) -> dict[str, DataKey]: return await self.signal.describe() @property diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 340298160b..d4e4d7bbb9 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -2,22 +2,10 @@ import asyncio import functools -from typing import ( - Any, - AsyncGenerator, - Callable, - Dict, - Generic, - Mapping, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from collections.abc import AsyncGenerator, Callable, Mapping +from typing import Any, Generic, TypeVar, cast from bluesky.protocols import ( - DataKey, Locatable, Location, Movable, @@ -25,6 +13,7 @@ Status, Subscribable, ) +from event_model import DataKey from ._device import Device from ._mock_signal_backend import MockSignalBackend @@ -32,7 +21,7 @@ from ._signal_backend import SignalBackend from ._soft_signal_backend import SignalMetadata, SoftSignalBackend from ._status import AsyncStatus -from ._utils import DEFAULT_TIMEOUT, CalculatableTimeout, CalculateTimeout, Callback, T +from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T S = TypeVar("S") @@ -45,13 +34,26 @@ async def wrapper(self: Signal, *args, **kwargs): return wrapper +def _fail(*args, **kwargs): + raise RuntimeError("Signal has not been supplied a backend yet") + + +class DisconnectedBackend(SignalBackend): + source = connect = put = get_datakey = get_reading = get_value = get_setpoint = ( + set_callback + ) = _fail + + +DISCONNECTED_BACKEND = DisconnectedBackend() + + class Signal(Device, Generic[T]): """A Device with the concept of a value, with R, RW, W and X flavours""" def __init__( self, - backend: Optional[SignalBackend[T]] = None, - timeout: Optional[float] = DEFAULT_TIMEOUT, + backend: SignalBackend[T] = DISCONNECTED_BACKEND, + timeout: float | None = DEFAULT_TIMEOUT, name: str = "", ) -> None: self._timeout = timeout @@ -63,10 +65,13 @@ async def connect( mock=False, timeout=DEFAULT_TIMEOUT, force_reconnect: bool = False, - backend: Optional[SignalBackend[T]] = None, + backend: SignalBackend[T] | None = None, ): if backend: - if self._backend and backend is not self._backend: + if ( + self._backend is not DISCONNECTED_BACKEND + and backend is not self._backend + ): raise ValueError("Backend at connection different from previous one.") self._backend = backend @@ -114,10 +119,10 @@ class _SignalCache(Generic[T]): def __init__(self, backend: SignalBackend[T], signal: Signal): self._signal = signal self._staged = False - self._listeners: Dict[Callback, bool] = {} + self._listeners: dict[Callback, bool] = {} self._valid = asyncio.Event() - self._reading: Optional[Reading] = None - self._value: Optional[T] = None + self._reading: Reading | None = None + self._value: T | None = None self.backend = backend signal.log.debug(f"Making subscription on source {signal.source}") @@ -171,11 +176,9 @@ def set_staged(self, staged: bool): class SignalR(Signal[T], AsyncReadable, AsyncStageable, Subscribable): """Signal that can be read from and monitored""" - _cache: Optional[_SignalCache] = None + _cache: _SignalCache | None = None - def _backend_or_cache( - self, cached: Optional[bool] - ) -> Union[_SignalCache, SignalBackend]: + def _backend_or_cache(self, cached: bool | None) -> _SignalCache | SignalBackend: # If cached is None then calculate it based on whether we already have a cache if cached is None: cached = self._cache is not None @@ -196,17 +199,17 @@ def _del_cache(self, needed: bool): self._cache = None @_add_timeout - async def read(self, cached: Optional[bool] = None) -> Dict[str, Reading]: + async def read(self, cached: bool | None = None) -> dict[str, Reading]: """Return a single item dict with the reading in it""" return {self.name: await self._backend_or_cache(cached).get_reading()} @_add_timeout - async def describe(self) -> Dict[str, DataKey]: + async def describe(self) -> dict[str, DataKey]: """Return a single item dict with the descriptor in it""" return {self.name: await self._backend.get_datakey(self.source)} @_add_timeout - async def get_value(self, cached: Optional[bool] = None) -> T: + async def get_value(self, cached: bool | None = None) -> T: """The current value""" value = await self._backend_or_cache(cached).get_value() self.log.debug(f"get_value() on source {self.source} returned {value}") @@ -216,7 +219,7 @@ def subscribe_value(self, function: Callback[T]): """Subscribe to updates in value of a device""" self._get_cache().subscribe(function, want_value=True) - def subscribe(self, function: Callback[Dict[str, Reading]]) -> None: + def subscribe(self, function: Callback[dict[str, Reading]]) -> None: """Subscribe to updates in the reading""" self._get_cache().subscribe(function, want_value=False) @@ -239,10 +242,10 @@ class SignalW(Signal[T], Movable): """Signal that can be set""" def set( - self, value: T, wait=True, timeout: CalculatableTimeout = CalculateTimeout + self, value: T, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT ) -> AsyncStatus: """Set the value and return a status saying when it's done""" - if timeout is CalculateTimeout: + if timeout is CALCULATE_TIMEOUT: timeout = self._timeout async def do_set(): @@ -270,18 +273,18 @@ class SignalX(Signal): """Signal that puts the default value""" def trigger( - self, wait=True, timeout: CalculatableTimeout = CalculateTimeout + self, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT ) -> AsyncStatus: """Trigger the action and return a status saying when it's done""" - if timeout is CalculateTimeout: + if timeout is CALCULATE_TIMEOUT: timeout = self._timeout coro = self._backend.put(None, wait=wait, timeout=timeout) return AsyncStatus(coro) def soft_signal_rw( - datatype: Optional[Type[T]] = None, - initial_value: Optional[T] = None, + datatype: type[T] | None = None, + initial_value: T | None = None, name: str = "", units: str | None = None, precision: int | None = None, @@ -298,12 +301,12 @@ def soft_signal_rw( def soft_signal_r_and_setter( - datatype: Optional[Type[T]] = None, - initial_value: Optional[T] = None, + datatype: type[T] | None = None, + initial_value: T | None = None, name: str = "", units: str | None = None, precision: int | None = None, -) -> Tuple[SignalR[T], Callable[[T], None]]: +) -> tuple[SignalR[T], Callable[[T], None]]: """Returns a tuple of a read-only Signal and a callable through which the signal can be internally modified within the device. May pass metadata, which are propagated into describe. @@ -316,9 +319,7 @@ def soft_signal_r_and_setter( return (signal, backend.set_value) -def _generate_assert_error_msg( - name: str, expected_result: str, actual_result: str -) -> str: +def _generate_assert_error_msg(name: str, expected_result, actual_result) -> str: WARNING = "\033[93m" FAIL = "\033[91m" ENDC = "\033[0m" @@ -484,14 +485,14 @@ async def get_value(): else: break else: - yield item + yield cast(T, item) finally: signal.clear_sub(q.put_nowait) class _ValueChecker(Generic[T]): def __init__(self, matcher: Callable[[T], bool], matcher_name: str): - self._last_value: Optional[T] = None + self._last_value: T | None = None self._matcher = matcher self._matcher_name = matcher_name @@ -501,7 +502,7 @@ async def _wait_for_value(self, signal: SignalR[T]): if self._matcher(value): return - async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]): + async def wait_for_value(self, signal: SignalR[T], timeout: float | None): try: await asyncio.wait_for(self._wait_for_value(signal), timeout) except asyncio.TimeoutError as e: @@ -513,8 +514,8 @@ async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]): async def wait_for_value( signal: SignalR[T], - match: Union[T, Callable[[T], bool]], - timeout: Optional[float], + match: T | Callable[[T], bool], + timeout: float | None, ): """Wait for a signal to have a matching value. @@ -540,7 +541,7 @@ async def wait_for_value( wait_for_value(device.num_captured, lambda v: v > 45, timeout=1) """ if callable(match): - checker = _ValueChecker(match, match.__name__) + checker = _ValueChecker(match, match.__name__) # type: ignore else: checker = _ValueChecker(lambda v: v == match, repr(match)) await checker.wait_for_value(signal, timeout) @@ -552,7 +553,7 @@ async def set_and_wait_for_other_value( read_signal: SignalR[S], read_value: S, timeout: float = DEFAULT_TIMEOUT, - set_timeout: Optional[float] = None, + set_timeout: float | None = None, ) -> AsyncStatus: """Set a signal and monitor another signal until it has the specified value. @@ -610,7 +611,7 @@ async def set_and_wait_for_value( signal: SignalRW[T], value: T, timeout: float = DEFAULT_TIMEOUT, - status_timeout: Optional[float] = None, + status_timeout: float | None = None, ) -> AsyncStatus: """Set a signal and monitor it until it has that value. diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 594863ef2a..035936f32c 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -1,15 +1,15 @@ from abc import abstractmethod from typing import ( TYPE_CHECKING, + Any, ClassVar, Generic, Literal, - Optional, - Tuple, - Type, ) -from ._protocol import DataKey, Reading +from bluesky.protocols import Reading +from event_model import DataKey + from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T @@ -17,11 +17,11 @@ class SignalBackend(Generic[T]): """A read/write/monitor backend for a Signals""" #: Datatype of the signal value - datatype: Optional[Type[T]] = None + datatype: type[T] | None = None @classmethod @abstractmethod - def datatype_allowed(cls, dtype: type): + def datatype_allowed(cls, dtype: Any) -> bool: """Check if a given datatype is acceptable for this signal backend.""" #: Like ca://PV_PREFIX:SIGNAL @@ -35,7 +35,7 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): """Connect to underlying hardware""" @abstractmethod - async def put(self, value: Optional[T], wait=True, timeout=None): + async def put(self, value: T | None, wait=True, timeout=None): """Put a value to the PV, if wait then wait for completion for up to timeout""" @abstractmethod @@ -55,14 +55,14 @@ async def get_setpoint(self) -> T: """The point that a signal was requested to move to.""" @abstractmethod - def set_callback(self, callback: Optional[ReadingValueCallback[T]]) -> None: + def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: """Observe changes to the current value, timestamp and severity""" class _RuntimeSubsetEnumMeta(type): def __str__(cls): if hasattr(cls, "choices"): - return f"SubsetEnum{list(cls.choices)}" + return f"SubsetEnum{list(cls.choices)}" # type: ignore return "SubsetEnum" def __getitem__(cls, _choices): @@ -85,7 +85,7 @@ class _RuntimeSubsetEnum(cls): class RuntimeSubsetEnum(metaclass=_RuntimeSubsetEnumMeta): - choices: ClassVar[Tuple[str, ...]] + choices: ClassVar[tuple[str, ...]] def __init__(self): raise RuntimeError("SubsetEnum cannot be instantiated") diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 1e895e60cc..eb4aa47d71 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -2,13 +2,14 @@ import inspect import time -from abc import ABCMeta from collections import abc from enum import Enum -from typing import Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin +from typing import Generic, cast, get_origin import numpy as np -from bluesky.protocols import DataKey, Dtype, Reading +from bluesky.protocols import Reading +from event_model import DataKey +from event_model.documents.event_descriptor import Dtype from pydantic import BaseModel from typing_extensions import TypedDict @@ -16,9 +17,15 @@ RuntimeSubsetEnum, SignalBackend, ) -from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T, get_dtype +from ._utils import ( + DEFAULT_TIMEOUT, + ReadingValueCallback, + T, + get_dtype, + is_pydantic_model, +) -primitive_dtypes: Dict[type, Dtype] = { +primitive_dtypes: dict[type, Dtype] = { str: "string", int: "integer", float: "number", @@ -27,8 +34,8 @@ class SignalMetadata(TypedDict): - units: str | None = None - precision: int | None = None + units: str | None + precision: int | None class SoftConverter(Generic[T]): @@ -46,7 +53,7 @@ def reading(self, value: T, timestamp: float, severity: int) -> Reading: ) def get_datakey(self, source: str, value, **metadata) -> DataKey: - dk = {"source": source, "shape": [], **metadata} + dk: DataKey = {"source": source, "shape": [], **metadata} # type: ignore dtype = type(value) if np.issubdtype(dtype, np.integer): dtype = int @@ -56,13 +63,14 @@ def get_datakey(self, source: str, value, **metadata) -> DataKey: dtype in primitive_dtypes ), f"invalid converter for value of type {type(value)}" dk["dtype"] = primitive_dtypes[dtype] + # type ignore until https://github.com/bluesky/event-model/issues/308 try: - dk["dtype_numpy"] = np.dtype(dtype).descr[0][1] + dk["dtype_numpy"] = np.dtype(dtype).descr[0][1] # type: ignore except TypeError: - dk["dtype_numpy"] = "" + dk["dtype_numpy"] = "" # type: ignore return dk - def make_initial_value(self, datatype: Optional[Type[T]]) -> T: + def make_initial_value(self, datatype: type[T] | None) -> T: if datatype is None: return cast(T, None) @@ -81,12 +89,12 @@ def get_datakey(self, source: str, value, **metadata) -> DataKey: return { "source": source, "dtype": "array", - "dtype_numpy": dtype_numpy, + "dtype_numpy": dtype_numpy, # type: ignore "shape": [len(value)], **metadata, } - def make_initial_value(self, datatype: Optional[Type[T]]) -> T: + def make_initial_value(self, datatype: type[T] | None) -> T: if datatype is None: return cast(T, None) @@ -97,28 +105,29 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: class SoftEnumConverter(SoftConverter): - choices: Tuple[str, ...] + choices: tuple[str, ...] - def __init__(self, datatype: Union[RuntimeSubsetEnum, Type[Enum]]): - if issubclass(datatype, Enum): + def __init__(self, datatype: RuntimeSubsetEnum | type[Enum]): + if issubclass(datatype, Enum): # type: ignore self.choices = tuple(v.value for v in datatype) else: self.choices = datatype.choices - def write_value(self, value: Union[Enum, str]) -> str: - return value + def write_value(self, value: Enum | str) -> str: + return value # type: ignore def get_datakey(self, source: str, value, **metadata) -> DataKey: return { "source": source, "dtype": "string", - "dtype_numpy": "|S40", + # type ignore until https://github.com/bluesky/event-model/issues/308 + "dtype_numpy": "|S40", # type: ignore "shape": [], "choices": self.choices, **metadata, } - def make_initial_value(self, datatype: Optional[Type[T]]) -> T: + def make_initial_value(self, datatype: type[T] | None) -> T: if datatype is None: return cast(T, None) @@ -128,7 +137,7 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: class SoftPydanticModelConverter(SoftConverter): - def __init__(self, datatype: Type[BaseModel]): + def __init__(self, datatype: type[BaseModel]): self.datatype = datatype def write_value(self, value): @@ -144,19 +153,12 @@ def make_converter(datatype): issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) ) - is_pydantic_model = ( - inspect.isclass(datatype) - # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ - and isinstance(datatype, ABCMeta) - and issubclass(datatype, BaseModel) - ) - if is_array or is_sequence: return SoftArrayConverter() if is_enum: - return SoftEnumConverter(datatype) - if is_pydantic_model: - return SoftPydanticModelConverter(datatype) + return SoftEnumConverter(datatype) # type: ignore + if is_pydantic_model(datatype): + return SoftPydanticModelConverter(datatype) # type: ignore return SoftConverter() @@ -165,19 +167,19 @@ class SoftSignalBackend(SignalBackend[T]): """An backend to a soft Signal, for test signals see ``MockSignalBackend``.""" _value: T - _initial_value: Optional[T] + _initial_value: T | None _timestamp: float _severity: int @classmethod - def datatype_allowed(cls, datatype: Type) -> bool: + def datatype_allowed(cls, dtype: type) -> bool: return True # Any value allowed in a soft signal def __init__( self, - datatype: Optional[Type[T]], - initial_value: Optional[T] = None, - metadata: SignalMetadata = None, + datatype: type[T] | None, + initial_value: T | None = None, + metadata: SignalMetadata = None, # type: ignore ) -> None: self.datatype = datatype self._initial_value = initial_value @@ -186,11 +188,11 @@ def __init__( if self._initial_value is None: self._initial_value = self.converter.make_initial_value(self.datatype) else: - self._initial_value = self.converter.write_value(self._initial_value) + self._initial_value = self.converter.write_value(self._initial_value) # type: ignore - self.callback: Optional[ReadingValueCallback[T]] = None + self.callback: ReadingValueCallback[T] | None = None self._severity = 0 - self.set_value(self._initial_value) + self.set_value(self._initial_value) # type: ignore def source(self, name: str) -> str: return f"soft://{name}" @@ -199,14 +201,14 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None: """Connection isn't required for soft signals.""" pass - async def put(self, value: Optional[T], wait=True, timeout=None): + async def put(self, value: T | None, wait=True, timeout=None): write_value = ( self.converter.write_value(value) if value is not None else self._initial_value ) - self.set_value(write_value) + self.set_value(write_value) # type: ignore def set_value(self, value: T): """Method to bypass asynchronous logic.""" @@ -232,7 +234,7 @@ async def get_setpoint(self) -> T: """For a soft signal, the setpoint and readback values are the same.""" return await self.get_value() - def set_callback(self, callback: Optional[ReadingValueCallback[T]]) -> None: + def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: if callback: assert not self.callback, "Cannot set a callback when one is already set" reading: Reading = self.converter.reading( diff --git a/src/ophyd_async/core/_status.py b/src/ophyd_async/core/_status.py index ca35362ace..93b9888404 100644 --- a/src/ophyd_async/core/_status.py +++ b/src/ophyd_async/core/_status.py @@ -3,14 +3,10 @@ import asyncio import functools import time +from collections.abc import AsyncIterator, Callable, Coroutine from dataclasses import asdict, replace from typing import ( - AsyncIterator, - Awaitable, - Callable, Generic, - Optional, - Type, TypeVar, cast, ) @@ -27,7 +23,7 @@ class AsyncStatusBase(Status): """Convert asyncio awaitable to bluesky Status interface""" - def __init__(self, awaitable: Awaitable): + def __init__(self, awaitable: Coroutine | asyncio.Task): if isinstance(awaitable, asyncio.Task): self.task = awaitable else: @@ -86,8 +82,10 @@ def __repr__(self) -> str: class AsyncStatus(AsyncStatusBase): + """Convert asyncio awaitable to bluesky Status interface""" + @classmethod - def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]: + def wrap(cls: type[AS], f: Callable[P, Coroutine]) -> Callable[P, AS]: """Wrap an async function in an AsyncStatus.""" @functools.wraps(f) @@ -131,7 +129,7 @@ def watch(self, watcher: Watcher): @classmethod def wrap( - cls: Type[WAS], + cls: type[WAS], f: Callable[P, AsyncIterator[WatcherUpdate[T]]], ) -> Callable[P, WAS]: """Wrap an AsyncIterator in a WatchableAsyncStatus.""" @@ -144,7 +142,7 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS: @AsyncStatus.wrap -async def completed_status(exception: Optional[Exception] = None): +async def completed_status(exception: Exception | None = None): if exception: raise exception return None diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py index 6f10dfe537..f36b60dceb 100644 --- a/src/ophyd_async/core/_table.py +++ b/src/ophyd_async/core/_table.py @@ -1,9 +1,11 @@ from enum import Enum -from typing import get_args, get_origin +from typing import TypeVar, get_args, get_origin import numpy as np from pydantic import BaseModel, ConfigDict, model_validator +TableSubclass = TypeVar("TableSubclass", bound="Table") + def _concat(value1, value2): if isinstance(value1, np.ndarray): @@ -17,10 +19,10 @@ class Table(BaseModel): model_config = ConfigDict(validate_assignment=True, strict=False) - @classmethod - def row(cls, sub_cls, **kwargs) -> "Table": + @staticmethod + def row(cls: type[TableSubclass], **kwargs) -> TableSubclass: # type: ignore arrayified_kwargs = {} - for field_name, field_value in sub_cls.model_fields.items(): + for field_name, field_value in cls.model_fields.items(): value = kwargs.pop(field_name) if field_value.default_factory is None: raise ValueError( @@ -40,21 +42,20 @@ def row(cls, sub_cls, **kwargs) -> "Table": ) if kwargs: raise TypeError( - f"Unexpected keyword arguments {kwargs.keys()} for {sub_cls.__name__}." + f"Unexpected keyword arguments {kwargs.keys()} for {cls.__name__}." ) + return cls(**arrayified_kwargs) - return sub_cls(**arrayified_kwargs) - - def __add__(self, right: "Table") -> "Table": + def __add__(self, right: TableSubclass) -> TableSubclass: """Concatenate the arrays in field values.""" - if not isinstance(right, type(self)): + if type(right) is not type(self): raise RuntimeError( f"{right} is not a `Table`, or is not the same " f"type of `Table` as {self}." ) - return type(self)( + return type(right)( **{ field_name: _concat( getattr(self, field_name), getattr(right, field_name) @@ -85,7 +86,8 @@ def numpy_table(self): # but it defaults to the largest dtype for everything. dtype = self.numpy_dtype() transposed_list = [ - np.array(tuple(row), dtype=dtype) for row in zip(*self.numpy_columns()) + np.array(tuple(row), dtype=dtype) + for row in zip(*self.numpy_columns(), strict=False) ] transposed = np.array(transposed_list, dtype=dtype) return transposed @@ -130,7 +132,7 @@ def validate_arrays(self) -> "Table": # or if the value is a string enum. np.issubdtype(getattr(self, field_name).dtype, default_array.dtype) if isinstance( - default_array := self.model_fields[field_name].default_factory(), + default_array := self.model_fields[field_name].default_factory(), # type: ignore np.ndarray, ) else issubclass(get_args(field_value.annotation)[0], Enum) diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index d081ed008f..8c90639e21 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -2,23 +2,13 @@ import asyncio import logging +from collections.abc import Awaitable, Callable, Iterable from dataclasses import dataclass -from typing import ( - Awaitable, - Callable, - Dict, - Generic, - Iterable, - List, - Optional, - ParamSpec, - Type, - TypeVar, - Union, -) +from typing import Generic, Literal, ParamSpec, TypeVar, get_origin import numpy as np from bluesky.protocols import Reading +from pydantic import BaseModel T = TypeVar("T") P = ParamSpec("P") @@ -28,18 +18,18 @@ #: monitor updates ReadingValueCallback = Callable[[Reading, T], None] DEFAULT_TIMEOUT = 10.0 -ErrorText = Union[str, Dict[str, Exception]] +ErrorText = str | dict[str, Exception] -class CalculateTimeout: - """Sentinel class used to implement ``myfunc(timeout=CalculateTimeout)`` +CALCULATE_TIMEOUT = "CALCULATE_TIMEOUT" +"""Sentinel used to implement ``myfunc(timeout=CalculateTimeout)`` - This signifies that the function should calculate a suitable non-zero - timeout itself - """ +This signifies that the function should calculate a suitable non-zero +timeout itself +""" -CalculatableTimeout = float | None | Type[CalculateTimeout] +CalculatableTimeout = float | None | Literal["CALCULATE_TIMEOUT"] class NotConnected(Exception): @@ -115,7 +105,7 @@ async def wait_for_connection(**coros: Awaitable[None]): results = await asyncio.gather(*coros.values(), return_exceptions=True) exceptions = {} - for name, result in zip(coros, results): + for name, result in zip(coros, results, strict=False): if isinstance(result, Exception): exceptions[name] = result if not isinstance(result, NotConnected): @@ -129,7 +119,7 @@ async def wait_for_connection(**coros: Awaitable[None]): raise NotConnected(exceptions) -def get_dtype(typ: Type) -> Optional[np.dtype]: +def get_dtype(typ: type) -> np.dtype | None: """Get the runtime dtype from a numpy ndarray type annotation >>> import numpy.typing as npt @@ -144,7 +134,7 @@ def get_dtype(typ: Type) -> Optional[np.dtype]: return None -def get_unique(values: Dict[str, T], types: str) -> T: +def get_unique(values: dict[str, T], types: str) -> T: """If all values are the same, return that value, otherwise raise TypeError >>> get_unique({"a": 1, "b": 1}, "integers") @@ -162,21 +152,21 @@ def get_unique(values: Dict[str, T], types: str) -> T: async def merge_gathered_dicts( - coros: Iterable[Awaitable[Dict[str, T]]], -) -> Dict[str, T]: + coros: Iterable[Awaitable[dict[str, T]]], +) -> dict[str, T]: """Merge dictionaries produced by a sequence of coroutines. Can be used for merging ``read()`` or ``describe``. For instance:: combined_read = await merge_gathered_dicts(s.read() for s in signals) """ - ret: Dict[str, T] = {} + ret: dict[str, T] = {} for result in await asyncio.gather(*coros): ret.update(result) return ret -async def gather_list(coros: Iterable[Awaitable[T]]) -> List[T]: +async def gather_list(coros: Iterable[Awaitable[T]]) -> list[T]: return await asyncio.gather(*coros) @@ -195,3 +185,9 @@ def in_micros(t: float) -> int: if t < 0: raise ValueError(f"Expected a positive time in seconds, got {t!r}") return int(np.ceil(t * 1e6)) + + +def is_pydantic_model(datatype) -> bool: + while origin := get_origin(datatype): + datatype = origin + return datatype and issubclass(datatype, BaseModel) diff --git a/src/ophyd_async/epics/adaravis/_aravis_controller.py b/src/ophyd_async/epics/adaravis/_aravis_controller.py index 894a46c008..80b826db2d 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_controller.py +++ b/src/ophyd_async/epics/adaravis/_aravis_controller.py @@ -1,5 +1,5 @@ import asyncio -from typing import Literal, Tuple +from typing import Literal from ophyd_async.core import ( DetectorControl, @@ -26,7 +26,7 @@ def __init__(self, driver: AravisDriverIO, gpio_number: GPIO_NUMBER) -> None: self.gpio_number = gpio_number self._arm_status: AsyncStatus | None = None - def get_deadtime(self, exposure: float) -> float: + def get_deadtime(self, exposure: float | None) -> float: return _HIGHEST_POSSIBLE_DEADTIME async def prepare(self, trigger_info: TriggerInfo): @@ -56,7 +56,7 @@ async def wait_for_idle(self): def _get_trigger_info( self, trigger: DetectorTrigger - ) -> Tuple[AravisTriggerMode, AravisTriggerSource]: + ) -> tuple[AravisTriggerMode, AravisTriggerSource]: supported_trigger_types = ( DetectorTrigger.constant_gate, DetectorTrigger.edge_trigger, @@ -71,7 +71,7 @@ def _get_trigger_info( if trigger == DetectorTrigger.internal: return AravisTriggerMode.off, "Freerun" else: - return (AravisTriggerMode.on, f"Line{self.gpio_number}") + return (AravisTriggerMode.on, f"Line{self.gpio_number}") # type: ignore async def disarm(self): await adcore.stop_busy_record(self._drv.acquire, False, timeout=1) diff --git a/src/ophyd_async/epics/adaravis/_aravis_io.py b/src/ophyd_async/epics/adaravis/_aravis_io.py index 83da702a4f..27c2898513 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_io.py +++ b/src/ophyd_async/epics/adaravis/_aravis_io.py @@ -38,6 +38,7 @@ def __init__(self, prefix: str, name: str = "") -> None: AravisTriggerMode, prefix + "TriggerMode" ) self.trigger_source = epics_signal_rw_rbv( - AravisTriggerSource, prefix + "TriggerSource" + AravisTriggerSource, # type: ignore + prefix + "TriggerSource", ) super().__init__(prefix, name=name) diff --git a/src/ophyd_async/epics/adcore/_core_io.py b/src/ophyd_async/epics/adcore/_core_io.py index f15d48cd2e..7968579117 100644 --- a/src/ophyd_async/epics/adcore/_core_io.py +++ b/src/ophyd_async/epics/adcore/_core_io.py @@ -135,4 +135,6 @@ def __init__(self, prefix: str, name="") -> None: self.array_size0 = epics_signal_r(int, prefix + "ArraySize0") self.array_size1 = epics_signal_r(int, prefix + "ArraySize1") self.create_directory = epics_signal_rw(int, prefix + "CreateDirectory") + self.num_frames_chunks = epics_signal_r(int, prefix + "NumFramesChunks_RBV") + self.chunk_size_auto = epics_signal_rw_rbv(bool, prefix + "ChunkSizeAuto") super().__init__(prefix, name) diff --git a/src/ophyd_async/epics/adcore/_core_logic.py b/src/ophyd_async/epics/adcore/_core_logic.py index 21b07406fb..3c12c293f9 100644 --- a/src/ophyd_async/epics/adcore/_core_logic.py +++ b/src/ophyd_async/epics/adcore/_core_logic.py @@ -1,5 +1,4 @@ import asyncio -from typing import FrozenSet, Set from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -14,7 +13,7 @@ # Default set of states that we should consider "good" i.e. the acquisition # is complete and went well -DEFAULT_GOOD_STATES: FrozenSet[DetectorState] = frozenset( +DEFAULT_GOOD_STATES: frozenset[DetectorState] = frozenset( [DetectorState.Idle, DetectorState.Aborted] ) @@ -66,7 +65,7 @@ async def set_exposure_time_and_acquire_period_if_supplied( async def start_acquiring_driver_and_ensure_status( driver: ADBaseIO, - good_states: Set[DetectorState] = set(DEFAULT_GOOD_STATES), + good_states: frozenset[DetectorState] = frozenset(DEFAULT_GOOD_STATES), timeout: float = DEFAULT_TIMEOUT, ) -> AsyncStatus: """ diff --git a/src/ophyd_async/epics/adcore/_hdf_writer.py b/src/ophyd_async/epics/adcore/_hdf_writer.py index 2f0e3a766e..bfffa67b89 100644 --- a/src/ophyd_async/epics/adcore/_hdf_writer.py +++ b/src/ophyd_async/epics/adcore/_hdf_writer.py @@ -1,9 +1,10 @@ import asyncio +from collections.abc import AsyncGenerator, AsyncIterator from pathlib import Path -from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional from xml.etree import ElementTree as ET -from bluesky.protocols import DataKey, Hints, StreamAsset +from bluesky.protocols import Hints, StreamAsset +from event_model import DataKey from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -42,19 +43,22 @@ def __init__( self._dataset_describer = dataset_describer self._plugins = plugins - self._capture_status: Optional[AsyncStatus] = None - self._datasets: List[HDFDataset] = [] - self._file: Optional[HDFFile] = None + self._capture_status: AsyncStatus | None = None + self._datasets: list[HDFDataset] = [] + self._file: HDFFile | None = None self._multiplier = 1 - async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: + async def open(self, multiplier: int = 1) -> dict[str, DataKey]: self._file = None - info = self._path_provider(device_name=self.hdf.name) + info = self._path_provider(device_name=self._name_provider()) # Set the directory creation depth first, since dir creation callback happens # when directory path PV is processed. await self.hdf.create_directory.set(info.create_dir_depth) + # Make sure we are using chunk auto-sizing + await asyncio.gather(self.hdf.chunk_size_auto.set(True)) + await asyncio.gather( self.hdf.num_extra_dims.set(0), self.hdf.lazy_open.set(True), @@ -83,6 +87,9 @@ async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: self._multiplier = multiplier outer_shape = (multiplier,) if multiplier > 1 else () + # Determine number of frames that will be saved per HDF chunk + frames_per_chunk = await self.hdf.num_frames_chunks.get_value() + # Add the main data self._datasets = [ HDFDataset( @@ -91,6 +98,7 @@ async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: shape=detector_shape, dtype_numpy=np_dtype, multiplier=multiplier, + chunk_shape=(frames_per_chunk, *detector_shape), ) ] # And all the scalar datasets @@ -117,6 +125,9 @@ async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: (), np_datatype, multiplier, + # NDAttributes appear to always be configured with + # this chunk size + chunk_shape=(16384,), ) ) @@ -125,7 +136,7 @@ async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: source=self.hdf.full_file_name.source, shape=outer_shape + tuple(ds.shape), dtype="array" if ds.shape else "number", - dtype_numpy=ds.dtype_numpy, + dtype_numpy=ds.dtype_numpy, # type: ignore external="STREAM:", ) for ds in self._datasets diff --git a/src/ophyd_async/epics/adcore/_single_trigger.py b/src/ophyd_async/epics/adcore/_single_trigger.py index 8cad9420b5..c39e4e46ab 100644 --- a/src/ophyd_async/epics/adcore/_single_trigger.py +++ b/src/ophyd_async/epics/adcore/_single_trigger.py @@ -1,5 +1,5 @@ import asyncio -from typing import Sequence +from collections.abc import Sequence from bluesky.protocols import Triggerable diff --git a/src/ophyd_async/epics/adcore/_utils.py b/src/ophyd_async/epics/adcore/_utils.py index a1dc25b934..a1a21b6071 100644 --- a/src/ophyd_async/epics/adcore/_utils.py +++ b/src/ophyd_async/epics/adcore/_utils.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional from ophyd_async.core import DEFAULT_TIMEOUT, SignalRW, T, wait_for_value from ophyd_async.core._signal import SignalR @@ -51,8 +50,8 @@ def convert_pv_dtype_to_np(datatype: str) -> str: else: try: np_datatype = convert_ad_dtype_to_np(_pvattribute_to_ad_datatype[datatype]) - except KeyError: - raise ValueError(f"Invalid dbr type {datatype}") + except KeyError as e: + raise ValueError(f"Invalid dbr type {datatype}") from e return np_datatype @@ -69,8 +68,8 @@ def convert_param_dtype_to_np(datatype: str) -> str: np_datatype = convert_ad_dtype_to_np( _paramattribute_to_ad_datatype[datatype] ) - except KeyError: - raise ValueError(f"Invalid datatype {datatype}") + except KeyError as e: + raise ValueError(f"Invalid datatype {datatype}") from e return np_datatype @@ -126,7 +125,7 @@ async def stop_busy_record( signal: SignalRW[T], value: T, timeout: float = DEFAULT_TIMEOUT, - status_timeout: Optional[float] = None, + status_timeout: float | None = None, ) -> None: await signal.set(value, wait=False, timeout=status_timeout) await wait_for_value(signal, value, timeout=timeout) diff --git a/src/ophyd_async/epics/adkinetix/_kinetix_controller.py b/src/ophyd_async/epics/adkinetix/_kinetix_controller.py index 70d32e6a78..acf95850ef 100644 --- a/src/ophyd_async/epics/adkinetix/_kinetix_controller.py +++ b/src/ophyd_async/epics/adkinetix/_kinetix_controller.py @@ -23,7 +23,7 @@ def __init__( self._drv = driver self._arm_status: AsyncStatus | None = None - def get_deadtime(self, exposure: float) -> float: + def get_deadtime(self, exposure: float | None) -> float: return 0.001 async def prepare(self, trigger_info: TriggerInfo): diff --git a/src/ophyd_async/epics/adpilatus/_pilatus_controller.py b/src/ophyd_async/epics/adpilatus/_pilatus_controller.py index 54e0d41d5d..9e8bd54aef 100644 --- a/src/ophyd_async/epics/adpilatus/_pilatus_controller.py +++ b/src/ophyd_async/epics/adpilatus/_pilatus_controller.py @@ -29,7 +29,7 @@ def __init__( self._readout_time = readout_time self._arm_status: AsyncStatus | None = None - def get_deadtime(self, exposure: float) -> float: + def get_deadtime(self, exposure: float | None) -> float: return self._readout_time async def prepare(self, trigger_info: TriggerInfo): diff --git a/src/ophyd_async/epics/adsimdetector/_sim.py b/src/ophyd_async/epics/adsimdetector/_sim.py index c007c72ffc..5a23b47744 100644 --- a/src/ophyd_async/epics/adsimdetector/_sim.py +++ b/src/ophyd_async/epics/adsimdetector/_sim.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence from ophyd_async.core import PathProvider, SignalR, StandardDetector from ophyd_async.epics import adcore diff --git a/src/ophyd_async/epics/adsimdetector/_sim_controller.py b/src/ophyd_async/epics/adsimdetector/_sim_controller.py index 6561ee24f1..10b8516ece 100644 --- a/src/ophyd_async/epics/adsimdetector/_sim_controller.py +++ b/src/ophyd_async/epics/adsimdetector/_sim_controller.py @@ -1,5 +1,4 @@ import asyncio -from typing import Set from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -15,14 +14,14 @@ class SimController(DetectorControl): def __init__( self, driver: adcore.ADBaseIO, - good_states: Set[adcore.DetectorState] = set(adcore.DEFAULT_GOOD_STATES), + good_states: frozenset[adcore.DetectorState] = adcore.DEFAULT_GOOD_STATES, ) -> None: self.driver = driver self.good_states = good_states self.frame_timeout: float self._arm_status: AsyncStatus | None = None - def get_deadtime(self, exposure: float) -> float: + def get_deadtime(self, exposure: float | None) -> float: return 0.002 async def prepare(self, trigger_info: TriggerInfo): diff --git a/src/ophyd_async/epics/advimba/_vimba_controller.py b/src/ophyd_async/epics/advimba/_vimba_controller.py index 9b87d37872..f9ce2a8d02 100644 --- a/src/ophyd_async/epics/advimba/_vimba_controller.py +++ b/src/ophyd_async/epics/advimba/_vimba_controller.py @@ -30,7 +30,7 @@ def __init__( self._drv = driver self._arm_status: AsyncStatus | None = None - def get_deadtime(self, exposure: float) -> float: + def get_deadtime(self, exposure: float | None) -> float: return 0.001 async def prepare(self, trigger_info: TriggerInfo): diff --git a/src/ophyd_async/epics/demo/_mover.py b/src/ophyd_async/epics/demo/_mover.py index 31aa198e3d..72266de846 100644 --- a/src/ophyd_async/epics/demo/_mover.py +++ b/src/ophyd_async/epics/demo/_mover.py @@ -4,10 +4,10 @@ from bluesky.protocols import Movable, Stoppable from ophyd_async.core import ( + CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, AsyncStatus, CalculatableTimeout, - CalculateTimeout, ConfigSignal, Device, HintedSignal, @@ -44,9 +44,8 @@ def set_name(self, name: str): self.readback.set_name(name) @WatchableAsyncStatus.wrap - async def set( - self, new_position: float, timeout: CalculatableTimeout = CalculateTimeout - ): + async def set(self, value: float, timeout: CalculatableTimeout = CALCULATE_TIMEOUT): + new_position = value self._set_success = True old_position, units, precision, velocity = await asyncio.gather( self.setpoint.get_value(), @@ -54,7 +53,7 @@ async def set( self.precision.get_value(), self.velocity.get_value(), ) - if timeout is CalculateTimeout: + if timeout == CALCULATE_TIMEOUT: assert velocity > 0, "Mover has zero velocity" timeout = abs(new_position - old_position) / velocity + DEFAULT_TIMEOUT # Make an Event that will be set on completion, and a Status that will diff --git a/src/ophyd_async/epics/demo/sensor.db b/src/ophyd_async/epics/demo/sensor.db index 9912bb3cae..95cba4b872 100644 --- a/src/ophyd_async/epics/demo/sensor.db +++ b/src/ophyd_async/epics/demo/sensor.db @@ -17,4 +17,3 @@ record(calc, "$(P)Value") { field(EGU, "$(EGU=cts/s)") field(PREC, "$(PREC=3)") } - diff --git a/src/ophyd_async/epics/eiger/_eiger.py b/src/ophyd_async/epics/eiger/_eiger.py index e7d60786ec..bc898d3660 100644 --- a/src/ophyd_async/epics/eiger/_eiger.py +++ b/src/ophyd_async/epics/eiger/_eiger.py @@ -38,6 +38,6 @@ def __init__( ) @AsyncStatus.wrap - async def prepare(self, value: EigerTriggerInfo) -> None: + async def prepare(self, value: EigerTriggerInfo) -> None: # type: ignore await self._controller.set_energy(value.energy_ev) await super().prepare(value) diff --git a/src/ophyd_async/epics/eiger/_eiger_controller.py b/src/ophyd_async/epics/eiger/_eiger_controller.py index c7542bc741..bed28c2d49 100644 --- a/src/ophyd_async/epics/eiger/_eiger_controller.py +++ b/src/ophyd_async/epics/eiger/_eiger_controller.py @@ -25,7 +25,7 @@ def __init__( ) -> None: self._drv = driver - def get_deadtime(self, exposure: float) -> float: + def get_deadtime(self, exposure: float | None) -> float: # See https://media.dectris.com/filer_public/30/14/3014704e-5f3b-43ba-8ccf-8ef720e60d2a/240202_usermanual_eiger2.pdf return 0.0001 diff --git a/src/ophyd_async/epics/eiger/_odin_io.py b/src/ophyd_async/epics/eiger/_odin_io.py index 0d1b3516d9..c5a38a669b 100644 --- a/src/ophyd_async/epics/eiger/_odin_io.py +++ b/src/ophyd_async/epics/eiger/_odin_io.py @@ -1,9 +1,9 @@ import asyncio +from collections.abc import AsyncGenerator, AsyncIterator from enum import Enum -from typing import AsyncGenerator, AsyncIterator, Dict from bluesky.protocols import StreamAsset -from event_model.documents.event_descriptor import DataKey +from event_model import DataKey from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -77,7 +77,7 @@ def __init__( self._name_provider = name_provider super().__init__() - async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: + async def open(self, multiplier: int = 1) -> dict[str, DataKey]: info = self._path_provider(device_name=self._name_provider()) await asyncio.gather( @@ -93,7 +93,7 @@ async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: return await self._describe() - async def _describe(self) -> Dict[str, DataKey]: + async def _describe(self) -> dict[str, DataKey]: data_shape = await asyncio.gather( self._drv.image_height.get_value(), self._drv.image_width.get_value() ) @@ -103,7 +103,8 @@ async def _describe(self) -> Dict[str, DataKey]: source=self._drv.file_name.source, shape=data_shape, dtype="array", - dtype_numpy=" None: self._set_success = True # end_position of a fly move, with run_up_distance added on. - self._fly_completed_position: Optional[float] = None + self._fly_completed_position: float | None = None # Set on kickoff(), complete when motor reaches self._fly_completed_position - self._fly_status: Optional[WatchableAsyncStatus] = None + self._fly_status: WatchableAsyncStatus | None = None # Set during prepare - self._fly_timeout: Optional[CalculatableTimeout] = CalculateTimeout + self._fly_timeout: CalculatableTimeout | None = CALCULATE_TIMEOUT super().__init__(name=name) @@ -138,9 +137,8 @@ def complete(self) -> WatchableAsyncStatus: return self._fly_status @WatchableAsyncStatus.wrap - async def set( - self, new_position: float, timeout: CalculatableTimeout = CalculateTimeout - ): + async def set(self, value: float, timeout: CalculatableTimeout = CALCULATE_TIMEOUT): + new_position = value self._set_success = True ( old_position, @@ -155,7 +153,7 @@ async def set( self.velocity.get_value(), self.acceleration_time.get_value(), ) - if timeout is CalculateTimeout: + if timeout is CALCULATE_TIMEOUT: assert velocity > 0, "Motor has zero velocity" timeout = ( abs(new_position - old_position) / velocity diff --git a/src/ophyd_async/epics/pvi/_pvi.py b/src/ophyd_async/epics/pvi/_pvi.py index a2d8cf5f24..8182bc8f29 100644 --- a/src/ophyd_async/epics/pvi/_pvi.py +++ b/src/ophyd_async/epics/pvi/_pvi.py @@ -1,15 +1,11 @@ import re +import types +from collections.abc import Callable from dataclasses import dataclass from inspect import isclass from typing import ( Any, - Callable, - Dict, - FrozenSet, Literal, - Optional, - Tuple, - Type, Union, get_args, get_origin, @@ -32,23 +28,24 @@ epics_signal_x, ) -Access = FrozenSet[ - Union[Literal["r"], Literal["w"], Literal["rw"], Literal["x"], Literal["d"]] +Access = frozenset[ + Literal["r"] | Literal["w"] | Literal["rw"] | Literal["x"] | Literal["d"] ] -def _strip_number_from_string(string: str) -> Tuple[str, Optional[int]]: +def _strip_number_from_string(string: str) -> tuple[str, int | None]: match = re.match(r"(.*?)(\d*)$", string) assert match name = match.group(1) number = match.group(2) or None - if number: - number = int(number) - return name, number + if number is None: + return name, None + else: + return name, int(number) -def _split_subscript(tp: T) -> Union[Tuple[Any, Tuple[Any]], Tuple[T, None]]: +def _split_subscript(tp: T) -> tuple[Any, tuple[Any]] | tuple[T, None]: """Split a subscripted type into the its origin and args. If `tp` is not a subscripted type, then just return the type and None as args. @@ -60,8 +57,8 @@ def _split_subscript(tp: T) -> Union[Tuple[Any, Tuple[Any]], Tuple[T, None]]: return tp, None -def _strip_union(field: Union[Union[T], T]) -> Tuple[T, bool]: - if get_origin(field) is Union: +def _strip_union(field: T | T) -> tuple[T, bool]: + if get_origin(field) in [Union, types.UnionType]: args = get_args(field) is_optional = type(None) in args for arg in args: @@ -70,7 +67,7 @@ def _strip_union(field: Union[Union[T], T]) -> Tuple[T, bool]: return field, False -def _strip_device_vector(field: Union[Type[Device]]) -> Tuple[bool, Type[Device]]: +def _strip_device_vector(field: type[Device]) -> tuple[bool, type[Device]]: if get_origin(field) is DeviceVector: return True, get_args(field)[0] return False, field @@ -83,13 +80,13 @@ class _PVIEntry: This could either be a signal or a sub-table. """ - sub_entries: Dict[str, Union[Dict[int, "_PVIEntry"], "_PVIEntry"]] - pvi_pv: Optional[str] = None - device: Optional[Device] = None - common_device_type: Optional[Type[Device]] = None + sub_entries: dict[str, Union[dict[int, "_PVIEntry"], "_PVIEntry"]] + pvi_pv: str | None = None + device: Device | None = None + common_device_type: type[Device] | None = None -def _verify_common_blocks(entry: _PVIEntry, common_device: Type[Device]): +def _verify_common_blocks(entry: _PVIEntry, common_device: type[Device]): if not entry.sub_entries: return common_sub_devices = get_type_hints(common_device) @@ -107,12 +104,12 @@ def _verify_common_blocks(entry: _PVIEntry, common_device: Type[Device]): _verify_common_blocks(sub_sub_entry, sub_device) # type: ignore else: _verify_common_blocks( - entry.sub_entries[sub_name], + entry.sub_entries[sub_name], # type: ignore sub_device, # type: ignore ) -_pvi_mapping: Dict[FrozenSet[str], Callable[..., Signal]] = { +_pvi_mapping: dict[frozenset[str], Callable[..., Signal]] = { frozenset({"r", "w"}): lambda dtype, read_pv, write_pv: epics_signal_rw( dtype, "pva://" + read_pv, "pva://" + write_pv ), @@ -129,8 +126,8 @@ def _verify_common_blocks(entry: _PVIEntry, common_device: Type[Device]): def _parse_type( is_pvi_table: bool, - number_suffix: Optional[int], - common_device_type: Optional[Type[Device]], + number_suffix: int | None, + common_device_type: type[Device] | None, ): if common_device_type: # pre-defined type @@ -159,7 +156,7 @@ def _parse_type( return is_device_vector, is_signal, signal_dtype, device_cls -def _mock_common_blocks(device: Device, stripped_type: Optional[Type] = None): +def _mock_common_blocks(device: Device, stripped_type: type | None = None): device_t = stripped_type or type(device) sub_devices = ( (field, field_type) @@ -173,11 +170,10 @@ def _mock_common_blocks(device: Device, stripped_type: Optional[Type] = None): device_cls, device_args = _split_subscript(device_cls) assert issubclass(device_cls, Device) - is_signal = issubclass(device_cls, Signal) signal_dtype = device_args[0] if device_args is not None else None if is_device_vector: - if is_signal: + if issubclass(device_cls, Signal): sub_device_1 = device_cls(SoftSignalBackend(signal_dtype)) sub_device_2 = device_cls(SoftSignalBackend(signal_dtype)) sub_device = DeviceVector({1: sub_device_1, 2: sub_device_2}) @@ -198,7 +194,7 @@ def _mock_common_blocks(device: Device, stripped_type: Optional[Type] = None): for value in sub_device.values(): value.parent = sub_device else: - if is_signal: + if issubclass(device_cls, Signal): sub_device = device_cls(SoftSignalBackend(signal_dtype)) else: sub_device = getattr(device, device_name, device_cls()) @@ -271,7 +267,8 @@ def _set_device_attributes(entry: _PVIEntry): # Set the device vector entry to have the device vector as a parent device_vector_sub_entry.device.parent = sub_device # type: ignore else: - sub_device = sub_entry.device # type: ignore + sub_device = sub_entry.device + assert sub_device, f"Device of {sub_entry} is None" if sub_entry.pvi_pv: _set_device_attributes(sub_entry) @@ -308,8 +305,8 @@ async def fill_pvi_entries( def create_children_from_annotations( device: Device, - included_optional_fields: Tuple[str, ...] = (), - device_vectors: Optional[Dict[str, int]] = None, + included_optional_fields: tuple[str, ...] = (), + device_vectors: dict[str, int] | None = None, ): """For intializing blocks at __init__ of ``device``.""" for name, device_type in get_type_hints(type(device)).items(): @@ -328,7 +325,7 @@ def create_children_from_annotations( if is_device_vector: n_device_vector = DeviceVector( - {i: device_type() for i in range(1, device_vectors[name] + 1)} + {i: device_type() for i in range(1, device_vectors[name] + 1)} # type: ignore ) setattr(device, name, n_device_vector) for sub_device in n_device_vector.values(): diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index ef8a5693e2..bdac6d878f 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -1,10 +1,11 @@ import inspect import logging import sys +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin +from typing import Any, get_origin import numpy as np from aioca import ( @@ -18,8 +19,10 @@ caput, ) from aioca.types import AugmentedValue, Dbr, Format -from bluesky.protocols import DataKey, Dtype, Reading +from bluesky.protocols import Reading from epicscorelibs.ca import dbr +from event_model import DataKey +from event_model.documents.event_descriptor import Dtype from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -35,7 +38,7 @@ from ._common import LimitPair, Limits, common_meta, get_supported_values -dbr_to_dtype: Dict[Dbr, Dtype] = { +dbr_to_dtype: dict[Dbr, Dtype] = { dbr.DBR_STRING: "string", dbr.DBR_SHORT: "integer", dbr.DBR_FLOAT: "number", @@ -48,8 +51,8 @@ def _data_key_from_augmented_value( value: AugmentedValue, *, - choices: Optional[List[str]] = None, - dtype: Optional[Dtype] = None, + choices: list[str] | None = None, + dtype: Dtype | None = None, ) -> DataKey: """Use the return value of get with FORMAT_CTRL to construct a DataKey describing the signal. See docstring of AugmentedValue for expected @@ -67,14 +70,15 @@ def _data_key_from_augmented_value( assert value.ok, f"Error reading {source}: {value}" scalar = value.element_count == 1 - dtype = dtype or dbr_to_dtype[value.datatype] + dtype = dtype or dbr_to_dtype[value.datatype] # type: ignore dtype_numpy = np.dtype(dbr.DbrCodeToType[value.datatype].dtype).descr[0][1] d = DataKey( source=source, dtype=dtype if scalar else "array", - dtype_numpy=dtype_numpy, + # Ignore until https://github.com/bluesky/event-model/issues/308 + dtype_numpy=dtype_numpy, # type: ignore # strictly value.element_count >= len(value) shape=[] if scalar else [len(value)], ) @@ -84,10 +88,10 @@ def _data_key_from_augmented_value( d[key] = attr if choices is not None: - d["choices"] = choices + d["choices"] = choices # type: ignore if limits := _limits_from_augmented_value(value): - d["limits"] = limits + d["limits"] = limits # type: ignore return d @@ -110,8 +114,8 @@ def get_limits(limit: str) -> LimitPair: @dataclass class CaConverter: - read_dbr: Optional[Dbr] - write_dbr: Optional[Dbr] + read_dbr: Dbr | None + write_dbr: Dbr | None def write_value(self, value) -> Any: return value @@ -120,9 +124,9 @@ def value(self, value: AugmentedValue): # for channel access ca_xxx classes, this # invokes __pos__ operator to return an instance of # the builtin base class - return +value + return +value # type: ignore - def reading(self, value: AugmentedValue): + def reading(self, value: AugmentedValue) -> Reading: return { "value": self.value(value), "timestamp": value.timestamp, @@ -157,14 +161,14 @@ class CaEnumConverter(CaConverter): choices: dict[str, str] - def write_value(self, value: Union[Enum, str]): + def write_value(self, value: Enum | str): if isinstance(value, Enum): return value.value else: return value def value(self, value: AugmentedValue): - return self.choices[value] + return self.choices[value] # type: ignore def get_datakey(self, value: AugmentedValue) -> DataKey: # Sometimes DBR_TYPE returns as String, must pass choices still @@ -186,7 +190,7 @@ def __getattribute__(self, __name: str) -> Any: def make_converter( - datatype: Optional[Type], values: Dict[str, AugmentedValue] + datatype: type | None, values: dict[str, AugmentedValue] ) -> CaConverter: pv = list(values)[0] pv_dbr = get_unique({k: v.datatype for k, v in values.items()}, "datatypes") @@ -202,7 +206,7 @@ def make_converter( raise TypeError(f"{pv} has type [str] not {datatype.__name__}") return CaArrayConverter(pv_dbr, None) elif is_array: - pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes") + pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes") # type: ignore # This is an array if datatype: # Check we wanted an array of this type @@ -211,7 +215,7 @@ def make_converter( raise TypeError(f"{pv} has type [{pv_dtype}] not {datatype.__name__}") if dtype != pv_dtype: raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") - return CaArrayConverter(pv_dbr, None) + return CaArrayConverter(pv_dbr, None) # type: ignore elif pv_dbr == dbr.DBR_ENUM and datatype is bool: # Database can't do bools, so are often representated as enums, # CA can do int @@ -243,7 +247,7 @@ def make_converter( f"{pv} has type {type(value).__name__.replace('ca_', '')} " + f"not {datatype.__name__}" ) - return CaConverter(pv_dbr, None) + return CaConverter(pv_dbr, None) # type: ignore _tried_pyepics = False @@ -271,24 +275,24 @@ class CaSignalBackend(SignalBackend[T]): ) @classmethod - def datatype_allowed(cls, datatype: Optional[Type]) -> bool: - stripped_origin = get_origin(datatype) or datatype - if datatype is None: + def datatype_allowed(cls, dtype: Any) -> bool: + stripped_origin = get_origin(dtype) or dtype + if dtype is None: return True return inspect.isclass(stripped_origin) and issubclass( stripped_origin, cls._ALLOWED_DATATYPES ) - def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): + def __init__(self, datatype: type[T] | None, read_pv: str, write_pv: str): self.datatype = datatype if not CaSignalBackend.datatype_allowed(self.datatype): raise TypeError(f"Given datatype {self.datatype} unsupported in CA.") self.read_pv = read_pv self.write_pv = write_pv - self.initial_values: Dict[str, AugmentedValue] = {} + self.initial_values: dict[str, AugmentedValue] = {} self.converter: CaConverter = DisconnectedCaConverter(None, None) - self.subscription: Optional[Subscription] = None + self.subscription: Subscription | None = None def source(self, name: str): return f"ca://{self.read_pv}" @@ -315,7 +319,7 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): await self._store_initial_value(self.read_pv, timeout=timeout) self.converter = make_converter(self.datatype, self.initial_values) - async def put(self, value: Optional[T], wait=True, timeout=None): + async def put(self, value: T | None, wait=True, timeout=None): if value is None: write_value = self.initial_values[self.write_pv] else: @@ -357,7 +361,7 @@ async def get_setpoint(self) -> T: ) return self.converter.value(value) - def set_callback(self, callback: Optional[ReadingValueCallback[T]]) -> None: + def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: if callback: assert ( not self.subscription diff --git a/src/ophyd_async/epics/signal/_common.py b/src/ophyd_async/epics/signal/_common.py index 34d344ba07..ae40e93029 100644 --- a/src/ophyd_async/epics/signal/_common.py +++ b/src/ophyd_async/epics/signal/_common.py @@ -1,6 +1,5 @@ import inspect from enum import Enum -from typing import Dict, Optional, Tuple, Type from typing_extensions import TypedDict @@ -16,9 +15,6 @@ class LimitPair(TypedDict): high: float | None low: float | None - def __bool__(self) -> bool: - return self.low is None and self.high is None - class Limits(TypedDict): alarm: LimitPair @@ -26,15 +22,12 @@ class Limits(TypedDict): display: LimitPair warning: LimitPair - def __bool__(self) -> bool: - return any(self.alarm, self.control, self.display, self.warning) - def get_supported_values( pv: str, - datatype: Optional[Type[str]], - pv_choices: Tuple[str, ...], -) -> Dict[str, str]: + datatype: type[str] | None, + pv_choices: tuple[str, ...], +) -> dict[str, str]: if inspect.isclass(datatype) and issubclass(datatype, RuntimeSubsetEnum): if not set(datatype.choices).issubset(set(pv_choices)): raise TypeError( diff --git a/src/ophyd_async/epics/signal/_epics_transport.py b/src/ophyd_async/epics/signal/_epics_transport.py index b6954c9d1c..4737de704f 100644 --- a/src/ophyd_async/epics/signal/_epics_transport.py +++ b/src/ophyd_async/epics/signal/_epics_transport.py @@ -4,22 +4,25 @@ from enum import Enum + +def _make_unavailable_class(error: Exception) -> type: + class TransportNotAvailable: + def __init__(*args, **kwargs): + raise NotImplementedError("Transport not available") from error + + return TransportNotAvailable + + try: from ._aioca import CaSignalBackend except ImportError as ca_error: - - class CaSignalBackend: # type: ignore - def __init__(*args, ca_error=ca_error, **kwargs): - raise NotImplementedError("CA support not available") from ca_error + CaSignalBackend = _make_unavailable_class(ca_error) try: from ._p4p import PvaSignalBackend except ImportError as pva_error: - - class PvaSignalBackend: # type: ignore - def __init__(*args, pva_error=pva_error, **kwargs): - raise NotImplementedError("PVA support not available") from pva_error + PvaSignalBackend = _make_unavailable_class(pva_error) class _EpicsTransport(Enum): diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index c7d0b5240d..6fe13d0e2c 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -3,14 +3,16 @@ import inspect import logging import time -from abc import ABCMeta +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin +from typing import Any, get_origin import numpy as np -from bluesky.protocols import DataKey, Dtype, Reading +from bluesky.protocols import Reading +from event_model import DataKey +from event_model.documents.event_descriptor import Dtype from p4p import Value from p4p.client.asyncio import Context, Subscription from pydantic import BaseModel @@ -24,13 +26,14 @@ T, get_dtype, get_unique, + is_pydantic_model, wait_for_connection, ) from ._common import LimitPair, Limits, common_meta, get_supported_values # https://mdavidsaver.github.io/p4p/values.html -specifier_to_dtype: Dict[str, Dtype] = { +specifier_to_dtype: dict[str, Dtype] = { "?": "integer", # bool "b": "integer", # int8 "B": "integer", # uint8 @@ -45,7 +48,7 @@ "s": "string", } -specifier_to_np_dtype: Dict[str, str] = { +specifier_to_np_dtype: dict[str, str] = { "?": " DataKey: """ Args: @@ -108,7 +111,8 @@ def _data_key_from_value( d = DataKey( source=source, dtype=dtype, - dtype_numpy=dtype_numpy, + # type ignore until https://github.com/bluesky/event-model/issues/308 + dtype_numpy=dtype_numpy, # type: ignore shape=shape, ) if display_data is not None: @@ -118,10 +122,12 @@ def _data_key_from_value( d[key] = attr if choices is not None: - d["choices"] = choices + # type ignore until https://github.com/bluesky/event-model/issues/309 + d["choices"] = choices # type: ignore if limits := _limits_from_value(value): - d["limits"] = limits + # type ignore until https://github.com/bluesky/event-model/issues/309 + d["limits"] = limits # type: ignore return d @@ -152,7 +158,7 @@ def write_value(self, value): def value(self, value): return value["value"] - def reading(self, value): + def reading(self, value) -> Reading: ts = value["timeStamp"] sv = value["alarm"]["severity"] return { @@ -164,13 +170,13 @@ def reading(self, value): def get_datakey(self, source: str, value) -> DataKey: return _data_key_from_value(source, value) - def metadata_fields(self) -> List[str]: + def metadata_fields(self) -> list[str]: """ Fields to request from PVA for metadata. """ return ["alarm", "timeStamp"] - def value_fields(self) -> List[str]: + def value_fields(self) -> list[str]: """ Fields to request from PVA for the value. """ @@ -185,11 +191,11 @@ def get_datakey(self, source: str, value) -> DataKey: class PvaNDArrayConverter(PvaConverter): - def metadata_fields(self) -> List[str]: + def metadata_fields(self) -> list[str]: return super().metadata_fields() + ["dimension"] - def _get_dimensions(self, value) -> List[int]: - dimensions: List[Value] = value["dimension"] + def _get_dimensions(self, value) -> list[int]: + dimensions: list[Value] = value["dimension"] dims = [dim.size for dim in dimensions] # Note: dimensions in NTNDArray are in fortran-like order # with first index changing fastest. @@ -224,7 +230,7 @@ class PvaEnumConverter(PvaConverter): def __init__(self, choices: dict[str, str]): self.choices = tuple(choices.values()) - def write_value(self, value: Union[Enum, str]): + def write_value(self, value: Enum | str): if isinstance(value, Enum): return value.value else: @@ -253,7 +259,7 @@ def value(self, value): def get_datakey(self, source: str, value) -> DataKey: # This is wrong, but defer until we know how to actually describe a table - return _data_key_from_value(source, value, dtype="object") + return _data_key_from_value(source, value, dtype="object") # type: ignore class PvaPydanticModelConverter(PvaConverter): @@ -261,16 +267,16 @@ def __init__(self, datatype: BaseModel): self.datatype = datatype def value(self, value: Value): - return self.datatype(**value.todict()) + return self.datatype(**value.todict()) # type: ignore - def write_value(self, value: Union[BaseModel, Dict[str, Any]]): - if isinstance(value, self.datatype): - return value.model_dump(mode="python") + def write_value(self, value: BaseModel | dict[str, Any]): + if isinstance(value, self.datatype): # type: ignore + return value.model_dump(mode="python") # type: ignore return value class PvaDictConverter(PvaConverter): - def reading(self, value): + def reading(self, value) -> Reading: ts = time.time() value = value.todict() # Alarm severity is vacuously 0 for a table @@ -282,13 +288,13 @@ def value(self, value: Value): def get_datakey(self, source: str, value) -> DataKey: raise NotImplementedError("Describing Dict signals not currently supported") - def metadata_fields(self) -> List[str]: + def metadata_fields(self) -> list[str]: """ Fields to request from PVA for metadata. """ return [] - def value_fields(self) -> List[str]: + def value_fields(self) -> list[str]: """ Fields to request from PVA for the value. """ @@ -300,7 +306,7 @@ def __getattribute__(self, __name: str) -> Any: raise NotImplementedError("No PV has been set as connect() has not been called") -def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConverter: +def make_converter(datatype: type | None, values: dict[str, Any]) -> PvaConverter: pv = list(values)[0] typeid = get_unique({k: v.getID() for k, v in values.items()}, "typeids") typ = get_unique( @@ -349,7 +355,7 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve and issubclass(datatype, RuntimeSubsetEnum) ): return PvaEnumConverter( - get_supported_values(pv, datatype, datatype.choices) + get_supported_values(pv, datatype, datatype.choices) # type: ignore ) elif datatype and not issubclass(typ, datatype): # Allow int signals to represent float records when prec is 0 @@ -364,15 +370,8 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") return PvaConverter() elif "NTTable" in typeid: - if ( - datatype - and inspect.isclass(datatype) - and - # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ - isinstance(datatype, ABCMeta) - and issubclass(datatype, BaseModel) - ): - return PvaPydanticModelConverter(datatype) + if is_pydantic_model(datatype): + return PvaPydanticModelConverter(datatype) # type: ignore return PvaTableConverter() elif "structure" in typeid: return PvaDictConverter() @@ -381,7 +380,7 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve class PvaSignalBackend(SignalBackend[T]): - _ctxt: Optional[Context] = None + _ctxt: Context | None = None _ALLOWED_DATATYPES = ( bool, @@ -397,24 +396,24 @@ class PvaSignalBackend(SignalBackend[T]): ) @classmethod - def datatype_allowed(cls, datatype: Optional[Type]) -> bool: - stripped_origin = get_origin(datatype) or datatype - if datatype is None: + def datatype_allowed(cls, dtype: Any) -> bool: + stripped_origin = get_origin(dtype) or dtype + if dtype is None: return True return inspect.isclass(stripped_origin) and issubclass( stripped_origin, cls._ALLOWED_DATATYPES ) - def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): + def __init__(self, datatype: type[T] | None, read_pv: str, write_pv: str): self.datatype = datatype if not PvaSignalBackend.datatype_allowed(self.datatype): raise TypeError(f"Given datatype {self.datatype} unsupported in PVA.") self.read_pv = read_pv self.write_pv = write_pv - self.initial_values: Dict[str, Any] = {} + self.initial_values: dict[str, Any] = {} self.converter: PvaConverter = DisconnectedPvaConverter() - self.subscription: Optional[Subscription] = None + self.subscription: Subscription | None = None def source(self, name: str): return f"pva://{self.read_pv}" @@ -454,7 +453,7 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): await self._store_initial_value(self.read_pv, timeout=timeout) self.converter = make_converter(self.datatype, self.initial_values) - async def put(self, value: Optional[T], wait=True, timeout=None): + async def put(self, value: T | None, wait=True, timeout=None): if value is None: write_value = self.initial_values[self.write_pv] else: @@ -474,7 +473,7 @@ async def get_datakey(self, source: str) -> DataKey: value = await self.ctxt.get(self.read_pv) return self.converter.get_datakey(source, value) - def _pva_request_string(self, fields: List[str]) -> str: + def _pva_request_string(self, fields: list[str]) -> str: """ Converts a list of requested fields into a PVA request string which can be passed to p4p. @@ -497,7 +496,7 @@ async def get_setpoint(self) -> T: value = await self.ctxt.get(self.write_pv, "field(value)") return self.converter.value(value) - def set_callback(self, callback: Optional[ReadingValueCallback[T]]) -> None: + def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: if callback: assert ( not self.subscription diff --git a/src/ophyd_async/epics/signal/_signal.py b/src/ophyd_async/epics/signal/_signal.py index 2e50c67b65..6711ac734e 100644 --- a/src/ophyd_async/epics/signal/_signal.py +++ b/src/ophyd_async/epics/signal/_signal.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Optional, Tuple, Type - from ophyd_async.core import ( SignalBackend, SignalR, @@ -19,7 +17,7 @@ _default_epics_transport = _EpicsTransport.ca -def _transport_pv(pv: str) -> Tuple[_EpicsTransport, str]: +def _transport_pv(pv: str) -> tuple[_EpicsTransport, str]: split = pv.split("://", 1) if len(split) > 1: # We got something like pva://mydevice, so use specified comms mode @@ -32,7 +30,7 @@ def _transport_pv(pv: str) -> Tuple[_EpicsTransport, str]: def _epics_signal_backend( - datatype: Optional[Type[T]], read_pv: str, write_pv: str + datatype: type[T] | None, read_pv: str, write_pv: str ) -> SignalBackend[T]: """Create an epics signal backend.""" r_transport, r_pv = _transport_pv(read_pv) @@ -42,7 +40,7 @@ def _epics_signal_backend( def epics_signal_rw( - datatype: Type[T], read_pv: str, write_pv: Optional[str] = None, name: str = "" + datatype: type[T], read_pv: str, write_pv: str | None = None, name: str = "" ) -> SignalRW[T]: """Create a `SignalRW` backed by 1 or 2 EPICS PVs @@ -60,7 +58,7 @@ def epics_signal_rw( def epics_signal_rw_rbv( - datatype: Type[T], write_pv: str, read_suffix: str = "_RBV", name: str = "" + datatype: type[T], write_pv: str, read_suffix: str = "_RBV", name: str = "" ) -> SignalRW[T]: """Create a `SignalRW` backed by 1 or 2 EPICS PVs, with a suffix on the readback pv @@ -76,7 +74,7 @@ def epics_signal_rw_rbv( return epics_signal_rw(datatype, f"{write_pv}{read_suffix}", write_pv, name) -def epics_signal_r(datatype: Type[T], read_pv: str, name: str = "") -> SignalR[T]: +def epics_signal_r(datatype: type[T], read_pv: str, name: str = "") -> SignalR[T]: """Create a `SignalR` backed by 1 EPICS PV Parameters @@ -90,7 +88,7 @@ def epics_signal_r(datatype: Type[T], read_pv: str, name: str = "") -> SignalR[T return SignalR(backend, name=name) -def epics_signal_w(datatype: Type[T], write_pv: str, name: str = "") -> SignalW[T]: +def epics_signal_w(datatype: type[T], write_pv: str, name: str = "") -> SignalW[T]: """Create a `SignalW` backed by 1 EPICS PVs Parameters diff --git a/src/ophyd_async/fastcs/panda/_control.py b/src/ophyd_async/fastcs/panda/_control.py index aeb8e750cd..212f616376 100644 --- a/src/ophyd_async/fastcs/panda/_control.py +++ b/src/ophyd_async/fastcs/panda/_control.py @@ -16,7 +16,7 @@ def __init__(self, pcap: PcapBlock) -> None: self.pcap = pcap self._arm_status: AsyncStatus | None = None - def get_deadtime(self, exposure: float) -> float: + def get_deadtime(self, exposure: float | None) -> float: return 0.000000008 async def prepare(self, trigger_info: TriggerInfo): diff --git a/src/ophyd_async/fastcs/panda/_hdf_panda.py b/src/ophyd_async/fastcs/panda/_hdf_panda.py index 3469ccf639..5045d7b27f 100644 --- a/src/ophyd_async/fastcs/panda/_hdf_panda.py +++ b/src/ophyd_async/fastcs/panda/_hdf_panda.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from ophyd_async.core import DEFAULT_TIMEOUT, PathProvider, SignalR, StandardDetector from ophyd_async.epics.pvi import create_children_from_annotations, fill_pvi_entries @@ -36,7 +36,14 @@ def __init__( ) async def connect( - self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT - ) -> None: + self, + mock: bool = False, + timeout: float = DEFAULT_TIMEOUT, + force_reconnect: bool = False, + ): + # TODO: this doesn't support caching + # https://github.com/bluesky/ophyd-async/issues/472 await fill_pvi_entries(self, self._prefix + "PVI", timeout=timeout, mock=mock) - await super().connect(mock=mock, timeout=timeout) + await super().connect( + mock=mock, timeout=timeout, force_reconnect=force_reconnect + ) diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index ece1a1f683..a021d23fa8 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,6 +1,6 @@ -import inspect +from collections.abc import Sequence from enum import Enum -from typing import Annotated, Sequence +from typing import Annotated import numpy as np import numpy.typing as npt @@ -38,14 +38,14 @@ class SeqTrigger(str, Enum): PydanticNp1DArrayInt32 = Annotated[ - np.ndarray[tuple[int], np.int32], + np.ndarray[tuple[int], np.dtype[np.int32]], NpArrayPydanticAnnotation.factory( data_type=np.int32, dimensions=1, strict_data_typing=False ), Field(default_factory=lambda: np.array([], np.int32)), ] PydanticNp1DArrayBool = Annotated[ - np.ndarray[tuple[int], np.bool_], + np.ndarray[tuple[int], np.dtype[np.bool_]], NpArrayPydanticAnnotation.factory( data_type=np.bool_, dimensions=1, strict_data_typing=False ), @@ -74,7 +74,7 @@ class SeqTable(Table): outf2: PydanticNp1DArrayBool @classmethod - def row( + def row( # type: ignore cls, *, repeats: int = 1, @@ -95,9 +95,7 @@ def row( oute2: bool = False, outf2: bool = False, ) -> "SeqTable": - sig = inspect.signature(cls.row) - kwargs = {k: v for k, v in locals().items() if k in sig.parameters} - return Table.row(cls, **kwargs) + return Table.row(**locals()) @model_validator(mode="after") def validate_max_length(self) -> "SeqTable": diff --git a/src/ophyd_async/fastcs/panda/_trigger.py b/src/ophyd_async/fastcs/panda/_trigger.py index c79988a381..7abe7b6456 100644 --- a/src/ophyd_async/fastcs/panda/_trigger.py +++ b/src/ophyd_async/fastcs/panda/_trigger.py @@ -1,5 +1,4 @@ import asyncio -from typing import Optional from pydantic import BaseModel, Field @@ -82,7 +81,7 @@ async def kickoff(self) -> None: await self.pcomp.enable.set("ONE") await wait_for_value(self.pcomp.active, True, timeout=1) - async def complete(self, timeout: Optional[float] = None) -> None: + async def complete(self, timeout: float | None = None) -> None: await wait_for_value(self.pcomp.active, False, timeout=timeout) async def stop(self): diff --git a/src/ophyd_async/fastcs/panda/_utils.py b/src/ophyd_async/fastcs/panda/_utils.py index 38bb015a9c..e960b5c7dd 100644 --- a/src/ophyd_async/fastcs/panda/_utils.py +++ b/src/ophyd_async/fastcs/panda/_utils.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, Sequence +from collections.abc import Sequence +from typing import Any -def phase_sorter(panda_signal_values: Dict[str, Any]) -> Sequence[Dict[str, Any]]: +def phase_sorter(panda_signal_values: dict[str, Any]) -> Sequence[dict[str, Any]]: # Panda has two load phases. If the signal name ends in the string "UNITS", # it needs to be loaded first so put in first phase phase_1, phase_2 = {}, {} diff --git a/src/ophyd_async/fastcs/panda/_writer.py b/src/ophyd_async/fastcs/panda/_writer.py index 65ca186fe3..100af8b10e 100644 --- a/src/ophyd_async/fastcs/panda/_writer.py +++ b/src/ophyd_async/fastcs/panda/_writer.py @@ -1,8 +1,9 @@ import asyncio +from collections.abc import AsyncGenerator, AsyncIterator from pathlib import Path -from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional -from bluesky.protocols import DataKey, StreamAsset +from bluesky.protocols import StreamAsset +from event_model import DataKey from p4p.client.thread import Context from ophyd_async.core import ( @@ -20,7 +21,7 @@ class PandaHDFWriter(DetectorWriter): - _ctxt: Optional[Context] = None + _ctxt: Context | None = None def __init__( self, @@ -33,12 +34,12 @@ def __init__( self._prefix = prefix self._path_provider = path_provider self._name_provider = name_provider - self._datasets: List[HDFDataset] = [] - self._file: Optional[HDFFile] = None + self._datasets: list[HDFDataset] = [] + self._file: HDFFile | None = None self._multiplier = 1 # Triggered on PCAP arm - async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: + async def open(self, multiplier: int = 1) -> dict[str, DataKey]: """Retrieve and get descriptor of all PandA signals marked for capture""" # Ensure flushes are immediate @@ -76,7 +77,7 @@ async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: return await self._describe() - async def _describe(self) -> Dict[str, DataKey]: + async def _describe(self) -> dict[str, DataKey]: """ Return a describe based on the datasets PV """ @@ -85,9 +86,11 @@ async def _describe(self) -> Dict[str, DataKey]: describe = { ds.data_key: DataKey( source=self.panda_data_block.hdf_directory.source, - shape=ds.shape, + shape=list(ds.shape), dtype="array" if ds.shape != [1] else "number", - dtype_numpy=" None: capture_table = await self.panda_data_block.datasets.get_value() self._datasets = [ - HDFDataset(dataset_name, "/" + dataset_name, [1], multiplier=1) + # TODO: Update chunk size to read signal once available in IOC + # Currently PandA IOC sets chunk size to 1024 points per chunk + HDFDataset( + dataset_name, "/" + dataset_name, [1], multiplier=1, chunk_shape=(1024,) + ) for dataset_name in capture_table["name"] ] @@ -118,9 +125,7 @@ async def _update_datasets(self) -> None: # Next few functions are exactly the same as AD writer. Could move as default # StandardDetector behavior - async def wait_for_index( - self, index: int, timeout: Optional[float] = DEFAULT_TIMEOUT - ): + async def wait_for_index(self, index: int, timeout: float | None = DEFAULT_TIMEOUT): def matcher(value: int) -> bool: return value >= index diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 2cf6f5499e..a9f0003dec 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -1,5 +1,3 @@ -from typing import List, Optional - import bluesky.plan_stubs as bps from bluesky.utils import short_uid @@ -20,7 +18,7 @@ def prepare_static_pcomp_flyer_and_detectors( flyer: StandardFlyer[PcompInfo], - detectors: List[StandardDetector], + detectors: list[StandardDetector], pcomp_info: PcompInfo, trigger_info: TriggerInfo, ): @@ -39,13 +37,13 @@ def prepare_static_pcomp_flyer_and_detectors( def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( flyer: StandardFlyer[SeqTableInfo], - detectors: List[StandardDetector], + detectors: list[StandardDetector], number_of_frames: int, exposure: float, shutter_time: float, repeats: int = 1, period: float = 0.0, - frame_timeout: Optional[float] = None, + frame_timeout: float | None = None, iteration: int = 1, ): """Prepare a hardware triggered flyable and one or more detectors. @@ -107,7 +105,7 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( def fly_and_collect( stream_name: str, flyer: StandardFlyer[SeqTableInfo] | StandardFlyer[PcompInfo], - detectors: List[StandardDetector], + detectors: list[StandardDetector], ): """Kickoff, complete and collect with a flyer and multiple detectors. @@ -147,7 +145,7 @@ def fly_and_collect( def fly_and_collect_with_static_pcomp( stream_name: str, flyer: StandardFlyer[PcompInfo], - detectors: List[StandardDetector], + detectors: list[StandardDetector], number_of_pulses: int, pulse_width: int, rising_edge_step: int, @@ -173,7 +171,7 @@ def fly_and_collect_with_static_pcomp( def time_resolved_fly_and_collect_with_static_seq_table( stream_name: str, flyer: StandardFlyer[SeqTableInfo], - detectors: List[StandardDetector], + detectors: list[StandardDetector], number_of_frames: int, exposure: float, shutter_time: float, diff --git a/src/ophyd_async/plan_stubs/_nd_attributes.py b/src/ophyd_async/plan_stubs/_nd_attributes.py index 986dd86db6..95a473033d 100644 --- a/src/ophyd_async/plan_stubs/_nd_attributes.py +++ b/src/ophyd_async/plan_stubs/_nd_attributes.py @@ -1,14 +1,15 @@ -from typing import Sequence -from xml.etree import cElementTree as ET +from collections.abc import Sequence +from xml.etree import ElementTree as ET import bluesky.plan_stubs as bps -from ophyd_async.core._device import Device -from ophyd_async.epics.adcore._core_io import NDArrayBaseIO -from ophyd_async.epics.adcore._utils import ( +from ophyd_async.core import Device +from ophyd_async.epics.adcore import ( + NDArrayBaseIO, NDAttributeDataType, NDAttributeParam, NDAttributePv, + NDFileHDFIO, ) @@ -48,9 +49,14 @@ def setup_ndattributes( def setup_ndstats_sum(detector: Device): + hdf = getattr(detector, "hdf", None) + assert isinstance(hdf, NDFileHDFIO), ( + f"Expected {detector.name} to have 'hdf' attribute that is an NDFilHDFIO, " + f"got {hdf}" + ) yield from ( setup_ndattributes( - detector.hdf, + hdf, [ NDAttributeParam( name=f"{detector.name}-sum", diff --git a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector.py b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector.py index ede7e03fee..baa97cbd58 100644 --- a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector.py +++ b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector.py @@ -1,10 +1,10 @@ +from collections.abc import Sequence from pathlib import Path -from typing import Sequence from ophyd_async.core import ( - AsyncReadable, FilenameProvider, PathProvider, + SignalR, StandardDetector, StaticFilenameProvider, StaticPathProvider, @@ -19,7 +19,7 @@ class PatternDetector(StandardDetector): def __init__( self, path: Path, - config_sigs: Sequence[AsyncReadable] = [], + config_sigs: Sequence[SignalR] = (), name: str = "", ) -> None: fp: FilenameProvider = StaticFilenameProvider(name) diff --git a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py index 756b4177ed..45dcddc9c0 100644 --- a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py +++ b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py @@ -1,5 +1,4 @@ import asyncio -from typing import Optional from ophyd_async.core import DetectorControl, PathProvider from ophyd_async.core._detector import TriggerInfo @@ -12,17 +11,15 @@ def __init__( self, pattern_generator: PatternGenerator, path_provider: PathProvider, - exposure: Optional[float] = 0.1, + exposure: float = 0.1, ) -> None: self.pattern_generator: PatternGenerator = pattern_generator self.pattern_generator.set_exposure(exposure) self.path_provider: PathProvider = path_provider - self.task: Optional[asyncio.Task] = None + self.task: asyncio.Task | None = None super().__init__() - async def prepare( - self, trigger_info: TriggerInfo = TriggerInfo(number=1, livetime=0.01) - ): + async def prepare(self, trigger_info: TriggerInfo): self._trigger_info = trigger_info if self._trigger_info.livetime is None: self._trigger_info.livetime = 0.01 diff --git a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_writer.py b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_writer.py index 83178eb768..16dda6f69f 100644 --- a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_writer.py +++ b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_writer.py @@ -1,8 +1,8 @@ -from typing import AsyncGenerator, AsyncIterator, Dict +from collections.abc import AsyncGenerator, AsyncIterator -from bluesky.protocols import DataKey +from event_model import DataKey -from ophyd_async.core import DetectorWriter, NameProvider, PathProvider +from ophyd_async.core import DEFAULT_TIMEOUT, DetectorWriter, NameProvider, PathProvider from ._pattern_generator import PatternGenerator @@ -20,7 +20,7 @@ def __init__( self.path_provider = path_provider self.name_provider = name_provider - async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: + async def open(self, multiplier: int = 1) -> dict[str, DataKey]: return await self.pattern_generator.open_file( self.path_provider, self.name_provider(), multiplier ) @@ -31,8 +31,11 @@ async def close(self) -> None: def collect_stream_docs(self, indices_written: int) -> AsyncIterator: return self.pattern_generator.collect_stream_docs(indices_written) - def observe_indices_written(self, timeout=...) -> AsyncGenerator[int, None]: - return self.pattern_generator.observe_indices_written() + async def observe_indices_written( + self, timeout=DEFAULT_TIMEOUT + ) -> AsyncGenerator[int, None]: + async for index in self.pattern_generator.observe_indices_written(timeout): + yield index async def get_indices_written(self) -> int: return self.pattern_generator.image_counter diff --git a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_generator.py b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_generator.py index 0031a931e4..01c54c8245 100644 --- a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_generator.py +++ b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_generator.py @@ -1,9 +1,10 @@ +from collections.abc import AsyncGenerator, AsyncIterator from pathlib import Path -from typing import AsyncGenerator, AsyncIterator, Dict, Optional import h5py import numpy as np -from bluesky.protocols import DataKey, StreamAsset +from bluesky.protocols import StreamAsset +from event_model import DataKey from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -60,19 +61,22 @@ def __init__( generate_gaussian_blob(width=detector_width, height=detector_height) * MAX_UINT8_VALUE ) - self._hdf_stream_provider: Optional[HDFFile] = None - self._handle_for_h5_file: Optional[h5py.File] = None - self.target_path: Optional[Path] = None + self._hdf_stream_provider: HDFFile | None = None + self._handle_for_h5_file: h5py.File | None = None + self.target_path: Path | None = None - async def write_image_to_file(self) -> None: + def write_data_to_dataset(self, path: str, data_shape: tuple[int, ...], data): + """Write data to named dataset, resizing to fit and flushing after.""" assert self._handle_for_h5_file, "no file has been opened!" - # prepare - resize the fixed hdf5 data structure - # so that the new image can be written - self._handle_for_h5_file[DATA_PATH].resize( - (self.image_counter + 1, self.height, self.width) - ) - self._handle_for_h5_file[SUM_PATH].resize((self.image_counter + 1,)) + dset = self._handle_for_h5_file[path] + assert isinstance( + dset, h5py.Dataset + ), f"Expected {path} to be dataset, got {dset}" + dset.resize((self.image_counter + 1,) + data_shape) + dset[self.image_counter] = data + dset.flush() + async def write_image_to_file(self) -> None: # generate the simulated data intensity: float = generate_interesting_pattern(self.x, self.y) detector_data = ( @@ -82,14 +86,9 @@ async def write_image_to_file(self) -> None: / self.saturation_exposure_time ).astype(np.uint8) - # write data to disc (intermediate step) - self._handle_for_h5_file[DATA_PATH][self.image_counter] = detector_data - sum = np.sum(detector_data) - self._handle_for_h5_file[SUM_PATH][self.image_counter] = sum - - # save metadata - so that it's discoverable - self._handle_for_h5_file[DATA_PATH].flush() - self._handle_for_h5_file[SUM_PATH].flush() + # Write the data and sum + self.write_data_to_dataset(DATA_PATH, (self.height, self.width), detector_data) + self.write_data_to_dataset(SUM_PATH, (), np.sum(detector_data)) # counter increment is last # as only at this point the new data is visible from the outside @@ -107,7 +106,7 @@ def set_y(self, value: float) -> None: async def open_file( self, path_provider: PathProvider, name: str, multiplier: int = 1 - ) -> Dict[str, DataKey]: + ) -> dict[str, DataKey]: await self.counter_signal.connect() self.target_path = self._get_new_path(path_provider) @@ -156,7 +155,7 @@ async def open_file( describe = { ds.data_key: DataKey( source="sim://pattern-generator-hdf-file", - shape=outer_shape + tuple(ds.shape), + shape=list(outer_shape) + list(ds.shape), dtype="array" if ds.shape else "number", external="STREAM:", ) diff --git a/src/ophyd_async/sim/demo/_sim_motor.py b/src/ophyd_async/sim/demo/_sim_motor.py index e2a63b1657..eaca21e45d 100644 --- a/src/ophyd_async/sim/demo/_sim_motor.py +++ b/src/ophyd_async/sim/demo/_sim_motor.py @@ -63,10 +63,11 @@ async def _move(self, old_position: float, new_position: float, move_time: float await asyncio.sleep(0.1) @WatchableAsyncStatus.wrap - async def set(self, new_position: float): + async def set(self, value: float): """ Asynchronously move the motor to a new position. """ + new_position = value # Make sure any existing move tasks are stopped await self.stop() old_position, units, velocity = await asyncio.gather( diff --git a/system_tests/epics/eiger/README.md b/system_tests/epics/eiger/README.md index 8080a53395..f2c1e50bd2 100644 --- a/system_tests/epics/eiger/README.md +++ b/system_tests/epics/eiger/README.md @@ -4,5 +4,3 @@ This system test runs against the eiger tickit sim. To run it: 1. Run `podman run --rm -it -v /dev/shm:/dev/shm -v /tmp:/tmp --net=host ghcr.io/diamondlightsource/eiger-detector-runtime:1.16.0beta5` this will bring up the simulator itself. 2. In a separate terminal load a python environment with `ophyd-async` in it 3. `cd system_tests/epics/eiger` and `./start_iocs_and_run_tests.sh` - - diff --git a/tests/conftest.py b/tests/conftest.py index 6f042ac19b..d71738ae85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,9 @@ import subprocess import sys import time +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable +from typing import Any import pytest from bluesky.run_engine import RunEngine, TransitionError @@ -35,11 +36,16 @@ if os.getenv("PYTEST_RAISE", "0") == "1": @pytest.hookimpl(tryfirst=True) - def pytest_exception_interact(call): - raise call.excinfo.value + def pytest_exception_interact(call: pytest.CallInfo[Any]): + if call.excinfo is not None: + raise call.excinfo.value + else: + raise RuntimeError( + f"{call} has no exception data, an unknown error has occurred" + ) @pytest.hookimpl(tryfirst=True) - def pytest_internalerror(excinfo): + def pytest_internalerror(excinfo: pytest.ExceptionInfo[Any]): raise excinfo.value diff --git a/tests/core/test_device_collector.py b/tests/core/test_device_collector.py index 75d8550ac2..efebc6a48f 100644 --- a/tests/core/test_device_collector.py +++ b/tests/core/test_device_collector.py @@ -45,7 +45,7 @@ async def test_device_collector_handles_top_level_errors(caplog): ] # In some environments the asyncio teardown will be logged as an error too assert len(device_log) == 1 - device_log[0].levelname == "ERROR" + assert device_log[0].levelname == "ERROR" def test_sync_device_connector_no_run_engine_raises_error(): diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index b265b86137..03dd73d2c0 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -1,6 +1,7 @@ +from collections.abc import Sequence from enum import Enum from os import path -from typing import Any, Dict, List, Sequence +from typing import Any from unittest.mock import patch import numpy as np @@ -105,7 +106,7 @@ async def device_all_types() -> DummyDeviceGroupAllTypes: # Dummy function to check different phases save properly -def sort_signal_by_phase(values: Dict[str, Any]) -> List[Dict[str, Any]]: +def sort_signal_by_phase(values: dict[str, Any]) -> list[dict[str, Any]]: phase_1 = {"child1.sig1": values["child1.sig1"]} phase_2 = {"child2.sig1": values["child2.sig1"]} return [phase_1, phase_2] @@ -114,7 +115,7 @@ def sort_signal_by_phase(values: Dict[str, Any]) -> List[Dict[str, Any]]: async def test_enum_yaml_formatting(tmp_path): enums = [EnumTest.VAL1, EnumTest.VAL2] save_to_yaml(enums, path.join(tmp_path, "test_file.yaml")) - with open(path.join(tmp_path, "test_file.yaml"), "r") as file: + with open(path.join(tmp_path, "test_file.yaml")) as file: saved_enums = yaml.load(file, yaml.Loader) # check that save/load reduces from enum to str assert all(isinstance(value, str) for value in saved_enums) @@ -182,7 +183,7 @@ def save_my_device(): RE(save_my_device()) actual_file_path = path.join(tmp_path, "test_file.yaml") - with open(actual_file_path, "r") as actual_file: + with open(actual_file_path) as actual_file: with open("tests/test_data/test_yaml_save.yml") as expected_file: assert actual_file.read() == expected_file.read() @@ -222,7 +223,7 @@ def save_my_device(): RE(save_my_device()) - with open(path.join(tmp_path, "test_file.yaml"), "r") as file: + with open(path.join(tmp_path, "test_file.yaml")) as file: yaml_content = yaml.load(file, yaml.Loader)[0] assert len(yaml_content) == 4 assert yaml_content["child1.sig1"] == "test_string" @@ -243,7 +244,7 @@ async def test_yaml_formatting(RE: RunEngine, device, tmp_path): await device.child2.sig1.set(table_pv) RE(save_device(device, file_path, sorter=sort_signal_by_phase)) - with open(file_path, "r") as file: + with open(file_path) as file: expected = """\ - child1.sig1: test_string - child2.sig1: diff --git a/tests/core/test_flyer.py b/tests/core/test_flyer.py index 6d9c9142aa..b9a3186134 100644 --- a/tests/core/test_flyer.py +++ b/tests/core/test_flyer.py @@ -1,13 +1,14 @@ import time +from collections.abc import AsyncGenerator, AsyncIterator, Sequence from enum import Enum -from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional, Sequence +from typing import Any from unittest.mock import Mock import bluesky.plan_stubs as bps import pytest -from bluesky.protocols import DataKey, StreamAsset +from bluesky.protocols import StreamAsset from bluesky.run_engine import RunEngine -from event_model import ComposeStreamResourceBundle, compose_stream_resource +from event_model import ComposeStreamResourceBundle, DataKey, compose_stream_resource from pydantic import ValidationError from ophyd_async.core import ( @@ -54,11 +55,11 @@ def __init__(self, name: str, shape: Sequence[int]): self.dummy_signal = epics_signal_rw(int, "pva://read_pv") self._shape = shape self._name = name - self._file: Optional[ComposeStreamResourceBundle] = None + self._file: ComposeStreamResourceBundle | None = None self._last_emitted = 0 self.index = 0 - async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: + async def open(self, multiplier: int = 1) -> dict[str, DataKey]: return { self._name: DataKey( source="soft://some-source", @@ -173,7 +174,7 @@ def flying_plan(): assert flyer._trigger_logic.state == TriggerState.preparing for detector in detectors: - detector.controller.disarm.assert_called_once # type: ignore + detector.controller.disarm.assert_called_once() # type: ignore yield from bps.open_run() yield from bps.declare_stream(*detectors, name="main_stream", collect=True) diff --git a/tests/core/test_mock_signal_backend.py b/tests/core/test_mock_signal_backend.py index 00c11a2708..b0ac8012b3 100644 --- a/tests/core/test_mock_signal_backend.py +++ b/tests/core/test_mock_signal_backend.py @@ -286,7 +286,8 @@ async def test_set_mock_values_exhausted_passes(mock_signals): repeat(iter(["second_value", "third_value"]), 6), require_all_consumed=False, ) - for calls, value_set in enumerate(iterator, start=1): + calls = 0 + for calls, value_set in enumerate(iterator, start=1): # noqa: B007 assert await signal2.get_value() == value_set assert calls == 6 diff --git a/tests/core/test_providers.py b/tests/core/test_providers.py index ce3d92d9ab..06874019f8 100644 --- a/tests/core/test_providers.py +++ b/tests/core/test_providers.py @@ -64,7 +64,7 @@ def test_auto_increment_path_provider(static_filename_provider, tmp_path): static_filename_provider, tmp_path, num_calls_per_inc=3, increment=2 ) - for i in range(3): + for _ in range(3): info = auto_inc_path_provider() assert os.path.basename(info.directory_path) == "00000" info = auto_inc_path_provider() diff --git a/tests/core/test_readable.py b/tests/core/test_readable.py index 7d13f308c3..59586602c7 100644 --- a/tests/core/test_readable.py +++ b/tests/core/test_readable.py @@ -1,5 +1,5 @@ from inspect import ismethod -from typing import List, get_type_hints +from typing import get_type_hints from unittest.mock import MagicMock import pytest @@ -58,7 +58,7 @@ def test_standard_readable_hints_raises_when_overriding_string_literal(): ) with pytest.raises(AssertionError): - sr.hints + sr.hints # noqa: B018 def test_standard_readable_hints_raises_when_overriding_sequence(): @@ -76,7 +76,7 @@ def test_standard_readable_hints_raises_when_overriding_sequence(): ) with pytest.raises(AssertionError): - sr.hints + sr.hints # noqa: B018 @pytest.mark.parametrize("invalid_type", [1, 1.0, {"abc": "def"}, {1, 2, 3}]) @@ -89,7 +89,7 @@ def test_standard_readable_hints_invalid_types(invalid_type): sr._has_hints = (hint1,) with pytest.raises(TypeError): - sr.hints + sr.hints # noqa: B018 def test_standard_readable_add_children_context_manager(): @@ -184,7 +184,7 @@ def test_standard_readable_add_readables_adds_to_expected_attrs( ], ) def test_standard_readable_add_readables_adds_wrapped_to_expected_attr( - wrapper, expected_attrs: List[str] + wrapper, expected_attrs: list[str] ): sr = StandardReadable() diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index d498e67531..133cfd5ee4 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -37,9 +37,11 @@ async def test_signal_can_be_given_backend_on_connect(): + from ophyd_async.core._signal import DISCONNECTED_BACKEND + sim_signal = SignalR() backend = MockSignalBackend(int) - assert sim_signal._backend is None + assert sim_signal._backend is DISCONNECTED_BACKEND await sim_signal.connect(mock=False, backend=backend) assert await sim_signal.get_value() == 0 diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index 16bf23567e..65316d18a8 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -1,7 +1,8 @@ import asyncio import time +from collections.abc import Callable, Sequence from enum import Enum -from typing import Any, Callable, Sequence, Tuple, Type +from typing import Any import numpy as np import numpy.typing as npt @@ -40,7 +41,7 @@ def waveform_d(value): class MonitorQueue: def __init__(self, backend: SignalBackend): self.backend = backend - self.updates: asyncio.Queue[Tuple[Reading, Any]] = asyncio.Queue() + self.updates: asyncio.Queue[tuple[Reading, Any]] = asyncio.Queue() backend.set_callback(self.add_reading_value) def add_reading_value(self, reading: Reading, value): @@ -88,7 +89,7 @@ def close(self): ], ) async def test_soft_signal_backend_get_put_monitor( - datatype: Type[T], + datatype: type[T], initial_value: T, put_value: T, descriptor: Callable[[Any], dict], diff --git a/tests/core/test_status.py b/tests/core/test_status.py index d2ec206bc8..263a39dca6 100644 --- a/tests/core/test_status.py +++ b/tests/core/test_status.py @@ -147,8 +147,8 @@ async def test_status_propogates_traceback_under_RE(RE) -> None: async def test_async_status_exception_timeout(): + st = AsyncStatus(asyncio.sleep(0.1)) try: - st = AsyncStatus(asyncio.sleep(0.1)) with pytest.raises( ValueError, match=( diff --git a/tests/core/test_subset_enum.py b/tests/core/test_subset_enum.py index 8c638d2770..3512075e7f 100644 --- a/tests/core/test_subset_enum.py +++ b/tests/core/test_subset_enum.py @@ -79,12 +79,12 @@ async def test_runtime_enum_signal(): signal_rw_ca = epics_signal_rw(SubsetEnum["A2", "B2"], "ca://RW_PV", name="signal") await signal_rw_pva.connect(mock=True) await signal_rw_ca.connect(mock=True) - await signal_rw_pva.get_value() == "A1" - await signal_rw_ca.get_value() == "A2" + assert await signal_rw_pva.get_value() == "A1" + assert await signal_rw_ca.get_value() == "A2" await signal_rw_pva.set("B1") await signal_rw_ca.set("B2") - await signal_rw_pva.get_value() == "B1" - await signal_rw_ca.get_value() == "B2" + assert await signal_rw_pva.get_value() == "B1" + assert await signal_rw_ca.get_value() == "B2" # Will accept string values even if they're not in the runtime enum # Though type checking should compain diff --git a/tests/core/test_watchable_async_status.py b/tests/core/test_watchable_async_status.py index eaf23bb61b..d3f954d239 100644 --- a/tests/core/test_watchable_async_status.py +++ b/tests/core/test_watchable_async_status.py @@ -1,6 +1,6 @@ import asyncio +from collections.abc import AsyncIterator from functools import partial -from typing import AsyncIterator import bluesky.plan_stubs as bps import pytest @@ -96,7 +96,7 @@ async def set(self, val, timeout=0.01): class ASTestDeviceIteratorSet(ASTestDevice): def __init__( - self, name: str = "", values=[1, 2, 3, 4, 5], complete_set: bool = True + self, name: str = "", values=(1, 2, 3, 4, 5), complete_set: bool = True ) -> None: self.values = values self.complete_set = complete_set diff --git a/tests/epics/adaravis/test_aravis.py b/tests/epics/adaravis/test_aravis.py index 3c34fa49eb..341ad280ed 100644 --- a/tests/epics/adaravis/test_aravis.py +++ b/tests/epics/adaravis/test_aravis.py @@ -1,11 +1,9 @@ import re import pytest -from bluesky.run_engine import RunEngine from ophyd_async.core import ( DetectorTrigger, - DeviceCollector, PathProvider, TriggerInfo, set_mock_value, @@ -14,14 +12,8 @@ @pytest.fixture -async def test_adaravis( - RE: RunEngine, - static_path_provider: PathProvider, -) -> adaravis.AravisDetector: - async with DeviceCollector(mock=True): - test_adaravis = adaravis.AravisDetector("ADARAVIS:", static_path_provider) - - return test_adaravis +def test_adaravis(ad_standard_det_factory) -> adaravis.AravisDetector: + return ad_standard_det_factory(adaravis.AravisDetector) @pytest.mark.parametrize("exposure_time", [0.0, 0.1, 1.0, 10.0, 100.0]) @@ -80,7 +72,7 @@ def test_gpio_pin_limited(test_adaravis: adaravis.AravisDetector): async def test_hints_from_hdf_writer(test_adaravis: adaravis.AravisDetector): - assert test_adaravis.hints == {"fields": ["test_adaravis"]} + assert test_adaravis.hints == {"fields": ["test_adaravis1"]} async def test_can_read(test_adaravis: adaravis.AravisDetector): @@ -98,9 +90,9 @@ async def test_decribe_describes_writer_dataset( await test_adaravis.stage() await test_adaravis.prepare(one_shot_trigger_info) assert await test_adaravis.describe() == { - "test_adaravis": { - "source": "mock+ca://ADARAVIS:HDF1:FullFileName_RBV", - "shape": (0, 0), + "test_adaravis1": { + "source": "mock+ca://ARAVIS1:HDF1:FullFileName_RBV", + "shape": (10, 10), "dtype": "array", "dtype_numpy": "|i1", "external": "STREAM:", @@ -125,12 +117,13 @@ async def test_can_collect( assert docs[0][0] == "stream_resource" stream_resource = docs[0][1] sr_uid = stream_resource["uid"] - assert stream_resource["data_key"] == "test_adaravis" + assert stream_resource["data_key"] == "test_adaravis1" assert stream_resource["uri"] == "file://localhost" + str(full_file_name) assert stream_resource["parameters"] == { "dataset": "/entry/data/data", "swmr": False, "multiplier": 1, + "chunk_shape": (1, 10, 10), } assert docs[1][0] == "stream_datum" stream_datum = docs[1][1] @@ -148,9 +141,9 @@ async def test_can_decribe_collect( await test_adaravis.stage() await test_adaravis.prepare(one_shot_trigger_info) assert (await test_adaravis.describe_collect()) == { - "test_adaravis": { - "source": "mock+ca://ADARAVIS:HDF1:FullFileName_RBV", - "shape": (0, 0), + "test_adaravis1": { + "source": "mock+ca://ARAVIS1:HDF1:FullFileName_RBV", + "shape": (10, 10), "dtype": "array", "dtype_numpy": "|i1", "external": "STREAM:", diff --git a/tests/epics/adcore/test_scans.py b/tests/epics/adcore/test_scans.py index 6889fbbe3b..f5264e9786 100644 --- a/tests/epics/adcore/test_scans.py +++ b/tests/epics/adcore/test_scans.py @@ -1,6 +1,6 @@ import asyncio from pathlib import Path -from typing import Any, Optional +from typing import Any from unittest.mock import AsyncMock, patch import bluesky.plan_stubs as bps @@ -97,7 +97,7 @@ def test_hdf_writer_fails_on_timeout_with_flyscan( controller = DummyController() set_mock_value(writer.hdf.file_path_exists, True) - detector: StandardDetector[Optional[TriggerInfo]] = StandardDetector( + detector: StandardDetector[TriggerInfo | None] = StandardDetector( controller, writer ) trigger_logic = DummyTriggerLogic() diff --git a/tests/epics/adcore/test_writers.py b/tests/epics/adcore/test_writers.py index 601504beea..af32f86667 100644 --- a/tests/epics/adcore/test_writers.py +++ b/tests/epics/adcore/test_writers.py @@ -1,4 +1,3 @@ -from typing import List from unittest.mock import patch import pytest @@ -47,6 +46,9 @@ async def hdf_writer_with_stats( hdf = adcore.NDFileHDFIO("HDF:") stats = adcore.NDPluginStatsIO("FOO:") + # Set number of frames per chunk to something reasonable + set_mock_value(hdf.num_frames_chunks, 2) + return adcore.ADHDFWriter( hdf, static_path_provider, @@ -59,7 +61,7 @@ async def hdf_writer_with_stats( @pytest.fixture async def detectors( static_path_provider: PathProvider, -) -> List[StandardDetector]: +) -> list[StandardDetector]: detectors = [] async with DeviceCollector(mock=True): detectors.append(advimba.VimbaDetector("VIMBA:", static_path_provider)) @@ -83,7 +85,7 @@ async def test_stats_describe_when_plugin_configured( set_mock_value(hdf_writer_with_stats.hdf.file_path_exists, True) set_mock_value( hdf_writer_with_stats._plugins[0].nd_attributes_file, - str(""" + """ -"""), +""", ) with patch("ophyd_async.core._signal.wait_for_value", return_value=None): descriptor = await hdf_writer_with_stats.open() @@ -135,7 +137,7 @@ async def test_stats_describe_raises_error_with_dbr_native( set_mock_value(hdf_writer_with_stats.hdf.file_path_exists, True) set_mock_value( hdf_writer_with_stats._plugins[0].nd_attributes_file, - str(""" + """ -"""), +""", ) with pytest.raises(ValueError) as e: with patch("ophyd_async.core._signal.wait_for_value", return_value=None): diff --git a/tests/epics/adkinetix/test_kinetix.py b/tests/epics/adkinetix/test_kinetix.py index a17be5e5b3..ae2f72462e 100644 --- a/tests/epics/adkinetix/test_kinetix.py +++ b/tests/epics/adkinetix/test_kinetix.py @@ -1,25 +1,18 @@ import pytest -from bluesky.run_engine import RunEngine from ophyd_async.core import ( DetectorTrigger, - DeviceCollector, StaticPathProvider, set_mock_value, ) from ophyd_async.core._detector import TriggerInfo from ophyd_async.epics import adkinetix +from ophyd_async.epics.adkinetix._kinetix_io import KinetixTriggerMode @pytest.fixture -async def test_adkinetix( - RE: RunEngine, - static_path_provider: StaticPathProvider, -) -> adkinetix.KinetixDetector: - async with DeviceCollector(mock=True): - test_adkinetix = adkinetix.KinetixDetector("KINETIX:", static_path_provider) - - return test_adkinetix +def test_adkinetix(ad_standard_det_factory): + return ad_standard_det_factory(adkinetix.KinetixDetector) async def test_get_deadtime( @@ -30,7 +23,7 @@ async def test_get_deadtime( async def test_trigger_modes(test_adkinetix: adkinetix.KinetixDetector): - set_mock_value(test_adkinetix.drv.trigger_mode, "Internal") + set_mock_value(test_adkinetix.drv.trigger_mode, KinetixTriggerMode.internal) async def setup_trigger_mode(trig_mode: DetectorTrigger): await test_adkinetix.controller.prepare( @@ -58,7 +51,7 @@ async def setup_trigger_mode(trig_mode: DetectorTrigger): async def test_hints_from_hdf_writer(test_adkinetix: adkinetix.KinetixDetector): - assert test_adkinetix.hints == {"fields": ["test_adkinetix"]} + assert test_adkinetix.hints == {"fields": ["test_adkinetix1"]} async def test_can_read(test_adkinetix: adkinetix.KinetixDetector): @@ -76,9 +69,9 @@ async def test_decribe_describes_writer_dataset( await test_adkinetix.stage() await test_adkinetix.prepare(one_shot_trigger_info) assert await test_adkinetix.describe() == { - "test_adkinetix": { - "source": "mock+ca://KINETIX:HDF1:FullFileName_RBV", - "shape": (0, 0), + "test_adkinetix1": { + "source": "mock+ca://KINETIX1:HDF1:FullFileName_RBV", + "shape": (10, 10), "dtype": "array", "dtype_numpy": "|i1", "external": "STREAM:", @@ -103,12 +96,13 @@ async def test_can_collect( assert docs[0][0] == "stream_resource" stream_resource = docs[0][1] sr_uid = stream_resource["uid"] - assert stream_resource["data_key"] == "test_adkinetix" + assert stream_resource["data_key"] == "test_adkinetix1" assert stream_resource["uri"] == "file://localhost" + str(full_file_name) assert stream_resource["parameters"] == { "dataset": "/entry/data/data", "swmr": False, "multiplier": 1, + "chunk_shape": (1, 10, 10), } assert docs[1][0] == "stream_datum" stream_datum = docs[1][1] @@ -126,9 +120,9 @@ async def test_can_decribe_collect( await test_adkinetix.stage() await test_adkinetix.prepare(one_shot_trigger_info) assert (await test_adkinetix.describe_collect()) == { - "test_adkinetix": { - "source": "mock+ca://KINETIX:HDF1:FullFileName_RBV", - "shape": (0, 0), + "test_adkinetix1": { + "source": "mock+ca://KINETIX1:HDF1:FullFileName_RBV", + "shape": (10, 10), "dtype": "array", "dtype_numpy": "|i1", "external": "STREAM:", diff --git a/tests/epics/adpilatus/test_pilatus.py b/tests/epics/adpilatus/test_pilatus.py index 72145ff0cc..7d6145f5a5 100644 --- a/tests/epics/adpilatus/test_pilatus.py +++ b/tests/epics/adpilatus/test_pilatus.py @@ -1,5 +1,5 @@ import asyncio -from typing import Awaitable, Callable +from collections.abc import Awaitable, Callable from unittest.mock import patch import pytest diff --git a/tests/epics/adsimdetector/test_sim.py b/tests/epics/adsimdetector/test_sim.py index 9e69733680..fd99fd0497 100644 --- a/tests/epics/adsimdetector/test_sim.py +++ b/tests/epics/adsimdetector/test_sim.py @@ -3,7 +3,7 @@ import time from collections import defaultdict from pathlib import Path -from typing import List, cast +from typing import cast import bluesky.plan_stubs as bps import bluesky.preprocessors as bpp @@ -43,7 +43,7 @@ def _set_full_file_name(val, *args, **kwargs): return det -def count_sim(dets: List[StandardDetector], times: int = 1): +def count_sim(dets: list[StandardDetector], times: int = 1): """Test plan to do the equivalent of bp.count for a sim detector.""" yield from bps.stage_all(*dets) @@ -110,7 +110,7 @@ async def two_detectors(tmp_path: Path): async def test_two_detectors_fly_different_rate( - two_detectors: List[adsimdetector.SimDetector], RE: RunEngine + two_detectors: list[adsimdetector.SimDetector], RE: RunEngine ): trigger_info = TriggerInfo( number=15, @@ -174,7 +174,7 @@ def fly_plan(): async def test_two_detectors_step( - two_detectors: List[StandardDetector], + two_detectors: list[StandardDetector], RE: RunEngine, ): names = [] @@ -235,12 +235,8 @@ def plan(): assert descriptor["data_keys"]["testb"]["shape"] == (769, 1025) assert sda["stream_resource"] == sra["uid"] assert sdb["stream_resource"] == srb["uid"] - assert srb["uri"] == str("file://localhost") + str( - info_b.directory_path / file_name_b - ) - assert sra["uri"] == str("file://localhost") + str( - info_a.directory_path / file_name_a - ) + assert srb["uri"] == "file://localhost" + str(info_b.directory_path / file_name_b) + assert sra["uri"] == "file://localhost" + str(info_a.directory_path / file_name_a) assert event["data"] == {} diff --git a/tests/epics/advimba/test_vimba.py b/tests/epics/advimba/test_vimba.py index ec93cc07d3..a8990502c3 100644 --- a/tests/epics/advimba/test_vimba.py +++ b/tests/epics/advimba/test_vimba.py @@ -1,25 +1,22 @@ import pytest -from bluesky.run_engine import RunEngine from ophyd_async.core import ( DetectorTrigger, - DeviceCollector, PathProvider, set_mock_value, ) from ophyd_async.core._detector import TriggerInfo from ophyd_async.epics import advimba +from ophyd_async.epics.advimba._vimba_io import ( + VimbaExposeOutMode, + VimbaOnOff, + VimbaTriggerSource, +) @pytest.fixture -async def test_advimba( - RE: RunEngine, - static_path_provider: PathProvider, -) -> advimba.VimbaDetector: - async with DeviceCollector(mock=True): - test_advimba = advimba.VimbaDetector("VIMBA:", static_path_provider) - - return test_advimba +def test_advimba(ad_standard_det_factory) -> advimba.VimbaDetector: + return ad_standard_det_factory(advimba.VimbaDetector) async def test_get_deadtime( @@ -30,9 +27,9 @@ async def test_get_deadtime( async def test_arming_trig_modes(test_advimba: advimba.VimbaDetector): - set_mock_value(test_advimba.drv.trigger_source, "Freerun") - set_mock_value(test_advimba.drv.trigger_mode, "Off") - set_mock_value(test_advimba.drv.exposure_mode, "Timed") + set_mock_value(test_advimba.drv.trigger_source, VimbaTriggerSource.freerun) + set_mock_value(test_advimba.drv.trigger_mode, VimbaOnOff.off) + set_mock_value(test_advimba.drv.exposure_mode, VimbaExposeOutMode.timed) async def setup_trigger_mode(trig_mode: DetectorTrigger): await test_advimba.controller.prepare(TriggerInfo(number=1, trigger=trig_mode)) @@ -68,7 +65,7 @@ async def setup_trigger_mode(trig_mode: DetectorTrigger): async def test_hints_from_hdf_writer(test_advimba: advimba.VimbaDetector): - assert test_advimba.hints == {"fields": ["test_advimba"]} + assert test_advimba.hints == {"fields": ["test_advimba1"]} async def test_can_read(test_advimba: advimba.VimbaDetector): @@ -86,9 +83,9 @@ async def test_decribe_describes_writer_dataset( await test_advimba.stage() await test_advimba.prepare(one_shot_trigger_info) assert await test_advimba.describe() == { - "test_advimba": { - "source": "mock+ca://VIMBA:HDF1:FullFileName_RBV", - "shape": (0, 0), + "test_advimba1": { + "source": "mock+ca://VIMBA1:HDF1:FullFileName_RBV", + "shape": (10, 10), "dtype": "array", "dtype_numpy": "|i1", "external": "STREAM:", @@ -113,12 +110,13 @@ async def test_can_collect( assert docs[0][0] == "stream_resource" stream_resource = docs[0][1] sr_uid = stream_resource["uid"] - assert stream_resource["data_key"] == "test_advimba" + assert stream_resource["data_key"] == "test_advimba1" assert stream_resource["uri"] == "file://localhost" + str(full_file_name) assert stream_resource["parameters"] == { "dataset": "/entry/data/data", "swmr": False, "multiplier": 1, + "chunk_shape": (1, 10, 10), } assert docs[1][0] == "stream_datum" stream_datum = docs[1][1] @@ -136,9 +134,9 @@ async def test_can_decribe_collect( await test_advimba.stage() await test_advimba.prepare(one_shot_trigger_info) assert (await test_advimba.describe_collect()) == { - "test_advimba": { - "source": "mock+ca://VIMBA:HDF1:FullFileName_RBV", - "shape": (0, 0), + "test_advimba1": { + "source": "mock+ca://VIMBA1:HDF1:FullFileName_RBV", + "shape": (10, 10), "dtype": "array", "dtype_numpy": "|i1", "external": "STREAM:", diff --git a/tests/epics/conftest.py b/tests/epics/conftest.py new file mode 100644 index 0000000000..6d8f331765 --- /dev/null +++ b/tests/epics/conftest.py @@ -0,0 +1,38 @@ +from collections.abc import Callable + +import pytest +from bluesky.run_engine import RunEngine + +from ophyd_async.core._detector import StandardDetector +from ophyd_async.core._device import DeviceCollector +from ophyd_async.core._mock_signal_utils import set_mock_value + + +@pytest.fixture +def ad_standard_det_factory( + RE: RunEngine, + static_path_provider, +) -> Callable: + def generate_ad_standard_det( + ad_standard_detector_class, number=1 + ) -> StandardDetector: + # Dynamically generate a name based on the class of detector + detector_name = ad_standard_detector_class.__name__ + if detector_name.endswith("Detector"): + detector_name = detector_name[: -len("Detector")] + + with DeviceCollector(mock=True): + test_adstandard_det = ad_standard_detector_class( + f"{detector_name.upper()}{number}:", + static_path_provider, + name=f"test_ad{detector_name.lower()}{number}", + ) + + # Set number of frames per chunk and frame dimensions to something reasonable + set_mock_value(test_adstandard_det.hdf.num_frames_chunks, 1) + set_mock_value(test_adstandard_det.drv.array_size_x, 10) + set_mock_value(test_adstandard_det.drv.array_size_y, 10) + + return test_adstandard_det + + return generate_ad_standard_det diff --git a/tests/epics/demo/test_demo.py b/tests/epics/demo/test_demo.py index 5fe63215d0..82039c661e 100644 --- a/tests/epics/demo/test_demo.py +++ b/tests/epics/demo/test_demo.py @@ -1,7 +1,6 @@ import asyncio import subprocess from collections import defaultdict -from typing import Dict from unittest.mock import ANY, Mock, call, patch import pytest @@ -218,7 +217,7 @@ async def test_read_mover(mock_mover: demo.Mover): async def test_set_velocity(mock_mover: demo.Mover) -> None: v = mock_mover.velocity - q: asyncio.Queue[Dict[str, Reading]] = asyncio.Queue() + q: asyncio.Queue[dict[str, Reading]] = asyncio.Queue() v.subscribe(q.put_nowait) assert (await q.get())["mock_mover-velocity"]["value"] == 1.0 await v.set(2.0) @@ -229,7 +228,7 @@ async def test_set_velocity(mock_mover: demo.Mover) -> None: assert q.empty() -async def test_mover_disconncted(): +async def test_mover_disconnected(): with pytest.raises(NotConnected): async with DeviceCollector(timeout=0.1): m = demo.Mover("ca://PRE:", name="mover") diff --git a/tests/epics/pvi/test_pvi.py b/tests/epics/pvi/test_pvi.py index 69152dc343..caa9468526 100644 --- a/tests/epics/pvi/test_pvi.py +++ b/tests/epics/pvi/test_pvi.py @@ -1,5 +1,3 @@ -from typing import Optional - import pytest from ophyd_async.core import ( @@ -28,7 +26,7 @@ class Block2(Device): class Block3(Device): - device_vector: Optional[DeviceVector[Block2]] + device_vector: DeviceVector[Block2] | None device: Block2 signal_device: Block1 signal_x: SignalX @@ -44,7 +42,7 @@ def __init__(self, prefix: str, name: str = ""): self._prefix = prefix super().__init__(name) - async def connect( + async def connect( # type: ignore self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT ) -> None: await fill_pvi_entries( @@ -67,9 +65,9 @@ async def test_fill_pvi_entries_mock_mode(pvi_test_device_t): # elements of device vectors are typed recursively assert test_device.device_vector[1].signal_rw._backend.datatype is int assert isinstance(test_device.device_vector[1].device, Block1) - assert test_device.device_vector[1].device.signal_rw._backend.datatype is int + assert test_device.device_vector[1].device.signal_rw._backend.datatype is int # type: ignore assert ( - test_device.device_vector[1].device.device_vector_signal_rw[1]._backend.datatype + test_device.device_vector[1].device.device_vector_signal_rw[1]._backend.datatype # type: ignore is float ) @@ -78,9 +76,9 @@ async def test_fill_pvi_entries_mock_mode(pvi_test_device_t): assert isinstance(test_device.device, Block2) # elements of top level blocks are typed recursively - assert test_device.device.signal_rw._backend.datatype is int + assert test_device.device.signal_rw._backend.datatype is int # type: ignore assert isinstance(test_device.device.device, Block1) - assert test_device.device.device.signal_rw._backend.datatype is int + assert test_device.device.device.signal_rw._backend.datatype is int # type: ignore assert test_device.signal_rw.parent == test_device assert test_device.device_vector.parent == test_device @@ -108,7 +106,7 @@ def __init__(self, prefix: str, name: str = ""): super().__init__(name) create_children_from_annotations(self) - async def connect( + async def connect( # type: ignore self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT ) -> None: await fill_pvi_entries( @@ -152,9 +150,9 @@ def pvi_test_device_with_device_vectors_t(): class TestBlock(Device): device_vector: DeviceVector[Block1] - device: Optional[Block1] + device: Block1 | None signal_x: SignalX - signal_rw: Optional[SignalRW[int]] + signal_rw: SignalRW[int] | None class TestDevice(TestBlock): def __init__(self, prefix: str, name: str = ""): @@ -166,7 +164,7 @@ def __init__(self, prefix: str, name: str = ""): ) super().__init__(name) - async def connect( + async def connect( # type: ignore self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT ) -> None: await fill_pvi_entries( diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index 8e346f324b..e2e5c20f7d 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -5,19 +5,21 @@ import subprocess import sys import time +from collections.abc import Sequence from contextlib import closing from dataclasses import dataclass from enum import Enum from pathlib import Path from types import GenericAlias -from typing import Any, Dict, Literal, Optional, Sequence, Tuple, Type +from typing import Any, Literal from unittest.mock import ANY import numpy as np import numpy.typing as npt import pytest from aioca import CANothing, purge_channel_caches -from bluesky.protocols import DataKey, Reading +from bluesky.protocols import Reading +from event_model import DataKey from typing_extensions import TypedDict from ophyd_async.core import ( @@ -49,23 +51,23 @@ class IOC: protocol: Literal["ca", "pva"] async def make_backend( - self, typ: Optional[Type], suff: str, connect=True + self, typ: type | None, suff: str, connect=True ) -> SignalBackend: # Calculate the pv pv = f"{PV_PREFIX}:{self.protocol}:{suff}" # Make and connect the backend cls = _EpicsTransport[self.protocol].value - backend = cls(typ, pv, pv) + backend = cls(typ, pv, pv) # type: ignore if connect: - await asyncio.wait_for(backend.connect(), 10) - return backend + await asyncio.wait_for(backend.connect(), 10) # type: ignore + return backend # type: ignore # Use a module level fixture per protocol so it's fast to run tests. This means # we need to add a record for every PV that we will modify in tests to stop # tests interfering with each other @pytest.fixture(scope="module", params=["pva", "ca"]) -def ioc(request): +def ioc(request: pytest.FixtureRequest): protocol = request.param process = subprocess.Popen( [ @@ -85,7 +87,7 @@ def ioc(request): start_time = time.monotonic() while "iocRun: All initialization complete" not in ( - process.stdout.readline().strip() + process.stdout.readline().strip() # type: ignore ): if time.monotonic() - start_time > 10: raise TimeoutError("IOC did not start in time") @@ -127,7 +129,7 @@ class MonitorQueue: def __init__(self, backend: SignalBackend): self.backend = backend self.subscription = backend.set_callback(self.add_reading_value) - self.updates: asyncio.Queue[Tuple[Reading, Any]] = asyncio.Queue() + self.updates: asyncio.Queue[tuple[Reading, Any]] = asyncio.Queue() def add_reading_value(self, reading: Reading, value): self.updates.put_nowait((reading, value)) @@ -165,8 +167,8 @@ async def assert_monitor_then_put( datakey: dict, initial_value: T, put_value: T, - datatype: Optional[Type[T]] = None, - check_type: Optional[bool] = True, + datatype: type[T] | None = None, + check_type: bool | None = True, ): backend = await ioc.make_backend(datatype, suffix) # Make a monitor queue that will monitor for updates @@ -193,7 +195,7 @@ async def put_error( ioc: IOC, suffix: str, put_value: T, - datatype: Optional[Type[T]] = None, + datatype: type[T] | None = None, ): backend = await ioc.make_backend(datatype, suffix) # The below will work without error @@ -211,7 +213,7 @@ class MyEnum(str, Enum): MySubsetEnum = SubsetEnum["Aaa", "Bbb", "Ccc"] -_metadata: Dict[str, Dict[str, Dict[str, Any]]] = { +_metadata: dict[str, dict[str, dict[str, Any]]] = { "ca": { "boolean": {"units": ANY, "limits": ANY}, "integer": {"units": ANY, "limits": ANY}, @@ -249,7 +251,7 @@ def get_dtype(suffix: str) -> str: return "string" return get_internal_dtype(suffix) - def get_dtype_numpy(suffix: str) -> str: + def get_dtype_numpy(suffix: str) -> str: # type: ignore if "float32" in suffix: return " str: d = { "dtype": dtype, "dtype_numpy": dtype_numpy, - "shape": [len(value)] if dtype == "array" else [], + "shape": [len(value)] if dtype == "array" else [], # type: ignore } if get_internal_dtype(suffix) == "enum": if issubclass(type(value), Enum): - d["choices"] = [e.value for e in type(value)] + d["choices"] = [e.value for e in type(value)] # type: ignore else: - d["choices"] = list(value.choices) + d["choices"] = list(value.choices) # type: ignore d.update(_metadata[protocol].get(get_internal_dtype(suffix), {})) - return d + return d # type: ignore ls1 = "a string that is just longer than forty characters" @@ -396,11 +398,11 @@ def get_dtype_numpy(suffix: str) -> str: ) async def test_backend_get_put_monitor( ioc: IOC, - datatype: Type[T], + datatype: type[T], suffix: str, initial_value: T, put_value: T, - tmp_path, + tmp_path: Path, supported_backends: set[str], ): # ca can't support all the types @@ -413,7 +415,7 @@ async def test_backend_get_put_monitor( await assert_monitor_then_put( ioc, suffix, - datakey(ioc.protocol, suffix, initial_value), + datakey(ioc.protocol, suffix, initial_value), # type: ignore initial_value, put_value, datatype, @@ -422,7 +424,7 @@ async def test_backend_get_put_monitor( await assert_monitor_then_put( ioc, suffix, - datakey(ioc.protocol, suffix, put_value), + datakey(ioc.protocol, suffix, put_value), # type: ignore put_value, initial_value, datatype=None, @@ -435,7 +437,7 @@ async def test_backend_get_put_monitor( @pytest.mark.parametrize("suffix", ["bool", "bool_unnamed"]) -async def test_bool_conversion_of_enum(ioc: IOC, suffix: str, tmp_path) -> None: +async def test_bool_conversion_of_enum(ioc: IOC, suffix: str, tmp_path: Path) -> None: """Booleans are converted to Short Enumerations with values 0,1 as database does not support boolean natively. The flow of test_backend_get_put_monitor Gets a value with a dtype of None: we @@ -619,7 +621,18 @@ async def test_pva_table(ioc: IOC) -> None: enum=[MyEnum.c, MyEnum.b], ) # TODO: what should this be for a variable length table? - datakey = {"dtype": "object", "shape": [], "source": "test-source"} + datakey = { + "dtype": "object", + "shape": [], + "source": "test-source", + "dtype_numpy": "", + "limits": { + "alarm": {"high": None, "low": None}, + "control": {"high": None, "low": None}, + "display": {"high": None, "low": None}, + "warning": {"high": None, "low": None}, + }, + } # Make and connect the backend for t, i, p in [(MyTable, initial, put), (None, put, initial)]: backend = await ioc.make_backend(t, "table") @@ -627,7 +640,7 @@ async def test_pva_table(ioc: IOC) -> None: q = MonitorQueue(backend) try: # Check datakey - datakey == await backend.get_datakey("test-source") + assert datakey == await backend.get_datakey("test-source") # Check initial value await q.assert_updates(approx_table(i)) # Put to new value and check that @@ -642,7 +655,7 @@ async def test_pvi_structure(ioc: IOC) -> None: # CA can't do structure return # Make and connect the backend - backend = await ioc.make_backend(Dict[str, Any], "pvi") + backend = await ioc.make_backend(dict[str, Any], "pvi") # Make a monitor queue that will monitor for updates q = MonitorQueue(backend) @@ -726,8 +739,11 @@ def test_make_backend_fails_for_different_transports(): with pytest.raises(TypeError) as err: epics_signal_rw(str, read_pv, write_pv) - assert err.args[0] == f"Differing transports: {read_pv} has EpicsTransport.ca," - +" {write_pv} has EpicsTransport.pva" + assert ( + err.args[0] + == f"Differing transports: {read_pv} has EpicsTransport.ca," + + " {write_pv} has EpicsTransport.pva" + ) def test_signal_helpers(): diff --git a/tests/epics/test_motor.py b/tests/epics/test_motor.py index 51d778007d..12fdfd8c57 100644 --- a/tests/epics/test_motor.py +++ b/tests/epics/test_motor.py @@ -1,14 +1,13 @@ import asyncio import time -from typing import Dict from unittest.mock import AsyncMock, MagicMock, Mock, call import pytest from bluesky.protocols import Reading from ophyd_async.core import ( + CALCULATE_TIMEOUT, AsyncStatus, - CalculateTimeout, DeviceCollector, MockSignalBackend, SignalRW, @@ -184,7 +183,7 @@ async def test_read_motor(sim_motor: motor.Motor): async def test_set_velocity(sim_motor: motor.Motor) -> None: v = sim_motor.velocity - q: asyncio.Queue[Dict[str, Reading]] = asyncio.Queue() + q: asyncio.Queue[dict[str, Reading]] = asyncio.Queue() v.subscribe(q.put_nowait) assert (await q.get())["sim_motor-velocity"]["value"] == 1.0 await v.set(2.0) @@ -316,7 +315,7 @@ async def test_kickoff(sim_motor: motor.Motor): await sim_motor.kickoff() sim_motor._fly_completed_position = 20 await sim_motor.kickoff() - sim_motor.set.assert_called_once_with(20, timeout=CalculateTimeout) + sim_motor.set.assert_called_once_with(20, timeout=CALCULATE_TIMEOUT) async def test_complete(sim_motor: motor.Motor) -> None: diff --git a/tests/fastcs/panda/test_hdf_panda.py b/tests/fastcs/panda/test_hdf_panda.py index ab5acd5c20..41f50d648d 100644 --- a/tests/fastcs/panda/test_hdf_panda.py +++ b/tests/fastcs/panda/test_hdf_panda.py @@ -1,5 +1,4 @@ import os -from typing import Dict from unittest.mock import ANY import bluesky.plan_stubs as bps @@ -147,7 +146,7 @@ def flying_plan(): ) # test descriptor - data_key_names: Dict[str, str] = docs["descriptor"][0]["object_keys"]["panda"] + data_key_names: dict[str, str] = docs["descriptor"][0]["object_keys"]["panda"] assert data_key_names == ["x", "y"] for data_key_name in data_key_names: assert ( @@ -157,25 +156,22 @@ def flying_plan(): # test stream resources for dataset_name, stream_resource, data_key_name in zip( - ("x", "y"), docs["stream_resource"], data_key_names + ("x", "y"), docs["stream_resource"], data_key_names, strict=False ): - - def assert_resource_document(): - assert stream_resource == { - "run_start": docs["start"][0]["uid"], - "uid": ANY, - "data_key": data_key_name, - "mimetype": "application/x-hdf5", - "uri": "file://localhost" + str(tmp_path / "test-panda.h5"), - "parameters": { - "dataset": f"/{dataset_name}", - "swmr": False, - "multiplier": 1, - }, - } - assert "test-panda.h5" in stream_resource["uri"] - - assert_resource_document() + assert stream_resource == { + "run_start": docs["start"][0]["uid"], + "uid": ANY, + "data_key": data_key_name, + "mimetype": "application/x-hdf5", + "uri": "file://localhost" + str(tmp_path / "test-panda.h5"), + "parameters": { + "dataset": f"/{dataset_name}", + "swmr": False, + "multiplier": 1, + "chunk_shape": (1024,), + }, + } + assert "test-panda.h5" in stream_resource["uri"] # test stream datum for stream_datum in docs["stream_datum"]: @@ -228,7 +224,7 @@ def flying_plan(): yield from bps.declare_stream(mock_hdf_panda, name="main_stream", collect=True) - for i in range(iteration): + for _ in range(iteration): set_mock_value(flyer.trigger_logic.seq.active, 1) yield from bps.kickoff(flyer, wait=True) yield from bps.kickoff(mock_hdf_panda) @@ -267,7 +263,7 @@ def flying_plan(): ) # test descriptor - data_key_names: Dict[str, str] = docs["descriptor"][0]["object_keys"]["panda"] + data_key_names: dict[str, str] = docs["descriptor"][0]["object_keys"]["panda"] assert data_key_names == ["x", "y"] for data_key_name in data_key_names: assert ( @@ -277,25 +273,22 @@ def flying_plan(): # test stream resources for dataset_name, stream_resource, data_key_name in zip( - ("x", "y"), docs["stream_resource"], data_key_names + ("x", "y"), docs["stream_resource"], data_key_names, strict=False ): - - def assert_resource_document(): - assert stream_resource == { - "run_start": docs["start"][0]["uid"], - "uid": ANY, - "data_key": data_key_name, - "mimetype": "application/x-hdf5", - "uri": "file://localhost" + str(tmp_path / "test-panda.h5"), - "parameters": { - "dataset": f"/{dataset_name}", - "swmr": False, - "multiplier": 1, - }, - } - assert "test-panda.h5" in stream_resource["uri"] - - assert_resource_document() + assert stream_resource == { + "run_start": docs["start"][0]["uid"], + "uid": ANY, + "data_key": data_key_name, + "mimetype": "application/x-hdf5", + "uri": "file://localhost" + str(tmp_path / "test-panda.h5"), + "parameters": { + "dataset": f"/{dataset_name}", + "swmr": False, + "multiplier": 1, + "chunk_shape": (1024,), + }, + } + assert "test-panda.h5" in stream_resource["uri"] # test stream datum for stream_datum in docs["stream_datum"]: diff --git a/tests/fastcs/panda/test_panda_connect.py b/tests/fastcs/panda/test_panda_connect.py index b6dcbe2b00..61cad0dddc 100644 --- a/tests/fastcs/panda/test_panda_connect.py +++ b/tests/fastcs/panda/test_panda_connect.py @@ -1,7 +1,6 @@ """Used to test setting up signals for a PandA""" import copy -from typing import Dict import numpy as np import pytest @@ -33,7 +32,7 @@ def todict(self): class MockPvi: - def __init__(self, pvi: Dict[str, _PVIEntry]) -> None: + def __init__(self, pvi: dict[str, _PVIEntry]) -> None: self.pvi = pvi def get(self, item: str): @@ -41,7 +40,7 @@ def get(self, item: str): class MockCtxt: - def __init__(self, pvi: Dict[str, _PVIEntry]) -> None: + def __init__(self, pvi: dict[str, _PVIEntry]) -> None: self.pvi = copy.copy(pvi) def get(self, pv: str, timeout: float = 0.0): diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index d8b9a01269..b10e076a0f 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -58,7 +58,7 @@ def check_equal_with_seq_tables(actual, expected): ) # Load the YAML content as a string - with open(str(tmp_path / "panda.yaml"), "r") as file: + with open(str(tmp_path / "panda.yaml")) as file: yaml_content = file.read() # Parse the YAML content diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index c115ea1865..ed963c91f5 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -200,19 +200,25 @@ def _assert_col_equal(column1, column2): assert all(isinstance(x, SeqTrigger) for x in column2) seq_table_from_pva_dict = SeqTable(**pva_dict) - for (_, column1), column2 in zip(seq_table_from_pva_dict, pva_dict.values()): + for (_, column1), column2 in zip( + seq_table_from_pva_dict, pva_dict.values(), strict=False + ): _assert_col_equal(column1, column2) seq_table_from_rows = reduce( lambda x, y: x + y, [SeqTable.row(**row_kwargs) for row_kwargs in row_wise_dicts], ) - for (_, column1), column2 in zip(seq_table_from_rows, pva_dict.values()): + for (_, column1), column2 in zip( + seq_table_from_rows, pva_dict.values(), strict=False + ): _assert_col_equal(column1, column2) # Idempotency applied_twice_to_pva_dict = SeqTable(**pva_dict).model_dump(mode="python") - for column1, column2 in zip(applied_twice_to_pva_dict.values(), pva_dict.values()): + for column1, column2 in zip( + applied_twice_to_pva_dict.values(), pva_dict.values(), strict=False + ): _assert_col_equal(column1, column2) assert np.array_equal( diff --git a/tests/fastcs/panda/test_writer.py b/tests/fastcs/panda/test_writer.py index dc26787cbf..e7298568d1 100644 --- a/tests/fastcs/panda/test_writer.py +++ b/tests/fastcs/panda/test_writer.py @@ -142,7 +142,7 @@ async def test_open_returns_correct_descriptors( assert "DATASETS table is empty!" in caplog.text for key, entry, expected_key in zip( - description.keys(), description.values(), table["name"] + description.keys(), description.values(), table["name"], strict=False ): assert key == expected_key assert entry == { @@ -209,7 +209,12 @@ def assert_resource_document(name, resource_doc): "data_key": name, "mimetype": "application/x-hdf5", "uri": "file://localhost" + str(tmp_path / "mock_panda" / "data.h5"), - "parameters": {"dataset": f"/{name}", "swmr": False, "multiplier": 1}, + "parameters": { + "dataset": f"/{name}", + "swmr": False, + "multiplier": 1, + "chunk_shape": (1024,), + }, } assert "mock_panda/data.h5" in resource_doc["uri"] diff --git a/tests/plan_stubs/test_fly.py b/tests/plan_stubs/test_fly.py index 06e40e29db..866ee5763b 100644 --- a/tests/plan_stubs/test_fly.py +++ b/tests/plan_stubs/test_fly.py @@ -1,12 +1,12 @@ import time -from typing import AsyncGenerator, AsyncIterator, Dict, Optional, Sequence +from collections.abc import AsyncGenerator, AsyncIterator, Sequence from unittest.mock import Mock import bluesky.plan_stubs as bps import pytest -from bluesky.protocols import DataKey, StreamAsset +from bluesky.protocols import StreamAsset from bluesky.run_engine import RunEngine -from event_model import ComposeStreamResourceBundle, compose_stream_resource +from event_model import ComposeStreamResourceBundle, DataKey, compose_stream_resource from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -42,12 +42,12 @@ def __init__(self, name: str, shape: Sequence[int]): self.dummy_signal = epics_signal_rw(int, "pva://read_pv") self._shape = shape self._name = name - self._file: Optional[ComposeStreamResourceBundle] = None + self._file: ComposeStreamResourceBundle | None = None self._last_emitted = 0 self.index = 0 self.observe_indices_written_timeout_log = [] - async def open(self, multiplier: int = 1) -> Dict[str, DataKey]: + async def open(self, multiplier: int = 1) -> dict[str, DataKey]: return { self._name: DataKey( source="soft://some-source", @@ -264,7 +264,7 @@ def flying_plan(): ) for detector in detector_list: - detector.controller.disarm.assert_called_once # type: ignore + detector.controller.disarm.assert_called_once() # type: ignore yield from bps.open_run() yield from bps.declare_stream(*detector_list, name="main_stream", collect=True) diff --git a/tests/sim/test_sim_writer.py b/tests/sim/test_sim_writer.py index 89f2540e86..5ee2bba97a 100644 --- a/tests/sim/test_sim_writer.py +++ b/tests/sim/test_sim_writer.py @@ -21,13 +21,13 @@ async def test_correct_descriptor_doc_after_open(writer: PatternDetectorWriter): assert descriptor == { "NAME": { "source": "sim://pattern-generator-hdf-file", - "shape": (240, 320), + "shape": [240, 320], "dtype": "array", "external": "STREAM:", }, "NAME-sum": { "source": "sim://pattern-generator-hdf-file", - "shape": (), + "shape": [], "dtype": "number", "external": "STREAM:", },