Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment: Annotated invalid expressions #199

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ jobs:
uses: actions/upload-artifact@v3
if: failure()
with:
name: "python-${{ matrix.python }}-pybind-${{ matrix.pybind11-branch }}.patch"
path: "./tests/stubs/python-${{ matrix.python }}/pybind11-${{ matrix.pybind11-branch }}.patch"
name: "python-${{ matrix.python }}-pybind-${{ matrix.pybind11-branch }}-${{ matrix.numpy-format }}.patch"
path: "./tests/stubs/python-${{ matrix.python }}/pybind11-${{ matrix.pybind11-branch }}/${{ matrix.numpy-format }}.patch"
retention-days: 30
if-no-files-found: ignore

Expand Down
4 changes: 3 additions & 1 deletion pybind11_stubgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
RemoveSelfAnnotation,
ReplaceReadWritePropertyWithField,
RewritePybind11EnumValueRepr,
WrapInvalidExpressions,
)
from pybind11_stubgen.parser.mixins.parse import (
BaseParser,
Expand Down Expand Up @@ -276,6 +277,7 @@ class Parser(
FixRedundantMethodsFromBuiltinObject,
RemoveSelfAnnotation,
FixPybind11EnumStrDoc,
WrapInvalidExpressions,
ExtractSignaturesFromPybind11Docstrings,
ParserDispatchMixin,
BaseParser,
Expand Down Expand Up @@ -306,7 +308,7 @@ def main():
args = arg_parser().parse_args(namespace=CLIArgs())

parser = stub_parser_from_args(args)
printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is)
printer = Printer()

out_dir, sub_dir = to_output_and_subdir(
output_dir=args.output_dir,
Expand Down
39 changes: 36 additions & 3 deletions pybind11_stubgen/parser/mixins/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ def handle_value(self, value: Any) -> Value:
result = super().handle_value(value)
if inspect.isroutine(value) and result.is_print_safe:
self._add_import(QualifiedName.from_str(result.repr))
else:
type_ = type(value)
self._add_import(
QualifiedName.from_str(f"{type_.__module__}.{type_.__qualname__}")
)
return result

def parse_annotation_str(
Expand All @@ -159,7 +164,9 @@ def _add_import(self, name: QualifiedName) -> None:
return
if len(name) == 1 and len(name[0]) == 0:
return
if hasattr(builtins, name[0]):
if len(name) == 1 and hasattr(builtins, name[0]):
return
if len(name) > 0 and name[0] == "builtins":
return
if self.__current_class is not None and hasattr(self.__current_class, name[0]):
return
Expand All @@ -171,6 +178,8 @@ def _add_import(self, name: QualifiedName) -> None:
if module_name is None:
self.report_error(NameResolutionError(name))
return
if self.__current_module.__name__ == str(module_name):
return
self.__extra_imports.add(Import(name=None, origin=module_name))

def _get_parent_module(self, name: QualifiedName) -> QualifiedName | None:
Expand Down Expand Up @@ -498,6 +507,14 @@ def handle_value(self, value: Any) -> Value:
result.repr = self._pattern.sub(r"<\g<name> object>", result.repr)
return result

def parse_value_str(self, value: str) -> Value | InvalidExpression:
result = super().parse_value_str(value)
if isinstance(result, Value):
result.repr = self._pattern.sub(r"<\g<name> object>", result.repr)
else:
result.text = self._pattern.sub(r"<\g<name> object>", result.text)
return result


class FixNumpyArrayDimAnnotation(IParser):
__array_names: set[QualifiedName] = {
Expand Down Expand Up @@ -867,6 +884,24 @@ def handle_class_member(
return result


class WrapInvalidExpressions(IParser):
def parse_annotation_str(
self, annotation_str: str
) -> ResolvedType | InvalidExpression | Value:
result = super().parse_annotation_str(annotation_str)
if not isinstance(result, InvalidExpression):
return result

substitute_t = self.parse_annotation_str("typing.Any")
return ResolvedType(
QualifiedName.from_str("Annotated"),
parameters=[
substitute_t,
result,
],
)


class FixMissingFixedSizeImport(IParser):
def parse_annotation_str(
self, annotation_str: str
Expand All @@ -887,8 +922,6 @@ def parse_annotation_str(
except ValueError:
pass
else:
# call `handle_type` to trigger implicit import
self.handle_type(FixedSize)
return self.handle_value(FixedSize(*dimensions))
return result

Expand Down
18 changes: 14 additions & 4 deletions pybind11_stubgen/parser/mixins/parse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import ast
import datetime
import inspect
import re
import types
import typing
from typing import Any

from pybind11_stubgen.parser.errors import (
Expand Down Expand Up @@ -34,6 +36,7 @@
TypeVar_,
Value,
)
from pybind11_stubgen.typing_ext import DynamicSize, FixedSize

_generic_args = [
Argument(name=Identifier("args"), variadic=True),
Expand Down Expand Up @@ -250,16 +253,16 @@ def handle_function(self, path: QualifiedName, func: Any) -> list[Function]:
func_args[arg_name].annotation = self.parse_annotation_str(
annotation
)
elif not isinstance(annotation, type):
func_args[arg_name].annotation = self.handle_value(annotation)
elif self._is_generic_alias(annotation):
func_args[arg_name].annotation = self.parse_annotation_str(
str(annotation)
)
else:
elif isinstance(annotation, type):
func_args[arg_name].annotation = ResolvedType(
name=self.handle_type(annotation),
)
else:
func_args[arg_name].annotation = self.handle_value(annotation)
if "return" in func_args:
returns = func_args["return"].annotation
else:
Expand Down Expand Up @@ -291,7 +294,10 @@ def _is_generic_alias(self, annotation: type) -> bool:
generic_alias_t: type | None = getattr(types, "GenericAlias", None)
if generic_alias_t is None:
return False
return isinstance(annotation, generic_alias_t)
typing_generic_alias_t = type(typing.List[int])
return isinstance(annotation, generic_alias_t) or isinstance(
annotation, typing_generic_alias_t
)

def handle_import(self, path: QualifiedName, origin: Any) -> Import | None:
full_name = self._get_full_name(path, origin)
Expand Down Expand Up @@ -370,6 +376,10 @@ def handle_value(self, value: Any) -> Value:
return Value(repr=str(self.handle_type(value)), is_print_safe=True)
if inspect.ismodule(value):
return Value(repr=value.__name__, is_print_safe=True)
if isinstance(value, datetime.timedelta):
return Value(repr=repr(value), is_print_safe=True)
if isinstance(value, (FixedSize, DynamicSize)):
return Value(repr=repr(value), is_print_safe=True)
return Value(repr=repr(value), is_print_safe=False)

def handle_type(self, type_: type) -> QualifiedName:
Expand Down
73 changes: 41 additions & 32 deletions pybind11_stubgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Modifier,
Module,
Property,
QualifiedName,
ResolvedType,
TypeVar_,
Value,
Expand All @@ -30,8 +31,8 @@ def indent_lines(lines: list[str], by=4) -> list[str]:


class Printer:
def __init__(self, invalid_expr_as_ellipses: bool):
self.invalid_expr_as_ellipses = invalid_expr_as_ellipses
def __init__(self):
self._need_typing_ext = False

def print_alias(self, alias: Alias) -> list[str]:
return [f"{alias.name} = {alias.origin}"]
Expand All @@ -43,13 +44,12 @@ def print_attribute(self, attr: Attribute) -> list[str]:
if attr.annotation is not None:
parts.append(f": {self.print_annotation(attr.annotation)}")

if attr.value is not None and attr.value.is_print_safe:
parts.append(f" = {self.print_value(attr.value)}")
else:
if attr.annotation is None:
parts.append(" = ...")
if attr.value is not None:
parts.append(f" # value = {self.print_value(attr.value)}")
if attr.value is not None:
if attr.value.is_print_safe or attr.annotation is None:
parts.append(f" = {self.print_value(attr.value)}")
else:
repr_first_line = attr.value.repr.split("\n", 1)[0]
parts.append(f" # value = {repr_first_line}")

return ["".join(parts)]

Expand Down Expand Up @@ -202,40 +202,51 @@ def print_method(self, method: Method) -> list[str]:
return result

def print_module(self, module: Module) -> list[str]:
result = []

if module.doc is not None:
result.extend(self.print_docstring(module.doc))

for import_ in sorted(module.imports, key=lambda x: x.origin):
result.extend(self.print_import(import_))
result_bottom = []
tmp = self._need_typing_ext

for sub_module in module.sub_modules:
result.extend(self.print_submodule_import(sub_module.name))
result_bottom.extend(self.print_submodule_import(sub_module.name))

# Place __all__ above everything
for attr in sorted(module.attributes, key=lambda a: a.name):
if attr.name == "__all__":
result.extend(self.print_attribute(attr))
result_bottom.extend(self.print_attribute(attr))
break

for type_var in sorted(module.type_vars, key=lambda t: t.name):
result.extend(self.print_type_var(type_var))
result_bottom.extend(self.print_type_var(type_var))

for class_ in sorted(module.classes, key=lambda c: c.name):
result.extend(self.print_class(class_))
result_bottom.extend(self.print_class(class_))

for func in sorted(module.functions, key=lambda f: f.name):
result.extend(self.print_function(func))
result_bottom.extend(self.print_function(func))

for attr in sorted(module.attributes, key=lambda a: a.name):
if attr.name != "__all__":
result.extend(self.print_attribute(attr))
result_bottom.extend(self.print_attribute(attr))

for alias in module.aliases:
result.extend(self.print_alias(alias))
result_bottom.extend(self.print_alias(alias))

if self._need_typing_ext:
module.imports.add(
Import(
name=None,
origin=QualifiedName.from_str("pybind11_stubgen.typing_ext"),
)
)

return result
result_top = []
if module.doc is not None:
result_top.extend(self.print_docstring(module.doc))

for import_ in sorted(module.imports, key=lambda x: x.origin):
result_top.extend(self.print_import(import_))

self._need_typing_ext = tmp
return result_top + result_bottom

def print_property(self, prop: Property) -> list[str]:
if not prop.getter:
Expand Down Expand Up @@ -276,11 +287,10 @@ def print_property(self, prop: Property) -> list[str]:
return result

def print_value(self, value: Value) -> str:
split = value.repr.split("\n", 1)
if len(split) == 1:
return split[0]
else:
return split[0] + "..."
if value.is_print_safe:
return value.repr
self._need_typing_ext = True
return f"pybind11_stubgen.typing_ext.ValueExpr({repr(value.repr)})"

def print_type(self, type_: ResolvedType) -> str:
if (
Expand Down Expand Up @@ -312,6 +322,5 @@ def print_annotation(self, annotation: Annotation) -> str:
raise AssertionError()

def print_invalid_exp(self, invalid_expr: InvalidExpression) -> str:
if self.invalid_expr_as_ellipses:
return "..."
return invalid_expr.text
self._need_typing_ext = True
return f"pybind11_stubgen.typing_ext.InvalidExpr({repr(invalid_expr.text)})"
16 changes: 16 additions & 0 deletions pybind11_stubgen/typing_ext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any


class FixedSize:
def __init__(self, *dim: int):
Expand All @@ -23,3 +25,17 @@ def __repr__(self):
f"{self.__class__.__qualname__}"
f"({', '.join(repr(d) for d in self.dim)})"
)


def InvalidExpr(expr: str) -> Any:
raise RuntimeError(
"The method exists only for annotation purposes in stub files. "
"Should never not be used at runtime"
)


def ValueExpr(expr: str) -> Any:
raise RuntimeError(
"The method exists only for annotation purposes in stub files. "
"Should never not be used at runtime"
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ from __future__ import annotations

import typing

import pybind11_stubgen.typing_ext

__all__ = ["ConsoleForegroundColor", "Magenta", "accepts_ambiguous_enum"]

class ConsoleForegroundColor:
Expand Down Expand Up @@ -32,6 +34,10 @@ class ConsoleForegroundColor:
@property
def value(self) -> int: ...

def accepts_ambiguous_enum(color: ConsoleForegroundColor = ...) -> None: ...
def accepts_ambiguous_enum(
color: ConsoleForegroundColor = pybind11_stubgen.typing_ext.InvalidExpr(
"<ConsoleForegroundColor.Magenta: 35>"
),
) -> None: ...

Magenta: ConsoleForegroundColor # value = <ConsoleForegroundColor.Magenta: 35>
Loading