From e36409b0e0180077a2642bfbba716a1cf75b765b Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 17:14:32 -0800 Subject: [PATCH] `mypy` finishing touches (#562) * Make class "property" a plain method Signed-off-by: liamhuber * Refactor to non-None classvar Signed-off-by: liamhuber * Explicitly cast to tuple For the sake of the name generator Signed-off-by: liamhuber * Silence mypy It is upset about the hinting `list[hint]`, because `hint` is variable. We would be able to verify that it is a type at static analysis time, but since it comes from the body node class -- which is unknown until runtime -- it is impossible to say _which_ type. Signed-off-by: liamhuber * Uniformly give and ignore classfactory hints At a minimum, getting mypy to parse these correctly requires more rigorous hinting in pyiron_snippets.factory. But actually, since the classfactory allows the parent class to be specified with _multiple bases_, I'm not even 100% sure we'd ever be able to get a single type variable that could do the trick universally. In any case, for now kick the can don't the road and always hint what you know is true, then tell mypy to not worry about it. Signed-off-by: liamhuber * Add some hints to preview Albeit pretty relaxed ones Signed-off-by: liamhuber * Add a return hint to Runnable.__init__ To get mypy to parse the body of the function Signed-off-by: liamhuber * Break loop into a method Mypy didn't like parsing the zip variable when it could be inputs or outputs (even though both inherit from the relevant DataIO in this case), but using a separate method is functionally equivalent and mypy can get a better grasp of the type values. Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/io.py | 55 +++++++++++++++++------------- pyiron_workflow/mixin/preview.py | 6 ++-- pyiron_workflow/mixin/run.py | 2 +- pyiron_workflow/nodes/for_loop.py | 4 +-- pyiron_workflow/nodes/function.py | 4 +-- pyiron_workflow/nodes/macro.py | 4 +-- pyiron_workflow/nodes/transform.py | 10 +++--- 7 files changed, 47 insertions(+), 38 deletions(-) diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index ecef56ce..4d983996 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -572,30 +572,39 @@ def _copy_values( list[tuple[Channel, Any]]: A list of tuples giving channels whose value has been updated and what it used to be (for reverting changes). """ + # Leverage a separate function because mypy has trouble parsing types + # if we loop over inputs and outputs at the same time + return self._copy_panel( + other, self.inputs, other.inputs, fail_hard=fail_hard + ) + self._copy_panel(other, self.outputs, other.outputs, fail_hard=fail_hard) + + def _copy_panel( + self, + other: HasIO, + my_panel: DataIO, + other_panel: DataIO, + fail_hard: bool = False, + ) -> list[tuple[DataChannel, Any]]: old_values = [] - for my_panel, other_panel in [ - (self.inputs, other.inputs), - (self.outputs, other.outputs), - ]: - for key, to_copy in other_panel.items(): - if to_copy.value is not NOT_DATA: - try: - old_value = my_panel[key].value - my_panel[key].value = to_copy.value # Gets hint-checked - old_values.append((my_panel[key], old_value)) - except Exception as e: - if fail_hard: - # If you run into trouble, unwind what you've done - for channel, value in old_values: - channel.value = value - raise ValueCopyError( - f"{self.label} could not copy values from " - f"{other.label} due to the channel {key} on " - f"{other_panel.__class__.__name__}, which holds value " - f"{to_copy.value}" - ) from e - else: - continue + for key, to_copy in other_panel.items(): + if to_copy.value is not NOT_DATA: + try: + old_value = my_panel[key].value + my_panel[key].value = to_copy.value # Gets hint-checked + old_values.append((my_panel[key], old_value)) + except Exception as e: + if fail_hard: + # If you run into trouble, unwind what you've done + for channel, value in old_values: + channel.value = value + raise ValueCopyError( + f"{self.label} could not copy values from " + f"{other.label} due to the channel {key} on " + f"{other_panel.__class__.__name__}, which holds value " + f"{to_copy.value}" + ) from e + else: + continue return old_values @property diff --git a/pyiron_workflow/mixin/preview.py b/pyiron_workflow/mixin/preview.py index 7ae1a1c1..97695bdb 100644 --- a/pyiron_workflow/mixin/preview.py +++ b/pyiron_workflow/mixin/preview.py @@ -76,7 +76,7 @@ def preview_outputs(cls) -> dict[str, Any]: return cls._build_outputs_preview() @classmethod - def preview_io(cls) -> DotDict[str, dict]: + def preview_io(cls) -> DotDict[str, dict[str, Any | tuple[Any, Any]]]: return DotDict( {"inputs": cls.preview_inputs(), "outputs": cls.preview_outputs()} ) @@ -124,7 +124,7 @@ def _io_defining_function(cls) -> Callable: ) @classmethod - def _build_inputs_preview(cls): + def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: type_hints = cls._get_type_hints() scraped: dict[str, tuple[Any, Any]] = {} for i, (label, value) in enumerate(cls._get_input_args().items()): @@ -152,7 +152,7 @@ def _build_inputs_preview(cls): return scraped @classmethod - def _build_outputs_preview(cls): + def _build_outputs_preview(cls) -> dict[str, Any]: if cls._validate_output_labels: cls._validate() # Validate output on first call diff --git a/pyiron_workflow/mixin/run.py b/pyiron_workflow/mixin/run.py index ac7aae86..aa40bd67 100644 --- a/pyiron_workflow/mixin/run.py +++ b/pyiron_workflow/mixin/run.py @@ -50,7 +50,7 @@ class Runnable(UsesState, HasLabel, HasRun, ABC): new keyword arguments. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.running: bool = False self.failed: bool = False diff --git a/pyiron_workflow/nodes/for_loop.py b/pyiron_workflow/nodes/for_loop.py index 6d1d4150..bc408821 100644 --- a/pyiron_workflow/nodes/for_loop.py +++ b/pyiron_workflow/nodes/for_loop.py @@ -509,7 +509,7 @@ def for_node_factory( output_column_map: dict | None = None, use_cache: bool = True, /, -): +) -> type[For]: combined_docstring = ( "For node docstring:\n" + (For.__doc__ if For.__doc__ is not None else "") @@ -520,7 +520,7 @@ def for_node_factory( iter_on = (iter_on,) if isinstance(iter_on, str) else iter_on zip_on = (zip_on,) if isinstance(zip_on, str) else zip_on - return ( + return ( # type: ignore[return-value] _for_node_class_name(body_node_class, iter_on, zip_on, output_as_dataframe), (For,), { diff --git a/pyiron_workflow/nodes/function.py b/pyiron_workflow/nodes/function.py index 8000a6d9..877f3b28 100644 --- a/pyiron_workflow/nodes/function.py +++ b/pyiron_workflow/nodes/function.py @@ -357,7 +357,7 @@ def function_node_factory( use_cache: bool = True, /, *output_labels, -): +) -> type[Function]: """ Create a new :class:`Function` node class based on the given node function. This function gets executed on each :meth:`run` of the resulting function. @@ -373,7 +373,7 @@ def function_node_factory( Returns: type[Node]: A new node class. """ - return ( + return ( # type: ignore[return-value] node_function.__name__, (Function,), # Define parentage { diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index fde7ab66..ec09f8e8 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -475,7 +475,7 @@ def macro_node_factory( use_cache: bool = True, /, *output_labels: str, -): +) -> type[Macro]: """ Create a new :class:`Macro` subclass using the given graph creator function. @@ -491,7 +491,7 @@ def macro_node_factory( Returns: type[Macro]: A new :class:`Macro` subclass. """ - return ( + return ( # type: ignore[return-value] graph_creator.__name__, (Macro,), # Define parentage { diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index aed09074..6a4371be 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -108,7 +108,7 @@ def _build_outputs_preview(cls) -> dict[str, Any]: @classfactory def inputs_to_list_factory(n: int, use_cache: bool = True, /) -> type[InputsToList]: - return ( + return ( # type: ignore[return-value] f"{InputsToList.__name__}{n}", (InputsToList,), { @@ -142,7 +142,7 @@ def inputs_to_list(n: int, /, *node_args, use_cache: bool = True, **node_kwargs) @classfactory def list_to_outputs_factory(n: int, use_cache: bool = True, /) -> type[ListToOutputs]: - return ( + return ( # type: ignore[return-value] f"{ListToOutputs.__name__}{n}", (ListToOutputs,), { @@ -231,7 +231,7 @@ def inputs_to_dict_factory( class_name_suffix = str( InputsToDict.hash_specification(input_specification) ).replace("-", "m") - return ( + return ( # type: ignore[return-value] f"{InputsToDict.__name__}{class_name_suffix}", (InputsToDict,), { @@ -307,7 +307,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: def inputs_to_dataframe_factory( n: int, use_cache: bool = True, / ) -> type[InputsToDataframe]: - return ( + return ( # type: ignore[return-value] f"{InputsToDataframe.__name__}{n}", (InputsToDataframe,), { @@ -403,7 +403,7 @@ def dataclass_node_factory( # Composition is preferable over inheritance, but we want inheritance to be possible module, qualname = dataclass.__module__, dataclass.__qualname__ dataclass.__qualname__ += ".dataclass" # So output type hints know where to find it - return ( + return ( # type: ignore[return-value] dataclass.__name__, (DataclassNode,), {