Skip to content

Commit

Permalink
Import some of the offset changes
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 24, 2024
1 parent 2cbfcd5 commit f6a9dc8
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 22 deletions.
10 changes: 8 additions & 2 deletions dace/sdfg/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,10 @@ def propagate_memlets_nested_sdfg(parent_sdfg: 'SDFG', parent_state: 'SDFGState'
if internal_memlet is None:
continue
try:
iedge.data = unsqueeze_memlet(internal_memlet, iedge.data, True)
ext_desc = parent_sdfg.arrays[iedge.data.data]
int_desc = sdfg.arrays[iedge.dst_conn]
iedge.data = unsqueeze_memlet(internal_memlet, iedge.data, True, internal_offset=int_desc.offset,
external_offset=ext_desc.offset)
# If no appropriate memlet found, use array dimension
for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[iedge.data.data].shape)):
if rng[1] + 1 == s:
Expand All @@ -1137,7 +1140,10 @@ def propagate_memlets_nested_sdfg(parent_sdfg: 'SDFG', parent_state: 'SDFGState'
if internal_memlet is None:
continue
try:
oedge.data = unsqueeze_memlet(internal_memlet, oedge.data, True)
ext_desc = parent_sdfg.arrays[oedge.data.data]
int_desc = sdfg.arrays[oedge.src_conn]
oedge.data = unsqueeze_memlet(internal_memlet, oedge.data, True, internal_offset=int_desc.offset,
external_offset=ext_desc.offset)
# If no appropriate memlet found, use array dimension
for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[oedge.data.data].shape)):
if rng[1] + 1 == s:
Expand Down
80 changes: 62 additions & 18 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,14 +509,24 @@ def apply(self, state: SDFGState, sdfg: SDFG):
if (edge not in modified_edges and edge.data.data == node.data):
for e in state.memlet_tree(edge):
if e._data.get_dst_subset(e, state):
new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_dst_subset=True)
offset = sdfg.arrays[e.data.data].offset
new_memlet = helpers.unsqueeze_memlet(e.data,
outer_edge.data,
use_dst_subset=True,
internal_offset=offset,
external_offset=offset)
e._data.dst_subset = new_memlet.subset
# NOTE: Node is source
for edge in state.out_edges(node):
if (edge not in modified_edges and edge.data.data == node.data):
for e in state.memlet_tree(edge):
if e._data.get_src_subset(e, state):
new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_src_subset=True)
offset = sdfg.arrays[e.data.data].offset
new_memlet = helpers.unsqueeze_memlet(e.data,
outer_edge.data,
use_src_subset=True,
internal_offset=offset,
external_offset=offset)
e._data.src_subset = new_memlet.subset

# If source/sink node is not connected to a source/destination access
Expand Down Expand Up @@ -625,10 +635,17 @@ def _modify_access_to_access(self,
state.out_edges_by_connector(nsdfg_node, inner_data))
# Create memlet by unsqueezing both w.r.t. src and
# dst subsets
in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True)
offset = state.parent.arrays[top_edge.data.data].offset
in_memlet = helpers.unsqueeze_memlet(inner_edge.data,
top_edge.data,
use_src_subset=True,
internal_offset=offset,
external_offset=offset)
out_memlet = helpers.unsqueeze_memlet(inner_edge.data,
matching_edge.data,
use_dst_subset=True)
use_dst_subset=True,
internal_offset=offset,
external_offset=offset)
new_memlet = in_memlet
new_memlet.other_subset = out_memlet.subset

Expand All @@ -651,10 +668,17 @@ def _modify_access_to_access(self,
state.out_edges_by_connector(nsdfg_node, inner_data))
# Create memlet by unsqueezing both w.r.t. src and
# dst subsets
in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True)
offset = state.parent.arrays[top_edge.data.data].offset
in_memlet = helpers.unsqueeze_memlet(inner_edge.data,
top_edge.data,
use_src_subset=True,
internal_offset=offset,
external_offset=offset)
out_memlet = helpers.unsqueeze_memlet(inner_edge.data,
matching_edge.data,
use_dst_subset=True)
use_dst_subset=True,
internal_offset=offset,
external_offset=offset)
new_memlet = in_memlet
new_memlet.other_subset = out_memlet.subset

