diff --git a/dace/codegen/cppunparse.py b/dace/codegen/cppunparse.py index 4fe522389a..a0a14a3033 100644 --- a/dace/codegen/cppunparse.py +++ b/dace/codegen/cppunparse.py @@ -532,7 +532,11 @@ def _write_constant(self, value): # Substitute overflowing decimal literal for AST infinities. self.write(result.replace("inf", INFSTR)) else: - self.write(result.replace('\'', '\"')) + # Special case for strings of containing byte literals (but are still strings). + if result.find("b'") >= 0: + self.write(result) + else: + self.write(result.replace('\'', '\"')) def _Constant(self, t): value = t.value diff --git a/dace/dtypes.py b/dace/dtypes.py index ffff6fb8d9..4bbe833bac 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -947,6 +947,12 @@ def get_trampoline(self, pyfunc, other_arguments): from functools import partial from dace import data, symbolic + def _string_converter(a: str, *args): + tmp = ctypes.cast(a, ctypes.c_char_p).value.decode('utf-8') + if tmp.startswith(chr(0xFFFF)): + return bytes(tmp[1:], 'utf-8') + return tmp + inp_arraypos = [] ret_arraypos = [] inp_types_and_sizes = [] @@ -961,7 +967,7 @@ def get_trampoline(self, pyfunc, other_arguments): elif isinstance(arg, data.Scalar) and arg.dtype == string: inp_arraypos.append(index) inp_types_and_sizes.append((ctypes.c_char_p, [])) - inp_converters.append(lambda a, *args: ctypes.cast(a, ctypes.c_char_p).value.decode('utf-8')) + inp_converters.append(_string_converter) elif isinstance(arg, data.Scalar) and isinstance(arg.dtype, pointer): inp_arraypos.append(index) inp_types_and_sizes.append((ctypes.c_void_p, [])) diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index f553115a23..b9c3fff6b4 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -11,6 +11,7 @@ from dace.sdfg import SDFG, SDFGState, InterstateEdge from dace.memlet import Memlet from dace.frontend.common import op_repository as oprepo +from dace.frontend.python.common import StringLiteral def _is_sequential(index_list): @@ -161,7 +162,7 @@ def prod(iterable): def create_einsum_sdfg(pv: 'dace.frontend.python.newast.ProgramVisitor', sdfg: SDFG, state: SDFGState, - einsum_string: str, + einsum_string: StringLiteral, *arrays: str, dtype: Optional[dtypes.typeclass] = None, optimize: bool = False, @@ -170,7 +171,7 @@ def create_einsum_sdfg(pv: 'dace.frontend.python.newast.ProgramVisitor', beta: Optional[symbolic.SymbolicType] = 0.0): return _create_einsum_internal(sdfg, state, - einsum_string, + str(einsum_string), *arrays, dtype=dtype, optimize=optimize, diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index 84f40415fb..812f9c390b 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -10,7 +10,7 @@ import numpy import sympy import sys -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Union from dace import dtypes, symbolic @@ -684,8 +684,10 @@ def create_constant(value: Any, node: Optional[ast.AST] = None) -> ast.AST: return newnode -def escape_string(value: str): +def escape_string(value: Union[bytes, str]): """ Converts special Python characters in strings back to their parsable version (e.g., newline to ``\n``) """ + if isinstance(value, bytes): + return f"{chr(0xFFFF)}{value.decode('utf-8')}" if sys.version_info >= (3, 0): return value.encode("unicode_escape").decode("utf-8") # Python 2.x diff --git a/dace/frontend/python/cached_program.py b/dace/frontend/python/cached_program.py index 1b5f48e688..bdbf958917 100644 --- a/dace/frontend/python/cached_program.py +++ b/dace/frontend/python/cached_program.py @@ -38,6 +38,12 @@ def _make_hashable(obj): except TypeError: return repr(obj) +def _make_sortable(obj): + try: + obj < obj + return obj + except TypeError: + return repr(obj) @dataclass class ProgramCacheKey: @@ -58,7 +64,7 @@ def __init__(self, arg_types: ArgTypes, closure_types: ArgTypes, closure_constan tuple((k, str(v.to_json())) for k, v in sorted(arg_types.items())), tuple((k, str(v.to_json())) for k, v in sorted(closure_types.items())), tuple((k, _make_hashable(v)) for k, v in sorted(closure_constants.items())), - tuple(sorted(specified_args)), + tuple(sorted(_make_sortable(a) for a in specified_args)), ) def __hash__(self) -> int: diff --git a/dace/frontend/python/common.py b/dace/frontend/python/common.py index 47bfb48ced..ffb5123c8d 100644 --- a/dace/frontend/python/common.py +++ b/dace/frontend/python/common.py @@ -8,6 +8,7 @@ class DaceSyntaxError(Exception): + def __init__(self, visitor, node: ast.AST, message: str): self.visitor = visitor self.node = node @@ -37,10 +38,20 @@ def inverse_dict_lookup(dict: Dict[str, Any], value: Any): return None +@dataclass(unsafe_hash=True) +class StringLiteral: + """ A string literal found in a parsed DaCe program. """ + value: Union[str, bytes] + + def __str__(self) -> str: + return self.value + + class SDFGConvertible(object): """ A mixin that defines the interface to annotate SDFG-convertible objects. """ + def __sdfg__(self, *args, **kwargs) -> SDFG: """ Returns an SDFG representation of this object. diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 97b8e5828d..1281f3e318 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -20,7 +20,8 @@ from dace.config import Config from dace.frontend.common import op_repository as oprepo from dace.frontend.python import astutils -from dace.frontend.python.common import (DaceSyntaxError, SDFGClosure, SDFGConvertible, inverse_dict_lookup) +from dace.frontend.python.common import (DaceSyntaxError, SDFGClosure, SDFGConvertible, inverse_dict_lookup, + StringLiteral) from dace.frontend.python.astutils import ExtNodeVisitor, ExtNodeTransformer from dace.frontend.python.astutils import rname from dace.frontend.python import nested_call, replacements, preprocessing @@ -261,7 +262,7 @@ def repl_callback(repldict): ] # Extra AST node types that are disallowed after preprocessing _DISALLOWED_STMTS = DISALLOWED_STMTS + [ - 'Global', 'Assert', 'Print', 'Nonlocal', 'Raise', 'Starred', 'AsyncFor', 'Bytes', 'ListComp', 'GeneratorExp', + 'Global', 'Assert', 'Print', 'Nonlocal', 'Raise', 'Starred', 'AsyncFor', 'ListComp', 'GeneratorExp', 'SetComp', 'DictComp', 'comprehension' ] @@ -3924,14 +3925,15 @@ def create_callback(self, node: ast.Call, create_graph=True): else: allargs.append(parsed_arg) else: - if isinstance(parsed_arg, (Number, numpy.number, type(None))): + if isinstance(parsed_arg, StringLiteral): + # Special case for strings + parsed_arg = f'"{astutils.escape_string(parsed_arg.value)}"' + atype = data.Scalar(dtypes.string) + elif isinstance(parsed_arg, (Number, numpy.number, type(None))): atype = data.create_datadescriptor(type(parsed_arg)) else: atype = data.create_datadescriptor(parsed_arg) - if isinstance(parsed_arg, str): - # Special case for strings - parsed_arg = f'"{astutils.escape_string(parsed_arg)}"' allargs.append(parsed_arg) argtypes.append(atype) @@ -4444,8 +4446,12 @@ def _visitname(self, name: str, node: ast.AST): #### Visitors that return arrays def visit_Str(self, node: ast.Str): - # A string constant returns itself - return node.s + # A string constant returns a string literal + return StringLiteral(node.s) + + def visit_Bytes(self, node: ast.Bytes): + # A bytes constant returns a string literal + return StringLiteral(node.s) def visit_Num(self, node: ast.Num): if isinstance(node.n, bool): @@ -4459,6 +4465,8 @@ def visit_Constant(self, node: ast.Constant): return dace.bool_(node.value) if isinstance(node.value, (int, float, complex)): return dtypes.DTYPE_TO_TYPECLASS[type(node.value)](node.value) + if isinstance(node.value, (str, bytes)): + return StringLiteral(node.value) return node.value def visit_Name(self, node: ast.Name): diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 5e43bf2c0a..2e4c3809b0 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -1,5 +1,6 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ DaCe Python parsing functionality and entry point to Python frontend. """ +import ast from dataclasses import dataclass import inspect import itertools @@ -523,7 +524,7 @@ def _get_type_annotations( types.update({f'__arg{j}': create_datadescriptor(varg) for j, varg in enumerate(vargs)}) arg_mapping.update({f'__arg{j}': varg for j, varg in enumerate(vargs)}) - gvar_mapping[aname] = tuple(f'__arg{j}' for j in range(len(vargs))) + gvar_mapping[aname] = tuple(ast.Name(id=f'__arg{j}') for j in range(len(vargs))) specified_args.update(set(gvar_mapping[aname])) # Shift arg_ind to the end arg_ind = len(given_args) @@ -538,7 +539,7 @@ def _get_type_annotations( f'arguments (invalid argument name: "{aname}").') types.update({f'__kwarg_{k}': v for k, v in vargs.items()}) arg_mapping.update({f'__kwarg_{k}': given_kwargs[k] for k in vargs.keys()}) - gvar_mapping[aname] = {k: f'__kwarg_{k}' for k in vargs.keys()} + gvar_mapping[aname] = {k: ast.Name(id=f'__kwarg_{k}') for k in vargs.keys()} specified_args.update({f'__kwarg_{k}' for k in vargs.keys()}) # END OF VARIABLE-LENGTH ARGUMENTS else: diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 94a4cc175e..966518d5f2 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -17,7 +17,7 @@ from dace.config import Config from dace.sdfg import SDFG from dace.frontend.python import astutils -from dace.frontend.python.common import (DaceSyntaxError, SDFGConvertible, SDFGClosure) +from dace.frontend.python.common import (DaceSyntaxError, SDFGConvertible, SDFGClosure, StringLiteral) class DaceRecursionError(Exception): @@ -474,10 +474,14 @@ def global_value_to_node(self, elif isinstance(value, symbolic.symbol): # Symbols resolve to the symbol name newnode = ast.Name(id=value.name, ctx=ast.Load()) - elif (dtypes.isconstant(value) or isinstance(value, SDFG) or hasattr(value, '__sdfg__')): + elif isinstance(value, ast.Name): + newnode = ast.Name(id=value.id, ctx=ast.Load()) + elif (dtypes.isconstant(value) or isinstance(value, (StringLiteral, SDFG)) or hasattr(value, '__sdfg__')): # Could be a constant, an SDFG, or SDFG-convertible object if isinstance(value, SDFG) or hasattr(value, '__sdfg__'): self.closure.closure_sdfgs[id(value)] = (qualname, value) + elif isinstance(value, StringLiteral): + value = value.value else: # If this is a function call to a None function, do not add its result to the closure if isinstance(parent_node, ast.Call): diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 5b4245f06e..cf53caec24 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -15,7 +15,7 @@ from dace.config import Config from dace import data, dtypes, subsets, symbolic, sdfg as sd from dace.frontend.common import op_repository as oprepo -from dace.frontend.python.common import DaceSyntaxError +from dace.frontend.python.common import DaceSyntaxError, StringLiteral import dace.frontend.python.memlet_parser as mem_parser from dace.frontend.python import astutils from dace.frontend.python.nested_call import NestedCall @@ -112,7 +112,7 @@ def _define_literal_ex(pv: 'ProgramVisitor', obj: Any, dtype: dace.typeclass = None, copy: bool = True, - order: str = 'K', + order: StringLiteral = StringLiteral('K'), subok: bool = False, ndmin: int = 0, like: Any = None, @@ -133,10 +133,10 @@ def _define_literal_ex(pv: 'ProgramVisitor', desc.dtype = dtype else: # From literal / constant if dtype is None: - arr = np.array(obj, copy=copy, order=order, subok=subok, ndmin=ndmin) + arr = np.array(obj, copy=copy, order=str(order), subok=subok, ndmin=ndmin) else: npdtype = dtype.as_numpy_dtype() - arr = np.array(obj, npdtype, copy=copy, order=order, subok=subok, ndmin=ndmin) + arr = np.array(obj, npdtype, copy=copy, order=str(order), subok=subok, ndmin=ndmin) desc = data.create_datadescriptor(arr) # Set extra properties @@ -3963,17 +3963,20 @@ def implement_ufunc_outer(visitor: 'ProgramVisitor', ast_node: ast.Call, sdfg: S @oprepo.replaces('numpy.reshape') -def reshape(pv: 'ProgramVisitor', - sdfg: SDFG, - state: SDFGState, - arr: str, - newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]], - order='C') -> str: +def reshape( + pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + arr: str, + newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]], + order: StringLiteral = StringLiteral('C') +) -> str: if isinstance(arr, (list, tuple)) and len(arr) == 1: arr = arr[0] desc = sdfg.arrays[arr] # "order" determines stride orders + order = str(order) fortran_strides = False if order == 'F' or (order == 'A' and desc.strides[0] == 1): # FORTRAN strides @@ -4063,8 +4066,10 @@ def size(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str) -> Size: @oprepo.replaces_attribute('Array', 'flat') @oprepo.replaces_attribute('Scalar', 'flat') @oprepo.replaces_attribute('View', 'flat') -def flat(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str, order: str = 'C') -> str: +def flat(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str, + order: StringLiteral = StringLiteral('C')) -> str: desc = sdfg.arrays[arr] + order = str(order) totalsize = data._prod(desc.shape) if order not in ('C', 'F'): raise NotImplementedError(f'Order "{order}" not yet supported for flattening') @@ -4153,12 +4158,14 @@ def _ndarray_fill(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str, @oprepo.replaces_method('Array', 'reshape') @oprepo.replaces_method('View', 'reshape') -def _ndarray_reshape(pv: 'ProgramVisitor', - sdfg: SDFG, - state: SDFGState, - arr: str, - newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]], - order='C') -> str: +def _ndarray_reshape( + pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + arr: str, + newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]], + order: StringLiteral = StringLiteral('C') +) -> str: return reshape(pv, sdfg, state, arr, newshape, order) @@ -4175,7 +4182,11 @@ def _ndarray_transpose(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: @oprepo.replaces_method('Array', 'flatten') @oprepo.replaces_method('Scalar', 'flatten') @oprepo.replaces_method('View', 'flatten') -def _ndarray_flatten(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str, order: str = 'C') -> str: +def _ndarray_flatten(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + arr: str, + order: StringLiteral = StringLiteral('C')) -> str: new_arr = flat(pv, sdfg, state, arr, order) # `flatten` always returns a copy if isinstance(new_arr, data.View): @@ -4186,7 +4197,11 @@ def _ndarray_flatten(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: st @oprepo.replaces_method('Array', 'ravel') @oprepo.replaces_method('Scalar', 'ravel') @oprepo.replaces_method('View', 'ravel') -def _ndarray_ravel(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str, order: str = 'C') -> str: +def _ndarray_ravel(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + arr: str, + order: StringLiteral = StringLiteral('C')) -> str: # `ravel` returns a copy only when necessary (sounds like ndarray.flat) return flat(pv, sdfg, state, arr, order) diff --git a/setup.py b/setup.py index 27e00b2191..227fbb7bc6 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ cmake_path = shutil.which('cmake') if cmake_path: # CMake is available, check version - output = subprocess.check_output([cmake_path, '--version'], text=True) + output = subprocess.check_output([cmake_path, '--version']).decode('utf-8') cmake_version = tuple(int(t) for t in output.splitlines()[0].split(' ')[-1].split('.')) # If version meets minimum requirements, CMake is not necessary if cmake_version >= (3, 15): diff --git a/tests/python_frontend/string_test.py b/tests/python_frontend/string_test.py new file mode 100644 index 0000000000..5db79f6fa9 --- /dev/null +++ b/tests/python_frontend/string_test.py @@ -0,0 +1,96 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +import pytest + + +def callback_inhibitor(f): + return f + +def test_string_literal_in_callback(): + success = False + @callback_inhibitor + def cb(a): + nonlocal success + if a == 'a': + success = True + + + @dace + def tester(a): + cb('a') + + + a = np.random.rand(1) + tester(a) + + assert success is True + + +def test_bytes_literal_in_callback(): + success = False + @callback_inhibitor + def cb(a): + nonlocal success + if a == b'Hello World!': + success = True + + + @dace + def tester(a): + cb(b'Hello World!') + + + a = np.random.rand(1) + tester(a) + + assert success is True + + +def test_string_literal_in_callback_2(): + success = False + @callback_inhibitor + def cb(a): + nonlocal success + if a == "b'Hello World!'": + success = True + + + @dace + def tester(a): + cb("b'Hello World!'") + + + a = np.random.rand(1) + tester(a) + + assert success is True + + +@pytest.mark.skip +def test_string_literal(): + + @dace + def tester(): + return 'Hello World!' + + assert tester()[0] == 'Hello World!' + + +@pytest.mark.skip +def test_bytes_literal(): + + @dace + def tester(): + return b'Hello World!' + + assert tester()[0] == b'Hello World!' + + +if __name__ == '__main__': + test_string_literal_in_callback() + test_bytes_literal_in_callback() + test_string_literal_in_callback_2() + # test_string_literal() + # test_bytes_literal()