Skip to content

Commit

Permalink
Support inject for default plan arguments again
Browse files Browse the repository at this point in the history
When preparing the parameters for a plan, pydantic creates an instance
of the dynamically generated BaseModel, relying on the 'Reference' types
to convert from strings to instances of the devices. This includes
strings that are used as default field values (default parameters
created using the `inject` method).

We then convert the model back to a dictionary to pass as kwargs to the
plan. The previously used `model_fields_set` field on the model only
iterates over the fields set via the input JSON (the user supplied
arguments) and skips the fields populated via the default factories in
the base model. Using the `__pydantic_fields__` class variable, allows
all fields to be used including the defaults. This was previously
avoided as it generated warnings for unknown types but it is possible to
opt-out of these warnings.
  • Loading branch information
tpoliaw authored and callumforrester committed Jan 15, 2025
1 parent 3e36c8c commit de4b5e2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
22 changes: 2 additions & 20 deletions src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class Task(BlueapiBaseModel):

def prepare_params(self, ctx: BlueskyContext) -> Mapping[str, Any]:
model = _lookup_params(ctx, self)
return _model_to_kwargs(model)
# Re-create dict manually to avoid nesting in model_dump output
return {field: getattr(model, field) for field in model.__pydantic_fields__}

def do_task(self, ctx: BlueskyContext) -> None:
LOGGER.info(f"Asked to run plan {self.name} with {self.params}")
Expand All @@ -49,22 +50,3 @@ def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel:
model = plan.model
adapter = TypeAdapter(model)
return adapter.validate_python(task.params)


def _model_to_kwargs(model: BaseModel) -> Mapping[str, Any]:
"""
Converts an instance of BaseModel back to a dictionary that
can be passed as **kwargs.
Used instead of BaseModel.model_dump() because we don't want
the dumping to be nested and because it fires UserWarnings
about data types it is unfamiliar with
(such as ophyd devices).
Args:
model: Pydantic model to convert to kwargs
Returns:
Mapping[str, Any]: Dictionary that can be passed as **kwargs
"""

return {name: getattr(model, name) for name in model.model_fields_set}
23 changes: 23 additions & 0 deletions tests/unit_tests/worker/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,26 @@ def test_begin_task_span_ok(
task_id = worker.submit_task(_SIMPLE_TASK)
with asserting_span_exporter(exporter, "begin_task", "task_id"):
worker.begin_task(task_id)


def test_injected_devices_are_found(
fake_device: FakeDevice,
context: BlueskyContext,
):
def injected_device_plan(dev: FakeDevice = fake_device.name) -> MsgGenerator: # type: ignore
yield from ()

context.register_plan(injected_device_plan)
params = Task(name="injected_device_plan").prepare_params(context)
assert params["dev"] == fake_device


def test_missing_injected_devices_fail_early(
context: BlueskyContext,
):
def missing_injection(dev: FakeDevice = "does_not_exist") -> MsgGenerator: # type: ignore
yield from ()

context.register_plan(missing_injection)
with pytest.raises(ValueError):
Task(name="missing_injection").prepare_params(context)

0 comments on commit de4b5e2

Please sign in to comment.