diff --git a/pyproject.toml b/pyproject.toml index f0a55423..aa12801d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,10 +9,10 @@ exclude = "/tests/" follow_imports = "silent" [[tool.mypy.overrides]] -# strict config for public API +# strict config for fully typed modules and public API module = [ - "chameleon.exc.RenderError", - "chameleon.exc.TemplateError", + "chameleon.exc.*", + "chameleon.utils.*", "chameleon.zpt.loader.*", "chameleon.zpt.template.*", ] diff --git a/src/chameleon/exc.py b/src/chameleon/exc.py index a50951d1..3767473d 100644 --- a/src/chameleon/exc.py +++ b/src/chameleon/exc.py @@ -2,6 +2,7 @@ import traceback from typing import TYPE_CHECKING +from typing import Any from chameleon.config import SOURCE_EXPRESSION_MARKER_LENGTH as LENGTH from chameleon.tokenize import Token @@ -10,10 +11,19 @@ if TYPE_CHECKING: + from collections.abc import Callable + from collections.abc import Iterable + from collections.abc import Iterator + from collections.abc import Mapping from typing_extensions import Self -def compute_source_marker(line, column, expression, size): +def compute_source_marker( + line: str, + column: int, + expression: str, + size: int +) -> tuple[str, str]: """Computes source marker location string. >>> def test(l, c, e, s): @@ -76,9 +86,9 @@ def compute_source_marker(line, column, expression, size): size = len(expression) else: window = (size - len(expression)) / 2.0 - offset = column - window - offset -= min(3, max(0, column + window + len(expression) - len(s))) - offset = int(offset) + f_offset = column - window + f_offset -= min(3, max(0, column + window + len(expression) - len(s))) + offset = int(f_offset) if offset > 0: s = s[offset:] @@ -97,7 +107,13 @@ def compute_source_marker(line, column, expression, size): return s, column * " " + marker -def iter_source_marker_lines(source, expression, line, column): +def iter_source_marker_lines( + source: Iterable[str], + expression: str, + line: int, + column: int +) -> Iterator[str]: + for i, l in enumerate(source): if i + 1 != line: continue @@ -111,7 +127,7 @@ def iter_source_marker_lines(source, expression, line, column): break -def ellipsify(string, limit): +def ellipsify(string: str, limit: int) -> str: if len(string) > limit: return "... " + string[-(limit - 4):] @@ -172,6 +188,7 @@ def __str__(self) -> str: text += "\n" text += " - Location: (line %d: col %d)" % (lineno, column) + lines: Iterable[str] if lineno and column: if self.token.source: lines = iter_source_marker_lines( @@ -184,7 +201,7 @@ def __str__(self) -> str: except OSError: pass else: - iter_source_marker_lines( + lines = iter_source_marker_lines( iter(f), self.token, lineno, column ) try: @@ -262,7 +279,14 @@ class ExpressionError(LanguageError): class ExceptionFormatter: - def __init__(self, errors, econtext, rcontext, value_repr) -> None: + def __init__( + self, + errors: list[tuple[str, int, int, str, BaseException]], + econtext: Mapping[str, object], + rcontext: dict[str, Any], + value_repr: Callable[[object], str] + ) -> None: + kwargs = rcontext.copy() kwargs.update(econtext) @@ -274,16 +298,16 @@ def __init__(self, errors, econtext, rcontext, value_repr) -> None: self._kwargs = kwargs self._value_repr = value_repr - def __call__(self): + def __call__(self) -> str: # Format keyword arguments; consecutive arguments are indented # for readability - formatted = [ + formatted_args = [ "{}: {}".format(name, self._value_repr(value)) for name, value in self._kwargs.items() ] - for index, string in enumerate(formatted[1:]): - formatted[index + 1] = " " * 15 + string + for index, string in enumerate(formatted_args[1:]): + formatted_args[index + 1] = " " * 15 + string out = [] @@ -319,14 +343,14 @@ def __call__(self): finally: f.close() - out.append(" - Arguments: %s" % "\n".join(formatted)) + out.append(" - Arguments: %s" % "\n".join(formatted_args)) if isinstance(exc.__str__, ExceptionFormatter): # This is a nested error that has already been wrapped # We must unwrap it before trying to format it to prevent # recursion exc = create_formatted_exception( - exc, type(exc), exc._original__str__) + exc, type(exc), exc._original__str__) # type: ignore formatted = traceback.format_exception_only(type(exc), exc)[-1] formatted_class = "%s:" % type(exc).__name__ diff --git a/src/chameleon/template.py b/src/chameleon/template.py index 0ce512a7..508ad7b5 100644 --- a/src/chameleon/template.py +++ b/src/chameleon/template.py @@ -239,7 +239,7 @@ def render(self, **__kw: Any) -> str: try: exc = create_formatted_exception( - exc, cls, formatter, RenderError + exc, cls, formatter, RenderError # type: ignore ) except TypeError: pass diff --git a/src/chameleon/utils.py b/src/chameleon/utils.py index 3d1080fb..6d394e73 100644 --- a/src/chameleon/utils.py +++ b/src/chameleon/utils.py @@ -4,19 +4,38 @@ import logging import os import re +from enum import Enum from html import entities as htmlentitydefs from typing import TYPE_CHECKING from typing import Any +from typing import Generic +from typing import Literal from typing import NoReturn +from typing import TypeVar +from typing import overload if TYPE_CHECKING: + from collections.abc import Callable from collections.abc import Iterable + from collections.abc import Iterator + from collections.abc import Mapping + from collections.abc import Sequence from types import TracebackType +_KT = TypeVar('_KT') +_VT_co = TypeVar('_VT_co') + log = logging.getLogger('chameleon.utils') -marker = object() + + +class _Marker(Enum): + marker = object() + + +# NOTE: Enums are better markers for type narrowing +marker: Literal[_Marker.marker] = _Marker.marker def safe_native(s: str | bytes, encoding: str = 'utf-8') -> str: @@ -152,13 +171,16 @@ def mangle(filename: str) -> str: return base.replace('.', '_').replace('-', '_') -def char2entity(c): +def char2entity(c: str | bytes | bytearray) -> str: cp = ord(c) name = htmlentitydefs.codepoint2name.get(cp) return '&%s;' % name if name is not None else '&#%d;' % cp -def substitute_entity(match, n2cp=htmlentitydefs.name2codepoint): +def substitute_entity( + match: re.Match[str], + n2cp: Mapping[str, int] = htmlentitydefs.name2codepoint +) -> str: ent = match.group(3) if match.group(1) == "#": @@ -166,6 +188,10 @@ def substitute_entity(match, n2cp=htmlentitydefs.name2codepoint): return chr(int(ent)) elif match.group(2) == 'x': return chr(int('0x' + ent, 16)) + else: + # FIXME: This should be unreachable, so we can + # try raising an AssertionError instead + return '' else: cp = n2cp.get(ent) @@ -175,7 +201,12 @@ def substitute_entity(match, n2cp=htmlentitydefs.name2codepoint): return match.group() -def create_formatted_exception(exc, cls, formatter, base=Exception): +def create_formatted_exception( + exc: BaseException, + cls: type[object], + formatter: Callable[..., str], + base: type[BaseException] = Exception +) -> BaseException: try: try: new = type(cls.__name__, (cls, base), { @@ -187,13 +218,14 @@ def create_formatted_exception(exc, cls, formatter, base=Exception): except TypeError: new = cls + inst: BaseException try: inst = BaseException.__new__(new) except TypeError: inst = cls.__new__(new) BaseException.__init__(inst, *exc.args) - inst.__dict__ = exc.__dict__ + inst.__dict__ = exc.__dict__ # type: ignore[attr-defined] return inst except ValueError: @@ -230,7 +262,7 @@ def join(stream: Iterable[str]) -> str: raise -def decode_htmlentities(string): +def decode_htmlentities(string: str) -> str: """ >>> str(decode_htmlentities('&amp;')) '&' @@ -244,21 +276,21 @@ def decode_htmlentities(string): # Taken from zope.dottedname -def _resolve_dotted(name, module=None): - name = name.split('.') - if not name[0]: +def _resolve_dotted(name: str, module: str | None = None) -> Any: + name_parts = name.split('.') + if not name_parts[0]: if module is None: raise ValueError("relative name without base module") - module = module.split('.') - name.pop(0) + module_parts = module.split('.') + name_parts.pop(0) while not name[0]: - module.pop() - name.pop(0) - name = module + name + module_parts.pop() + name_parts.pop(0) + name_parts = module_parts + name_parts - used = name.pop(0) + used = name_parts.pop(0) found = __import__(used) - for n in name: + for n in name_parts: used += '.' + n try: found = getattr(found, n) @@ -269,33 +301,36 @@ def _resolve_dotted(name, module=None): return found -def resolve_dotted(dotted): +def resolve_dotted(dotted: str) -> Any: if dotted not in module_cache: resolved = _resolve_dotted(dotted) module_cache[dotted] = resolved return module_cache[dotted] -def limit_string(s, max_length=53): +def limit_string(s: str, max_length: int = 53) -> str: if len(s) > max_length: return s[:max_length - 3] + '...' return s -def value_repr(value): +def value_repr(value: object) -> str: if isinstance(value, str): short = limit_string(value) return short.replace('\n', '\\n') if isinstance(value, (int, float)): - return value + return value # type: ignore[return-value] if isinstance(value, dict): return '{...} (%d)' % len(value) try: + # FIXME: Is this trailing comma intentional? + # it changes the formatting + # vs. name = str(getattr(value, '__name__', None)), except: # noqa: E722 do not use bare 'except' - name = '-' + name = '-' # type: ignore[assignment] return '<{} {} at {}>'.format( type(value).__name__, name, hex(abs(id(value)))) @@ -304,41 +339,41 @@ def value_repr(value): class callablestr(str): __slots__ = () - def __call__(self): + def __call__(self) -> str: return self class callableint(int): __slots__ = () - def __call__(self): + def __call__(self) -> int: return self class descriptorstr: __slots__ = "function", "__name__" - def __init__(self, function) -> None: + def __init__(self, function: Callable[[Any], str]) -> None: self.function = function self.__name__ = function.__name__ - def __get__(self, context, cls): + def __get__(self, context: object, cls: type[object]) -> callablestr: return callablestr(self.function(context)) class descriptorint: __slots__ = "function", "__name__" - def __init__(self, function) -> None: + def __init__(self, function: Callable[[Any], int]) -> None: self.function = function self.__name__ = function.__name__ - def __get__(self, context, cls): + def __get__(self, context: object, cls: type[object]) -> callableint: return callableint(self.function(context)) class DebuggingOutputStream(list[str]): - def append(self, value): + def append(self, value: str) -> None: if not isinstance(value, str): raise TypeError(value) @@ -381,31 +416,36 @@ class Scope(dict[str, Any]): set_local = dict.__setitem__ - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: value = self.get(key, marker) if value is marker: raise KeyError(key) return value - def __contains__(self, key) -> bool: - return self.get(key, marker) is not marker + def __contains__(self, key: object) -> bool: + return self.get(key, marker) is not marker # type: ignore - def __iter__(self): + def __iter__(self) -> Iterator[str]: root = getattr(self, "_root", marker) - yield from dict.__iter__(self) + yield from super().__iter__() if root is not marker: for key in root: - if not dict.__contains__(self, key): + if not super().__contains__(key): yield key - def get(self, key, default=None): - value = dict.get(self, key, marker) + @overload + def get(self, key: str, default: None = None) -> Any | None: ... + @overload + def get(self, key: str, default: object) -> Any: ... + + def get(self, key: str, default: object = None) -> Any: + value = super().get(key, marker) if value is not marker: return value root = getattr(self, "_root", marker) if root is not marker: - value = dict.get(root, key, marker) + value = super(Scope, root).get(key, marker) if value is not marker: return value @@ -413,20 +453,20 @@ def get(self, key, default=None): return default @property - def vars(self): + def vars(self) -> Mapping[str, Any]: return self - def copy(self): + def copy(self) -> Scope: inst = Scope(self) root = getattr(self, "_root", self) - inst._root = root + inst._root = root # type: ignore[attr-defined] return inst - def set_global(self, name, value) -> None: + def set_global(self, name: str, value: Any) -> None: root = getattr(self, "_root", self) root[name] = value - def get_name(self, key): + def get_name(self, key: str) -> Any: value = self.get(key, marker) if value is marker: raise NameError(key) @@ -436,11 +476,14 @@ def get_name(self, key): setGlobal = set_global -class ListDictProxy: - def __init__(self, _l) -> None: +class ListDictProxy(Generic[_KT, _VT_co]): + def __init__( + self: ListDictProxy[_KT, _VT_co], + _l: Sequence[Mapping[_KT, _VT_co]] + ) -> None: self._l = _l - def get(self, key): + def get(self, key: _KT) -> _VT_co | None: return self._l[-1].get(key) @@ -471,12 +514,14 @@ def __repr__(self) -> str: return '<%s>' % self.name -def lookup_attr(obj, key): +def lookup_attr(obj: object, key: str) -> Any: try: return getattr(obj, key) except AttributeError as exc: + # FIXME: What are the two try excepts here for? + # We just raise the thing we catch... try: - get = obj.__getitem__ + get = obj.__getitem__ # type: ignore[index] except AttributeError: raise exc try: