Skip to content

Commit

Permalink
Merge pull request #1081 from spcl/string-literal
Browse files Browse the repository at this point in the history
Explicit support for string literals in Python frontend
  • Loading branch information
tbennun authored Aug 19, 2022
2 parents 8835427 + 62ec75b commit 330861a
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 39 deletions.
6 changes: 5 additions & 1 deletion dace/codegen/cppunparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,11 @@ def _write_constant(self, value):
# Substitute overflowing decimal literal for AST infinities.
self.write(result.replace("inf", INFSTR))
else:
self.write(result.replace('\'', '\"'))
# Special case for strings of containing byte literals (but are still strings).
if result.find("b'") >= 0:
self.write(result)
else:
self.write(result.replace('\'', '\"'))

def _Constant(self, t):
value = t.value
Expand Down
8 changes: 7 additions & 1 deletion dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,12 @@ def get_trampoline(self, pyfunc, other_arguments):
from functools import partial
from dace import data, symbolic

def _string_converter(a: str, *args):
tmp = ctypes.cast(a, ctypes.c_char_p).value.decode('utf-8')
if tmp.startswith(chr(0xFFFF)):
return bytes(tmp[1:], 'utf-8')
return tmp

inp_arraypos = []
ret_arraypos = []
inp_types_and_sizes = []
Expand All @@ -961,7 +967,7 @@ def get_trampoline(self, pyfunc, other_arguments):
elif isinstance(arg, data.Scalar) and arg.dtype == string:
inp_arraypos.append(index)
inp_types_and_sizes.append((ctypes.c_char_p, []))
inp_converters.append(lambda a, *args: ctypes.cast(a, ctypes.c_char_p).value.decode('utf-8'))
inp_converters.append(_string_converter)
elif isinstance(arg, data.Scalar) and isinstance(arg.dtype, pointer):
inp_arraypos.append(index)
inp_types_and_sizes.append((ctypes.c_void_p, []))
Expand Down
5 changes: 3 additions & 2 deletions dace/frontend/common/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dace.sdfg import SDFG, SDFGState, InterstateEdge
from dace.memlet import Memlet
from dace.frontend.common import op_repository as oprepo
from dace.frontend.python.common import StringLiteral


