Skip to content

Commit

Permalink
Merge pull request #894: Improve preprocessing and fix callback issue…
Browse files Browse the repository at this point in the history
…s in Python frontend
  • Loading branch information
tbennun authored Dec 17, 2021
2 parents 49e3f55 + e9a2145 commit f2f697c
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 31 deletions.
5 changes: 5 additions & 0 deletions dace/codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ if(${NUM_ENV_VARS} GREATER 0)
set(${KEY} ${VAL})
endforeach()
endif()
string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_PACKAGES "${DACE_ENV_PACKAGES}")
string(REPLACE " " ";" DACE_ENV_PACKAGES "${DACE_ENV_PACKAGES}")
foreach(PACKAGE_NAME ${DACE_ENV_PACKAGES})
find_package(${PACKAGE_NAME} REQUIRED)
Expand All @@ -200,10 +201,14 @@ if(${NUM_ENV_VARS} GREATER 0)
endforeach()
endif()
# Configure specified include directories, libraries, and flags
string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_INCLUDES "${DACE_ENV_INCLUDES}")
string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_LIBRARIES "${DACE_ENV_LIBRARIES}")
string(REPLACE " " ";" DACE_ENV_INCLUDES "${DACE_ENV_INCLUDES}")
string(REPLACE " " ";" DACE_ENV_LIBRARIES "${DACE_ENV_LIBRARIES}")
include_directories(${DACE_ENV_INCLUDES})
set(DACE_LIBS ${DACE_LIBS} ${DACE_ENV_LIBRARIES})
string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_LINK_FLAGS "${DACE_ENV_LINK_FLAGS}")
string(REPLACE "_DACE_CMAKE_EXPAND" "$" DACE_ENV_COMPILE_FLAGS "${DACE_ENV_COMPILE_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${DACE_ENV_COMPILE_FLAGS}")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${DACE_ENV_LINK_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${DACE_ENV_LINK_FLAGS}")
Expand Down
6 changes: 5 additions & 1 deletion dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,11 @@ def memlet_definition(self,
else:
# Variable number of reads: get a const reference that can
# be read if necessary
memlet_type = '%s const' % memlet_type
memlet_type = 'const %s' % memlet_type
if is_pointer:
# This is done to make the reference constant, otherwise
# compilers error out with initial reference value.
memlet_type += ' const'
result += "{} &{} = {};".format(memlet_type, local_name,
expr)
defined = (DefinedType.Scalar
Expand Down
53 changes: 37 additions & 16 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3657,35 +3657,48 @@ def _parse_sdfg_call(self, funcname: str,
if isinstance(func, SDFG):
sdfg = copy.deepcopy(func)
funcname = sdfg.name
args = [(aname, self._parse_function_arg(arg))
posargs = [(aname, self._parse_function_arg(arg))
for aname, arg in zip(sdfg.arg_names, node.args)]
args += [(arg.arg, self._parse_function_arg(arg.value))
kwargs = [(arg.arg, self._parse_function_arg(arg.value))
for arg in node.keywords]
args = posargs + kwargs
required_args = [
a for a in sdfg.arglist().keys()
if a not in sdfg.symbols and not a.startswith('__return')
]
all_args = required_args
elif isinstance(func, SDFGConvertible) or self._has_sdfg(func):
argnames, constant_args = func.__sdfg_signature__()
args = [(aname, self._parse_function_arg(arg))
for aname, arg in zip(argnames, node.args)]
args += [(arg.arg, self._parse_function_arg(arg.value))
for arg in node.keywords]
posargs = [(aname, self._parse_function_arg(arg))
for aname, arg in zip(argnames, node.args)]
kwargs = [(arg.arg, self._parse_function_arg(arg.value))
for arg in node.keywords]
required_args = argnames
fargs = (self._eval_arg(arg) for _, arg in args)
args = posargs + kwargs

# fcopy = copy.copy(func)
fcopy = func
if hasattr(fcopy, 'global_vars'):
fcopy.global_vars = {**self.globals, **func.global_vars}

try:
fargs = tuple(self._eval_arg(arg) for _, arg in posargs)
fkwargs = {k: self._eval_arg(arg) for k, arg in kwargs}

if isinstance(fcopy, DaceProgram):
fcopy.signature = copy.deepcopy(func.signature)
sdfg = fcopy.to_sdfg(*fargs, strict=self.strict, save=False)
sdfg = fcopy.to_sdfg(*fargs,
**fkwargs,
strict=self.strict,
save=False)
else:
sdfg = fcopy.__sdfg__(*fargs)
sdfg = fcopy.__sdfg__(*fargs, **fkwargs)

# Filter out parsed/omitted arguments
posargs = [(k, v) for k, v in posargs if k in required_args]
kwargs = [(k, v) for k, v in kwargs if k in required_args]
args = posargs + kwargs

except: # Parsing failure
# If parsing fails in an auto-parsed context, exit silently
if hasattr(node.func, 'oldnode'):
Expand Down Expand Up @@ -3757,15 +3770,15 @@ def _parse_sdfg_call(self, funcname: str,
required_args = dtypes.deduplicate(required_args)

# Argument checks
for arg in node.keywords:
for aname, arg in kwargs:
# Skip explicit return values
if arg.arg.startswith('__return'):
required_args.append(arg.arg)
if aname.startswith('__return'):
required_args.append(aname)
continue
if arg.arg not in required_args and arg.arg not in all_args:
if aname not in required_args and aname not in all_args:
raise DaceSyntaxError(
self, node, 'Invalid keyword argument "%s" in call to '
'"%s"' % (arg.arg, funcname))
'"%s"' % (aname, funcname))
if len(args) != len(required_args):
raise DaceSyntaxError(
self, node, 'Argument number mismatch in'
Expand Down Expand Up @@ -4105,11 +4118,19 @@ def create_callback(self, node: ast.Call, create_graph=True):
parent_is_toplevel = True
parent: ast.AST = None
for anode in ast.walk(self.program_ast):
if parent is not None:
break
for child in ast.iter_child_nodes(anode):
if child is node:
parent = anode
parent_is_toplevel = getattr(anode, 'toplevel', False)
break
if hasattr(child, 'func') and hasattr(child.func, 'oldnode'):
# Check if the AST node is part of a failed parse
if child.func.oldnode is node:
parent = anode
parent_is_toplevel = getattr(anode, 'toplevel', False)
break
if parent is None:
raise DaceSyntaxError(
self, node,
Expand Down Expand Up @@ -4235,12 +4256,12 @@ def parse_target(t: Union[ast.Name, ast.Subscript]):
self.sdfg.arrays[cname].dtype)

# Setup arguments in graph
for arg in args:
for arg in dtypes.deduplicate(args):
r = self.last_state.add_read(arg)
self.last_state.add_edge(r, None, tasklet, f'__in_{arg}',
Memlet(arg))

for arg in outargs:
for arg in dtypes.deduplicate(outargs):
w = self.last_state.add_write(arg)
self.last_state.add_edge(tasklet, f'__out_{arg}', w, None,
Memlet(arg))
Expand Down
50 changes: 47 additions & 3 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def __init__(self,
self.f = f
self.dec_args = args
self.dec_kwargs = kwargs
self.name = f.__name__
self.resolve_functions = constant_functions
self.argnames = _get_argnames(f)
if method:
Expand All @@ -168,6 +167,7 @@ def __init__(self,
self.symbols = set(k for k, v in self.global_vars.items()
if isinstance(v, symbolic.symbol))
self.closure_arg_mapping: Dict[str, Callable[[], Any]] = {}
self.resolver: pycommon.SDFGClosure = None

# Add type annotations from decorator arguments (DEPRECATED)
if self.dec_args:
Expand Down Expand Up @@ -206,14 +206,45 @@ def to_sdfg(self,
strict=None,
save=False,
validate=False,
use_cache=False,
**kwargs) -> SDFG:
""" Parses the DaCe function into an SDFG. """
return self._parse(args,
if use_cache:
# Update global variables with current closure
self.global_vars = _get_locals_and_globals(self.f)

# Move "self" from an argument into the closure
if self.methodobj is not None:
self.global_vars[self.objname] = self.methodobj

argtypes, arg_mapping, constant_args = self._get_type_annotations(
args, kwargs)

# Add constant arguments to globals for caching
self.global_vars.update(constant_args)

# Check cache for already-parsed SDFG
cachekey = self._cache.make_key(argtypes, self.closure_array_keys,
self.closure_constant_keys,
constant_args)

if self._cache.has(cachekey):
entry = self._cache.get(cachekey)
return entry.sdfg

sdfg = self._parse(args,
kwargs,
strict=strict,
save=save,
validate=validate)

if use_cache:
# Add to cache
self._cache.add(cachekey, sdfg, None)

return sdfg


def __sdfg__(self, *args, **kwargs) -> SDFG:
return self._parse(args,
kwargs,
Expand All @@ -239,6 +270,16 @@ def methodobj(self) -> Any:
def methodobj(self, new_obj: Any):
self._methodobj = new_obj

@property
def name(self) -> str:
""" Returns a unique name for this program. """
result = ''
if self.f.__module__ is not None and self.f.__module__ != '__main__':
result += self.f.__module__.replace('.', '_') + '_'
if self._methodobj is not None:
result += type(self._methodobj).__name__ + '_'
return result + self.f.__name__

def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]:
return self.argnames, self.constant_args

Expand All @@ -261,7 +302,10 @@ def __sdfg_closure__(

if reevaluate is None:
result = {k: v() for k, v in self.closure_arg_mapping.items()}
result.update({k: v[1] for k, v in self.resolver.callbacks.items()})
if self.resolver is not None:
result.update(
{k: v[1]
for k, v in self.resolver.callbacks.items()})
return result
else:
return {
Expand Down
12 changes: 10 additions & 2 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,8 @@ def global_value_to_node(self,

res = self.global_value_to_node(parsed, parent_node, qualname,
recurse, detect_callables)
del self.closure.callbacks[cbname]
# Keep callback in callbacks in case of parsing failure
# del self.closure.callbacks[cbname]
return res
except Exception: # Parsing failed (almost any exception can occur)
return newnode
Expand Down Expand Up @@ -838,6 +839,11 @@ def visit_Call(self, node: ast.Call):
if not isinstance(node.func, (ast.Num, ast.Constant)):
self.seen_calls.add(astutils.rname(node.func))
return self.generic_visit(node)
if hasattr(node.func, 'oldnode'):
if isinstance(node.func.oldnode, ast.Call):
self.seen_calls.add(astutils.rname(node.func.oldnode.func))
else:
self.seen_calls.add(astutils.rname(node.func.oldnode))
if isinstance(node.func, ast.Num):
value = node.func.n
else:
Expand All @@ -849,6 +855,7 @@ def visit_Call(self, node: ast.Call):
constant_args = self._eval_args(node)

# Resolve nested closure as necessary
qualname = None
try:
qualname = next(k for k, v in self.closure.closure_sdfgs.items()
if v is value)
Expand All @@ -863,7 +870,8 @@ def visit_Call(self, node: ast.Call):
raise
except Exception as ex: # Parsing failed (anything can happen here)
warnings.warn(f'Parsing SDFGConvertible {value} failed: {ex}')
del self.closure.closure_sdfgs[qualname]
if qualname in self.closure.closure_sdfgs:
del self.closure.closure_sdfgs[qualname]
# Return old call AST instead
node.func = node.func.oldnode.func

Expand Down
15 changes: 8 additions & 7 deletions dace/sdfg/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def infer_connector_types(sdfg: SDFG):
continue
scalar = (e.data.subset and e.data.subset.num_elements() == 1)
if e.data.data is not None:
allocated_as_scalar = (sdfg.arrays[e.data.data].storage is
not dtypes.StorageType.GPU_Global)
allocated_as_scalar = (sdfg.arrays[e.data.data].storage
is not dtypes.StorageType.GPU_Global)
else:
allocated_as_scalar = True

Expand Down Expand Up @@ -64,11 +64,11 @@ def infer_connector_types(sdfg: SDFG):
and (not e.data.dynamic or
(e.data.dynamic and e.data.wcr is not None)))
if e.data.data is not None:
allocated_as_scalar = (sdfg.arrays[e.data.data].storage is
not dtypes.StorageType.GPU_Global)
allocated_as_scalar = (sdfg.arrays[e.data.data].storage
is not dtypes.StorageType.GPU_Global)
else:
allocated_as_scalar = True

if node.out_connectors[cname].type is None:
# If nested SDFG, try to use internal array type
if isinstance(node, nodes.NestedSDFG):
Expand Down Expand Up @@ -105,7 +105,7 @@ def infer_connector_types(sdfg: SDFG):


def set_default_schedule_and_storage_types(
sdfg: SDFG, toplevel_schedule: dtypes.ScheduleType):
sdfg: SDFG, toplevel_schedule: dtypes.ScheduleType):
"""
Sets default storage and schedule types throughout SDFG in-place.
Replaces `ScheduleType.Default` and `StorageType.Default`
Expand Down Expand Up @@ -244,7 +244,8 @@ def _set_default_storage_types(sdfg: SDFG,

# Take care of remaining arrays/scalars, e.g., code->code edges
for desc in sdfg.arrays.values():
if desc.storage is dtypes.StorageType.Default:
if ((desc.transient or sdfg.parent_sdfg is None)
and desc.storage is dtypes.StorageType.Default):
desc.storage = dtypes.StorageType.Register

for state in sdfg.nodes():
Expand Down
10 changes: 8 additions & 2 deletions dace/sdfg/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,22 @@ def replace_properties(node: Any, symrepl: Dict[symbolic.symbol,
continue

if lang is dtypes.Language.CPP: # Replace in C++ code
# Avoid import loop
from dace.codegen.targets.cpp import sym2cpp

# Use local variables and shadowing to replace
replacement = 'auto %s = %s;\n' % (name, new_name)
replacement = 'auto %s = %s;\n' % (name,
sym2cpp(new_name))
propval.code = replacement + newcode
else:
warnings.warn('Replacement of %s with %s was not made '
'for string tasklet code of language %s' %
(name, new_name, lang))
elif propval.code is not None:
for stmt in propval.code:
ASTFindReplace({name: new_name}).visit(stmt)
ASTFindReplace({
name: symbolic.symstr(new_name)
}).visit(stmt)
elif (isinstance(propclass, properties.DictProperty)
and pname == 'symbol_mapping'):
# Symbol mappings for nested SDFGs
Expand Down
Loading

0 comments on commit f2f697c

Please sign in to comment.