Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate node and graph parameters type #84

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions qualibrate/qualibration_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
*,
modes: Optional[RunModes] = None,
):
if not isinstance(parameters, GraphParameters):
raise ValueError("Graph parameters must be of type GraphParameters")
super().__init__(name, parameters, description=description, modes=modes)
self._nodes = self._validate_nodes_names_mapping(nodes)
self._connectivity = connectivity
Expand Down
73 changes: 25 additions & 48 deletions qualibrate/qualibration_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@
external_parameters_ctx: ContextVar[Optional[tuple[str, Any]]] = ContextVar(
"external_parameters", default=None
)
last_executed_node_ctx: ContextVar[Optional["QualibrationNode[Any]"]] = (
ContextVar("last_executed_node", default=None)
last_executed_node_ctx: ContextVar[Optional["QualibrationNode[Any]"]] = ContextVar(
"last_executed_node", default=None
)


Expand All @@ -99,9 +99,7 @@ class QualibrationNode(
StopInspection: Raised if the node is instantiated in inspection mode.
"""

storage_manager: Optional[
StorageManager["QualibrationNode[NodeParameters]"]
] = None
storage_manager: Optional[StorageManager["QualibrationNode[NodeParameters]"]] = None
active_node: Optional["QualibrationNode[ParametersType]"] = None

def __init__(
Expand All @@ -128,9 +126,7 @@ def __init__(
self.machine = None

if self.modes.inspection:
raise StopInspection(
"Node instantiated in inspection mode", instance=self
)
raise StopInspection("Node instantiated in inspection mode", instance=self)
self.__class__.active_node = self
last_executed_node_ctx.set(self)

Expand All @@ -149,7 +145,9 @@ def _validate_passed_parameters_options(
parameters_class: Optional[type[ParametersType]],
) -> ParametersType:
"""
Validates passed parameters and parameters class. If parameters
Validates passed parameters and parameters class.

If parameters
passed then the instance will be used. If parameters class is passed,
an attempt will be made to instantiate it. If neither parameters nor
parameter class are passed, then the default base parameters will be
Expand All @@ -166,7 +164,10 @@ def _validate_passed_parameters_options(
Raises:
ValueError: If parameters class instantiation fails.
"""
params_type_error = ValueError("Node parameters must be of type NodeParameters")
if parameters is not None:
if not isinstance(parameters, NodeParameters):
raise params_type_error
if parameters_class is not None:
logger.warning(
"Passed both parameters and parameters_class to the node "
Expand All @@ -175,8 +176,7 @@ def _validate_passed_parameters_options(
return parameters
if parameters_class is None:
fields = {
name: copy(field)
for name, field in NodeParameters.model_fields.items()
name: copy(field) for name, field in NodeParameters.model_fields.items()
}
# Create subclass of NodeParameters. It's needed because otherwise
# there will be an issue with type checking of subclasses.
Expand All @@ -186,16 +186,15 @@ def _validate_passed_parameters_options(
__doc__=NodeParameters.__doc__,
__base__=NodeParameters,
__module__=NodeParameters.__module__,
**{
name: (info.annotation, info)
for name, info in fields.items()
},
**{name: (info.annotation, info) for name, info in fields.items()},
)
return cast(ParametersType, new_model())
logger.warning(
"parameters_class argument is deprecated. Please use "
f"parameters argument for initializing node '{name}'."
)
if not issubclass(parameters_class, NodeParameters):
raise params_type_error
try:
return parameters_class()
except ValidationError as e:
Expand Down Expand Up @@ -247,15 +246,11 @@ def copy(
f"{name = }, {node_parameters = }"
)
if name is not None and not isinstance(name, str):
raise ValueError(
f"{self.__class__.__name__} should have a string name"
)
raise ValueError(f"{self.__class__.__name__} should have a string name")
instance = self.__copy__()
if name is not None:
instance.name = name
instance._parameters = instance.parameters_class.model_validate(
node_parameters
)
instance._parameters = instance.parameters_class.model_validate(node_parameters)
instance.parameters_class = self.build_parameters_class_from_instance(
instance._parameters
)
Expand Down Expand Up @@ -328,9 +323,7 @@ def save(self) -> None:
root_data_folder=qs.storage.location,
active_machine_path=state_path,
)
self.storage_manager.save(
node=cast("QualibrationNode[NodeParameters]", self)
)
self.storage_manager.save(node=cast("QualibrationNode[NodeParameters]", self))

def _load_from_id(
self,
Expand Down Expand Up @@ -377,9 +370,7 @@ def _load_from_id(
self.machine = quam_machine
if parameters is not None:
if build_params_class:
self.parameters_class = cast(
ParametersType, parameters
).__class__
self.parameters_class = cast(ParametersType, parameters).__class__
self._parameters = cast(ParametersType, parameters)
else:
self._parameters = self.parameters.model_construct(
Expand Down Expand Up @@ -513,9 +504,7 @@ def run(
RuntimeError: Raised if the node filepath is not provided, or
execution
"""
logger.info(
f"Run node {self.name} with parameters: {passed_parameters}"
)
logger.info(f"Run node {self.name} with parameters: {passed_parameters}")
if self.filepath is None:
ex = RuntimeError(f"Node {self.name} file path was not provided")
logger.exception("", exc_info=ex)
Expand All @@ -537,9 +526,7 @@ def run(
run_modes_token = run_modes_ctx.set(
RunModes(external=True, interactive=interactive, inspection=False)
)
external_parameters_token = external_parameters_ctx.set(
(self.name, parameters)
)
external_parameters_token = external_parameters_ctx.set((self.name, parameters))
try:
self._parameters = parameters
self.run_node_file(self.filepath)
Expand All @@ -556,9 +543,7 @@ def run(
external_parameters_ctx.reset(external_parameters_token)
last_executed_node = last_executed_node_ctx.get()
if last_executed_node is None:
logger.warning(
f"Last executed node not set after running {self}"
)
logger.warning(f"Last executed node not set after running {self}")
last_executed_node = self

run_summary = self._post_run(
Expand Down Expand Up @@ -588,9 +573,7 @@ def run_node_file(self, node_filepath: Path) -> None:
# Appending dir with nodes can cause issues with relative imports
try:
matplotlib.use("agg")
_module = import_from_path(
get_module_name(node_filepath), node_filepath
)
_module = import_from_path(get_module_name(node_filepath), node_filepath)
finally:
matplotlib.use(mpl_backend)

Expand Down Expand Up @@ -695,13 +678,9 @@ def record_state_updates(
}
try:
for cls in cls_setattr_funcs:
cls.__setattr__ = partialmethod(
record_state_update_getattr, node=self
)
cls.__setattr__ = partialmethod(record_state_update_getattr, node=self)
for cls in cls_setitem_funcs:
cls.__setitem__ = partialmethod(
record_state_update_getitem, node=self
)
cls.__setitem__ = partialmethod(record_state_update_getitem, node=self)
yield
finally:
for cls, setattr_func in cls_setattr_funcs.items():
Expand Down Expand Up @@ -796,9 +775,7 @@ def add_node(
nodes: dictionary to store nodes.
"""
if node.name in nodes:
logger.warning(
f'Node "{node.name}" already exists in library, overwriting'
)
logger.warning(f'Node "{node.name}" already exists in library, overwriting')

nodes[node.name] = node

Expand Down
25 changes: 10 additions & 15 deletions tests/unit/test_qualibration_graph/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from qualibrate import QualibrationGraph, QualibrationNode
from qualibrate.models.node_status import NodeStatus
from qualibrate.models.run_mode import RunModes
from qualibrate.parameters import GraphParameters
from qualibrate.q_runnnable import QRunnable
from qualibrate.utils.exceptions import StopInspection

Expand Down Expand Up @@ -75,7 +76,7 @@ def test_init_graph_base(self, mocker, pre_setup_graph_init):
)
graph = QualibrationGraph(
name="test_graph",
parameters=MagicMock(),
parameters=MagicMock(spec=GraphParameters),
nodes=nodes,
connectivity=connectivity,
)
Expand All @@ -100,7 +101,7 @@ def test_init_graph_with_inspection_mode(self, pre_setup_graph_init):
) as ex:
QualibrationGraph(
name="test_graph",
parameters=MagicMock(),
parameters=MagicMock(spec=GraphParameters),
nodes=nodes,
connectivity=connectivity,
modes=RunModes(inspection=True),
Expand Down Expand Up @@ -141,13 +142,12 @@ def test_add_node_by_name(
mocked_validate_names = mocker.patch.object(
QualibrationGraph, "_validate_nodes_names_mapping", return_value={}
)
parameters = MagicMock()
mock_get_qnode = mocker.patch.object(
QualibrationGraph, "_get_qnode_or_error", return_value=node
)
graph = QualibrationGraph(
name="test_graph",
parameters=parameters,
parameters=MagicMock(spec=GraphParameters),
nodes={},
connectivity=[],
)
Expand All @@ -166,7 +166,7 @@ def test_cleanup(self, mocker, mock_orchestrator, pre_setup_graph_init):
connectivity = [("node1", "node2")]
graph = QualibrationGraph(
name="test_graph",
parameters=MagicMock(),
parameters=MagicMock(spec=GraphParameters),
nodes=nodes,
connectivity=connectivity,
orchestrator=mock_orchestrator,
Expand All @@ -183,7 +183,6 @@ def test_cleanup(self, mocker, mock_orchestrator, pre_setup_graph_init):
def test_completed_count(self, mocker, pre_setup_graph_init):
(nodes, _, _, _) = pre_setup_graph_init
connectivity = [("node1", "node2")]
parameters = MagicMock()

mock_get_node_attributes = mocker.patch(
"qualibrate.qualibration_graph.nx.get_node_attributes",
Expand All @@ -195,7 +194,7 @@ def test_completed_count(self, mocker, pre_setup_graph_init):

graph = QualibrationGraph(
name="test_graph",
parameters=parameters,
parameters=MagicMock(spec=GraphParameters),
nodes=nodes,
connectivity=connectivity,
)
Expand All @@ -208,12 +207,11 @@ def test_completed_count(self, mocker, pre_setup_graph_init):
def test_run_successful(self, mocker, pre_setup_graph_init):
(nodes, _, _, _) = pre_setup_graph_init
connectivity = [("node1", "node2")]
parameters = MagicMock()
orchestrator = MagicMock()

graph = QualibrationGraph(
name="test_graph",
parameters=parameters,
parameters=MagicMock(spec=GraphParameters),
nodes=nodes,
connectivity=connectivity,
orchestrator=orchestrator,
Expand Down Expand Up @@ -243,12 +241,11 @@ def test_run_successful(self, mocker, pre_setup_graph_init):
def test_run_error(self, mocker, pre_setup_graph_init):
(nodes, _, _, _) = pre_setup_graph_init
connectivity = [("node1", "node2")]
parameters = MagicMock()
orchestrator = MagicMock()

graph = QualibrationGraph(
name="test_graph",
parameters=parameters,
parameters=MagicMock(spec=GraphParameters),
nodes=nodes,
connectivity=connectivity,
orchestrator=orchestrator,
Expand All @@ -271,12 +268,11 @@ def test_run_error(self, mocker, pre_setup_graph_init):
def test_stop(self, pre_setup_graph_init):
(nodes, _, _, _) = pre_setup_graph_init
connectivity = [("node1", "node2")]
parameters = MagicMock()
orchestrator = MagicMock()

graph = QualibrationGraph(
name="test_graph",
parameters=parameters,
parameters=MagicMock(spec=GraphParameters),
nodes=nodes,
connectivity=connectivity,
orchestrator=orchestrator,
Expand All @@ -295,12 +291,11 @@ def test_stop(self, pre_setup_graph_init):
def test_stop_without_active_node(self, pre_setup_graph_init):
(nodes, _, _, _) = pre_setup_graph_init
connectivity = [("node1", "node2")]
parameters = MagicMock()
orchestrator = MagicMock()

graph = QualibrationGraph(
name="test_graph",
parameters=parameters,
parameters=MagicMock(spec=GraphParameters),
nodes=nodes,
connectivity=connectivity,
orchestrator=orchestrator,
Expand Down
Loading
Loading