From 1d1b0c60f2cfb550e21df24a1523e08e69e20bd1 Mon Sep 17 00:00:00 2001 From: 1597463007 Date: Fri, 4 Oct 2024 09:25:41 -0400 Subject: [PATCH 1/5] Fix default argument bug Signed-off-by: 1597463007 --- pargraph/graph/decorators.py | 27 +++++++++++++++++++++++---- pargraph/graph/objects.py | 4 ++++ tests/test_graph_generation.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/pargraph/graph/decorators.py b/pargraph/graph/decorators.py index b65ef17..70671f5 100644 --- a/pargraph/graph/decorators.py +++ b/pargraph/graph/decorators.py @@ -162,6 +162,16 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . # Short circuit if external input is passed in (for top-level graph calls) if any(arg._target is None for arg in bound_args.arguments.values() if isinstance(arg, GraphContext)): + # Inject default values for external inputs + for name, arg in bound_args.arguments.items(): + default_value = bound_args.signature.parameters[name].default + if default_value is inspect.Parameter.empty: + continue + + const_id = f"_{uuid.uuid4().hex}" + sub_graph.consts[ConstKey(key=const_id)] = Const.from_value(default_value) + sub_graph.inputs[InputKey(key=name)] = ConstKey(key=const_id) + return sub_graph # Inject sub graph into parent graph @@ -242,7 +252,7 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . # Short circuit if external input is passed in (for top-level delayed calls) if any(graph_context._target is None for graph_context in arg_dict.values()): - return Graph( + graph_result = Graph( consts={ ConstKey(key=name): arg._graph.consts[arg._target] for name, arg in arg_dict.items() @@ -264,6 +274,18 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . }, ) + # Inject default values for external inputs + for name, arg in bound_args.arguments.items(): + default_value = bound_args.signature.parameters[name].default + if default_value is inspect.Parameter.empty: + continue + + const_id = f"_{uuid.uuid4().hex}" + graph_result.consts[ConstKey(key=const_id)] = Const.from_value(default_value) + graph_result.inputs[InputKey(key=name)] = ConstKey(key=const_id) + + return graph_result + # Inject function call node into graph node_id = f"{function.__name__}_{uuid.uuid4().hex}" merged_graph = _merge_graphs(*(arg._graph for arg in arg_dict.values())) @@ -327,9 +349,6 @@ def _generate_graph(func: Callable, /, *args, **kwargs): new_args = [] new_kwargs = {} for name, param in bound_args.signature.parameters.items(): - if param.default is not inspect.Parameter.empty: - continue - if param.kind == param.POSITIONAL_ONLY: new_args.append(bound_args.arguments.get(name, external_input())) elif param.kind in {param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY}: diff --git a/pargraph/graph/objects.py b/pargraph/graph/objects.py index 793929b..7870435 100644 --- a/pargraph/graph/objects.py +++ b/pargraph/graph/objects.py @@ -809,6 +809,10 @@ def _convert_graph_to_dask_graph( for input_key in self.inputs.keys(): graph_key = f"input_{input_key.key}_{uuid.uuid4().hex}" dask_graph[graph_key] = inputs[input_key.key] + # if input key is not in inputs, use the default value + dask_graph[graph_key] = ( + inputs[input_key.key] if input_key.key in inputs else self.consts[self.inputs[input_key]].to_value() + ) key_to_uuid[input_key] = graph_key # assign random keys to all node paths and node output paths beforehand diff --git a/tests/test_graph_generation.py b/tests/test_graph_generation.py index 604afb6..cfbca04 100644 --- a/tests/test_graph_generation.py +++ b/tests/test_graph_generation.py @@ -265,3 +265,31 @@ def sample_graph(x: int, y: int) -> int: assert isinstance(function_call, FunctionCall) self.assertEqual(getattr(function_call.function, "__implicit", False), True) + + def test_graph_default_argument(self): + @graph + def sample_graph(x: int, y: int = 1) -> int: + return x + y + + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=2, y=3))[0], sample_graph(x=2, y=3)) + + def test_graph_default_argument_missing(self): + @graph + def sample_graph(x: int, y: int = 1) -> int: + return x + y + + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=2))[0], sample_graph(x=2)) + + def test_function_default_argument(self): + @graph + def add(x: int, y: int = 1) -> int: + return x + y + + self.assertEqual(self.engine.get(*add.to_graph().to_dask(x=2, y=3))[0], add(x=2, y=3)) + + def test_function_default_argument_missing(self): + @graph + def add(x: int, y: int = 1) -> int: + return x + y + + self.assertEqual(self.engine.get(*add.to_graph().to_dask(x=2))[0], add(x=2)) From e23e9e2df8069876270a588bb9cc72d83446c5ba Mon Sep 17 00:00:00 2001 From: 1597463007 Date: Fri, 4 Oct 2024 09:35:50 -0400 Subject: [PATCH 2/5] Remove line Signed-off-by: 1597463007 --- pargraph/graph/objects.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pargraph/graph/objects.py b/pargraph/graph/objects.py index 7870435..c60f11a 100644 --- a/pargraph/graph/objects.py +++ b/pargraph/graph/objects.py @@ -808,7 +808,6 @@ def _convert_graph_to_dask_graph( if inputs is not None: for input_key in self.inputs.keys(): graph_key = f"input_{input_key.key}_{uuid.uuid4().hex}" - dask_graph[graph_key] = inputs[input_key.key] # if input key is not in inputs, use the default value dask_graph[graph_key] = ( inputs[input_key.key] if input_key.key in inputs else self.consts[self.inputs[input_key]].to_value() From 172a9b1f61826d6eec2e67b3a56943323e05ce03 Mon Sep 17 00:00:00 2001 From: 1597463007 Date: Fri, 4 Oct 2024 11:56:31 -0400 Subject: [PATCH 3/5] Fix linting errors Signed-off-by: 1597463007 --- pargraph/graph/decorators.py | 4 ++-- tests/test_graph_generation.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pargraph/graph/decorators.py b/pargraph/graph/decorators.py index 70671f5..7e5206a 100644 --- a/pargraph/graph/decorators.py +++ b/pargraph/graph/decorators.py @@ -167,7 +167,7 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . default_value = bound_args.signature.parameters[name].default if default_value is inspect.Parameter.empty: continue - + const_id = f"_{uuid.uuid4().hex}" sub_graph.consts[ConstKey(key=const_id)] = Const.from_value(default_value) sub_graph.inputs[InputKey(key=name)] = ConstKey(key=const_id) @@ -283,7 +283,7 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . const_id = f"_{uuid.uuid4().hex}" graph_result.consts[ConstKey(key=const_id)] = Const.from_value(default_value) graph_result.inputs[InputKey(key=name)] = ConstKey(key=const_id) - + return graph_result # Inject function call node into graph diff --git a/tests/test_graph_generation.py b/tests/test_graph_generation.py index cfbca04..85c57d6 100644 --- a/tests/test_graph_generation.py +++ b/tests/test_graph_generation.py @@ -270,7 +270,7 @@ def test_graph_default_argument(self): @graph def sample_graph(x: int, y: int = 1) -> int: return x + y - + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=2, y=3))[0], sample_graph(x=2, y=3)) def test_graph_default_argument_missing(self): From f9fd166f487adbf31b9aa91ee43f10b2974e6ed1 Mon Sep 17 00:00:00 2001 From: 1597463007 Date: Fri, 11 Oct 2024 13:59:42 -0400 Subject: [PATCH 4/5] Bump version Signed-off-by: 1597463007 --- pargraph/about.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pargraph/about.py b/pargraph/about.py index 732155f..fa3ddd8 100644 --- a/pargraph/about.py +++ b/pargraph/about.py @@ -1 +1 @@ -__version__ = "0.8.3" +__version__ = "0.8.4" From b67762bc2e911a05721e8c4e5ae1ae72196156d8 Mon Sep 17 00:00:00 2001 From: 1597463007 Date: Tue, 15 Oct 2024 19:43:51 -0400 Subject: [PATCH 5/5] Fix mypy errors Signed-off-by: 1597463007 --- pargraph/graph/decorators.py | 13 ++++++++----- pargraph/graph/objects.py | 9 +++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pargraph/graph/decorators.py b/pargraph/graph/decorators.py index 7e5206a..004a6a8 100644 --- a/pargraph/graph/decorators.py +++ b/pargraph/graph/decorators.py @@ -6,7 +6,7 @@ import uuid import warnings from dataclasses import dataclass -from typing import Any, Callable, Optional, Protocol, Tuple, Union, cast +from typing import Any, Callable, Optional, Protocol, Tuple, Union, cast, Iterator from pargraph.graph.annotation import _get_output_names from pargraph.graph.objects import ( @@ -151,10 +151,13 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, . nodes=graph_context._graph.nodes, outputs={OutputKey(key=output_name): graph_context._target}, ) - for output_name, graph_context in ( - zip(output_names, graph_result) - if isinstance(output_names, tuple) - else ((output_names, graph_result),) + for output_name, graph_context in cast( + Iterator[Tuple[str, GraphContext]], + ( + zip(output_names, graph_result) + if isinstance(output_names, tuple) + else ((output_names, graph_result),) + ), ) if isinstance(graph_context, GraphContext) ) diff --git a/pargraph/graph/objects.py b/pargraph/graph/objects.py index c60f11a..b8b0681 100644 --- a/pargraph/graph/objects.py +++ b/pargraph/graph/objects.py @@ -7,7 +7,7 @@ import warnings from collections import defaultdict, deque from dataclasses import dataclass -from typing import Any, Callable, DefaultDict, Dict, List, Literal, Optional, Tuple, TypedDict, Union, cast +from typing import Any, Callable, DefaultDict, Dict, List, Literal, Optional, Tuple, TypedDict, Union, cast, Iterator import cloudpickle import jsonschema @@ -703,7 +703,7 @@ def _peel_subgraphs(graph: Graph) -> Graph: return graph graph = self - for _ in range(depth) if depth >= 0 else itertools.count(): + for _ in cast(Iterator, range(depth) if depth >= 0 else itertools.count()): graph = _peel_subgraphs(graph) # break if there are no more subgraphs @@ -873,8 +873,9 @@ def _convert_graph_to_dask_graph( assert callable(node.function) output_names = _get_output_names(node.function) node_uuid = f"node_{self._get_function_name(node.function)}_{uuid.uuid4().hex}" - for output_position, output_name in ( - enumerate(output_names) if isinstance(output_names, tuple) else ((None, output_names),) + for output_position, output_name in cast( + Iterator[Tuple[Optional[int], str]], + enumerate(output_names) if isinstance(output_names, tuple) else ((None, output_names),), ): graph_key = key_to_uuid[NodeOutputKey(key=node_key.key, output=output_name)]