Skip to content

Commit

Permalink
yapf in taskletfusion classes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Jun 6, 2023
1 parent 9290fd2 commit c3f5548
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 34 deletions.
25 changes: 7 additions & 18 deletions dace/transformation/dataflow/tasklet_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def visit_Name(self, node: ast.Name) -> Any:


class CPPConnectorRenamer():

def __init__(self, repl_dict: Dict[str, str]) -> None:
self.repl_dict = repl_dict

Expand All @@ -44,7 +43,6 @@ def rename(self, code: str) -> str:


class PythonInliner(ast.NodeTransformer):

def __init__(self, target_id, target_ast):
self.target_id = target_id
self.target_ast = target_ast
Expand All @@ -57,7 +55,6 @@ def visit_Name(self, node: ast.AST):


class CPPInliner():

def __init__(self, inline_target, inline_val):
self.inline_target = inline_target
self.inline_val = inline_val
Expand Down Expand Up @@ -144,10 +141,7 @@ class TaskletFusion(pm.SingleStateTransformation):

@classmethod
def expressions(cls):
return [
sdutil.node_path_graph(cls.t1, cls.data, cls.t2),
sdutil.node_path_graph(cls.t1, cls.t2)
]
return [sdutil.node_path_graph(cls.t1, cls.data, cls.t2), sdutil.node_path_graph(cls.t1, cls.t2)]

def can_be_applied(self, graph: dace.SDFGState, expr_index: int, sdfg: dace.SDFG, permissive: bool = False) -> bool:
t1 = self.t1
Expand Down Expand Up @@ -191,9 +185,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
t2_in_edge = graph.out_edges(data if data is not None else t1)[0]

# Remove the connector from the second Tasklet.
inputs = {
k: v for k, v in t2.in_connectors.items() if k != t2_in_edge.dst_conn
}
inputs = {k: v for k, v in t2.in_connectors.items() if k != t2_in_edge.dst_conn}

# Copy the first Tasklet's in connectors.
repldict = {}
Expand All @@ -214,8 +206,8 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
break
else:
t2edge = conflict_edges[0]
if t2edge is not None and (in_edge.data != t2edge.data or in_edge.data.data != t2edge.data.data or
in_edge.data is None or in_edge.data.data is None):
if t2edge is not None and (in_edge.data != t2edge.data or in_edge.data.data != t2edge.data.data
or in_edge.data is None or in_edge.data.data is None):
in_edge.dst_conn = dace.data.find_new_name(in_edge.dst_conn, set(inputs))
repldict[old_value] = in_edge.dst_conn
else:
Expand All @@ -231,9 +223,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
if repldict:
assigned_value = PythonConnectorRenamer(repldict).visit(assigned_value)

new_code = [
PythonInliner(t2_in_edge.dst_conn, assigned_value).visit(line) for line in t2.code.code
]
new_code = [PythonInliner(t2_in_edge.dst_conn, assigned_value).visit(line) for line in t2.code.code]
new_code_str = '\n'.join(astunparse.unparse(line) for line in new_code)
elif t1.language == Language.CPP:
assigned_value = t1.code.as_string
Expand All @@ -255,9 +245,8 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
else:
return

new_tasklet = graph.add_tasklet(
t1.label + '_fused_' + t2.label, inputs, t2.out_connectors, new_code_str, t1.language
)
new_tasklet = graph.add_tasklet(t1.label + '_fused_' + t2.label, inputs, t2.out_connectors, new_code_str,
t1.language)

for in_edge in graph.in_edges(t1):
graph.add_edge(in_edge.src, in_edge.src_conn, new_tasklet, in_edge.dst_conn, in_edge.data)
Expand Down
29 changes: 13 additions & 16 deletions tests/transformations/tasklet_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
M = 10
N = 2 * M


@dace.program
def map_with_tasklets(A: datatype[N], B: datatype[M]):
C = np.zeros_like(B)
Expand Down Expand Up @@ -42,15 +43,11 @@ def _make_sdfg(language: str, with_data: bool = False):
outputs = {
'__out': datatype,
}
ta = state.add_tasklet(
'a', inputs, {
'__out1': datatype,
'__out2': datatype,
'__out3': datatype,
},
f'__out1 = __inp1 + __inp2{endl}__out2 = __out1{endl}__out3 = __out1{endl}',
lang
)
ta = state.add_tasklet('a', inputs, {
'__out1': datatype,
'__out2': datatype,
'__out3': datatype,
}, f'__out1 = __inp1 + __inp2{endl}__out2 = __out1{endl}__out3 = __out1{endl}', lang)
tb = state.add_tasklet('b', inputs, outputs, f'__out = __inp1 * __inp2{endl}', lang)
tc = state.add_tasklet('c', inputs, outputs, f'__out = __inp1 + __inp2{endl}', lang)
td = state.add_tasklet('d', inputs, outputs, f'__out = __inp1 / __inp2{endl}', lang)
Expand All @@ -60,12 +57,12 @@ def _make_sdfg(language: str, with_data: bool = False):
state.add_memlet_path(A, me, tb, memlet=dace.Memlet('A[2*i]'), dst_conn='__inp2')
state.add_memlet_path(B, me, tc, memlet=dace.Memlet('B[i]'), dst_conn='__inp2')
if with_data:
sdfg.add_array('tmp1', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp2', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp3', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp4', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp5', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp6', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp1', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp2', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp3', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp4', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp5', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp6', (1, ), datatype, dtypes.StorageType.Default, None, True)
atemp1 = state.add_access('tmp1')
atemp2 = state.add_access('tmp2')
atemp3 = state.add_access('tmp3')
Expand Down Expand Up @@ -101,7 +98,7 @@ def test_basic():
def test_basic_tf(A: datatype[5, 5]):
B = A + 1
return B * 2

sdfg = test_basic_tf.to_sdfg(simplify=True)

num_map_fusions = sdfg.apply_transformations(MapFusion)
Expand Down

0 comments on commit c3f5548

Please sign in to comment.