Skip to content

Commit

Permalink
Rename task graph to dict graph
Browse files Browse the repository at this point in the history
Signed-off-by: 1597463007 <[email protected]>
  • Loading branch information
1597463007 committed Nov 18, 2024
1 parent 36b53cc commit abfb28a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 81 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ map_reduce_sort_recursive.to_graph(partition_counts=4).to_dot().write_png("map_r

![Map-Reduce Sort Recursive](docs/_static/map_reduce_sort_recursive.png)

Use the `to_dask` method to convert the generated graph to a Dask task graph.
Use the `to_dict` method to convert the generated graph to a dict graph.

```python
import numpy as np
from distributed import Client

with Client() as client:
client.get(map_reduce_sort.to_graph(partition_count=4).to_dask(array=np.random.rand(20)))[0]
client.get(map_reduce_sort.to_graph(partition_count=4).to_dict(array=np.random.rand(20)))[0]

# [0.06253707 0.06795382 0.11492823 0.14512393 0.20183152 0.41109117
# 0.42613798 0.45156214 0.4714821 0.54000373 0.54902451 0.62671881
Expand Down
4 changes: 2 additions & 2 deletions pargraph/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def set_parallel_backend(self, backend: Backend) -> None:

def get(self, graph: Dict, keys: Any, **kwargs) -> Any:
"""
Compute task graph
Compute dict graph
:param graph: task graph
:param graph: dict graph
:param keys: keys to compute (e.g. ``"x"``, ``["x", "y", "z"]``, etc)
:param kwargs: keyword arguments to forward to the parallel backend
:return: results in the same structure as keys
Expand Down
98 changes: 49 additions & 49 deletions pargraph/graph/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def __post_init__(self):
assert isinstance(self.value, str), f"Value must be a string; got type '{type(self.value)}'"

@staticmethod
def from_dict(data: Dict) -> "Const":
def from_json(data: Dict) -> "Const":
return Const(**data)

def to_dict(self) -> Dict[str, Any]:
def to_json(self) -> Dict[str, Any]:
return {"type": self.type, "value": self.value}

@staticmethod
Expand Down Expand Up @@ -234,7 +234,7 @@ def __post_init__(self):
), f"Arg '{arg}' must ConstKey, InputKey, or NodeOutputKey; got type '{type(arg)}'"

@staticmethod
def from_dict(data: Dict) -> "FunctionCall":
def from_json(data: Dict) -> "FunctionCall":
data = data.copy()
function = data.pop("function")
return FunctionCall(
Expand All @@ -247,7 +247,7 @@ def from_dict(data: Dict) -> "FunctionCall":
**data,
)

def to_dict(self) -> Dict[str, Any]:
def to_json(self) -> Dict[str, Any]:
return {
"function": (
base64.b64encode(cloudpickle.dumps(self.function)).decode("ascii")
Expand Down Expand Up @@ -279,16 +279,16 @@ def __post_init__(self):
), f"Arg '{arg}' must ConstKey, InputKey, or NodeOutputKey; got type '{type(arg)}'"

@staticmethod
def from_dict(data: Dict) -> "GraphCall":
def from_json(data: Dict) -> "GraphCall":
data = data.copy()
return GraphCall(
graph=Graph.from_dict(data.pop("graph")),
graph=Graph.from_json(data.pop("graph")),
args={arg: _get_key_from_str(key_str) for arg, key_str in data.pop("args").items()},
**data,
)

def to_dict(self) -> Dict[str, Any]:
dct: dict = {"graph": self.graph.to_dict(), "args": {arg: key.to_str() for arg, key in self.args.items()}}
def to_json(self) -> Dict[str, Any]:
dct: dict = {"graph": self.graph.to_json(), "args": {arg: key.to_str() for arg, key in self.args.items()}}
if self.graph_name is not None:
dct["graph_name"] = self.graph_name
return dct
Expand Down Expand Up @@ -342,22 +342,22 @@ def __post_init__(self):
), f"Output '{output}' must be type '{ConstKey}', '{InputKey}', or '{NodeOutputKey}'"

@staticmethod
def from_dict(data: Dict) -> "Graph":
def from_json(data: Dict) -> "Graph":
"""
Create graph from graph dict by inferring the graph dict type
Create graph from json serializable dictionary by inferring the graph type
:param data: graph dict
:return: graph
"""
if "edges" in data:
return Graph.from_dict_with_edge_list(data)
return Graph.from_json_with_edge_list(data)

return Graph.from_dict_with_node_arguments(data)
return Graph.from_json_with_node_arguments(data)

@staticmethod
def from_dict_with_edge_list(data: Dict) -> "Graph":
def from_json_with_edge_list(data: Dict) -> "Graph":
"""
Create graph from graph dict with edge list
Create graph from json serializable dictionary with edge list
:param data: graph dict with edge list
:return: graph
Expand Down Expand Up @@ -411,53 +411,53 @@ def from_dict_with_edge_list(data: Dict) -> "Graph":

outputs[key] = new_output

return Graph.from_dict_with_node_arguments(data)
return Graph.from_json_with_node_arguments(data)

@staticmethod
def from_dict_with_node_arguments(data: Dict) -> "Graph":
def from_json_with_node_arguments(data: Dict) -> "Graph":
"""
Create graph from graph dict with node arguments
Create graph from json serializable dictionary with node arguments
:param data: graph dict with node arguments
:return: graph
"""

def _graph_node_from_dict(data: Union[Dict, str]) -> Union[FunctionCall, "GraphCall"]:
def _graph_node_from_json(data: Union[Dict, str]) -> Union[FunctionCall, "GraphCall"]:
if isinstance(data, dict) and "function" in data:
return FunctionCall.from_dict(data)
return FunctionCall.from_json(data)
elif isinstance(data, dict) and "graph" in data:
return GraphCall.from_dict(data)
return GraphCall.from_json(data)

raise ValueError(f"invalid graph node dict '{data}'")

data = data.copy()
return Graph(
consts={ConstKey(key=key): Const.from_dict(value) for key, value in data.pop("consts").items()},
consts={ConstKey(key=key): Const.from_json(value) for key, value in data.pop("consts").items()},
inputs={
InputKey(key=key): cast(ConstKey, _get_key_from_str(value)) if value is not None else None
for key, value in data.pop("inputs").items()
},
nodes={NodeKey(key=key): _graph_node_from_dict(value) for key, value in data.pop("nodes").items()},
nodes={NodeKey(key=key): _graph_node_from_json(value) for key, value in data.pop("nodes").items()},
outputs={OutputKey(key=key): _get_key_from_str(value) for key, value in data.pop("outputs").items()},
**data,
)

def to_dict(self) -> Dict[str, Any]:
def to_json(self) -> Dict[str, Any]:
"""
Convert graph representation to serializable dictionary
Convert graph representation to json serializable dictionary
:return: graph dictionary
:return: json serializable dictionary
"""
graph_dict: GraphDict = {"consts": {}, "inputs": {}, "nodes": {}, "edges": [], "outputs": {}}

for const_node_key, const_node in self.consts.items():
graph_dict["consts"][const_node_key.key] = const_node.to_dict()
graph_dict["consts"][const_node_key.key] = const_node.to_json()

for input_node_key, input_node in self.inputs.items():
graph_dict["inputs"][input_node_key.key] = input_node.to_str() if input_node is not None else None

for func_node_key, func_node in self.nodes.items():
func_node_dict = func_node.to_dict()
func_node_dict = func_node.to_json()
func_node_dict.pop("args")

graph_dict["nodes"][func_node_key.key] = func_node_dict
Expand All @@ -483,11 +483,11 @@ def to_dict(self) -> Dict[str, Any]:

return cast(dict, graph_dict)

def to_task_graph(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]:
def to_dict(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]:
"""
Convert graph to task graph
Convert graph to dict graph
Task graph representation:
Dict graph representation:
.. code-block:: json
Expand All @@ -504,10 +504,10 @@ def to_task_graph(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]:
:param args: positional arguments
:param kwargs: keyword arguments
:return: task graph and output keys
:return: dict graph and output keys
"""
inputs: dict = {**dict(zip((key.key for key in self.inputs.keys()), args)), **kwargs}
return self._convert_graph_to_task_graph(inputs=inputs)
return self._convert_graph_to_dict(inputs=inputs)

def to_dask(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]:
"""
Expand All @@ -516,17 +516,17 @@ def to_dask(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]:
.. warning::
This method is deprecated and will be removed in a future release.
Please use :func:`to_task_graph` instead.
Please use :func:`to_dict` instead.
:param args: positional arguments
:param kwargs: keyword arguments
:return: dask graph and output keys
"""
warnings.warn(
"This method is deprecated and will be removed in a future release. Please use 'to_task_graph' instead.",
"This method is deprecated and will be removed in a future release. Please use 'to_dict' instead.",
DeprecationWarning,
)
return self.to_task_graph(*args, **kwargs)
return self.to_dict(*args, **kwargs)

def to_dot(
self,
Expand Down Expand Up @@ -813,37 +813,37 @@ def _create_dot_edge(src: str, dst: str) -> pydot.Edge:

return edge

def _convert_graph_to_task_graph(
def _convert_graph_to_dict(
self,
inputs: Optional[Dict[str, Any]] = None,
input_mapping: Optional[Dict[InputKey, str]] = None,
output_mapping: Optional[Dict[OutputKey, str]] = None,
) -> Tuple[Dict[str, Any], List[str]]:
"""
Convert our own graph format to a task graph.
Convert our own graph format to a dict graph.
:param inputs: inputs dictionary
:param input_mapping: input mapping for subgraphs
:param output_mapping: output mapping for subgraphs
:return: tuple containing task graph and targets
:return: tuple containing dict graph and targets
"""
assert inputs is None or input_mapping is None, "cannot specify both inputs and input_mapping"

task_graph: dict = {}
dict_graph: dict = {}
key_to_uuid: dict = {}

# create constants
for const_key, const in self.consts.items():
graph_key = f"const_{self._get_const_label(const)}_{uuid.uuid4().hex}"
task_graph[graph_key] = const.to_value()
dict_graph[graph_key] = const.to_value()
key_to_uuid[const_key] = graph_key

# create inputs
if inputs is not None:
for input_key in self.inputs.keys():
graph_key = f"input_{input_key.key}_{uuid.uuid4().hex}"
# if input key is not in inputs, use the default value
task_graph[graph_key] = (
dict_graph[graph_key] = (
inputs[input_key.key] if input_key.key in inputs else self.consts[self.inputs[input_key]].to_value()
)
key_to_uuid[input_key] = graph_key
Expand Down Expand Up @@ -879,7 +879,7 @@ def _convert_graph_to_task_graph(
else:
key_to_uuid[input_key] = key_to_uuid[const_path]

# build task graph
# build dict graph
for node_key, node in self.nodes.items():
if isinstance(node, FunctionCall):
assert callable(node.function)
Expand All @@ -896,7 +896,7 @@ def _convert_graph_to_task_graph(
# handle default arguments
if param_name not in node.args:
graph_key = f"const_{self._get_const_label(input_annotation.default)}_{uuid.uuid4().hex}"
task_graph[graph_key] = input_annotation.default
dict_graph[graph_key] = input_annotation.default
args.append(graph_key)
continue

Expand All @@ -918,10 +918,10 @@ def _convert_graph_to_task_graph(
break

constant_key = f"const_{self._get_const_label(output_position)}_{uuid.uuid4().hex}"
task_graph[constant_key] = output_position
task_graph[graph_key] = (_unpack_tuple, node_uuid, constant_key)
dict_graph[constant_key] = output_position
dict_graph[graph_key] = (_unpack_tuple, node_uuid, constant_key)

task_graph[node_uuid] = (node.function,) + tuple(args)
dict_graph[node_uuid] = (node.function,) + tuple(args)

elif isinstance(node, GraphCall):
new_input_mapping = {
Expand All @@ -931,12 +931,12 @@ def _convert_graph_to_task_graph(
output_key: key_to_uuid[NodeOutputKey(key=node_key.key, output=output_key.key)]
for output_key in node.graph.outputs
}
task_subgraph, _ = node.graph._convert_graph_to_task_graph(
dict_subgraph, _ = node.graph._convert_graph_to_dict(
input_mapping=new_input_mapping, output_mapping=new_output_mapping
)
task_graph.update(task_subgraph)
dict_graph.update(dict_subgraph)

return task_graph, [key_to_uuid[output_path] for output_path in self.outputs.values()]
return dict_graph, [key_to_uuid[output_path] for output_path in self.outputs.values()]

def _scramble_keys(
self, old_to_new: Optional[bidict[Union[ConstKey, NodeKey], Union[ConstKey, NodeKey]]] = None
Expand Down
Loading

0 comments on commit abfb28a

Please sign in to comment.