From f6a9dc8d462f436706749e7664e794b605bcde09 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 24 Dec 2024 10:57:39 +0100 Subject: [PATCH] Import some of the offset changes --- dace/sdfg/propagation.py | 10 ++- .../transformation/interstate/sdfg_nesting.py | 80 ++++++++++++++----- .../fortran/non-interactive/function_test.py | 1 - tests/fortran/tasklet_test.py | 1 - 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 2983ec3c63..4ba80b4ea9 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1117,7 +1117,10 @@ def propagate_memlets_nested_sdfg(parent_sdfg: 'SDFG', parent_state: 'SDFGState' if internal_memlet is None: continue try: - iedge.data = unsqueeze_memlet(internal_memlet, iedge.data, True) + ext_desc = parent_sdfg.arrays[iedge.data.data] + int_desc = sdfg.arrays[iedge.dst_conn] + iedge.data = unsqueeze_memlet(internal_memlet, iedge.data, True, internal_offset=int_desc.offset, + external_offset=ext_desc.offset) # If no appropriate memlet found, use array dimension for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[iedge.data.data].shape)): if rng[1] + 1 == s: @@ -1137,7 +1140,10 @@ def propagate_memlets_nested_sdfg(parent_sdfg: 'SDFG', parent_state: 'SDFGState' if internal_memlet is None: continue try: - oedge.data = unsqueeze_memlet(internal_memlet, oedge.data, True) + ext_desc = parent_sdfg.arrays[oedge.data.data] + int_desc = sdfg.arrays[oedge.src_conn] + oedge.data = unsqueeze_memlet(internal_memlet, oedge.data, True, internal_offset=int_desc.offset, + external_offset=ext_desc.offset) # If no appropriate memlet found, use array dimension for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[oedge.data.data].shape)): if rng[1] + 1 == s: diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 31e751bb6a..3ea55e3cab 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -509,14 +509,24 @@ def apply(self, state: SDFGState, sdfg: SDFG): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e._data.get_dst_subset(e, state): - new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_dst_subset=True) + offset = sdfg.arrays[e.data.data].offset + new_memlet = helpers.unsqueeze_memlet(e.data, + outer_edge.data, + use_dst_subset=True, + internal_offset=offset, + external_offset=offset) e._data.dst_subset = new_memlet.subset # NOTE: Node is source for edge in state.out_edges(node): if (edge not in modified_edges and edge.data.data == node.data): for e in state.memlet_tree(edge): if e._data.get_src_subset(e, state): - new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_src_subset=True) + offset = sdfg.arrays[e.data.data].offset + new_memlet = helpers.unsqueeze_memlet(e.data, + outer_edge.data, + use_src_subset=True, + internal_offset=offset, + external_offset=offset) e._data.src_subset = new_memlet.subset # If source/sink node is not connected to a source/destination access @@ -625,10 +635,17 @@ def _modify_access_to_access(self, state.out_edges_by_connector(nsdfg_node, inner_data)) # Create memlet by unsqueezing both w.r.t. src and # dst subsets - in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True) + offset = state.parent.arrays[top_edge.data.data].offset + in_memlet = helpers.unsqueeze_memlet(inner_edge.data, + top_edge.data, + use_src_subset=True, + internal_offset=offset, + external_offset=offset) out_memlet = helpers.unsqueeze_memlet(inner_edge.data, matching_edge.data, - use_dst_subset=True) + use_dst_subset=True, + internal_offset=offset, + external_offset=offset) new_memlet = in_memlet new_memlet.other_subset = out_memlet.subset @@ -651,10 +668,17 @@ def _modify_access_to_access(self, state.out_edges_by_connector(nsdfg_node, inner_data)) # Create memlet by unsqueezing both w.r.t. src and # dst subsets - in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True) + offset = state.parent.arrays[top_edge.data.data].offset + in_memlet = helpers.unsqueeze_memlet(inner_edge.data, + top_edge.data, + use_src_subset=True, + internal_offset=offset, + external_offset=offset) out_memlet = helpers.unsqueeze_memlet(inner_edge.data, matching_edge.data, - use_dst_subset=True) + use_dst_subset=True, + internal_offset=offset, + external_offset=offset) new_memlet = in_memlet new_memlet.other_subset = out_memlet.subset @@ -689,7 +713,11 @@ def _modify_memlet_path( if inner_edge in edges_to_ignore: new_memlet = inner_edge.data else: - new_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data) + offset = state.parent.arrays[top_edge.data.data].offset + new_memlet = helpers.unsqueeze_memlet(inner_edge.data, + top_edge.data, + internal_offset=offset, + external_offset=offset) if inputs: if inner_edge.dst in inner_to_outer: dst = inner_to_outer[inner_edge.dst] @@ -708,15 +736,19 @@ def _modify_memlet_path( mtree = state.memlet_tree(new_edge) # Modify all memlets going forward/backward - def traverse(mtree_node): + def traverse(mtree_node, state, nstate): result.add(mtree_node.edge) - mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data, top_edge.data) + offset = state.parent.arrays[top_edge.data.data].offset + mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data, + top_edge.data, + internal_offset=offset, + external_offset=offset) for child in mtree_node.children: - traverse(child) + traverse(child, state, nstate) result.add(new_edge) for child in mtree.children: - traverse(child) + traverse(child, state, nstate) return result @@ -1035,8 +1067,8 @@ def _check_cand(candidates, outer_edges): # If there are any symbols here that are not defined # in "defined_symbols" - missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), - list(indices)) - set(nsdfg.symbol_mapping.keys())) + missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) - + set(nsdfg.symbol_mapping.keys())) if missing_symbols: ignore.add(cname) continue @@ -1045,10 +1077,13 @@ def _check_cand(candidates, outer_edges): _check_cand(out_candidates, state.out_edges_by_connector) # Return result, filtering out the states - return ({k: (dc(v), ind) - for k, (v, _, ind) in in_candidates.items() - if k not in ignore}, {k: (dc(v), ind) - for k, (v, _, ind) in out_candidates.items() if k not in ignore}) + return ({ + k: (dc(v), ind) + for k, (v, _, ind) in in_candidates.items() if k not in ignore + }, { + k: (dc(v), ind) + for k, (v, _, ind) in out_candidates.items() if k not in ignore + }) def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False): nsdfg = self.nsdfg @@ -1071,7 +1106,16 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]], outer_edge = next(iter(outer_edges(nsdfg_node, aname))) except StopIteration: continue - new_memlet = helpers.unsqueeze_memlet(refine, outer_edge.data) + if isinstance(outer_edge.dst, nodes.NestedSDFG): + conn = outer_edge.dst_conn + else: + conn = outer_edge.src_conn + int_desc = nsdfg.arrays[conn] + ext_desc = sdfg.arrays[outer_edge.data.data] + new_memlet = helpers.unsqueeze_memlet(refine, + outer_edge.data, + internal_offset=int_desc.offset, + external_offset=ext_desc.offset) outer_edge.data.subset = subsets.Range([ ns if i in indices else os for i, (os, ns) in enumerate(zip(outer_edge.data.subset, new_memlet.subset)) diff --git a/tests/fortran/non-interactive/function_test.py b/tests/fortran/non-interactive/function_test.py index 87cfd260c3..c637de41ad 100644 --- a/tests/fortran/non-interactive/function_test.py +++ b/tests/fortran/non-interactive/function_test.py @@ -267,7 +267,6 @@ def test_fortran_frontend_function_test3(): sdfg.parent_nsdfg_node = None sdfg.reset_sdfg_list() sdfg.simplify(verbose=True) - sdfg.view() sdfg.compile() diff --git a/tests/fortran/tasklet_test.py b/tests/fortran/tasklet_test.py index 49a2f5ac79..5c125f3e0f 100644 --- a/tests/fortran/tasklet_test.py +++ b/tests/fortran/tasklet_test.py @@ -32,7 +32,6 @@ def test_fortran_frontend_tasklet(): """ sdfg = fortran_parser.create_sdfg_from_string(test_string, "tasklet", normalize_offsets=True) - sdfg.view() sdfg.simplify(verbose=True) sdfg.compile()