def _is_sequential(index_list):
Expand Down Expand Up @@ -161,7 +162,7 @@ def prod(iterable):
def create_einsum_sdfg(pv: 'dace.frontend.python.newast.ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
einsum_string: str,
einsum_string: StringLiteral,
*arrays: str,
dtype: Optional[dtypes.typeclass] = None,
optimize: bool = False,
Expand All @@ -170,7 +171,7 @@ def create_einsum_sdfg(pv: 'dace.frontend.python.newast.ProgramVisitor',
beta: Optional[symbolic.SymbolicType] = 0.0):
return _create_einsum_internal(sdfg,
state,
einsum_string,
str(einsum_string),
*arrays,
dtype=dtype,
optimize=optimize,
Expand Down
6 changes: 4 additions & 2 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy
import sympy
import sys
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Union

from dace import dtypes, symbolic

Expand Down Expand Up @@ -684,8 +684,10 @@ def create_constant(value: Any, node: Optional[ast.AST] = None) -> ast.AST:
return newnode


def escape_string(value: str):
def escape_string(value: Union[bytes, str]):
""" Converts special Python characters in strings back to their parsable version (e.g., newline to ``\n``) """
if isinstance(value, bytes):
return f"{chr(0xFFFF)}{value.decode('utf-8')}"
if sys.version_info >= (3, 0):
return value.encode("unicode_escape").decode("utf-8")
# Python 2.x
Expand Down
8 changes: 7 additions & 1 deletion dace/frontend/python/cached_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def _make_hashable(obj):
except TypeError:
return repr(obj)

def _make_sortable(obj):
try:
obj < obj
return obj
except TypeError:
return repr(obj)

@dataclass
class ProgramCacheKey:
Expand All @@ -58,7 +64,7 @@ def __init__(self, arg_types: ArgTypes, closure_types: ArgTypes, closure_constan
tuple((k, str(v.to_json())) for k, v in sorted(arg_types.items())),
tuple((k, str(v.to_json())) for k, v in sorted(closure_types.items())),
tuple((k, _make_hashable(v)) for k, v in sorted(closure_constants.items())),
tuple(sorted(specified_args)),
tuple(sorted(_make_sortable(a) for a in specified_args)),
)

def __hash__(self) -> int:
Expand Down
11 changes: 11 additions & 0 deletions dace/frontend/python/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


class DaceSyntaxError(Exception):

def __init__(self, visitor, node: ast.AST, message: str):
self.visitor = visitor
self.node = node
Expand Down Expand Up @@ -37,10 +38,20 @@ def inverse_dict_lookup(dict: Dict[str, Any], value: Any):
return None


@dataclass(unsafe_hash=True)
class StringLiteral:
""" A string literal found in a parsed DaCe program. """
value: Union[str, bytes]

def __str__(self) -> str:
return self.value


class SDFGConvertible(object):
"""
A mixin that defines the interface to annotate SDFG-convertible objects.
"""

def __sdfg__(self, *args, **kwargs) -> SDFG:
"""
Returns an SDFG representation of this object.
Expand Down
24 changes: 16 additions & 8 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from dace.config import Config
from dace.frontend.common import op_repository as oprepo
from dace.frontend.python import astutils
from dace.frontend.python.common import (DaceSyntaxError, SDFGClosure, SDFGConvertible, inverse_dict_lookup)
from dace.frontend.python.common import (DaceSyntaxError, SDFGClosure, SDFGConvertible, inverse_dict_lookup,
StringLiteral)
from dace.frontend.python.astutils import ExtNodeVisitor, ExtNodeTransformer
from dace.frontend.python.astutils import rname
from dace.frontend.python import nested_call, replacements, preprocessing
Expand Down Expand Up @@ -261,7 +262,7 @@ def repl_callback(repldict):
]
# Extra AST node types that are disallowed after preprocessing
_DISALLOWED_STMTS = DISALLOWED_STMTS + [
'Global', 'Assert', 'Print', 'Nonlocal', 'Raise', 'Starred', 'AsyncFor', 'Bytes', 'ListComp', 'GeneratorExp',
'Global', 'Assert', 'Print', 'Nonlocal', 'Raise', 'Starred', 'AsyncFor', 'ListComp', 'GeneratorExp',
'SetComp', 'DictComp', 'comprehension'
]

Expand Down Expand Up @@ -3924,14 +3925,15 @@ def create_callback(self, node: ast.Call, create_graph=True):
else:
allargs.append(parsed_arg)
else:
if isinstance(parsed_arg, (Number, numpy.number, type(None))):
if isinstance(parsed_arg, StringLiteral):
# Special case for strings
parsed_arg = f'"{astutils.escape_string(parsed_arg.value)}"'
atype = data.Scalar(dtypes.string)
elif isinstance(parsed_arg, (Number, numpy.number, type(None))):
atype = data.create_datadescriptor(type(parsed_arg))
else:
atype = data.create_datadescriptor(parsed_arg)

if isinstance(parsed_arg, str):
# Special case for strings
parsed_arg = f'"{astutils.escape_string(parsed_arg)}"'
allargs.append(parsed_arg)

argtypes.append(atype)
Expand Down Expand Up @@ -4444,8 +4446,12 @@ def _visitname(self, name: str, node: ast.AST):

#### Visitors that return arrays
def visit_Str(self, node: ast.Str):
# A string constant returns itself
return node.s
# A string constant returns a string literal
return StringLiteral(node.s)

def visit_Bytes(self, node: ast.Bytes):
# A bytes constant returns a string literal
return StringLiteral(node.s)

def visit_Num(self, node: ast.Num):
if isinstance(node.n, bool):
Expand All @@ -4459,6 +4465,8 @@ def visit_Constant(self, node: ast.Constant):
return dace.bool_(node.value)
if isinstance(node.value, (int, float, complex)):
return dtypes.DTYPE_TO_TYPECLASS[type(node.value)](node.value)
if isinstance(node.value, (str, bytes)):
return StringLiteral(node.value)
return node.value

def visit_Name(self, node: ast.Name):
Expand Down
5 changes: 3 additions & 2 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
""" DaCe Python parsing functionality and entry point to Python frontend. """
import ast
from dataclasses import dataclass
import inspect
import itertools
Expand Down Expand Up @@ -523,7 +524,7 @@ def _get_type_annotations(

types.update({f'__arg{j}': create_datadescriptor(varg) for j, varg in enumerate(vargs)})
arg_mapping.update({f'__arg{j}': varg for j, varg in enumerate(vargs)})
gvar_mapping[aname] = tuple(f'__arg{j}' for j in range(len(vargs)))
gvar_mapping[aname] = tuple(ast.Name(id=f'__arg{j}') for j in range(len(vargs)))
specified_args.update(set(gvar_mapping[aname]))
# Shift arg_ind to the end
arg_ind = len(given_args)
Expand All @@ -538,7 +539,7 @@ def _get_type_annotations(
f'arguments (invalid argument name: "{aname}").')
types.update({f'__kwarg_{k}': v for k, v in vargs.items()})
arg_mapping.update({f'__kwarg_{k}': given_kwargs[k] for k in vargs.keys()})
gvar_mapping[aname] = {k: f'__kwarg_{k}' for k in vargs.keys()}
gvar_mapping[aname] = {k: ast.Name(id=f'__kwarg_{k}') for k in vargs.keys()}
specified_args.update({f'__kwarg_{k}' for k in vargs.keys()})
# END OF VARIABLE-LENGTH ARGUMENTS
else:
Expand Down
8 changes: 6 additions & 2 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dace.config import Config
from dace.sdfg import SDFG
from dace.frontend.python import astutils
from dace.frontend.python.common import (DaceSyntaxError, SDFGConvertible, SDFGClosure)
from dace.frontend.python.common import (DaceSyntaxError, SDFGConvertible, SDFGClosure, StringLiteral)


class DaceRecursionError(Exception):
Expand Down Expand Up @@ -474,10 +474,14 @@ def global_value_to_node(self,
elif isinstance(value, symbolic.symbol):
# Symbols resolve to the symbol name
newnode = ast.Name(id=value.name, ctx=ast.Load())
elif (dtypes.isconstant(value) or isinstance(value, SDFG) or hasattr(value, '__sdfg__')):
elif isinstance(value, ast.Name):
newnode = ast.Name(id=value.id, ctx=ast.Load())
elif (dtypes.isconstant(value) or isinstance(value, (StringLiteral, SDFG)) or hasattr(value, '__sdfg__')):
# Could be a constant, an SDFG, or SDFG-convertible object
if isinstance(value, SDFG) or hasattr(value, '__sdfg__'):
self.closure.closure_sdfgs[id(value)] = (qualname, value)
elif isinstance(value, StringLiteral):
value = value.value
else:
# If this is a function call to a None function, do not add its result to the closure
if isinstance(parent_node, ast.Call):
Expand Down
53 changes: 34 additions & 19 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dace.config import Config
from dace import data, dtypes, subsets, symbolic, sdfg as sd
from dace.frontend.common import op_repository as oprepo
from dace.frontend.python.common import DaceSyntaxError
from dace.frontend.python.common import DaceSyntaxError, StringLiteral
import dace.frontend.python.memlet_parser as mem_parser
from dace.frontend.python import astutils
from dace.frontend.python.nested_call import NestedCall
Expand Down Expand Up @@ -112,7 +112,7 @@ def _define_literal_ex(pv: 'ProgramVisitor',
obj: Any,
dtype: dace.typeclass = None,
copy: bool = True,
order: str = 'K',
order: StringLiteral = StringLiteral('K'),
subok: bool = False,
ndmin: int = 0,
like: Any = None,
Expand All @@ -133,10 +133,10 @@ def _define_literal_ex(pv: 'ProgramVisitor',
desc.dtype = dtype
else: # From literal / constant
if dtype is None:
arr = np.array(obj, copy=copy, order=order, subok=subok, ndmin=ndmin)
arr = np.array(obj, copy=copy, order=str(order), subok=subok, ndmin=ndmin)
else:
npdtype = dtype.as_numpy_dtype()
arr = np.array(obj, npdtype, copy=copy, order=order, subok=subok, ndmin=ndmin)
arr = np.array(obj, npdtype, copy=copy, order=str(order), subok=subok, ndmin=ndmin)
desc = data.create_datadescriptor(arr)

# Set extra properties
Expand Down Expand Up @@ -3963,17 +3963,20 @@ def implement_ufunc_outer(visitor: 'ProgramVisitor', ast_node: ast.Call, sdfg: S


@oprepo.replaces('numpy.reshape')
def reshape(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
arr: str,
newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]],
order='C') -> str:
def reshape(
pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
arr: str,
newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]],
order: StringLiteral = StringLiteral('C')
) -> str:
if isinstance(arr, (list, tuple)) and len(arr) == 1:
arr = arr[0]
desc = sdfg.arrays[arr]

# "order" determines stride orders
order = str(order)
fortran_strides = False
if order == 'F' or (order == 'A' and desc.strides[0] == 1):
# FORTRAN strides
Expand Down Expand Up @@ -4063,8 +4066,10 @@ def size(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str) -> Size:
@oprepo.replaces_attribute('Array', 'flat')
@oprepo.replaces_attribute('Scalar', 'flat')
@oprepo.replaces_attribute('View', 'flat')
def flat(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str, order: str = 'C') -> str:
def flat(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str,
order: StringLiteral = StringLiteral('C')) -> str:
desc = sdfg.arrays[arr]
order = str(order)
totalsize = data._prod(desc.shape)
if order not in ('C', 'F'):
raise NotImplementedError(f'Order "{order}" not yet supported for flattening')
Expand Down Expand Up @@ -4153,12 +4158,14 @@ def _ndarray_fill(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str,

@oprepo.replaces_method('Array', 'reshape')
@oprepo.replaces_method('View', 'reshape')
def _ndarray_reshape(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
arr: str,
newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]],
order='C') -> str:
def _ndarray_reshape(
pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
arr: str,
newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]],
order: StringLiteral = StringLiteral('C')
) -> str:
return reshape(pv, sdfg, state, arr, newshape, order)


Expand All @@ -4175,7 +4182,11 @@ def _ndarray_transpose(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr:
@oprepo.replaces_method('Array', 'flatten')
@oprepo.replaces_method('Scalar', 'flatten')
@oprepo.replaces_method('View', 'flatten')
def _ndarray_flatten(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str, order: str = 'C') -> str:
def _ndarray_flatten(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
arr: str,
order: StringLiteral = StringLiteral('C')) -> str:
new_arr = flat(pv, sdfg, state, arr, order)
# `flatten` always returns a copy
if isinstance(new_arr, data.View):
Expand All @@ -4186,7 +4197,11 @@ def _ndarray_flatten(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: st
@oprepo.replaces_method('Array', 'ravel')
@oprepo.replaces_method('Scalar', 'ravel')
@oprepo.replaces_method('View', 'ravel')
def _ndarray_ravel(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str, order: str = 'C') -> str:
def _ndarray_ravel(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
arr: str,
order: StringLiteral = StringLiteral('C')) -> str:
# `ravel` returns a copy only when necessary (sounds like ndarray.flat)
return flat(pv, sdfg, state, arr, order)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
cmake_path = shutil.which('cmake')
if cmake_path:
# CMake is available, check version
output = subprocess.check_output([cmake_path, '--version'], text=True)
output = subprocess.check_output([cmake_path, '--version']).decode('utf-8')
cmake_version = tuple(int(t) for t in output.splitlines()[0].split(' ')[-1].split('.'))
# If version meets minimum requirements, CMake is not necessary
if cmake_version >= (3, 15):
Expand Down
Loading

0 comments on commit 330861a

Please sign in to comment.