diff --git a/CHANGES.rst b/CHANGES.rst index 006f1522..50f438d2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,11 @@ Changes In next release ... +- Parsing the AST back to Python code now uses the built-in + `ast.unparse` function. This change is not directly surfaced but + means that the unparsing code is now more correctly tracking changes + to the interpreter. + - Drop support for platforms where AST nodes aren't weakref-capable (e.g., older PyPy). diff --git a/src/chameleon/astutil.py b/src/chameleon/astutil.py index 5c635571..9b699f06 100644 --- a/src/chameleon/astutil.py +++ b/src/chameleon/astutil.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import ClassVar -from typing import cast if TYPE_CHECKING: @@ -127,23 +126,11 @@ def copy(source, target) -> None: target.__dict__ = source.__dict__ -def swap(body, replacement, name) -> None: - root = ast.Expression(body=body) - for node in ast.walk(root): - if ( - isinstance(node, ast.Name) - and isinstance(node.ctx, ast.Load) - and node.id == name - ): - assert hasattr(replacement, '_fields') - node_annotations.setdefault(node, replacement) - - def marker(name): return ast.Str(s="__%s" % name) -class Node: +class Node(ast.AST): """AST baseclass that gives us a convenient initialization method. We explicitly declare and use the ``_fields`` attribute.""" @@ -214,822 +201,6 @@ class TokenRef(Node): _fields = "pos", "length" -class ASTCodeGenerator: - """General purpose base class for AST transformations. - - Every visitor method can be overridden to return an AST node that has been - altered or replaced in some way. - """ - - def __init__(self, tree): - self.lines_info = [] - self.line_info = [] - self.lines = [] - self.line = "" - self.last = None - self.indent = 0 - self.blame_stack = [] - self.visit(tree) - - if self.line.strip(): - self._new_line() - - self.line = None - self.line_info = None - - # strip trivial lines - self.code = "\n".join( - line.strip() and line or "" for line in self.lines - ) - - def _change_indent(self, delta) -> None: - self.indent += delta - - def _new_line(self) -> None: - if self.line is not None: - self.lines.append(self.line) - self.lines_info.append(self.line_info) - self.line = ' ' * 4 * self.indent - if len(self.blame_stack) == 0: - self.line_info = [] - self.last = None - else: - self.line_info = [ - ( - 0, - self.blame_stack[-1], - ) - ] - self.last = self.blame_stack[-1] - - def _write(self, s) -> None: - if len(s) == 0: - return - if len(self.blame_stack) == 0: - if self.last is not None: - self.last = None - self.line_info.append((len(self.line), self.last)) - else: - if self.last != self.blame_stack[-1]: - self.last = self.blame_stack[-1] - self.line_info.append((len(self.line), self.last)) - self.line += s - - def flush(self) -> None: - if self.line: - self._new_line() - - def visit(self, node): - if node is None: - return None - if isinstance(node, tuple): - return tuple([self.visit(n) for n in node]) - try: - self.blame_stack.append( - ( - node.lineno, - node.col_offset, - ) - ) - info = True - except AttributeError: - info = False - visitor = getattr(self, 'visit_%s' % node.__class__.__name__, None) - if visitor is None: - raise Exception( - 'No handler for ``{}`` ({}).'.format( - node.__class__.__name__, repr(node) - ) - ) - ret = visitor(node) - if info: - self.blame_stack.pop() - return ret - - def visit_Module(self, node) -> None: - for n in node.body: - self.visit(n) - - visit_Interactive = visit_Module - visit_Suite = visit_Module - - def visit_Expression(self, node): - return self.visit(node.body) - - # arguments = (expr* args, identifier? vararg, - # identifier? kwarg, expr* defaults) - def visit_arguments(self, node) -> None: - first = True - no_default_count = len(node.args) - len(node.defaults) - for i, arg in enumerate(node.args): - if not first: - self._write(', ') - else: - first = False - self.visit(arg) - if i >= no_default_count: - self._write('=') - self.visit(node.defaults[i - no_default_count]) - if getattr(node, 'vararg', None): - if not first: - self._write(', ') - else: - first = False - self._write('*' + node.vararg) - if getattr(node, 'kwarg', None): - if not first: - self._write(', ') - else: - first = False - self._write('**' + node.kwarg) - - def visit_arg(self, node) -> None: - self._write(node.arg) - - # FunctionDef(identifier name, arguments args, - # stmt* body, expr* decorators) - def visit_FunctionDef(self, node) -> None: - self._new_line() - for decorator in getattr(node, 'decorator_list', ()): - self._new_line() - self._write('@') - self.visit(decorator) - self._new_line() - self._write('def ' + node.name + '(') - self.visit(node.args) - self._write('):') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - - # ClassDef(identifier name, expr* bases, stmt* body) - def visit_ClassDef(self, node) -> None: - self._new_line() - self._write('class ' + node.name) - if node.bases: - self._write('(') - self.visit(node.bases[0]) - for base in node.bases[1:]: - self._write(', ') - self.visit(base) - self._write(')') - self._write(':') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - - # Return(expr? value) - def visit_Return(self, node) -> None: - self._new_line() - self._write('return') - if getattr(node, 'value', None): - self._write(' ') - self.visit(node.value) - - # Delete(expr* targets) - def visit_Delete(self, node) -> None: - self._new_line() - self._write('del ') - self.visit(node.targets[0]) - for target in node.targets[1:]: - self._write(', ') - self.visit(target) - - # Assign(expr* targets, expr value) - def visit_Assign(self, node) -> None: - self._new_line() - for target in node.targets: - self.visit(target) - self._write(' = ') - self.visit(node.value) - - # AugAssign(expr target, operator op, expr value) - def visit_AugAssign(self, node) -> None: - self._new_line() - self.visit(node.target) - self._write(' ' + self.binary_operators[node.op.__class__] + '= ') - self.visit(node.value) - - # JoinedStr(expr* values) - def visit_JoinedStr(self, node) -> None: - if node.values: - self._write('"".join((') - for value in node.values: - self.visit(value) - self._write(',') - self._write('))') - else: - self._write('""') - - # FormattedValue(expr value) - def visit_FormattedValue(self, node) -> None: - if node.conversion == ord('r'): - self._write('repr') - elif node.conversion == ord('a'): - self._write('ascii') - else: - self._write('str') - self._write('(') - self.visit(node.value) - if node.format_spec is not None: - self._write('.__format__(') - self.visit(node.format_spec) - self._write(')') - self._write(')') - - # Print(expr? dest, expr* values, bool nl) - def visit_Print(self, node) -> None: - self._new_line() - self._write('print') - if getattr(node, 'dest', None): - self._write(' >> ') - self.visit(node.dest) - if getattr(node, 'values', None): - self._write(', ') - else: - self._write(' ') - if getattr(node, 'values', None): - self.visit(node.values[0]) - for value in node.values[1:]: - self._write(', ') - self.visit(value) - if not node.nl: - self._write(',') - - # For(expr target, expr iter, stmt* body, stmt* orelse) - def visit_For(self, node) -> None: - self._new_line() - self._write('for ') - self.visit(node.target) - self._write(' in ') - self.visit(node.iter) - self._write(':') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - if getattr(node, 'orelse', None): - self._new_line() - self._write('else:') - self._change_indent(1) - for statement in node.orelse: - self.visit(statement) - self._change_indent(-1) - - # While(expr test, stmt* body, stmt* orelse) - def visit_While(self, node) -> None: - self._new_line() - self._write('while ') - self.visit(node.test) - self._write(':') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - if getattr(node, 'orelse', None): - self._new_line() - self._write('else:') - self._change_indent(1) - for statement in node.orelse: - self.visit(statement) - self._change_indent(-1) - - # If(expr test, stmt* body, stmt* orelse) - def visit_If(self, node) -> None: - self._new_line() - self._write('if ') - self.visit(node.test) - self._write(':') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - if getattr(node, 'orelse', None): - self._new_line() - self._write('else:') - self._change_indent(1) - for statement in node.orelse: - self.visit(statement) - self._change_indent(-1) - - # With(expr context_expr, expr? optional_vars, stmt* body) - def visit_With(self, node) -> None: - self._new_line() - self._write('with ') - self.visit(node.context_expr) - if getattr(node, 'optional_vars', None): - self._write(' as ') - self.visit(node.optional_vars) - self._write(':') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - - # Raise(expr? type, expr? inst, expr? tback) - def visit_Raise(self, node): - self._new_line() - self._write('raise') - if not getattr(node, "type", None): - exc = getattr(node, "exc", None) - if exc is None: - return - self._write(' ') - return self.visit(exc) - self._write(' ') - self.visit(node.type) - if not node.inst: - return - self._write(', ') - self.visit(node.inst) - if not node.tback: - return - self._write(', ') - self.visit(node.tback) - - # Try(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody) - def visit_Try(self, node) -> None: - self._new_line() - self._write('try:') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - if getattr(node, 'handlers', None): - for handler in node.handlers: - self.visit(handler) - self._new_line() - - if getattr(node, 'orelse', None): - self._write('else:') - self._change_indent(1) - for statement in node.orelse: - self.visit(statement) - self._change_indent(-1) - - if getattr(node, 'finalbody', None): - self._new_line() - self._write('finally:') - self._change_indent(1) - for statement in node.finalbody: - self.visit(statement) - self._change_indent(-1) - - # TryExcept(stmt* body, excepthandler* handlers, stmt* orelse) - def visit_TryExcept(self, node) -> None: - self._new_line() - self._write('try:') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - if getattr(node, 'handlers', None): - for handler in node.handlers: - self.visit(handler) - self._new_line() - if getattr(node, 'orelse', None): - self._write('else:') - self._change_indent(1) - for statement in node.orelse: - self.visit(statement) - self._change_indent(-1) - - # excepthandler = (expr? type, expr? name, stmt* body) - def visit_ExceptHandler(self, node) -> None: - self._new_line() - self._write('except') - if getattr(node, 'type', None): - self._write(' ') - self.visit(node.type) - if getattr(node, 'name', None): - self._write(' as ') - self.visit(node.name) - self._write(':') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - - visit_excepthandler = visit_ExceptHandler - - # TryFinally(stmt* body, stmt* finalbody) - def visit_TryFinally(self, node) -> None: - self._new_line() - self._write('try:') - self._change_indent(1) - for statement in node.body: - self.visit(statement) - self._change_indent(-1) - - if getattr(node, 'finalbody', None): - self._new_line() - self._write('finally:') - self._change_indent(1) - for statement in node.finalbody: - self.visit(statement) - self._change_indent(-1) - - # Assert(expr test, expr? msg) - def visit_Assert(self, node) -> None: - self._new_line() - self._write('assert ') - self.visit(node.test) - if getattr(node, 'msg', None): - self._write(', ') - self.visit(node.msg) - - def visit_alias(self, node) -> None: - self._write(node.name) - if getattr(node, 'asname', None): - self._write(' as ') - self._write(node.asname) - - # Import(alias* names) - def visit_Import(self, node) -> None: - self._new_line() - self._write('import ') - self.visit(node.names[0]) - for name in node.names[1:]: - self._write(', ') - self.visit(name) - - # ImportFrom(identifier module, alias* names, int? level) - def visit_ImportFrom(self, node) -> None: - self._new_line() - self._write('from ') - if node.level: - self._write('.' * node.level) - self._write(node.module) - self._write(' import ') - self.visit(node.names[0]) - for name in node.names[1:]: - self._write(', ') - self.visit(name) - - # Exec(expr body, expr? globals, expr? locals) - def visit_Exec(self, node) -> None: - self._new_line() - self._write('exec ') - self.visit(node.body) - if not node.globals: - return - self._write(', ') - self.visit(node.globals) - if not node.locals: - return - self._write(', ') - self.visit(node.locals) - - # Global(identifier* names) - def visit_Global(self, node) -> None: - self._new_line() - self._write('global ') - self.visit(node.names[0]) - for name in node.names[1:]: - self._write(', ') - self.visit(name) - - # Expr(expr value) - def visit_Expr(self, node) -> None: - self._new_line() - self.visit(node.value) - - # Pass - def visit_Pass(self, node) -> None: - self._new_line() - self._write('pass') - - # Break - def visit_Break(self, node) -> None: - self._new_line() - self._write('break') - - # Continue - def visit_Continue(self, node) -> None: - self._new_line() - self._write('continue') - - # EXPRESSIONS - def with_parens(f: _F) -> _F: # type: ignore[misc] - def _f(self, node) -> None: - self._write('(') - f(self, node) - self._write(')') - - return cast('_F', _f) - - bool_operators = {ast.And: 'and', ast.Or: 'or'} - - # BoolOp(boolop op, expr* values) - @with_parens - def visit_BoolOp(self, node) -> None: - joiner = ' ' + self.bool_operators[node.op.__class__] + ' ' - self.visit(node.values[0]) - for value in node.values[1:]: - self._write(joiner) - self.visit(value) - - binary_operators = { - ast.Add: '+', - ast.Sub: '-', - ast.Mult: '*', - ast.Div: '/', - ast.Mod: '%', - ast.Pow: '**', - ast.LShift: '<<', - ast.RShift: '>>', - ast.BitOr: '|', - ast.BitXor: '^', - ast.BitAnd: '&', - ast.FloorDiv: '//', - } - - # BinOp(expr left, operator op, expr right) - @with_parens - def visit_BinOp(self, node) -> None: - self.visit(node.left) - self._write(' ' + self.binary_operators[node.op.__class__] + ' ') - self.visit(node.right) - - unary_operators = { - ast.Invert: '~', - ast.Not: 'not', - ast.UAdd: '+', - ast.USub: '-', - } - - # UnaryOp(unaryop op, expr operand) - def visit_UnaryOp(self, node) -> None: - self._write(self.unary_operators[node.op.__class__] + ' ') - self.visit(node.operand) - - # Lambda(arguments args, expr body) - @with_parens - def visit_Lambda(self, node) -> None: - self._write('lambda ') - self.visit(node.args) - self._write(': ') - self.visit(node.body) - - # IfExp(expr test, expr body, expr orelse) - @with_parens - def visit_IfExp(self, node) -> None: - self.visit(node.body) - self._write(' if ') - self.visit(node.test) - self._write(' else ') - self.visit(node.orelse) - - # Dict(expr* keys, expr* values) - def visit_Dict(self, node) -> None: - self._write('{') - for key, value in zip(node.keys, node.values): - self.visit(key) - self._write(': ') - self.visit(value) - self._write(', ') - self._write('}') - - def visit_Set(self, node) -> None: - self._write('{') - elts = list(node.elts) - last = elts.pop() - for elt in elts: - self.visit(elt) - self._write(', ') - self.visit(last) - self._write('}') - - # DictComp(expr key, expr value, comprehension* generators) - def visit_DictComp(self, node) -> None: - self._write('{') - self.visit(node.key) - self._write(': ') - self.visit(node.value) - for generator in node.generators: - # comprehension = (expr target, expr iter, expr* ifs) - self._write(' for ') - self.visit(generator.target) - self._write(' in ') - self.visit(generator.iter) - for ifexpr in generator.ifs: - self._write(' if ') - self.visit(ifexpr) - self._write('}') - - # ListComp(expr elt, comprehension* generators) - def visit_ListComp(self, node) -> None: - self._write('[') - self.visit(node.elt) - for generator in node.generators: - # comprehension = (expr target, expr iter, expr* ifs) - self._write(' for ') - self.visit(generator.target) - self._write(' in ') - self.visit(generator.iter) - for ifexpr in generator.ifs: - self._write(' if ') - self.visit(ifexpr) - self._write(']') - - # GeneratorExp(expr elt, comprehension* generators) - def visit_GeneratorExp(self, node) -> None: - self._write('(') - self.visit(node.elt) - for generator in node.generators: - # comprehension = (expr target, expr iter, expr* ifs) - self._write(' for ') - self.visit(generator.target) - self._write(' in ') - self.visit(generator.iter) - for ifexpr in generator.ifs: - self._write(' if ') - self.visit(ifexpr) - self._write(')') - - # SetComp(expr elt, comprehension* generators) - def visit_SetComp(self, node) -> None: - self._write('{') - self.visit(node.elt) - for generator in node.generators: - # comprehension = (expr target, expr iter, expr* ifs) - self._write(' for ') - self.visit(generator.target) - self._write(' in ') - self.visit(generator.iter) - for ifexpr in generator.ifs: - self._write(' if ') - self.visit(ifexpr) - self._write('}') - - # Yield(expr? value) - def visit_Yield(self, node) -> None: - self._write('yield') - if getattr(node, 'value', None): - self._write(' ') - self.visit(node.value) - - comparison_operators = { - ast.Eq: '==', - ast.NotEq: '!=', - ast.Lt: '<', - ast.LtE: '<=', - ast.Gt: '>', - ast.GtE: '>=', - ast.Is: 'is', - ast.IsNot: 'is not', - ast.In: 'in', - ast.NotIn: 'not in', - } - - # Compare(expr left, cmpop* ops, expr* comparators) - @with_parens - def visit_Compare(self, node) -> None: - self.visit(node.left) - for op, comparator in zip(node.ops, node.comparators): - self._write(' ' + self.comparison_operators[op.__class__] + ' ') - self.visit(comparator) - - # Call(expr func, expr* args, keyword* keywords, - # expr? starargs, expr? kwargs) - def visit_Call(self, node) -> None: - self.visit(node.func) - self._write('(') - first = True - for arg in node.args: - if not first: - self._write(', ') - first = False - self.visit(arg) - - for keyword in node.keywords: - if not first: - self._write(', ') - first = False - # keyword = (identifier arg, expr value) - if keyword.arg is not None: - self._write(keyword.arg) - self._write('=') - else: - self._write('**') - self.visit(keyword.value) - - self._write(')') - - # Repr(expr value) - def visit_Repr(self, node) -> None: - self._write('`') - self.visit(node.value) - self._write('`') - - # Constant(object value) - def visit_Constant(self, node) -> None: - if node.value is Ellipsis: - self._write('...') - else: - self._write(repr(node.value)) - - # Num(object n) - def visit_Num(self, node) -> None: - self._write(repr(node.n)) - - # Str(string s) - def visit_Str(self, node) -> None: - self._write(repr(node.s)) - - def visit_Ellipsis(self, node) -> None: - self._write('...') - - # Attribute(expr value, identifier attr, expr_context ctx) - def visit_Attribute(self, node) -> None: - self.visit(node.value) - self._write('.') - self._write(node.attr) - - # Subscript(expr value, slice slice, expr_context ctx) - def visit_Subscript(self, node) -> None: - self.visit(node.value) - self._write('[') - if isinstance(node.slice, ast.Tuple) and node.slice.elts: - self.visit(node.slice.elts[0]) - if len(node.slice.elts) == 1: - self._write(', ') - else: - for dim in node.slice.elts[1:]: - self._write(', ') - self.visit(dim) - elif isinstance(node.slice, ast.Slice): - self.visit_Slice(node.slice, True) - else: - self.visit(node.slice) - self._write(']') - - # Slice(expr? lower, expr? upper, expr? step) - def visit_Slice(self, node, subscription: bool = False) -> None: - if subscription: - if getattr(node, 'lower', None) is not None: - self.visit(node.lower) - self._write(':') - if getattr(node, 'upper', None) is not None: - self.visit(node.upper) - if getattr(node, 'step', None) is not None: - self._write(':') - self.visit(node.step) - else: - self._write('slice(') - self.visit(getattr(node, "lower", None) or AST_NONE) - self._write(', ') - self.visit(getattr(node, "upper", None) or AST_NONE) - self._write(', ') - self.visit(getattr(node, "step", None) or AST_NONE) - self._write(')') - - # Index(expr value) - def visit_Index(self, node) -> None: - self.visit(node.value) - - # ExtSlice(slice* dims) - def visit_ExtSlice(self, node) -> None: - self.visit(node.dims[0]) - if len(node.dims) == 1: - self._write(', ') - else: - for dim in node.dims[1:]: - self._write(', ') - self.visit(dim) - - # Starred(expr value, expr_context ctx) - def visit_Starred(self, node) -> None: - self._write('*') - self.visit(node.value) - - # Name(identifier id, expr_context ctx) - def visit_Name(self, node) -> None: - self._write(node.id) - - # List(expr* elts, expr_context ctx) - def visit_List(self, node) -> None: - self._write('[') - for elt in node.elts: - self.visit(elt) - self._write(', ') - self._write(']') - - # Tuple(expr *elts, expr_context ctx) - def visit_Tuple(self, node) -> None: - self._write('(') - for elt in node.elts: - self.visit(elt) - self._write(', ') - self._write(')') - - # NameConstant(singleton value) - def visit_NameConstant(self, node) -> None: - self._write(str(node.value)) - - class AnnotationAwareVisitor(ast.NodeVisitor): def visit(self, node) -> None: annotation = node_annotations.get(node) diff --git a/src/chameleon/codegen.py b/src/chameleon/codegen.py index 3c5b1635..b7007e70 100644 --- a/src/chameleon/codegen.py +++ b/src/chameleon/codegen.py @@ -1,11 +1,23 @@ from __future__ import annotations -import ast import builtins +import re import textwrap import types +from ast import AST +from ast import Assign +from ast import Constant +from ast import Expr +from ast import FunctionDef +from ast import Import +from ast import ImportFrom +from ast import Module +from ast import NodeTransformer +from ast import NodeVisitor +from ast import Num +from ast import alias +from ast import unparse -from chameleon.astutil import ASTCodeGenerator from chameleon.astutil import Builtin from chameleon.astutil import Symbol from chameleon.astutil import load @@ -37,17 +49,18 @@ def wrapper(*vargs, **kwargs): symbols = dict(zip(args, vargs + defaults)) symbols.update(kwargs) - class Visitor(ast.NodeVisitor): + class Visitor(NodeVisitor): def visit_FunctionDef(self, node) -> None: self.generic_visit(node) name = symbols.get(node.name, self) if name is not self: - node_annotations[node] = ast.FunctionDef( + node_annotations[node] = FunctionDef( name=name, args=node.args, body=node.body, decorator_list=getattr(node, "decorator_list", []), + lineno=None, ) def visit_Name(self, node) -> None: @@ -80,76 +93,41 @@ def visit_Name(self, node) -> None: return wrapper(**kw) -class TemplateCodeGenerator(ASTCodeGenerator): - """Extends the standard Python code generator class with handlers - for the helper node classes: - - - Symbol (an importable value) - - Static (value that can be made global) - - Builtin (from the builtins module) - - Marker (short-hand for a unique static object) +class TemplateCodeGenerator(NodeTransformer): + """Generate code from AST tree. + The syntax tree has been extended with internal nodes. We first + transform the tree to process the internal nodes before generating + the code string. """ names = () def __init__(self, tree): - self.imports = {} + self.comments = [] self.defines = {} - self.markers = {} + self.imports = {} self.tokens = [] - # Generate code - super().__init__(tree) - - def visit_Module(self, node): - super().visit_Module(node) - - # Make sure we terminate the line printer - self.flush() - - # Clear lines array for import visits - body = self.lines - self.lines = [] + # Run transform. + tree = self.visit(tree) - while self.defines: - name, node = self.defines.popitem() - assignment = ast.Assign(targets=[store(name)], value=node) - self.visit(assignment) - - # Make sure we terminate the line printer - self.flush() - - # Clear lines array for import visits - defines = self.lines - self.lines = [] - - while self.imports: - value, node = self.imports.popitem() - - if isinstance(value, types.ModuleType): - stmt = ast.Import( - names=[ast.alias(name=value.__name__, asname=node.id)]) - elif hasattr(value, '__name__'): - path = reverse_builtin_map.get(value) - if path is None: - path = value.__module__ - name = value.__name__ - stmt = ast.ImportFrom( - module=path, - names=[ast.alias(name=name, asname=node.id)], - level=0, - ) - else: - raise TypeError(value) + # Generate code. + code = unparse(tree) - self.visit(stmt) + # Fix-up comments. + comments = iter(self.comments) + code = re.sub( + r'^(\s*)\.\.\.$', + lambda m: "\n".join( + (m.group(1) + "#" + line) + for line in next(comments).replace("\r", "\n").split("\n") + ), + code, + flags=re.MULTILINE + ) - # Clear last import - self.flush() - - # Stich together lines - self.lines += defines + body + self.code = code def define(self, name, node): assert node is not None @@ -178,40 +156,64 @@ def require(self, value): return node - def visit(self, node) -> None: + def visit(self, node) -> AST: annotation = node_annotations.get(node) if annotation is None: - super().visit(node) - else: - self.visit(annotation) + return super().visit(node) + return self.visit(annotation) - def visit_Comment(self, node) -> None: - if node.stmt is None: - self._new_line() - else: - self.visit(node.stmt) + def visit_Module(self, module) -> AST: + assert isinstance(module, Module) + module = super().generic_visit(module) + preamble: list[AST] = [] + + for name, node in self.defines.items(): + assignment = Assign(targets=[store(name)], value=node, lineno=None) + preamble.append(assignment) + + for value, node in self.imports.items(): + stmt: AST + + if isinstance(value, types.ModuleType): + stmt = Import( + names=[alias(name=value.__name__, asname=node.id)]) + elif hasattr(value, '__name__'): + path = reverse_builtin_map.get(value) + if path is None: + path = value.__module__ + name = value.__name__ + stmt = ImportFrom( + module=path, + names=[alias(name=name, asname=node.id)], + level=0, + ) + else: + raise TypeError(value) + + preamble.append(stmt) - for line in node.text.replace('\r', '\n').split('\n'): - self._new_line() - self._write("{}#{}".format(node.space, line)) + preamble = [self.visit(stmt) for stmt in preamble] + return Module(preamble + module.body, ()) - def visit_Builtin(self, node) -> None: + def visit_Comment(self, node) -> AST: + self.comments.append(node.text) + return Expr(Constant(...)) + + def visit_Builtin(self, node) -> AST: name = load(node.id) - self.visit(name) + return self.visit(name) - def visit_Symbol(self, node) -> None: - node = self.require(node.value) - self.visit(node) + def visit_Symbol(self, node) -> AST: + return self.require(node.value) - def visit_Static(self, node) -> None: + def visit_Static(self, node) -> AST: if node.name is None: name = "_static_%s" % str(id(node.value)).replace('-', '_') else: name = node.name - node = self.define(name, node.value) - self.visit(node) + return self.visit(node) - def visit_TokenRef(self, node) -> None: + def visit_TokenRef(self, node) -> AST: self.tokens.append((node.pos, node.length)) - super().visit(ast.Num(n=node.pos)) + return self.visit(Num(n=node.pos)) diff --git a/src/chameleon/compiler.py b/src/chameleon/compiler.py index d1613f28..91264c6f 100644 --- a/src/chameleon/compiler.py +++ b/src/chameleon/compiler.py @@ -11,7 +11,6 @@ import sys import textwrap import threading -from ast import Try from chameleon.astutil import AST_NONE from chameleon.astutil import Builtin @@ -26,10 +25,8 @@ from chameleon.astutil import param from chameleon.astutil import store from chameleon.astutil import subscript -from chameleon.astutil import swap from chameleon.codegen import TemplateCodeGenerator from chameleon.codegen import template -from chameleon.config import DEBUG_MODE from chameleon.exc import ExpressionError from chameleon.exc import TranslationError from chameleon.i18n import simple_translate @@ -48,7 +45,6 @@ from chameleon.tal import NAME from chameleon.tal import ErrorInfo from chameleon.tokenize import Token -from chameleon.utils import DebuggingOutputStream from chameleon.utils import ListDictProxy from chameleon.utils import char2entity from chameleon.utils import decode_htmlentities @@ -77,11 +73,6 @@ RE_MANGLE = re.compile(r'[^\w_]') RE_NAME = re.compile('^%s$' % NAME) -if DEBUG_MODE: - LIST = template("cls()", cls=DebuggingOutputStream, mode="eval") -else: - LIST = template("[]", mode="eval") - def identifier(prefix, suffix=None) -> str: return "__{}_{}".format(prefix, mangle(suffix or id(prefix))) @@ -290,7 +281,10 @@ class EmitText(Node): class Scope(Node): - """"Set a local output scope.""" + """Set a local output scope. + + This is used for the translation machinery. + """ _fields = "body", "append", "stream" @@ -736,7 +730,7 @@ class NameTransform: >>> def test(node): ... rewritten = nt(node) - ... module = ast.Module([ast.fix_missing_locations(rewritten)]) + ... module = ast.Module([ast.fix_missing_locations(rewritten)], []) ... codegen = TemplateCodeGenerator(module) ... return codegen.code @@ -998,6 +992,7 @@ def __init__( source, builtins={}, strict=True, + stream_factory=list, ): self._scopes = [set()] self._expression_cache = {} @@ -1007,6 +1002,17 @@ def __init__( self._macros = [] self._current_slot = [] + # Prepare stream factory (callable) + self._new_list = ( + ast.List([], ast.Load()) if stream_factory is list else + ast.Call( + ast.Symbol(stream_factory), + args=[], + kwargs=[], + lineno=None, + ) + ) + internals = COMPILER_INTERNALS_OR_DISALLOWED | set(self.defaults) transform = NameTransform( @@ -1024,29 +1030,39 @@ def __init__( strict=strict, ) - module = ast.Module([]) + module = ast.Module([], []) module.body += self.visit(node) ast.fix_missing_locations(module) class Generator(TemplateCodeGenerator): scopes = [Scope()] - def visit_EmitText(self, node) -> None: + def visit_EmitText(self, node) -> ast.AST: append = load(self.scopes[-1].append or "__append") - for node in template( - "append(s)", append=append, s=ast.Str(s=node.s) - ): - self.visit(node) - - def visit_Scope(self, node) -> None: + node = ast.Expr(ast.Call( + func=append, + args=[ast.Str(s=node.s)], + keywords=[], + starargs=None, + kwargs=None + )) + return self.visit(node) + + def visit_Name(self, node) -> ast.AST: + if isinstance(node.ctx, ast.Load): + scope = self.scopes[-1] + for name in ("append", "stream"): + if node.id == f"__{name}": + identifier = getattr(scope, name, None) + if identifier: + return load(identifier) + return node + + def visit_Scope(self, node) -> list[ast.AST]: self.scopes.append(node) - body = list(node.body) - swap(body, load(node.append), "__append") - if node.stream: - swap(body, load(node.stream), "__stream") - for node in body: - self.visit(node) + stmts = list(map(self.visit, node.body)) self.scopes.pop() + return stmts generator = Generator(module) tokens = [ @@ -1094,7 +1110,6 @@ def visit_Element(self, node): def visit_Module(self, node): body = [] - body += template("import re") body += template("import functools") body += template("from itertools import chain as __chain") @@ -1102,7 +1117,7 @@ def visit_Module(self, node): body += template("__default = intern('__default__')") body += template("__marker = object()") body += template( - r"g_re_amp = re.compile(r'&(?!([A-Za-z]+|#[0-9]+);)')" + "g_re_amp = re.compile(r'&(?!([A-Za-z]+|#[0-9]+);)')" ) body += template( r"g_re_needs_escape = re.compile(r'[&<>\"\']').search") @@ -1116,11 +1131,16 @@ def visit_Module(self, node): program = self.visit(node.program) body += [ast.FunctionDef( - name=node.name, args=ast.arguments( + name=node.name, + args=ast.arguments( args=[param(b) for b in self._builtins], - defaults=(), + defaults=[], + kwonlyargs=[], + posonlyargs=[], ), - body=program + body=program, + decorator_list=[], + lineno=None, )] return body @@ -1173,7 +1193,7 @@ def visit_Macro(self, node): self._slots = set() # Visit macro body - nodes = itertools.chain(*tuple(map(self.visit, node.body))) + nodes = list(itertools.chain(*tuple(map(self.visit, node.body)))) # Slot resolution for name in self._slots: @@ -1196,9 +1216,11 @@ def visit_Macro(self, node): # Wrap visited nodes in try-except error handler. body += [ - Try( + ast.Try( body=nodes, - handlers=[ast.ExceptHandler(body=exc_handler)] + handlers=[ast.ExceptHandler(body=exc_handler)], + finalbody=[], + orelse=[], ) ] @@ -1216,8 +1238,12 @@ def visit_Macro(self, node): param("target_language"), ], defaults=[load("None"), load("None"), load("None")], + kwonlyargs=[], + posonlyargs=[], ), - body=body + body=body, + decorator_list=[], + lineno=None, ) yield function @@ -1266,16 +1292,18 @@ def visit_OnError(self, node): key=ast.Str(s=node.name), ) - body += [Try( + body += [ast.Try( body=self.visit(node.node), handlers=[ast.ExceptHandler( type=ast.Tuple(elts=[Builtin("Exception")], ctx=ast.Load()), - name=store("__exc"), + name="__exc", body=(error_assignment + template("del __stream[fallback:]", fallback=fallback) + fallback_body ), - )] + )], + finalbody=[], + orelse=[], )] return body @@ -1424,7 +1452,8 @@ def visit_Translate(self, node): # Prepare new stream append = identifier("append", id(node)) stream = identifier("stream", id(node)) - body += template("s = new_list", s=stream, new_list=LIST) + \ + + body += template("s = new_list", s=stream, new_list=self._new_list) + \ template("a = s.append", a=append, s=stream) # Visit body to generate the message body @@ -1672,7 +1701,7 @@ def visit_Name(self, node): # prepare new stream stream, append = self._get_translation_identifiers(node.name) - body += template("s = new_list", s=stream, new_list=LIST) + \ + body += template("s = new_list", s=stream, new_list=self._new_list) + \ template("a = s.append", a=append, s=stream) # generate code @@ -1727,9 +1756,12 @@ def visit_UseExternalMacro(self, node): load("__i18n_context"), load("target_language"), ], + kwonlyargs=[], + posonlyargs=[], ), - body=body or [ - ast.Pass()], + body=body or [ast.Pass()], + decorator_list=[], + lineno=None, )) key = ast.Str(s=key) @@ -1742,9 +1774,10 @@ def visit_UseExternalMacro(self, node): if node.extend: append = template("_slots.appendleft(NAME)", NAME=fun) - assignment = [Try( + assignment = [ast.Try( body=template("_slots = getname(KEY)", KEY=key), handlers=[ast.ExceptHandler(body=assignment)], + finalbody=[], orelse=append, )] @@ -1847,6 +1880,7 @@ def visit_Repeat(self, node): target=store("__item"), iter=load("__iterator"), body=assignment + inner, + orelse=[], )] # Finally, clean up assignment if it's local diff --git a/src/chameleon/tales.py b/src/chameleon/tales.py index bdc82843..d8c05ac6 100644 --- a/src/chameleon/tales.py +++ b/src/chameleon/tales.py @@ -2,7 +2,6 @@ import ast import re -from ast import Try from chameleon.astutil import Builtin from chameleon.astutil import ItemLookupOnAttributeErrorVisitor @@ -148,15 +147,17 @@ def __call__(self, target, engine): if i == 0: body = assignment else: - body = [Try( + body = [ast.Try( body=assignment, handlers=[ast.ExceptHandler( type=ast.Tuple( - elts=map(resolve_global, self.exceptions), + elts=list(map(resolve_global, self.exceptions)), ctx=ast.Load()), name=None, body=body, )], + finalbody=[], + orelse=[], )] return body @@ -505,16 +506,17 @@ def __call__(self, target, engine): compiler = engine.parse(self.expression, False) body = compiler.assign_value(ignore) - classes = map(resolve_global, self.exceptions) + classes = list(map(resolve_global, self.exceptions)) return [ - Try( + ast.Try( body=body, handlers=[ast.ExceptHandler( type=ast.Tuple(elts=classes, ctx=ast.Load()), name=None, body=template("target = 0", target=target), )], + finalbody=[], orelse=template("target = 1", target=target) ) ] diff --git a/src/chameleon/template.py b/src/chameleon/template.py index 6ebd81bc..bf5b0c39 100644 --- a/src/chameleon/template.py +++ b/src/chameleon/template.py @@ -59,6 +59,7 @@ def safe_get_package_version(name: str) -> str | None: except importlib_metadata.PackageNotFoundError: return None + def get_package_versions() -> list[tuple[str, str]]: distributions = importlib_metadata.packages_distributions().values() versions = { @@ -335,8 +336,12 @@ def _compile(self, body: str, builtins: Collection[str]) -> str: program = self.parse(body) module = Module(PROGRAM_NAME, program) compiler = Compiler( - self.engine, module, str(self.filename), body, - builtins, strict=self.strict + self.engine, + module, + str(self.filename), + body, + builtins=builtins, + strict=self.strict ) return compiler.code diff --git a/src/chameleon/tests/test_astutil.py b/src/chameleon/tests/test_astutil.py deleted file mode 100644 index 5b02fc4a..00000000 --- a/src/chameleon/tests/test_astutil.py +++ /dev/null @@ -1,29 +0,0 @@ -import ast -import unittest - - -class ASTCodeGeneratorTestCase(unittest.TestCase): - def _eval(self, tree, env): - from chameleon.astutil import ASTCodeGenerator - source = ASTCodeGenerator(tree).code - code = compile(source, '', 'exec') - exec(code, env) - - def test_slice(self): - tree = ast.Module( - body=[ - ast.Assign( - targets=[ - ast.Name(id='x', ctx=ast.Store())], - value=ast.Call( - func=ast.Name(id='f', ctx=ast.Load()), - args=[ - ast.Slice( - upper=ast.Constant(value=0))], - keywords=[]))], - type_ignores=[] - ) - def f(x): return x - d = {"f": f} - self._eval(tree, d) - assert d['x'] == slice(None, 0, None) diff --git a/src/chameleon/tests/test_templates.py b/src/chameleon/tests/test_templates.py index d94030f9..8bc3a020 100644 --- a/src/chameleon/tests/test_templates.py +++ b/src/chameleon/tests/test_templates.py @@ -1,16 +1,16 @@ import glob import os import re -import shutil import sys from functools import partial from functools import wraps +import pytest + from chameleon.exc import RenderError from chameleon.exc import TemplateError from chameleon.tales import DEFAULT_MARKER -import pytest ROOT = os.path.dirname(__file__)