diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index 91ecbfcda..02ad15a20 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -22,7 +22,8 @@ class Task(BlueapiBaseModel): def prepare_params(self, ctx: BlueskyContext) -> Mapping[str, Any]: model = _lookup_params(ctx, self) - return _model_to_kwargs(model) + # Re-create dict manually to avoid nesting in model_dump output + return {field: getattr(model, field) for field in model.__pydantic_fields__} def do_task(self, ctx: BlueskyContext) -> None: LOGGER.info(f"Asked to run plan {self.name} with {self.params}") @@ -49,22 +50,3 @@ def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel: model = plan.model adapter = TypeAdapter(model) return adapter.validate_python(task.params) - - -def _model_to_kwargs(model: BaseModel) -> Mapping[str, Any]: - """ - Converts an instance of BaseModel back to a dictionary that - can be passed as **kwargs. - Used instead of BaseModel.model_dump() because we don't want - the dumping to be nested and because it fires UserWarnings - about data types it is unfamiliar with - (such as ophyd devices). - - Args: - model: Pydantic model to convert to kwargs - - Returns: - Mapping[str, Any]: Dictionary that can be passed as **kwargs - """ - - return {name: getattr(model, name) for name in model.model_fields_set} diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 1eae7280c..d593821dc 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -575,3 +575,26 @@ def test_begin_task_span_ok( task_id = worker.submit_task(_SIMPLE_TASK) with asserting_span_exporter(exporter, "begin_task", "task_id"): worker.begin_task(task_id) + + +def test_injected_devices_are_found( + fake_device: FakeDevice, + context: BlueskyContext, +): + def injected_device_plan(dev: FakeDevice = fake_device.name) -> MsgGenerator: # type: ignore + yield from () + + context.register_plan(injected_device_plan) + params = Task(name="injected_device_plan").prepare_params(context) + assert params["dev"] == fake_device + + +def test_missing_injected_devices_fail_early( + context: BlueskyContext, +): + def missing_injection(dev: FakeDevice = "does_not_exist") -> MsgGenerator: # type: ignore + yield from () + + context.register_plan(missing_injection) + with pytest.raises(ValueError): + Task(name="missing_injection").prepare_params(context)