diff --git a/dace/codegen/CMakeLists.txt b/dace/codegen/CMakeLists.txt index 975490130c..27f17dd5f2 100644 --- a/dace/codegen/CMakeLists.txt +++ b/dace/codegen/CMakeLists.txt @@ -181,6 +181,7 @@ if(${NUM_ENV_VARS} GREATER 0) set(${KEY} ${VAL}) endforeach() endif() +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_PACKAGES "${DACE_ENV_PACKAGES}") string(REPLACE " " ";" DACE_ENV_PACKAGES "${DACE_ENV_PACKAGES}") foreach(PACKAGE_NAME ${DACE_ENV_PACKAGES}) find_package(${PACKAGE_NAME} REQUIRED) @@ -200,10 +201,14 @@ if(${NUM_ENV_VARS} GREATER 0) endforeach() endif() # Configure specified include directories, libraries, and flags +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_INCLUDES "${DACE_ENV_INCLUDES}") +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_LIBRARIES "${DACE_ENV_LIBRARIES}") string(REPLACE " " ";" DACE_ENV_INCLUDES "${DACE_ENV_INCLUDES}") string(REPLACE " " ";" DACE_ENV_LIBRARIES "${DACE_ENV_LIBRARIES}") include_directories(${DACE_ENV_INCLUDES}) set(DACE_LIBS ${DACE_LIBS} ${DACE_ENV_LIBRARIES}) +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_LINK_FLAGS "${DACE_ENV_LINK_FLAGS}") +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_COMPILE_FLAGS "${DACE_ENV_COMPILE_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${DACE_ENV_COMPILE_FLAGS}") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${DACE_ENV_LINK_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${DACE_ENV_LINK_FLAGS}") diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 8a2a7651b9..dd7fad7f7a 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -1286,7 +1286,11 @@ def memlet_definition(self, else: # Variable number of reads: get a const reference that can # be read if necessary - memlet_type = '%s const' % memlet_type + memlet_type = 'const %s' % memlet_type + if is_pointer: + # This is done to make the reference constant, otherwise + # compilers error out with initial reference value. + memlet_type += ' const' result += "{} &{} = {};".format(memlet_type, local_name, expr) defined = (DefinedType.Scalar diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index d3f9fad0d0..3dd03d16b1 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3657,10 +3657,11 @@ def _parse_sdfg_call(self, funcname: str, if isinstance(func, SDFG): sdfg = copy.deepcopy(func) funcname = sdfg.name - args = [(aname, self._parse_function_arg(arg)) + posargs = [(aname, self._parse_function_arg(arg)) for aname, arg in zip(sdfg.arg_names, node.args)] - args += [(arg.arg, self._parse_function_arg(arg.value)) + kwargs = [(arg.arg, self._parse_function_arg(arg.value)) for arg in node.keywords] + args = posargs + kwargs required_args = [ a for a in sdfg.arglist().keys() if a not in sdfg.symbols and not a.startswith('__return') @@ -3668,12 +3669,12 @@ def _parse_sdfg_call(self, funcname: str, all_args = required_args elif isinstance(func, SDFGConvertible) or self._has_sdfg(func): argnames, constant_args = func.__sdfg_signature__() - args = [(aname, self._parse_function_arg(arg)) - for aname, arg in zip(argnames, node.args)] - args += [(arg.arg, self._parse_function_arg(arg.value)) - for arg in node.keywords] + posargs = [(aname, self._parse_function_arg(arg)) + for aname, arg in zip(argnames, node.args)] + kwargs = [(arg.arg, self._parse_function_arg(arg.value)) + for arg in node.keywords] required_args = argnames - fargs = (self._eval_arg(arg) for _, arg in args) + args = posargs + kwargs # fcopy = copy.copy(func) fcopy = func @@ -3681,11 +3682,23 @@ def _parse_sdfg_call(self, funcname: str, fcopy.global_vars = {**self.globals, **func.global_vars} try: + fargs = tuple(self._eval_arg(arg) for _, arg in posargs) + fkwargs = {k: self._eval_arg(arg) for k, arg in kwargs} + if isinstance(fcopy, DaceProgram): fcopy.signature = copy.deepcopy(func.signature) - sdfg = fcopy.to_sdfg(*fargs, strict=self.strict, save=False) + sdfg = fcopy.to_sdfg(*fargs, + **fkwargs, + strict=self.strict, + save=False) else: - sdfg = fcopy.__sdfg__(*fargs) + sdfg = fcopy.__sdfg__(*fargs, **fkwargs) + + # Filter out parsed/omitted arguments + posargs = [(k, v) for k, v in posargs if k in required_args] + kwargs = [(k, v) for k, v in kwargs if k in required_args] + args = posargs + kwargs + except: # Parsing failure # If parsing fails in an auto-parsed context, exit silently if hasattr(node.func, 'oldnode'): @@ -3757,15 +3770,15 @@ def _parse_sdfg_call(self, funcname: str, required_args = dtypes.deduplicate(required_args) # Argument checks - for arg in node.keywords: + for aname, arg in kwargs: # Skip explicit return values - if arg.arg.startswith('__return'): - required_args.append(arg.arg) + if aname.startswith('__return'): + required_args.append(aname) continue - if arg.arg not in required_args and arg.arg not in all_args: + if aname not in required_args and aname not in all_args: raise DaceSyntaxError( self, node, 'Invalid keyword argument "%s" in call to ' - '"%s"' % (arg.arg, funcname)) + '"%s"' % (aname, funcname)) if len(args) != len(required_args): raise DaceSyntaxError( self, node, 'Argument number mismatch in' @@ -4105,11 +4118,19 @@ def create_callback(self, node: ast.Call, create_graph=True): parent_is_toplevel = True parent: ast.AST = None for anode in ast.walk(self.program_ast): + if parent is not None: + break for child in ast.iter_child_nodes(anode): if child is node: parent = anode parent_is_toplevel = getattr(anode, 'toplevel', False) break + if hasattr(child, 'func') and hasattr(child.func, 'oldnode'): + # Check if the AST node is part of a failed parse + if child.func.oldnode is node: + parent = anode + parent_is_toplevel = getattr(anode, 'toplevel', False) + break if parent is None: raise DaceSyntaxError( self, node, @@ -4235,12 +4256,12 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): self.sdfg.arrays[cname].dtype) # Setup arguments in graph - for arg in args: + for arg in dtypes.deduplicate(args): r = self.last_state.add_read(arg) self.last_state.add_edge(r, None, tasklet, f'__in_{arg}', Memlet(arg)) - for arg in outargs: + for arg in dtypes.deduplicate(outargs): w = self.last_state.add_write(arg) self.last_state.add_edge(tasklet, f'__out_{arg}', w, None, Memlet(arg)) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 926196a54c..9b68468057 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -145,7 +145,6 @@ def __init__(self, self.f = f self.dec_args = args self.dec_kwargs = kwargs - self.name = f.__name__ self.resolve_functions = constant_functions self.argnames = _get_argnames(f) if method: @@ -168,6 +167,7 @@ def __init__(self, self.symbols = set(k for k, v in self.global_vars.items() if isinstance(v, symbolic.symbol)) self.closure_arg_mapping: Dict[str, Callable[[], Any]] = {} + self.resolver: pycommon.SDFGClosure = None # Add type annotations from decorator arguments (DEPRECATED) if self.dec_args: @@ -206,14 +206,45 @@ def to_sdfg(self, strict=None, save=False, validate=False, + use_cache=False, **kwargs) -> SDFG: """ Parses the DaCe function into an SDFG. """ - return self._parse(args, + if use_cache: + # Update global variables with current closure + self.global_vars = _get_locals_and_globals(self.f) + + # Move "self" from an argument into the closure + if self.methodobj is not None: + self.global_vars[self.objname] = self.methodobj + + argtypes, arg_mapping, constant_args = self._get_type_annotations( + args, kwargs) + + # Add constant arguments to globals for caching + self.global_vars.update(constant_args) + + # Check cache for already-parsed SDFG + cachekey = self._cache.make_key(argtypes, self.closure_array_keys, + self.closure_constant_keys, + constant_args) + + if self._cache.has(cachekey): + entry = self._cache.get(cachekey) + return entry.sdfg + + sdfg = self._parse(args, kwargs, strict=strict, save=save, validate=validate) + if use_cache: + # Add to cache + self._cache.add(cachekey, sdfg, None) + + return sdfg + + def __sdfg__(self, *args, **kwargs) -> SDFG: return self._parse(args, kwargs, @@ -239,6 +270,16 @@ def methodobj(self) -> Any: def methodobj(self, new_obj: Any): self._methodobj = new_obj + @property + def name(self) -> str: + """ Returns a unique name for this program. """ + result = '' + if self.f.__module__ is not None and self.f.__module__ != '__main__': + result += self.f.__module__.replace('.', '_') + '_' + if self._methodobj is not None: + result += type(self._methodobj).__name__ + '_' + return result + self.f.__name__ + def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: return self.argnames, self.constant_args @@ -261,7 +302,10 @@ def __sdfg_closure__( if reevaluate is None: result = {k: v() for k, v in self.closure_arg_mapping.items()} - result.update({k: v[1] for k, v in self.resolver.callbacks.items()}) + if self.resolver is not None: + result.update( + {k: v[1] + for k, v in self.resolver.callbacks.items()}) return result else: return { diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 88fd3b3578..9b6a1905c9 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -615,7 +615,8 @@ def global_value_to_node(self, res = self.global_value_to_node(parsed, parent_node, qualname, recurse, detect_callables) - del self.closure.callbacks[cbname] + # Keep callback in callbacks in case of parsing failure + # del self.closure.callbacks[cbname] return res except Exception: # Parsing failed (almost any exception can occur) return newnode @@ -838,6 +839,11 @@ def visit_Call(self, node: ast.Call): if not isinstance(node.func, (ast.Num, ast.Constant)): self.seen_calls.add(astutils.rname(node.func)) return self.generic_visit(node) + if hasattr(node.func, 'oldnode'): + if isinstance(node.func.oldnode, ast.Call): + self.seen_calls.add(astutils.rname(node.func.oldnode.func)) + else: + self.seen_calls.add(astutils.rname(node.func.oldnode)) if isinstance(node.func, ast.Num): value = node.func.n else: @@ -849,6 +855,7 @@ def visit_Call(self, node: ast.Call): constant_args = self._eval_args(node) # Resolve nested closure as necessary + qualname = None try: qualname = next(k for k, v in self.closure.closure_sdfgs.items() if v is value) @@ -863,7 +870,8 @@ def visit_Call(self, node: ast.Call): raise except Exception as ex: # Parsing failed (anything can happen here) warnings.warn(f'Parsing SDFGConvertible {value} failed: {ex}') - del self.closure.closure_sdfgs[qualname] + if qualname in self.closure.closure_sdfgs: + del self.closure.closure_sdfgs[qualname] # Return old call AST instead node.func = node.func.oldnode.func diff --git a/dace/sdfg/infer_types.py b/dace/sdfg/infer_types.py index 0804d0d0dd..7c2db88714 100644 --- a/dace/sdfg/infer_types.py +++ b/dace/sdfg/infer_types.py @@ -25,8 +25,8 @@ def infer_connector_types(sdfg: SDFG): continue scalar = (e.data.subset and e.data.subset.num_elements() == 1) if e.data.data is not None: - allocated_as_scalar = (sdfg.arrays[e.data.data].storage is - not dtypes.StorageType.GPU_Global) + allocated_as_scalar = (sdfg.arrays[e.data.data].storage + is not dtypes.StorageType.GPU_Global) else: allocated_as_scalar = True @@ -64,11 +64,11 @@ def infer_connector_types(sdfg: SDFG): and (not e.data.dynamic or (e.data.dynamic and e.data.wcr is not None))) if e.data.data is not None: - allocated_as_scalar = (sdfg.arrays[e.data.data].storage is - not dtypes.StorageType.GPU_Global) + allocated_as_scalar = (sdfg.arrays[e.data.data].storage + is not dtypes.StorageType.GPU_Global) else: allocated_as_scalar = True - + if node.out_connectors[cname].type is None: # If nested SDFG, try to use internal array type if isinstance(node, nodes.NestedSDFG): @@ -105,7 +105,7 @@ def infer_connector_types(sdfg: SDFG): def set_default_schedule_and_storage_types( - sdfg: SDFG, toplevel_schedule: dtypes.ScheduleType): + sdfg: SDFG, toplevel_schedule: dtypes.ScheduleType): """ Sets default storage and schedule types throughout SDFG in-place. Replaces `ScheduleType.Default` and `StorageType.Default` @@ -244,7 +244,8 @@ def _set_default_storage_types(sdfg: SDFG, # Take care of remaining arrays/scalars, e.g., code->code edges for desc in sdfg.arrays.values(): - if desc.storage is dtypes.StorageType.Default: + if ((desc.transient or sdfg.parent_sdfg is None) + and desc.storage is dtypes.StorageType.Default): desc.storage = dtypes.StorageType.Register for state in sdfg.nodes(): diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index fbbc768218..ab91f01632 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -92,8 +92,12 @@ def replace_properties(node: Any, symrepl: Dict[symbolic.symbol, continue if lang is dtypes.Language.CPP: # Replace in C++ code + # Avoid import loop + from dace.codegen.targets.cpp import sym2cpp + # Use local variables and shadowing to replace - replacement = 'auto %s = %s;\n' % (name, new_name) + replacement = 'auto %s = %s;\n' % (name, + sym2cpp(new_name)) propval.code = replacement + newcode else: warnings.warn('Replacement of %s with %s was not made ' @@ -101,7 +105,9 @@ def replace_properties(node: Any, symrepl: Dict[symbolic.symbol, (name, new_name, lang)) elif propval.code is not None: for stmt in propval.code: - ASTFindReplace({name: new_name}).visit(stmt) + ASTFindReplace({ + name: symbolic.symstr(new_name) + }).visit(stmt) elif (isinstance(propclass, properties.DictProperty) and pname == 'symbol_mapping'): # Symbol mappings for nested SDFGs diff --git a/tests/inlining_test.py b/tests/inlining_test.py index c0c777324f..58f3216d83 100644 --- a/tests/inlining_test.py +++ b/tests/inlining_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import dace +from dace.transformation.interstate import InlineSDFG import numpy as np import pytest @@ -167,6 +168,36 @@ def outerprog(A: dace.float64[20]): assert np.allclose(A, expected) +def test_inline_symexpr(): + nsdfg = dace.SDFG('inner') + nsdfg.add_array('a', [20], dace.float64) + nstate = nsdfg.add_state() + nstate.add_mapped_tasklet('doit', {'k': '0:20'}, {}, + '''if k < j: + o = 2.0''', {'o': dace.Memlet('a[k]', dynamic=True)}, + external_edges=True) + + sdfg = dace.SDFG('outer') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_symbol('i', dace.int32) + state = sdfg.add_state() + w = state.add_write('A') + nsdfg_node = state.add_nested_sdfg(nsdfg, None, {}, {'a'}, + {'j': 'min(i, 10)'}) + state.add_edge(nsdfg_node, 'a', w, None, dace.Memlet('A')) + + # Verify that compilation works before inlining + sdfg.compile() + + sdfg.apply_transformations(InlineSDFG) + + # Compile and run + a = np.random.rand(20) + sdfg(A=a, i=15) + assert np.allclose(a[:10], 2.0) + assert not np.allclose(a[10:], 2.0) + + if __name__ == "__main__": test() # Skipped to to bug that cannot be reproduced @@ -174,3 +205,4 @@ def outerprog(A: dace.float64[20]): test_empty_memlets() test_multistate_inline() test_multistate_inline_samename() + test_inline_symexpr() diff --git a/tests/python_frontend/callback_autodetect_test.py b/tests/python_frontend/callback_autodetect_test.py index be797d0a96..02a7c4c93a 100644 --- a/tests/python_frontend/callback_autodetect_test.py +++ b/tests/python_frontend/callback_autodetect_test.py @@ -357,6 +357,61 @@ def timeprog(A: dace.float64[20]): assert np.all(B > A) and np.all(A > now) +def test_object_with_nested_callback(): + c = np.random.rand(20) + + @dace_inhibitor + def call_another_function(a, b): + nonlocal c + c[:] = a + b + + class MyObject: + def __call__(self, a, b): + c = dict(a=a, b=b) + call_another_function(**c) + + obj = MyObject() + + @dace.program + def callobj(a, b): + obj(a, b) + + a = np.random.rand(20) + b = np.random.rand(20) + callobj(a, b) + assert np.allclose(c, a + b) + + +def test_two_parameters_same_name(): + @dace_inhibitor + def add(a, b): + return a + b + + @dace.program + def calladd(A: dace.float64[20], B: dace.float64[20]): + B[:] = add(A, A) + + a = np.random.rand(20) + b = np.random.rand(20) + calladd(a, b) + assert np.allclose(b, a + a) + + +def test_inout_same_name(): + @dace_inhibitor + def add(a, b): + return a + b + + @dace.program + def calladd(A: dace.float64[20]): + A[:] = add(A, A) + + a = np.random.rand(20) + expected = a + a + calladd(a) + assert np.allclose(expected, a) + + if __name__ == '__main__': test_automatic_callback() test_automatic_callback_2() @@ -372,3 +427,6 @@ def timeprog(A: dace.float64[20]): test_callback_samename() # test_gpu_callback() test_bad_closure() + test_object_with_nested_callback() + test_two_parameters_same_name() + test_inout_same_name()