From 1856db88d9c373cca810d4d5984a8437aa951469 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 2 Dec 2021 23:10:08 +0100 Subject: [PATCH 01/14] Handle case where closure is evaluated before preprocessing --- dace/frontend/python/parser.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 926196a54c..216bdc79d9 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -168,6 +168,7 @@ def __init__(self, self.symbols = set(k for k, v in self.global_vars.items() if isinstance(v, symbolic.symbol)) self.closure_arg_mapping: Dict[str, Callable[[], Any]] = {} + self.resolver: pycommon.SDFGClosure = None # Add type annotations from decorator arguments (DEPRECATED) if self.dec_args: @@ -261,7 +262,10 @@ def __sdfg_closure__( if reevaluate is None: result = {k: v() for k, v in self.closure_arg_mapping.items()} - result.update({k: v[1] for k, v in self.resolver.callbacks.items()}) + if self.resolver is not None: + result.update( + {k: v[1] + for k, v in self.resolver.callbacks.items()}) return result else: return { From 55893894683e9e69652551bced697cfd82c09d19 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 3 Dec 2021 01:30:35 +0100 Subject: [PATCH 02/14] Replace remaining expansion expressions with cmake variables --- dace/codegen/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dace/codegen/CMakeLists.txt b/dace/codegen/CMakeLists.txt index 975490130c..27f17dd5f2 100644 --- a/dace/codegen/CMakeLists.txt +++ b/dace/codegen/CMakeLists.txt @@ -181,6 +181,7 @@ if(${NUM_ENV_VARS} GREATER 0) set(${KEY} ${VAL}) endforeach() endif() +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_PACKAGES "${DACE_ENV_PACKAGES}") string(REPLACE " " ";" DACE_ENV_PACKAGES "${DACE_ENV_PACKAGES}") foreach(PACKAGE_NAME ${DACE_ENV_PACKAGES}) find_package(${PACKAGE_NAME} REQUIRED) @@ -200,10 +201,14 @@ if(${NUM_ENV_VARS} GREATER 0) endforeach() endif() # Configure specified include directories, libraries, and flags +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_INCLUDES "${DACE_ENV_INCLUDES}") +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_LIBRARIES "${DACE_ENV_LIBRARIES}") string(REPLACE " " ";" DACE_ENV_INCLUDES "${DACE_ENV_INCLUDES}") string(REPLACE " " ";" DACE_ENV_LIBRARIES "${DACE_ENV_LIBRARIES}") include_directories(${DACE_ENV_INCLUDES}) set(DACE_LIBS ${DACE_LIBS} ${DACE_ENV_LIBRARIES}) +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_LINK_FLAGS "${DACE_ENV_LINK_FLAGS}") +string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_COMPILE_FLAGS "${DACE_ENV_COMPILE_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${DACE_ENV_COMPILE_FLAGS}") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${DACE_ENV_LINK_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${DACE_ENV_LINK_FLAGS}") From 9cab345699346c73bc29e4bae5d6eae12471256d Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 3 Dec 2021 14:32:05 +0100 Subject: [PATCH 03/14] Handle case where closure SDFG could not be found --- dace/frontend/python/preprocessing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 88fd3b3578..0d9626a484 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -849,6 +849,7 @@ def visit_Call(self, node: ast.Call): constant_args = self._eval_args(node) # Resolve nested closure as necessary + qualname = None try: qualname = next(k for k, v in self.closure.closure_sdfgs.items() if v is value) @@ -863,7 +864,8 @@ def visit_Call(self, node: ast.Call): raise except Exception as ex: # Parsing failed (anything can happen here) warnings.warn(f'Parsing SDFGConvertible {value} failed: {ex}') - del self.closure.closure_sdfgs[qualname] + if qualname in self.closure.closure_sdfgs: + del self.closure.closure_sdfgs[qualname] # Return old call AST instead node.func = node.func.oldnode.func From 57cf558d26ff8da7dc7dfc089c8eb8d04bcb54f3 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 3 Dec 2021 22:10:13 +0100 Subject: [PATCH 04/14] Emit correct type when using dynamic memlets as const pointers --- dace/codegen/targets/cpu.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 8a2a7651b9..dd7fad7f7a 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -1286,7 +1286,11 @@ def memlet_definition(self, else: # Variable number of reads: get a const reference that can # be read if necessary - memlet_type = '%s const' % memlet_type + memlet_type = 'const %s' % memlet_type + if is_pointer: + # This is done to make the reference constant, otherwise + # compilers error out with initial reference value. + memlet_type += ' const' result += "{} &{} = {};".format(memlet_type, local_name, expr) defined = (DefinedType.Scalar From 1fd55a7eaaa2c8bd4a617d7dc85b31ce27980d50 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sat, 4 Dec 2021 00:05:20 +0100 Subject: [PATCH 05/14] Properly convert symbolic expressions to code upon replacement --- dace/sdfg/replace.py | 10 ++++++++-- tests/inlining_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index fbbc768218..ab91f01632 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -92,8 +92,12 @@ def replace_properties(node: Any, symrepl: Dict[symbolic.symbol, continue if lang is dtypes.Language.CPP: # Replace in C++ code + # Avoid import loop + from dace.codegen.targets.cpp import sym2cpp + # Use local variables and shadowing to replace - replacement = 'auto %s = %s;\n' % (name, new_name) + replacement = 'auto %s = %s;\n' % (name, + sym2cpp(new_name)) propval.code = replacement + newcode else: warnings.warn('Replacement of %s with %s was not made ' @@ -101,7 +105,9 @@ def replace_properties(node: Any, symrepl: Dict[symbolic.symbol, (name, new_name, lang)) elif propval.code is not None: for stmt in propval.code: - ASTFindReplace({name: new_name}).visit(stmt) + ASTFindReplace({ + name: symbolic.symstr(new_name) + }).visit(stmt) elif (isinstance(propclass, properties.DictProperty) and pname == 'symbol_mapping'): # Symbol mappings for nested SDFGs diff --git a/tests/inlining_test.py b/tests/inlining_test.py index c0c777324f..58f3216d83 100644 --- a/tests/inlining_test.py +++ b/tests/inlining_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import dace +from dace.transformation.interstate import InlineSDFG import numpy as np import pytest @@ -167,6 +168,36 @@ def outerprog(A: dace.float64[20]): assert np.allclose(A, expected) +def test_inline_symexpr(): + nsdfg = dace.SDFG('inner') + nsdfg.add_array('a', [20], dace.float64) + nstate = nsdfg.add_state() + nstate.add_mapped_tasklet('doit', {'k': '0:20'}, {}, + '''if k < j: + o = 2.0''', {'o': dace.Memlet('a[k]', dynamic=True)}, + external_edges=True) + + sdfg = dace.SDFG('outer') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_symbol('i', dace.int32) + state = sdfg.add_state() + w = state.add_write('A') + nsdfg_node = state.add_nested_sdfg(nsdfg, None, {}, {'a'}, + {'j': 'min(i, 10)'}) + state.add_edge(nsdfg_node, 'a', w, None, dace.Memlet('A')) + + # Verify that compilation works before inlining + sdfg.compile() + + sdfg.apply_transformations(InlineSDFG) + + # Compile and run + a = np.random.rand(20) + sdfg(A=a, i=15) + assert np.allclose(a[:10], 2.0) + assert not np.allclose(a[10:], 2.0) + + if __name__ == "__main__": test() # Skipped to to bug that cannot be reproduced @@ -174,3 +205,4 @@ def outerprog(A: dace.float64[20]): test_empty_memlets() test_multistate_inline() test_multistate_inline_samename() + test_inline_symexpr() From c7f5d64430660258fbe7b0850fe477032d71d0f8 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 14 Dec 2021 13:11:44 +0100 Subject: [PATCH 06/14] Pass keyword arguments to SDFGConvertible objects --- dace/frontend/python/newast.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index e6c04f5eb2..0a26ad4c6d 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3677,12 +3677,11 @@ def _parse_sdfg_call(self, funcname: str, all_args = required_args elif isinstance(func, SDFGConvertible) or self._has_sdfg(func): argnames, constant_args = func.__sdfg_signature__() - args = [(aname, self._parse_function_arg(arg)) + posargs = [(aname, self._parse_function_arg(arg)) for aname, arg in zip(argnames, node.args)] - args += [(arg.arg, self._parse_function_arg(arg.value)) + kwargs = [(arg.arg, self._parse_function_arg(arg.value)) for arg in node.keywords] required_args = argnames - fargs = (self._eval_arg(arg) for _, arg in args) # fcopy = copy.copy(func) fcopy = func @@ -3690,11 +3689,14 @@ def _parse_sdfg_call(self, funcname: str, fcopy.global_vars = {**self.globals, **func.global_vars} try: + fargs = tuple(self._eval_arg(arg) for _, arg in posargs) + fkwargs = {k: self._eval_arg(arg) for k, arg in kwargs} + if isinstance(fcopy, DaceProgram): fcopy.signature = copy.deepcopy(func.signature) - sdfg = fcopy.to_sdfg(*fargs, strict=self.strict, save=False) + sdfg = fcopy.to_sdfg(*fargs, **fkwargs, strict=self.strict, save=False) else: - sdfg = fcopy.__sdfg__(*fargs) + sdfg = fcopy.__sdfg__(*fargs, **fkwargs) except: # Parsing failure # If parsing fails in an auto-parsed context, exit silently if hasattr(node.func, 'oldnode'): @@ -3702,6 +3704,8 @@ def _parse_sdfg_call(self, funcname: str, else: raise + args = posargs + kwargs + funcname = sdfg.name all_args = required_args # Try to promote args of kind `sym = scalar` From dbe078fac2cfab485167fb203aa3ee0ceb7486c3 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 14 Dec 2021 13:23:19 +0100 Subject: [PATCH 07/14] Allow SDFGConvertible objects to use arguments at parse-time --- dace/frontend/python/newast.py | 35 +++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 0a26ad4c6d..ad41ec7248 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3666,10 +3666,11 @@ def _parse_sdfg_call(self, funcname: str, if isinstance(func, SDFG): sdfg = copy.deepcopy(func) funcname = sdfg.name - args = [(aname, self._parse_function_arg(arg)) + posargs = [(aname, self._parse_function_arg(arg)) for aname, arg in zip(sdfg.arg_names, node.args)] - args += [(arg.arg, self._parse_function_arg(arg.value)) + kwargs = [(arg.arg, self._parse_function_arg(arg.value)) for arg in node.keywords] + args = posargs + kwargs required_args = [ a for a in sdfg.arglist().keys() if a not in sdfg.symbols and not a.startswith('__return') @@ -3678,10 +3679,11 @@ def _parse_sdfg_call(self, funcname: str, elif isinstance(func, SDFGConvertible) or self._has_sdfg(func): argnames, constant_args = func.__sdfg_signature__() posargs = [(aname, self._parse_function_arg(arg)) - for aname, arg in zip(argnames, node.args)] + for aname, arg in zip(argnames, node.args)] kwargs = [(arg.arg, self._parse_function_arg(arg.value)) - for arg in node.keywords] + for arg in node.keywords] required_args = argnames + args = posargs + kwargs # fcopy = copy.copy(func) fcopy = func @@ -3694,9 +3696,18 @@ def _parse_sdfg_call(self, funcname: str, if isinstance(fcopy, DaceProgram): fcopy.signature = copy.deepcopy(func.signature) - sdfg = fcopy.to_sdfg(*fargs, **fkwargs, strict=self.strict, save=False) + sdfg = fcopy.to_sdfg(*fargs, + **fkwargs, + strict=self.strict, + save=False) else: sdfg = fcopy.__sdfg__(*fargs, **fkwargs) + + # Filter out parsed/omitted arguments + posargs = [(k, v) for k, v in posargs if k in required_args] + kwargs = [(k, v) for k, v in kwargs if k in required_args] + args = posargs + kwargs + except: # Parsing failure # If parsing fails in an auto-parsed context, exit silently if hasattr(node.func, 'oldnode'): @@ -3704,8 +3715,6 @@ def _parse_sdfg_call(self, funcname: str, else: raise - args = posargs + kwargs - funcname = sdfg.name all_args = required_args # Try to promote args of kind `sym = scalar` @@ -3770,15 +3779,15 @@ def _parse_sdfg_call(self, funcname: str, required_args = dtypes.deduplicate(required_args) # Argument checks - for arg in node.keywords: + for aname, arg in kwargs: # Skip explicit return values - if arg.arg.startswith('__return'): - required_args.append(arg.arg) + if aname.startswith('__return'): + required_args.append(aname) continue - if arg.arg not in required_args and arg.arg not in all_args: + if aname not in required_args and aname not in all_args: raise DaceSyntaxError( self, node, 'Invalid keyword argument "%s" in call to ' - '"%s"' % (arg.arg, funcname)) + '"%s"' % (aname, funcname)) if len(args) != len(required_args): raise DaceSyntaxError( self, node, 'Argument number mismatch in' @@ -4173,7 +4182,7 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): if isinstance(dtype, data.Data): n, arr = self.sdfg.add_temp_transient_like(dtype) elif isinstance(dtype, dtypes.typeclass): - n, arr = self.sdfg.add_temp_transient((1,), dtype) + n, arr = self.sdfg.add_temp_transient((1, ), dtype) else: n, arr = None, None else: From d9a8c89ef421f9e41f0e27adf7f7e22a05dbad74 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 15 Dec 2021 11:41:25 +0100 Subject: [PATCH 08/14] Add module and class to parsed SDFG names --- dace/frontend/python/parser.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 216bdc79d9..422e13b45b 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -145,7 +145,6 @@ def __init__(self, self.f = f self.dec_args = args self.dec_kwargs = kwargs - self.name = f.__name__ self.resolve_functions = constant_functions self.argnames = _get_argnames(f) if method: @@ -240,6 +239,16 @@ def methodobj(self) -> Any: def methodobj(self, new_obj: Any): self._methodobj = new_obj + @property + def name(self) -> str: + """ Returns a unique name for this program. """ + result = '' + if self.f.__module__ is not None and self.f.__module__ != '__main__': + result += self.f.__module__.replace('.', '_') + '_' + if self._methodobj is not None: + result += type(self._methodobj).__name__ + '_' + return result + self.f.__name__ + def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: return self.argnames, self.constant_args From 494d4e0a0fa886e5369e79aa91e485af79ca86d0 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 16 Dec 2021 00:48:51 +0100 Subject: [PATCH 09/14] Add use_cache flag to `DaceProgram.to_sdfg` --- dace/frontend/python/parser.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 422e13b45b..9b68468057 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -206,14 +206,45 @@ def to_sdfg(self, strict=None, save=False, validate=False, + use_cache=False, **kwargs) -> SDFG: """ Parses the DaCe function into an SDFG. """ - return self._parse(args, + if use_cache: + # Update global variables with current closure + self.global_vars = _get_locals_and_globals(self.f) + + # Move "self" from an argument into the closure + if self.methodobj is not None: + self.global_vars[self.objname] = self.methodobj + + argtypes, arg_mapping, constant_args = self._get_type_annotations( + args, kwargs) + + # Add constant arguments to globals for caching + self.global_vars.update(constant_args) + + # Check cache for already-parsed SDFG + cachekey = self._cache.make_key(argtypes, self.closure_array_keys, + self.closure_constant_keys, + constant_args) + + if self._cache.has(cachekey): + entry = self._cache.get(cachekey) + return entry.sdfg + + sdfg = self._parse(args, kwargs, strict=strict, save=save, validate=validate) + if use_cache: + # Add to cache + self._cache.add(cachekey, sdfg, None) + + return sdfg + + def __sdfg__(self, *args, **kwargs) -> SDFG: return self._parse(args, kwargs, From 762205b9ea1e9537bb4c65ed87835204adfda870 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 16 Dec 2021 01:05:08 +0100 Subject: [PATCH 10/14] Keep callback in closure and use it in case of a failed parse --- dace/frontend/python/newast.py | 8 ++++++ dace/frontend/python/preprocessing.py | 5 +++- .../callback_autodetect_test.py | 26 +++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index f96b2367ea..693fce336e 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -4118,11 +4118,19 @@ def create_callback(self, node: ast.Call, create_graph=True): parent_is_toplevel = True parent: ast.AST = None for anode in ast.walk(self.program_ast): + if parent is not None: + break for child in ast.iter_child_nodes(anode): if child is node: parent = anode parent_is_toplevel = getattr(anode, 'toplevel', False) break + if hasattr(child, 'func') and hasattr(child.func, 'oldnode'): + # Check if the AST node is part of a failed parse + if child.func.oldnode is node: + parent = anode + parent_is_toplevel = getattr(anode, 'toplevel', False) + break if parent is None: raise DaceSyntaxError( self, node, diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 0d9626a484..bf5eb3c01e 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -615,7 +615,8 @@ def global_value_to_node(self, res = self.global_value_to_node(parsed, parent_node, qualname, recurse, detect_callables) - del self.closure.callbacks[cbname] + # Keep callback in callbacks in case of parsing failure + # del self.closure.callbacks[cbname] return res except Exception: # Parsing failed (almost any exception can occur) return newnode @@ -838,6 +839,8 @@ def visit_Call(self, node: ast.Call): if not isinstance(node.func, (ast.Num, ast.Constant)): self.seen_calls.add(astutils.rname(node.func)) return self.generic_visit(node) + if hasattr(node.func, 'oldnode'): + self.seen_calls.add(astutils.rname(node.func.oldnode.func)) if isinstance(node.func, ast.Num): value = node.func.n else: diff --git a/tests/python_frontend/callback_autodetect_test.py b/tests/python_frontend/callback_autodetect_test.py index be797d0a96..2ac3f1c58e 100644 --- a/tests/python_frontend/callback_autodetect_test.py +++ b/tests/python_frontend/callback_autodetect_test.py @@ -357,6 +357,31 @@ def timeprog(A: dace.float64[20]): assert np.all(B > A) and np.all(A > now) +def test_object_with_nested_callback(): + c = np.random.rand(20) + + @dace_inhibitor + def call_another_function(a, b): + nonlocal c + c[:] = a + b + + class MyObject: + def __call__(self, a, b): + c = dict(a=a, b=b) + call_another_function(**c) + + obj = MyObject() + + @dace.program + def callobj(a, b): + obj(a, b) + + a = np.random.rand(20) + b = np.random.rand(20) + callobj(a, b) + assert np.allclose(c, a + b) + + if __name__ == '__main__': test_automatic_callback() test_automatic_callback_2() @@ -372,3 +397,4 @@ def timeprog(A: dace.float64[20]): test_callback_samename() # test_gpu_callback() test_bad_closure() + test_object_with_nested_callback() From 2359b89e60f21e8f8e3ad47c67577bd8e839fb2a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 16 Dec 2021 01:45:46 +0100 Subject: [PATCH 11/14] Minor fix --- dace/frontend/python/preprocessing.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index bf5eb3c01e..9b6a1905c9 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -840,7 +840,10 @@ def visit_Call(self, node: ast.Call): self.seen_calls.add(astutils.rname(node.func)) return self.generic_visit(node) if hasattr(node.func, 'oldnode'): - self.seen_calls.add(astutils.rname(node.func.oldnode.func)) + if isinstance(node.func.oldnode, ast.Call): + self.seen_calls.add(astutils.rname(node.func.oldnode.func)) + else: + self.seen_calls.add(astutils.rname(node.func.oldnode)) if isinstance(node.func, ast.Num): value = node.func.n else: From 447d7d2f659fb64934c6f8d0f4fe625af36e7068 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 16 Dec 2021 09:28:46 +0100 Subject: [PATCH 12/14] StorageType inference does not impact non-transients --- dace/sdfg/infer_types.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/dace/sdfg/infer_types.py b/dace/sdfg/infer_types.py index 0804d0d0dd..7c2db88714 100644 --- a/dace/sdfg/infer_types.py +++ b/dace/sdfg/infer_types.py @@ -25,8 +25,8 @@ def infer_connector_types(sdfg: SDFG): continue scalar = (e.data.subset and e.data.subset.num_elements() == 1) if e.data.data is not None: - allocated_as_scalar = (sdfg.arrays[e.data.data].storage is - not dtypes.StorageType.GPU_Global) + allocated_as_scalar = (sdfg.arrays[e.data.data].storage + is not dtypes.StorageType.GPU_Global) else: allocated_as_scalar = True @@ -64,11 +64,11 @@ def infer_connector_types(sdfg: SDFG): and (not e.data.dynamic or (e.data.dynamic and e.data.wcr is not None))) if e.data.data is not None: - allocated_as_scalar = (sdfg.arrays[e.data.data].storage is - not dtypes.StorageType.GPU_Global) + allocated_as_scalar = (sdfg.arrays[e.data.data].storage + is not dtypes.StorageType.GPU_Global) else: allocated_as_scalar = True - + if node.out_connectors[cname].type is None: # If nested SDFG, try to use internal array type if isinstance(node, nodes.NestedSDFG): @@ -105,7 +105,7 @@ def infer_connector_types(sdfg: SDFG): def set_default_schedule_and_storage_types( - sdfg: SDFG, toplevel_schedule: dtypes.ScheduleType): + sdfg: SDFG, toplevel_schedule: dtypes.ScheduleType): """ Sets default storage and schedule types throughout SDFG in-place. Replaces `ScheduleType.Default` and `StorageType.Default` @@ -244,7 +244,8 @@ def _set_default_storage_types(sdfg: SDFG, # Take care of remaining arrays/scalars, e.g., code->code edges for desc in sdfg.arrays.values(): - if desc.storage is dtypes.StorageType.Default: + if ((desc.transient or sdfg.parent_sdfg is None) + and desc.storage is dtypes.StorageType.Default): desc.storage = dtypes.StorageType.Register for state in sdfg.nodes(): From 3e2a002d9c2e1a6fb742ba3a1a7e8ac3483f175d Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 16 Dec 2021 09:37:37 +0100 Subject: [PATCH 13/14] Allow minor updates to Python 3.10 to be installed --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b8d5cf3a21..9bde2b17ad 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", ], - python_requires='>=3.6, <=3.10', + python_requires='>=3.6, <3.11', packages=find_packages( exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), package_data={ From ee1582338a0616234d325cfe4d0396fc413eda62 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 16 Dec 2021 15:02:07 +0100 Subject: [PATCH 14/14] Python callbacks: Fix multiple uses of same parameter --- dace/frontend/python/newast.py | 4 +-- .../callback_autodetect_test.py | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 693fce336e..3dd03d16b1 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -4256,12 +4256,12 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): self.sdfg.arrays[cname].dtype) # Setup arguments in graph - for arg in args: + for arg in dtypes.deduplicate(args): r = self.last_state.add_read(arg) self.last_state.add_edge(r, None, tasklet, f'__in_{arg}', Memlet(arg)) - for arg in outargs: + for arg in dtypes.deduplicate(outargs): w = self.last_state.add_write(arg) self.last_state.add_edge(tasklet, f'__out_{arg}', w, None, Memlet(arg)) diff --git a/tests/python_frontend/callback_autodetect_test.py b/tests/python_frontend/callback_autodetect_test.py index 2ac3f1c58e..02a7c4c93a 100644 --- a/tests/python_frontend/callback_autodetect_test.py +++ b/tests/python_frontend/callback_autodetect_test.py @@ -382,6 +382,36 @@ def callobj(a, b): assert np.allclose(c, a + b) +def test_two_parameters_same_name(): + @dace_inhibitor + def add(a, b): + return a + b + + @dace.program + def calladd(A: dace.float64[20], B: dace.float64[20]): + B[:] = add(A, A) + + a = np.random.rand(20) + b = np.random.rand(20) + calladd(a, b) + assert np.allclose(b, a + a) + + +def test_inout_same_name(): + @dace_inhibitor + def add(a, b): + return a + b + + @dace.program + def calladd(A: dace.float64[20]): + A[:] = add(A, A) + + a = np.random.rand(20) + expected = a + a + calladd(a) + assert np.allclose(expected, a) + + if __name__ == '__main__': test_automatic_callback() test_automatic_callback_2() @@ -398,3 +428,5 @@ def callobj(a, b): # test_gpu_callback() test_bad_closure() test_object_with_nested_callback() + test_two_parameters_same_name() + test_inout_same_name()