diff --git a/pyproject.toml b/pyproject.toml index dde5b0816..c2d7fa354 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,10 @@ classifiers = [ ] description = "Lightweight bluesky-as-a-service wrapper application. Also usable as a library." dependencies = [ + "tiled", + "json_merge_patch", + "jsonpatch", + "pyarrow", "bluesky>=1.13", "ophyd", "nslsii", @@ -95,7 +99,8 @@ addopts = """ --ignore=src/blueapi/startup """ # https://iscinumpy.gitlab.io/post/bound-version-constraints/#watch-for-warnings -filterwarnings = ["error", "ignore::DeprecationWarning"] +# Unignore UserWarning after Pydantic warning removed from bluesky/bluesky and release +filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::UserWarning"] # Doctest python code in docs, python code in src docstrings, test functions in tests testpaths = "docs src tests" asyncio_mode = "auto" diff --git a/src/blueapi/config.py b/src/blueapi/config.py index e4581a663..1055eea8b 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -41,6 +41,15 @@ class StompConfig(BaseModel): auth: BasicAuthentication | None = None +class TiledConfig(BaseModel): + """ + Config for connecting to a tiled instance + """ + + host: str + port: int + + class WorkerEventConfig(BlueapiBaseModel): """ Config for event broadcasting via the message bus @@ -160,6 +169,7 @@ class ApplicationConfig(BlueapiBaseModel): """ stomp: StompConfig | None = None + tiled: TiledConfig | None = None env: EnvironmentConfig = Field(default_factory=EnvironmentConfig) logging: LoggingConfig = Field(default_factory=LoggingConfig) api: RestConfig = Field(default_factory=RestConfig) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 803841964..162064c14 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -13,6 +13,7 @@ from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask +from blueapi.worker.tiled import TiledConnection """This module provides interface between web application and underlying Bluesky context and worker""" @@ -40,9 +41,11 @@ def context() -> BlueskyContext: @cache def worker() -> TaskWorker: + conf = config() worker = TaskWorker( context(), - broadcast_statuses=config().env.events.broadcast_status_events, + broadcast_statuses=conf.env.events.broadcast_status_events, + tiled_inserter=TiledConnection(conf.tiled) if conf.tiled else None, ) worker.start() return worker diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index 546c5f3b5..67c22119c 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -9,6 +9,7 @@ from typing import Any, Generic, TypeVar from bluesky.protocols import Status +from httpx import Headers from observability_utils.tracing import ( add_span_attributes, get_tracer, @@ -32,6 +33,7 @@ from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop from blueapi.utils.base_model import BlueapiBaseModel from blueapi.utils.thread_exception import handle_all_exceptions +from blueapi.worker.tiled import TiledConnection from .event import ( ProgressEvent, @@ -112,9 +114,11 @@ def __init__( ctx: BlueskyContext, start_stop_timeout: float = DEFAULT_START_STOP_TIMEOUT, broadcast_statuses: bool = True, + tiled_inserter: TiledConnection | None = None, ) -> None: self._ctx = ctx self._start_stop_timeout = start_stop_timeout + self._tiled_inserter = tiled_inserter self._tasks = {} @@ -194,13 +198,25 @@ def get_active_task(self) -> TrackableTask[Task] | None: return current @start_as_current_span(TRACER, "task_id") - def begin_task(self, task_id: str) -> None: + def begin_task(self, task_id: str, headers: Headers | None) -> None: task = self._tasks.get(task_id) + data_subs: list[int] = [] if task is not None: - self._submit_trackable_task(task) + if self._tiled_inserter: + data_subs.append(self._authorize_running_task(headers)) + self._submit_trackable_task(task, data_subs) + else: raise KeyError(f"No pending task with ID {task_id}") + def _authorize_running_task(self, headers: Headers | None) -> int: + assert self._tiled_inserter + # https://github.com/DiamondLightSource/blueapi/issues/774 + # If users should only be able to run their own scans, pass headers + # as part of submitting a task, cache in TrackableTask field and check + # that token belongs to same user (but may be newer token!) + return self.data_events.subscribe(self._tiled_inserter(headers)) + @start_as_current_span(TRACER, "task.name", "task.params") def submit_task(self, task: Task) -> str: task.prepare_params(self._ctx) # Will raise if parameters are invalid @@ -218,7 +234,9 @@ def submit_task(self, task: Task) -> str: "trackable_task.task.name", "trackable_task.task.params", ) - def _submit_trackable_task(self, trackable_task: TrackableTask) -> None: + def _submit_trackable_task( + self, trackable_task: TrackableTask, data_subs: list[int] | None = None + ) -> None: if self.state is not WorkerState.IDLE: raise WorkerBusyError(f"Worker is in state {self.state}") @@ -235,17 +253,18 @@ def mark_task_as_started(event: WorkerEvent, _: str | None) -> None: sub = self.worker_events.subscribe(mark_task_as_started) try: self._current_task_otel_context = get_current() - sub = self.worker_events.subscribe(mark_task_as_started) """ Cache the current trace context as the one for this task id """ self._task_channel.put_nowait(trackable_task) - task_started.wait(timeout=5.0) - if not task_started.is_set(): + if not task_started.wait(timeout=5.0): raise TimeoutError("Failed to start plan within timeout") except Full as f: LOGGER.error("Cannot submit task while another is running") raise WorkerBusyError("Cannot submit task while another is running") from f finally: self.worker_events.unsubscribe(sub) + if data_subs: + for data_sub in data_subs: + self.data_events.unsubscribe(data_sub) @start_as_current_span(TRACER) def start(self) -> None: diff --git a/src/blueapi/worker/tiled.py b/src/blueapi/worker/tiled.py new file mode 100644 index 000000000..816f0156a --- /dev/null +++ b/src/blueapi/worker/tiled.py @@ -0,0 +1,23 @@ +from bluesky.callbacks.tiled_writer import TiledWriter +from httpx import Headers +from tiled.client import from_context +from tiled.client.context import Context as TiledContext + +from blueapi.config import TiledConfig +from blueapi.core.bluesky_types import DataEvent + + +class TiledConverter: + def __init__(self, tiled_context: TiledContext): + self._writer: TiledWriter = TiledWriter(from_context(tiled_context)) + + def __call__(self, data: DataEvent, _: str | None = None) -> None: + self._writer(data.name, data.doc) + + +class TiledConnection: + def __init__(self, config: TiledConfig): + self.uri = f"{config.host}:{config.port}" + + def __call__(self, headers: Headers | None): + return TiledConverter(TiledContext(self.uri, headers=headers)) diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 67517a691..3caf99a19 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -224,6 +224,7 @@ def temp_yaml_config_file( "logging": {"level": "INFO"}, "api": {"host": "0.0.0.0", "port": 8000, "protocol": "http"}, "scratch": None, + "tiled": None, }, ], indirect=True, @@ -285,6 +286,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "tiled": None, }, { "stomp": { @@ -318,6 +320,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "tiled": None, }, ], indirect=True,