From 903aa91811548fb74d473544ecb4c139b1e6801c Mon Sep 17 00:00:00 2001 From: jsh9 <25124332+jsh9@users.noreply.github.com> Date: Sun, 15 Dec 2024 17:23:04 -0500 Subject: [PATCH] Add mypy static type checking to pydoclint (#185) --- CHANGELOG.md | 2 ++ pydoclint/flake8_entry.py | 6 ++-- pydoclint/main.py | 2 +- pydoclint/parse_config.py | 6 ++-- pydoclint/utils/arg.py | 35 ++++++++++++------- pydoclint/utils/astTypes.py | 9 ----- pydoclint/utils/doc.py | 17 ++++----- pydoclint/utils/generic.py | 27 ++++++++++----- pydoclint/utils/return_anno.py | 14 +++++--- pydoclint/utils/return_yield_raise.py | 21 +++++------ pydoclint/utils/unparser_custom.py | 3 +- pydoclint/utils/visitor_helper.py | 50 +++++++++++++++------------ pydoclint/utils/walk.py | 12 ++++--- pydoclint/visitor.py | 28 ++++++++------- pyproject.toml | 4 +++ tox.ini | 14 ++++++-- 16 files changed, 149 insertions(+), 101 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index efbeb60..92b8333 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ - Changed - Dropped support for Python 3.8 +- Added + - Added static type checking using `mypy` ## [0.5.11] - 2024-12-14 diff --git a/pydoclint/flake8_entry.py b/pydoclint/flake8_entry.py index 93b8513..74558f2 100644 --- a/pydoclint/flake8_entry.py +++ b/pydoclint/flake8_entry.py @@ -1,3 +1,5 @@ +# mypy: disable-error-code=attr-defined + import ast import importlib.metadata as importlib_metadata from typing import Any, Generator, Tuple @@ -15,7 +17,7 @@ def __init__(self, tree: ast.AST) -> None: self._tree = tree @classmethod - def add_options(cls, parser): # noqa: D102 + def add_options(cls, parser: Any) -> None: # noqa: D102 parser.add_option( '--style', action='store', @@ -196,7 +198,7 @@ def add_options(cls, parser): # noqa: D102 ) @classmethod - def parse_options(cls, options): # noqa: D102 + def parse_options(cls, options: Any) -> None: # noqa: D102 cls.type_hints_in_signature = options.type_hints_in_signature cls.type_hints_in_docstring = options.type_hints_in_docstring cls.arg_type_hints_in_signature = options.arg_type_hints_in_signature diff --git a/pydoclint/main.py b/pydoclint/main.py index 06f03f6..7a0b8d5 100644 --- a/pydoclint/main.py +++ b/pydoclint/main.py @@ -362,7 +362,7 @@ def main( # noqa: C901 ctx.exit(1) # it means users supply this option - if require_return_section_when_returning_none != 'None': + if require_return_section_when_returning_none != 'None': # type:ignore[comparison-overlap] click.echo( click.style( ''.join([ diff --git a/pydoclint/parse_config.py b/pydoclint/parse_config.py index 35b9da9..25c4847 100644 --- a/pydoclint/parse_config.py +++ b/pydoclint/parse_config.py @@ -87,10 +87,10 @@ def findCommonParentFolder( makeAbsolute: bool = True, # allow makeAbsolute=False just for testing ) -> Path: """Find the common parent folder of the given ``paths``""" - paths = [Path(path) for path in paths] + paths_: Sequence[Path] = [Path(path) for path in paths] - common_parent = paths[0] - for path in paths[1:]: + common_parent = paths_[0] + for path in paths_[1:]: if len(common_parent.parts) > len(path.parts): common_parent, path = path, common_parent diff --git a/pydoclint/utils/arg.py b/pydoclint/utils/arg.py index 752ae57..4b0e08e 100644 --- a/pydoclint/utils/arg.py +++ b/pydoclint/utils/arg.py @@ -29,7 +29,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return f'{self.name}: {self.typeHint}' - def __eq__(self, other: 'Arg') -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, Arg): return False @@ -84,17 +84,22 @@ def fromDocstringAttr(cls, attr: DocstringAttr) -> 'Arg': @classmethod def fromAstArg(cls, astArg: ast.arg) -> 'Arg': """Construct an Arg object from a Python AST argument object""" - anno = astArg.annotation - typeHint: str = '' if anno is None else unparseName(anno) + anno: Optional[ast.expr] = astArg.annotation + typeHint: Optional[str] = '' if anno is None else unparseName(anno) + assert typeHint is not None # to help mypy better understand type return Arg(name=astArg.arg, typeHint=typeHint) @classmethod def fromAstAnnAssign(cls, astAnnAssign: ast.AnnAssign) -> 'Arg': """Construct an Arg object from a Python ast.AnnAssign object""" - return Arg( - name=unparseName(astAnnAssign.target), - typeHint=unparseName(astAnnAssign.annotation), - ) + unparsedArgName = unparseName(astAnnAssign.target) + unparsedTypeHint = unparseName(astAnnAssign.annotation) + + # These assertions are to help mypy better interpret types + assert unparsedArgName is not None + assert unparsedTypeHint is not None + + return Arg(name=unparsedArgName, typeHint=unparsedTypeHint) @classmethod def _str(cls, typeName: Optional[str]) -> str: @@ -113,12 +118,12 @@ def _typeHintsEq(cls, hint1: str, hint2: str) -> bool: # >>> "ghi", # >>> ] try: - hint1_: str = unparseName(ast.parse(stripQuotes(hint1))) + hint1_: str = unparseName(ast.parse(stripQuotes(hint1))) # type:ignore[arg-type,assignment] except SyntaxError: hint1_ = hint1 try: - hint2_: str = unparseName(ast.parse(stripQuotes(hint2))) + hint2_: str = unparseName(ast.parse(stripQuotes(hint2))) # type:ignore[arg-type,assignment] except SyntaxError: hint2_ = hint2 @@ -156,7 +161,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return '[' + ', '.join(str(_) for _ in self.infoList) + ']' - def __eq__(self, other: 'ArgList') -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, ArgList): return False @@ -221,7 +226,9 @@ def fromAstAssign(cls, astAssign: ast.Assign) -> 'ArgList': elif isinstance(target, ast.Name): # such as `a = 1` or `a = b = 2` infoList.append(Arg(name=target.id, typeHint='')) elif isinstance(target, ast.Attribute): # e.g., uvw.xyz = 1 - infoList.append(Arg(name=unparseName(target), typeHint='')) + unparsedTarget: Optional[str] = unparseName(target) + assert unparsedTarget is not None # to help mypy understand type + infoList.append(Arg(name=unparsedTarget, typeHint='')) else: raise EdgeCaseError( f'astAssign.targets[{i}] is of type {type(target)}' @@ -303,7 +310,11 @@ def findArgsWithDifferentTypeHints(self, other: 'ArgList') -> List[Arg]: return result - def subtract(self, other: 'ArgList', checkTypeHint=True) -> Set[Arg]: + def subtract( + self, + other: 'ArgList', + checkTypeHint: bool = True, + ) -> Set[Arg]: """Find the args that are in this object but not in `other`.""" if checkTypeHint: return set(self.infoList) - set(other.infoList) diff --git a/pydoclint/utils/astTypes.py b/pydoclint/utils/astTypes.py index 8443db9..8e02f46 100644 --- a/pydoclint/utils/astTypes.py +++ b/pydoclint/utils/astTypes.py @@ -4,15 +4,6 @@ FuncOrAsyncFuncDef = Union[ast.AsyncFunctionDef, ast.FunctionDef] ClassOrFunctionDef = Union[ast.ClassDef, ast.AsyncFunctionDef, ast.FunctionDef] -AnnotationType = Union[ - ast.Name, - ast.Subscript, - ast.Index, - ast.Tuple, - ast.Constant, - ast.BinOp, - ast.Attribute, -] LegacyBlockTypes = [ ast.If, diff --git a/pydoclint/utils/doc.py b/pydoclint/utils/doc.py index 3b82085..720c1c0 100644 --- a/pydoclint/utils/doc.py +++ b/pydoclint/utils/doc.py @@ -1,5 +1,5 @@ import pprint -from typing import Any, List +from typing import Any, List, Optional, Union from docstring_parser.common import ( Docstring, @@ -23,6 +23,7 @@ def __init__(self, docstring: str, style: str = 'numpy') -> None: self.docstring = docstring self.style = style + parser: Union[NumpydocParser, GoogleParser] if style == 'numpy': parser = NumpydocParser() self.parsed = parser.parse(docstring) @@ -38,7 +39,7 @@ def __repr__(self) -> str: return pprint.pformat(self.__dict__, indent=2) @property - def isShortDocstring(self) -> bool: + def isShortDocstring(self) -> bool: # type:ignore[return] """Is the docstring a short one (containing only a summary)""" if self.style in {'google', 'numpy', 'sphinx'}: # API documentation: @@ -60,7 +61,7 @@ def isShortDocstring(self) -> bool: self._raiseException() # noqa: R503 @property - def argList(self) -> ArgList: + def argList(self) -> ArgList: # type:ignore[return] """The argument info in the docstring, presented as an ArgList""" if self.style in {'google', 'numpy', 'sphinx'}: return ArgList.fromDocstringParam(self.parsed.params) @@ -68,7 +69,7 @@ def argList(self) -> ArgList: self._raiseException() # noqa: R503 @property - def attrList(self) -> ArgList: + def attrList(self) -> ArgList: # type:ignore[return] """The attributes info in the docstring, presented as an ArgList""" if self.style in {'google', 'numpy', 'sphinx'}: return ArgList.fromDocstringAttr(self.parsed.attrs) @@ -76,16 +77,16 @@ def attrList(self) -> ArgList: self._raiseException() # noqa: R503 @property - def hasReturnsSection(self) -> bool: + def hasReturnsSection(self) -> bool: # type:ignore[return] """Whether the docstring has a 'Returns' section""" if self.style in {'google', 'numpy', 'sphinx'}: - retSection: DocstringReturns = self.parsed.returns + retSection: Optional[DocstringReturns] = self.parsed.returns return retSection is not None and not retSection.is_generator self._raiseException() # noqa: R503 @property - def hasYieldsSection(self) -> bool: + def hasYieldsSection(self) -> bool: # type:ignore[return] """Whether the docstring has a 'Yields' section""" if self.style in {'google', 'numpy', 'sphinx'}: yieldSection: DocstringYields = self.parsed.yields @@ -94,7 +95,7 @@ def hasYieldsSection(self) -> bool: self._raiseException() # noqa: R503 @property - def hasRaisesSection(self) -> bool: + def hasRaisesSection(self) -> bool: # type:ignore[return] """Whether the docstring has a 'Raises' section""" if self.style in {'google', 'numpy', 'sphinx'}: return len(self.parsed.raises) > 0 diff --git a/pydoclint/utils/generic.py b/pydoclint/utils/generic.py index 2d7ea72..7560617 100644 --- a/pydoclint/utils/generic.py +++ b/pydoclint/utils/generic.py @@ -1,12 +1,17 @@ +from __future__ import annotations + import ast import copy import re -from typing import List, Match, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Match, Optional, Tuple, Union from pydoclint.utils.astTypes import ClassOrFunctionDef, FuncOrAsyncFuncDef from pydoclint.utils.method_type import MethodType from pydoclint.utils.violation import Violation +if TYPE_CHECKING: + from pydoclint.utils.arg import Arg, ArgList + def collectFuncArgs(node: FuncOrAsyncFuncDef) -> List[ast.arg]: """ @@ -70,7 +75,7 @@ def getFunctionId(node: FuncOrAsyncFuncDef) -> Tuple[int, int, str]: return node.lineno, node.col_offset, node.name -def detectMethodType(node: ast.FunctionDef) -> MethodType: +def detectMethodType(node: FuncOrAsyncFuncDef) -> MethodType: """ Detect whether the function def is an instance method, a classmethod, or a staticmethod. @@ -159,11 +164,17 @@ def getNodeName(node: ast.AST) -> str: if node is None: return '' - return node.name if 'name' in node.__dict__ else '' + return getattr(node, 'name', '') -def stringStartsWith(string: str, substrings: Tuple[str, ...]) -> bool: +def stringStartsWith( + string: Optional[str], + substrings: Tuple[str, ...], +) -> bool: """Check whether the string starts with any of the substrings""" + if string is None: + return False + for substring in substrings: if string.startswith(substring): return True @@ -202,11 +213,11 @@ def _replacer(match: Match[str]) -> str: def appendArgsToCheckToV105( *, original_v105: Violation, - funcArgs: 'ArgList', # noqa: F821 - docArgs: 'ArgList', # noqa: F821 + funcArgs: ArgList, + docArgs: ArgList, ) -> Violation: """Append the arg names to check to the error message of v105 or v605""" - argsToCheck: List['Arg'] = funcArgs.findArgsWithDifferentTypeHints(docArgs) # noqa: F821 + argsToCheck: List[Arg] = funcArgs.findArgsWithDifferentTypeHints(docArgs) argNames: str = ', '.join(_.name for _ in argsToCheck) return original_v105.appendMoreMsg(moreMsg=argNames) @@ -244,4 +255,4 @@ def getFullAttributeName(node: Union[ast.Attribute, ast.Name]) -> str: if isinstance(node, ast.Name): return node.id - return getFullAttributeName(node.value) + '.' + node.attr + return getFullAttributeName(node.value) + '.' + node.attr # type:ignore[arg-type] diff --git a/pydoclint/utils/return_anno.py b/pydoclint/utils/return_anno.py index 0c96288..603c8e9 100644 --- a/pydoclint/utils/return_anno.py +++ b/pydoclint/utils/return_anno.py @@ -38,6 +38,8 @@ def decompose(self) -> List[str]: When the annotation string has strange values """ if self._isTuple(): # noqa: R506 + assert self.annotation is not None # to help mypy understand type + if not self.annotation.endswith(']'): raise EdgeCaseError('Return annotation not ending with `]`') @@ -49,15 +51,16 @@ def decompose(self) -> List[str]: insideTuple: str = self.annotation[6:-1] if insideTuple.endswith('...'): # like this: Tuple[int, ...] - return [self.annotation] # b/c we don't know the tuple's length + # because we don't know the tuple's length + return [self.annotation] - parsedBody0: ast.Expr = ast.parse(insideTuple).body[0] + parsedBody0: ast.Expr = ast.parse(insideTuple).body[0] # type:ignore[assignment] if isinstance(parsedBody0.value, ast.Name): # like this: Tuple[int] return [insideTuple] if isinstance(parsedBody0.value, ast.Tuple): # like Tuple[int, str] - elts: List = parsedBody0.value.elts - return [unparseName(_) for _ in elts] + elts: List[ast.expr] = parsedBody0.value.elts + return [unparseName(_) for _ in elts] # type:ignore[misc] raise EdgeCaseError('decompose(): This should not have happened') else: @@ -65,7 +68,8 @@ def decompose(self) -> List[str]: def _isTuple(self) -> bool: try: - annoHead = ast.parse(self.annotation).body[0].value.value.id + assert self.annotation is not None # to help mypy understand type + annoHead = ast.parse(self.annotation).body[0].value.value.id # type:ignore[attr-defined] return annoHead in {'tuple', 'Tuple'} except Exception: return False diff --git a/pydoclint/utils/return_yield_raise.py b/pydoclint/utils/return_yield_raise.py index dde2525..a94a6d7 100644 --- a/pydoclint/utils/return_yield_raise.py +++ b/pydoclint/utils/return_yield_raise.py @@ -7,7 +7,7 @@ from pydoclint.utils.unparser_custom import unparseName ReturnType = Type[ast.Return] -ExprType = Type[ast.Expr] +ExprType = Type[ast.expr] YieldAndYieldFromTypes = Tuple[Type[ast.Yield], Type[ast.YieldFrom]] FuncOrAsyncFuncTypes = Tuple[Type[ast.FunctionDef], Type[ast.AsyncFunctionDef]] FuncOrAsyncFunc = (ast.FunctionDef, ast.AsyncFunctionDef) @@ -23,7 +23,7 @@ def isReturnAnnotationNone(node: FuncOrAsyncFuncDef) -> bool: return _isNone(node.returns) -def _isNone(node: ast.AST) -> bool: +def _isNone(node: Optional[ast.expr]) -> bool: return isinstance(node, ast.Constant) and node.value is None @@ -32,7 +32,7 @@ def isReturnAnnotationNoReturn(node: FuncOrAsyncFuncDef) -> bool: if node.returns is None: return False - returnAnnotation: str = unparseName(node.returns) + returnAnnotation: Optional[str] = unparseName(node.returns) return returnAnnotation == 'NoReturn' @@ -41,7 +41,7 @@ def hasGeneratorAsReturnAnnotation(node: FuncOrAsyncFuncDef) -> bool: if node.returns is None: return False - returnAnno: str = unparseName(node.returns) + returnAnno: Optional[str] = unparseName(node.returns) return returnAnno in {'Generator', 'AsyncGenerator'} or stringStartsWith( returnAnno, ('Generator[', 'AsyncGenerator[') ) @@ -52,7 +52,7 @@ def hasIteratorOrIterableAsReturnAnnotation(node: FuncOrAsyncFuncDef) -> bool: if node.returns is None: return False - returnAnnotation: str = unparseName(node.returns) + returnAnnotation: Optional[str] = unparseName(node.returns) return returnAnnotation in { 'Iterator', 'Iterable', @@ -240,14 +240,15 @@ def _updateFamilyTree( def _getLineNum(node: ast.AST) -> int: + lineNum: int try: if 'lineno' in node.__dict__: # normal case - lineNum = node.lineno + lineNum = node.lineno # type:ignore[attr-defined] elif 'pattern' in node.__dict__: # the node is a `case ...:` - lineNum = node.pattern.lineno - else: - lineNum = node.lineno # this could fail - except Exception: + lineNum = node.pattern.lineno # type:ignore[attr-defined] + else: # fallback case, but this could still fail + lineNum = node.lineno # type:ignore[attr-defined] + except AttributeError: # if `node` doesn't have any of those attributes lineNum = -1 return lineNum diff --git a/pydoclint/utils/unparser_custom.py b/pydoclint/utils/unparser_custom.py index 7f6a910..df6cf33 100644 --- a/pydoclint/utils/unparser_custom.py +++ b/pydoclint/utils/unparser_custom.py @@ -3,7 +3,6 @@ import sys from typing import Optional, Union -from pydoclint.utils.astTypes import AnnotationType from pydoclint.utils.edge_case_error import EdgeCaseError @@ -32,7 +31,7 @@ def py311unparse(astObj: ast.AST) -> str: def unparseName( - node: Union[AnnotationType, ast.Module, None], + node: Union[ast.expr, ast.Module, None], ) -> Optional[str]: """Parse type annotations from argument list or return annotation.""" if node is None: diff --git a/pydoclint/utils/visitor_helper.py b/pydoclint/utils/visitor_helper.py index 2b193a9..743bc7b 100644 --- a/pydoclint/utils/visitor_helper.py +++ b/pydoclint/utils/visitor_helper.py @@ -170,7 +170,7 @@ def extractClassAttributesFromNode( atl.append( Arg( name=itm.name, - typeHint=unparseName(itm.returns), + typeHint=unparseName(itm.returns), # type:ignore[arg-type] ) ) @@ -355,7 +355,7 @@ def checkReturnTypesForNumpyStyle( returnAnnoItems: List[str] = returnAnnotation.decompose() returnAnnoInList: List[str] = returnAnnotation.putAnnotationInList() - returnSecTypes: List[str] = [stripQuotes(_.argType) for _ in returnSection] + returnSecTypes: List[str] = [stripQuotes(_.argType) for _ in returnSection] # type:ignore[misc] if returnAnnoInList != returnSecTypes: if len(returnAnnoItems) != len(returnSection): @@ -389,7 +389,7 @@ def checkReturnTypesForGoogleOrSphinxStyle( # use one compound style for tuples. if len(returnSection) > 0: - retArgType: str = stripQuotes(returnSection[0].argType) + retArgType: str = stripQuotes(returnSection[0].argType) # type:ignore[assignment] if returnAnnotation.annotation is None: msg = 'Return annotation has 0 type(s); docstring' msg += ' return section has 1 type(s).' @@ -426,7 +426,9 @@ def checkYieldTypesForViolations( # to check and less ambiguous. returnAnnoText: Optional[str] = returnAnnotation.annotation - yieldType: str = extractYieldTypeFromGeneratorOrIteratorAnnotation( + + extract = extractYieldTypeFromGeneratorOrIteratorAnnotation + yieldType: Optional[str] = extract( returnAnnoText, hasGeneratorAsReturnAnnotation, hasIteratorOrIterableAsReturnAnnotation, @@ -467,23 +469,24 @@ def checkYieldTypesForViolations( def extractYieldTypeFromGeneratorOrIteratorAnnotation( - returnAnnoText: str, + returnAnnoText: Optional[str], hasGeneratorAsReturnAnnotation: bool, hasIteratorOrIterableAsReturnAnnotation: bool, -) -> str: +) -> Optional[str]: """Extract yield type from Generator or Iterator annotations""" - try: - # "Yield type" is the 0th element in a Generator - # type annotation (Generator[YieldType, SendType, - # ReturnType]) - # https://docs.python.org/3/library/typing.html#typing.Generator - # Or it's the 0th (only) element in Iterator - yieldType: str + # + # "Yield type" is the 0th element in a Generator + # type annotation (Generator[YieldType, SendType, + # ReturnType]) + # https://docs.python.org/3/library/typing.html#typing.Generator + # Or it's the 0th (only) element in Iterator + yieldType: Optional[str] + try: if hasGeneratorAsReturnAnnotation: if sys.version_info >= (3, 9): yieldType = unparseName( - ast.parse(returnAnnoText).body[0].value.slice.elts[0] + ast.parse(returnAnnoText).body[0].value.slice.elts[0] # type:ignore[attr-defined,arg-type] ) else: yieldType = unparseName( @@ -491,7 +494,7 @@ def extractYieldTypeFromGeneratorOrIteratorAnnotation( ) elif hasIteratorOrIterableAsReturnAnnotation: yieldType = unparseName( - ast.parse(returnAnnoText).body[0].value.slice + ast.parse(returnAnnoText).body[0].value.slice # type:ignore[attr-defined,arg-type] ) else: yieldType = returnAnnoText @@ -501,17 +504,20 @@ def extractYieldTypeFromGeneratorOrIteratorAnnotation( return stripQuotes(yieldType) -def extractReturnTypeFromGenerator(returnAnnoText: str) -> str: +def extractReturnTypeFromGenerator( + returnAnnoText: Optional[str], +) -> Optional[str]: """Extract return type from Generator annotations""" + # + # "Return type" is the last element in a Generator + # type annotation (Generator[YieldType, SendType, + # ReturnType]) + # https://docs.python.org/3/library/typing.html#typing.Generator + returnType: Optional[str] try: - # "Return type" is the last element in a Generator - # type annotation (Generator[YieldType, SendType, - # ReturnType]) - # https://docs.python.org/3/library/typing.html#typing.Generator - returnType: str if sys.version_info >= (3, 9): returnType = unparseName( - ast.parse(returnAnnoText).body[0].value.slice.elts[-1] + ast.parse(returnAnnoText).body[0].value.slice.elts[-1] # type:ignore[attr-defined,arg-type] ) else: returnType = unparseName( diff --git a/pydoclint/utils/walk.py b/pydoclint/utils/walk.py index 0f809b4..28baf89 100644 --- a/pydoclint/utils/walk.py +++ b/pydoclint/utils/walk.py @@ -17,29 +17,33 @@ """ import ast from collections import deque +from typing import Deque, Generator, Tuple -def walk(node): +def walk(node: ast.AST) -> Generator[Tuple[ast.AST, ast.AST], None, None]: """ Recursively yield all descendant nodes in the tree starting at *node* (including *node* itself), in no specified order. This is useful if you only want to modify nodes in place and don't care about the context. """ - todo = deque([(node, None)]) + emptyNode: ast.AST = ast.Pass() + todo: Deque[Tuple[ast.AST, ast.AST]] = deque([(node, emptyNode)]) while todo: node, parent = todo.popleft() todo.extend(iter_child_nodes(node)) yield node, parent -def walk_dfs(node): +def walk_dfs(node: ast.AST) -> Generator[Tuple[ast.AST, ast.AST], None, None]: """Depth-first traversal of AST. Modified from `walk()` in this file""" for child, parent in iter_child_nodes(node): yield child, parent yield from walk_dfs(child) -def iter_child_nodes(node): +def iter_child_nodes( + node: ast.AST, +) -> Generator[Tuple[ast.AST, ast.AST], None, None]: """ Yield all direct child nodes of *node*, that is, all fields that are nodes and all items of fields that are lists of nodes. diff --git a/pydoclint/visitor.py b/pydoclint/visitor.py index acdaedc..eb94c7e 100644 --- a/pydoclint/visitor.py +++ b/pydoclint/visitor.py @@ -1,5 +1,5 @@ import ast -from typing import List, Optional +from typing import List, Optional, Union from pydoclint.utils.arg import Arg, ArgList from pydoclint.utils.astTypes import FuncOrAsyncFuncDef @@ -90,10 +90,10 @@ def __init__( requireYieldSectionWhenYieldingNothing ) - self.parent: Optional[ast.AST] = None # keep track of parent node + self.parent: ast.AST = ast.Pass() # keep track of parent node self.violations: List[Violation] = [] - def visit_ClassDef(self, node: ast.ClassDef): # noqa: D102 + def visit_ClassDef(self, node: ast.ClassDef) -> None: # noqa: D102 currentParent = self.parent # keep aside self.parent = node @@ -120,8 +120,8 @@ def visit_ClassDef(self, node: ast.ClassDef): # noqa: D102 self.parent = currentParent # restore - def visit_FunctionDef(self, node: FuncOrAsyncFuncDef): # noqa: D102 - parent_ = self.parent # keep aside + def visit_FunctionDef(self, node: FuncOrAsyncFuncDef) -> None: # noqa: D102 + parent_: Union[ast.ClassDef, FuncOrAsyncFuncDef] = self.parent # type:ignore[assignment] self.parent = node isClassConstructor: bool = node.name == '__init__' and isinstance( @@ -133,6 +133,7 @@ def visit_FunctionDef(self, node: FuncOrAsyncFuncDef): # noqa: D102 self.isAbstractMethod = checkIsAbstractMethod(node) if isClassConstructor: + assert isinstance(parent_, ast.ClassDef) # to help mypy know type docstring = self._checkClassDocstringAndConstructorDocstrings( node=node, parent_=parent_, @@ -205,7 +206,8 @@ def visit_FunctionDef(self, node: FuncOrAsyncFuncDef): # noqa: D102 # different for class constructors. returnViolations = ( self.checkReturnsAndYieldsInClassConstructor( - parent=parent_, doc=doc + parent=parent_, # type: ignore[arg-type] + doc=doc, ) ) @@ -218,11 +220,11 @@ def visit_FunctionDef(self, node: FuncOrAsyncFuncDef): # noqa: D102 self.parent = parent_ # restore - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): # noqa: D102 + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa: D102 # Treat async functions similarly to regular ones self.visit_FunctionDef(node) - def visit_Raise(self, node: ast.Raise): # noqa: D102 + def visit_Raise(self, node: ast.Raise) -> None: # noqa: D102 self.generic_visit(node) def _checkClassDocstringAndConstructorDocstrings( # noqa: C901 @@ -618,7 +620,8 @@ def checkYields( # noqa: C901 returnAnno = ReturnAnnotation(None) if not docstringHasYieldsSection: - yieldType: str = extractYieldTypeFromGeneratorOrIteratorAnnotation( + extract = extractYieldTypeFromGeneratorOrIteratorAnnotation + yieldType: Optional[str] = extract( returnAnnoText=returnAnno.annotation, hasGeneratorAsReturnAnnotation=hasGenAsRetAnno, hasIteratorOrIterableAsReturnAnnotation=hasIterAsRetAnno, @@ -722,6 +725,7 @@ def my_function(num: int) -> Generator[int, None, str]: returnSec: List[ReturnArg] = doc.returnSection # Check the return section in the docstring + retTypeInGenerator: Optional[str] if not docstringHasReturnSection: if doc.isShortDocstring and self.skipCheckingShortDocstrings: pass @@ -735,7 +739,7 @@ def my_function(num: int) -> Generator[int, None, str]: # fmt: on ): - retTypeInGenerator: str = extractReturnTypeFromGenerator( + retTypeInGenerator = extractReturnTypeFromGenerator( returnAnnoText=returnAnno.annotation, ) # If "Generator[...]" is put in the return type annotation, @@ -748,7 +752,7 @@ def my_function(num: int) -> Generator[int, None, str]: else: if self.checkReturnTypes: if hasGenAsRetAnno: - retTypeInGenerator: str = extractReturnTypeFromGenerator( + retTypeInGenerator = extractReturnTypeFromGenerator( returnAnnoText=returnAnno.annotation, ) checkReturnTypesForViolations( @@ -775,7 +779,7 @@ def my_function(num: int) -> Generator[int, None, str]: if hasGenAsRetAnno or hasIterAsRetAnno: extract = extractYieldTypeFromGeneratorOrIteratorAnnotation - yieldType: str = extract( + yieldType: Optional[str] = extract( returnAnnoText=returnAnno.annotation, hasGeneratorAsReturnAnnotation=hasGenAsRetAnno, hasIteratorOrIterableAsReturnAnnotation=hasIterAsRetAnno, diff --git a/pyproject.toml b/pyproject.toml index 473ad06..46a3773 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,3 +11,7 @@ style = 'numpy' exclude = '\.git|.?venv|\.tox|tests/data|build' require-return-section-when-returning-nothing = true check-class-attributes = false + +[tool.mypy] +strict = true +exclude = "pydoclint/utils/unparser.py" diff --git a/tox.ini b/tox.ini index 93963c2..c5d39f0 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,7 @@ envlist = py310 py311 py312 + mypy cercis check-self flake8-basic @@ -14,9 +15,16 @@ envlist = [gh-actions] python = - 3.9: py39, cercis, check-self, flake8-basic, flake8-misc, flake8-docstrings, pre-commit - 3.10: py310, cercis, check-self, flake8-basic, flake8-misc, flake8-docstrings, pre-commit - 3.11: py311, cercis, check-self, flake8-basic, flake8-misc, flake8-docstrings, pre-commit + 3.9: py39, mypy, cercis, check-self, flake8-basic, flake8-misc, flake8-docstrings, pre-commit + 3.10: py310, mypy, cercis, check-self, flake8-basic, flake8-misc, flake8-docstrings, pre-commit + 3.11: py311, mypy, cercis, check-self, flake8-basic, flake8-misc, flake8-docstrings, pre-commit + + +[testenv:mypy] +deps = + mypy +commands = + mypy pydoclint/ [testenv:cercis]