diff --git a/dace/cli/sdfv.py b/dace/cli/sdfv.py index d14059468f..3abe9ea249 100644 --- a/dace/cli/sdfv.py +++ b/dace/cli/sdfv.py @@ -22,7 +22,7 @@ class NewCls(cls): return NewCls -def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: bool = True): +def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: bool = True, compress: bool = True): """ View an sdfg in the system's HTML viewer @@ -33,6 +33,7 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: b served using a basic web server on that port, blocking the current thread. :param verbose: Be verbose. + :param compress: Use compression for the temporary file. """ # If vscode is open, try to open it inside vscode if filename is None: @@ -41,8 +42,9 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: b or 'VSCODE_IPC_HOOK_CLI' in os.environ or 'VSCODE_GIT_IPC_HANDLE' in os.environ ): - fd, filename = tempfile.mkstemp(suffix='.sdfg') - sdfg.save(filename) + suffix = '.sdfgz' if compress else '.sdfg' + fd, filename = tempfile.mkstemp(suffix=suffix) + sdfg.save(filename, compress=compress) if platform.system() == 'Darwin': # Special case for MacOS os.system(f'open {filename}') diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 238b0f7a22..5abcc770aa 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -55,6 +55,7 @@ def __init__(self, sdfg: SDFG): # resolve all symbols and constants # first handle root + sdfg.reset_cfg_list() self._symbols_and_constants[sdfg.cfg_id] = sdfg.free_symbols.union(sdfg.constants_prop.keys()) # then recurse for nested, state in sdfg.all_nodes_recursive(): @@ -66,7 +67,7 @@ def __init__(self, sdfg: SDFG): # found a new nested sdfg: resolve symbols and constants result = nsdfg.free_symbols.union(nsdfg.constants_prop.keys()) - parent_constants = self._symbols_and_constants[nsdfg._parent_sdfg.cfg_id] + parent_constants = self._symbols_and_constants[nsdfg.parent_sdfg.cfg_id] result |= parent_constants # check for constant inputs diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 3c66400b3d..36c40e4e24 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2350,7 +2350,7 @@ def visit_For(self, node: ast.For): if symbolic.issymbolic(atom, self.sdfg.constants): astr = str(atom) # Check for undefined variables - if astr not in self.defined: + if astr not in self.defined and not ('.' in astr and astr in self.sdfg.arrays): raise DaceSyntaxError(self, node, 'Undefined variable "%s"' % atom) # Add to global SDFG symbols if not a scalar if (astr not in self.sdfg.symbols and not (astr in self.variables or astr in self.sdfg.arrays)): @@ -3944,6 +3944,10 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no else: required_args.extend(symbols) required_args = dtypes.deduplicate(required_args) + gargs = set(a[0] for a in args) + for rarg in required_args: + if rarg not in gargs and rarg in self.sdfg.symbols: + args.append((rarg, rarg)) # Argument checks for aname, arg in kwargs: diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index b6aad43664..73ab0d154c 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -1294,8 +1294,14 @@ def _array_array_where(visitor: ProgramVisitor, raise ValueError('numpy.where is only supported for the case where x and y are given') cond_arr = sdfg.arrays[cond_operand] - left_arr = sdfg.arrays.get(left_operand, None) - right_arr = sdfg.arrays.get(right_operand, None) + try: + left_arr = sdfg.arrays[left_operand] + except KeyError: + left_arr = None + try: + right_arr = sdfg.arrays[right_operand] + except KeyError: + right_arr = None left_type = left_arr.dtype if left_arr else dtypes.dtype_to_typeclass(type(left_operand)) right_type = right_arr.dtype if right_arr else dtypes.dtype_to_typeclass(type(right_operand)) diff --git a/dace/libraries/standard/nodes/reduce.py b/dace/libraries/standard/nodes/reduce.py index 970dfcef3a..7cda4017da 100644 --- a/dace/libraries/standard/nodes/reduce.py +++ b/dace/libraries/standard/nodes/reduce.py @@ -1562,7 +1562,7 @@ class Reduce(dace.sdfg.nodes.LibraryNode): # Properties axes = ListProperty(element_type=int, allow_none=True) wcr = LambdaProperty(default='lambda a, b: a') - identity = Property(allow_none=True) + identity = Property(allow_none=True, to_json=lambda x: str(x)) def __init__(self, name, diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 1bd343ecfb..841f11a7b1 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -2509,7 +2509,7 @@ def apply_strict_transformations(self, validate=True, validate_all=False): warnings.warn('SDFG.apply_strict_transformations is deprecated, use SDFG.simplify instead.', DeprecationWarning) return self.simplify(validate, validate_all) - def simplify(self, validate=True, validate_all=False, verbose=False): + def simplify(self, validate=True, validate_all=False, verbose=False, options=None): """ Applies safe transformations (that will surely increase the performance) on the SDFG. For example, this fuses redundant states (safely) and removes redundant arrays. @@ -2517,7 +2517,8 @@ def simplify(self, validate=True, validate_all=False, verbose=False): :note: This is an in-place operation on the SDFG. """ from dace.transformation.passes.simplify import SimplifyPass - return SimplifyPass(validate=validate, validate_all=validate_all, verbose=verbose).apply_pass(self, {}) + return SimplifyPass(validate=validate, validate_all=validate_all, verbose=verbose, + pass_options=options).apply_pass(self, {}) def auto_optimize(self, device: dtypes.DeviceType, diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index a873bf0888..00ab59c1de 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -88,6 +88,16 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]: def report(self, pass_retval: int) -> str: return f'Inlined {pass_retval} SDFGs.' + def set_opts(self, opts): + opt_keys = [ + 'multistate', + ] + + for k in opt_keys: + attr_k = InlineSDFGs.__name__ + '.' + k + if attr_k in opts: + setattr(self, k, opts[attr_k]) + @dataclass(unsafe_hash=True) @properties.make_properties