Skip to content

Commit

Permalink
More symbol fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Oct 19, 2023
1 parent 909f032 commit 9099017
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
2 changes: 2 additions & 0 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def __label__(self, sdfg, state):
return self.data

def desc(self, sdfg: Union['dace.sdfg.SDFG', 'dace.sdfg.SDFGState', 'dace.sdfg.ScopeSubgraphView']):
if isinstance(sdfg, (dace.sdfg.SDFGState, dace.sdfg.ScopeSubgraphView)):
sdfg = sdfg.parent
return sdfg.arrays[self.data]

def validate(self, sdfg, state):
Expand Down
18 changes: 7 additions & 11 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,13 @@ def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[Multi
# Query, subgraph, and replacement methods

@abc.abstractmethod
def used_symbols(self, all_symbols: bool) -> Set[str]:
def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]:
"""
Returns a set of symbol names that are used in the graph.
:param all_symbols: If False, only returns symbols that are needed as arguments (only used in generated code).
:param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping
will be removed from the set of defined symbols.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -618,13 +620,7 @@ def is_leaf_memlet(self, e):
return False
return True

def used_symbols(self, all_symbols: bool) -> Set[str]:
"""
Returns a set of symbol names that are used in the state.
:param all_symbols: If False, only returns the set of symbols that will be used
in the generated code and are needed as arguments.
"""
def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]:
state = self.graph if isinstance(self, SubgraphView) else self
sdfg = state.parent
new_symbols = set()
Expand Down Expand Up @@ -1013,8 +1009,8 @@ def _used_symbols_internal(self,
keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]:
raise NotImplementedError()

def used_symbols(self, all_symbols: bool) -> Set[str]:
return self._used_symbols_internal(all_symbols)[0]
def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]:
return self._used_symbols_internal(all_symbols, keep_defined_in_mapping=keep_defined_in_mapping)[0]

def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]:
read_set = set()
Expand Down Expand Up @@ -2557,7 +2553,7 @@ def all_state_scopes_recursive(self, recurse_into_sdfgs=False) -> Iterator['Scop

def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']:
""" Iterate over this and all nested SDFGs. """
for cfg in self.all_state_scopes_recursive(recurse_into_sdfgs=False):
for cfg in self.all_state_scopes_recursive(recurse_into_sdfgs=True):
if isinstance(cfg, dace.SDFG):
yield cfg

Expand Down

0 comments on commit 9099017

Please sign in to comment.