diff --git a/pargraph/about.py b/pargraph/about.py index 732155f..fa3ddd8 100644 --- a/pargraph/about.py +++ b/pargraph/about.py @@ -1 +1 @@ -__version__ = "0.8.3" +__version__ = "0.8.4" diff --git a/pargraph/graph/decorators.py b/pargraph/graph/decorators.py index b65ef17..004a6a8 100644 --- a/pargraph/graph/decorators.py +++ b/pargraph/graph/decorators.py @@ -6,7 +6,7 @@ import uuid import warnings from dataclasses import dataclass -from typing import Any, Callable, Optional, Protocol, Tuple, Union, cast +from typing import Any, Callable, Optional, Protocol, Tuple, Union, cast, Iterator from pargraph.graph.annotation import _get_output_names from pargraph.graph.objects import ( @@ -151,10 +151,13 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . nodes=graph_context._graph.nodes, outputs={OutputKey(key=output_name): graph_context._target}, ) - for output_name, graph_context in ( - zip(output_names, graph_result) - if isinstance(output_names, tuple) - else ((output_names, graph_result),) + for output_name, graph_context in cast( + Iterator[Tuple[str, GraphContext]], + ( + zip(output_names, graph_result) + if isinstance(output_names, tuple) + else ((output_names, graph_result),) + ), ) if isinstance(graph_context, GraphContext) ) @@ -162,6 +165,16 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . # Short circuit if external input is passed in (for top-level graph calls) if any(arg._target is None for arg in bound_args.arguments.values() if isinstance(arg, GraphContext)): + # Inject default values for external inputs + for name, arg in bound_args.arguments.items(): + default_value = bound_args.signature.parameters[name].default + if default_value is inspect.Parameter.empty: + continue + + const_id = f"_{uuid.uuid4().hex}" + sub_graph.consts[ConstKey(key=const_id)] = Const.from_value(default_value) + sub_graph.inputs[InputKey(key=name)] = ConstKey(key=const_id) + return sub_graph # Inject sub graph into parent graph @@ -242,7 +255,7 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . # Short circuit if external input is passed in (for top-level delayed calls) if any(graph_context._target is None for graph_context in arg_dict.values()): - return Graph( + graph_result = Graph( consts={ ConstKey(key=name): arg._graph.consts[arg._target] for name, arg in arg_dict.items() @@ -264,6 +277,18 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . }, ) + # Inject default values for external inputs + for name, arg in bound_args.arguments.items(): + default_value = bound_args.signature.parameters[name].default + if default_value is inspect.Parameter.empty: + continue + + const_id = f"_{uuid.uuid4().hex}" + graph_result.consts[ConstKey(key=const_id)] = Const.from_value(default_value) + graph_result.inputs[InputKey(key=name)] = ConstKey(key=const_id) + + return graph_result + # Inject function call node into graph node_id = f"{function.__name__}_{uuid.uuid4().hex}" merged_graph = _merge_graphs(*(arg._graph for arg in arg_dict.values())) @@ -327,9 +352,6 @@ def _generate_graph(func: Callable, /, *args, **kwargs): new_args = [] new_kwargs = {} for name, param in bound_args.signature.parameters.items(): - if param.default is not inspect.Parameter.empty: - continue - if param.kind == param.POSITIONAL_ONLY: new_args.append(bound_args.arguments.get(name, external_input())) elif param.kind in {param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY}: diff --git a/pargraph/graph/objects.py b/pargraph/graph/objects.py index 793929b..b8b0681 100644 --- a/pargraph/graph/objects.py +++ b/pargraph/graph/objects.py @@ -7,7 +7,7 @@ import warnings from collections import defaultdict, deque from dataclasses import dataclass -from typing import Any, Callable, DefaultDict, Dict, List, Literal, Optional, Tuple, TypedDict, Union, cast +from typing import Any, Callable, DefaultDict, Dict, List, Literal, Optional, Tuple, TypedDict, Union, cast, Iterator import cloudpickle import jsonschema @@ -703,7 +703,7 @@ def _peel_subgraphs(graph: Graph) -> Graph: return graph graph = self - for _ in range(depth) if depth >= 0 else itertools.count(): + for _ in cast(Iterator, range(depth) if depth >= 0 else itertools.count()): graph = _peel_subgraphs(graph) # break if there are no more subgraphs @@ -808,7 +808,10 @@ def _convert_graph_to_dask_graph( if inputs is not None: for input_key in self.inputs.keys(): graph_key = f"input_{input_key.key}_{uuid.uuid4().hex}" - dask_graph[graph_key] = inputs[input_key.key] + # if input key is not in inputs, use the default value + dask_graph[graph_key] = ( + inputs[input_key.key] if input_key.key in inputs else self.consts[self.inputs[input_key]].to_value() + ) key_to_uuid[input_key] = graph_key # assign random keys to all node paths and node output paths beforehand @@ -870,8 +873,9 @@ def _convert_graph_to_dask_graph( assert callable(node.function) output_names = _get_output_names(node.function) node_uuid = f"node_{self._get_function_name(node.function)}_{uuid.uuid4().hex}" - for output_position, output_name in ( - enumerate(output_names) if isinstance(output_names, tuple) else ((None, output_names),) + for output_position, output_name in cast( + Iterator[Tuple[Optional[int], str]], + enumerate(output_names) if isinstance(output_names, tuple) else ((None, output_names),), ): graph_key = key_to_uuid[NodeOutputKey(key=node_key.key, output=output_name)] diff --git a/tests/test_graph_generation.py b/tests/test_graph_generation.py index 604afb6..85c57d6 100644 --- a/tests/test_graph_generation.py +++ b/tests/test_graph_generation.py @@ -265,3 +265,31 @@ def sample_graph(x: int, y: int) -> int: assert isinstance(function_call, FunctionCall) self.assertEqual(getattr(function_call.function, "__implicit", False), True) + + def test_graph_default_argument(self): + @graph + def sample_graph(x: int, y: int = 1) -> int: + return x + y + + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=2, y=3))[0], sample_graph(x=2, y=3)) + + def test_graph_default_argument_missing(self): + @graph + def sample_graph(x: int, y: int = 1) -> int: + return x + y + + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=2))[0], sample_graph(x=2)) + + def test_function_default_argument(self): + @graph + def add(x: int, y: int = 1) -> int: + return x + y + + self.assertEqual(self.engine.get(*add.to_graph().to_dask(x=2, y=3))[0], add(x=2, y=3)) + + def test_function_default_argument_missing(self): + @graph + def add(x: int, y: int = 1) -> int: + return x + y + + self.assertEqual(self.engine.get(*add.to_graph().to_dask(x=2))[0], add(x=2))