Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various Fixes and QOL improvements #1853

Merged
merged 10 commits into from
Jan 8, 2025
8 changes: 5 additions & 3 deletions dace/cli/sdfv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class NewCls(cls):
return NewCls


def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: bool = True):
def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: bool = True, compress: bool = True):
"""
View an sdfg in the system's HTML viewer

Expand All @@ -33,6 +33,7 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: b
served using a basic web server on that port,
blocking the current thread.
:param verbose: Be verbose.
:param compress: Use compression for the temporary file.
"""
# If vscode is open, try to open it inside vscode
if filename is None:
Expand All @@ -41,8 +42,9 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: b
or 'VSCODE_IPC_HOOK_CLI' in os.environ
or 'VSCODE_GIT_IPC_HANDLE' in os.environ
):
fd, filename = tempfile.mkstemp(suffix='.sdfg')
sdfg.save(filename)
suffix = '.sdfgz' if compress else '.sdfg'
fd, filename = tempfile.mkstemp(suffix=suffix)
sdfg.save(filename, compress=compress)
if platform.system() == 'Darwin':
# Special case for MacOS
os.system(f'open {filename}')
Expand Down
3 changes: 2 additions & 1 deletion dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, sdfg: SDFG):

# resolve all symbols and constants
# first handle root
sdfg.reset_cfg_list()
self._symbols_and_constants[sdfg.cfg_id] = sdfg.free_symbols.union(sdfg.constants_prop.keys())
# then recurse
for nested, state in sdfg.all_nodes_recursive():
Expand All @@ -66,7 +67,7 @@ def __init__(self, sdfg: SDFG):
# found a new nested sdfg: resolve symbols and constants
result = nsdfg.free_symbols.union(nsdfg.constants_prop.keys())

parent_constants = self._symbols_and_constants[nsdfg._parent_sdfg.cfg_id]
parent_constants = self._symbols_and_constants[nsdfg.parent_sdfg.cfg_id]
result |= parent_constants

# check for constant inputs
Expand Down
6 changes: 5 additions & 1 deletion dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2350,7 +2350,7 @@ def visit_For(self, node: ast.For):
if symbolic.issymbolic(atom, self.sdfg.constants):
astr = str(atom)
# Check for undefined variables
if astr not in self.defined:
if astr not in self.defined and not ('.' in astr and astr in self.sdfg.arrays):
raise DaceSyntaxError(self, node, 'Undefined variable "%s"' % atom)
# Add to global SDFG symbols if not a scalar
if (astr not in self.sdfg.symbols and not (astr in self.variables or astr in self.sdfg.arrays)):
Expand Down Expand Up @@ -3944,6 +3944,10 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no
else:
required_args.extend(symbols)
required_args = dtypes.deduplicate(required_args)
gargs = set(a[0] for a in args)
for rarg in required_args:
if rarg not in gargs and rarg in self.sdfg.symbols:
args.append((rarg, rarg))

# Argument checks
for aname, arg in kwargs:
Expand Down
10 changes: 8 additions & 2 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,8 +1294,14 @@ def _array_array_where(visitor: ProgramVisitor,
raise ValueError('numpy.where is only supported for the case where x and y are given')

cond_arr = sdfg.arrays[cond_operand]
left_arr = sdfg.arrays.get(left_operand, None)
right_arr = sdfg.arrays.get(right_operand, None)
try:
left_arr = sdfg.arrays[left_operand]
except KeyError:
left_arr = None
try:
right_arr = sdfg.arrays[right_operand]
except KeyError:
right_arr = None

left_type = left_arr.dtype if left_arr else dtypes.dtype_to_typeclass(type(left_operand))
right_type = right_arr.dtype if right_arr else dtypes.dtype_to_typeclass(type(right_operand))
Expand Down
2 changes: 1 addition & 1 deletion dace/libraries/standard/nodes/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ class Reduce(dace.sdfg.nodes.LibraryNode):
# Properties
axes = ListProperty(element_type=int, allow_none=True)
wcr = LambdaProperty(default='lambda a, b: a')
identity = Property(allow_none=True)
identity = Property(allow_none=True, to_json=lambda x: str(x))

def __init__(self,
name,
Expand Down
5 changes: 3 additions & 2 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,15 +2509,16 @@ def apply_strict_transformations(self, validate=True, validate_all=False):
warnings.warn('SDFG.apply_strict_transformations is deprecated, use SDFG.simplify instead.', DeprecationWarning)
return self.simplify(validate, validate_all)

def simplify(self, validate=True, validate_all=False, verbose=False):
def simplify(self, validate=True, validate_all=False, verbose=False, options=None):
""" Applies safe transformations (that will surely increase the
performance) on the SDFG. For example, this fuses redundant states
(safely) and removes redundant arrays.

:note: This is an in-place operation on the SDFG.
"""
from dace.transformation.passes.simplify import SimplifyPass
return SimplifyPass(validate=validate, validate_all=validate_all, verbose=verbose).apply_pass(self, {})
return SimplifyPass(validate=validate, validate_all=validate_all, verbose=verbose,
pass_options=options).apply_pass(self, {})

def auto_optimize(self,
device: dtypes.DeviceType,
Expand Down
10 changes: 10 additions & 0 deletions dace/transformation/passes/fusion_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]:
def report(self, pass_retval: int) -> str:
return f'Inlined {pass_retval} SDFGs.'

def set_opts(self, opts):
tbennun marked this conversation as resolved.
Show resolved Hide resolved
opt_keys = [
'multistate',
]

for k in opt_keys:
attr_k = InlineSDFGs.__name__ + '.' + k
if attr_k in opts:
setattr(self, k, opts[attr_k])


@dataclass(unsafe_hash=True)
@properties.make_properties
Expand Down
Loading