Skip to content

Commit

Permalink
Adds just the framework for integral loops
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Oct 18, 2023
1 parent 0755385 commit 73616ee
Show file tree
Hide file tree
Showing 6 changed files with 924 additions and 395 deletions.
4 changes: 2 additions & 2 deletions dace/sdfg/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def infer_connector_types(sdfg: SDFG):
:param sdfg: The SDFG to infer.
"""
# Loop over states, and in a topological sort over each state's nodes
for state in sdfg.nodes():
for state in sdfg.states():
for node in dfs_topological_sort(state):
# Try to infer input connector type from node type or previous edges
for e in state.in_edges(node):
Expand Down Expand Up @@ -167,7 +167,7 @@ def set_default_schedule_and_storage_types(scope: Union[SDFG, SDFGState, nodes.E

if isinstance(scope, SDFG):
# Set device for default top-level schedules and storages
for state in scope.nodes():
for state in scope.states():
set_default_schedule_and_storage_types(state,
parent_schedules,
use_parent_schedule=use_parent_schedule,
Expand Down
5 changes: 2 additions & 3 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,8 @@ def label(self):
def __label__(self, sdfg, state):
return self.data

def desc(self, sdfg):
from dace.sdfg import SDFGState, ScopeSubgraphView
if isinstance(sdfg, (SDFGState, ScopeSubgraphView)):
def desc(self, sdfg: Union['dace.sdfg.SDFG', 'dace.sdfg.SDFGState', 'dace.sdfg.ScopeSubgraphView']):
if isinstance(sdfg, dace.sdfg.SDFG):
sdfg = sdfg.parent
return sdfg.arrays[self.data]

Expand Down
29 changes: 15 additions & 14 deletions dace/sdfg/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,18 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]):
sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname]
del sdfg.constants_prop[aname]

# Replace in interstate edges
for e in sdfg.edges():
e.data.replace_dict(repl, replace_keys=False)

for state in sdfg.nodes():
# Replace in access nodes
for node in state.data_nodes():
if node.data in repl:
node.data = repl[node.data]

# Replace in memlets
for edge in state.edges():
if edge.data.data in repl:
edge.data.data = repl[edge.data.data]
for cf in sdfg.all_state_scopes_recursive():
# Replace in interstate edges
for e in cf.edges():
e.data.replace_dict(repl, replace_keys=False)

for block in cf.nodes():
if isinstance(block, dace.SDFGState):
# Replace in access nodes
for node in block.data_nodes():
if node.data in repl:
node.data = repl[node.data]
# Replace in memlets
for edge in block.edges():
if edge.data.data in repl:
edge.data.data = repl[edge.data.data]
Loading

0 comments on commit 73616ee

Please sign in to comment.