Expand Down Expand Up @@ -689,7 +713,11 @@ def _modify_memlet_path(
if inner_edge in edges_to_ignore:
new_memlet = inner_edge.data
else:
new_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data)
offset = state.parent.arrays[top_edge.data.data].offset
new_memlet = helpers.unsqueeze_memlet(inner_edge.data,
top_edge.data,
internal_offset=offset,
external_offset=offset)
if inputs:
if inner_edge.dst in inner_to_outer:
dst = inner_to_outer[inner_edge.dst]
Expand All @@ -708,15 +736,19 @@ def _modify_memlet_path(
mtree = state.memlet_tree(new_edge)

# Modify all memlets going forward/backward
def traverse(mtree_node):
def traverse(mtree_node, state, nstate):
result.add(mtree_node.edge)
mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data, top_edge.data)
offset = state.parent.arrays[top_edge.data.data].offset
mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data,
top_edge.data,
internal_offset=offset,
external_offset=offset)
for child in mtree_node.children:
traverse(child)
traverse(child, state, nstate)

result.add(new_edge)
for child in mtree.children:
traverse(child)
traverse(child, state, nstate)

return result

Expand Down Expand Up @@ -1035,8 +1067,8 @@ def _check_cand(candidates, outer_edges):

# If there are any symbols here that are not defined
# in "defined_symbols"
missing_symbols = (memlet.get_free_symbols_by_indices(list(indices),
list(indices)) - set(nsdfg.symbol_mapping.keys()))
missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) -
set(nsdfg.symbol_mapping.keys()))
if missing_symbols:
ignore.add(cname)
continue
Expand All @@ -1045,10 +1077,13 @@ def _check_cand(candidates, outer_edges):
_check_cand(out_candidates, state.out_edges_by_connector)

# Return result, filtering out the states
return ({k: (dc(v), ind)
for k, (v, _, ind) in in_candidates.items()
if k not in ignore}, {k: (dc(v), ind)
for k, (v, _, ind) in out_candidates.items() if k not in ignore})
return ({
k: (dc(v), ind)
for k, (v, _, ind) in in_candidates.items() if k not in ignore
}, {
k: (dc(v), ind)
for k, (v, _, ind) in out_candidates.items() if k not in ignore
})

def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False):
nsdfg = self.nsdfg
Expand All @@ -1071,7 +1106,16 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]],
outer_edge = next(iter(outer_edges(nsdfg_node, aname)))
except StopIteration:
continue
new_memlet = helpers.unsqueeze_memlet(refine, outer_edge.data)
if isinstance(outer_edge.dst, nodes.NestedSDFG):
conn = outer_edge.dst_conn
else:
conn = outer_edge.src_conn
int_desc = nsdfg.arrays[conn]
ext_desc = sdfg.arrays[outer_edge.data.data]
new_memlet = helpers.unsqueeze_memlet(refine,
outer_edge.data,
internal_offset=int_desc.offset,
external_offset=ext_desc.offset)
outer_edge.data.subset = subsets.Range([
ns if i in indices else os
for i, (os, ns) in enumerate(zip(outer_edge.data.subset, new_memlet.subset))
Expand Down
1 change: 0 additions & 1 deletion tests/fortran/non-interactive/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def test_fortran_frontend_function_test3():
sdfg.parent_nsdfg_node = None
sdfg.reset_sdfg_list()
sdfg.simplify(verbose=True)
sdfg.view()
sdfg.compile()


Expand Down
1 change: 0 additions & 1 deletion tests/fortran/tasklet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def test_fortran_frontend_tasklet():
"""

sdfg = fortran_parser.create_sdfg_from_string(test_string, "tasklet", normalize_offsets=True)
sdfg.view()
sdfg.simplify(verbose=True)

sdfg.compile()
Expand Down

0 comments on commit f6a9dc8

Please sign in to comment.