Skip to content

Commit

Permalink
Fix default argument bug
Browse files Browse the repository at this point in the history
Signed-off-by: 1597463007 <[email protected]>
  • Loading branch information
1597463007 committed Oct 4, 2024
1 parent 618b279 commit 202f271
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
27 changes: 23 additions & 4 deletions pargraph/graph/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, .

# Short circuit if external input is passed in (for top-level graph calls)
if any(arg._target is None for arg in bound_args.arguments.values() if isinstance(arg, GraphContext)):
# Inject default values for external inputs
for name, arg in bound_args.arguments.items():
default_value = bound_args.signature.parameters[name].default
if default_value is inspect.Parameter.empty:
continue

const_id = f"_{uuid.uuid4().hex}"
sub_graph.consts[ConstKey(key=const_id)] = Const.from_value(default_value)
sub_graph.inputs[InputKey(key=name)] = ConstKey(key=const_id)

return sub_graph

# Inject sub graph into parent graph
Expand Down Expand Up @@ -242,7 +252,7 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, .

# Short circuit if external input is passed in (for top-level delayed calls)
if any(graph_context._target is None for graph_context in arg_dict.values()):
return Graph(
graph_result = Graph(
consts={
ConstKey(key=name): arg._graph.consts[arg._target]
for name, arg in arg_dict.items()
Expand All @@ -264,6 +274,18 @@ def wrapper(*args, **kwargs) -> Union[Graph, GraphContext, Tuple[GraphContext, .
},
)

# Inject default values for external inputs
for name, arg in bound_args.arguments.items():
default_value = bound_args.signature.parameters[name].default
if default_value is inspect.Parameter.empty:
continue

const_id = f"_{uuid.uuid4().hex}"
graph_result.consts[ConstKey(key=const_id)] = Const.from_value(default_value)
graph_result.inputs[InputKey(key=name)] = ConstKey(key=const_id)

return graph_result

# Inject function call node into graph
node_id = f"{function.__name__}_{uuid.uuid4().hex}"
merged_graph = _merge_graphs(*(arg._graph for arg in arg_dict.values()))
Expand Down Expand Up @@ -327,9 +349,6 @@ def _generate_graph(func: Callable, /, *args, **kwargs):
new_args = []
new_kwargs = {}
for name, param in bound_args.signature.parameters.items():
if param.default is not inspect.Parameter.empty:
continue

if param.kind == param.POSITIONAL_ONLY:
new_args.append(bound_args.arguments.get(name, external_input()))
elif param.kind in {param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY}:
Expand Down
4 changes: 4 additions & 0 deletions pargraph/graph/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,10 @@ def _convert_graph_to_dask_graph(
for input_key in self.inputs.keys():
graph_key = f"input_{input_key.key}_{uuid.uuid4().hex}"
dask_graph[graph_key] = inputs[input_key.key]
# if input key is not in inputs, use the default value
dask_graph[graph_key] = (
inputs[input_key.key] if input_key.key in inputs else self.consts[self.inputs[input_key]].to_value()
)
key_to_uuid[input_key] = graph_key

# assign random keys to all node paths and node output paths beforehand
Expand Down
28 changes: 28 additions & 0 deletions tests/test_graph_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,31 @@ def sample_graph(x: int, y: int) -> int:

assert isinstance(function_call, FunctionCall)
self.assertEqual(getattr(function_call.function, "__implicit", False), True)

def test_graph_default_argument(self):
@graph
def sample_graph(x: int, y: int = 1) -> int:
return x + y

self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=2, y=3))[0], sample_graph(x=2, y=3))

def test_graph_default_argument_missing(self):
@graph
def sample_graph(x: int, y: int = 1) -> int:
return x + y

self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=2))[0], sample_graph(x=2))

def test_function_default_argument(self):
@graph
def add(x: int, y: int = 1) -> int:
return x + y

self.assertEqual(self.engine.get(*add.to_graph().to_dask(x=2, y=3))[0], add(x=2, y=3))

def test_function_default_argument_missing(self):
@graph
def add(x: int, y: int = 1) -> int:
return x + y

self.assertEqual(self.engine.get(*add.to_graph().to_dask(x=2))[0], add(x=2))

0 comments on commit 202f271

Please sign in to comment.